diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index 9b977369a2029101d52d48de0845217e14b6bdf6..fc7be25aa6fab9564c2d1753081bc62797bfad4c 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -719,7 +719,7 @@ def lapse_sides(test, state_sets, indices): """Compute and plot a lapse differential across sessions. Takes a single mouse and plots (1 - lapse_left) - lapse_right across sessions, with sessions boundaries shown.""" - + def func_init(): return {'lapse_side': np.zeros(test.results[0].n_datapoints) + 10, 'session_bounds': []} def first_for(test, results): @@ -1384,24 +1384,46 @@ len_to_bools = { } +def state_type_durs(states, pmfs): + # Takes states and pmfs, first creates an array of when which type is how active, then computes the number of sessions each type lasts. + # A type lasts until a more advanced type takes up more than 50% of a session (and cannot return) + # Returns the durations for all the different state types, and an array which holds the state percentages + state_types = np.zeros((3, states.shape[1])) + for s, pmf in zip(states, pmfs): + pmf_counter = -1 + for i in range(states.shape[1]): + if s[i]: + pmf_counter += 1 + state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i] # indexing horror + + durs = (np.where(state_types[1] > 0.5)[0][0], + np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0], + states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) + if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]: + durs = (np.where(state_types[2] > 0.5)[0][0], + 0, + states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) + return durs, state_types + + def state_cluster_interpolation(states, pmfs): + # + # Used to contain a first_type_count variable, which seemed to just count the number of states? pmf_examples = [[], [], []] state_trans = np.zeros((3, 3)) state_types = np.zeros((3, states.shape[1])) - first_type_count = 0 for state, pmf in zip(states, pmfs): sessions = np.where(state)[0] for i, sess_pmf in enumerate(pmf[1]): if i == 0: state_type = pmf_type(sess_pmf[pmf[0]]) - first_type_count += state_type == 0 if i > 0 and state_type != pmf_type(sess_pmf[pmf[0]]): state_trans[state_type, pmf_type(sess_pmf[pmf[0]])] += 1 state_type = pmf_type(sess_pmf[pmf[0]]) pmf_examples[pmf_type(sess_pmf[pmf[0]])].append(sess_pmf[pmf[0]]) state_types[pmf_type(sess_pmf[pmf[0]]), sessions[i]] += state[sessions[i]] - return state_types, state_trans, first_type_count, pmf_examples + return state_types, state_trans, pmf_examples def pmf_type(pmf): @@ -1477,7 +1499,6 @@ if __name__ == "__main__": n_points = 150 state_trajs = np.zeros((3, n_points)) state_trans = np.zeros((3, 3)) - first_type_count = 0 not_yet = True @@ -1496,69 +1517,48 @@ if __name__ == "__main__": mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) - lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) - continue - # - state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) - states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True) + + # lapse differential + # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) + + # training overview + # states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True) # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False) - consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) - consistencies /= consistencies[0, 0] - contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies) - quit() + # session overview + # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + # consistencies /= consistencies[0, 0] + # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies) - # state_types = np.zeros((3, states.shape[1])) - # for s, pmf in zip(states, pmfs): - # pmf_counter = -1 - # for i in range(states.shape[1]): - # if s[i]: - # pmf_counter += 1 - # state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i] # indexing horror - # - # durs = (np.where(state_types[1] > 0.5)[0][0], - # np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0], - # states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) - # if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]: - # durs = (np.where(state_types[2] > 0.5)[0][0], - # 0, - # states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) - # if durs[0] < 0 or durs[1] < 0 or durs[2] < 0: - # quit() + # duration of different state types (and also percentage of type activities) + # durs, _ = state_type_durs(states, pmfs) # abs_state_durs.append(durs) - # if durs[1] < 0: - # state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=False, separate_pmf=True) # rel_state_durs.append((durs[0] / states.shape[1], durs[1] / states.shape[1], durs[2] / states.shape[1])) - # continue - # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True) - # continue - # ret, trans, count, pmf_types = state_cluster_interpolation(states, pmfs) - # state_trans += trans - # first_type_count += count - # points = np.linspace(1, test.results[0].n_sessions, n_points) - # - # # plt.plot(np.arange(1, 1 + test.results[0].n_sessions), ret.T) - # for i, r in enumerate(ret): - # state_trajs[i] += np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r) - # # plt.plot(points, np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r)) - # # plt.show() - # continue + ret, trans, count, pmf_types = state_cluster_interpolation(states, pmfs) + state_trans += trans + points = np.linspace(1, test.results[0].n_sessions, n_points) + # plt.plot(np.arange(1, 1 + test.results[0].n_sessions), ret.T) + for i, r in enumerate(ret): + state_trajs[i] += np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r) + plt.plot(points, np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r)) + plt.show() + continue - # plt.plot(ret.T, label=[0, 1, 2]) - # plt.legend() - # plt.show() - # continue + plt.plot(ret.T, label=[0, 1, 2]) + plt.legend() + plt.show() + continue - # for i, pmfs in enumerate(pmf_types): - # plt.subplot(1, 3, i + 1) - # plt.plot([0, 11], [0.5, 0.5], 'grey', alpha=1/3) - # for pmf in pmfs: - # plt.plot(len_to_bools[len(pmf)], pmf) - # plt.ylim(0, 1) - # plt.show() - # continue + for i, pmfs in enumerate(pmf_types): + plt.subplot(1, 3, i + 1) + plt.plot([0, 11], [0.5, 0.5], 'grey', alpha=1/3) + for pmf in pmfs: + plt.plot(len_to_bools[len(pmf)], pmf) + plt.ylim(0, 1) + plt.show() + continue # quit() # for pmf in pmfs: