diff --git a/__pycache__/analysis_pmf.cpython-310.pyc b/__pycache__/analysis_pmf.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9444f63cd5c90bb3a9aa6b583c61b347d9bbe11b
Binary files /dev/null and b/__pycache__/analysis_pmf.cpython-310.pyc differ
diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc
index e1ab5f64794b712fd3d7753bfc894d24535085b8..f9fd33bf94c4befd1c5937c83c151a084e749d91 100644
Binary files a/__pycache__/analysis_pmf.cpython-37.pyc and b/__pycache__/analysis_pmf.cpython-37.pyc differ
diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc
index 740060e1aee1c905d0d6e98e7fb190b6806d36f6..4a89c93f30137d03c0b66cc48856b5c8f6b5a111 100644
Binary files a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc and b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc differ
diff --git a/__pycache__/simplex_plot.cpython-310.pyc b/__pycache__/simplex_plot.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c4249fdd0ebd95ed28071c9f8d211aa84a2bb34
Binary files /dev/null and b/__pycache__/simplex_plot.cpython-310.pyc differ
diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc
index d40c9990e4c7cc9a2781a81174008f4cb3e8b339..a8537b6f9da6be3317317d69450bbe8640da3809 100644
Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ
diff --git a/analysis_pmf.py b/analysis_pmf.py
index 637f06cd68c5139c19f7171c90f5b28c4a86c15a..03e55c1d17a94fbdfd7a0b8f2c8e239d4a076e4f 100644
--- a/analysis_pmf.py
+++ b/analysis_pmf.py
@@ -7,8 +7,30 @@ import seaborn as sns
 type2color = {0: 'green', 1: 'blue', 2: 'red'}
 all_conts = np.array([-1, -0.5, -.25, -.125, -.062, 0, .062, .125, .25, 0.5, 1])
 
+performance_points = np.array([-1, -1, 0, 0])
+
+def pmf_to_perf(pmf):
+    # determine performance of a pmf, but only on the omnipresent strongest contrasts
+    return np.mean(np.abs(performance_points + pmf[[0, 1, -2, -1]]))
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
+    return 1 / (1 + np.exp(psi))
 
 def pmf_type(pmf):
+    rew = pmf_to_perf(pmf)
+    if rew < 0.6:
+        return 0
+    elif rew < 0.7827:
+        return 1
+    else:
+        return 2
+
+
+def pmf_type_old(pmf):
     if pmf[-1] - pmf[0] < 0.2:
         return 0
     # elif pmf[-1] - pmf[0] < 0.4:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1:
@@ -59,6 +81,30 @@ if __name__ == "__main__":
     all_pmfs = pickle.load(open("all_pmfs.p", 'rb'))
     all_bias_flips = pickle.load(open("all_bias_flips.p", 'rb'))
 
+    _, axs = plt.subplots(1, 3, figsize=(16, 9))
+    for defined_points, pmf in all_first_pmfs_typeless['KS014']:
+        axs[pmf_type(pmf)].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)])
+        if np.round(pmf_to_perf(pmf), 2) in [0.86, 0.81]:
+            axs[pmf_type(pmf)].annotate(np.round(pmf_to_perf(pmf), 2), (8.7, pmf[-1] - 0.05), size=19)
+        else:
+            axs[pmf_type(pmf)].annotate(np.round(pmf_to_perf(pmf), 2), (8.7, pmf[-1] + 0.02), size=19)
+    for i, a in enumerate(axs):
+        a.spines[['right', 'top']].set_visible(False)
+        a.set_ylim(0, 1)
+        if i != 0:
+            a.set_xticks([])
+            a.set_yticks([])
+        else:
+            tick_size = 14
+            a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+            a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+            a.set_ylabel("P(rightwards)", size=26)
+            a.set_xlabel("Contrasts", size=26)
+    plt.tight_layout()
+    plt.savefig("ks014_pmfs")
+    plt.close()
+
+
     # bias flips
     # plt.hist(all_bias_flips, bins=np.arange(0, max(all_bias_flips) + 1), color='grey', align='left')
     # plt.ylabel("# of mice")
@@ -159,7 +205,92 @@ if __name__ == "__main__":
     # plt.xlabel("Higher lapse rate")
     # plt.show()
 
-    lw = 4
+    lw = 7
+
+    # indiv examples
+
+    state_num = 6
+    defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+
+    plt.ylim(0, 1)
+    plt.xlim(0, 10)
+    plt.yticks([])
+    plt.gca().set_xticks([])
+    sns.despine()
+    plt.gca().spines['left'].set_linewidth(4)
+    plt.gca().spines['bottom'].set_linewidth(4)
+    plt.tight_layout()
+    plt.savefig("single exam 1")
+    plt.close()
+
+
+    state_num = 7
+    defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+
+    plt.ylim(0, 1)
+    plt.xlim(0, 10)
+    plt.yticks([])
+    plt.gca().set_xticks([])
+    sns.despine()
+    plt.gca().spines['left'].set_linewidth(4)
+    plt.gca().spines['bottom'].set_linewidth(4)
+    plt.tight_layout()
+    plt.savefig("single exam 2")
+    plt.close()
+
+
+
+    state_num = 3
+    defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+
+    plt.ylim(0, 1)
+    plt.xlim(0, 10)
+    plt.yticks([])
+    plt.gca().set_xticks([])
+    sns.despine()
+    plt.gca().spines['left'].set_linewidth(4)
+    plt.gca().spines['bottom'].set_linewidth(4)
+    plt.tight_layout()
+    plt.savefig("single exam 3")
+    plt.close()
+
+
+    state_num = 1
+    defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+
+    plt.ylim(0, 1)
+    plt.xlim(0, 10)
+    plt.yticks([])
+    plt.gca().set_xticks([])
+    sns.despine()
+    plt.gca().spines['left'].set_linewidth(4)
+    plt.gca().spines['bottom'].set_linewidth(4)
+    plt.tight_layout()
+    plt.savefig("single exam 4")
+    plt.close()
+
+
+    state_num = 4
+    defined_points, pmf = all_first_pmfs_typeless['ibl_witten_17'][state_num][0], all_first_pmfs_typeless['ibl_witten_17'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+
+    plt.ylim(0, 1)
+    plt.xlim(0, 10)
+    plt.yticks([])
+    plt.gca().set_xticks([])
+    sns.despine()
+    plt.gca().spines['left'].set_linewidth(4)
+    plt.gca().spines['bottom'].set_linewidth(4)
+    plt.tight_layout()
+    plt.savefig("single exam 5")
+    plt.show()
+
+
+
     # Simplex example pmfs
     state_num = 7
     defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1]
@@ -211,7 +342,7 @@ if __name__ == "__main__":
     plt.tight_layout()
     plt.savefig("example type 3")
     plt.show()
-    quit()
+
 
     n_rows, n_cols = 5, 6
     _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
@@ -356,7 +487,7 @@ if __name__ == "__main__":
     # All first PMFs
     tick_size = 14
     label_size = 26
-    all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb'))
+    all_first_pmfs = pickle.load(open("all_first_pmfs.p", 'rb'))
     n_rows, n_cols = 1, 3
     _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
     save_title = "all types" if True else "KS014 types"
@@ -407,6 +538,9 @@ if __name__ == "__main__":
             else:
                 use_ax = int(pmf[0] > 1 - pmf[-1])
 
+            pmf_min = min(pmf[0], pmf[1])
+            pmf_max = max(pmf[-2], pmf[-1])
+            defined_points = np.logical_and(np.logical_and(defined_points, ~ (pmf > pmf_max)), ~ (pmf < pmf_min))
             ax[use_ax].plot(np.where(defined_points)[0], pmf[defined_points], c='b')
     ax[0].set_ylim(0, 1)
     ax[0].set_xlim(0, 10)
diff --git a/behavioral_state_data.py b/behavioral_state_data.py
index 2f1f3f0256f7c9d2721d955c74ff2013a0bcdde2..be9aef832e5d8d480af615b126d5528cad2f8229 100644
--- a/behavioral_state_data.py
+++ b/behavioral_state_data.py
@@ -41,7 +41,7 @@ def get_df(trials):
     df['feedback'] = df['feedback'].replace(-1, 0)
     df['signed_contrast'] = df['contrastR'] - df['contrastL']
     df['signed_contrast'] = df['signed_contrast'].map(contrast_to_num)
-    df['response'] += 1
+    df['response'] += 1  # this is coded most unintuitely, 0 is rightwards, and 1 is leftwards (which is why I not this variable in other programs)
     df['block'] = trials['probabilityLeft']
     # df['rt'] = data_dict['response_times'] - data_dict['goCue_times']  # RTODO
 
@@ -49,11 +49,7 @@ def get_df(trials):
 
 misses = []
 to_introduce = [2, 3, 4, 5]
-subjects = ['ibl_witten_13', 'ibl_witten_17', 'ibl_witten_18',
-            'ibl_witten_19', 'ibl_witten_20', 'ibl_witten_25', 'ibl_witten_26',
-            'ibl_witten_27', 'ibl_witten_29', 'CSH_ZAD_001', 'CSH_ZAD_011',
-            'CSH_ZAD_019', 'CSH_ZAD_022', 'CSH_ZAD_024', 'CSH_ZAD_025',
-            'CSH_ZAD_026', 'CSH_ZAD_029']
+# subjects = ['ZFM-05236', 'ZFM-05245']
 # subjects = ['CSHL045', 'CSHL046', 'CSHL049', 'CSHL051', 'CSHL052', 'CSHL053', 'CSHL054', 'CSHL055',
 #             'CSHL059', 'CSHL061', 'CSHL062', 'CSHL065', 'CSHL066']  # CSHL063 no good
 # subjects = ['CSHL_007', 'CSHL_014', 'CSHL_015', 'CSHL_018', 'CSHL_019', 'CSHL_020', 'CSH_ZAD_001', 'CSH_ZAD_011',
@@ -77,7 +73,7 @@ subjects = ['ibl_witten_13', 'ibl_witten_17', 'ibl_witten_18',
 #             "ibl_witten_06", "ibl_witten_07", "ibl_witten_12", "ibl_witten_13", "ibl_witten_14", "ibl_witten_15",
 #             "ibl_witten_16", "KS003", "KS005", "KS019", "NYU-01", "NYU-02", "NYU-04", "NYU-06", "ZM_1367", "ZM_1369",
 #             "ZM_1371", "ZM_1372", "ZM_1743", "ZM_1745", "ZM_1746"]  # zoe's subjects
-# subjects = ['CSHL_018']
+subjects = ['ibl_witten_14']
 
 data_folder = 'session_data'
 # why does CSHL058 not work?
@@ -101,7 +97,7 @@ for subject in subjects:
     print('_____________________')
     print(subject)
 
-    eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2023-01-01'], details=True)
+    eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2025-01-01'], details=True)
 
     start_times = [sess['date'] for sess in sess_info]
     protocols = [sess['task_protocol'] for sess in sess_info]
diff --git a/behavioral_state_data_easier.py b/behavioral_state_data_easier.py
new file mode 100644
index 0000000000000000000000000000000000000000..e526d7cc485df51c38469af52b0e669aee793b5c
--- /dev/null
+++ b/behavioral_state_data_easier.py
@@ -0,0 +1,200 @@
+from one.api import ONE
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+import pickle
+import json
+
+one = ONE()
+
+
+contrast_to_num = {-1.: 0, -0.5: 1, -0.25: 2, -0.125: 3, -0.0625: 4, 0: 5, 0.0625: 6, 0.125: 7, 0.25: 8, 0.5: 9, 1.: 10}
+
+dataset_types = ['choice', 'contrastLeft', 'contrastRight',
+                 'feedbackType', 'probabilityLeft', 'response_times',
+                 'goCue_times']
+
+def get_df(trials):
+    if np.all(None == trials['choice']) or np.all(None == trials['contrastLeft']) or np.all(None == trials['contrastRight']) or np.all(None == trials['feedbackType']) or np.all(None == trials['probabilityLeft']):  # or np.all(None == data_dict['response_times']):
+        return None, None
+    d = {'response': trials['choice'], 'contrastL': trials['contrastLeft'], 'contrastR': trials['contrastRight'], 'feedback': trials['feedbackType']}
+
+    df = pd.DataFrame(data=d, index=range(len(trials['choice']))).fillna(0)
+    df['feedback'] = df['feedback'].replace(-1, 0)
+    df['signed_contrast'] = df['contrastR'] - df['contrastL']
+    df['signed_contrast'] = df['signed_contrast'].map(contrast_to_num)
+    df['response'] += 1  # this is coded most unintuitely, 0 is rightwards, and 1 is leftwards (which is why I not this variable in other programs)
+    df['block'] = trials['probabilityLeft']
+    # df['rt'] = data_dict['response_times'] - data_dict['goCue_times']  # RTODO
+
+    return df
+
+misses = []
+to_introduce = [2, 3, 4, 5]
+
+amiss = ['UCLA034', 'UCLA036', 'UCLA037', 'PL015', 'PL016', 'PL017', 'PL024', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021', 'NR_0027']
+subjects = ['NYU-21']
+fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
+if fit_type == 'bias':
+    loading_info = json.load(open("canonical_infos_bias.json", 'r'))
+    r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r'))
+elif fit_type == 'prebias':
+    loading_info = json.load(open("canonical_infos.json", 'r'))
+    r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
+already_fit = list(loading_info.keys())
+
+remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit]
+print(remaining_subs)
+
+
+
+data_folder = 'session_data_test'
+
+old_style = False
+if old_style:
+    print("Warning, data can have splits")
+    data_folder = 'session_data_old'
+bias_eids = []
+
+print("#########################################")
+print("Waring, rt's removed, find with   # RTODO")
+print("#########################################")
+
+short_subjs = []
+names = []
+
+pre_bias = []
+entire_training = []
+for subject in subjects:
+    print('_____________________')
+    print(subject)
+
+    if subject in already_fit or subject in amiss:
+        continue
+
+    trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table')
+
+    # Load training status and join to trials table
+    training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table')
+    trials = (trials
+              .set_index('session')
+              .join(training.set_index('session'))
+              .sort_values(by='session_start_time', kind='stable'))
+
+    start_times, indices = np.unique(trials.session_start_time, return_index=True)
+    start_times = [trials.session_start_time[index] for index in sorted(indices)]
+    task_protocol, indices = np.unique(trials.task_protocol, return_index=True)
+    task_protocol = [trials.task_protocol[index] for index in sorted(indices)]
+    nums, indices = np.unique(trials.session_number, return_index=True)
+    nums = [trials.session_number[index] for index in sorted(indices)]
+    eids, indices = np.unique(trials.index, return_index=True)
+    eids = [trials.index[index] for index in sorted(indices)]
+
+    print("original # of eids {}".format(len(eids)))
+
+    test = [(y, x) for y, x in sorted(zip(start_times, eids))]
+    pickle.dump(test, open("./{}/{}_session_names.p".format(data_folder, subject), "wb"))
+
+    performance = np.zeros(len(eids))
+    easy_per = np.zeros(len(eids))
+    hard_per = np.zeros(len(eids))
+    bias_start = 0
+
+    info_dict = {'subject': subject, 'dates': [st.to_pydatetime() for st in start_times], 'eids': eids}
+    contrast_set = {0, 1, 9, 10}
+
+    rel_count = -1
+    quit()
+    for i, start_time in enumerate(start_times):
+
+        rel_count += 1
+
+        df = trials[trials.session_start_time == start_time]
+        df.loc[:, 'contrastRight'] = df.loc[:, 'contrastRight'].fillna(0)
+        df.loc[:, 'contrastLeft'] = df.loc[:, 'contrastLeft'].fillna(0)
+        df.loc[:, 'feedbackType'] = df.loc[:, 'feedbackType'].replace(-1, 0)
+        df.loc[:, 'signed_contrast'] = df.loc[:, 'contrastRight'] - df.loc[:, 'contrastLeft']
+        df.loc[:, 'signed_contrast'] = df.loc[:, 'signed_contrast'].map(contrast_to_num)
+        df.loc[:, 'choice'] = df.loc[:, 'choice'] + 1
+
+        if any([df[x].isnull().any() for x in ['signed_contrast', 'choice', 'feedbackType', 'probabilityLeft']]):
+            quit()
+
+        assert len(np.unique(df['session_start_time'])) == 1
+
+        current_contrasts = set(df['signed_contrast'])
+        diff = current_contrasts.difference(contrast_set)
+        for c in to_introduce:
+            if c in diff:
+                info_dict[c] = rel_count
+        contrast_set.update(diff)
+
+        performance[i] = np.mean(df['feedbackType'])
+        easy_per[i] = np.mean(df['feedbackType'][np.logical_or(df['signed_contrast'] == 0, df['signed_contrast'] == 10)])
+        hard_per[i] = np.mean(df['feedbackType'][df['signed_contrast'] == 5])
+
+        if bias_start == 0 and df.task_protocol[0].startswith('_iblrig_tasks_biasedChoiceWorld'):
+            bias_start = i
+            print("bias start {}".format(rel_count))
+            info_dict['bias_start'] = rel_count
+            if bias_start < 33:
+                short_subjs.append(subject)
+
+        pickle.dump(df, open("./{}/{}_df_{}.p".format(data_folder, subject, rel_count), "wb"))
+
+        side_info = np.zeros((len(df), 2))
+        side_info[:, 0] = df['probabilityLeft']
+        side_info[:, 1] = df['feedbackType']
+        pickle.dump(side_info, open("./{}/{}_side_info_{}.p".format(data_folder, subject, rel_count), "wb"))
+
+        fit_info = np.zeros((len(df), 3))
+        fit_info[:, 0] = df['signed_contrast']
+        fit_info[:, 1] = df['choice']
+        print(len(df))
+        # fit_info[:, 2] = df['rt']  # RTODO
+        pickle.dump(fit_info, open("./{}/{}_fit_info_{}.p".format(data_folder, subject, rel_count), "wb"))
+
+    if rel_count == -1:
+        continue
+    plt.figure(figsize=(11, 8))
+    print(performance)
+    plt.plot(performance, label='Overall')
+    plt.plot(easy_per, label='100% contrasts')
+    plt.plot(hard_per, label='0% contrasts')
+    plt.axvline(bias_start - 0.5)
+    skip_count = 0
+    for p in performance:
+        if p == 0.:
+            skip_count += 1
+        else:
+            break
+    for c in to_introduce:
+        plt.axvline(info_dict[c] + skip_count, ymax=0.85, c='grey')
+    plt.annotate('Pre-bias', (bias_start / 2, 1.), size=20, ha='center')
+    plt.annotate('Bias', (bias_start + (i - bias_start) / 2, 1.), size=20, ha='center')
+    plt.title(subject, size=22)
+    plt.ylabel('Performance', size=22)
+    plt.xlabel('Session', size=22)
+    plt.xticks(size=16)
+    plt.xticks(size=16)
+    plt.ylim(bottom=0)
+    plt.xlim(left=0)
+
+    sns.despine()
+    plt.tight_layout()
+    if not old_style:
+        plt.savefig('./figures/behavior/all_of_trainig_{}'.format(subject))
+    plt.close()
+
+    # print(bias_eids)
+    pre_bias.append(info_dict['bias_start'])
+    entire_training.append(rel_count + 1)
+
+    info_dict['n_sessions'] = rel_count
+    pickle.dump(info_dict, open("./{}/{}_info_dict.p".format(data_folder, subject), "wb"))
+print(misses)
+print(short_subjs)
+
+print(pre_bias)
+print(entire_training)
diff --git a/behaviour_overview.py b/behaviour_overview.py
index 1a4899e65f47119af82ac806fd7babd966679c33..f38eca39245e5708bad090ad36b119c2ef4aa6f3 100644
--- a/behaviour_overview.py
+++ b/behaviour_overview.py
@@ -5,7 +5,7 @@ from one.api import ONE
 import pickle
 
 
-show = 0
+show = 1
 
 def progression(data, contrasts, progression_variable='feedback', windowsize=6, upper_bound=None, title=None):
     # looks somewhat irregular, red dots are not in middle of bump they cause, this is because distribution of specific contrasts is not uniform
@@ -47,12 +47,37 @@ def progression(data, contrasts, progression_variable='feedback', windowsize=6,
     #     plt.title(title, size=22)
     #     plt.savefig("temp {}".format(title).replace('/', '_'))
     #     plt.close()
+    if progression_variable == 'feedback':
+        data_local = data[:113]
+        means = data_local.groupby('signed_contrast').mean()['response']
+        stds = data_local.groupby('signed_contrast').sem()['response']
+        plt.errorbar(means.index, means.values, stds.values, label='1')
+
+        data_local = data[113:226]
+        means = data_local.groupby('signed_contrast').mean()['response']
+        stds = data_local.groupby('signed_contrast').sem()['response']
+        plt.errorbar(means.index, means.values, stds.values, label='2')
+
+        data_local = data[226:339]
+        means = data_local.groupby('signed_contrast').mean()['response']
+        stds = data_local.groupby('signed_contrast').sem()['response']
+        plt.errorbar(means.index, means.values, stds.values, label='3')
+
+        data_local = data[339:]
+        means = data_local.groupby('signed_contrast').mean()['response']
+        stds = data_local.groupby('signed_contrast').sem()['response']
+        plt.errorbar(means.index, means.values, stds.values, label='last')
+        plt.title(title, size=22)
+        plt.ylim(0, 1)
+        plt.legend()
+        # plt.savefig("temp {}".format(title).replace('/', '_'))
+        plt.show()
     if progression_variable == 'rt':
         means = data.groupby('signed_contrast').mean()['rt']
         stds = data.groupby('signed_contrast').sem()['rt']
         plt.errorbar(means.index, means.values, stds.values)
         plt.title(title, size=22)
-        plt.savefig("temp {}".format(title).replace('/', '_'))
+        # plt.savefig("temp {}".format(title).replace('/', '_'))
         plt.close()
 
 
@@ -95,8 +120,8 @@ exclude_eids = ['a66f1593-dafd-4982-9b66-f9554b6c86b5', 'ee40aece-cffd-4edb-a4b6
 #                      project='ibl_neuropixel_brainwide_01')
 # traj.reverse()
 
-subject = 'KS014'
-eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2022-01-01'], details=True)
+subject = 'KS022'
+eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2024-01-01'], details=True)
 
 start_times = [sess['date'] for sess in sess_info]
 protocols = [sess['task_protocol'] for sess in sess_info]
@@ -107,8 +132,11 @@ protocols = [x for _, x in sorted(zip(start_times, protocols))]
 # for t in traj:
 counti = 0
 for i, (prot, eid) in enumerate(zip(protocols, eids)):
+    print("Nice demonstration that the PMF jump is actually quite sudden!")
     print(i)
-    if not prot.startswith('_iblrig_tasks_trainingChoiceWorld'):
+    # if not prot.startswith('_iblrig_tasks_trainingChoiceWorld'):
+    #     continue
+    if 'habituation' in prot:
         continue
     # eid = t['session']['id']
     df, _ = get_df(eid)
@@ -116,13 +144,18 @@ for i, (prot, eid) in enumerate(zip(protocols, eids)):
     if df is None:
         continue
 
+    print(prot)
+
     counti += 1
 
-    rt_data = np.zeros((len(df), 3))
-    rt_data[:, 0] = df['signed_contrast']
-    rt_data[:, 1] = df['rt']
-    rt_data[:, 2] = df['response']
-    pickle.dump(rt_data, open("./session_data/{} rt info {}".format(subject, counti), 'wb'))
+    if counti != 7:
+        continue
+
+    # rt_data = np.zeros((len(df), 3))
+    # rt_data[:, 0] = df['signed_contrast']
+    # rt_data[:, 1] = df['rt']
+    # rt_data[:, 2] = df['response']
+    # pickle.dump(rt_data, open("./session_data/{} rt info {}".format(subject, counti), 'wb'))
 
-    progression(df, df['signed_contrast'].unique(), progression_variable='feedback', upper_bound=2, title="{} PMF {} / 15".format(subject, counti))
-    progression(df, df['signed_contrast'].unique(), progression_variable='rt', upper_bound=4, title="{} CMF {} / 15".format(subject, counti))
+    progression(df, df['signed_contrast'].unique(), progression_variable='feedback', upper_bound=2, title="{} PMF {}".format(subject, counti))
+    # progression(df, df['signed_contrast'].unique(), progression_variable='rt', upper_bound=4, title="{} CMF {}".format(subject, counti))
diff --git a/brain_wide_download.py b/brain_wide_download.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e44f83a0eae9abc58b3e8970385ed8313e7a13d
--- /dev/null
+++ b/brain_wide_download.py
@@ -0,0 +1,77 @@
+from pathlib import Path
+
+import seaborn
+
+from one.api import ONE
+from one.remote import aws
+from one.alf.files import add_uuid_string
+# root_path = Path("/datadisk/FlatIron/aggregates")
+#
+# files_parquet = list(root_path.rglob('_ibl_subjectTrials.table.pqt'))
+
+one.load(subject_id, content_type='subject', name='_ibl_subjectTrials.table.pqt')
+
+
+if len(files_parquet) == 0:
+    one = ONE()
+    s3, bucket_name = aws.get_s3_from_alyx(alyx=one.alyx)
+    datasets = one.alyx.rest('datasets', 'list', name='_ibl_subjectTrials.table.pqt')
+    for dset in datasets:
+        rel_path = dset['file_records'][0]['relative_path']
+        aws_path = add_uuid_string('aggregates/' + rel_path, dset['url'][-36:])
+        aws.s3_download_file(aws_path, root_path.joinpath(rel_path), s3=s3, bucket_name=bucket_name)
+
+
+## %%
+import dask.dataframe as dd
+import pandas as pd
+import numpy as np
+
+time_bins = np.arange(0, 4000, 20)
+
+
+trials = dd.read_parquet(files_parquet)
+# the ITI is the start time of the next trial minus the end time of the current plus .5 secs
+trials['iti'] = trials['intervals_0'].shift(-1) - trials['intervals_1'] + .5
+trials = trials[trials['iti'] > 0]  # we need to remove the session jumps
+# here we select only the ephys protocol trials
+trials = trials[trials['task_protocol'].apply(lambda x: 'ephys' in x, meta=('task_protocol', 'bool'))]
+# aggregate the iti per time bins
+evolution = trials['iti'].groupby(trials['intervals_0'].map_partitions(pd.cut, time_bins)).agg(['mean', 'std'])
+
+
+
+import time
+# crunch crunch
+now = time.time()
+itis = trials['iti'].compute()
+tev = evolution.compute()
+print(time.time() - now)
+
+
+
+## %%
+import seaborn as sns
+import matplotlib.pyplot as plt
+sns.set_theme('paper')
+sns.set_palette('deep')
+fig, axs = plt.subplots(1, 3, figsize=(16, 5))
+axs[0].hist(itis, bins=1000, range=[0.5, 3], linewidth=0, alpha=0.8)
+axs[0].set(title='ITI distribution for all trials', xlabel='ITI (s)', ylabel='Count')
+
+x = time_bins[:-1] + 10
+axs[1].fill_between(x, y1=tev['mean'].values - tev['std'].values, y2=tev['mean'].values + tev['std'].values, alpha=0.4)
+axs[1].plot(x, tev['mean'].values, color='orange', linewidth=2)
+axs[1].set(title='ITI duration', ylabel='ITI (s)', xlabel='Time elapsed in session (s)', ylim=[0.75, 2], xlim=[0, 3600])
+
+trials_ = pd.read_parquet(files_parquet[0])
+trials_['iti'] = trials_['intervals_0'].shift(-1) - trials_['intervals_1'] + .5
+session0 = trials_[trials_['session'] == 'dc1b7422-cc16-4a37-b552-c31dccdddbce']
+axs[2].plot(session0['intervals_0'], session0['iti'], label='dc1b7422', linewidth=0.5)
+trials_ = pd.read_parquet(files_parquet[50])
+trials_['iti'] = trials_['intervals_0'].shift(-1) - trials_['intervals_1'] + .5
+session1 = trials_[trials_['session'] == 'd9d83a6a-8fb5-41eb-b4ce-7c6dd1716d72']
+axs[2].plot(session1['intervals_0'], session1['iti'], label='d9d83a6a', linewidth=0.5)
+axs[2].set(title='ITI duration', ylabel='ITI (s)', xlabel='Time elapsed in session (s)', ylim=[0.75, 2], xlim=[0, 3600])
+axs[2].legend()
+##
diff --git a/bwm_mice_stats.py b/bwm_mice_stats.py
index 14696ff16c7811a63c1ee7027a6cd81cae596cf7..0ea04bba2db48c10ea81c0c62e416d058c173f1f 100644
--- a/bwm_mice_stats.py
+++ b/bwm_mice_stats.py
@@ -69,6 +69,7 @@ subjects = ['NYU-11', 'NYU-12', 'NYU-21', 'NYU-27', 'NYU-30', 'NYU-37',
    	'ibl_witten_27', 'ibl_witten_29', 'CSH_ZAD_001', 'CSH_ZAD_011',
    	'CSH_ZAD_019', 'CSH_ZAD_022', 'CSH_ZAD_024', 'CSH_ZAD_025',
    	'CSH_ZAD_026', 'CSH_ZAD_029']
+quit()
 # why does CSHL058 not work?
 
 old_style = False
diff --git a/canonical_infos.json b/canonical_infos.json
index c0868bc5f9a444da61e2e4c5010c9c759000ff3c..12a3fa3aae799f418267abc5e98f113db7d5d248 100644
--- a/canonical_infos.json
+++ b/canonical_infos.json
@@ -1 +1 @@
-{"SWC_022": {"seeds": ["412", "401", "403", "413", "406", "407", "415", "409", "405", "400", "408", "404", "410", "411", "414", "402"], "fit_nums": ["347", "54", "122", "132", "520", "386", "312", "59", "999", "849", "372", "300", "485", "593", "358", "550"], "chain_num": 19, "ignore": [5, 6, 12, 3, 1, 11, 10]}, "SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4, "ignore": [12, 1, 15, 14, 8, 6, 4, 10]}, "ibl_witten_15": {"seeds": ["408", "412", "400", "411", "410", "407", "403", "406", "413", "405", "404", "402", "401", "415", "409", "414"], "fit_nums": ["40", "241", "435", "863", "941", "530", "382", "750", "532", "731", "146", "500", "967", "334", "375", "670"], "chain_num": 19, "ignore": [6, 7, 9, 8, 5, 13, 12, 11]}, "ibl_witten_13": {"seeds": ["401", "414", "409", "413", "415", "411", "410", "408", "402", "405", "406", "407", "412", "403", "400", "404"], "fit_nums": ["702", "831", "47", "740", "251", "929", "579", "351", "515", "261", "222", "852", "754", "892", "473", "29"], "chain_num": 19, "ignore": [11, 4, 15, 6, 8, 0]}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4, "ignore": [0, 2, 14, 12, 1, 7, 11, 6]}, "ibl_witten_19": {"seeds": ["412", "415", "413", "408", "409", "404", "403", "401", "405", "411", "410", "406", "402", "414", "407", "400"], "fit_nums": ["234", "41", "503", "972", "935", "808", "912", "32", "331", "755", "117", "833", "822", "704", "901", "207"], "chain_num": 19, "ignore": [6, 1, 15, 10, 12]}, "CSH_ZAD_017": {"seeds": ["408", "404", "413", "406", "414", "411", "400", "401", "415", "407", "402", "412", "403", "409", "405", "410"], "fit_nums": ["928", "568", "623", "841", "92", "251", "829", "922", "964", "257", "150", "970", "375", "113", "423", "564"], "chain_num": 19, "ignore": [8, 10, 12, 5, 9, 7, 4]}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4, "ignore": [10, 1, 0, 13, 5, 9, 12, 3]}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4, "ignore": [8, 10, 13, 5, 12, 9, 7, 1]}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9, "ignore": [9, 0, 1, 7, 11, 3, 10, 8]}, "SWC_021": {"seeds": ["404", "413", "406", "412", "403", "401", "410", "409", "400", "414", "415", "402", "405", "408", "411", "407"], "fit_nums": ["840", "978", "224", "38", "335", "500", "83", "509", "441", "9", "135", "890", "358", "460", "844", "30"], "chain_num": 19, "ignore": [2, 13, 1, 9, 0, 4, 6, 5]}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4, "ignore": [11, 0, 4, 2, 12, 13, 8, 3]}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4, "ignore": [15, 4, 8, 0, 5, 10, 12, 11]}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2, "ignore": [15, 9, 8, 14, 1, 12, 10, 3]}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4, "ignore": [0, 14, 5, 8, 7, 11, 13, 10]}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4, "ignore": [6, 5, 9, 15, 0, 8, 4, 13]}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9, "ignore": [5, 12, 7, 10, 11, 2, 6, 4]}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4, "ignore": [15, 0, 3, 4, 7, 6, 1, 11]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4, "ignore": [14, 6, 3, 11, 15, 13, 4, 12]}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4, "ignore": [12, 8, 5, 1, 9, 3, 13, 15]}, "KS003": {"seeds": ["405", "401", "414", "415", "410", "404", "409", "413", "412", "408", "411", "407", "402", "406", "403", "400"], "fit_nums": ["858", "464", "710", "285", "665", "857", "990", "438", "233", "177", "43", "509", "780", "254", "523", "695"], "chain_num": 19, "ignore": [1, 8, 9, 2, 4, 15, 0, 7]}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4, "ignore": [8, 2, 7, 12, 3, 4, 13, 11]}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9, "ignore": [12, 14, 1, 2, 4, 7, 10, 15]}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9, "ignore": [10, 11, 6, 7, 12, 13, 1, 8]}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4, "ignore": [11, 14, 6, 13, 5, 12, 15, 8]}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4, "ignore": [9, 3, 5, 15, 6, 10, 2, 1]}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2, "ignore": [12, 4, 11, 6, 7, 14, 0, 1]}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4, "ignore": [3, 12, 4, 5, 2, 0, 13, 1]}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4, "ignore": [0, 2, 14, 11, 7, 10, 13, 9]}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4, "ignore": [11, 12, 0, 8, 2, 14, 5, 1]}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2, "ignore": [8, 1, 0, 3, 2, 5, 10, 4]}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4, "ignore": [7, 6, 10, 2, 15, 13, 1, 3]}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4, "ignore": [3, 12, 2, 6, 10, 14, 4, 1]}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4, "ignore": [7, 8, 0, 10, 2, 3, 12, 9]}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2, "ignore": [0, 7, 15, 14, 3, 10, 11, 13]}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2, "ignore": [1, 9, 15, 3, 13, 12, 7, 11]}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4, "ignore": [12, 13, 4, 11, 8, 3, 15, 0]}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4, "ignore": [8, 10, 1, 13, 4, 15, 14, 7]}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2, "ignore": [11, 2, 5, 1, 4, 9, 15, 12]}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4, "ignore": [11, 13, 7, 15, 14, 3, 0, 4]}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4, "ignore": [15, 12, 8, 13, 0, 2, 4, 5]}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2, "ignore": [7, 11, 2, 15, 0, 13, 5, 10]}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2, "ignore": [4, 10, 5, 0, 13, 8, 6, 7]}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4, "ignore": [14, 7, 12, 1, 3, 4, 11, 8]}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4, "ignore": [9, 12, 4, 8, 3, 7, 0, 1]}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4, "ignore": [0, 13, 8, 1, 12, 5, 10, 9]}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}}
\ No newline at end of file
+{"SWC_022": {"seeds": ["412", "401", "403", "413", "406", "407", "415", "409", "405", "400", "408", "404", "410", "411", "414", "402"], "fit_nums": ["347", "54", "122", "132", "520", "386", "312", "59", "999", "849", "372", "300", "485", "593", "358", "550"], "chain_num": 19}, "SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "ZFM-05236": {"seeds": ["404", "409", "401", "412", "406", "411", "410", "402", "408", "405", "415", "403", "414", "400", "413", "407"], "fit_nums": ["106", "111", "333", "253", "395", "76", "186", "192", "221", "957", "989", "612", "632", "304", "50", "493"], "chain_num": 14, "ignore": [3, 0, 8, 2, 6, 15, 14, 10]}, "ibl_witten_15": {"seeds": ["408", "412", "400", "411", "410", "407", "403", "406", "413", "405", "404", "402", "401", "415", "409", "414"], "fit_nums": ["40", "241", "435", "863", "941", "530", "382", "750", "532", "731", "146", "500", "967", "334", "375", "670"], "chain_num": 19}, "ibl_witten_13": {"seeds": ["401", "414", "409", "413", "415", "411", "410", "408", "402", "405", "406", "407", "412", "403", "400", "404"], "fit_nums": ["702", "831", "47", "740", "251", "929", "579", "351", "515", "261", "222", "852", "754", "892", "473", "29"], "chain_num": 19}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "ibl_witten_19": {"seeds": ["412", "415", "413", "408", "409", "404", "403", "401", "405", "411", "410", "406", "402", "414", "407", "400"], "fit_nums": ["234", "41", "503", "972", "935", "808", "912", "32", "331", "755", "117", "833", "822", "704", "901", "207"], "chain_num": 19}, "CSH_ZAD_017": {"seeds": ["408", "404", "413", "406", "414", "411", "400", "401", "415", "407", "402", "412", "403", "409", "405", "410"], "fit_nums": ["928", "568", "623", "841", "92", "251", "829", "922", "964", "257", "150", "970", "375", "113", "423", "564"], "chain_num": 19}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "SWC_021": {"seeds": ["404", "413", "406", "412", "403", "401", "410", "409", "400", "414", "415", "402", "405", "408", "411", "407"], "fit_nums": ["840", "978", "224", "38", "335", "500", "83", "509", "441", "9", "135", "890", "358", "460", "844", "30"], "chain_num": 19}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "ZFM-05245": {"seeds": ["400", "413", "414", "409", "411", "403", "405", "412", "406", "410", "407", "415", "401", "404", "402", "408"], "fit_nums": ["512", "765", "704", "17", "539", "449", "584", "987", "138", "932", "869", "313", "253", "540", "37", "634"], "chain_num": 11}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "ZFM-04019": {"seeds": ["413", "404", "408", "403", "415", "406", "414", "410", "402", "405", "411", "400", "401", "412", "409", "407"], "fit_nums": ["493", "302", "590", "232", "121", "938", "270", "999", "95", "175", "576", "795", "728", "244", "32", "177"], "chain_num": 14, "ignore": [11, 12, 6, 3, 5, 2, 15, 1]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "KS003": {"seeds": ["405", "401", "414", "415", "410", "404", "409", "413", "412", "408", "411", "407", "402", "406", "403", "400"], "fit_nums": ["858", "464", "710", "285", "665", "857", "990", "438", "233", "177", "43", "509", "780", "254", "523", "695"], "chain_num": 19}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}}
\ No newline at end of file
diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index c61d4c7c659c1360646febd6889a35d401ebdc37..3e8a74ceeab4065b6c0f37e16911e44f010f5cf1 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -545,10 +545,10 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
         cnas.append(c_n_a)
 
         mask = c_n_a[:, -1] == 0
-        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward', alpha=0.6)
+        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - all_conts[cont_mapping(- c_n_a[mask, 0] + c_n_a[mask, 1])]), 'o', c='b', ms=ms, label='Rightward', alpha=0.6)
 
         mask = c_n_a[:, -1] == 1
-        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward', alpha=0.6)
+        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - all_conts[cont_mapping(- c_n_a[mask, 0] + c_n_a[mask, 1])]), 'o', c='r', ms=ms, label='Leftward', alpha=0.6)
 
         plt.title("session #{} / {}".format(1+seq_num, test.results[0].n_sessions), size=26)
         # plt.yticks(*self.cont_ticks, size=22-2)
@@ -566,13 +566,15 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
             1.03, functions=(cont_to_belief, belief_to_cont))
         secax_y2.set_ylabel("Contrast", size=26)
         secax_y2.set_yticks([-1, -0.5, 0., 0.5, 1])
+        # secax_y2.set_yticks(np.arange(11) / 5 - 1)
+        # secax_y2.set_yticklabels([-1, '', '', '', '', 0., '', '', '', '', 1])
         secax_y2.spines['right'].set_bounds(-1, 1)
         secax_y2.tick_params(axis='y', which='major', labelsize=fs)
         plt.xlabel('Trial', size=28)
         sns.despine()
 
-        plt.xlim(left=250, right=450)
-        plt.legend(frameon=False, fontsize=22, bbox_to_anchor=(0.8, 0.5))
+        # plt.xlim(left=250, right=450)
+        plt.legend(frameon=False, fontsize=22, ncol=2, loc=(0.7, 0.05))
         plt.tight_layout()
         if save:
             plt.savefig("dynamic_GLM_figures/all posterior and contrasts {}, sess {}{}.png".format(subject, seq_num, save_append), dpi=dpi)#, bbox_inches='tight')
@@ -685,6 +687,7 @@ def create_mode_indices(test, subject, fit_type):
     try:
         xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb'))
     except Exception:
+        print('Doing PCA')
         ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_type='dists', dim=dim)
         xy = np.vstack([dimreduc[i] for i in range(dim)])
         from scipy.stats import gaussian_kde
@@ -1128,7 +1131,6 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
                 counter += 1
             else:
                 counter = 0
-            if test.results[0].infos[c] + 1 < plot_until:
                 ax0.axvline(test.results[0].infos[c] + 1, color='gray', zorder=0)
                 ax1.axvline(test.results[0].infos[c] + 1, color='gray', zorder=0)
                 ax0.plot(test.results[0].infos[c] + 1 - 0.25, 0.6 + counter * 0.2, 'ko', ms=18)
@@ -1281,7 +1283,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
         ax0.axhline(-0.5, c='k')
         ax0.axhline(0.5, c='k')
         # print(perf)
-        ax0.fill_between(range(1, min(1 + test.results[0].n_sessions, plot_until)), perf[:plot_until - 1] - 0.5, -0.5, color='k')
+        ax0.fill_between(range(1, 1 + test.results[0].n_sessions), perf - 0.5, -0.5, color='k')
 
     durs, state_types = state_type_durs(states_by_session, all_pmfs)
 
@@ -1352,7 +1354,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
 
     plt.tight_layout()
     if save:
-        print("saving with {} dpi".format(dpi))
+        # print("saving with {} dpi".format(dpi))
         plt.savefig("dynamic_GLM_figures/meta_state_development_{}_{}{}.png".format(test.results[0].name, separate_pmf, save_append), dpi=dpi)
     if show:
         plt.show()
@@ -1361,7 +1363,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
 
     return states_by_session, all_pmfs, all_pmf_weights, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type
 
-def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, all_pmfs, title=""):
+def compare_pmfs(test, states2compare, states_by_session, all_pmfs, title=""):
     """
        Take a set of states, and plot out their PMFs on all sessions on which they occur.
        See how different they really are.
@@ -1395,6 +1397,31 @@ def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, a
     plt.show()
 
 
+def compare_params(test, session, state_sets, indices, states2compare, title=""):
+    """
+       Take a set of states, and plot a histogram of their regression parameters for a specific session.
+       See how different they really are.
+
+       Unfinished
+
+       Takes states_by_session and all_pmfs as input from state_development
+    """
+    colors = ['blue', 'orange', 'green', 'black', 'red']
+    assert len(states2compare) <= len(colors)
+    # subtract 1 to get internal numbering
+    states2compare = [s - 1 for s in states2compare]
+    # transform desired states into the actual numbering, before ordering by bias
+    states2compare = [key for key in test.state_mapping.keys() if test.state_mapping[key] in states2compare]
+
+    for state, trials in enumerate(state_sets):
+        if state not in states2compare:
+            continue
+        session_js, pmfs, pmf_weights = state_weights(test, trials, indices)
+        return pmf_weights
+    plt.tight_layout()
+    plt.show()
+
+
 def compare_weights(test, state_sets, indices, states2compare, states_by_session, title=""):
     """
        Take a set of states, and plot out their weights on all sessions on which they occur.
@@ -1495,7 +1522,7 @@ if __name__ == "__main__":
     elif fit_type == 'prebias':
         loading_info = json.load(open("canonical_infos.json", 'r'))
         r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
-    subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017']  # list(loading_info.keys())
+    subjects = ['ZFM-04019', 'ZFM-05236']  # list(loading_info.keys())
 
     r_hats = {}
 
@@ -1544,6 +1571,24 @@ def dist_helper(dist_matrix, state_hists, inds):
         dist_matrix[i, j] = np.sum(np.abs(state_hists[i] - state_hists[j]))
     return dist_matrix
 
+def type_2_appearance(states, pmfs):
+    # How does type 2 appear, is it a new state or continuation of a type 1?
+    state_counter = {}
+    found_states = 0
+    for session_counter in range(states.shape[1]):
+        for state, pmf in zip(range(states.shape[0]), pmfs):
+            if states[state, session_counter]:
+                if state not in state_counter:
+                    state_counter[state] = -1
+                state_counter[state] += 1
+                if pmf_type(pmf[1][state_counter[state]][pmf[0]]) == 1:
+                    found_states += 1
+                    new = state_counter[state] == 0
+        if found_states > 1:
+            print("Problem")
+            return 2
+        if found_states == 1:
+            return new
 
 def state_type_durs(states, pmfs):
     # Takes states and pmfs, first creates an array of when which type is how active, then computes the number of sessions each type lasts.
@@ -1557,7 +1602,6 @@ def state_type_durs(states, pmfs):
                 pmf_counter += 1
                 state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i]  # indexing horror
 
-    print(state_types)
     if np.any(state_types[1] > 0.5):
         durs = (np.where(state_types[1] > 0.5)[0][0],
                 np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0],
@@ -1662,6 +1706,36 @@ def compare_performance(cnas, contrasts=(1, 1), title=""):
         plt.savefig(title)
     plt.show()
 
+def write_results(test, state_sets, indices, consistencies=None):
+    n = test.results[0].n_sessions
+    trial_counter = 0
+    state_dict = {state: {'sessions': [], 'trials': [], 'pmfs': []} for state in range(len(state_sets))}
+
+    for state, trials in enumerate(state_sets):
+
+        session_js, pmfs, pmf_weights = state_pmfs(test, trials, indices)
+        state_dict[state]['sessions'] = session_js
+        state_dict[state]['pmfs'] = pmfs
+
+
+    for seq_num in range(n):
+        for state, trials in enumerate(state_sets):
+            relevant_trials = trials[np.logical_and(trial_counter <= trials, trials < trial_counter + len(test.results[0].models[0].stateseqs[seq_num]))]
+            active_trials = np.zeros(len(test.results[0].models[0].stateseqs[seq_num]))
+
+            if consistencies is None:
+                active_trials[relevant_trials - trial_counter] = 1
+            else:
+                active_trials[relevant_trials - trial_counter] = np.sum(consistencies[tuple(np.meshgrid(relevant_trials, trials))], axis=0)
+                active_trials[relevant_trials - trial_counter] -= 1
+                active_trials[relevant_trials - trial_counter] = active_trials[relevant_trials - trial_counter] / (trials.shape[0] - 1)
+
+            if np.sum(active_trials) > 0:
+                state_dict[state]['trials'].append(active_trials)
+
+        trial_counter += len(test.results[0].models[0].stateseqs[seq_num])
+    return state_dict
+
 
 if __name__ == "__main__":
 
@@ -1675,7 +1749,7 @@ if __name__ == "__main__":
     no_good_pcas = ['NYU-06', 'SWC_023']
     subjects = list(loading_info.keys())
     # subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017']
-    # subjects = ['KS014']
+    subjects = ['KS021']
 
     print(subjects)
     fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
@@ -1717,11 +1791,13 @@ if __name__ == "__main__":
     all_bias_flips = []
     all_pmf_weights = []
 
-    temp_counter = 0
-
+    new_counter, transform_counter = 0, 0
     state_types_interpolation = np.zeros((3, 150))
     all_state_types = []
 
+    state_nums_5 = []
+    state_nums_10 = []
+
     for subject in subjects:
 
         if subject.startswith('GLM_Sim_') or subject == 'ibl_witten_18':
@@ -1748,22 +1824,64 @@ if __name__ == "__main__":
             # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
 
             # training overview
-            # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(7)), plot_until=2)
-            # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(6)), plot_until=7)
-            # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(4)), plot_until=13)
+            # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 0', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(8)), plot_until=-1)
+            # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(7)), plot_until=2)
+            # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(6)), plot_until=7)
+            # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(4)), plot_until=13)
             states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
-            all_pmf_weights += pmf_weights
-            continue
+            # state_nums_5.append((states > 0.05).sum(0))
+            # state_nums_10.append((states > 0.1).sum(0))
+            # continue
+            a = compare_params(test, 26, [s for s in state_sets if len(s) > 40], mode_indices, [3, 5])
+            compare_pmfs(test, [3, 5], states, pmfs, title="{} convergence pmf".format(subject))
+            quit()
+            new = type_2_appearance(states, pmfs)
+
+            if new == 2:
+                print('____________________________')
+                print(subject)
+                print('____________________________')
+            if new == 1:
+                new_counter += 1
+            if new == 0:
+                transform_counter += 1
+            print(new_counter, transform_counter)
 
             consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
             consistencies /= consistencies[0, 0]
-            contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False)
+            temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False)
+            continue
 
-            all_state_types.append(state_types)
+            state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices)
+            pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb'))
+            quit()
+            # abs_state_durs.append(durs)
+            # continue
+            # all_pmf_weights += pmf_weights
+            # all_state_types.append(state_types)
+            #
             # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0])
             # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1])
             # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2])
-            # temp_counter += 1
+            #
+            # all_first_pmfs_typeless[subject] = []
+            # for pmf in pmfs:
+            #     all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0]))
+            #     all_pmfs.append(pmf)
+            #
+            # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs)
+            # for pmf in changing_pmfs:
+            #     if type(pmf[0]) == int:
+            #         continue
+            #     all_changing_pmf_names.append(subject)
+            #     all_changing_pmfs.append(pmf)
+            #
+            # all_first_pmfs[subject] = first_pmfs
+            # continue
+            #
+            # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+            # consistencies /= consistencies[0, 0]
+            # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False)
 
 
             # b_flips = bias_flips(states, pmfs, durs)
@@ -1776,22 +1894,12 @@ if __name__ == "__main__":
             # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject))
             # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject))
             # quit()
-            # all_first_pmfs_typeless[subject] = []
-            # for pmf in pmfs:
-            #     all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0]))
             # all_intros.append(undiv_intros)
             # all_intros_div.append(intros_by_type)
             # if states_per_type != []:
             #     all_states_per_type.append(states_per_type)
             #
             # intros_by_type_sum += intros_by_type
-            # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs)
-            # for pmf in changing_pmfs:
-            #     if type(pmf[0]) == int:
-            #         continue
-            #     all_changing_pmf_names.append(subject)
-            #     all_changing_pmfs.append(pmf)
-            # all_first_pmfs[subject] = first_pmfs
             # for pmf in pmfs:
             #     all_pmfs.append(pmf)
             #     for p in pmf[1]:
@@ -1817,7 +1925,6 @@ if __name__ == "__main__":
             # compare_performance(cnas, (0, 0.848))
 
             # duration of different state types (and also percentage of type activities)
-            # abs_state_durs.append(durs)
             # continue
             # simplex_durs = np.array(durs).reshape(1, 3)
             # print(simplex_durs / np.sum(simplex_durs))
@@ -1860,6 +1967,7 @@ if __name__ == "__main__":
             # plt.savefig("temp")
             # plt.close()
 
+            print('Computing sub result')
             create_mode_indices(test, subject, fit_type)
             state_set_and_plot(test, 'first_', subject, fit_type)
             print("second mode?")
@@ -1868,8 +1976,8 @@ if __name__ == "__main__":
 
         except FileNotFoundError as e:
             print(e)
-            continue
             print('no canoncial result')
+            continue
             print(r_hats[subject])
             if r_hats[subject] >= 1.05:
                 print("Skipping")
@@ -1915,7 +2023,7 @@ if __name__ == "__main__":
             test.r_hat_and_ess(alpha_func, True)
 
 
-    # pickle.dump(all_first_pmfs, open("pmfs_temp.p", 'wb'))
+    # pickle.dump(all_first_pmfs, open("all_first_pmfs.p", 'wb'))
     # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb'))
     # pickle.dump(all_changing_pmf_names, open("changing_pmf_names.p", 'wb'))
     # pickle.dump(all_first_pmfs_typeless, open("all_first_pmfs_typeless.p", 'wb'))
@@ -1927,13 +2035,12 @@ if __name__ == "__main__":
     # pickle.dump(regression_diffs, open("regression_diffs.p", 'wb'))
     # pickle.dump(all_bias_flips, open("all_bias_flips.p", 'wb'))
     # pickle.dump(all_state_types, open("all_state_types.p", 'wb'))
-    pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb'))
-    quit()
-
+    # pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb'))
+    # pickle.dump(state_types_interpolation, open("state_types_interpolation.p", 'wb'))
+    # abs_state_durs = np.array(abs_state_durs)
+    # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb'))
 
     if True:
-        abs_state_durs = np.array(abs_state_durs)
-        # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb'))
         abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb'))
 
         print("Correlations")
