diff --git a/__pycache__/mcmc_chain_analysis.cpython-34.pyc b/__pycache__/mcmc_chain_analysis.cpython-34.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f4261481b77817f0a5299f8eaec91dab934679b Binary files /dev/null and b/__pycache__/mcmc_chain_analysis.cpython-34.pyc differ diff --git a/__pycache__/pmf_analysis.cpython-37.pyc b/__pycache__/pmf_analysis.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d73cb225dcfed4bac0c68eecdffceec3453a4d1 Binary files /dev/null and b/__pycache__/pmf_analysis.cpython-37.pyc differ diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc index 8fa16cdde21c4a7d3d6daa3797043fc2e294e842..cc971e6effeb8c5cd96ac67921d02e35f9142f1b 100644 Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ diff --git a/all_first_pmfs_typeless b/all_first_pmfs_typeless new file mode 100644 index 0000000000000000000000000000000000000000..336970346beed422f9657b5d9db49519023836d9 Binary files /dev/null and b/all_first_pmfs_typeless differ diff --git a/behaviour_overview.py b/behaviour_overview.py index 48418b9c0dabfdb4e7b7bec914109f255a1fb815..24180ed205a00a9cb7e864c39782d62c92f8d658 100644 --- a/behaviour_overview.py +++ b/behaviour_overview.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd from one.api import ONE +import pickle show = 1 @@ -21,27 +22,39 @@ def progression(data, contrasts, progression_variable='feedback', windowsize=6, plt.ylim(0, upper_bound) plt.gca().spines['top'].set_visible(False) plt.gca().spines['right'].set_visible(False) - # if title == "7 rt" or title == "7 fb": - # plt.plot([271, 286], [1, 1], color='r', lw=3) - # plt.plot([429, 439], [1, 1], color='r', lw=3) - # plt.plot([632, 636], [1, 1], color='r', lw=3) - if title == "KS014, RT session 12 / 15" or title == "KS014, performance session 12 / 15": - plt.plot([599, 603], [1, 1], color='m', lw=6) - plt.plot([658, 662], [1, 1], color='m', lw=6) - plt.plot([703, 708], [1, 1], color='m', lw=6) + # if title == "KS014, RT session 12 / 15" or title == "KS014, performance session 12 / 15": + # plt.plot([599, 603], [1, 1], color='m', lw=6) + # plt.plot([658, 662], [1, 1], color='m', lw=6) + # plt.plot([703, 708], [1, 1], color='m', lw=6) plt.title(title, size=22) plt.legend(fontsize=13) plt.ylabel(progression_variable, size=20) plt.xlabel('trial', size=20) - if title: - plt.savefig('./overview_figures/' + title.replace('/', '_') + '.png') + # if title: + # plt.savefig('./overview_figures/' + title.replace('/', '_') + '.png') if show: plt.show() else: plt.close() + # if progression_variable == 'feedback': + # means = data.groupby('signed_contrast').mean()['response'] + # stds = data.groupby('signed_contrast').std()['response'] + # plt.errorbar(means.index, means.values, stds.values) + # plt.ylim(-0.2, 1.2) + # plt.title(title, size=22) + # plt.savefig("temp {}".format(title).replace('/', '_')) + # plt.close() + # if progression_variable == 'rt': + # means = data.groupby('signed_contrast').mean()['rt'] + # stds = data.groupby('signed_contrast').std()['rt'] + # plt.errorbar(means.index, means.values, stds.values) + # plt.title(title, size=22) + # plt.savefig("temp {}".format(title).replace('/', '_')) + # plt.show() + dataset_types = ['choice', 'contrastLeft', 'contrastRight', \ 'feedbackType', 'probabilityLeft', 'response_times', \ @@ -49,7 +62,7 @@ dataset_types = ['choice', 'contrastLeft', 'contrastRight', \ def get_df(eid): # dangerously many changes, check again try: - data_dict = one.load_object(eid, 'trials', attribute=dataset_types) # download_only=True messes with my code, returns path for whatever reason + data_dict = one.load_object(eid, 'trials') # download_only=True messes with my code, returns path for whatever reason except: print('lol') return None, None @@ -82,25 +95,34 @@ exclude_eids = ['a66f1593-dafd-4982-9b66-f9554b6c86b5', 'ee40aece-cffd-4edb-a4b6 # project='ibl_neuropixel_brainwide_01') # traj.reverse() -for subject in ['KS014']: - eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2022-01-01'], details=True) +subject = 'KS014' +eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2022-01-01'], details=True) - start_times = [sess['date'] for sess in sess_info] - protocols = [sess['task_protocol'] for sess in sess_info] +start_times = [sess['date'] for sess in sess_info] +protocols = [sess['task_protocol'] for sess in sess_info] - eids = [x for _, x in sorted(zip(start_times, eids))] +eids = [x for _, x in sorted(zip(start_times, eids))] +protocols = [x for _, x in sorted(zip(start_times, protocols))] # for t in traj: counti = 0 -for i, eid in enumerate(eids): +for i, (prot, eid) in enumerate(zip(protocols, eids)): print(i) + if not prot.startswith('_iblrig_tasks_trainingChoiceWorld'): + continue # eid = t['session']['id'] df, _ = get_df(eid) + if df is None: continue counti += 1 - if counti != 12: - continue - progression(df, df['signed_contrast'].unique(), progression_variable='feedback', upper_bound=2, title="{}, performance session {} / 15".format(subject, counti)) - progression(df, df['signed_contrast'].unique(), progression_variable='rt', upper_bound=4, title="{}, RT session {} / 15".format(subject, counti)) + + rt_data = np.zeros((len(df), 3)) + rt_data[:, 0] = df['signed_contrast'] + rt_data[:, 1] = df['rt'] + rt_data[:, 2] = df['response'] + pickle.dump(rt_data, open("./session_data/{} rt info {}".format(subject, counti), 'wb')) + + progression(df, df['signed_contrast'].unique(), progression_variable='feedback', upper_bound=2, title="{} PMF {} / 15".format(subject, counti)) + progression(df, df['signed_contrast'].unique(), progression_variable='rt', upper_bound=4, title="{} CMF {} / 15".format(subject, counti)) diff --git a/canonical_infos.json b/canonical_infos.json index c2ca57ff4832606bb4d5af5dad36768fc4162753..5f8da99817e18fb2831226f46d3018b12ba8b064 100644 --- a/canonical_infos.json +++ b/canonical_infos.json @@ -1 +1 @@ -{"SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "SWC_021": {"seeds": ["415", "403", "412", "407", "409", "408", "405", "404", "410", "414", "401", "413", "402", "400", "406", "411"], "fit_nums": ["773", "615", "107", "583", "564", "354", "142", "184", "549", "185", "924", "907", "105", "531", "9", "812"], "chain_num": 9}, "ibl_witten_15": {"seeds": ["409", "410", "401", "415", "414", "403", "411", "404", "402", "405", "400", "412", "408", "407", "406", "413"], "fit_nums": ["411", "344", "496", "600", "716", "18", "527", "467", "898", "334", "309", "326", "133", "823", "740", "253"], "chain_num": 9}, "ibl_witten_13": {"seeds": ["302", "312", "313", "306", "315", "307", "311", "314", "309", "301", "308", "300", "304", "310", "303", "305"], "fit_nums": ["897", "765", "433", "641", "967", "599", "984", "259", "853", "385", "887", "619", "434", "964", "483", "891"], "chain_num": 4}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "KS003": {"seeds": ["404", "407", "413", "403", "414", "405", "400", "401", "402", "410", "415", "408", "411", "409", "406", "412"], "fit_nums": ["846", "256", "845", "945", "293", "406", "420", "109", "690", "421", "54", "866", "784", "81", "997", "665"], "chain_num": 9}, "ibl_witten_19": {"seeds": ["315", "311", "307", "314", "308", "300", "305", "301", "313", "304", "302", "310", "306", "312", "309", "303"], "fit_nums": ["179", "951", "613", "6", "623", "382", "458", "504", "406", "554", "5", "631", "746", "817", "265", "328"], "chain_num": 4}, "SWC_022": {"seeds": ["411", "403", "414", "409", "407", "412", "410", "413", "415", "404", "405", "400", "402", "401", "408", "406"], "fit_nums": ["408", "884", "62", "962", "744", "854", "635", "70", "320", "952", "8", "67", "231", "381", "536", "962"], "chain_num": 9}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_017": {"seeds": ["401", "409", "405", "403", "415", "404", "402", "411", "410", "414", "408", "406", "413", "412", "400", "407"], "fit_nums": ["883", "803", "637", "806", "356", "804", "662", "654", "684", "350", "947", "460", "569", "976", "103", "713"], "chain_num": 9}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2, "ignore": [10, 12, 4, 1, 0, 3, 2, 6]}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}} \ No newline at end of file +{"SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4, "ignore": [12, 1, 15, 14, 8, 6, 4, 10]}, "SWC_021": {"seeds": ["415", "403", "412", "407", "409", "408", "405", "404", "410", "414", "401", "413", "402", "400", "406", "411"], "fit_nums": ["773", "615", "107", "583", "564", "354", "142", "184", "549", "185", "924", "907", "105", "531", "9", "812"], "chain_num": 9, "ignore": [14, 12, 0, 10, 9, 4, 5, 1]}, "ibl_witten_15": {"seeds": ["409", "410", "401", "415", "414", "403", "411", "404", "402", "405", "400", "412", "408", "407", "406", "413"], "fit_nums": ["411", "344", "496", "600", "716", "18", "527", "467", "898", "334", "309", "326", "133", "823", "740", "253"], "chain_num": 9, "ignore": [14, 13, 8, 4, 5, 12, 11, 9]}, "ibl_witten_13": {"seeds": ["302", "312", "313", "306", "315", "307", "311", "314", "309", "301", "308", "300", "304", "310", "303", "305"], "fit_nums": ["897", "765", "433", "641", "967", "599", "984", "259", "853", "385", "887", "619", "434", "964", "483", "891"], "chain_num": 4, "ignore": [3, 5, 15, 0, 2, 12, 11, 10]}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4, "ignore": [0, 2, 14, 12, 1, 7, 11, 6]}, "KS003": {"seeds": ["404", "407", "413", "403", "414", "405", "400", "401", "402", "410", "415", "408", "411", "409", "406", "412"], "fit_nums": ["846", "256", "845", "945", "293", "406", "420", "109", "690", "421", "54", "866", "784", "81", "997", "665"], "chain_num": 9, "ignore": [8, 15, 0, 13, 7, 12, 11, 1]}, "ibl_witten_19": {"seeds": ["315", "311", "307", "314", "308", "300", "305", "301", "313", "304", "302", "310", "306", "312", "309", "303"], "fit_nums": ["179", "951", "613", "6", "623", "382", "458", "504", "406", "554", "5", "631", "746", "817", "265", "328"], "chain_num": 4, "ignore": [13, 4, 10, 9, 2, 1, 3, 6]}, "SWC_022": {"seeds": ["411", "403", "414", "409", "407", "412", "410", "413", "415", "404", "405", "400", "402", "401", "408", "406"], "fit_nums": ["408", "884", "62", "962", "744", "854", "635", "70", "320", "952", "8", "67", "231", "381", "536", "962"], "chain_num": 9, "ignore": [4, 8, 7, 9, 1, 2, 10, 6]}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4, "ignore": [10, 1, 0, 13, 5, 9, 12, 3]}, "CSH_ZAD_017": {"seeds": ["401", "409", "405", "403", "415", "404", "402", "411", "410", "414", "408", "406", "413", "412", "400", "407"], "fit_nums": ["883", "803", "637", "806", "356", "804", "662", "654", "684", "350", "947", "460", "569", "976", "103", "713"], "chain_num": 9, "ignore": [3, 4, 6, 7, 5, 0, 15, 12]}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4, "ignore": [8, 10, 13, 5, 12, 9, 7, 1]}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9, "ignore": [9, 0, 1, 7, 11, 3, 10, 8]}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4, "ignore": [11, 0, 4, 2, 12, 13, 8, 3]}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4, "ignore": [15, 4, 8, 0, 5, 10, 12, 11]}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2, "ignore": [15, 9, 8, 14, 1, 12, 10, 3]}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4, "ignore": [0, 14, 5, 8, 7, 11, 13, 10]}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4, "ignore": [6, 5, 9, 15, 0, 8, 4, 13]}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9, "ignore": [5, 12, 7, 10, 11, 2, 6, 4]}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4, "ignore": [15, 0, 3, 4, 7, 6, 1, 11]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4, "ignore": [14, 6, 3, 11, 15, 13, 4, 12]}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4, "ignore": [12, 8, 5, 1, 9, 3, 13, 15]}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4, "ignore": [8, 2, 7, 12, 3, 4, 13, 11]}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9, "ignore": [12, 14, 1, 2, 4, 7, 10, 15]}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9, "ignore": [10, 11, 6, 7, 12, 13, 1, 8]}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4, "ignore": [11, 14, 6, 13, 5, 12, 15, 8]}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4, "ignore": [9, 3, 5, 15, 6, 10, 2, 1]}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2, "ignore": [12, 4, 11, 6, 7, 14, 0, 1]}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4, "ignore": [3, 12, 4, 5, 2, 0, 13, 1]}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4, "ignore": [0, 2, 14, 11, 7, 10, 13, 9]}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4, "ignore": [11, 12, 0, 8, 2, 14, 5, 1]}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2, "ignore": [8, 1, 0, 3, 2, 5, 10, 4]}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4, "ignore": [7, 6, 10, 2, 15, 13, 1, 3]}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4, "ignore": [3, 12, 2, 6, 10, 14, 4, 1]}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4, "ignore": [7, 8, 0, 10, 2, 3, 12, 9]}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2, "ignore": [0, 7, 15, 14, 3, 10, 11, 13]}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2, "ignore": [1, 9, 15, 3, 13, 12, 7, 11]}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4, "ignore": [12, 13, 4, 11, 8, 3, 15, 0]}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4, "ignore": [8, 10, 1, 13, 4, 15, 14, 7]}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2, "ignore": [11, 2, 5, 1, 4, 9, 15, 12]}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4, "ignore": [11, 13, 7, 15, 14, 3, 0, 4]}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4, "ignore": [15, 12, 8, 13, 0, 2, 4, 5]}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2, "ignore": [7, 11, 2, 15, 0, 13, 5, 10]}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2, "ignore": [10, 12, 4, 1, 0, 3, 2, 6]}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2, "ignore": [4, 10, 5, 0, 13, 8, 6, 7]}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4, "ignore": [14, 7, 12, 1, 3, 4, 11, 8]}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4, "ignore": [9, 12, 4, 8, 3, 7, 0, 1]}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4, "ignore": [0, 13, 8, 1, 12, 5, 10, 9]}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}} \ No newline at end of file diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index 04907a5a632d01e90a49247b007a1febd4f1fe6e..c44c7897447e7df141ad2fd4e516f65052f99d28 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -11,7 +11,7 @@ import pyhsmm import pickle import seaborn as sns import sys -from scipy.stats import zscore +from scipy.stats import zscore, norm from scipy.optimize import minimize from itertools import combinations, product import matplotlib.gridspec as gridspec @@ -20,6 +20,8 @@ import json import time import multiprocessing as mp from mcmc_chain_analysis import state_size_helper, state_num_helper, gamma_func, alpha_func, ll_func, r_hat_array_comp, rank_inv_normal_transform, eval_r_hat, eval_simple_r_hat +import pandas as pd +from pmf_analysis import pmf_type, type2color colors = np.genfromtxt('colors.csv', delimiter=',') @@ -38,12 +40,20 @@ bias_cont_ticks = (np.arange(9), [-1, -.25, -.125, -.062, 0, .062, .125, .25, 1] contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0]) contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1] -type_colours = ['g', 'b', 'r'] - def weights_to_pmf(weights, with_bias=1): - psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1] - return 1 / (1 + np.exp(-psi)) + psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1] + return 1 / (1 + np.exp(psi)) # we somehow got the answers twisted, so we drop the minus here to get the opposite response probability for plotting + +performance_points = np.array([-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0]) +def pmf_to_perf(pmf, def_points): + # determine performance of a pmf + # we use this to determine regressions in the behaviour of animals + # therefore, we exclude 0 as performance on it is 0.5 regardless of PMF, but it might + # overall lower performance. The removal of 0.5 later might also be a problem, let's see + relevant_points = def_points + relevant_points[5] = False + return np.mean(np.abs(performance_points[relevant_points] + pmf[relevant_points])) class MCMC_result_list: @@ -477,9 +487,11 @@ def return_ascending_shuffled(): return temp -def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figure', save_append='', consistencies=None): +def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figure', save_append='', consistencies=None, CMF=False): n = test.results[0].n_sessions trial_counter = 0 + cnas = [] # contrasts aNd actions + for seq_num in range(n): # if seq_num + 1 != 12: # trial_counter += len(test.results[0].models[0].stateseqs[seq_num]) @@ -518,9 +530,13 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur plt.plot(active_trials, color=cmap(0.2 + 0.8 * seq_num / test.results[0].n_sessions), lw=4, label=label, alpha=0.7) + trial_counter += len(test.results[0].models[0].stateseqs[seq_num]) + ms = 6 noise = np.zeros(len(c_n_a))# np.random.rand(len(c_n_a)) * 0.4 - 0.2 + cnas.append(c_n_a) + mask = c_n_a[:, -1] == 0 plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward', alpha=0.6) @@ -557,8 +573,51 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur plt.show() else: plt.close() - trial_counter += len(test.results[0].models[0].stateseqs[seq_num]) + if CMF and not test.results[0].name.startswith('Sim_'): + rt_data = pickle.load(open("./session_data/{} rt info {}".format(subject, seq_num + 1), 'rb')) + rt_data = rt_data[1:] + assert c_n_a.shape[0] == rt_data.shape[0] + df = pd.DataFrame(rt_data, columns = ['Signed contrast', 'RT', 'Responses']) + means = df.groupby('Signed contrast').mean()['RT'] + stds = df.groupby('Signed contrast').sem()['RT'] + plt.errorbar(means.index, means.values, stds.values) + plt.title(seq_num + 1, size=22) + plt.show() + + return cnas + +def pmf_regressions(states_by_session, pmfs, durs): + # find out whether pmfs regressed + state_perfs = {} + state_counter = {} + current_best_state = -1 + counter = 0 + types = [0, 0, 0] + + for sess in range(states_by_session.shape[1]): + states = np.where(states_by_session[:, sess])[0] + + for s in states: + if s not in state_counter: + state_counter[s] = -1 + state_counter[s] += 1 + state_perfs[s] = pmf_to_perf(pmfs[s][1][state_counter[s]], pmfs[s][0]) + if current_best_state == -1 or state_perfs[current_best_state] < state_perfs[s]: + current_best_state = s + + if state_perfs[np.argmax(states_by_session[:, sess])] + 0.05 < state_perfs[current_best_state] and state_counter[np.argmax(states_by_session[:, sess])] > 1: + counter += 1 + if sess < durs[0]: + types[0] += 1 + elif sess < durs[0] + durs[1]: + types[1] += 1 + else: + types[2] += 1 + a, b = state_perfs[np.argmax(states_by_session[:, sess])], state_perfs[current_best_state] + print("Regression in session {} during {:.2f}% of session ({:.2f} instead of {:.2f})".format(sess + 1, np.max(states_by_session[:, sess]) * 100, a, b)) + + return [counter, states_by_session.shape[1], *types] def control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for): # Generalised control flow for iterating over samples of mode across individually across sessions @@ -602,6 +661,25 @@ def state_pmfs(test, trials, indices): return results['session_js'], results['pmfs'] +def state_weights(test, trials, indices): + def func_init(): return {'weightss': [], 'session_js': []} + + def first_for(test, results): + results['weights'] = np.zeros(test.results[0].models[0].obs_distns[0].weights.shape[1]) + + def second_for(m, j, session_trials, trial_counter, results): + states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True) + for sub_state, c in zip(states, counts): + results['weights'] += m.obs_distns[sub_state].weights[j] * c / session_trials.shape[0] + + def end_first_for(results, indices, j, **kwargs): + results['weightss'].append(results['weights'] / len(indices)) + results['session_js'].append(j) + + results = control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for) + return results['session_js'], results['weightss'] + + def lapse_sides(test, state_sets, indices): """Compute and plot a lapse differential across sessions. @@ -875,7 +953,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show trial_counter += len(state_seq) counter += 1 pmfs_to_score.append(np.mean(pmfs)) - state_mapping = dict(zip(range(len(state_sets)), np.argsort(np.argsort(pmfs_to_score)))) # double argsort for ranks + # test.state_mapping = dict(zip(range(len(state_sets)), np.argsort(np.argsort(pmfs_to_score)))) # double argsort for ranks + test.state_mapping = dict(zip(range(len(state_sets)), np.flip(np.argsort((states_by_session != 0).argmax(axis=1))))) for state, trials in enumerate(state_sets): cmap = matplotlib.cm.get_cmap(cmaps[state]) if state < len(cmaps) else matplotlib.cm.get_cmap('Greys') @@ -914,8 +993,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show temp = np.sum(pmfs[:, defined_points]) / (np.sum(defined_points)) state_color = colors[int(temp * 101 - 1)] - ax1.fill_between(range(1, 1 + test.results[0].n_sessions), state_mapping[state] - 0.5, - state_mapping[state] + states_by_session[state] - 0.5, color=state_color) + ax1.fill_between(range(1, 1 + test.results[0].n_sessions), test.state_mapping[state] - 0.5, + test.state_mapping[state] + states_by_session[state] - 0.5, color=state_color) else: n_points = 150 @@ -924,11 +1003,11 @@ def state_development(test, state_sets, indices, save=True, save_append='', show for k in range(n_points-1): ax1.fill_between([points[k], points[k+1]], - state_mapping[state] - 0.5, [state_mapping[state] + interpolation[k] - 0.5, state_mapping[state] + interpolation[k+1] - 0.5], color=cmap(0.3 + 0.7 * k / n_points)) - ax1.annotate(state_mapping[state] + 1, (test.results[0].n_sessions + 0.1, state_mapping[state] - 0.15), fontsize=22, annotation_clip=False) + test.state_mapping[state] - 0.5, [test.state_mapping[state] + interpolation[k] - 0.5, test.state_mapping[state] + interpolation[k+1] - 0.5], color=cmap(0.3 + 0.7 * k / n_points)) + ax1.annotate(test.state_mapping[state] + 1, (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False) if test.results[0].name.startswith('GLM_Sim_'): - ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r') + ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][test.state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r') alpha_level = 0.3 ax2.axvline(0.5, c='grey', alpha=alpha_level, zorder=4) @@ -943,23 +1022,23 @@ def state_development(test, state_sets, indices, save=True, save_append='', show # defined_points[[0, 1, -2, -1]] = True if separate_pmf: for j, pmf in zip(session_js, pmfs): - ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions)) - ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions)) + ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions)) + ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions)) all_pmfs.append((defined_points, pmfs)) else: temp = np.percentile(pmfs, [2.5, 97.5], axis=0) - ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmfs[:, defined_points].mean(axis=0) - 0.5 + state_mapping[state], color=state_color) - ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmfs[:, defined_points].mean(axis=0) - 0.5 + state_mapping[state], ls='', ms=7, marker='*', color=state_color) - ax2.fill_between(np.where(defined_points)[0] / (len(defined_points)-1), temp[1, defined_points] - 0.5 + state_mapping[state], temp[0, defined_points] - 0.5 + state_mapping[state], alpha=0.2, color=state_color) + ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmfs[:, defined_points].mean(axis=0) - 0.5 + test.state_mapping[state], color=state_color) + ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmfs[:, defined_points].mean(axis=0) - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=state_color) + ax2.fill_between(np.where(defined_points)[0] / (len(defined_points)-1), temp[1, defined_points] - 0.5 + test.state_mapping[state], temp[0, defined_points] - 0.5 + test.state_mapping[state], alpha=0.2, color=state_color) all_pmfs.append((defined_points, pmfs[:, defined_points].mean(axis=0))) if test.results[0].name.startswith('GLM_Sim_'): sim_pmf = weights_to_pmf(truth['weights'][state]) - ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), sim_pmf[defined_points] - 0.5 + truth['state_map'][state_mapping[state]], color='r') + ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), sim_pmf[defined_points] - 0.5 + truth['state_map'][test.state_mapping[state]], color='r') - ax2.axhline(state_mapping[state] + 0.5, c='k') - ax2.axhline(state_mapping[state], c='grey', alpha=alpha_level, zorder=4) - ax1.axhline(state_mapping[state] + 0.5, c='grey', alpha=alpha_level, zorder=4) + ax2.axhline(test.state_mapping[state] + 0.5, c='k') + ax2.axhline(test.state_mapping[state], c='grey', alpha=alpha_level, zorder=4) + ax1.axhline(test.state_mapping[state] + 0.5, c='grey', alpha=alpha_level, zorder=4) if not test.results[0].name.startswith('Sim_'): perf = np.zeros(test.results[0].n_sessions) @@ -980,6 +1059,14 @@ def state_development(test, state_sets, indices, save=True, save_append='', show ax0.fill_between(range(1, 1 + test.results[0].n_sessions), perf - 0.5, -0.5, color='k') durs, state_types = state_type_durs(states_by_session, all_pmfs) + + # how many states per session per type + states_per_sess = np.sum(states_by_session > 0.05, axis=0) + if durs[0] > 0 and durs[1] > 0 and durs[2] > 1: + states_per_type = [np.mean(states_per_sess[:durs[0]]), np.mean(states_per_sess[durs[0]:durs[0]+durs[1]]), np.mean(states_per_sess[durs[0]+durs[1]:])] + else: + states_per_type = [] + # other statistics dur_counter = 1 contrast_intro_types = [0, 0, 0, 0] state, when = np.where(states_by_session > 0.05) @@ -987,7 +1074,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show covered_states = [] for i, d in enumerate(durs): if type_coloring: - ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3) + ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type2color[i], zorder=0, alpha=0.3) dur_counter += d # find out during which state type which contrast was introduced @@ -1041,13 +1128,86 @@ def state_development(test, state_sets, indices, save=True, save_append='', show plt.tight_layout() if save: print("saving with {} dpi".format(dpi)) - plt.savefig("dynamic_GLM_figures/meta state development_{}_{}{}.png".format(test.results[0].name, separate_pmf, save_append), dpi=dpi) + plt.savefig("dynamic_GLM_figures/meta_state_development_{}_{}{}.png".format(test.results[0].name, separate_pmf, save_append), dpi=dpi) if show: plt.show() else: plt.close() - return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage + return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type + +def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, all_pmfs, title=""): + """ + Take a set of states, and plot out their PMFs on all sessions on which they occur. + See how different they really are. + + Takes states_by_session and all_pmfs as input from state_development + """ + colors = ['blue', 'orange', 'green', 'black', 'red'] + assert len(states2compare) <= len(colors) + # subtract 1 to get internal numbering + states2compare = [s - 1 for s in states2compare] + # transform desired states into the actual numbering, before ordering by bias + states2compare = [key for key in test.state_mapping.keys() if test.state_mapping[key] in states2compare] + + sessions = np.where(states_by_session[states2compare].sum(0))[0] + + for i, state in enumerate(states2compare): + counter = 0 + for j, session in enumerate(sessions): + plt.subplot(1, len(sessions), j + 1) + if i == 0: + plt.title(session) + if states_by_session[state, session] > 0: + plt.plot(np.where(all_pmfs[state][0])[0], (all_pmfs[state][1][counter])[all_pmfs[state][0]], c=colors[i]) + counter += 1 + plt.ylim(0, 1) + if j != 0: + plt.gca().set_yticks([]) + # plt.tight_layout() + if title != "": + plt.savefig(title) + plt.show() + + +def compare_weights(test, state_sets, indices, states2compare, states_by_session, title=""): + """ + Take a set of states, and plot out their weights on all sessions on which they occur. + See how different they really are. + Similar to compare_pmfs + + Takes states_by_session as input from state_development + """ + colors = ['blue', 'orange', 'green', 'black', 'red'] + assert len(states2compare) <= len(colors) + # subtract 1 to get internal numbering + states2compare = [s - 1 for s in states2compare] + # transform desired states into the actual numbering, before ordering by bias + states2compare = [key for key in test.state_mapping.keys() if test.state_mapping[key] in states2compare] + + sessions = np.where(states_by_session[states2compare].sum(0))[0] + + state_counter = -1 + for state, trials in enumerate(state_sets): + if state not in states2compare: + continue + state_counter += 1 + _, weights = state_weights(test, trials, indices) + counter = 0 + for j, session in enumerate(sessions): + plt.subplot(1, len(sessions), j + 1) + if state == 0: + plt.title(session) + if states_by_session[state, session] > 0: + plt.plot(weights[counter], c=colors[state_counter]) + counter += 1 + plt.ylim(-6, 6) + if j != 0: + plt.gca().set_yticks([]) + # plt.tight_layout() + if title != "": + plt.savefig(title) + plt.show() def smart_divide(a, b): @@ -1097,11 +1257,13 @@ if __name__ == "__main__": fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] if fit_type == 'bias': loading_info = json.load(open("canonical_infos_bias.json", 'r')) + r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r')) elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) + r_hats = json.load(open("canonical_info_r_hats.json", 'r')) subjects = list(loading_info.keys()) - r_hats = [] + r_hats = {} # R^hat tests # test = MCMC_result_list([fake_result(100) for i in range(8)]) @@ -1112,10 +1274,9 @@ if __name__ == "__main__": check_r_hats = False if check_r_hats: subjects = list(loading_info.keys()) - subjects = ['KS014'] for subject in subjects: - # if subject.startswith('GLM'): - # continue + if subject.startswith('GLM_Sim_07') or subject.startswith('GLM_Sim_11'): + continue print("_________________________________") print(subject) fit_num = loading_info[subject]['fit_nums'][-1] @@ -1132,7 +1293,7 @@ if __name__ == "__main__": # mins = find_good_chains(chains[:, :-1].reshape(32, chains.shape[1] // 2)) sol, final_r_hat = find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=chains1.shape[0] // 2) - r_hats.append((subject, final_r_hat)) + r_hats[subject] = final_r_hat loading_info[subject]['ignore'] = sol print(r_hats) @@ -1233,171 +1394,88 @@ def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False): plt.close() -def pmf_type(pmf): - if pmf[-1] - pmf[0] < 0.2: - return 0 - elif pmf[-1] - pmf[0] < 0.6:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1: - return 1 - else: - return 2 - - -type2color = {0: 'green', 1: 'blue', 2: 'red'} +def two_sample_binom_test(s1, s2): + p1, p2 = np.mean(s1), np.mean(s2) + n1, n2 = s1.size, s2.size + p = (n1 * p1 + n2 * p2) / (n1 + n2) + if p == 1. or p == 0.: + return 0., 0.99 + z = (p1 - p2) / np.sqrt(p * (1 - p) * (1 / n1 + 1 / n2)) + p_val = (1 - norm.cdf(np.abs(z))) * 2 + return z, p_val + + +def compare_performance(cnas, contrasts=(1, 1), title=""): + # compare the performance on a certain contrast + # cnas is set of data for all sessions (from contrasts_plot) + # contrasts is a tuple specifying the contrast of interest (first or second column, contrast strength) + plt.figure(figsize=(16*0.75, 9*0.75)) + for i in range(len(cnas) - 1): + if cnas[i][cnas[i][:, contrasts[0]] == contrasts[1], -1].shape[0] == 0 or cnas[i+1][cnas[i+1][:, contrasts[0]] == contrasts[1], -1].shape[0] == 0: + continue + perf1, perf2 = np.mean(cnas[i][cnas[i][:, contrasts[0]] == contrasts[1], -1]), np.mean(cnas[i+1][cnas[i+1][:, contrasts[0]] == contrasts[1], -1]) + p = two_sample_binom_test(cnas[i][cnas[i][:, contrasts[0]] == contrasts[1], -1], cnas[i+1][cnas[i+1][:, contrasts[0]] == contrasts[1], -1])[1] + factor = (perf1 > perf2) * 2 - 1 + plt.annotate(xy=(i + 0.5, (perf1 + perf2) / 2 + factor * 0.025), text=("{:.3f}".format(p))[1:]) + color = 'red' if p < 0.05 else 'blue' + plt.plot([i, i+1], [perf1, perf2], color=color) + sns.despine() + cont_string = "Contrast {} ".format('right' if contrasts[0] else 'left') + plt.title(cont_string + str(contrasts[1])) + plt.ylim(-0.07, 1.07) + plt.tight_layout() + if title != "": + plt.savefig(title) + plt.show() -if False: - all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb')) - plt.figure(figsize=(16, 9)) - for i, pmf in enumerate(all_changing_pmfs): - plt.subplot(4, 7, i + 1) - for p in pmf[1]: - plt.plot(np.where(pmf[0])[0], p[pmf[0]], color=type2color[pmf_type(p)]) - plt.ylim(0, 1) +def type_hist(data): + highest = int(data.max()) + if (data % 1 == 0).all(): + bins = np.arange(highest + 2) - 0.5 + else: + bins = np.histogram(data)[1] + hist_max = 0 + for i in range(data.shape[1]): + hist_max = max(hist_max, np.histogram(data[:, i], bins)[0].max()) + + plt.subplot(3, 1, 1) + assert np.histogram(data[:, 0])[0].sum() == np.histogram(data[:, 0], bins)[0].sum() + plt.hist(data[:, 0], alpha=1/3, label="type 1", align='mid', bins=bins) + plt.xlim(-0.5, highest + 1) + plt.ylim(0, hist_max + 1) + + plt.subplot(3, 1, 2) + assert np.histogram(data[:, 1])[0].sum() == np.histogram(data[:, 1], bins)[0].sum() + plt.hist(data[:, 1], alpha=1/3, label="type 2", align='mid', bins=bins) + plt.xlim(-0.5, highest + 1) + plt.ylim(0, hist_max + 1) + + plt.subplot(3, 1, 3) + assert np.histogram(data[:, 2])[0].sum() == np.histogram(data[:, 2], bins)[0].sum() + plt.hist(data[:, 2], alpha=1/3, label="type 3", align='mid', bins=bins) + plt.xlim(-0.5, highest + 1) + plt.ylim(0, hist_max + 1) + + # plt.legend() + plt.show() - sns.despine() - if i+1 != 22: - plt.gca().set_xticks([]) - plt.gca().set_yticks([]) - else: - plt.xlabel("Contrasts", size=22) - plt.ylabel("P(rightwards)", size=22) - plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16) - plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16) - if i + 1 == 30: - break +if 0: + all_intros = pickle.load(open("all_intros.p", 'rb')) + all_intros_div = pickle.load(open("all_intros_div.p", 'rb')) + all_states_per_type = pickle.load(open("all_states_per_type.p", 'rb')) - plt.tight_layout() - plt.savefig("changing pmfs") - plt.show() + # There are 5 mice with 0 type 2 intros, but only 3 mice with no type 2 stats. + # That is because they come up, but don't explain the necessary 50% to start phase 2. + type_hist(all_intros) + type_hist(all_intros_div) + type_hist(all_states_per_type) quit() - type_2_assyms = [] - tick_size = 14 - label_size = 26 - all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb')) - plt.figure(figsize=(16, 9)) - plt.subplot(1, 3, 1) - counter = [[0, 0], [0, 0]] - save_title = "all types" if False else "KS014 types" - if save_title == "KS014 types": - all_first_pmfs = {'KS014': all_first_pmfs['KS014']} - - for key in all_first_pmfs: - x = all_first_pmfs[key] - if type(x[0]) == int: - continue - linestyle = '-' if x[2] == 0 else '--' - plt.plot(np.where(x[1])[0], x[0][x[1]], linestyle=linestyle, c='g') - plt.ylim(0, 1) - plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size) - plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) - plt.gca().spines[['right', 'top']].set_visible(False) - plt.xlim(0, 10) - plt.xticks(rotation=45) - plt.gca().set_ylabel("P(rightwards)", size=label_size) - - plt.subplot(1, 3, 2) - for key in all_first_pmfs: - x = all_first_pmfs[key] - if type(x[3]) == int: - continue - type_2_assyms.append(np.abs(x[3][0] + x[3][-1] - 1)) - linestyle = '-' if x[5] == 0 else '--' - counter[0][0 if x[5] == 0 else 1] += 1 - if linestyle == '--': - continue - plt.plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b') - plt.gca().set_yticks([]) - plt.ylim(0, 1) - plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size) - plt.gca().spines[['right', 'top']].set_visible(False) - plt.xticks(rotation=45) - plt.xlim(0, 10) - plt.gca().set_xlabel("Contrasts", size=label_size) - - plt.subplot(1, 3, 3) - for key in all_first_pmfs: - x = all_first_pmfs[key] - if type(x[6]) == int: - continue - linestyle = '-' if x[8] == 0 else '--' - counter[1][0 if x[8] == 0 else 1] += 1 - if linestyle == '--': - continue - plt.plot(np.where(x[7])[0], x[6][x[7]], linestyle=linestyle, c='r') - plt.gca().set_yticks([]) - plt.ylim(0, 1) - plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size) - plt.gca().spines[['right', 'top']].set_visible(False) - plt.xlim(0, 10) - plt.xticks(rotation=45) - - print(counter) - plt.tight_layout() - plt.savefig(save_title) - plt.show() - if save_title == "KS014 types": - quit() - - counter = 0 - fig, ax = plt.subplots(1, 3, figsize=(16, 9)) - for key in all_first_pmfs: - x = all_first_pmfs[key] - if type(x[3]) == int: - continue - linestyle = '-' if x[5] == 0 else '--' - if linestyle == '--': - continue - if np.abs(x[3][0] + x[3][-1] - 1) <= 0.1: - counter += 1 - use_ax = 2 - else: - use_ax = int(x[3][0] > 1 - x[3][-1]) - - ax[use_ax].plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b') - ax[0].set_ylim(0, 1) - ax[0].set_xlim(0, 10) - ax[0].spines[['right', 'top']].set_visible(False) - ax[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) - ax[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) - ax[0].set_ylabel("P(rightwards)", size=label_size) - - ax[1].set_ylim(0, 1) - ax[1].set_xlim(0, 10) - ax[1].set_yticks([]) - ax[1].spines[['right', 'top']].set_visible(False) - ax[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) - ax[1].set_xlabel("Contrasts", size=label_size) - - ax[2].set_ylim(0, 1) - ax[2].set_xlim(0, 10) - ax[2].set_yticks([]) - ax[2].spines[['right', 'top']].set_visible(False) - ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) - print(counter) - plt.tight_layout() - plt.savefig("differentiate type 2") - plt.show() - quit() if __name__ == "__main__": - # visualise pmf types - # lapses = [0.1, 0.2, 0.25, 0.33, 0.4, 0.45, 0.5, 0.55, 0.66, 0.9] - # test_pmf = np.zeros(4) - # for i, lapse_l in enumerate(lapses): - # plt.subplot(1, 10, 1+i) - # if i != 0: - # plt.gca().set_yticklabels([]) - # plt.ylim(0, 1) - # test_pmf[:2] = lapse_l - # for lapse_r in np.linspace(0.02, 0.98, 33): - # test_pmf[2:] = lapse_r - # plt.plot([0, 1, 9, 10], test_pmf, c=type_colours[pmf_type(test_pmf)]) - # plt.show() - # quit() - fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] if fit_type == 'bias': loading_info = json.load(open("canonical_infos_bias.json", 'r')) @@ -1405,8 +1483,11 @@ if __name__ == "__main__": elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) r_hats = json.load(open("canonical_info_r_hats.json", 'r')) - no_good_pcas = ['NYU-06', 'SWC_023'] # no good rhat: 'ibl_witten_13' - subjects = ['KS014'] # list(loading_info.keys()) + no_good_pcas = ['NYU-06', 'SWC_023'] + subjects = list(loading_info.keys()) + # subjects = ['ZM_1897'] + + # meh pmfs: KS021 print(subjects) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] dur = 'yes' @@ -1433,11 +1514,16 @@ if __name__ == "__main__": abs_state_durs = [] all_first_pmfs = {} + all_first_pmfs_typeless = {} all_pmf_diffs = [] all_pmf_asymms = [] all_pmfs = [] all_changing_pmfs = [] + all_changing_pmf_names = [] all_intros = [] + all_intros_div = [] + all_states_per_type = [] + regressions = [] for subject in subjects: @@ -1445,48 +1531,66 @@ if __name__ == "__main__": continue print(subject) - results = [] try: + continue test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb')) print('loaded canoncial result') + mode_specifier = '' mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) - quit() state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) + # lapse differential # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) # training overview - states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) + states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) + regression = pmf_regressions(states, pmfs, durs) + regressions.append(regression) + continue + # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject)) + # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject)) + # quit() + all_first_pmfs_typeless[subject] = [] + for pmf in pmfs: + all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0])) all_intros.append(undiv_intros) + all_intros_div.append(intros_by_type) + if states_per_type != []: + all_states_per_type.append(states_per_type) + intros_by_type_sum += intros_by_type first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs) for pmf in changing_pmfs: if type(pmf[0]) == int: continue + all_changing_pmf_names.append(subject) all_changing_pmfs.append(pmf) all_first_pmfs[subject] = first_pmfs for pmf in pmfs: + all_pmfs.append(pmf) for p in pmf[1]: all_pmf_diffs.append(p[-1] - p[0]) all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1)) - all_pmfs.append(p) contrast_intro_types.append(contrast_intro_type) + continue # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False) # session overview - consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) - consistencies /= consistencies[0, 0] - contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies) + # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + # consistencies /= consistencies[0, 0] + # c_n_a, rt_data = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=True) + # compare_performance(cnas, (1, 1), title="{} contrast {} performance test".format(subject, (1, 1))) + # compare_performance(cnas, (0, 0.848)) # duration of different state types (and also percentage of type activities) - abs_state_durs.append(durs) - simplex_durs = np.array(durs).reshape(1, 3) - print(simplex_durs / np.sum(simplex_durs)) - from simplex_plot import projectSimplex - print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None])) - continue + # abs_state_durs.append(durs) + # continue + # simplex_durs = np.array(durs).reshape(1, 3) + # print(simplex_durs / np.sum(simplex_durs)) + # from simplex_plot import projectSimplex + # print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None])) # compute state type proportions and split the pmfs accordingly # ret, trans, pmf_types = state_cluster_interpolation(states, pmfs) @@ -1552,47 +1656,46 @@ if __name__ == "__main__": # plt.savefig("temp") # plt.close() - quit() - - dim = 3 - ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_vecs='dists', dim=dim) - xy = np.vstack([dimreduc[i] for i in range(dim)]) - from scipy.stats import gaussian_kde - z = gaussian_kde(xy)(xy) - pickle.dump((xy, z), open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'wb')) - - if 'mode prob level' not in loading_info[subject]: - print(subject) - xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb')) - happy = False - while not happy: - print() - print("Pick level") - prob_level = float(input()) - print("Level is {}".format(prob_level)) - print("# of samples: {}".format((z > prob_level).sum())) - mode_indices = np.where(z > prob_level)[0] - if (z > prob_level).sum() > 0: - print(xy[0][mode_indices].min(), xy[0][mode_indices].max(), xy[1][mode_indices].min(), xy[1][mode_indices].max()) - print("Happy?") - happy = 'yes' == input() - print("Subset by factor?") - if input() == 'yes': - print("Factor?") - print(mode_indices.shape) - factor = int(input()) - mode_indices = mode_indices[::factor] - print(mode_indices.shape) - loading_info[subject]['mode prob level'] = prob_level - - pickle.dump(mode_indices, open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'wb')) - consistencies = test.consistency_rsa(indices=mode_indices) - pickle.dump(consistencies, open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'wb')) - - string_prefix = '' + # dim = 3 + # ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_vecs='dists', dim=dim) + # xy = np.vstack([dimreduc[i] for i in range(dim)]) + # from scipy.stats import gaussian_kde + # z = gaussian_kde(xy)(xy) + # pickle.dump((xy, z), open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'wb')) + # # + # string_prefix = 'second_' + # + # print(subject) + # xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb')) + # quit() + # happy = False + # while not happy: + # print() + # print("Pick level") + # prob_level = float(input()) + # print("Level is {}".format(prob_level)) + # print("# of samples: {}".format((z > prob_level).sum())) + # mode_indices = np.where(z > prob_level)[0] + # if (z > prob_level).sum() > 0: + # print(xy[0][mode_indices].min(), xy[0][mode_indices].max(), xy[1][mode_indices].min(), xy[1][mode_indices].max()) + # print("Happy?") + # happy = 'yes' == input() + # print("Subset by factor?") + # if input() == 'yes': + # print("Factor?") + # print(mode_indices.shape) + # factor = int(input()) + # mode_indices = mode_indices[::factor] + # print(mode_indices.shape) + # loading_info[subject]['mode prob level'] = prob_level + # + # pickle.dump(mode_indices, open("multi_chain_saves/{}mode_indices_{}_{}.p".format(string_prefix, subject, fit_type), 'wb')) + # consistencies = test.consistency_rsa(indices=mode_indices) + # pickle.dump(consistencies, open("multi_chain_saves/{}consistencies_{}_{}.p".format(string_prefix, subject, fit_type), 'wb')) mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}.p".format(string_prefix, subject, fit_type), 'rb')) - consistencies = pickle.load(open("multi_chain_saves/{}consistencies_{}_{}.p".format(string_prefix, subject, fit_type), 'rb')) + consistencies = pickle.load(open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(string_prefix, subject, fit_type), 'rb')) + session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs])) import scipy.cluster.hierarchy as hc consistencies /= consistencies[0, 0] @@ -1620,10 +1723,9 @@ if __name__ == "__main__": for x, y in zip(b, c): state_sets.append(np.where(a == x)[0]) print("dumping state set") - pickle.dump(state_sets, open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'wb')) - quit() + pickle.dump(state_sets, open("multi_chain_saves/{}state_sets_{}_{}.p".format(string_prefix, subject, fit_type), 'wb')) states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='_{}{}'.format(string_prefix, criterion), show=True, separate_pmf=True) - contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save_append='_{}{}'.format(string_prefix, criterion), save=True, show=True) + # contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save_append='_{}{}'.format(string_prefix, criterion), save=True, show=True) # I think this is about finding out where states start and how long they last # state_id, session_appears = np.where(states) @@ -1660,18 +1762,14 @@ if __name__ == "__main__": plt.tight_layout() plt.savefig("peter figures/{}clustered_trials_{}_{}".format(string_prefix, subject, 'criteria comp').replace('.', '_')) - plt.show() + plt.close() except FileNotFoundError as e: - continue print(e) - r_hat = 1.5 - for r in r_hats: - if r[0] == subject: - r_hat = r[1] + continue print('no canoncial result') - print(r_hat) - if r_hat >= 1.05: + print(r_hats[subject]) + if r_hats[subject] >= 1.05: print("Skipping") continue else: @@ -1729,6 +1827,13 @@ if __name__ == "__main__": # pickle.dump(all_first_pmfs, open("pmfs_temp.p", 'wb')) # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb')) + # pickle.dump(all_changing_pmf_names, open("changing_pmf_names.p", 'wb')) + # pickle.dump(all_first_pmfs_typeless, open("all_first_pmfs_typeless.p", 'wb')) + # pickle.dump(all_pmfs, open("all_pmfs.p", 'wb')) + # pickle.dump(all_intros, open("all_intros.p", 'wb')) + # pickle.dump(all_intros_div, open("all_intros_div.p", 'wb')) + # pickle.dump(all_states_per_type, open("all_states_per_type.p", 'wb')) + # pickle.dump(regressions, open("regressions.p", 'wb')) # # a = [x for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2] # b = [y for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2] @@ -1758,6 +1863,16 @@ if __name__ == "__main__": plotSimplex(np.array(abs_state_durs), c='k', show=True) + plt.hist(abs_state_durs.sum(1), color='grey', bins=12) + sns.despine() + plt.xticks(size=26) + plt.yticks(size=26) + plt.ylabel("# of mice", size=40) + plt.xlabel('# of sessions', size=40) + plt.tight_layout() + plt.savefig("session_num_hist.png", dpi=300, transparent=True) + plt.show() + if False: ax[0].set_ylim(0, 1) ax[1].set_ylim(0, 1) diff --git a/index_mice.py b/index_mice.py index 59af129239564a66b3a9965c4f03454c26f8e81a..ebb58b840677ea1ffc0cdd34584d08c4403b3a75 100644 --- a/index_mice.py +++ b/index_mice.py @@ -34,7 +34,7 @@ for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"): local_dict = bias_subinfo if subject not in local_dict: local_dict[subject] = {"seeds": [], "fit_nums": [], "chain_num": 0} - if int(chain_num) == 0: + if int(chain_num) == 0: # if this is the first file of that chain, save some info local_dict[subject]["seeds"].append(seed) local_dict[subject]["fit_nums"].append(fit_num) else: @@ -62,7 +62,6 @@ for s in prebias_subinfo.keys(): prebias_subinfo[s]["seeds"] = new_seeds - big = [] non_big = [] sim_subjects = [] diff --git a/pmf_analysis.py b/pmf_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7a0338c09ff15877860996f52d72999c16cd15 --- /dev/null +++ b/pmf_analysis.py @@ -0,0 +1,343 @@ +import pickle +import matplotlib.pyplot as plt +import numpy as np + + +type2color = {0: 'green', 1: 'blue', 2: 'red'} +all_conts = np.array([-1, -0.5, -.25, -.125, -.062, 0, .062, .125, .25, 0.5, 1]) + + +def pmf_type(pmf): + if pmf[-1] - pmf[0] < 0.2: + return 0 + # elif pmf[-1] - pmf[0] < 0.4:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1: + elif max(pmf[0], 1 - pmf[-1]) > 0.37 or pmf[-1] - pmf[0] < 0.55: + if pmf[0] > 0.37 and 1 - pmf[-1] > 0.37: + print("___________Troublesome PMF: {}___________".format(pmf)) + return 0 + return 1 + else: + return 2 + + +if __name__ == "__main__": + all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless", 'rb')) + all_pmfs = pickle.load(open("all_pmfs.p", 'rb')) + + fewer_states_side = [] + for key in all_first_pmfs_typeless: + animal_biases = np.zeros(2) + for defined_points, pmf in all_first_pmfs_typeless[key]: + bias = np.mean(pmf[defined_points]) + if bias > 0.55: + animal_biases[0] += 1 + elif bias < 0.45: + animal_biases[1] += 1 + fewer_states_side.append(np.min(animal_biases / animal_biases.sum())) + plt.hist(fewer_states_side) + sns.despine() + plt.tight_layout() + plt.savefig("./meeting_figures/proportion_other_bias") + plt.show() + + total_counter = 0 + bias_counter = 0 + tendency_counter = 0 + lapse_counter = 0 + for pmf in all_pmfs: + max_b, min_b = 0, 1 + max_tendency, min_tendency = 0, 1 + max_lapse_diff, min_lapse_diff = 0, 1 + for p in pmf[1]: + if pmf[0][5]: # if this part of the pmf is defined, just take it + bias = p[5] + deviation = 0 + while True: # just take the closest thing + deviation += 1 + if pmf[0][5 - deviation] and pmf[0][5 + deviation]: + bias = (p[5 - deviation] + p[5 + deviation]) / 2 + break + max_b = max(max_b, bias) + min_b = min(min_b, bias) + max_tendency = max(max_tendency, np.mean(p[pmf[0]])) + min_tendency = min(min_tendency, np.mean(p[pmf[0]])) + max_lapse_diff = max(max_lapse_diff, p[0] + p[-1] - 1) + min_lapse_diff = min(min_lapse_diff, p[0] + p[-1] - 1) + bias_changed = max_b > 0.55 and min_b < 0.45 + tendency_changed = max_tendency > 0.55 and min_tendency < 0.45 + lapse_changed = max_lapse_diff > 0.1 and min_lapse_diff < -0.1 + bias_counter += bias_changed + tendency_counter += tendency_changed + lapse_counter += lapse_changed + if bias_changed or tendency_changed or lapse_changed: + total_counter += 1 + for p in pmf[1]: + plt.plot(np.where(pmf[0])[0], p[pmf[0]]) + plt.title("bias: {}, tendency: {}, lapse: {}".format(bias_changed, tendency_changed, lapse_changed)) + plt.ylim(0, 1) + plt.axvline(5, color='grey') + plt.axhline(0.5, color='grey') + plt.savefig("./meeting_figures/bias_change_{}".format(total_counter)) + plt.close() + print(total_counter) + print(bias_counter) + print(tendency_counter) + print(lapse_counter) + + pmf_ranges = [] + for key in all_first_pmfs_typeless: + for defined_points, pmf in all_first_pmfs_typeless[key]: + if pmf_type(pmf) == 2: + pmf_ranges.append(pmf[-1] - pmf[0]) + # if pmf_ranges[-1] < 0.6: + # plt.plot(pmf) + # plt.title(pmf_ranges[-1]) + # plt.ylim(0, 1) + # plt.show() + plt.hist(pmf_ranges, bins=40) + plt.title("Type 2 ranges") + plt.show() + + lapses = [] + for key in all_first_pmfs_typeless: + for defined_points, pmf in all_first_pmfs_typeless[key]: + if pmf_type(pmf) != 0: + lapses.append(max(pmf[0], 1 - pmf[-1])) + plt.hist(lapses, bins=40) + plt.title("Higher lapse rate of type != 0") + plt.show() + + n_rows, n_cols = 5, 6 + _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) + for i, key in enumerate(all_first_pmfs_typeless): + if i == n_rows * 2: + break + axs[(i * 3) // n_cols, (i * 3) % n_cols].set_ylabel(key, size=17) + for defined_points, pmf in all_first_pmfs_typeless[key]: + axs[(i * 3 + pmf_type(pmf)) // n_cols, (i * 3 + pmf_type(pmf)) % n_cols].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)]) + for i, ax in enumerate(axs): + for j, a in enumerate(ax): + a.spines[['right', 'top']].set_visible(False) + a.set_ylim(0, 1) + if i != n_rows - 1 or j != 0: + a.set_xticks([]) + a.set_yticks([]) + else: + tick_size = 12 + a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + plt.tight_layout() + plt.savefig("animals 1") + plt.show() + + n_rows, n_cols = 5, 6 + _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) + for i, key in enumerate(all_first_pmfs_typeless): + if i < n_rows * 2: + continue + elif i < n_rows * 4: + i = i - n_rows * 2 + else: + break + if i == n_rows * 2: + break + axs[(i * 3) // n_cols, (i * 3) % n_cols].set_ylabel(key, size=17) + for defined_points, pmf in all_first_pmfs_typeless[key]: + axs[(i * 3 + pmf_type(pmf)) // n_cols, (i * 3 + pmf_type(pmf)) % n_cols].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)]) + for i, ax in enumerate(axs): + for j, a in enumerate(ax): + a.spines[['right', 'top']].set_visible(False) + a.set_ylim(0, 1) + if i != n_rows - 1 or j != 0: + a.set_xticks([]) + a.set_yticks([]) + else: + tick_size = 12 + a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + plt.tight_layout() + plt.savefig("animals 2") + plt.show() + + n_rows, n_cols = 5, 6 + _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) + for i, key in enumerate(all_first_pmfs_typeless): + if i < n_rows * 4: + continue + elif i < n_rows * 6: + i = i - n_rows * 4 + else: + break + if i == n_rows * 4: + break + axs[(i * 3) // n_cols, (i * 3) % n_cols].set_ylabel(key, size=17) + for defined_points, pmf in all_first_pmfs_typeless[key]: + axs[(i * 3 + pmf_type(pmf)) // n_cols, (i * 3 + pmf_type(pmf)) % n_cols].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)]) + for i, ax in enumerate(axs): + for j, a in enumerate(ax): + a.spines[['right', 'top']].set_visible(False) + a.set_ylim(0, 1) + if i != n_rows - 1 or j != 0: + a.set_xticks([]) + a.set_yticks([]) + else: + tick_size = 12 + a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + plt.tight_layout() + plt.savefig("animals 3") + plt.show() + + n_rows, n_cols = 5, 6 + _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) + for i, key in enumerate(all_first_pmfs_typeless): + if i < n_rows * 6: + continue + elif i < n_rows * 8: + i = i - n_rows * 6 + else: + break + if i == n_rows * 6: + break + axs[(i * 3) // n_cols, (i * 3) % n_cols].set_ylabel(key, size=17) + for defined_points, pmf in all_first_pmfs_typeless[key]: + axs[(i * 3 + pmf_type(pmf)) // n_cols, (i * 3 + pmf_type(pmf)) % n_cols].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)]) + for i, ax in enumerate(axs): + for j, a in enumerate(ax): + a.spines[['right', 'top']].set_visible(False) + a.set_ylim(0, 1) + if i != n_rows - 1 or j != 0: + a.set_xticks([]) + a.set_yticks([]) + else: + tick_size = 12 + a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + plt.tight_layout() + plt.savefig("animals 4") + plt.show() + + # Collection of PMFs which change state type + all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb')) + all_changing_pmf_names = pickle.load(open("changing_pmf_names.p", 'rb')) + + plt.figure(figsize=(16, 9)) + for i, (pmf, name) in enumerate(zip(all_changing_pmfs, all_changing_pmf_names)): + plt.subplot(4, 7, i + 1) + plt.title(name) + for p in pmf[1]: + plt.plot(np.where(pmf[0])[0], p[pmf[0]], color=type2color[pmf_type(p)]) + plt.ylim(0, 1) + + sns.despine() + if i+1 != 22: + plt.gca().set_xticks([]) + plt.gca().set_yticks([]) + else: + plt.xlabel("Contrasts", size=22) + plt.ylabel("P(rightwards)", size=22) + plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16) + plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16) + + if i + 1 == 30: + break + + plt.tight_layout() + plt.savefig("changing pmfs") + plt.show() + + # All first PMFs + tick_size = 14 + label_size = 26 + all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb')) + n_rows, n_cols = 1, 3 + _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) + counter = [[0, 0], [0, 0]] + save_title = "all types" if True else "KS014 types" + if save_title == "KS014 types": + all_first_pmfs_typeless = {'KS014': all_first_pmfs_typeless['KS014']} + + for key in all_first_pmfs_typeless: + x = all_first_pmfs_typeless[key] + for pmf in x: + axs[pmf_type(pmf[1])].plot(np.where(pmf[0])[0], pmf[1][pmf[0]], c=type2color[pmf_type(pmf[1])]) + axs[0].set_ylim(0, 1) + axs[1].set_ylim(0, 1) + axs[2].set_ylim(0, 1) + axs[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + axs[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + axs[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + axs[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + axs[1].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + axs[2].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + axs[0].spines[['right', 'top']].set_visible(False) + axs[1].spines[['right', 'top']].set_visible(False) + axs[2].spines[['right', 'top']].set_visible(False) + axs[0].set_xlim(0, 10) + axs[1].set_xlim(0, 10) + axs[2].set_xlim(0, 10) + axs[0].set_ylabel("P(rightwards)", size=label_size) + axs[0].set_xlabel("Contrasts", size=label_size) + + print(counter) + plt.tight_layout() + plt.savefig(save_title) + plt.show() + if save_title == "KS014 types": + quit() + + # Type 2 PMFs + counter = 0 + fig, ax = plt.subplots(1, 3, figsize=(16, 9)) + for key in all_first_pmfs_typeless: + for defined_points, pmf in all_first_pmfs_typeless[key]: + if pmf_type(pmf) != 1: + continue + # linestyle = '-' if x[5] == 0 else '--' + # if linestyle == '--': + # continue + if np.abs(pmf[0] + pmf[-1] - 1) <= 0.1: + counter += 1 + use_ax = 2 + else: + use_ax = int(pmf[0] > 1 - pmf[-1]) + + ax[use_ax].plot(np.where(defined_points)[0], pmf[defined_points], c='b') + ax[0].set_ylim(0, 1) + ax[0].set_xlim(0, 10) + ax[0].spines[['right', 'top']].set_visible(False) + ax[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + ax[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + ax[0].set_ylabel("P(rightwards)", size=label_size) + + ax[1].set_ylim(0, 1) + ax[1].set_xlim(0, 10) + ax[1].set_yticks([]) + ax[1].spines[['right', 'top']].set_visible(False) + ax[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + ax[1].set_xlabel("Contrasts", size=label_size) + + ax[2].set_ylim(0, 1) + ax[2].set_xlim(0, 10) + ax[2].set_yticks([]) + ax[2].spines[['right', 'top']].set_visible(False) + ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + print(counter) + plt.tight_layout() + plt.savefig("differentiate type 2") + plt.show() + + + # visualise pmf types + # lapses = [0.1, 0.2, 0.25, 0.33, 0.4, 0.45, 0.5, 0.55, 0.66, 0.9] + # test_pmf = np.zeros(4) + # for i, lapse_l in enumerate(lapses): + # plt.subplot(1, 10, 1+i) + # if i != 0: + # plt.gca().set_yticklabels([]) + # plt.ylim(0, 1) + # test_pmf[:2] = lapse_l + # for lapse_r in np.linspace(0.02, 0.98, 33): + # test_pmf[2:] = lapse_r + # plt.plot([0, 1, 9, 10], test_pmf, c=type_colours[pmf_type(test_pmf)]) + # plt.show() diff --git a/simplex_plot.py b/simplex_plot.py index dd8ba57d9f3d356cda8e04524404b4833a627a60..d518ca4025c531ac269c63afabe7eec0417be66f 100644 --- a/simplex_plot.py +++ b/simplex_plot.py @@ -50,7 +50,7 @@ def plotSimplex(points, fig=None, P.axis('off') - P.savefig("dur_simplex.png", bbox_inches='tight') + P.savefig("dur_simplex.png", bbox_inches='tight', dpi=300, transparent=True) if show: P.show() else: