diff --git a/.gitignore b/.gitignore
index f9b99b52a826e7af468e07a81b9a537ed1907977..6828228ddf230944e5cd9532defcbf26028e3c8d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,7 +20,6 @@ peter_fiugres/
 consistency_data/
 dynamic_GLM_figures/
 dynamic_GLMiHMM_fits2/
-glm sim mice/
 dynamic_GLMiHMM_crossvals/
 __pycache__/
 fibre_data/
diff --git a/glm sim mice/Sim_03_5.py b/glm sim mice/Sim_03_5.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a17e61dc9607e3cb4a02b7eadd97179da2e862
--- /dev/null
+++ b/glm sim mice/Sim_03_5.py	
@@ -0,0 +1,90 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 2 states throughout
+They start a certain distance apart, states alternate during the session, duration is negative binomial
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+print('BIAS NOT UPDATED')
+quit()
+subject = 'CSHL059'
+new_name = 'GLM_Sim_03_5'
+seed = 4
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['n_sessions']
+from_session = info_dict['bias_start']
+
+GLM_weights = [np.array([-4.5, 4.3, 0, 0., 1.2, -0.7]),
+               np.array([-4.5, 3.2, 0, 1.3, 0.3, -1.5])]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 2))
+
+for k, j in enumerate(range(from_session, till_session + 1)):
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(5)
+    state_plot = np.zeros(2)
+    count = 0
+    curr_state = int(np.random.rand() > 0.5)
+    dur = nbinom.rvs(21 + curr_state * 15, 0.3)
+
+    prev_choice = 2 * int(np.random.rand() > 0.5) - 1
+    if (contrasts[0] < 0) == (prev_choice > 0):
+        prev_reward = 1.
+    elif contrasts[0] > 0:
+        prev_reward = 0.
+    else:
+        prev_reward = int(np.random.rand() > 0.5)
+    prev_side = - prev_choice if prev_reward == 0. else prev_choice
+    data[0, 1] = prev_choice + 1
+    side_info[0, 1] = prev_reward
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = prev_choice
+        predictors[3] = prev_reward
+        predictors[4] = prev_side
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state][:-1] * predictors) - GLM_weights[curr_state][-1])))
+        state_plot[curr_state] += 1
+        if dur == 0:
+            curr_state = (curr_state + 1) % 2
+            dur = nbinom.rvs(21 + curr_state * 15, 0.3)
+        else:
+            dur -= 1
+        prev_choice = data[i+1, 1] - 1
+        if (c < 0) == (prev_choice > 0):
+            prev_reward = 1.
+        elif c > 0:
+            prev_reward = 0.
+        else:
+            prev_reward = int(np.random.rand() > side_info[i+1, 0])
+        prev_side = - prev_choice if prev_reward == 0. else prev_choice
+        side_info[i+1, 1] = prev_reward
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(2):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.show()
diff --git a/glm sim mice/Sim_04.py b/glm sim mice/Sim_04.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a4af31225f841ef46b2420b40fbd3f2222fc65f
--- /dev/null
+++ b/glm sim mice/Sim_04.py	
@@ -0,0 +1,97 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 4 states throughout
+They start a certain distance apart, states alternate during the session, duration is negative binomial
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+
+subject = 'CSHL059'
+new_name = 'GLM_Sim_04'
+seed = 5
+
+print('BIAS NOT UPDATED')
+quit()
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['n_sessions']
+from_session = info_dict['bias_start']
+
+GLM_weights = [np.array([-4.5, 4.3, 0, 0., 1.2, -0.7]),
+               np.array([-4.5, 3.2, 1.3, 0., 0.3, -1.5]),
+               np.array([-1, 1.2, 2.1, 0.5, 0.1, 1]),
+               np.array([-0.1, 0.2, 0, 0.1, 3.9, -1.])]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+
+for k, j in enumerate(range(from_session, till_session + 1)):
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(5)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = np.random.choice(4)
+    dur = nbinom.rvs(21 + (curr_state % 2) * 15, 0.3)
+
+    prev_choice = 2 * int(np.random.rand() > 0.5) - 1
+    if (contrasts[0] < 0) == (prev_choice > 0):
+        prev_reward = 1.
+    elif contrasts[0] > 0:
+        prev_reward = 0.
+    else:
+        prev_reward = int(np.random.rand() > 0.5)
+    prev_side = - prev_choice if prev_reward == 0. else prev_choice
+    data[0, 1] = prev_choice + 1
+    side_info[0, 1] = prev_reward
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = prev_choice
+        predictors[3] = prev_reward
+        predictors[4] = prev_side
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state][:-1] * predictors) - GLM_weights[curr_state][-1])))
+        state_plot[curr_state] += 1
+        if dur == 0:
+            while True:
+                next_state = np.random.choice(4)
+                if next_state != curr_state:
+                    curr_state = next_state
+                    break
+            dur = nbinom.rvs(21 + (curr_state % 2) * 15, 0.3)
+        else:
+            dur -= 1
+        prev_choice = data[i+1, 1] - 1
+        if (c < 0) == (prev_choice > 0):
+            prev_reward = 1.
+        elif c > 0:
+            prev_reward = 0.
+        else:
+            prev_reward = int(np.random.rand() > side_info[i+1, 0])
+        prev_side = - prev_choice if prev_reward == 0. else prev_choice
+        side_info[i+1, 1] = prev_reward
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.show()
diff --git a/glm sim mice/Sim_05.py b/glm sim mice/Sim_05.py
new file mode 100644
index 0000000000000000000000000000000000000000..036e171d243066e620e9661cc2277e98e2a34126
--- /dev/null
+++ b/glm sim mice/Sim_05.py	
@@ -0,0 +1,88 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 4 states throughout
+They start a certain distance apart, states alternate during the session, duration is negative binomial
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+
+subject = 'CSHL059'
+new_name = 'GLM_Sim_05'
+seed = 6
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+GLM_weights = [np.array([-4.5, 4.3, 0., 1.2, -0.7]),
+               np.array([-4.5, 3.2, 1.3, 0.3, -1.5]),
+               np.array([-1, 1.2, 2.1, 0.1, 1]),
+               np.array([-0.1, 0.2, 0., 3.9, -1.])]
+GLM_weights = list(reversed(GLM_weights))
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+
+for k, j in enumerate(range(from_session, till_session + 1)):
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(5)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = j // 5
+
+    prev_choice = 2 * int(np.random.rand() > 0.5) - 1
+    if (contrasts[0] < 0) == (prev_choice > 0):
+        prev_reward = 1.
+    elif contrasts[0] > 0:
+        prev_reward = 0.
+    else:
+        prev_reward = int(np.random.rand() > 0.5)
+    prev_side = - prev_choice if prev_reward == 0. else prev_choice
+    data[0, 1] = prev_choice + 1
+    side_info[0, 1] = prev_reward
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = prev_choice
+        predictors[3] = prev_side
+        predictors[4] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+
+        prev_choice = data[i+1, 1] - 1
+        if (c < 0) == (prev_choice > 0):
+            prev_reward = 1.
+        elif c > 0:
+            prev_reward = 0.
+        else:
+            prev_reward = int(np.random.rand() > side_info[i+1, 0])
+        prev_side = - prev_choice if prev_reward == 0. else prev_choice
+        side_info[i+1, 1] = prev_reward
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.savefig('states_05')
+plt.show()
diff --git a/glm sim mice/Sim_06.py b/glm sim mice/Sim_06.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd59566d6aa3f5a6bbaf4494eea13f541fecccfd
--- /dev/null
+++ b/glm sim mice/Sim_06.py	
@@ -0,0 +1,89 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 4 states throughout
+They start a certain distance apart, states alternate during the session, duration is negative binomial
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+
+subject = 'CSHL059'
+new_name = 'GLM_Sim_06'
+seed = 7
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+GLM_weights = [np.array([-4.5, 4.3, 0., 1.2, -0.7]),
+               np.array([-4.5, 3.2, 1.3, 0.3, -1.5]),
+               np.array([-1, 1.2, 2.1, 0.1, 1]),
+               np.array([-0.1, 0.2, 0., 3.9, -1.])]
+GLM_weights = list(reversed(GLM_weights))
+states = [0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 1, 1, 2, 3, 3, 3, 0, 3]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+
+for k, j in enumerate(range(from_session, till_session + 1)):
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(5)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = states[j]
+
+    prev_choice = 2 * int(np.random.rand() > 0.5) - 1
+    if (contrasts[0] < 0) == (prev_choice > 0):
+        prev_reward = 1.
+    elif contrasts[0] > 0:
+        prev_reward = 0.
+    else:
+        prev_reward = int(np.random.rand() > 0.5)
+    prev_side = - prev_choice if prev_reward == 0. else prev_choice
+    data[0, 1] = prev_choice + 1
+    side_info[0, 1] = prev_reward
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = prev_choice
+        predictors[3] = prev_side
+        predictors[4] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+
+        prev_choice = data[i+1, 1] - 1
+        if (c < 0) == (prev_choice > 0):
+            prev_reward = 1.
+        elif c > 0:
+            prev_reward = 0.
+        else:
+            prev_reward = int(np.random.rand() > side_info[i+1, 0])
+        prev_side = - prev_choice if prev_reward == 0. else prev_choice
+        side_info[i+1, 1] = prev_reward
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.savefig('states_06')
+plt.show()
diff --git a/glm sim mice/Sim_07.py b/glm sim mice/Sim_07.py
new file mode 100644
index 0000000000000000000000000000000000000000..35e6bb66d0cf510fc05531e8a876ee2582f5b802
--- /dev/null
+++ b/glm sim mice/Sim_07.py	
@@ -0,0 +1,99 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 4 states throughout
+same as 06 but more data
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+
+subject = 'CSHL059'
+new_name = 'GLM_Sim_07'
+seed = 8
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+info_dict['bias_start'] = info_dict['bias_start'] * 2
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+GLM_weights = [np.array([-4.5, 4.3, 0., 1.2, -0.7]),
+               np.array([-4.5, 3.2, 1.3, 0.3, -1.5]),
+               np.array([-1, 1.2, 2.1, 0.1, 1]),
+               np.array([-0.1, 0.2, 0., 3.9, -1.])]
+GLM_weights = list(reversed(GLM_weights))
+states = [0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 1, 1, 2, 3, 3, 3, 0, 3]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+
+for k, j in enumerate(range(from_session, till_session + 1)):
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j // 2), "rb"))
+    data = np.tile(data, (2, 1))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j // 2), "rb"))
+    side_info = np.tile(side_info, (2, 1))
+    if j % 2 == 1:
+        np.random.shuffle(data)
+        np.random.shuffle(side_info)
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(5)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = states[j // 2]
+
+    prev_choice = 2 * int(np.random.rand() > 0.5) - 1
+    if (contrasts[0] < 0) == (prev_choice > 0):
+        prev_reward = 1.
+    elif contrasts[0] > 0:
+        prev_reward = 0.
+    else:
+        prev_reward = int(np.random.rand() > 0.5)
+    prev_side = - prev_choice if prev_reward == 0. else prev_choice
+    data[0, 1] = prev_choice + 1
+    side_info[0, 1] = prev_reward
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = prev_choice
+        predictors[3] = prev_side
+        predictors[4] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+
+        prev_choice = data[i+1, 1] - 1
+        if (c < 0) == (prev_choice > 0):
+            prev_reward = 1.
+        elif c > 0:
+            prev_reward = 0.
+        else:
+            prev_reward = int(np.random.rand() > side_info[i+1, 0])
+        prev_side = - prev_choice if prev_reward == 0. else prev_choice
+        side_info[i+1, 1] = prev_reward
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+
+plt.savefig('states_07')
+plt.show()
+
+truth = {'state_posterior': state_posterior, 'weights': list(reversed(GLM_weights)), 'state_map': dict(zip(list(range(len(GLM_weights))), list(range(len(GLM_weights)))))}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
diff --git a/glm sim mice/Sim_08.py b/glm sim mice/Sim_08.py
new file mode 100644
index 0000000000000000000000000000000000000000..974da73fcb4741632ae956751f72c384f24373ec
--- /dev/null
+++ b/glm sim mice/Sim_08.py	
@@ -0,0 +1,96 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 4 states throughout
+They start a certain distance apart, states alternate during the session, duration is negative binomial
+Same as 06, but now we take data from the favourite
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+
+
+subject = 'CSH_ZAD_022'
+new_name = 'GLM_Sim_08'
+seed = 9
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    if weights.shape[0] == 3 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()
+
+
+GLM_weights = [np.array([-3.5, 3.3, -0.7]),
+               np.array([-3.5, 1.7, -1.5]),
+               np.array([-0.3, 1.2, 1]),
+               np.array([-0.1, 0.2, -1.])]
+neg_bin_params = [(12, 0.3), (16, 0.25), (20, 0.15), (30, 0.15)]
+GLM_weights = list(reversed(GLM_weights))
+states = [(0,), (0,), (0,), (1,), (0, 1), (0, 1), (2,), (2, 3), (2, 3, 1), (2, 3), (3,), (1, 3), (1, 2), (2, 3), (2, 3), (2, 3, 1, 0)]
+
+for w in GLM_weights:
+    plt.plot(weights_to_pmf(w, with_bias=1))
+    plt.ylim(bottom=0, top=1)
+    plt.show()
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+
+for k, j in enumerate(range(from_session, till_session)):
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+
+    state_counter = 0
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.show()
diff --git a/glm sim mice/Sim_09.py b/glm sim mice/Sim_09.py
new file mode 100644
index 0000000000000000000000000000000000000000..8674a99df99a5adc045b53685a48773b9b5940ce
--- /dev/null
+++ b/glm sim mice/Sim_09.py	
@@ -0,0 +1,105 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 4 states throughout
+They start a certain distance apart, states alternate during the session, duration is negative binomial
+Same as 08, but now also noise on the glm weights
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+
+
+subject = 'CSH_ZAD_022'
+new_name = 'GLM_Sim_09'
+seed = 10
+
+print("_____________________________________________________________________________________________________________")
+print("This is problematic, as the duration changes depending on whether a state is the only one in a session or not")
+print("_____________________________________________________________________________________________________________")
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    if weights.shape[0] == 3 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()
+
+GLM_weights = [np.array([-3.5, 3.3, -0.7]),
+               np.array([-3.5, 1.7, -1.5]),
+               np.array([-0.3, 1.2, 1]),
+               np.array([-0.1, 0.2, -1.])]
+neg_bin_params = [(15, 0.3), (30, 0.25), (30, 0.15), (140, 0.15)]
+# [5, 15, 30, 50, 75, 105, 140, 180, 225, 275, 330, 390, 455, 525, 600, 680, 765, 855, 950]
+GLM_weights = list(reversed(GLM_weights))
+states = [(0,), (0,), (0,), (1,), (0, 1), (0, 1), (2,), (2, 3), (2, 3, 1), (2, 3), (3,), (1, 3), (1, 2), (2, 3), (2, 3), (2, 3, 1, 0)]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+for k, j in enumerate(range(from_session, till_session)):
+    for i, w in enumerate(GLM_weights):
+        if i in states[j]:
+            # print(j, i)
+            w += np.random.normal(np.zeros(3), 0.03 * np.ones(3))
+    plt.plot(weights_to_pmf(GLM_weights[-3]))
+    if j == till_session - 1:
+        plt.ylim(bottom=0, top=1)
+        plt.show()
+
+    print(GLM_weights[0])
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+
+    state_counter = 0
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.show()
diff --git a/glm sim mice/Sim_10.py b/glm sim mice/Sim_10.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e7904a39a2b23d405990f7a9b0258d5c8b623b3
--- /dev/null
+++ b/glm sim mice/Sim_10.py	
@@ -0,0 +1,100 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 2 states throughout
+They start a certain distance apart, one state at start, other later, duration is negative binomial
+Same as 09, but now only 2 states
+this seems to have changed a bit, there are now 18 sessions of this mouse, but we orignially only did 16...
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+
+
+subject = 'CSH_ZAD_022'
+new_name = 'GLM_Sim_10'
+seed = 11
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    if weights.shape[0] == 3 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()
+
+GLM_weights = [np.array([-3.5, 3.3, -0.7]),
+               np.array([-0.3, 1.2, 1])]
+neg_bin_params = [(12, 0.3), (30, 0.15)]
+GLM_weights = list(reversed(GLM_weights))
+states = [(0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,)]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+
+for k, j in enumerate(range(from_session, till_session + 1)):
+    for w in GLM_weights:
+        w += np.random.normal(np.zeros(3), 0.03 * np.ones(3))
+    plt.plot(weights_to_pmf(GLM_weights[1]))
+    if j == till_session - 1:
+        plt.ylim(bottom=0, top=1)
+        plt.show()
+
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+
+    state_counter = 0
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.savefig('states_10')
+plt.show()
+
+truth = {'state_posterior': state_posterior, 'weights': list(reversed(GLM_weights)), 'state_map': dict(zip(list(range(len(GLM_weights))), list(range(len(GLM_weights))))), 'durs': neg_bin_params}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
diff --git a/glm sim mice/Sim_11.py b/glm sim mice/Sim_11.py
new file mode 100644
index 0000000000000000000000000000000000000000..f74b8ef0e30d5dc5cce015caae5ba9b5efc7808d
--- /dev/null
+++ b/glm sim mice/Sim_11.py	
@@ -0,0 +1,99 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 1 states throughout, which goes from poor to good
+Fit var is 0.0315
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+
+
+subject = 'KS014'
+new_name = 'GLM_Sim_11'
+seed = 12
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    if weights.shape[0] == 3 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()
+
+
+start = np.array([-0.1, 0.2, -1.])
+end = np.array([-3.7, 3.5, -0.7])
+GLM_weights = np.tile(start, (till_session + 1, 1))
+for i, gw in enumerate(GLM_weights):
+    print(np.var((1 / (till_session + 1)) * (end - start)))
+    gw += (i / (till_session + 1)) * (end - start)
+
+neg_bin_params = [(30, 0.15)]
+states = [(0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,)]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+
+for k, j in enumerate(range(from_session, till_session + 1)):
+    plt.plot(weights_to_pmf(GLM_weights[k]))
+    if j == till_session - 1:
+        plt.ylim(bottom=0, top=1)
+        plt.show()
+
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+
+    state_counter = 0
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[k] * predictors))))
+        state_plot[curr_state] += 1
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.savefig('states_11')
+plt.show()
diff --git a/glm sim mice/Sim_11_sub.py b/glm sim mice/Sim_11_sub.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae1deefb6e12314604c1494bda919d6d638f4067
--- /dev/null
+++ b/glm sim mice/Sim_11_sub.py	
@@ -0,0 +1,94 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 1 states throughout, which stays constant
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+
+
+subject = 'KS014'
+new_name = 'GLM_Sim_11_sub'
+seed = 112
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    if weights.shape[0] == 3 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()
+
+
+GLM_weights = np.array([[-1.8, 2.9, -0.9]])
+neg_bin_params = [(30, 0.15)]
+states = [(0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,)]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session - from_session + 1, 1))
+
+for k, j in enumerate(range(from_session, till_session + 1)):
+    plt.plot(weights_to_pmf(GLM_weights[0]))
+    if j == till_session - 1:
+        plt.ylim(bottom=0, top=1)
+        plt.show()
+
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(1)
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+
+    state_counter = 0
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(1):
+    plt.fill_between(range(till_session - from_session + 1), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+truth = {'state_posterior': state_posterior, 'weights': GLM_weights, 'state_map': {0: 0}}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
+plt.savefig('states_{}'.format(new_name))
+plt.show()
diff --git a/glm sim mice/Sim_11_trick.py b/glm sim mice/Sim_11_trick.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e654c14dab78dafa5cb0a612eb49a610fc320a9
--- /dev/null
+++ b/glm sim mice/Sim_11_trick.py	
@@ -0,0 +1,114 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 1 states throughout, which goes from poor to good, but not on session bounds, but continuously on every trial
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+
+
+subject = 'KS014'
+new_name = 'GLM_Sim_11_trick'
+seed = 212
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    if weights.shape[0] == 3 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()
+
+
+start = np.array([-0.1, 0.2, -1.])
+end = np.array([-3.7, 3.5, -0.7])
+GLM_weights = np.tile(start, (till_session + 1, 1))
+
+neg_bin_params = [(30, 0.15)]
+states = [(0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,)]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session - from_session + 1, 1))
+
+total_trial_count = 0
+for k, j in enumerate(range(from_session, till_session + 1)):
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    total_trial_count += data.shape[0]
+
+internal_trial_count = 0
+for k, j in enumerate(range(from_session, till_session + 1)):
+
+    GLM_weights[k] = start + (internal_trial_count / total_trial_count) * (end - start)
+
+    plt.plot(weights_to_pmf(GLM_weights[k]))
+    if j == till_session - 1:
+        plt.ylim(bottom=0, top=1)
+        plt.show()
+
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(1)
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+    side_info[0, 1] = (prev_choice == 2 and contrasts[0] < 0) or ((prev_choice == 0 and contrasts[0] > 0))
+    if contrasts[0] == 0:
+        side_info[0, 1] = 0.5
+
+    state_counter = 0
+
+    for i, c in enumerate(contrasts[1:]):
+        internal_trial_count += 1
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(start + ((internal_trial_count / total_trial_count) * (end - start)) * predictors))))
+        state_plot[curr_state] += 1
+        side_info[i + 1, 1] = (data[i+1, 1] == 2 and c < 0) or ((data[i+1, 1] == 0 and c > 0))
+        if c == 0:
+            side_info[i + 1, 1] = 0.5
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+s = 0
+plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+truth = {'state_posterior': state_posterior, 'weights': GLM_weights, 'state_map': {0: 0}}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
+
+plt.savefig('states_11_trick')
+plt.show()
diff --git a/glm sim mice/Sim_12.py b/glm sim mice/Sim_12.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d46d0cba6d43b91e639eaad8c01017128e6adc5
--- /dev/null
+++ b/glm sim mice/Sim_12.py	
@@ -0,0 +1,117 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 4 states throughout
+They start a certain distance apart, states alternate during the session, duration is negative binomial
+Same as 08, but now also noise on the glm weights
+this seems to have changed a bit, there are now 18 sessions of this mouse, but we orignially only did 16...
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+
+subject = 'CSH_ZAD_022'
+new_name = 'GLM_Sim_12'
+seed = 13
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    if weights.shape[0] == 3 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()
+
+GLM_weights = [np.array([-3.5, 3.3, -0.7]),
+               np.array([-3.5, 1.7, -1.5]),
+               np.array([-0.3, 1.2, 1]),
+               np.array([-0.1, 0.2, -1.])]
+neg_bin_params = [(180, 0.2), (75, 0.12), (105, 0.14), (140, 0.15)]
+# [5, 15, 30, 50, 75, 105, 140, 180, 225, 275, 330, 390, 455, 525, 600, 680, 765, 855, 950]
+GLM_weights = list(reversed(GLM_weights))
+states = [(0,), (0,), (0,), (1, 0), (1, 0), (1, 0), (2,), (2, 3), (2, 3, 1), (2, 3), (3,), (1, 3), (1, 2), (3, 2), (3, 2), (3, 2, 1, 0), (3, 2), (3, 2), (3, 2, 1, 0)]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+for k, j in enumerate(range(from_session, till_session + 1)):
+    for i, w in enumerate(GLM_weights):
+        if i in states[j]:
+            # print(j, i)
+            w += np.random.normal(np.zeros(3), 0.03 * np.ones(3))
+    # plt.plot(weights_to_pmf(GLM_weights[-3]))
+
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    if len(states[j]) <= 1:
+        print(states[j])
+        print(data.shape)
+        print(1 - nbinom.cdf(data.shape[0], *neg_bin_params[states[j][0]]))
+        print()
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+
+    state_counter = 0
+
+    plt.plot(weights_to_pmf(GLM_weights[0], with_bias=1))
+    plt.plot(weights_to_pmf(GLM_weights[1], with_bias=1))
+    plt.title(k)
+    plt.ylim(bottom=0, top=1)
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    print(k)
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.show()
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.savefig('states_12')
+plt.show()
+
+truth = {'state_posterior': state_posterior, 'weights': list(reversed(GLM_weights)), 'state_map': dict(zip(list(range(len(GLM_weights))), list(range(len(GLM_weights))))), 'durs': neg_bin_params}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
diff --git a/glm sim mice/Sim_13.py b/glm sim mice/Sim_13.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba58f4575f579faeb7d0bb99387069dcfca409dd
--- /dev/null
+++ b/glm sim mice/Sim_13.py	
@@ -0,0 +1,127 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 5 states throughout
+States alternate during the session, duration is negative binomial
+I paid attention that if a state filled out a session, it's dur dist reflected that
+"""
+import numpy as np
+# import matplotlib
+# matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+
+subject = 'CSH_ZAD_022'
+new_name = 'GLM_Sim_13'
+seed = 14
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    if weights.shape[0] == 3 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()
+
+
+GLM_weights = [np.array([-2.5, 2.3, -1.3]),
+               np.array([-2.4, 2.1, 1.1]),
+               np.array([-0.5, 0.4, -2.5]),
+               np.array([0.1, 0.9, -0.6]),
+               np.array([0.1, -0.2, 1.])]
+
+for gw in GLM_weights:
+    plt.plot(weights_to_pmf(gw))
+plt.show()
+neg_bin_params = [(180, 0.2), (105, 0.132), (5, 0.2), (30, 0.5), (30, 0.28)]
+# [5, 15, 30, 50, 75, 105, 140, 180, 225, 275, 330, 390, 455, 525, 600, 680, 765, 855, 950]
+GLM_weights = list(reversed(GLM_weights))
+states = [(0,), (1,), (1,), (1, 2,), (1,), (1,), (3, 4,), (3, 4,), (4, 3, 2, 4, 3, 4, 3,), (3, 4,), (3, 4, 2,),
+          (3, 4, 3, 2,), (3, 4,), (3, 4, 3, 4, 2,), (3, 4, 3, 2,), (3, 4,), (3, 4,), (3, 4, 3, 2,), (3, 4,)]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session - from_session + 1, len(GLM_weights)))
+state_counter_array = np.zeros((till_session - from_session + 1, len(GLM_weights)))
+for k, j in enumerate(range(from_session, till_session + 1)):
+    for i, w in enumerate(GLM_weights):
+        if i in states[j]:
+            # print(j, i)
+            w += np.random.normal(np.zeros(3), 0.03 * np.ones(3))
+    # plt.plot(weights_to_pmf(GLM_weights[-3]))
+    if j == till_session - 1:
+        plt.ylim(bottom=0, top=1)
+        plt.show()
+
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    if len(states[j]) <= 1:
+        print(states[j])
+        print(1 - nbinom.cdf(data.shape[0], *neg_bin_params[states[j][0]]))
+    print(data.shape)
+    print()
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(len(GLM_weights))
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+    side_info[0, 1] = (data[0, 1] == 2 and contrasts[0] < 0) or ((data[0, 1] == 0 and contrasts[0] > 0))
+    if contrasts[0] == 0:
+        side_info[0, 1] = 0.5
+
+    state_counter = 0
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+        side_info[i + 1, 1] = (data[i+1, 1] == 2 and c < 0) or ((data[i+1, 1] == 0 and c > 0))
+        if c == 0:
+            side_info[i + 1, 1] = 0.5
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    state_counter_array[k] = state_plot
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(len(GLM_weights)):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+truth = {'state_posterior': state_posterior, 'weights': GLM_weights, 'state_map': {0: 0, 1: 1, 2: 4, 3: 3, 4: 2}}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
+plt.savefig('states_13')
+plt.show()
diff --git a/glm sim mice/Sim_14.py b/glm sim mice/Sim_14.py
new file mode 100644
index 0000000000000000000000000000000000000000..72016aaa598de5ba60378a8d5c45d221d303cb53
--- /dev/null
+++ b/glm sim mice/Sim_14.py	
@@ -0,0 +1,140 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 6 states all throughout
+They start a certain distance apart, states alternate during the session, duration is negative binomial
+"""
+import numpy as np
+# import matplotlib
+# matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+from scipy.linalg import eig
+
+subject = 'CSH_ZAD_022'
+new_name = 'GLM_Sim_14'
+seed = 15
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    if weights.shape[0] == 3 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()
+
+
+GLM_weights = [np.array([0., 0., 0.]),
+               np.array([-2.8, 2.8, 0.]),
+               np.array([2.8, -2.8, 0.]),
+               np.array([-0.85, 0.85, 0.]),
+               np.array([-2.5, -2.5, -2.5]),
+               np.array([2.5, 2.5, 2.5])]
+
+plt.subplot(2, 1, 1)
+for gw in GLM_weights:
+    plt.plot(weights_to_pmf(gw))
+
+plt.subplot(2, 1, 2)
+for gw in GLM_weights:
+    plt.plot(weights_to_pmf(gw, with_bias=0))
+plt.close()
+neg_bin_params = [(30, 0.2), (75, 0.35), (180, 0.3), (15, 0.23), (140, 0.26), (5, 0.17)]
+# [5, 15, 30, 50, 75, 105, 140, 180, 225, 275, 330, 390, 455, 525, 600, 680, 765, 855, 950]
+
+
+transition_mat = np.random.dirichlet(np.ones(6) * 0.5, (6))
+np.fill_diagonal(transition_mat, 0)
+transition_mat = transition_mat / transition_mat.sum(1)[:, None]
+print(transition_mat)
+print(transition_mat.sum(1))
+
+# this computation doesn't work for some reason
+eigenvals, eigenvects = eig(transition_mat.T)
+close_to_1_idx = np.isclose(eigenvals, 1)
+target_eigenvect = eigenvects[:, close_to_1_idx]
+target_eigenvect = target_eigenvect[:, 0]
+# Turn the eigenvector elements into probabilites
+stationary_distrib = target_eigenvect / sum(target_eigenvect)
+print(stationary_distrib)
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, len(GLM_weights)))
+log_like_save = 0
+counter = 0
+for k, j in enumerate(range(from_session, till_session + 1)):
+    if j == till_session - 1:
+        plt.ylim(bottom=0, top=1)
+        plt.close()
+
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(len(GLM_weights))
+    count = 0
+    curr_state = np.random.choice(6)
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+    side_info[0, 1] = (data[0, 1] == 2 and contrasts[0] < 0) or ((data[0, 1] == 0 and contrasts[0] > 0))
+    if contrasts[0] == 0:
+        side_info[0, 1] = 0.5
+
+    state_counter = 0
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        prob = 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors)))
+        data[i+1, 1] = 2 * (np.random.rand() < prob)
+        log_like_save += np.log(prob) if data[i+1, 1] == 2 else np.log(1 - prob)
+        counter += 1
+        state_plot[curr_state] += 1
+        side_info[i + 1, 1] = (data[i+1, 1] == 2 and c < 0) or ((data[i+1, 1] == 0 and c > 0))
+        if c == 0:
+            side_info[i + 1, 1] = 0.5
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = np.random.choice(6, p=transition_mat[curr_state])
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+print("Groud truth log like: {}".format(log_like_save / counter))
+
+plt.figure(figsize=(16, 9))
+for s in range(len(GLM_weights)):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.savefig('states_14')
+plt.close()
+
+truth = {'state_posterior': state_posterior, 'weights': GLM_weights, 'state_map': {2: 0, 0: 1, 4: 2, 5: 3, 3: 4, 1: 5}, 'durs': neg_bin_params}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
\ No newline at end of file
diff --git a/glm sim mice/Sim_15.py b/glm sim mice/Sim_15.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcaf633a92f98f28ef82bfc7d66309cc6b1404c4
--- /dev/null
+++ b/glm sim mice/Sim_15.py	
@@ -0,0 +1,136 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 6 states all throughout, but duration prior is misspecified
+They start a certain distance apart, states alternate during the session, duration is negative binomial
+"""
+import numpy as np
+import matplotlib
+# matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+from scipy.linalg import eig
+from communal_funcs import weights_to_pmf
+
+subject = 'CSH_ZAD_022'
+new_name = 'GLM_Sim_15'
+seed = 20
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+
+GLM_weights = [np.array([0., 0., 0.]),
+               np.array([-2.8, 2.8, 0.]),
+               np.array([2.8, -2.8, 0.]),
+               np.array([-0.85, 0.85, 0.]),
+               np.array([-2.5, -2.5, -2.5]),
+               np.array([2.5, 2.5, 2.5])]
+
+plt.subplot(2, 1, 1)
+for gw in GLM_weights:
+    plt.plot(weights_to_pmf(gw))
+
+plt.subplot(2, 1, 2)
+for gw in GLM_weights:
+    plt.plot(weights_to_pmf(gw))
+plt.close()
+neg_bin_params = [(23, 0.2), (91, 0.38), (211, 0.42), (10, 0.2), (159, 0.52), (7, 0.15)]
+# [5, 15, 30, 50, 75, 105, 140, 180, 225, 275, 330, 390, 455, 525, 600, 680, 765, 855, 950]
+
+
+transition_mat = np.random.dirichlet(np.ones(6) * 0.5, (6))
+np.fill_diagonal(transition_mat, 0)
+transition_mat = transition_mat / transition_mat.sum(1)[:, None]
+print(transition_mat)
+print(transition_mat.sum(1))
+
+# this computation doesn't work for some reason
+eigenvals, eigenvects = eig(transition_mat.T)
+close_to_1_idx = np.isclose(eigenvals, 1)
+target_eigenvect = eigenvects[:, close_to_1_idx]
+target_eigenvect = target_eigenvect[:, 0]
+# Turn the eigenvector elements into probabilites
+stationary_distrib = target_eigenvect / sum(target_eigenvect)
+print(stationary_distrib)
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, len(GLM_weights)))
+state_counter_array = np.zeros((till_session + 1 - from_session, len(GLM_weights)))
+observed_states = []
+first_appearances = []
+session_bounds = [0]
+trial_counter = 0
+for k, j in enumerate(range(from_session, till_session + 1)):
+
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(3)
+    state_plot = np.zeros(len(GLM_weights))
+    count = 0
+    curr_state = np.random.choice(6)
+    if curr_state not in observed_states:
+        first_appearances.append(trial_counter)
+        observed_states.append(curr_state)
+
+    observed_states.append(curr_state)
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    prev_choice = 2 * int(np.random.rand() > 0.5)
+    data[0, 1] = prev_choice
+
+    state_counter = 0
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = np.random.choice(6, p=transition_mat[curr_state])
+            if curr_state not in observed_states:
+                first_appearances.append(trial_counter + i)
+                observed_states.append(curr_state)
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    state_counter_array[k] = state_plot
+    trial_counter += len(data[:, 0])
+    session_bounds.append(trial_counter)
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.vlines(session_bounds, 0, 1.2, color='k')
+plt.vlines(first_appearances, 0, 1, color='r')
+plt.close()
+
+state_props = state_posterior.sum(0) / state_posterior.sum()
+print(state_props)
+print(state_props.min(), state_props.max())
+plt.figure(figsize=(16, 9))
+for s in range(len(GLM_weights)):
+    plt.fill_between(range(till_session + 1 - from_session), s, s + state_posterior[:, s])
+
+plt.savefig('states_15')
+plt.close()
+
+# 2: 0
+truth = {'state_posterior': state_posterior, 'weights': GLM_weights, 'state_map': {1: 1, 5: 5, 4: 3, 0: 4, 2: 0, 3: 2}, 'durs': neg_bin_params}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
\ No newline at end of file
diff --git a/glm sim mice/Sim_16.py b/glm sim mice/Sim_16.py
new file mode 100644
index 0000000000000000000000000000000000000000..980cca0a66461b7264412770f16ef98be716f7e6
--- /dev/null
+++ b/glm sim mice/Sim_16.py	
@@ -0,0 +1,85 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse uses 4 states throughout
+They start a certain distance apart, states alternate during the session, duration are always entire sessions
+now with new perseveration
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+
+subject = 'CSHL059'
+new_name = 'GLM_Sim_16'
+seed = 17
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+GLM_weights = [np.array([-4.5, 4.3, 0., -0.7]),
+               np.array([-4.5, 3.2, 1.3, -1.5]),
+               np.array([-1, 1.2, 2.1, 1]),
+               np.array([-0.1, 0.2, 0., -1.])]
+GLM_weights = list(reversed(GLM_weights))
+states = [0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 1, 1, 2, 3, 3, 3, 0, 3]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session + 1 - from_session, 4))
+
+exp_decay, exp_length = 0.3, 5
+exp_filter = np.exp(- exp_decay * np.arange(exp_length))
+exp_filter /= exp_filter.sum()
+exp_filter = np.flip(exp_filter)  # because we don't convolve, we need to flip manually
+
+for k, j in enumerate(range(from_session, till_session + 1)):
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(4)
+    previous_answers = np.zeros(5)
+    state_plot = np.zeros(4)
+    count = 0
+    curr_state = states[j]
+
+    previous_answers[-1] = 2 * int(np.random.rand() > 0.5) - 1
+
+    data[0, 1] = previous_answers[-1] + 1
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = np.sum(previous_answers * exp_filter)
+        predictors[3] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+
+        previous_answers[:-1] = previous_answers[1:]
+        previous_answers[-1] = data[i+1, 1] - 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(4):
+    plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.savefig('states_16')
+plt.show()
+
+truth = {'state_posterior': state_posterior, 'weights': list(reversed(GLM_weights)), 'state_map': dict(zip(list(range(len(GLM_weights))), list(range(len(GLM_weights)))))}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
\ No newline at end of file
diff --git a/glm sim mice/Sim_17.py b/glm sim mice/Sim_17.py
new file mode 100644
index 0000000000000000000000000000000000000000..57eb545351653755cf00d694ffc5ffd7421a63a9
--- /dev/null
+++ b/glm sim mice/Sim_17.py	
@@ -0,0 +1,127 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse has rather a lot of sessions, that's the main test here
+Also, we use all other tricks in the book:
+states alternate during the session, duration is negative binomial
+now with new perseveration, and a bit of noise
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+
+subject = 'KS055'
+new_name = 'GLM_Sim_17'
+seed = 18
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
+    return 1 / (1 + np.exp(psi))  # we somehow got the answers twisted, so we drop the minus here to get the opposite response probability for plotting
+
+
+GLM_weights = [np.array([-4.5, 4.3, 0., -0.7]),
+               np.array([-0.3, 3.2, 0.5, -0.8]),
+               np.array([-2.3, 1.5, 0.3, -0.2]),
+               np.array([-1.9, 0.3, 1.5, 0.2]),
+               np.array([-0.3, 0.5, -0.8, 2.5]),
+               np.array([-1, 1.2, 2.1, 1]),
+               np.array([0.2, 0.3, 2, -1.]),
+               np.array([-0.1, 0.2, 0., -1.])]
+GLM_weights = list(reversed(GLM_weights))
+neg_bin_params = [(190, 0.2), (75, 0.11), (105, 0.17), (50, 0.15), (120, 0.28), (100, 0.12), (150, 0.24), (150, 0.17)]
+for i, gw in enumerate(GLM_weights):
+    plt.plot(weights_to_pmf(gw), label=i)
+plt.ylim(0, 1)
+plt.legend()
+plt.show()
+states = [(0,), (0,), (1, 0), (1,), (1,), (0,), (1, 0), (1, 0), (2,), (1, 0), (1, 0), (3, 1), (3, 1), (2,), (3, 0), (3,), (1, 3),
+          (2,), (4,0), (4,), (4,0), (4,0), (4,0), (2,), (0, 1), (4, 5), (4, 5), (4, 5), (6, 2), (6, 2), (6, 2), (6, 7), (6, 7),
+          (6, 3, 6), (6, 7), (6, 7), (7,), (7,), (7,)]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session - from_session, len(GLM_weights)))
+
+exp_decay, exp_length = 0.3, 5
+exp_filter = np.exp(- exp_decay * np.arange(exp_length))
+exp_filter /= exp_filter.sum()
+exp_filter = np.flip(exp_filter)  # because we don't convolve, we need to flip manually
+
+for k, j in enumerate(range(from_session, till_session)):
+    for i, w in enumerate(GLM_weights):
+        if i in states[j]:
+            w += np.random.normal(np.zeros(4), 0.03 * np.ones(4))
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    print(data.shape)
+    if len(states[j]) <= 1:
+        if 1 - nbinom.cdf(data.shape[0], *neg_bin_params[states[j][0]]) < 0.1:
+            print()
+            print(states[j])
+            print(j)
+            print(1 - nbinom.cdf(data.shape[0], *neg_bin_params[states[j][0]]))
+            print()
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(4)
+    previous_answers = np.zeros(5)
+    state_plot = np.zeros(len(GLM_weights))
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    previous_answers[-1] = 2 * int(np.random.rand() > 0.5) - 1
+    data[0, 1] = previous_answers[-1] + 1
+
+    state_counter = 0
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = np.sum(previous_answers * exp_filter)
+        predictors[3] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+        previous_answers[:-1] = previous_answers[1:]
+        previous_answers[-1] = data[i+1, 1] - 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+state_posterior = state_posterior[:, ::-1]  # we plot this in reverse
+
+plt.figure(figsize=(16, 9))
+for s in range(len(GLM_weights)):
+    plt.fill_between(range(till_session - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.savefig('states_17')
+plt.show()
+
+truth = {'state_posterior': state_posterior, 'weights': list(reversed(GLM_weights)), 'state_map': dict(zip(list(range(8)), list(range(8)))), 'durs': neg_bin_params}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
diff --git a/glm sim mice/Sim_18.py b/glm sim mice/Sim_18.py
new file mode 100644
index 0000000000000000000000000000000000000000..755aa4e4e077c907295985602617566a87e5f9ed
--- /dev/null
+++ b/glm sim mice/Sim_18.py	
@@ -0,0 +1,133 @@
+"""
+Generate data from a simulated mouse-GLM.
+
+This mouse has even  more sessions, that's the main test here
+Also, we use all other tricks in the book:
+states alternate during the session, duration is negative binomial
+now with new perseveration, and a b it of noise
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import nbinom
+
+
+subject = 'DY_008'
+new_name = 'GLM_Sim_18'
+seed = 19
+
+print(new_name)
+np.random.seed(seed)
+
+info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
+assert info_dict['subject'] == subject
+info_dict['subject'] = new_name
+pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
+
+till_session = info_dict['bias_start']
+from_session = 0
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
+    return 1 / (1 + np.exp(psi))  # we somehow got the answers twisted, so we drop the minus here to get the opposite response probability for plotting
+
+
+GLM_weights = [np.array([-4.1, 5.3, 0., -2.5]),
+               np.array([-4.5, 4.3, 0., -0.7]),
+               np.array([-0.3, 3.2, 0.5, -0.8]),
+               np.array([-2.3, 1.5, 0.3, -0.2]),
+               np.array([-1.9, 0.3, 1.5, 0.2]),
+               np.array([-0.3, 0.5, -0.8, 2.5]),
+               np.array([-1, 1.2, 2.1, 1]),
+               np.array([0.2, 0.3, 2, -0.3]),
+               np.array([-0.1, 0.2, 0., -1.])]
+GLM_weights = list(reversed(GLM_weights))
+neg_bin_params = [(190, 0.2), (75, 0.15), (105, 0.2), (50, 0.18), (120, 0.42), (100, 0.12), (150, 0.43), (150, 0.17), (460, 0.56)]
+for i, gw in enumerate(GLM_weights):
+    plt.plot(weights_to_pmf(gw), label=i)
+plt.ylim(0, 1)
+plt.legend()
+plt.show()
+states = [(0,), (0,), (1, 0), (1,), (1,), (0,), (2,), (1, 0), (1, 0), (3, 1), (3, 1), (3, 0), (2,), (1, 3),
+          (2,), (3,), (4,0), (4, 5), (4,0), (4,0), (4,0), (2,), (0, 1), (4, 5), (4, 5), (4, 5), (6, 2), (6, 2), (6, 2), (6, 7), (6, 7),
+          (6, 3, 6), (6, 7), (6, 7), (7,), (7,), (7,), (0,), (8, 5), (8, 5), (6, 7), (8, 5), (8,), (8, 5), (0,), (7,), (7,), (7,),
+          (8,), (7,), (1, 0), (7,), (8, 7), (8,), (4, 5), (8, 7), (6, 7), (8, 7)]
+
+contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
+num_to_contrast = {v: k for k, v in contrast_to_num.items()}
+
+state_posterior = np.zeros((till_session - from_session, len(GLM_weights)))
+
+exp_decay, exp_length = 0.3, 5
+exp_filter = np.exp(- exp_decay * np.arange(exp_length))
+exp_filter /= exp_filter.sum()
+exp_filter = np.flip(exp_filter)  # because we don't convolve, we need to flip manually
+
+for k, j in enumerate(range(from_session, till_session)):
+    for i, w in enumerate(GLM_weights):
+        if i in states[j]:
+            w += np.random.normal(np.zeros(4), 0.03 * np.ones(4))
+    data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
+    side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
+
+    print(data.shape)
+    if len(states[j]) <= 1:
+        if 1 - nbinom.cdf(data.shape[0], *neg_bin_params[states[j][0]]) < 0.1:
+            print()
+            print(states[j])
+            print(j)
+            print(1 - nbinom.cdf(data.shape[0], *neg_bin_params[states[j][0]]))
+            print()
+
+    contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
+
+    predictors = np.zeros(4)
+    previous_answers = np.zeros(5)
+    state_plot = np.zeros(len(GLM_weights))
+    count = 0
+    curr_state = states[j][0]
+    curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+    previous_answers[-1] = 2 * int(np.random.rand() > 0.5) - 1
+    data[0, 1] = previous_answers[-1] + 1
+    side_info[0, 1] = (data[0, 1] == 2 and contrasts[0] < 0) or ((data[0, 1] == 0 and contrasts[0] > 0))
+    if contrasts[0] == 0:
+        side_info[0, 1] = 0.5
+
+    state_counter = 0
+
+    for i, c in enumerate(contrasts[1:]):
+        predictors[0] = max(c, 0)
+        predictors[1] = abs(min(c, 0))
+        predictors[2] = np.sum(previous_answers * exp_filter)
+        predictors[3] = 1
+        data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
+        state_plot[curr_state] += 1
+        side_info[i + 1, 1] = (data[i+1, 1] == 2 and c < 0) or ((data[i+1, 1] == 0 and c > 0))
+        if c == 0:
+            side_info[i + 1, 1] = 0.5
+        curr_dur -= 1
+        if curr_dur == 0:
+            state_counter += 1
+            curr_state = states[j][state_counter % len(states[j])]
+            curr_dur = np.random.negative_binomial(*neg_bin_params[curr_state]) + 1
+
+        previous_answers[:-1] = previous_answers[1:]
+        previous_answers[-1] = data[i+1, 1] - 1
+
+    state_posterior[k] = state_plot / len(data[:, 0])
+    pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
+    pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
+
+plt.figure(figsize=(16, 9))
+for s in range(len(GLM_weights)):
+    plt.fill_between(range(till_session - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
+
+plt.savefig('states_18')
+plt.show()
+
+truth = {'state_posterior': state_posterior, 'weights': GLM_weights, 'state_map': dict(zip(list(range(9)), list(range(9)))), 'durs': neg_bin_params}
+pickle.dump(truth, open("truth_{}.p".format(new_name), "wb"))
diff --git a/glm sim mice/communal_funcs.py b/glm sim mice/communal_funcs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c11764cf9280b76169509366cfb47107249824da
--- /dev/null
+++ b/glm sim mice/communal_funcs.py	
@@ -0,0 +1,17 @@
+import numpy as np
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=None):
+    if with_bias is not None:
+        print("Bias is now always in PMF")
+        quit()
+    if weights.shape[0] == 3 or weights.shape[0] == 4 or weights.shape[0] == 5:
+        psi = weights[0] * contrasts_L + weights[1] * contrasts_R + weights[-1]
+        return 1 / (1 + np.exp(-psi))
+    elif weights.shape[0] == 11:
+        return weights[:, 0]
+    else:
+        print('new weight shape')
+        quit()