diff --git a/dynamic_GLMiHMM_fit (copy).py b/dynamic_GLMiHMM_fit (copy).py
new file mode 100644
index 0000000000000000000000000000000000000000..6ccb8afcc173f24b0fd7228f668917b4625a52f4
--- /dev/null
+++ b/dynamic_GLMiHMM_fit (copy).py	
@@ -0,0 +1,407 @@
+"""Start a (series) of iHMM fit(s)."""
+import os
+os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=4
+os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=4
+os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=6
+os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=4
+os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=6
+import pyhsmm
+import pyhsmm.basic.distributions as distributions
+import copy
+import warnings
+import pickle
+import time
+from scipy.special import digamma
+import os.path
+import numpy as np
+from itertools import product
+import json
+import sys
+
+
+def crp_expec(n, theta):
+    """
+    Return expected number of tables after n customers, given concentration theta.
+
+    From Wikipedia
+    """
+    return theta * (digamma(theta + n) - digamma(theta))
+
+
+def eleven2nine(x):
+    """Map from 11 possible contrasts to 9, for the non-training phases.
+
+    1 and 9 can't appear there, make other elements consider this.
+
+    E.g.:
+    [2, 0, 4, 8, 10] -> [1, 0, 3, 7, 8]
+    """
+    assert 1 not in x and 9 not in x
+    x[x > 9] -= 1
+    x[x > 1] -= 1
+    return x
+
+
+def eval_cross_val(models, data, unmasked_data, n_all_states):
+    """Eval cross_val."""
+    lls = np.zeros(len(models))
+    cross_val_n = 0
+    for sess_time, (d, full_data) in enumerate(zip(data, unmasked_data)):
+        held_out = np.isnan(d[:, -1])
+        cross_val_n += held_out.sum()
+        d[:, -1][held_out] = full_data[:, -1][held_out]
+        for i, m in enumerate(models):
+            for s in range(n_all_states):
+                mask = np.logical_and(held_out, m.stateseqs[sess_time] == s)
+                if mask.sum() > 0:
+                    ll = m.obs_distns[s].log_likelihood(d[mask], sess_time)
+                    lls[i] += np.sum(ll)
+    lls /= cross_val_n
+    ll_mean = np.mean(lls[-1000:])
+    return lls, ll_mean
+
+
+# following Nick Roys contrasts: following tanh transformation of the contrasts x has a
+# free parameter p which we set as p= 5 throughout the paper: xp = tanh (px)/ tanh (p).
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+cont_mapping = np.vectorize(num_to_contrast.get)
+
+data_folder = 'session_data_test'
+old_style = False
+if old_style:
+    print("Warning, data can have splits")
+    print("Sure you want to use old data?")
+    temp = input()
+    if temp:
+        data_folder = 'session_data_old'
+    else:
+        quit()
+
+# available:
+# not great: ibl_witten_18
+
+# test subjects:
+subjects = ['NYU-21', 'NYU-27', 'NYU-30', 'NYU-37', 'NYU-39', 'NYU-40', 'NYU-45', 'NYU-46', 'NYU-47', 'NYU-48', 'CSHL045', 'CSHL047', 'CSHL052', 'CSHL053', 'CSHL055', 'CSHL058', 'CSHL060', 'UCLA005', 'UCLA006', 'UCLA011', 'UCLA012', 'UCLA014', 'UCLA015', 'UCLA017', 'UCLA033', 'UCLA035', 'KS042', 'KS043', 'KS044', 'KS045', 'KS046', 'KS051', 'KS052', 'KS055', 'KS084', 'KS086', 'KS091', 'KS094', 'KS096', 'DY_008', 'DY_009', 'DY_010', 'DY_011', 'DY_013', 'DY_014', 'DY_016', 'DY_018', 'DY_020', 'SWC_042', 'SWC_043', 'SWC_060', 'SWC_061', 'SWC_066', 'ZFM-01576', 'ZFM-01577', 'ZFM-01592', 'ZFM-01935', 'ZFM-01936', 'ZFM-01937', 'ZFM-02368', 'ZFM-02369', 'ZFM-02370', 'ZFM-02372', 'ZFM-02373', 'ZM_1898', 'ZM_2240', 'ZM_2241', 'ZM_2245', 'SWC_038', 'SWC_039', 'SWC_052', 'SWC_053', 'SWC_054', 'SWC_058', 'SWC_065', 'ibl_witten_20', 'ibl_witten_25', 'ibl_witten_26', 'ibl_witten_27', 'ibl_witten_29', 'CSH_ZAD_019', 'CSH_ZAD_024', 'CSH_ZAD_029']
+
+
+# subjects = [['GLM_Sim_15', 'GLM_Sim_14', 'GLM_Sim_13', 'GLM_Sim_11', 'GLM_Sim_10', 'GLM_Sim_09', 'GLM_Sim_12'][2]]
+# (0.03, 0.3, 5, 'contR', 'contL', 'prevA', 'bias', 1, 0.1):
+
+print(subjects)
+
+cv_nums = [400 + int(sys.argv[1]) % 16, 400 + (2 * int(sys.argv[1])) % 16]
+subjects = [subjects[int(sys.argv[1]) // 16], subjects[(2 * int(sys.argv[1])) // 16]]
+
+print(cv_nums)
+print(subjects)
+
+for loop_count_i, (s, cv_num) in enumerate(zip(subjects, cv_nums)):
+    # if loop_count_i > 8:
+    # if loop_count_i <= 8 or loop_count_i > 16:
+    # if loop_count_i <= 16:
+    #     continue
+    params = {}
+    params['subject'] = s
+    params['cross_val_num'] = cv_num
+    params['fit_variance'] = 0.03
+    params['jumplimit'] = 1
+    all_regressors = ['contR', 'contL', 'cont', 'prevA', 'weighted_prevA', 'WSLS', 'bias']
+    params['regressors'] = [all_regressors[i] for i in [0, 1, 3, 6]]
+
+    # default (non-iteration) settings:
+    params['fit_type'] = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
+    # params['fit_variance'] = [0.0005, 0.002, 0.008, 0.02, 0.06, 0.1, 0.3, 0.6, 1., 2.4, 10, 16, 30, 'uniform'][6]
+    if 'prevA' in params['regressors'] or 'weighted_prevA' in params['regressors']:
+        params['exp_decay'], params['exp_length'] = 0.3, 5
+        params['exp_filter'] = np.exp(- params['exp_decay'] * np.arange(params['exp_length']))
+        params['exp_filter'] /= params['exp_filter'].sum()
+        print(params['exp_filter'])
+    params['dur'] = 'yes'
+    params['obs_dur'] = ['glm', 'cat'][0]
+    # more obscure params:
+    params['gamma'] = None  # 0.005
+    params['alpha'] = None  # 1
+    if params['gamma'] is not None:
+        print("_______________________")
+        print("Warning, gamma is fixed")
+        print("_______________________")
+    params['gamma_a_0'] = 0.001
+    params['gamma_b_0'] = 1000
+    params['init_var'] = 8
+    params['init_mean'] = np.zeros(len(params['regressors']))
+    # normal:
+    r_support = np.cumsum(np.arange(5, 100, 5))
+    r_support = np.arange(5, 705, 4)
+    params['dur_params'] = dict(r_support=r_support,
+                                r_probs=np.ones(len(r_support))/len(r_support), alpha_0=1, beta_0=1)
+    # params['dur_params'] = dict(r_support=np.array([1, 2, 3, 5, 7, 10, 15, 21, 28, 36, 45, 55, 150]),
+    #                             r_probs=np.ones(13)/13., alpha_0=1, beta_0=1)
+    # params['dur_params'] = dict(r_support=np.arange(1, 251),
+    #                             r_probs=np.ones(250)/250., alpha_0=1, beta_0=1)
+    params['alpha_a_0'] = 0.1
+    params['alpha_b_0'] = 10
+    # trying a smaller value here, should lower the appearance of ephemeral new states on session bounds
+    # hope this doesn't make real new states at session bound less likely, or hurt mixing...
+    params['init_state_concentration'] = 3
+    # cat params
+    params['conditioned_on'] = 'nothing'
+
+    params['cross_val'] = False
+    params['cross_val_fold'] = 10
+    params['CROSS_VAL_SEED'] = 4  # Do not change this, it's 4
+
+    params['seed'] = 100 + params['cross_val_num']
+
+    params['n_states'] = 15
+    params['n_samples'] = 6 if params['obs_dur'] == 'glm' else 4000
+    if params['cross_val']:
+        params['n_samples'] = 4000
+    if s.startswith("GLM_Sim"):
+        print("reduced sample size")
+        params['n_samples'] = 12000
+
+    print(params['n_samples'])
+    # now actual fit:
+    # new start names: uniform_start_, bias_fraction_, small_gamma_, high_init_, non_semi_, non_semi_normal_init_, correct_sol_, correct_sol_semi_
+    while True:
+        folder = "./dynamic_GLMiHMM_crossvals/"
+        rand_id = np.random.randint(1000)
+        if params['cross_val']:
+            id = "{}_crossval_{}_{}_var_{}_{}_{}".format(params['subject'], params['cross_val_num'], params['fit_type'],
+                                                         params['fit_variance'], params['seed'], rand_id)
+        else:
+            id = "{}_fittype_{}_var_{}_{}_{}".format(params['subject'], params['fit_type'],
+                                                     params['fit_variance'], params['seed'], rand_id)
+        if not os.path.isfile(folder + id + '_0.p'):
+            break
+    # create placeholder dataset for rand_id purposes
+    pickle.dump(params, open(folder + id + '_0.p', 'wb'))
+    if params['obs_dur'] == 'glm':
+        print(params['regressors'])
+    else:
+        print('using categoricals')
+    print(id)
+    params['file_name'] = folder + id
+    np.random.seed(params['seed'])
+
+    info_dict = pickle.load(open("./{}/{}_info_dict.p".format(data_folder, params['subject']), "rb"))
+    # Determine session numbers
+    if params['fit_type'] == 'prebias':
+        till_session = info_dict['bias_start']
+    elif params['fit_type'] == 'bias' or params['fit_type'] == 'zoe_style' or params['fit_type'] == 'all':
+        till_session = info_dict['n_sessions']
+    elif params['fit_type'] == 'prebias_plus':
+        till_session = min(info_dict['bias_start'] + 6, info_dict['n_sessions'])  # 6 here will actually turn into 7 later
+
+    from_session = info_dict['bias_start'] if params['fit_type'] in ['bias', 'zoe_style'] else 0
+
+    models = []
+
+    if params['obs_dur'] == 'glm':
+        n_inputs = len(params['regressors'])
+        T = till_session - from_session + (params['fit_type'] != 'prebias')
+        obs_hypparams = {'n_regressors': n_inputs, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'],
+                         'P_0': params['init_var'] * np.eye(n_inputs), 'Q': params['fit_variance'] * np.tile(np.eye(n_inputs), (T, 1, 1))}
+        obs_distns = [distributions.Dynamic_GLM(**obs_hypparams) for state in range(params['n_states'])]
+    else:
+        n_inputs = 9 if params['fit_type'] == 'bias' else 11
+        obs_hypparams = {'n_regressors': n_inputs * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'),
+                         'jumplimit': params['jumplimit'], 'sigmasq_states': params['fit_variance']}
+        obs_distns = [Dynamic_Input_Categorical(**obs_hypparams) for state in range(params['n_states'])]
+
+    dur_distns = [distributions.NegativeBinomialIntegerR2Duration(**params['dur_params']) for state in range(params['n_states'])]
+
+    if params['dur'] == 'yes':
+        if params['gamma'] is None:
+            posteriormodel = pyhsmm.models.WeakLimitHDPHSMM(
+                    # https://math.stackexchange.com/questions/449234/vague-gamma-prior
+                    alpha_a_0=params['alpha_a_0'], alpha_b_0=params['alpha_b_0'],  # TODO: gamma vs alpha? gamma steers state number
+                    gamma_a_0=params['gamma_a_0'], gamma_b_0=params['gamma_b_0'],
+                    init_state_concentration=params['init_state_concentration'],
+                    obs_distns=obs_distns,
+                    dur_distns=dur_distns,
+                    var_prior=params['fit_variance'])  # TODO: I don't think this does anything
+        else:
+            posteriormodel = pyhsmm.models.WeakLimitHDPHSMM(
+                    # https://math.stackexchange.com/questions/449234/vague-gamma-prior
+                    alpha=params['alpha'], # TODO: gamma vs alpha? gamma steers state number
+                    gamma=params['gamma'],
+                    init_state_concentration=params['init_state_concentration'],
+                    obs_distns=obs_distns,
+                    dur_distns=dur_distns,
+                    var_prior=params['fit_variance'])  # TODO: I don't think this does anything
+    else:
+        if params['gamma'] is None:
+            posteriormodel = pyhsmm.models.WeakLimitHDPHMM(
+                    alpha_a_0=params['alpha_a_0'], alpha_b_0=params['alpha_b_0'],  # TODO: gamma vs alpha? gamma steers state number
+                    gamma_a_0=params['gamma_a_0'], gamma_b_0=params['gamma_b_0'],
+                    init_state_concentration=params['init_state_concentration'],
+                    obs_distns=obs_distns,
+                    var_prior=params['fit_variance'])  # TODO: I don't think this does anything
+        else:
+            posteriormodel = pyhsmm.models.WeakLimitHDPHMM(
+                    alpha=params['alpha'],  # TODO: gamma vs alpha? gamma steers state number
+                    gamma=params['gamma'],
+                    init_state_concentration=params['init_state_concentration'],
+                    obs_distns=obs_distns,
+                    var_prior=params['fit_variance'])  # TODO: I don't think this does anything
+
+    print(from_session, till_session + (params['fit_type'] != 'prebias'))
+
+    if params['cross_val']:
+        rng = np.random.RandomState(params['CROSS_VAL_SEED'])
+
+    data_save = []
+    for j in range(from_session, till_session + (params['fit_type'] != 'prebias')):
+        try:
+            data = pickle.load(open("./{}/{}_fit_info_{}.p".format(data_folder, params['subject'], j), "rb"))
+        except FileNotFoundError:
+            continue
+        if data.shape[0] == 0:
+            print("meh, skipped session")
+            continue
+
+        # if j == 15:
+        #     import matplotlib.pyplot as plt
+        #     for i in [0, 2, 3,4,5,6,7,8,10]:
+        #         plt.plot(i, data[data[:, 0] == i, 1].mean(), 'ko')
+        #     plt.show()
+
+        if params['obs_dur'] == 'glm':
+            for i in range(data.shape[0]):
+                data[i, 0] = num_to_contrast[data[i, 0]]
+            mask = data[:, 1] != 1
+            mask[0] = False
+            if params['fit_type'] == 'zoe_style':
+                mask[90:] = False
+            mega_data = np.empty((np.sum(mask), n_inputs + 1))
+
+            for i, reg in enumerate(params['regressors']):
+                # positive numbers are contrast on the right
+                if reg == 'contR':
+                    mega_data[:, i] = np.maximum(data[mask, 0], 0)
+                elif reg == 'contL':
+                    mega_data[:, i] = np.abs(np.minimum(data[mask, 0], 0))
+                elif reg == 'cont':
+                    mega_data[:, i] = data[mask, 0]
+                elif reg == 'prevA':
+                    # prev_ans = data[:, 1].copy()
+                    new_prev_ans = data[:, 1].copy()
+                    # prev_ans[1:] = prev_ans[:-1]
+                    # prev_ans -= 1
+                    new_prev_ans -= 1
+                    new_prev_ans = np.convolve(np.append(0, new_prev_ans), params['exp_filter'])[:-(params['exp_filter'].shape[0])]
+                    mega_data[:, i] = new_prev_ans[mask]
+                elif reg == 'weighted_prevA':
+                    prev_ans = data[:, 1].copy()
+                    prev_ans -= 1
+                    # weigh the tendency by how clear the previous contrast was
+                    weighted_prev_ans = data[:, 0] + prev_ans
+                    weighted_prev_ans = np.convolve(np.append(0, weighted_prev_ans), params['exp_filter'])[:-(params['exp_filter'].shape[0])]
+                    mega_data[:, i] = weighted_prev_ans[mask]
+                elif reg == 'WSLS':
+                    side_info = pickle.load(open("./{}/{}_side_info_{}.p".format(data_folder, params['subject'], j), "rb"))
+                    prev_reward = side_info[:, 1]
+                    prev_reward[1:] = prev_reward[:-1]
+                    prev_ans = data[:, 1].copy()
+                    prev_ans[1:] = prev_ans[:-1] - 1
+                    mega_data[:, i] = prev_ans[mask]
+                    mega_data[prev_reward[mask] == 0, i] *= -1
+                elif reg == 'bias':
+                    # have bias active only if contrasts further from 1 are in the session
+                    # if len(np.unique(data[mask, 0])) > 5:
+                    #     print("original 1")
+                    #     mega_data[:, i] = 1
+                    # else:
+                    #     print("original 0")
+                    #     mega_data[:, i] = 0
+
+                    # have bias active only if contrasts further from 1 are in the session, new version
+                    # if bias_active:
+                    #     mega_data[:, i] = 1
+                    # elif len(np.unique(data[mask, 0])) > 5:
+                    #     if np.where(np.abs(data[mask, 0]) == 0.848)[0][0] / data.shape[0] < 0.5:
+                    #         mega_data[:, i] = 1
+                    #     else:
+                    #         mega_data[:, i] = 0
+                    #     bias_active = True
+                    # else:
+                    #     mega_data[:, i] = 0
+
+                    # bias is now always active
+                    mega_data[:, i] = 1
+
+                mega_data[:, -1] = data[mask, 1] / 2
+        elif params['obs_dur'] == 'cat':
+            mask = data[:, 1] != 1
+            mask[0] = False
+            data = data[:, [0, 1]]
+            data[:, 1] = data[:, 1] / 2
+            mega_data = data[mask]
+
+        data_save.append(mega_data.copy())
+
+        if params['cross_val']:
+            test_sets = np.tile(np.arange(params['cross_val_fold']), mega_data.shape[0] // params['cross_val_fold'] + 1)[:mega_data.shape[0]]
+            rng.shuffle(test_sets)
+            mega_data[:, -1][test_sets == params['cross_val_num']] = None
+
+        posteriormodel.add_data(mega_data)
+
+    # for d in posteriormodel.datas:
+    #     print(d.shape)
+
+    if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])):
+        pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb'))
+
+    # states_solution = pickle.load(open("states_{}_{}_condition_{}_{}.p".format('DY_013', 'all', 'nothing', '0_01'), 'rb'))  # todo: remove!
+    time_save = time.time()
+    likes = np.zeros(params['n_samples'])
+    with warnings.catch_warnings():  # ignore the scipy warning
+        warnings.simplefilter("ignore")
+        for j in range(params['n_samples']):
+
+            if j % 400 == 0 or j == 3:
+                print(j)
+
+            posteriormodel.resample_model()
+
+            likes[j] = posteriormodel.log_likelihood()
+            model_save = copy.deepcopy(posteriormodel)
+            if j != params['n_samples'] - 1 and j != 0 and j % 2000 != 1:
+                # To save on memory:
+                model_save.delete_data()
+                model_save.delete_obs_data()
+                model_save.delete_dur_data()
+            models.append(model_save)
+
+            # save something in case of crash
+            if j % 400 == 0 and j > 0:
+                if params['n_samples'] <= 4000:
+                    pickle.dump(models, open(folder + id + '.p', 'wb'))
+                else:
+                    pickle.dump(models, open(folder + id + '_{}.p'.format(j // 4001), 'wb'))
+                    if j % 4000 == 0:
+                        models = []
+    print(time.time() - time_save)
+
+    if params['cross_val']:
+        lls, lls_mean = eval_cross_val(models, posteriormodel.datas, data_save, n_all_states=params['n_states'])
+        params['cross_val_preds'] = lls
+        params['cross_val_preds'] = params['cross_val_preds'].tolist()
+
+    print(id)
+    if 'exp_filter' in params:
+        params['exp_filter'] = params['exp_filter'].tolist()
+    params['dur_params']['r_support'] = params['dur_params']['r_support'].tolist()
+    params['dur_params']['r_probs'] = params['dur_params']['r_probs'].tolist()
+    params['ll'] = likes.tolist()
+    params['init_mean'] = params['init_mean'].tolist()
+    if params['cross_val']:
+        json.dump(params, open(folder + "infos_new/" + '{}_{}_cvll_{}_{}_{}_{}_{}.json'.format(params['subject'], params['cross_val_num'], str(np.round(lls_mean, 3)).replace('.', '_'),
+                                                                                               params['fit_type'], params['fit_variance'], params['seed'], rand_id), 'w'))
+    else:
+        json.dump(params, open(folder + "infos_new/" + '{}_{}_{}_{}_{}.json'.format(params['subject'], params['fit_type'],
+                                                                                    params['fit_variance'], params['seed'], rand_id), 'w'))
+    pickle.dump(models, open(folder + id + '_{}.p'.format(j // 4001), 'wb'))
diff --git a/dynamic_GLMiHMM_fit.py b/dynamic_GLMiHMM_fit.py
index dbbe447a55e51e03d20aef3e840ccd600d76f653..bf1d046da4a345335014298a708c541c5c6d4ec3 100644
--- a/dynamic_GLMiHMM_fit.py
+++ b/dynamic_GLMiHMM_fit.py
@@ -67,7 +67,7 @@ contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.
 num_to_contrast = {v: k for k, v in contrast_to_num.items()}
 cont_mapping = np.vectorize(num_to_contrast.get)
 
-data_folder = 'session_data'
+data_folder = 'session_data_test'
 old_style = False
 if old_style:
     print("Warning, data can have splits")
@@ -86,7 +86,7 @@ subjects = ['ibl_witten_15', 'ibl_witten_17', 'ibl_witten_18', 'ibl_witten_19',
             'CSH_ZAD_017', 'CSH_ZAD_025', 'CSH_ZAD_026', 'CSHL049', 'CSHL051', 'CSHL061']
 
 # test subjects:
-sim_subjects = ['GLM_Sim_13', 'GLM_Sim_11', 'GLM_Sim_15']
+subjects = ['KS014']
 # subjects = ['KS021', 'KS016', 'ibl_witten_16', 'SWC_022', 'KS003', 'CSHL054', 'ZM_3003', 'KS015', 'ibl_witten_13', 'CSHL059', 'CSH_ZAD_022', 'CSHL_007', 'CSHL062', 'NYU-06', 'KS014', 'ibl_witten_14', 'SWC_023']
 
 # subjects = [['GLM_Sim_15', 'GLM_Sim_14', 'GLM_Sim_13', 'GLM_Sim_11', 'GLM_Sim_10', 'GLM_Sim_09', 'GLM_Sim_12'][2]]
@@ -314,30 +314,10 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
                     mega_data[:, i] = prev_ans[mask]
                     mega_data[prev_reward[mask] == 0, i] *= -1
                 elif reg == 'bias':
-                    # have bias active only if contrasts further from 1 are in the session
-                    # if len(np.unique(data[mask, 0])) > 5:
-                    #     print("original 1")
-                    #     mega_data[:, i] = 1
-                    # else:
-                    #     print("original 0")
-                    #     mega_data[:, i] = 0
-
-                    # have bias active only if contrasts further from 1 are in the session, new version
-                    # if bias_active:
-                    #     mega_data[:, i] = 1
-                    # elif len(np.unique(data[mask, 0])) > 5:
-                    #     if np.where(np.abs(data[mask, 0]) == 0.848)[0][0] / data.shape[0] < 0.5:
-                    #         mega_data[:, i] = 1
-                    #     else:
-                    #         mega_data[:, i] = 0
-                    #     bias_active = True
-                    # else:
-                    #     mega_data[:, i] = 0
-
                     # bias is now always active
                     mega_data[:, i] = 1
 
-                mega_data[:, -1] = data[mask, 1] / 2
+                mega_data[:, -1] = data[mask, 1] / 2 # 0 is rightwards, and 1 is leftwards, because the original data is weird
         elif params['obs_dur'] == 'cat':
             mask = data[:, 1] != 1
             mask[0] = False
@@ -357,9 +337,9 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
     # for d in posteriormodel.datas:
     #     print(d.shape)
 
-    if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])):
-        pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb'))
-
+    # if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])):
+    pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb'))
+    quit()
     # states_solution = pickle.load(open("states_{}_{}_condition_{}_{}.p".format('DY_013', 'all', 'nothing', '0_01'), 'rb'))  # todo: remove!
     time_save = time.time()
     likes = np.zeros(params['n_samples'])
diff --git a/index_mice.py b/index_mice.py
index e821de0e2e86ea42e5859cc88e8c545179cea4e8..add07bdaf3514b88e6aed44d1186c629072f3f64 100644
--- a/index_mice.py
+++ b/index_mice.py
@@ -20,14 +20,6 @@ for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"):
     fit_num = result.group(5)
     chain_num = result.group(6)
 
-    # print()
-    # print(filename)
-    # print(subject)
-    # print(fit_type)
-    # print(seed)
-    # print(fit_num)
-    # print(chain_num)
-    # continue
     if fit_type == 'prebias':
         local_dict = prebias_subinfo
     elif fit_type == 'bias':
diff --git a/pmf_weight_analysis.py b/pmf_weight_analysis.py
index 00611bcc7e1200cbe4a76a79ebde4510408bea3b..059851501e92abf00fe662b44d35c296ea153c09 100644
--- a/pmf_weight_analysis.py
+++ b/pmf_weight_analysis.py
@@ -3,10 +3,148 @@ import matplotlib.pyplot as plt
 import pickle
 from scipy.stats import gaussian_kde
 from analysis_pmf import pmf_type, type2color
+from mpl_toolkits import mplot3d
+
+performance_points = np.array([-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0])
+reduced_points = np.array([1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=bool)
+
+folder = "./temp_reward_analysis/"
+
+def pmf_to_perf(pmf):
+    # determine performance of a pmf, but only on the omnipresent strongest contrasts
+    return np.mean(np.abs(performance_points[reduced_points] + pmf[reduced_points]))
+
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+
+def weights_to_pmf(weights, with_bias=1):
+    psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
+    return 1 / (1 + np.exp(psi))
+
+def pmf_type_rew(weights):
+    rew = pmf_to_perf(weights_to_pmf(weights))
+    if rew < 0.6:
+        return 0
+    elif rew < 0.7827:
+        return 1
+    else:
+        return 2
 
 # all pmf weights
 apw = np.array(pickle.load(open("all_pmf_weights.p", 'rb')))
 
+colors = [type2color[pmf_type(weights_to_pmf(x))] for x in apw]
+colors_rew = [type2color[pmf_type_rew(x)] for x in apw]
+
+
+if True:
+    for i, weights in enumerate(apw):
+        if pmf_type(weights_to_pmf(weights)) != pmf_type_rew(weights):
+            plt.plot(weights_to_pmf(weights))
+            plt.ylim(0, 1)
+            plt.title("Classic says {}, reward says {}, ({})".format(1 + pmf_type(weights_to_pmf(weights)), 1 + pmf_type_rew(weights), pmf_to_perf(weights_to_pmf(weights))))
+            print(weights_to_pmf(weights))
+            plt.tight_layout()
+            plt.savefig(folder + "divergent classification {}".format(i))
+            plt.close()
+
+
+type1_rews = []
+type2_rews = []
+type3_rews = []
+all_rews = []
+for weights in apw:
+    type = pmf_type(weights_to_pmf(weights))
+    all_rews.append(pmf_to_perf(weights_to_pmf(weights)))
+    if type == 0:
+        if pmf_to_perf(weights_to_pmf(weights)) > 0.8:
+            plt.plot(weights_to_pmf(weights))
+            plt.show()
+        type1_rews.append(pmf_to_perf(weights_to_pmf(weights)))
+    elif type == 1:
+        type2_rews.append(pmf_to_perf(weights_to_pmf(weights)))
+    elif type == 2:
+        type3_rews.append(pmf_to_perf(weights_to_pmf(weights)))
+
+type1_rews, type2_rews, type3_rews = np.array(type1_rews), np.array(type2_rews), np.array(type3_rews)
+
+bound1 = np.linspace(0.55, 0.65, 100)
+bound2 = np.linspace(0.75, 0.85, 100)
+
+opt_bound1, errors1 = 0, 100
+for b in bound1:
+    if errors1 > (np.sum(type1_rews > b) + np.sum(type2_rews < b)):
+        opt_bound1 = b
+        errors1 = np.sum(type1_rews > b) + np.sum(type2_rews < b)
+
+opt_bound2, errors2 = 0, 100
+for b in bound2:
+    if errors2 > (np.sum(type2_rews > b) + np.sum(type3_rews < b)):
+        opt_bound2 = b
+        errors2 = np.sum(type2_rews > b) + np.sum(type3_rews < b)
+
+print(opt_bound1, errors1, opt_bound2, errors2)
+
+opt_bound2 = 0.780303
+print("Optimal bound 2: {}, {}".format(np.sum(type2_rews > opt_bound2), np.sum(type3_rews < opt_bound2)))
+
+man_bound = 0.7827
+print("Manual bound 2: {}, {}".format(np.sum(type2_rews > man_bound), np.sum(type3_rews < man_bound)))
+
+
+bins = np.linspace(0, 1, 30)
+plt.hist([type1_rews, type2_rews, type3_rews], bins=bins, label=["Type 1", "Type 2", "Type 3"])
+plt.axvline(0.6, c='k')
+plt.axvline(0.7827, c='k')
+plt.legend()
+plt.savefig(folder + "hist 1")
+plt.show()
+
+plt.hist([type1_rews, type2_rews, type3_rews], bins=bins, label=["Type 1", "Type 2", "Type 3"], stacked=True)
+plt.axvline(0.6, c='k')
+plt.axvline(0.7827, c='k')
+plt.legend()
+plt.savefig(folder + "hist 2")
+plt.show()
+
+plt.subplot(1, 3, 1)
+plt.title("Bins = 25")
+plt.hist(all_rews, bins=25)
+plt.axvline(0.6, c='k')
+plt.axvline(0.7827, c='k')
+
+plt.subplot(1, 3, 2)
+plt.title("Bins = 40")
+plt.hist(all_rews, bins=40)
+plt.axvline(0.6, c='k')
+plt.axvline(0.7827, c='k')
+
+plt.subplot(1, 3, 3)
+plt.title("Bins = 55")
+plt.hist(all_rews, bins=55)
+plt.axvline(0.6, c='k')
+plt.axvline(0.7827, c='k')
+
+plt.savefig(folder + "hists compare")
+plt.show()
+
+
+fig = plt.figure(figsize=(13 * 3 / 5, 9 * 3 / 5))
+plt.hist(all_rews, bins=40, color='grey')
+plt.axvline(0.6, c='k')
+plt.axvline(0.7827, c='k')
+
+plt.ylabel("# of occurences", size=28)
+plt.xlabel("Reward rate", size=28)
+plt.gca().spines[['right', 'top']].set_visible(False)
+
+plt.tight_layout()
+plt.savefig(folder + "single hist")
+plt.show()
+
+
 xy = np.vstack([apw[:, i] for i in range(4)])
 z = gaussian_kde(xy)(xy)
 
@@ -25,67 +163,107 @@ plt.scatter(apw[:, 3], apw[:, 0], c=z)
 plt.xlabel("Bias")
 plt.ylabel("Cont right")
 
+plt.savefig(folder + "density scatter")
 plt.show()
 
 
-contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
-contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
-
-def weights_to_pmf(weights, with_bias=1):
-    psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
-    return 1 / (1 + np.exp(psi))
-
-colors = [type2color[pmf_type(weights_to_pmf(x))] for x in apw]
-
-plt.subplot(1, 3, 1)
+plt.subplot(1, 4, 1)
 sc = plt.scatter(apw[:, 0], apw[:, 1], c=colors)
-fig, ax = plt.gcf(), plt.gca()
+fig, ax1 = plt.gcf(), plt.gca()
 plt.xlabel("Cont right")
 plt.ylabel("Cont left")
 
-annot = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
-                    bbox=dict(boxstyle="round", fc="w"),
-                    arrowprops=dict(arrowstyle="->"))
-annot.set_visible(False)
+annot1 = ax1.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
+                      bbox=dict(boxstyle="round", fc="w"),
+                      arrowprops=dict(arrowstyle="->"))
+annot1.set_visible(False)
+
+plt.subplot(1, 4, 2)
+plt.scatter(apw[:, 3], apw[:, 1], c=colors)
+ax2 = plt.gca()
+plt.xlabel("Bias")
+plt.ylabel("Cont left")
+
+annot2 = ax2.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
+                      bbox=dict(boxstyle="round", fc="w"),
+                      arrowprops=dict(arrowstyle="->"))
+annot2.set_visible(False)
+
+plt.subplot(1, 4, 3)
+plt.scatter(apw[:, 3], apw[:, 0], c=colors)
+ax3 = plt.gca()
+plt.xlabel("Bias")
+plt.ylabel("Cont right")
+
+annot3 = ax3.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
+                      bbox=dict(boxstyle="round", fc="w"),
+                      arrowprops=dict(arrowstyle="->"))
+annot3.set_visible(False)
+
 
 def update_annot(ind):
