diff --git a/__pycache__/analysis_pmf.cpython-310.pyc b/__pycache__/analysis_pmf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9444f63cd5c90bb3a9aa6b583c61b347d9bbe11b Binary files /dev/null and b/__pycache__/analysis_pmf.cpython-310.pyc differ diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc index e1ab5f64794b712fd3d7753bfc894d24535085b8..f9fd33bf94c4befd1c5937c83c151a084e749d91 100644 Binary files a/__pycache__/analysis_pmf.cpython-37.pyc and b/__pycache__/analysis_pmf.cpython-37.pyc differ diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc index 740060e1aee1c905d0d6e98e7fb190b6806d36f6..4a89c93f30137d03c0b66cc48856b5c8f6b5a111 100644 Binary files a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc and b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc differ diff --git a/__pycache__/simplex_plot.cpython-310.pyc b/__pycache__/simplex_plot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c4249fdd0ebd95ed28071c9f8d211aa84a2bb34 Binary files /dev/null and b/__pycache__/simplex_plot.cpython-310.pyc differ diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc index d40c9990e4c7cc9a2781a81174008f4cb3e8b339..a8537b6f9da6be3317317d69450bbe8640da3809 100644 Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ diff --git a/analysis_pmf.py b/analysis_pmf.py index 637f06cd68c5139c19f7171c90f5b28c4a86c15a..03e55c1d17a94fbdfd7a0b8f2c8e239d4a076e4f 100644 --- a/analysis_pmf.py +++ b/analysis_pmf.py @@ -7,8 +7,30 @@ import seaborn as sns type2color = {0: 'green', 1: 'blue', 2: 'red'} all_conts = np.array([-1, -0.5, -.25, -.125, -.062, 0, .062, .125, .25, 0.5, 1]) +performance_points = np.array([-1, -1, 0, 0]) + +def pmf_to_perf(pmf): + # determine performance of a pmf, but only on the omnipresent strongest contrasts + return np.mean(np.abs(performance_points + pmf[[0, 1, -2, -1]])) + +contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0]) +contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1] + +def weights_to_pmf(weights, with_bias=1): + psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1] + return 1 / (1 + np.exp(psi)) def pmf_type(pmf): + rew = pmf_to_perf(pmf) + if rew < 0.6: + return 0 + elif rew < 0.7827: + return 1 + else: + return 2 + + +def pmf_type_old(pmf): if pmf[-1] - pmf[0] < 0.2: return 0 # elif pmf[-1] - pmf[0] < 0.4:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1: @@ -59,6 +81,30 @@ if __name__ == "__main__": all_pmfs = pickle.load(open("all_pmfs.p", 'rb')) all_bias_flips = pickle.load(open("all_bias_flips.p", 'rb')) + _, axs = plt.subplots(1, 3, figsize=(16, 9)) + for defined_points, pmf in all_first_pmfs_typeless['KS014']: + axs[pmf_type(pmf)].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)]) + if np.round(pmf_to_perf(pmf), 2) in [0.86, 0.81]: + axs[pmf_type(pmf)].annotate(np.round(pmf_to_perf(pmf), 2), (8.7, pmf[-1] - 0.05), size=19) + else: + axs[pmf_type(pmf)].annotate(np.round(pmf_to_perf(pmf), 2), (8.7, pmf[-1] + 0.02), size=19) + for i, a in enumerate(axs): + a.spines[['right', 'top']].set_visible(False) + a.set_ylim(0, 1) + if i != 0: + a.set_xticks([]) + a.set_yticks([]) + else: + tick_size = 14 + a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + a.set_ylabel("P(rightwards)", size=26) + a.set_xlabel("Contrasts", size=26) + plt.tight_layout() + plt.savefig("ks014_pmfs") + plt.close() + + # bias flips # plt.hist(all_bias_flips, bins=np.arange(0, max(all_bias_flips) + 1), color='grey', align='left') # plt.ylabel("# of mice") @@ -159,7 +205,92 @@ if __name__ == "__main__": # plt.xlabel("Higher lapse rate") # plt.show() - lw = 4 + lw = 7 + + # indiv examples + + state_num = 6 + defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + + plt.ylim(0, 1) + plt.xlim(0, 10) + plt.yticks([]) + plt.gca().set_xticks([]) + sns.despine() + plt.gca().spines['left'].set_linewidth(4) + plt.gca().spines['bottom'].set_linewidth(4) + plt.tight_layout() + plt.savefig("single exam 1") + plt.close() + + + state_num = 7 + defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + + plt.ylim(0, 1) + plt.xlim(0, 10) + plt.yticks([]) + plt.gca().set_xticks([]) + sns.despine() + plt.gca().spines['left'].set_linewidth(4) + plt.gca().spines['bottom'].set_linewidth(4) + plt.tight_layout() + plt.savefig("single exam 2") + plt.close() + + + + state_num = 3 + defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + + plt.ylim(0, 1) + plt.xlim(0, 10) + plt.yticks([]) + plt.gca().set_xticks([]) + sns.despine() + plt.gca().spines['left'].set_linewidth(4) + plt.gca().spines['bottom'].set_linewidth(4) + plt.tight_layout() + plt.savefig("single exam 3") + plt.close() + + + state_num = 1 + defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + + plt.ylim(0, 1) + plt.xlim(0, 10) + plt.yticks([]) + plt.gca().set_xticks([]) + sns.despine() + plt.gca().spines['left'].set_linewidth(4) + plt.gca().spines['bottom'].set_linewidth(4) + plt.tight_layout() + plt.savefig("single exam 4") + plt.close() + + + state_num = 4 + defined_points, pmf = all_first_pmfs_typeless['ibl_witten_17'][state_num][0], all_first_pmfs_typeless['ibl_witten_17'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + + plt.ylim(0, 1) + plt.xlim(0, 10) + plt.yticks([]) + plt.gca().set_xticks([]) + sns.despine() + plt.gca().spines['left'].set_linewidth(4) + plt.gca().spines['bottom'].set_linewidth(4) + plt.tight_layout() + plt.savefig("single exam 5") + plt.show() + + + # Simplex example pmfs state_num = 7 defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1] @@ -211,7 +342,7 @@ if __name__ == "__main__": plt.tight_layout() plt.savefig("example type 3") plt.show() - quit() + n_rows, n_cols = 5, 6 _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) @@ -356,7 +487,7 @@ if __name__ == "__main__": # All first PMFs tick_size = 14 label_size = 26 - all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb')) + all_first_pmfs = pickle.load(open("all_first_pmfs.p", 'rb')) n_rows, n_cols = 1, 3 _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) save_title = "all types" if True else "KS014 types" @@ -407,6 +538,9 @@ if __name__ == "__main__": else: use_ax = int(pmf[0] > 1 - pmf[-1]) + pmf_min = min(pmf[0], pmf[1]) + pmf_max = max(pmf[-2], pmf[-1]) + defined_points = np.logical_and(np.logical_and(defined_points, ~ (pmf > pmf_max)), ~ (pmf < pmf_min)) ax[use_ax].plot(np.where(defined_points)[0], pmf[defined_points], c='b') ax[0].set_ylim(0, 1) ax[0].set_xlim(0, 10) diff --git a/behavioral_state_data.py b/behavioral_state_data.py index 2f1f3f0256f7c9d2721d955c74ff2013a0bcdde2..be9aef832e5d8d480af615b126d5528cad2f8229 100644 --- a/behavioral_state_data.py +++ b/behavioral_state_data.py @@ -41,7 +41,7 @@ def get_df(trials): df['feedback'] = df['feedback'].replace(-1, 0) df['signed_contrast'] = df['contrastR'] - df['contrastL'] df['signed_contrast'] = df['signed_contrast'].map(contrast_to_num) - df['response'] += 1 + df['response'] += 1 # this is coded most unintuitely, 0 is rightwards, and 1 is leftwards (which is why I not this variable in other programs) df['block'] = trials['probabilityLeft'] # df['rt'] = data_dict['response_times'] - data_dict['goCue_times'] # RTODO @@ -49,11 +49,7 @@ def get_df(trials): misses = [] to_introduce = [2, 3, 4, 5] -subjects = ['ibl_witten_13', 'ibl_witten_17', 'ibl_witten_18', - 'ibl_witten_19', 'ibl_witten_20', 'ibl_witten_25', 'ibl_witten_26', - 'ibl_witten_27', 'ibl_witten_29', 'CSH_ZAD_001', 'CSH_ZAD_011', - 'CSH_ZAD_019', 'CSH_ZAD_022', 'CSH_ZAD_024', 'CSH_ZAD_025', - 'CSH_ZAD_026', 'CSH_ZAD_029'] +# subjects = ['ZFM-05236', 'ZFM-05245'] # subjects = ['CSHL045', 'CSHL046', 'CSHL049', 'CSHL051', 'CSHL052', 'CSHL053', 'CSHL054', 'CSHL055', # 'CSHL059', 'CSHL061', 'CSHL062', 'CSHL065', 'CSHL066'] # CSHL063 no good # subjects = ['CSHL_007', 'CSHL_014', 'CSHL_015', 'CSHL_018', 'CSHL_019', 'CSHL_020', 'CSH_ZAD_001', 'CSH_ZAD_011', @@ -77,7 +73,7 @@ subjects = ['ibl_witten_13', 'ibl_witten_17', 'ibl_witten_18', # "ibl_witten_06", "ibl_witten_07", "ibl_witten_12", "ibl_witten_13", "ibl_witten_14", "ibl_witten_15", # "ibl_witten_16", "KS003", "KS005", "KS019", "NYU-01", "NYU-02", "NYU-04", "NYU-06", "ZM_1367", "ZM_1369", # "ZM_1371", "ZM_1372", "ZM_1743", "ZM_1745", "ZM_1746"] # zoe's subjects -# subjects = ['CSHL_018'] +subjects = ['ibl_witten_14'] data_folder = 'session_data' # why does CSHL058 not work? @@ -101,7 +97,7 @@ for subject in subjects: print('_____________________') print(subject) - eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2023-01-01'], details=True) + eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2025-01-01'], details=True) start_times = [sess['date'] for sess in sess_info] protocols = [sess['task_protocol'] for sess in sess_info] diff --git a/behavioral_state_data_easier.py b/behavioral_state_data_easier.py new file mode 100644 index 0000000000000000000000000000000000000000..e526d7cc485df51c38469af52b0e669aee793b5c --- /dev/null +++ b/behavioral_state_data_easier.py @@ -0,0 +1,200 @@ +from one.api import ONE +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import pickle +import json + +one = ONE() + + +contrast_to_num = {-1.: 0, -0.5: 1, -0.25: 2, -0.125: 3, -0.0625: 4, 0: 5, 0.0625: 6, 0.125: 7, 0.25: 8, 0.5: 9, 1.: 10} + +dataset_types = ['choice', 'contrastLeft', 'contrastRight', + 'feedbackType', 'probabilityLeft', 'response_times', + 'goCue_times'] + +def get_df(trials): + if np.all(None == trials['choice']) or np.all(None == trials['contrastLeft']) or np.all(None == trials['contrastRight']) or np.all(None == trials['feedbackType']) or np.all(None == trials['probabilityLeft']): # or np.all(None == data_dict['response_times']): + return None, None + d = {'response': trials['choice'], 'contrastL': trials['contrastLeft'], 'contrastR': trials['contrastRight'], 'feedback': trials['feedbackType']} + + df = pd.DataFrame(data=d, index=range(len(trials['choice']))).fillna(0) + df['feedback'] = df['feedback'].replace(-1, 0) + df['signed_contrast'] = df['contrastR'] - df['contrastL'] + df['signed_contrast'] = df['signed_contrast'].map(contrast_to_num) + df['response'] += 1 # this is coded most unintuitely, 0 is rightwards, and 1 is leftwards (which is why I not this variable in other programs) + df['block'] = trials['probabilityLeft'] + # df['rt'] = data_dict['response_times'] - data_dict['goCue_times'] # RTODO + + return df + +misses = [] +to_introduce = [2, 3, 4, 5] + +amiss = ['UCLA034', 'UCLA036', 'UCLA037', 'PL015', 'PL016', 'PL017', 'PL024', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021', 'NR_0027'] +subjects = ['NYU-21'] +fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] +if fit_type == 'bias': + loading_info = json.load(open("canonical_infos_bias.json", 'r')) + r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r')) +elif fit_type == 'prebias': + loading_info = json.load(open("canonical_infos.json", 'r')) + r_hats = json.load(open("canonical_info_r_hats.json", 'r')) +already_fit = list(loading_info.keys()) + +remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit] +print(remaining_subs) + + + +data_folder = 'session_data_test' + +old_style = False +if old_style: + print("Warning, data can have splits") + data_folder = 'session_data_old' +bias_eids = [] + +print("#########################################") +print("Waring, rt's removed, find with # RTODO") +print("#########################################") + +short_subjs = [] +names = [] + +pre_bias = [] +entire_training = [] +for subject in subjects: + print('_____________________') + print(subject) + + if subject in already_fit or subject in amiss: + continue + + trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table') + + # Load training status and join to trials table + training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table') + trials = (trials + .set_index('session') + .join(training.set_index('session')) + .sort_values(by='session_start_time', kind='stable')) + + start_times, indices = np.unique(trials.session_start_time, return_index=True) + start_times = [trials.session_start_time[index] for index in sorted(indices)] + task_protocol, indices = np.unique(trials.task_protocol, return_index=True) + task_protocol = [trials.task_protocol[index] for index in sorted(indices)] + nums, indices = np.unique(trials.session_number, return_index=True) + nums = [trials.session_number[index] for index in sorted(indices)] + eids, indices = np.unique(trials.index, return_index=True) + eids = [trials.index[index] for index in sorted(indices)] + + print("original # of eids {}".format(len(eids))) + + test = [(y, x) for y, x in sorted(zip(start_times, eids))] + pickle.dump(test, open("./{}/{}_session_names.p".format(data_folder, subject), "wb")) + + performance = np.zeros(len(eids)) + easy_per = np.zeros(len(eids)) + hard_per = np.zeros(len(eids)) + bias_start = 0 + + info_dict = {'subject': subject, 'dates': [st.to_pydatetime() for st in start_times], 'eids': eids} + contrast_set = {0, 1, 9, 10} + + rel_count = -1 + quit() + for i, start_time in enumerate(start_times): + + rel_count += 1 + + df = trials[trials.session_start_time == start_time] + df.loc[:, 'contrastRight'] = df.loc[:, 'contrastRight'].fillna(0) + df.loc[:, 'contrastLeft'] = df.loc[:, 'contrastLeft'].fillna(0) + df.loc[:, 'feedbackType'] = df.loc[:, 'feedbackType'].replace(-1, 0) + df.loc[:, 'signed_contrast'] = df.loc[:, 'contrastRight'] - df.loc[:, 'contrastLeft'] + df.loc[:, 'signed_contrast'] = df.loc[:, 'signed_contrast'].map(contrast_to_num) + df.loc[:, 'choice'] = df.loc[:, 'choice'] + 1 + + if any([df[x].isnull().any() for x in ['signed_contrast', 'choice', 'feedbackType', 'probabilityLeft']]): + quit() + + assert len(np.unique(df['session_start_time'])) == 1 + + current_contrasts = set(df['signed_contrast']) + diff = current_contrasts.difference(contrast_set) + for c in to_introduce: + if c in diff: + info_dict[c] = rel_count + contrast_set.update(diff) + + performance[i] = np.mean(df['feedbackType']) + easy_per[i] = np.mean(df['feedbackType'][np.logical_or(df['signed_contrast'] == 0, df['signed_contrast'] == 10)]) + hard_per[i] = np.mean(df['feedbackType'][df['signed_contrast'] == 5]) + + if bias_start == 0 and df.task_protocol[0].startswith('_iblrig_tasks_biasedChoiceWorld'): + bias_start = i + print("bias start {}".format(rel_count)) + info_dict['bias_start'] = rel_count + if bias_start < 33: + short_subjs.append(subject) + + pickle.dump(df, open("./{}/{}_df_{}.p".format(data_folder, subject, rel_count), "wb")) + + side_info = np.zeros((len(df), 2)) + side_info[:, 0] = df['probabilityLeft'] + side_info[:, 1] = df['feedbackType'] + pickle.dump(side_info, open("./{}/{}_side_info_{}.p".format(data_folder, subject, rel_count), "wb")) + + fit_info = np.zeros((len(df), 3)) + fit_info[:, 0] = df['signed_contrast'] + fit_info[:, 1] = df['choice'] + print(len(df)) + # fit_info[:, 2] = df['rt'] # RTODO + pickle.dump(fit_info, open("./{}/{}_fit_info_{}.p".format(data_folder, subject, rel_count), "wb")) + + if rel_count == -1: + continue + plt.figure(figsize=(11, 8)) + print(performance) + plt.plot(performance, label='Overall') + plt.plot(easy_per, label='100% contrasts') + plt.plot(hard_per, label='0% contrasts') + plt.axvline(bias_start - 0.5) + skip_count = 0 + for p in performance: + if p == 0.: + skip_count += 1 + else: + break + for c in to_introduce: + plt.axvline(info_dict[c] + skip_count, ymax=0.85, c='grey') + plt.annotate('Pre-bias', (bias_start / 2, 1.), size=20, ha='center') + plt.annotate('Bias', (bias_start + (i - bias_start) / 2, 1.), size=20, ha='center') + plt.title(subject, size=22) + plt.ylabel('Performance', size=22) + plt.xlabel('Session', size=22) + plt.xticks(size=16) + plt.xticks(size=16) + plt.ylim(bottom=0) + plt.xlim(left=0) + + sns.despine() + plt.tight_layout() + if not old_style: + plt.savefig('./figures/behavior/all_of_trainig_{}'.format(subject)) + plt.close() + + # print(bias_eids) + pre_bias.append(info_dict['bias_start']) + entire_training.append(rel_count + 1) + + info_dict['n_sessions'] = rel_count + pickle.dump(info_dict, open("./{}/{}_info_dict.p".format(data_folder, subject), "wb")) +print(misses) +print(short_subjs) + +print(pre_bias) +print(entire_training) diff --git a/behaviour_overview.py b/behaviour_overview.py index 1a4899e65f47119af82ac806fd7babd966679c33..f38eca39245e5708bad090ad36b119c2ef4aa6f3 100644 --- a/behaviour_overview.py +++ b/behaviour_overview.py @@ -5,7 +5,7 @@ from one.api import ONE import pickle -show = 0 +show = 1 def progression(data, contrasts, progression_variable='feedback', windowsize=6, upper_bound=None, title=None): # looks somewhat irregular, red dots are not in middle of bump they cause, this is because distribution of specific contrasts is not uniform @@ -47,12 +47,37 @@ def progression(data, contrasts, progression_variable='feedback', windowsize=6, # plt.title(title, size=22) # plt.savefig("temp {}".format(title).replace('/', '_')) # plt.close() + if progression_variable == 'feedback': + data_local = data[:113] + means = data_local.groupby('signed_contrast').mean()['response'] + stds = data_local.groupby('signed_contrast').sem()['response'] + plt.errorbar(means.index, means.values, stds.values, label='1') + + data_local = data[113:226] + means = data_local.groupby('signed_contrast').mean()['response'] + stds = data_local.groupby('signed_contrast').sem()['response'] + plt.errorbar(means.index, means.values, stds.values, label='2') + + data_local = data[226:339] + means = data_local.groupby('signed_contrast').mean()['response'] + stds = data_local.groupby('signed_contrast').sem()['response'] + plt.errorbar(means.index, means.values, stds.values, label='3') + + data_local = data[339:] + means = data_local.groupby('signed_contrast').mean()['response'] + stds = data_local.groupby('signed_contrast').sem()['response'] + plt.errorbar(means.index, means.values, stds.values, label='last') + plt.title(title, size=22) + plt.ylim(0, 1) + plt.legend() + # plt.savefig("temp {}".format(title).replace('/', '_')) + plt.show() if progression_variable == 'rt': means = data.groupby('signed_contrast').mean()['rt'] stds = data.groupby('signed_contrast').sem()['rt'] plt.errorbar(means.index, means.values, stds.values) plt.title(title, size=22) - plt.savefig("temp {}".format(title).replace('/', '_')) + # plt.savefig("temp {}".format(title).replace('/', '_')) plt.close() @@ -95,8 +120,8 @@ exclude_eids = ['a66f1593-dafd-4982-9b66-f9554b6c86b5', 'ee40aece-cffd-4edb-a4b6 # project='ibl_neuropixel_brainwide_01') # traj.reverse() -subject = 'KS014' -eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2022-01-01'], details=True) +subject = 'KS022' +eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2024-01-01'], details=True) start_times = [sess['date'] for sess in sess_info] protocols = [sess['task_protocol'] for sess in sess_info] @@ -107,8 +132,11 @@ protocols = [x for _, x in sorted(zip(start_times, protocols))] # for t in traj: counti = 0 for i, (prot, eid) in enumerate(zip(protocols, eids)): + print("Nice demonstration that the PMF jump is actually quite sudden!") print(i) - if not prot.startswith('_iblrig_tasks_trainingChoiceWorld'): + # if not prot.startswith('_iblrig_tasks_trainingChoiceWorld'): + # continue + if 'habituation' in prot: continue # eid = t['session']['id'] df, _ = get_df(eid) @@ -116,13 +144,18 @@ for i, (prot, eid) in enumerate(zip(protocols, eids)): if df is None: continue + print(prot) + counti += 1 - rt_data = np.zeros((len(df), 3)) - rt_data[:, 0] = df['signed_contrast'] - rt_data[:, 1] = df['rt'] - rt_data[:, 2] = df['response'] - pickle.dump(rt_data, open("./session_data/{} rt info {}".format(subject, counti), 'wb')) + if counti != 7: + continue + + # rt_data = np.zeros((len(df), 3)) + # rt_data[:, 0] = df['signed_contrast'] + # rt_data[:, 1] = df['rt'] + # rt_data[:, 2] = df['response'] + # pickle.dump(rt_data, open("./session_data/{} rt info {}".format(subject, counti), 'wb')) - progression(df, df['signed_contrast'].unique(), progression_variable='feedback', upper_bound=2, title="{} PMF {} / 15".format(subject, counti)) - progression(df, df['signed_contrast'].unique(), progression_variable='rt', upper_bound=4, title="{} CMF {} / 15".format(subject, counti)) + progression(df, df['signed_contrast'].unique(), progression_variable='feedback', upper_bound=2, title="{} PMF {}".format(subject, counti)) + # progression(df, df['signed_contrast'].unique(), progression_variable='rt', upper_bound=4, title="{} CMF {}".format(subject, counti)) diff --git a/brain_wide_download.py b/brain_wide_download.py new file mode 100644 index 0000000000000000000000000000000000000000..3e44f83a0eae9abc58b3e8970385ed8313e7a13d --- /dev/null +++ b/brain_wide_download.py @@ -0,0 +1,77 @@ +from pathlib import Path + +import seaborn + +from one.api import ONE +from one.remote import aws +from one.alf.files import add_uuid_string +# root_path = Path("/datadisk/FlatIron/aggregates") +# +# files_parquet = list(root_path.rglob('_ibl_subjectTrials.table.pqt')) + +one.load(subject_id, content_type='subject', name='_ibl_subjectTrials.table.pqt') + + +if len(files_parquet) == 0: + one = ONE() + s3, bucket_name = aws.get_s3_from_alyx(alyx=one.alyx) + datasets = one.alyx.rest('datasets', 'list', name='_ibl_subjectTrials.table.pqt') + for dset in datasets: + rel_path = dset['file_records'][0]['relative_path'] + aws_path = add_uuid_string('aggregates/' + rel_path, dset['url'][-36:]) + aws.s3_download_file(aws_path, root_path.joinpath(rel_path), s3=s3, bucket_name=bucket_name) + + +## %% +import dask.dataframe as dd +import pandas as pd +import numpy as np + +time_bins = np.arange(0, 4000, 20) + + +trials = dd.read_parquet(files_parquet) +# the ITI is the start time of the next trial minus the end time of the current plus .5 secs +trials['iti'] = trials['intervals_0'].shift(-1) - trials['intervals_1'] + .5 +trials = trials[trials['iti'] > 0] # we need to remove the session jumps +# here we select only the ephys protocol trials +trials = trials[trials['task_protocol'].apply(lambda x: 'ephys' in x, meta=('task_protocol', 'bool'))] +# aggregate the iti per time bins +evolution = trials['iti'].groupby(trials['intervals_0'].map_partitions(pd.cut, time_bins)).agg(['mean', 'std']) + + + +import time +# crunch crunch +now = time.time() +itis = trials['iti'].compute() +tev = evolution.compute() +print(time.time() - now) + + + +## %% +import seaborn as sns +import matplotlib.pyplot as plt +sns.set_theme('paper') +sns.set_palette('deep') +fig, axs = plt.subplots(1, 3, figsize=(16, 5)) +axs[0].hist(itis, bins=1000, range=[0.5, 3], linewidth=0, alpha=0.8) +axs[0].set(title='ITI distribution for all trials', xlabel='ITI (s)', ylabel='Count') + +x = time_bins[:-1] + 10 +axs[1].fill_between(x, y1=tev['mean'].values - tev['std'].values, y2=tev['mean'].values + tev['std'].values, alpha=0.4) +axs[1].plot(x, tev['mean'].values, color='orange', linewidth=2) +axs[1].set(title='ITI duration', ylabel='ITI (s)', xlabel='Time elapsed in session (s)', ylim=[0.75, 2], xlim=[0, 3600]) + +trials_ = pd.read_parquet(files_parquet[0]) +trials_['iti'] = trials_['intervals_0'].shift(-1) - trials_['intervals_1'] + .5 +session0 = trials_[trials_['session'] == 'dc1b7422-cc16-4a37-b552-c31dccdddbce'] +axs[2].plot(session0['intervals_0'], session0['iti'], label='dc1b7422', linewidth=0.5) +trials_ = pd.read_parquet(files_parquet[50]) +trials_['iti'] = trials_['intervals_0'].shift(-1) - trials_['intervals_1'] + .5 +session1 = trials_[trials_['session'] == 'd9d83a6a-8fb5-41eb-b4ce-7c6dd1716d72'] +axs[2].plot(session1['intervals_0'], session1['iti'], label='d9d83a6a', linewidth=0.5) +axs[2].set(title='ITI duration', ylabel='ITI (s)', xlabel='Time elapsed in session (s)', ylim=[0.75, 2], xlim=[0, 3600]) +axs[2].legend() +## diff --git a/bwm_mice_stats.py b/bwm_mice_stats.py index 14696ff16c7811a63c1ee7027a6cd81cae596cf7..0ea04bba2db48c10ea81c0c62e416d058c173f1f 100644 --- a/bwm_mice_stats.py +++ b/bwm_mice_stats.py @@ -69,6 +69,7 @@ subjects = ['NYU-11', 'NYU-12', 'NYU-21', 'NYU-27', 'NYU-30', 'NYU-37', 'ibl_witten_27', 'ibl_witten_29', 'CSH_ZAD_001', 'CSH_ZAD_011', 'CSH_ZAD_019', 'CSH_ZAD_022', 'CSH_ZAD_024', 'CSH_ZAD_025', 'CSH_ZAD_026', 'CSH_ZAD_029'] +quit() # why does CSHL058 not work? old_style = False diff --git a/canonical_infos.json b/canonical_infos.json index c0868bc5f9a444da61e2e4c5010c9c759000ff3c..12a3fa3aae799f418267abc5e98f113db7d5d248 100644 --- a/canonical_infos.json +++ b/canonical_infos.json @@ -1 +1 @@ -{"SWC_022": {"seeds": ["412", "401", "403", "413", "406", "407", "415", "409", "405", "400", "408", "404", "410", "411", "414", "402"], "fit_nums": ["347", "54", "122", "132", "520", "386", "312", "59", "999", "849", "372", "300", "485", "593", "358", "550"], "chain_num": 19, "ignore": [5, 6, 12, 3, 1, 11, 10]}, "SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4, "ignore": [12, 1, 15, 14, 8, 6, 4, 10]}, "ibl_witten_15": {"seeds": ["408", "412", "400", "411", "410", "407", "403", "406", "413", "405", "404", "402", "401", "415", "409", "414"], "fit_nums": ["40", "241", "435", "863", "941", "530", "382", "750", "532", "731", "146", "500", "967", "334", "375", "670"], "chain_num": 19, "ignore": [6, 7, 9, 8, 5, 13, 12, 11]}, "ibl_witten_13": {"seeds": ["401", "414", "409", "413", "415", "411", "410", "408", "402", "405", "406", "407", "412", "403", "400", "404"], "fit_nums": ["702", "831", "47", "740", "251", "929", "579", "351", "515", "261", "222", "852", "754", "892", "473", "29"], "chain_num": 19, "ignore": [11, 4, 15, 6, 8, 0]}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4, "ignore": [0, 2, 14, 12, 1, 7, 11, 6]}, "ibl_witten_19": {"seeds": ["412", "415", "413", "408", "409", "404", "403", "401", "405", "411", "410", "406", "402", "414", "407", "400"], "fit_nums": ["234", "41", "503", "972", "935", "808", "912", "32", "331", "755", "117", "833", "822", "704", "901", "207"], "chain_num": 19, "ignore": [6, 1, 15, 10, 12]}, "CSH_ZAD_017": {"seeds": ["408", "404", "413", "406", "414", "411", "400", "401", "415", "407", "402", "412", "403", "409", "405", "410"], "fit_nums": ["928", "568", "623", "841", "92", "251", "829", "922", "964", "257", "150", "970", "375", "113", "423", "564"], "chain_num": 19, "ignore": [8, 10, 12, 5, 9, 7, 4]}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4, "ignore": [10, 1, 0, 13, 5, 9, 12, 3]}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4, "ignore": [8, 10, 13, 5, 12, 9, 7, 1]}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9, "ignore": [9, 0, 1, 7, 11, 3, 10, 8]}, "SWC_021": {"seeds": ["404", "413", "406", "412", "403", "401", "410", "409", "400", "414", "415", "402", "405", "408", "411", "407"], "fit_nums": ["840", "978", "224", "38", "335", "500", "83", "509", "441", "9", "135", "890", "358", "460", "844", "30"], "chain_num": 19, "ignore": [2, 13, 1, 9, 0, 4, 6, 5]}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4, "ignore": [11, 0, 4, 2, 12, 13, 8, 3]}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4, "ignore": [15, 4, 8, 0, 5, 10, 12, 11]}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2, "ignore": [15, 9, 8, 14, 1, 12, 10, 3]}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4, "ignore": [0, 14, 5, 8, 7, 11, 13, 10]}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4, "ignore": [6, 5, 9, 15, 0, 8, 4, 13]}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9, "ignore": [5, 12, 7, 10, 11, 2, 6, 4]}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4, "ignore": [15, 0, 3, 4, 7, 6, 1, 11]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4, "ignore": [14, 6, 3, 11, 15, 13, 4, 12]}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4, "ignore": [12, 8, 5, 1, 9, 3, 13, 15]}, "KS003": {"seeds": ["405", "401", "414", "415", "410", "404", "409", "413", "412", "408", "411", "407", "402", "406", "403", "400"], "fit_nums": ["858", "464", "710", "285", "665", "857", "990", "438", "233", "177", "43", "509", "780", "254", "523", "695"], "chain_num": 19, "ignore": [1, 8, 9, 2, 4, 15, 0, 7]}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4, "ignore": [8, 2, 7, 12, 3, 4, 13, 11]}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9, "ignore": [12, 14, 1, 2, 4, 7, 10, 15]}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9, "ignore": [10, 11, 6, 7, 12, 13, 1, 8]}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4, "ignore": [11, 14, 6, 13, 5, 12, 15, 8]}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4, "ignore": [9, 3, 5, 15, 6, 10, 2, 1]}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2, "ignore": [12, 4, 11, 6, 7, 14, 0, 1]}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4, "ignore": [3, 12, 4, 5, 2, 0, 13, 1]}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4, "ignore": [0, 2, 14, 11, 7, 10, 13, 9]}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4, "ignore": [11, 12, 0, 8, 2, 14, 5, 1]}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2, "ignore": [8, 1, 0, 3, 2, 5, 10, 4]}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4, "ignore": [7, 6, 10, 2, 15, 13, 1, 3]}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4, "ignore": [3, 12, 2, 6, 10, 14, 4, 1]}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4, "ignore": [7, 8, 0, 10, 2, 3, 12, 9]}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2, "ignore": [0, 7, 15, 14, 3, 10, 11, 13]}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2, "ignore": [1, 9, 15, 3, 13, 12, 7, 11]}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4, "ignore": [12, 13, 4, 11, 8, 3, 15, 0]}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4, "ignore": [8, 10, 1, 13, 4, 15, 14, 7]}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2, "ignore": [11, 2, 5, 1, 4, 9, 15, 12]}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4, "ignore": [11, 13, 7, 15, 14, 3, 0, 4]}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4, "ignore": [15, 12, 8, 13, 0, 2, 4, 5]}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2, "ignore": [7, 11, 2, 15, 0, 13, 5, 10]}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2, "ignore": [4, 10, 5, 0, 13, 8, 6, 7]}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4, "ignore": [14, 7, 12, 1, 3, 4, 11, 8]}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4, "ignore": [9, 12, 4, 8, 3, 7, 0, 1]}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4, "ignore": [0, 13, 8, 1, 12, 5, 10, 9]}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}} \ No newline at end of file +{"SWC_022": {"seeds": ["412", "401", "403", "413", "406", "407", "415", "409", "405", "400", "408", "404", "410", "411", "414", "402"], "fit_nums": ["347", "54", "122", "132", "520", "386", "312", "59", "999", "849", "372", "300", "485", "593", "358", "550"], "chain_num": 19}, "SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "ZFM-05236": {"seeds": ["404", "409", "401", "412", "406", "411", "410", "402", "408", "405", "415", "403", "414", "400", "413", "407"], "fit_nums": ["106", "111", "333", "253", "395", "76", "186", "192", "221", "957", "989", "612", "632", "304", "50", "493"], "chain_num": 14, "ignore": [3, 0, 8, 2, 6, 15, 14, 10]}, "ibl_witten_15": {"seeds": ["408", "412", "400", "411", "410", "407", "403", "406", "413", "405", "404", "402", "401", "415", "409", "414"], "fit_nums": ["40", "241", "435", "863", "941", "530", "382", "750", "532", "731", "146", "500", "967", "334", "375", "670"], "chain_num": 19}, "ibl_witten_13": {"seeds": ["401", "414", "409", "413", "415", "411", "410", "408", "402", "405", "406", "407", "412", "403", "400", "404"], "fit_nums": ["702", "831", "47", "740", "251", "929", "579", "351", "515", "261", "222", "852", "754", "892", "473", "29"], "chain_num": 19}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "ibl_witten_19": {"seeds": ["412", "415", "413", "408", "409", "404", "403", "401", "405", "411", "410", "406", "402", "414", "407", "400"], "fit_nums": ["234", "41", "503", "972", "935", "808", "912", "32", "331", "755", "117", "833", "822", "704", "901", "207"], "chain_num": 19}, "CSH_ZAD_017": {"seeds": ["408", "404", "413", "406", "414", "411", "400", "401", "415", "407", "402", "412", "403", "409", "405", "410"], "fit_nums": ["928", "568", "623", "841", "92", "251", "829", "922", "964", "257", "150", "970", "375", "113", "423", "564"], "chain_num": 19}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "SWC_021": {"seeds": ["404", "413", "406", "412", "403", "401", "410", "409", "400", "414", "415", "402", "405", "408", "411", "407"], "fit_nums": ["840", "978", "224", "38", "335", "500", "83", "509", "441", "9", "135", "890", "358", "460", "844", "30"], "chain_num": 19}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "ZFM-05245": {"seeds": ["400", "413", "414", "409", "411", "403", "405", "412", "406", "410", "407", "415", "401", "404", "402", "408"], "fit_nums": ["512", "765", "704", "17", "539", "449", "584", "987", "138", "932", "869", "313", "253", "540", "37", "634"], "chain_num": 11}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "ZFM-04019": {"seeds": ["413", "404", "408", "403", "415", "406", "414", "410", "402", "405", "411", "400", "401", "412", "409", "407"], "fit_nums": ["493", "302", "590", "232", "121", "938", "270", "999", "95", "175", "576", "795", "728", "244", "32", "177"], "chain_num": 14, "ignore": [11, 12, 6, 3, 5, 2, 15, 1]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "KS003": {"seeds": ["405", "401", "414", "415", "410", "404", "409", "413", "412", "408", "411", "407", "402", "406", "403", "400"], "fit_nums": ["858", "464", "710", "285", "665", "857", "990", "438", "233", "177", "43", "509", "780", "254", "523", "695"], "chain_num": 19}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}} \ No newline at end of file diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index c61d4c7c659c1360646febd6889a35d401ebdc37..3e8a74ceeab4065b6c0f37e16911e44f010f5cf1 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -545,10 +545,10 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur cnas.append(c_n_a) mask = c_n_a[:, -1] == 0 - plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward', alpha=0.6) + plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - all_conts[cont_mapping(- c_n_a[mask, 0] + c_n_a[mask, 1])]), 'o', c='b', ms=ms, label='Rightward', alpha=0.6) mask = c_n_a[:, -1] == 1 - plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward', alpha=0.6) + plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - all_conts[cont_mapping(- c_n_a[mask, 0] + c_n_a[mask, 1])]), 'o', c='r', ms=ms, label='Leftward', alpha=0.6) plt.title("session #{} / {}".format(1+seq_num, test.results[0].n_sessions), size=26) # plt.yticks(*self.cont_ticks, size=22-2) @@ -566,13 +566,15 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur 1.03, functions=(cont_to_belief, belief_to_cont)) secax_y2.set_ylabel("Contrast", size=26) secax_y2.set_yticks([-1, -0.5, 0., 0.5, 1]) + # secax_y2.set_yticks(np.arange(11) / 5 - 1) + # secax_y2.set_yticklabels([-1, '', '', '', '', 0., '', '', '', '', 1]) secax_y2.spines['right'].set_bounds(-1, 1) secax_y2.tick_params(axis='y', which='major', labelsize=fs) plt.xlabel('Trial', size=28) sns.despine() - plt.xlim(left=250, right=450) - plt.legend(frameon=False, fontsize=22, bbox_to_anchor=(0.8, 0.5)) + # plt.xlim(left=250, right=450) + plt.legend(frameon=False, fontsize=22, ncol=2, loc=(0.7, 0.05)) plt.tight_layout() if save: plt.savefig("dynamic_GLM_figures/all posterior and contrasts {}, sess {}{}.png".format(subject, seq_num, save_append), dpi=dpi)#, bbox_inches='tight') @@ -685,6 +687,7 @@ def create_mode_indices(test, subject, fit_type): try: xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb')) except Exception: + print('Doing PCA') ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_type='dists', dim=dim) xy = np.vstack([dimreduc[i] for i in range(dim)]) from scipy.stats import gaussian_kde @@ -1128,7 +1131,6 @@ def state_development(test, state_sets, indices, save=True, save_append='', show counter += 1 else: counter = 0 - if test.results[0].infos[c] + 1 < plot_until: ax0.axvline(test.results[0].infos[c] + 1, color='gray', zorder=0) ax1.axvline(test.results[0].infos[c] + 1, color='gray', zorder=0) ax0.plot(test.results[0].infos[c] + 1 - 0.25, 0.6 + counter * 0.2, 'ko', ms=18) @@ -1281,7 +1283,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show ax0.axhline(-0.5, c='k') ax0.axhline(0.5, c='k') # print(perf) - ax0.fill_between(range(1, min(1 + test.results[0].n_sessions, plot_until)), perf[:plot_until - 1] - 0.5, -0.5, color='k') + ax0.fill_between(range(1, 1 + test.results[0].n_sessions), perf - 0.5, -0.5, color='k') durs, state_types = state_type_durs(states_by_session, all_pmfs) @@ -1352,7 +1354,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show plt.tight_layout() if save: - print("saving with {} dpi".format(dpi)) + # print("saving with {} dpi".format(dpi)) plt.savefig("dynamic_GLM_figures/meta_state_development_{}_{}{}.png".format(test.results[0].name, separate_pmf, save_append), dpi=dpi) if show: plt.show() @@ -1361,7 +1363,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show return states_by_session, all_pmfs, all_pmf_weights, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type -def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, all_pmfs, title=""): +def compare_pmfs(test, states2compare, states_by_session, all_pmfs, title=""): """ Take a set of states, and plot out their PMFs on all sessions on which they occur. See how different they really are. @@ -1395,6 +1397,31 @@ def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, a plt.show() +def compare_params(test, session, state_sets, indices, states2compare, title=""): + """ + Take a set of states, and plot a histogram of their regression parameters for a specific session. + See how different they really are. + + Unfinished + + Takes states_by_session and all_pmfs as input from state_development + """ + colors = ['blue', 'orange', 'green', 'black', 'red'] + assert len(states2compare) <= len(colors) + # subtract 1 to get internal numbering + states2compare = [s - 1 for s in states2compare] + # transform desired states into the actual numbering, before ordering by bias + states2compare = [key for key in test.state_mapping.keys() if test.state_mapping[key] in states2compare] + + for state, trials in enumerate(state_sets): + if state not in states2compare: + continue + session_js, pmfs, pmf_weights = state_weights(test, trials, indices) + return pmf_weights + plt.tight_layout() + plt.show() + + def compare_weights(test, state_sets, indices, states2compare, states_by_session, title=""): """ Take a set of states, and plot out their weights on all sessions on which they occur. @@ -1495,7 +1522,7 @@ if __name__ == "__main__": elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) r_hats = json.load(open("canonical_info_r_hats.json", 'r')) - subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] # list(loading_info.keys()) + subjects = ['ZFM-04019', 'ZFM-05236'] # list(loading_info.keys()) r_hats = {} @@ -1544,6 +1571,24 @@ def dist_helper(dist_matrix, state_hists, inds): dist_matrix[i, j] = np.sum(np.abs(state_hists[i] - state_hists[j])) return dist_matrix +def type_2_appearance(states, pmfs): + # How does type 2 appear, is it a new state or continuation of a type 1? + state_counter = {} + found_states = 0 + for session_counter in range(states.shape[1]): + for state, pmf in zip(range(states.shape[0]), pmfs): + if states[state, session_counter]: + if state not in state_counter: + state_counter[state] = -1 + state_counter[state] += 1 + if pmf_type(pmf[1][state_counter[state]][pmf[0]]) == 1: + found_states += 1 + new = state_counter[state] == 0 + if found_states > 1: + print("Problem") + return 2 + if found_states == 1: + return new def state_type_durs(states, pmfs): # Takes states and pmfs, first creates an array of when which type is how active, then computes the number of sessions each type lasts. @@ -1557,7 +1602,6 @@ def state_type_durs(states, pmfs): pmf_counter += 1 state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i] # indexing horror - print(state_types) if np.any(state_types[1] > 0.5): durs = (np.where(state_types[1] > 0.5)[0][0], np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0], @@ -1662,6 +1706,36 @@ def compare_performance(cnas, contrasts=(1, 1), title=""): plt.savefig(title) plt.show() +def write_results(test, state_sets, indices, consistencies=None): + n = test.results[0].n_sessions + trial_counter = 0 + state_dict = {state: {'sessions': [], 'trials': [], 'pmfs': []} for state in range(len(state_sets))} + + for state, trials in enumerate(state_sets): + + session_js, pmfs, pmf_weights = state_pmfs(test, trials, indices) + state_dict[state]['sessions'] = session_js + state_dict[state]['pmfs'] = pmfs + + + for seq_num in range(n): + for state, trials in enumerate(state_sets): + relevant_trials = trials[np.logical_and(trial_counter <= trials, trials < trial_counter + len(test.results[0].models[0].stateseqs[seq_num]))] + active_trials = np.zeros(len(test.results[0].models[0].stateseqs[seq_num])) + + if consistencies is None: + active_trials[relevant_trials - trial_counter] = 1 + else: + active_trials[relevant_trials - trial_counter] = np.sum(consistencies[tuple(np.meshgrid(relevant_trials, trials))], axis=0) + active_trials[relevant_trials - trial_counter] -= 1 + active_trials[relevant_trials - trial_counter] = active_trials[relevant_trials - trial_counter] / (trials.shape[0] - 1) + + if np.sum(active_trials) > 0: + state_dict[state]['trials'].append(active_trials) + + trial_counter += len(test.results[0].models[0].stateseqs[seq_num]) + return state_dict + if __name__ == "__main__": @@ -1675,7 +1749,7 @@ if __name__ == "__main__": no_good_pcas = ['NYU-06', 'SWC_023'] subjects = list(loading_info.keys()) # subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] - # subjects = ['KS014'] + subjects = ['KS021'] print(subjects) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] @@ -1717,11 +1791,13 @@ if __name__ == "__main__": all_bias_flips = [] all_pmf_weights = [] - temp_counter = 0 - + new_counter, transform_counter = 0, 0 state_types_interpolation = np.zeros((3, 150)) all_state_types = [] + state_nums_5 = [] + state_nums_10 = [] + for subject in subjects: if subject.startswith('GLM_Sim_') or subject == 'ibl_witten_18': @@ -1748,22 +1824,64 @@ if __name__ == "__main__": # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) # training overview - # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(7)), plot_until=2) - # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(6)), plot_until=7) - # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(4)), plot_until=13) + # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 0', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(8)), plot_until=-1) + # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(7)), plot_until=2) + # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(6)), plot_until=7) + # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(4)), plot_until=13) states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) - all_pmf_weights += pmf_weights - continue + # state_nums_5.append((states > 0.05).sum(0)) + # state_nums_10.append((states > 0.1).sum(0)) + # continue + a = compare_params(test, 26, [s for s in state_sets if len(s) > 40], mode_indices, [3, 5]) + compare_pmfs(test, [3, 5], states, pmfs, title="{} convergence pmf".format(subject)) + quit() + new = type_2_appearance(states, pmfs) + + if new == 2: + print('____________________________') + print(subject) + print('____________________________') + if new == 1: + new_counter += 1 + if new == 0: + transform_counter += 1 + print(new_counter, transform_counter) consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) consistencies /= consistencies[0, 0] - contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False) + temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False) + continue - all_state_types.append(state_types) + state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices) + pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb')) + quit() + # abs_state_durs.append(durs) + # continue + # all_pmf_weights += pmf_weights + # all_state_types.append(state_types) + # # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0]) # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1]) # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2]) - # temp_counter += 1 + # + # all_first_pmfs_typeless[subject] = [] + # for pmf in pmfs: + # all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0])) + # all_pmfs.append(pmf) + # + # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs) + # for pmf in changing_pmfs: + # if type(pmf[0]) == int: + # continue + # all_changing_pmf_names.append(subject) + # all_changing_pmfs.append(pmf) + # + # all_first_pmfs[subject] = first_pmfs + # continue + # + # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + # consistencies /= consistencies[0, 0] + # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False) # b_flips = bias_flips(states, pmfs, durs) @@ -1776,22 +1894,12 @@ if __name__ == "__main__": # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject)) # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject)) # quit() - # all_first_pmfs_typeless[subject] = [] - # for pmf in pmfs: - # all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0])) # all_intros.append(undiv_intros) # all_intros_div.append(intros_by_type) # if states_per_type != []: # all_states_per_type.append(states_per_type) # # intros_by_type_sum += intros_by_type - # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs) - # for pmf in changing_pmfs: - # if type(pmf[0]) == int: - # continue - # all_changing_pmf_names.append(subject) - # all_changing_pmfs.append(pmf) - # all_first_pmfs[subject] = first_pmfs # for pmf in pmfs: # all_pmfs.append(pmf) # for p in pmf[1]: @@ -1817,7 +1925,6 @@ if __name__ == "__main__": # compare_performance(cnas, (0, 0.848)) # duration of different state types (and also percentage of type activities) - # abs_state_durs.append(durs) # continue # simplex_durs = np.array(durs).reshape(1, 3) # print(simplex_durs / np.sum(simplex_durs)) @@ -1860,6 +1967,7 @@ if __name__ == "__main__": # plt.savefig("temp") # plt.close() + print('Computing sub result') create_mode_indices(test, subject, fit_type) state_set_and_plot(test, 'first_', subject, fit_type) print("second mode?") @@ -1868,8 +1976,8 @@ if __name__ == "__main__": except FileNotFoundError as e: print(e) - continue print('no canoncial result') + continue print(r_hats[subject]) if r_hats[subject] >= 1.05: print("Skipping") @@ -1915,7 +2023,7 @@ if __name__ == "__main__": test.r_hat_and_ess(alpha_func, True) - # pickle.dump(all_first_pmfs, open("pmfs_temp.p", 'wb')) + # pickle.dump(all_first_pmfs, open("all_first_pmfs.p", 'wb')) # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb')) # pickle.dump(all_changing_pmf_names, open("changing_pmf_names.p", 'wb')) # pickle.dump(all_first_pmfs_typeless, open("all_first_pmfs_typeless.p", 'wb')) @@ -1927,13 +2035,12 @@ if __name__ == "__main__": # pickle.dump(regression_diffs, open("regression_diffs.p", 'wb')) # pickle.dump(all_bias_flips, open("all_bias_flips.p", 'wb')) # pickle.dump(all_state_types, open("all_state_types.p", 'wb')) - pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb')) - quit() - + # pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb')) + # pickle.dump(state_types_interpolation, open("state_types_interpolation.p", 'wb')) + # abs_state_durs = np.array(abs_state_durs) + # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb')) if True: - abs_state_durs = np.array(abs_state_durs) - # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb')) abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb')) print("Correlations") diff --git a/dynamic_GLMiHMM_fit (copy).py b/dynamic_GLMiHMM_fit (copy).py new file mode 100644 index 0000000000000000000000000000000000000000..6ccb8afcc173f24b0fd7228f668917b4625a52f4 --- /dev/null +++ b/dynamic_GLMiHMM_fit (copy).py @@ -0,0 +1,407 @@ +"""Start a (series) of iHMM fit(s).""" +import os +os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=4 +os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=4 +os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=6 +os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=4 +os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=6 +import pyhsmm +import pyhsmm.basic.distributions as distributions +import copy +import warnings +import pickle +import time +from scipy.special import digamma +import os.path +import numpy as np +from itertools import product +import json +import sys + + +def crp_expec(n, theta): + """ + Return expected number of tables after n customers, given concentration theta. + + From Wikipedia + """ + return theta * (digamma(theta + n) - digamma(theta)) + + +def eleven2nine(x): + """Map from 11 possible contrasts to 9, for the non-training phases. + + 1 and 9 can't appear there, make other elements consider this. + + E.g.: + [2, 0, 4, 8, 10] -> [1, 0, 3, 7, 8] + """ + assert 1 not in x and 9 not in x + x[x > 9] -= 1 + x[x > 1] -= 1 + return x + + +def eval_cross_val(models, data, unmasked_data, n_all_states): + """Eval cross_val.""" + lls = np.zeros(len(models)) + cross_val_n = 0 + for sess_time, (d, full_data) in enumerate(zip(data, unmasked_data)): + held_out = np.isnan(d[:, -1]) + cross_val_n += held_out.sum() + d[:, -1][held_out] = full_data[:, -1][held_out] + for i, m in enumerate(models): + for s in range(n_all_states): + mask = np.logical_and(held_out, m.stateseqs[sess_time] == s) + if mask.sum() > 0: + ll = m.obs_distns[s].log_likelihood(d[mask], sess_time) + lls[i] += np.sum(ll) + lls /= cross_val_n + ll_mean = np.mean(lls[-1000:]) + return lls, ll_mean + + +# following Nick Roys contrasts: following tanh transformation of the contrasts x has a +# free parameter p which we set as p= 5 throughout the paper: xp = tanh (px)/ tanh (p). +contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10} +num_to_contrast = {v: k for k, v in contrast_to_num.items()} +cont_mapping = np.vectorize(num_to_contrast.get) + +data_folder = 'session_data_test' +old_style = False +if old_style: + print("Warning, data can have splits") + print("Sure you want to use old data?") + temp = input() + if temp: + data_folder = 'session_data_old' + else: + quit() + +# available: +# not great: ibl_witten_18 + +# test subjects: +subjects = ['NYU-21', 'NYU-27', 'NYU-30', 'NYU-37', 'NYU-39', 'NYU-40', 'NYU-45', 'NYU-46', 'NYU-47', 'NYU-48', 'CSHL045', 'CSHL047', 'CSHL052', 'CSHL053', 'CSHL055', 'CSHL058', 'CSHL060', 'UCLA005', 'UCLA006', 'UCLA011', 'UCLA012', 'UCLA014', 'UCLA015', 'UCLA017', 'UCLA033', 'UCLA035', 'KS042', 'KS043', 'KS044', 'KS045', 'KS046', 'KS051', 'KS052', 'KS055', 'KS084', 'KS086', 'KS091', 'KS094', 'KS096', 'DY_008', 'DY_009', 'DY_010', 'DY_011', 'DY_013', 'DY_014', 'DY_016', 'DY_018', 'DY_020', 'SWC_042', 'SWC_043', 'SWC_060', 'SWC_061', 'SWC_066', 'ZFM-01576', 'ZFM-01577', 'ZFM-01592', 'ZFM-01935', 'ZFM-01936', 'ZFM-01937', 'ZFM-02368', 'ZFM-02369', 'ZFM-02370', 'ZFM-02372', 'ZFM-02373', 'ZM_1898', 'ZM_2240', 'ZM_2241', 'ZM_2245', 'SWC_038', 'SWC_039', 'SWC_052', 'SWC_053', 'SWC_054', 'SWC_058', 'SWC_065', 'ibl_witten_20', 'ibl_witten_25', 'ibl_witten_26', 'ibl_witten_27', 'ibl_witten_29', 'CSH_ZAD_019', 'CSH_ZAD_024', 'CSH_ZAD_029'] + + +# subjects = [['GLM_Sim_15', 'GLM_Sim_14', 'GLM_Sim_13', 'GLM_Sim_11', 'GLM_Sim_10', 'GLM_Sim_09', 'GLM_Sim_12'][2]] +# (0.03, 0.3, 5, 'contR', 'contL', 'prevA', 'bias', 1, 0.1): + +print(subjects) + +cv_nums = [400 + int(sys.argv[1]) % 16, 400 + (2 * int(sys.argv[1])) % 16] +subjects = [subjects[int(sys.argv[1]) // 16], subjects[(2 * int(sys.argv[1])) // 16]] + +print(cv_nums) +print(subjects) + +for loop_count_i, (s, cv_num) in enumerate(zip(subjects, cv_nums)): + # if loop_count_i > 8: + # if loop_count_i <= 8 or loop_count_i > 16: + # if loop_count_i <= 16: + # continue + params = {} + params['subject'] = s + params['cross_val_num'] = cv_num + params['fit_variance'] = 0.03 + params['jumplimit'] = 1 + all_regressors = ['contR', 'contL', 'cont', 'prevA', 'weighted_prevA', 'WSLS', 'bias'] + params['regressors'] = [all_regressors[i] for i in [0, 1, 3, 6]] + + # default (non-iteration) settings: + params['fit_type'] = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] + # params['fit_variance'] = [0.0005, 0.002, 0.008, 0.02, 0.06, 0.1, 0.3, 0.6, 1., 2.4, 10, 16, 30, 'uniform'][6] + if 'prevA' in params['regressors'] or 'weighted_prevA' in params['regressors']: + params['exp_decay'], params['exp_length'] = 0.3, 5 + params['exp_filter'] = np.exp(- params['exp_decay'] * np.arange(params['exp_length'])) + params['exp_filter'] /= params['exp_filter'].sum() + print(params['exp_filter']) + params['dur'] = 'yes' + params['obs_dur'] = ['glm', 'cat'][0] + # more obscure params: + params['gamma'] = None # 0.005 + params['alpha'] = None # 1 + if params['gamma'] is not None: + print("_______________________") + print("Warning, gamma is fixed") + print("_______________________") + params['gamma_a_0'] = 0.001 + params['gamma_b_0'] = 1000 + params['init_var'] = 8 + params['init_mean'] = np.zeros(len(params['regressors'])) + # normal: + r_support = np.cumsum(np.arange(5, 100, 5)) + r_support = np.arange(5, 705, 4) + params['dur_params'] = dict(r_support=r_support, + r_probs=np.ones(len(r_support))/len(r_support), alpha_0=1, beta_0=1) + # params['dur_params'] = dict(r_support=np.array([1, 2, 3, 5, 7, 10, 15, 21, 28, 36, 45, 55, 150]), + # r_probs=np.ones(13)/13., alpha_0=1, beta_0=1) + # params['dur_params'] = dict(r_support=np.arange(1, 251), + # r_probs=np.ones(250)/250., alpha_0=1, beta_0=1) + params['alpha_a_0'] = 0.1 + params['alpha_b_0'] = 10 + # trying a smaller value here, should lower the appearance of ephemeral new states on session bounds + # hope this doesn't make real new states at session bound less likely, or hurt mixing... + params['init_state_concentration'] = 3 + # cat params + params['conditioned_on'] = 'nothing' + + params['cross_val'] = False + params['cross_val_fold'] = 10 + params['CROSS_VAL_SEED'] = 4 # Do not change this, it's 4 + + params['seed'] = 100 + params['cross_val_num'] + + params['n_states'] = 15 + params['n_samples'] = 6 if params['obs_dur'] == 'glm' else 4000 + if params['cross_val']: + params['n_samples'] = 4000 + if s.startswith("GLM_Sim"): + print("reduced sample size") + params['n_samples'] = 12000 + + print(params['n_samples']) + # now actual fit: + # new start names: uniform_start_, bias_fraction_, small_gamma_, high_init_, non_semi_, non_semi_normal_init_, correct_sol_, correct_sol_semi_ + while True: + folder = "./dynamic_GLMiHMM_crossvals/" + rand_id = np.random.randint(1000) + if params['cross_val']: + id = "{}_crossval_{}_{}_var_{}_{}_{}".format(params['subject'], params['cross_val_num'], params['fit_type'], + params['fit_variance'], params['seed'], rand_id) + else: + id = "{}_fittype_{}_var_{}_{}_{}".format(params['subject'], params['fit_type'], + params['fit_variance'], params['seed'], rand_id) + if not os.path.isfile(folder + id + '_0.p'): + break + # create placeholder dataset for rand_id purposes + pickle.dump(params, open(folder + id + '_0.p', 'wb')) + if params['obs_dur'] == 'glm': + print(params['regressors']) + else: + print('using categoricals') + print(id) + params['file_name'] = folder + id + np.random.seed(params['seed']) + + info_dict = pickle.load(open("./{}/{}_info_dict.p".format(data_folder, params['subject']), "rb")) + # Determine session numbers + if params['fit_type'] == 'prebias': + till_session = info_dict['bias_start'] + elif params['fit_type'] == 'bias' or params['fit_type'] == 'zoe_style' or params['fit_type'] == 'all': + till_session = info_dict['n_sessions'] + elif params['fit_type'] == 'prebias_plus': + till_session = min(info_dict['bias_start'] + 6, info_dict['n_sessions']) # 6 here will actually turn into 7 later + + from_session = info_dict['bias_start'] if params['fit_type'] in ['bias', 'zoe_style'] else 0 + + models = [] + + if params['obs_dur'] == 'glm': + n_inputs = len(params['regressors']) + T = till_session - from_session + (params['fit_type'] != 'prebias') + obs_hypparams = {'n_regressors': n_inputs, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'], + 'P_0': params['init_var'] * np.eye(n_inputs), 'Q': params['fit_variance'] * np.tile(np.eye(n_inputs), (T, 1, 1))} + obs_distns = [distributions.Dynamic_GLM(**obs_hypparams) for state in range(params['n_states'])] + else: + n_inputs = 9 if params['fit_type'] == 'bias' else 11 + obs_hypparams = {'n_regressors': n_inputs * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'), + 'jumplimit': params['jumplimit'], 'sigmasq_states': params['fit_variance']} + obs_distns = [Dynamic_Input_Categorical(**obs_hypparams) for state in range(params['n_states'])] + + dur_distns = [distributions.NegativeBinomialIntegerR2Duration(**params['dur_params']) for state in range(params['n_states'])] + + if params['dur'] == 'yes': + if params['gamma'] is None: + posteriormodel = pyhsmm.models.WeakLimitHDPHSMM( + # https://math.stackexchange.com/questions/449234/vague-gamma-prior + alpha_a_0=params['alpha_a_0'], alpha_b_0=params['alpha_b_0'], # TODO: gamma vs alpha? gamma steers state number + gamma_a_0=params['gamma_a_0'], gamma_b_0=params['gamma_b_0'], + init_state_concentration=params['init_state_concentration'], + obs_distns=obs_distns, + dur_distns=dur_distns, + var_prior=params['fit_variance']) # TODO: I don't think this does anything + else: + posteriormodel = pyhsmm.models.WeakLimitHDPHSMM( + # https://math.stackexchange.com/questions/449234/vague-gamma-prior + alpha=params['alpha'], # TODO: gamma vs alpha? gamma steers state number + gamma=params['gamma'], + init_state_concentration=params['init_state_concentration'], + obs_distns=obs_distns, + dur_distns=dur_distns, + var_prior=params['fit_variance']) # TODO: I don't think this does anything + else: + if params['gamma'] is None: + posteriormodel = pyhsmm.models.WeakLimitHDPHMM( + alpha_a_0=params['alpha_a_0'], alpha_b_0=params['alpha_b_0'], # TODO: gamma vs alpha? gamma steers state number + gamma_a_0=params['gamma_a_0'], gamma_b_0=params['gamma_b_0'], + init_state_concentration=params['init_state_concentration'], + obs_distns=obs_distns, + var_prior=params['fit_variance']) # TODO: I don't think this does anything + else: + posteriormodel = pyhsmm.models.WeakLimitHDPHMM( + alpha=params['alpha'], # TODO: gamma vs alpha? gamma steers state number + gamma=params['gamma'], + init_state_concentration=params['init_state_concentration'], + obs_distns=obs_distns, + var_prior=params['fit_variance']) # TODO: I don't think this does anything + + print(from_session, till_session + (params['fit_type'] != 'prebias')) + + if params['cross_val']: + rng = np.random.RandomState(params['CROSS_VAL_SEED']) + + data_save = [] + for j in range(from_session, till_session + (params['fit_type'] != 'prebias')): + try: + data = pickle.load(open("./{}/{}_fit_info_{}.p".format(data_folder, params['subject'], j), "rb")) + except FileNotFoundError: + continue + if data.shape[0] == 0: + print("meh, skipped session") + continue + + # if j == 15: + # import matplotlib.pyplot as plt + # for i in [0, 2, 3,4,5,6,7,8,10]: + # plt.plot(i, data[data[:, 0] == i, 1].mean(), 'ko') + # plt.show() + + if params['obs_dur'] == 'glm': + for i in range(data.shape[0]): + data[i, 0] = num_to_contrast[data[i, 0]] + mask = data[:, 1] != 1 + mask[0] = False + if params['fit_type'] == 'zoe_style': + mask[90:] = False + mega_data = np.empty((np.sum(mask), n_inputs + 1)) + + for i, reg in enumerate(params['regressors']): + # positive numbers are contrast on the right + if reg == 'contR': + mega_data[:, i] = np.maximum(data[mask, 0], 0) + elif reg == 'contL': + mega_data[:, i] = np.abs(np.minimum(data[mask, 0], 0)) + elif reg == 'cont': + mega_data[:, i] = data[mask, 0] + elif reg == 'prevA': + # prev_ans = data[:, 1].copy() + new_prev_ans = data[:, 1].copy() + # prev_ans[1:] = prev_ans[:-1] + # prev_ans -= 1 + new_prev_ans -= 1 + new_prev_ans = np.convolve(np.append(0, new_prev_ans), params['exp_filter'])[:-(params['exp_filter'].shape[0])] + mega_data[:, i] = new_prev_ans[mask] + elif reg == 'weighted_prevA': + prev_ans = data[:, 1].copy() + prev_ans -= 1 + # weigh the tendency by how clear the previous contrast was + weighted_prev_ans = data[:, 0] + prev_ans + weighted_prev_ans = np.convolve(np.append(0, weighted_prev_ans), params['exp_filter'])[:-(params['exp_filter'].shape[0])] + mega_data[:, i] = weighted_prev_ans[mask] + elif reg == 'WSLS': + side_info = pickle.load(open("./{}/{}_side_info_{}.p".format(data_folder, params['subject'], j), "rb")) + prev_reward = side_info[:, 1] + prev_reward[1:] = prev_reward[:-1] + prev_ans = data[:, 1].copy() + prev_ans[1:] = prev_ans[:-1] - 1 + mega_data[:, i] = prev_ans[mask] + mega_data[prev_reward[mask] == 0, i] *= -1 + elif reg == 'bias': + # have bias active only if contrasts further from 1 are in the session + # if len(np.unique(data[mask, 0])) > 5: + # print("original 1") + # mega_data[:, i] = 1 + # else: + # print("original 0") + # mega_data[:, i] = 0 + + # have bias active only if contrasts further from 1 are in the session, new version + # if bias_active: + # mega_data[:, i] = 1 + # elif len(np.unique(data[mask, 0])) > 5: + # if np.where(np.abs(data[mask, 0]) == 0.848)[0][0] / data.shape[0] < 0.5: + # mega_data[:, i] = 1 + # else: + # mega_data[:, i] = 0 + # bias_active = True + # else: + # mega_data[:, i] = 0 + + # bias is now always active + mega_data[:, i] = 1 + + mega_data[:, -1] = data[mask, 1] / 2 + elif params['obs_dur'] == 'cat': + mask = data[:, 1] != 1 + mask[0] = False + data = data[:, [0, 1]] + data[:, 1] = data[:, 1] / 2 + mega_data = data[mask] + + data_save.append(mega_data.copy()) + + if params['cross_val']: + test_sets = np.tile(np.arange(params['cross_val_fold']), mega_data.shape[0] // params['cross_val_fold'] + 1)[:mega_data.shape[0]] + rng.shuffle(test_sets) + mega_data[:, -1][test_sets == params['cross_val_num']] = None + + posteriormodel.add_data(mega_data) + + # for d in posteriormodel.datas: + # print(d.shape) + + if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])): + pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb')) + + # states_solution = pickle.load(open("states_{}_{}_condition_{}_{}.p".format('DY_013', 'all', 'nothing', '0_01'), 'rb')) # todo: remove! + time_save = time.time() + likes = np.zeros(params['n_samples']) + with warnings.catch_warnings(): # ignore the scipy warning + warnings.simplefilter("ignore") + for j in range(params['n_samples']): + + if j % 400 == 0 or j == 3: + print(j) + + posteriormodel.resample_model() + + likes[j] = posteriormodel.log_likelihood() + model_save = copy.deepcopy(posteriormodel) + if j != params['n_samples'] - 1 and j != 0 and j % 2000 != 1: + # To save on memory: + model_save.delete_data() + model_save.delete_obs_data() + model_save.delete_dur_data() + models.append(model_save) + + # save something in case of crash + if j % 400 == 0 and j > 0: + if params['n_samples'] <= 4000: + pickle.dump(models, open(folder + id + '.p', 'wb')) + else: + pickle.dump(models, open(folder + id + '_{}.p'.format(j // 4001), 'wb')) + if j % 4000 == 0: + models = [] + print(time.time() - time_save) + + if params['cross_val']: + lls, lls_mean = eval_cross_val(models, posteriormodel.datas, data_save, n_all_states=params['n_states']) + params['cross_val_preds'] = lls + params['cross_val_preds'] = params['cross_val_preds'].tolist() + + print(id) + if 'exp_filter' in params: + params['exp_filter'] = params['exp_filter'].tolist() + params['dur_params']['r_support'] = params['dur_params']['r_support'].tolist() + params['dur_params']['r_probs'] = params['dur_params']['r_probs'].tolist() + params['ll'] = likes.tolist() + params['init_mean'] = params['init_mean'].tolist() + if params['cross_val']: + json.dump(params, open(folder + "infos_new/" + '{}_{}_cvll_{}_{}_{}_{}_{}.json'.format(params['subject'], params['cross_val_num'], str(np.round(lls_mean, 3)).replace('.', '_'), + params['fit_type'], params['fit_variance'], params['seed'], rand_id), 'w')) + else: + json.dump(params, open(folder + "infos_new/" + '{}_{}_{}_{}_{}.json'.format(params['subject'], params['fit_type'], + params['fit_variance'], params['seed'], rand_id), 'w')) + pickle.dump(models, open(folder + id + '_{}.p'.format(j // 4001), 'wb')) diff --git a/dynamic_GLMiHMM_fit.py b/dynamic_GLMiHMM_fit.py index dbbe447a55e51e03d20aef3e840ccd600d76f653..bf1d046da4a345335014298a708c541c5c6d4ec3 100644 --- a/dynamic_GLMiHMM_fit.py +++ b/dynamic_GLMiHMM_fit.py @@ -67,7 +67,7 @@ contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0. num_to_contrast = {v: k for k, v in contrast_to_num.items()} cont_mapping = np.vectorize(num_to_contrast.get) -data_folder = 'session_data' +data_folder = 'session_data_test' old_style = False if old_style: print("Warning, data can have splits") @@ -86,7 +86,7 @@ subjects = ['ibl_witten_15', 'ibl_witten_17', 'ibl_witten_18', 'ibl_witten_19', 'CSH_ZAD_017', 'CSH_ZAD_025', 'CSH_ZAD_026', 'CSHL049', 'CSHL051', 'CSHL061'] # test subjects: -sim_subjects = ['GLM_Sim_13', 'GLM_Sim_11', 'GLM_Sim_15'] +subjects = ['KS014'] # subjects = ['KS021', 'KS016', 'ibl_witten_16', 'SWC_022', 'KS003', 'CSHL054', 'ZM_3003', 'KS015', 'ibl_witten_13', 'CSHL059', 'CSH_ZAD_022', 'CSHL_007', 'CSHL062', 'NYU-06', 'KS014', 'ibl_witten_14', 'SWC_023'] # subjects = [['GLM_Sim_15', 'GLM_Sim_14', 'GLM_Sim_13', 'GLM_Sim_11', 'GLM_Sim_10', 'GLM_Sim_09', 'GLM_Sim_12'][2]] @@ -314,30 +314,10 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)): mega_data[:, i] = prev_ans[mask] mega_data[prev_reward[mask] == 0, i] *= -1 elif reg == 'bias': - # have bias active only if contrasts further from 1 are in the session - # if len(np.unique(data[mask, 0])) > 5: - # print("original 1") - # mega_data[:, i] = 1 - # else: - # print("original 0") - # mega_data[:, i] = 0 - - # have bias active only if contrasts further from 1 are in the session, new version - # if bias_active: - # mega_data[:, i] = 1 - # elif len(np.unique(data[mask, 0])) > 5: - # if np.where(np.abs(data[mask, 0]) == 0.848)[0][0] / data.shape[0] < 0.5: - # mega_data[:, i] = 1 - # else: - # mega_data[:, i] = 0 - # bias_active = True - # else: - # mega_data[:, i] = 0 - # bias is now always active mega_data[:, i] = 1 - mega_data[:, -1] = data[mask, 1] / 2 + mega_data[:, -1] = data[mask, 1] / 2 # 0 is rightwards, and 1 is leftwards, because the original data is weird elif params['obs_dur'] == 'cat': mask = data[:, 1] != 1 mask[0] = False @@ -357,9 +337,9 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)): # for d in posteriormodel.datas: # print(d.shape) - if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])): - pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb')) - + # if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])): + pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb')) + quit() # states_solution = pickle.load(open("states_{}_{}_condition_{}_{}.p".format('DY_013', 'all', 'nothing', '0_01'), 'rb')) # todo: remove! time_save = time.time() likes = np.zeros(params['n_samples']) diff --git a/index_mice.py b/index_mice.py index e821de0e2e86ea42e5859cc88e8c545179cea4e8..add07bdaf3514b88e6aed44d1186c629072f3f64 100644 --- a/index_mice.py +++ b/index_mice.py @@ -20,14 +20,6 @@ for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"): fit_num = result.group(5) chain_num = result.group(6) - # print() - # print(filename) - # print(subject) - # print(fit_type) - # print(seed) - # print(fit_num) - # print(chain_num) - # continue if fit_type == 'prebias': local_dict = prebias_subinfo elif fit_type == 'bias': diff --git a/pmf_weight_analysis.py b/pmf_weight_analysis.py index 00611bcc7e1200cbe4a76a79ebde4510408bea3b..059851501e92abf00fe662b44d35c296ea153c09 100644 --- a/pmf_weight_analysis.py +++ b/pmf_weight_analysis.py @@ -3,10 +3,148 @@ import matplotlib.pyplot as plt import pickle from scipy.stats import gaussian_kde from analysis_pmf import pmf_type, type2color +from mpl_toolkits import mplot3d + +performance_points = np.array([-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0]) +reduced_points = np.array([1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=bool) + +folder = "./temp_reward_analysis/" + +def pmf_to_perf(pmf): + # determine performance of a pmf, but only on the omnipresent strongest contrasts + return np.mean(np.abs(performance_points[reduced_points] + pmf[reduced_points])) + + +contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0]) +contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1] + + +def weights_to_pmf(weights, with_bias=1): + psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1] + return 1 / (1 + np.exp(psi)) + +def pmf_type_rew(weights): + rew = pmf_to_perf(weights_to_pmf(weights)) + if rew < 0.6: + return 0 + elif rew < 0.7827: + return 1 + else: + return 2 # all pmf weights apw = np.array(pickle.load(open("all_pmf_weights.p", 'rb'))) +colors = [type2color[pmf_type(weights_to_pmf(x))] for x in apw] +colors_rew = [type2color[pmf_type_rew(x)] for x in apw] + + +if True: + for i, weights in enumerate(apw): + if pmf_type(weights_to_pmf(weights)) != pmf_type_rew(weights): + plt.plot(weights_to_pmf(weights)) + plt.ylim(0, 1) + plt.title("Classic says {}, reward says {}, ({})".format(1 + pmf_type(weights_to_pmf(weights)), 1 + pmf_type_rew(weights), pmf_to_perf(weights_to_pmf(weights)))) + print(weights_to_pmf(weights)) + plt.tight_layout() + plt.savefig(folder + "divergent classification {}".format(i)) + plt.close() + + +type1_rews = [] +type2_rews = [] +type3_rews = [] +all_rews = [] +for weights in apw: + type = pmf_type(weights_to_pmf(weights)) + all_rews.append(pmf_to_perf(weights_to_pmf(weights))) + if type == 0: + if pmf_to_perf(weights_to_pmf(weights)) > 0.8: + plt.plot(weights_to_pmf(weights)) + plt.show() + type1_rews.append(pmf_to_perf(weights_to_pmf(weights))) + elif type == 1: + type2_rews.append(pmf_to_perf(weights_to_pmf(weights))) + elif type == 2: + type3_rews.append(pmf_to_perf(weights_to_pmf(weights))) + +type1_rews, type2_rews, type3_rews = np.array(type1_rews), np.array(type2_rews), np.array(type3_rews) + +bound1 = np.linspace(0.55, 0.65, 100) +bound2 = np.linspace(0.75, 0.85, 100) + +opt_bound1, errors1 = 0, 100 +for b in bound1: + if errors1 > (np.sum(type1_rews > b) + np.sum(type2_rews < b)): + opt_bound1 = b + errors1 = np.sum(type1_rews > b) + np.sum(type2_rews < b) + +opt_bound2, errors2 = 0, 100 +for b in bound2: + if errors2 > (np.sum(type2_rews > b) + np.sum(type3_rews < b)): + opt_bound2 = b + errors2 = np.sum(type2_rews > b) + np.sum(type3_rews < b) + +print(opt_bound1, errors1, opt_bound2, errors2) + +opt_bound2 = 0.780303 +print("Optimal bound 2: {}, {}".format(np.sum(type2_rews > opt_bound2), np.sum(type3_rews < opt_bound2))) + +man_bound = 0.7827 +print("Manual bound 2: {}, {}".format(np.sum(type2_rews > man_bound), np.sum(type3_rews < man_bound))) + + +bins = np.linspace(0, 1, 30) +plt.hist([type1_rews, type2_rews, type3_rews], bins=bins, label=["Type 1", "Type 2", "Type 3"]) +plt.axvline(0.6, c='k') +plt.axvline(0.7827, c='k') +plt.legend() +plt.savefig(folder + "hist 1") +plt.show() + +plt.hist([type1_rews, type2_rews, type3_rews], bins=bins, label=["Type 1", "Type 2", "Type 3"], stacked=True) +plt.axvline(0.6, c='k') +plt.axvline(0.7827, c='k') +plt.legend() +plt.savefig(folder + "hist 2") +plt.show() + +plt.subplot(1, 3, 1) +plt.title("Bins = 25") +plt.hist(all_rews, bins=25) +plt.axvline(0.6, c='k') +plt.axvline(0.7827, c='k') + +plt.subplot(1, 3, 2) +plt.title("Bins = 40") +plt.hist(all_rews, bins=40) +plt.axvline(0.6, c='k') +plt.axvline(0.7827, c='k') + +plt.subplot(1, 3, 3) +plt.title("Bins = 55") +plt.hist(all_rews, bins=55) +plt.axvline(0.6, c='k') +plt.axvline(0.7827, c='k') + +plt.savefig(folder + "hists compare") +plt.show() + + +fig = plt.figure(figsize=(13 * 3 / 5, 9 * 3 / 5)) +plt.hist(all_rews, bins=40, color='grey') +plt.axvline(0.6, c='k') +plt.axvline(0.7827, c='k') + +plt.ylabel("# of occurences", size=28) +plt.xlabel("Reward rate", size=28) +plt.gca().spines[['right', 'top']].set_visible(False) + +plt.tight_layout() +plt.savefig(folder + "single hist") +plt.show() + + xy = np.vstack([apw[:, i] for i in range(4)]) z = gaussian_kde(xy)(xy) @@ -25,67 +163,107 @@ plt.scatter(apw[:, 3], apw[:, 0], c=z) plt.xlabel("Bias") plt.ylabel("Cont right") +plt.savefig(folder + "density scatter") plt.show() -contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0]) -contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1] - -def weights_to_pmf(weights, with_bias=1): - psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1] - return 1 / (1 + np.exp(psi)) - -colors = [type2color[pmf_type(weights_to_pmf(x))] for x in apw] - -plt.subplot(1, 3, 1) +plt.subplot(1, 4, 1) sc = plt.scatter(apw[:, 0], apw[:, 1], c=colors) -fig, ax = plt.gcf(), plt.gca() +fig, ax1 = plt.gcf(), plt.gca() plt.xlabel("Cont right") plt.ylabel("Cont left") -annot = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points", - bbox=dict(boxstyle="round", fc="w"), - arrowprops=dict(arrowstyle="->")) -annot.set_visible(False) +annot1 = ax1.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + arrowprops=dict(arrowstyle="->")) +annot1.set_visible(False) + +plt.subplot(1, 4, 2) +plt.scatter(apw[:, 3], apw[:, 1], c=colors) +ax2 = plt.gca() +plt.xlabel("Bias") +plt.ylabel("Cont left") + +annot2 = ax2.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + arrowprops=dict(arrowstyle="->")) +annot2.set_visible(False) + +plt.subplot(1, 4, 3) +plt.scatter(apw[:, 3], apw[:, 0], c=colors) +ax3 = plt.gca() +plt.xlabel("Bias") +plt.ylabel("Cont right") + +annot3 = ax3.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + arrowprops=dict(arrowstyle="->")) +annot3.set_visible(False) + def update_annot(ind): - print(ind) pos = sc.get_offsets()[ind["ind"][0]] - annot.xy = pos + annot1.xy = pos text = "{}".format(np.round(apw[ind["ind"][0]], 2)) - annot.set_text(text) + annot1.set_text(text) + + annot2.xy = apw[ind["ind"][0]][[3, 1]] + text = "{}".format(np.round(apw[ind["ind"][0]], 2)) + annot2.set_text(text) + + annot3.xy = apw[ind["ind"][0]][[3, 0]] + text = "{}".format(np.round(apw[ind["ind"][0]], 2)) + annot3.set_text(text) + + plt.subplot(1, 4, 4) + plt.cla() + plt.plot(weights_to_pmf(apw[ind["ind"][0]])) + plt.ylim(0, 1) + plt.ylabel("P(rightwards)") + plt.xlabel("Contrasts") + def hover(event): - vis = annot.get_visible() - if event.inaxes == ax: + vis = annot1.get_visible() + if event.inaxes == ax1: cont, ind = sc.contains(event) if cont: update_annot(ind) - annot.set_visible(True) + annot1.set_visible(True) + annot2.set_visible(True) + annot3.set_visible(True) fig.canvas.draw_idle() else: if vis: - annot.set_visible(False) + annot1.set_visible(False) + annot2.set_visible(False) + annot3.set_visible(False) fig.canvas.draw_idle() + fig.canvas.mpl_connect("motion_notify_event", hover) -plt.subplot(1, 3, 2) -plt.scatter(apw[:, 3], apw[:, 1], c=colors) -plt.xlabel("Bias") -plt.ylabel("Cont left") +plt.savefig(folder + "type scatter") +plt.show() -plt.subplot(1, 3, 3) -plt.scatter(apw[:, 3], apw[:, 0], c=colors) -plt.xlabel("Bias") -plt.ylabel("Cont right") +fig = plt.figure(figsize=(16, 9)) +ax = plt.axes(projection='3d') +ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors) +ax.view_init(27.5, -137) +plt.savefig(folder + "3d types") plt.show() +fig = plt.figure(figsize=(16, 9)) +ax = plt.axes(projection='3d') +ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors_rew) +ax.view_init(27.5, -137) +plt.savefig(folder + "3d types new") +plt.show() -from mpl_toolkits import mplot3d fig = plt.figure() ax = plt.axes(projection='3d') -ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors) +ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=z) +plt.savefig(folder + "3d density") plt.show() diff --git a/process_many_chains.py b/process_many_chains.py index cf8061356316a47aae18dedf91b2d0b4b2b27e9c..a0664e42c0946a573b715f894307b8c096c1c96b 100644 --- a/process_many_chains.py +++ b/process_many_chains.py @@ -17,7 +17,7 @@ if fit_type == 'bias': loading_info = json.load(open("canonical_infos_bias.json", 'r')) elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) -subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] # list(loading_info.keys()) +subjects = ['ZFM-04019', 'ZFM-05236'] # list(loading_info.keys()) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] func1 = state_num_helper(0.2) diff --git a/simplex_animation.py b/simplex_animation.py index 89dfe945c0994f472c74602d89541d12ac1d5fb1..ddf6df7925ac502d5319f6b4a6986eb27adb6fff 100644 --- a/simplex_animation.py +++ b/simplex_animation.py @@ -9,6 +9,7 @@ import copy all_state_types = pickle.load(open("all_state_types.p", 'rb')) + # create a list of fixed offsets, to apply to mice when they are in a corner offsets = [(0, 0)] hex_size = 0 @@ -37,7 +38,7 @@ assert (test_count == 1).all() # plt.show() # quit() -session_counter = 0 +session_counter = -1 alph = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'za', 'zb', 'zc', 'zd', 'ze', 'zf', 'zg', 'zh', 'zi', 'zj'] # do as many sessions as it takes while True: @@ -65,9 +66,8 @@ while True: type_proportions[:, i] = sts[:, temp_counter] - plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(session_counter), + plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(1 + session_counter), vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(alph[session_counter])) - plt.close() if not_ended == 0: break diff --git a/simplex_plot.py b/simplex_plot.py index 635885d05417cb2c7810ec74e6d5cb5572e6f0b7..420dd26c661da32a2e9f89e2fbb937edea411f25 100644 --- a/simplex_plot.py +++ b/simplex_plot.py @@ -14,8 +14,8 @@ import matplotlib.patches as PA def plotSimplex(points, fig=None, - vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'], - save_title="test.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs): + vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'], title='', + save_title="dur_simplex.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs): """ Plot Nx3 points array on the 3-simplex (with optionally labeled vertices) @@ -37,21 +37,25 @@ def plotSimplex(points, fig=None, # fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False) # Project and draw the actual points projected = projectSimplex(points / points.sum(1)[:, None]) - print(projected) - P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs) + P.scatter(projected[:, 0] + x_offset, projected[:, 1] + y_offset, s=points.sum(1) * 3.5, **kwargs)#s=35 # plot center with average size projected = projectSimplex(np.mean(points / points.sum(1)[:, None], axis=0).reshape(1, 3)) - P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5) + print(points) + print() + print(projected) + P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5)#s=50 # Leave some buffer around the triangle for vertex labels fig.gca().set_xlim(-0.05, 1.05) fig.gca().set_ylim(-0.05, 1.05) P.axis('off') + if title != '': + P.annotate(title, (0.395, np.sqrt(3) / 2 + 0.025), size=24) P.tight_layout() - P.savefig("dur_simplex.png", bbox_inches='tight', dpi=300, transparent=True) + P.savefig(save_title, bbox_inches='tight', dpi=300, transparent=True) if show: P.show() else: diff --git a/state_dict_KS014 b/state_dict_KS014 new file mode 100644 index 0000000000000000000000000000000000000000..c183d111be82d6be94e31c438d9ece422bbbea76 Binary files /dev/null and b/state_dict_KS014 differ diff --git a/state_dict_ZFM-04019 b/state_dict_ZFM-04019 new file mode 100644 index 0000000000000000000000000000000000000000..458cbe0f1f432c5e57ca3156688138077467873d Binary files /dev/null and b/state_dict_ZFM-04019 differ diff --git a/state_dict_ZFM-05236 b/state_dict_ZFM-05236 new file mode 100644 index 0000000000000000000000000000000000000000..216c9aeaafe342f1fda365a4d8261f91f3c2d74c Binary files /dev/null and b/state_dict_ZFM-05236 differ