-    print(ind)
     pos = sc.get_offsets()[ind["ind"][0]]
-    annot.xy = pos
+    annot1.xy = pos
     text = "{}".format(np.round(apw[ind["ind"][0]], 2))
-    annot.set_text(text)
+    annot1.set_text(text)
+
+    annot2.xy = apw[ind["ind"][0]][[3, 1]]
+    text = "{}".format(np.round(apw[ind["ind"][0]], 2))
+    annot2.set_text(text)
+
+    annot3.xy = apw[ind["ind"][0]][[3, 0]]
+    text = "{}".format(np.round(apw[ind["ind"][0]], 2))
+    annot3.set_text(text)
+
+    plt.subplot(1, 4, 4)
+    plt.cla()
+    plt.plot(weights_to_pmf(apw[ind["ind"][0]]))
+    plt.ylim(0, 1)
+    plt.ylabel("P(rightwards)")
+    plt.xlabel("Contrasts")
+
 
 def hover(event):
-    vis = annot.get_visible()
-    if event.inaxes == ax:
+    vis = annot1.get_visible()
+    if event.inaxes == ax1:
         cont, ind = sc.contains(event)
         if cont:
             update_annot(ind)
-            annot.set_visible(True)
+            annot1.set_visible(True)
+            annot2.set_visible(True)
+            annot3.set_visible(True)
             fig.canvas.draw_idle()
         else:
             if vis:
-                annot.set_visible(False)
+                annot1.set_visible(False)
+                annot2.set_visible(False)
+                annot3.set_visible(False)
                 fig.canvas.draw_idle()
 
+
 fig.canvas.mpl_connect("motion_notify_event", hover)
 
-plt.subplot(1, 3, 2)
-plt.scatter(apw[:, 3], apw[:, 1], c=colors)
-plt.xlabel("Bias")
-plt.ylabel("Cont left")
+plt.savefig(folder + "type scatter")
+plt.show()
 
-plt.subplot(1, 3, 3)
-plt.scatter(apw[:, 3], apw[:, 0], c=colors)
-plt.xlabel("Bias")
-plt.ylabel("Cont right")
 
+fig = plt.figure(figsize=(16, 9))
+ax = plt.axes(projection='3d')
+ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors)
+ax.view_init(27.5, -137)
+plt.savefig(folder + "3d types")
 plt.show()
 
+fig = plt.figure(figsize=(16, 9))
+ax = plt.axes(projection='3d')
+ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors_rew)
+ax.view_init(27.5, -137)
+plt.savefig(folder + "3d types new")
+plt.show()
 
-from mpl_toolkits import mplot3d
 
 fig = plt.figure()
 ax = plt.axes(projection='3d')
-ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors)
+ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=z)
+plt.savefig(folder + "3d density")
 plt.show()
diff --git a/process_many_chains.py b/process_many_chains.py
index cf8061356316a47aae18dedf91b2d0b4b2b27e9c..a0664e42c0946a573b715f894307b8c096c1c96b 100644
--- a/process_many_chains.py
+++ b/process_many_chains.py
@@ -17,7 +17,7 @@ if fit_type == 'bias':
     loading_info = json.load(open("canonical_infos_bias.json", 'r'))
 elif fit_type == 'prebias':
     loading_info = json.load(open("canonical_infos.json", 'r'))
-subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] # list(loading_info.keys())
+subjects = ['ZFM-04019', 'ZFM-05236'] # list(loading_info.keys())
 
 fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
 func1 = state_num_helper(0.2)
diff --git a/simplex_animation.py b/simplex_animation.py
index 89dfe945c0994f472c74602d89541d12ac1d5fb1..ddf6df7925ac502d5319f6b4a6986eb27adb6fff 100644
--- a/simplex_animation.py
+++ b/simplex_animation.py
@@ -9,6 +9,7 @@ import copy
 
 all_state_types = pickle.load(open("all_state_types.p", 'rb'))
 
+
 # create a list of fixed offsets, to apply to mice when they are in a corner
 offsets = [(0, 0)]
 hex_size = 0
@@ -37,7 +38,7 @@ assert (test_count == 1).all()
 # plt.show()
 # quit()
 
-session_counter = 0
+session_counter = -1
 alph = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'za', 'zb', 'zc', 'zd', 'ze', 'zf', 'zg', 'zh', 'zi', 'zj']
 # do as many sessions as it takes
 while True:
@@ -65,9 +66,8 @@ while True:
 
         type_proportions[:, i] = sts[:, temp_counter]
 
-    plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(session_counter),
+    plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(1 + session_counter),
                 vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(alph[session_counter]))
-    plt.close()
 
     if not_ended == 0:
         break
diff --git a/simplex_plot.py b/simplex_plot.py
index 635885d05417cb2c7810ec74e6d5cb5572e6f0b7..420dd26c661da32a2e9f89e2fbb937edea411f25 100644
--- a/simplex_plot.py
+++ b/simplex_plot.py
@@ -14,8 +14,8 @@ import matplotlib.patches as PA
 
 
 def plotSimplex(points, fig=None,
-                vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'],
-                save_title="test.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs):
+                vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'], title='',
+                save_title="dur_simplex.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs):
     """
     Plot Nx3 points array on the 3-simplex
     (with optionally labeled vertices)
@@ -37,21 +37,25 @@ def plotSimplex(points, fig=None,
     # fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False)
     # Project and draw the actual points
     projected = projectSimplex(points / points.sum(1)[:, None])
-    print(projected)
-    P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs)
+    P.scatter(projected[:, 0] + x_offset, projected[:, 1] + y_offset, s=points.sum(1) * 3.5, **kwargs)#s=35
 
     # plot center with average size
     projected = projectSimplex(np.mean(points / points.sum(1)[:, None], axis=0).reshape(1, 3))
-    P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5)
+    print(points)
+    print()
+    print(projected)
+    P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5)#s=50
 
     # Leave some buffer around the triangle for vertex labels
     fig.gca().set_xlim(-0.05, 1.05)
     fig.gca().set_ylim(-0.05, 1.05)
 
     P.axis('off')
+    if title != '':
+        P.annotate(title, (0.395, np.sqrt(3) / 2 + 0.025), size=24)
 
     P.tight_layout()
-    P.savefig("dur_simplex.png", bbox_inches='tight', dpi=300, transparent=True)
+    P.savefig(save_title, bbox_inches='tight', dpi=300, transparent=True)
     if show:
         P.show()
     else:
diff --git a/state_dict_KS014 b/state_dict_KS014
new file mode 100644
index 0000000000000000000000000000000000000000..c183d111be82d6be94e31c438d9ece422bbbea76
Binary files /dev/null and b/state_dict_KS014 differ
diff --git a/state_dict_ZFM-04019 b/state_dict_ZFM-04019
new file mode 100644
index 0000000000000000000000000000000000000000..458cbe0f1f432c5e57ca3156688138077467873d
Binary files /dev/null and b/state_dict_ZFM-04019 differ
diff --git a/state_dict_ZFM-05236 b/state_dict_ZFM-05236
new file mode 100644
index 0000000000000000000000000000000000000000..216c9aeaafe342f1fda365a4d8261f91f3c2d74c
Binary files /dev/null and b/state_dict_ZFM-05236 differ