diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc index 34a4499ac8d78ac0139d1f5675718e1ab52cae6d..e1ab5f64794b712fd3d7753bfc894d24535085b8 100644 Binary files a/__pycache__/analysis_pmf.cpython-37.pyc and b/__pycache__/analysis_pmf.cpython-37.pyc differ diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc index 8417f61340255c4910e14c7ab0a3f0d48670b60d..d40c9990e4c7cc9a2781a81174008f4cb3e8b339 100644 Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ diff --git a/analysis_pmf.py b/analysis_pmf.py index 6729289c14616f8cbf4beb722a909c8f323c8601..637f06cd68c5139c19f7171c90f5b28c4a86c15a 100644 --- a/analysis_pmf.py +++ b/analysis_pmf.py @@ -28,32 +28,32 @@ if __name__ == "__main__": state_types_interpolation = state_types_interpolation / state_types_interpolation.max() * 100 fs = 18 - plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0]) - plt.ylabel("% of type across population", size=fs) - plt.xlabel("Interpolated session time", size=fs) - plt.ylim(0, 100) - sns.despine() - plt.tight_layout() - plt.savefig("type hist 1") - plt.show() - - plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1]) - plt.ylabel("% of type across population", size=fs) - plt.xlabel("Interpolated session time", size=fs) - plt.ylim(0, 100) - sns.despine() - plt.tight_layout() - plt.savefig("type hist 2") - plt.show() - - plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2]) - plt.ylabel("% of type across population", size=fs) - plt.xlabel("Interpolated session time", size=fs) - plt.ylim(0, 100) - sns.despine() - plt.tight_layout() - plt.savefig("type hist 3") - plt.show() + # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0]) + # plt.ylabel("% of type across population", size=fs) + # plt.xlabel("Interpolated session time", size=fs) + # plt.ylim(0, 100) + # sns.despine() + # plt.tight_layout() + # plt.savefig("type hist 1") + # plt.show() + # + # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1]) + # plt.ylabel("% of type across population", size=fs) + # plt.xlabel("Interpolated session time", size=fs) + # plt.ylim(0, 100) + # sns.despine() + # plt.tight_layout() + # plt.savefig("type hist 2") + # plt.show() + # + # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2]) + # plt.ylabel("% of type across population", size=fs) + # plt.xlabel("Interpolated session time", size=fs) + # plt.ylim(0, 100) + # sns.despine() + # plt.tight_layout() + # plt.savefig("type hist 3") + # plt.show() all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless.p", 'rb')) all_pmfs = pickle.load(open("all_pmfs.p", 'rb')) @@ -161,37 +161,41 @@ if __name__ == "__main__": lw = 4 # Simplex example pmfs - # state_num = 7 - # defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1] - # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) - # state_num = 6 - # defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1] - # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) - # - # plt.ylim(0, 1) - # plt.xlim(0, 10) - # plt.yticks([]) - # plt.xticks([]) - # sns.despine() - # plt.tight_layout() - # plt.savefig("example type 1") - # plt.show() - # - # state_num = 1 - # defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1] - # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) - # state_num = 3 - # defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1] - # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) - # - # plt.ylim(0, 1) - # plt.xlim(0, 10) - # plt.yticks([]) - # plt.xticks([]) - # sns.despine() - # plt.tight_layout() - # plt.savefig("example type 2") - # plt.show() + state_num = 7 + defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + state_num = 6 + defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + + plt.ylim(0, 1) + plt.xlim(0, 10) + plt.ylabel("P(rightwards)", size=32) + plt.xlabel("Contrast", size=32) + plt.yticks([0, 1], size=27) + plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27) + sns.despine() + plt.tight_layout() + plt.savefig("example type 1") + plt.show() + + state_num = 1 + defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + state_num = 3 + defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + + plt.ylim(0, 1) + plt.xlim(0, 10) + plt.ylabel("P(rightwards)", size=32) + plt.xlabel("Contrast", size=32) + plt.yticks([0, 1], size=27) + plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27) + sns.despine() + plt.tight_layout() + plt.savefig("example type 2") + plt.show() state_num = 4 defined_points, pmf = all_first_pmfs_typeless['ibl_witten_17'][state_num][0], all_first_pmfs_typeless['ibl_witten_17'][state_num][1] @@ -199,8 +203,10 @@ if __name__ == "__main__": plt.ylim(0, 1) plt.xlim(0, 10) - plt.yticks([]) - plt.xticks([]) + plt.ylabel("P(rightwards)", size=32) + plt.xlabel("Contrast", size=32) + plt.yticks([0, 1], size=27) + plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27) sns.despine() plt.tight_layout() plt.savefig("example type 3") diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index f99904a98921b921be4f1a181e106a051eab03da..c61d4c7c659c1360646febd6889a35d401ebdc37 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -526,7 +526,7 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur # print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1") # input() - label = "State {}".format(state) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None + label = "State {}".format(len(state_sets) - test.state_mapping[state]) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None # state_c_n_a = c_n_a[relevant_trials - trial_counter] # print(state) @@ -571,7 +571,7 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur plt.xlabel('Trial', size=28) sns.despine() - # plt.xlim(left=250, right=450) + plt.xlim(left=250, right=450) plt.legend(frameon=False, fontsize=22, bbox_to_anchor=(0.8, 0.5)) plt.tight_layout() if save: @@ -854,22 +854,25 @@ def state_set_and_plot(test, mode_prefix, subject, fit_type): def state_pmfs(test, trials, indices): - def func_init(): return {'pmfs': [], 'session_js': []} + def func_init(): return {'pmfs': [], 'session_js': [], 'pmf_weights': []} def first_for(test, results): results['pmf'] = np.zeros(test.results[0].n_contrasts) + results['pmf_weight'] = np.zeros(4) def second_for(m, j, session_trials, trial_counter, results): states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True) for sub_state, c in zip(states, counts): results['pmf'] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / session_trials.shape[0] + results['pmf_weight'] += m.obs_distns[sub_state].weights[j] * c / session_trials.shape[0] def end_first_for(results, indices, j, **kwargs): results['pmfs'].append(results['pmf'] / len(indices)) + results['pmf_weights'].append(results['pmf_weight'] / len(indices)) results['session_js'].append(j) results = control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for) - return results['session_js'], results['pmfs'] + return results['session_js'], results['pmfs'], results['pmf_weights'] def state_weights(test, trials, indices): @@ -1138,6 +1141,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show ax0.annotate('Bias', (test.results[0].infos['bias_start'] + 1 - 0.5, 0.68), fontsize=22) all_pmfs = [] + all_pmf_weights = [] cmaps = ['Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds', 'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu'] np.random.seed(8) np.random.shuffle(cmaps) @@ -1147,7 +1151,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show for state, trials in enumerate(state_sets): if separate_pmf: n_trials = len(trials) - session_js, pmfs = state_pmfs(test, trials, indices) + session_js, pmfs, _ = state_pmfs(test, trials, indices) else: pmfs = np.zeros((len(indices), test.results[0].n_contrasts)) n_trials = len(trials) @@ -1173,9 +1177,10 @@ def state_development(test, state_sets, indices, save=True, save_append='', show if separate_pmf: n_trials = len(trials) - session_js, pmfs = state_pmfs(test, trials, indices) + session_js, pmfs, pmf_weights = state_pmfs(test, trials, indices) else: pmfs = np.zeros((len(indices), test.results[0].n_contrasts)) + pmf_weights = np.zeros((len(indices), test.results[0].obs_distns[0].weights.shape[0])) n_trials = len(trials) counter = 0 for i, m in enumerate([item for sublist in test.results for item in sublist.models]): @@ -1188,6 +1193,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show states, counts = np.unique(state_seq[session_trials - trial_counter], return_counts=True) for sub_state, c in zip(states, counts): pmfs[counter] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / n_trials + pmf_weights[counter] += m.obs_distns[sub_state].weights[j] * c / n_trials trial_counter += len(state_seq) counter += 1 @@ -1218,7 +1224,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show if not test.state_mapping[state] in dont_plot: ax1.fill_between([points[k], points[k+1]], test.state_mapping[state] - 0.5, [test.state_mapping[state] + interpolation[k] - 0.5, test.state_mapping[state] + interpolation[k+1] - 0.5], color=cmap(0.3 + 0.7 * k / n_points)) - ax1.annotate(test.state_mapping[state] + 1, (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False) + ax1.annotate(len(state_sets) - test.state_mapping[state], (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False) if test.results[0].name.startswith('GLM_Sim_'): ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][test.state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r') @@ -1235,11 +1241,12 @@ def state_development(test, state_sets, indices, save=True, save_append='', show # defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool) # defined_points[[0, 1, -2, -1]] = True if separate_pmf: - for j, pmf in zip(session_js, pmfs): + for j, pmf, pmf_weight in zip(session_js, pmfs, pmf_weights): if not test.state_mapping[state] in dont_plot: ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions)) ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions)) all_pmfs.append((defined_points, pmfs)) + all_pmf_weights += pmf_weights else: temp = np.percentile(pmfs, [2.5, 97.5], axis=0) if not test.state_mapping[state] in dont_plot: @@ -1313,7 +1320,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show ax2.set_title('Psychometric\nfunction', size=16) ax1.set_ylabel('Proportion of trials', size=28, labelpad=-20) ax0.set_ylabel('% correct', size=18) - ax2.set_ylabel('Probability', size=26, labelpad=-20) + ax2.set_ylabel('P(rightwards answer)', size=26, labelpad=-20) ax1.set_xlabel('Session', size=28) ax2.set_xlabel('Contrast', size=26) ax1.set_xlim(left=1, right=test.results[0].n_sessions) @@ -1352,7 +1359,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show else: plt.close() - return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type + return states_by_session, all_pmfs, all_pmf_weights, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, all_pmfs, title=""): """ @@ -1670,7 +1677,6 @@ if __name__ == "__main__": # subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] # subjects = ['KS014'] - # meh pmfs: KS021 print(subjects) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] dur = 'yes' @@ -1709,6 +1715,7 @@ if __name__ == "__main__": regressions = [] regression_diffs = [] all_bias_flips = [] + all_pmf_weights = [] temp_counter = 0 @@ -1744,15 +1751,20 @@ if __name__ == "__main__": # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(7)), plot_until=2) # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(6)), plot_until=7) # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(4)), plot_until=13) - states, pmfs, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) - all_state_types.append(state_types) + states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) + all_pmf_weights += pmf_weights continue + + consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + consistencies /= consistencies[0, 0] + contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False) + + all_state_types.append(state_types) # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0]) # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1]) # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2]) # temp_counter += 1 - continue # b_flips = bias_flips(states, pmfs, durs) # all_bias_flips.append(b_flips) @@ -1915,6 +1927,8 @@ if __name__ == "__main__": # pickle.dump(regression_diffs, open("regression_diffs.p", 'wb')) # pickle.dump(all_bias_flips, open("all_bias_flips.p", 'wb')) # pickle.dump(all_state_types, open("all_state_types.p", 'wb')) + pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb')) + quit() if True: diff --git a/pmf_weight_analysis.py b/pmf_weight_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..00611bcc7e1200cbe4a76a79ebde4510408bea3b --- /dev/null +++ b/pmf_weight_analysis.py @@ -0,0 +1,91 @@ +import numpy as np +import matplotlib.pyplot as plt +import pickle +from scipy.stats import gaussian_kde +from analysis_pmf import pmf_type, type2color + +# all pmf weights +apw = np.array(pickle.load(open("all_pmf_weights.p", 'rb'))) + +xy = np.vstack([apw[:, i] for i in range(4)]) +z = gaussian_kde(xy)(xy) + +plt.subplot(1, 3, 1) +plt.scatter(apw[:, 0], apw[:, 1], c=z) +plt.xlabel("Cont right") +plt.ylabel("Cont left") + +plt.subplot(1, 3, 2) +plt.scatter(apw[:, 3], apw[:, 1], c=z) +plt.xlabel("Bias") +plt.ylabel("Cont left") + +plt.subplot(1, 3, 3) +plt.scatter(apw[:, 3], apw[:, 0], c=z) +plt.xlabel("Bias") +plt.ylabel("Cont right") + +plt.show() + + +contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0]) +contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1] + +def weights_to_pmf(weights, with_bias=1): + psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1] + return 1 / (1 + np.exp(psi)) + +colors = [type2color[pmf_type(weights_to_pmf(x))] for x in apw] + +plt.subplot(1, 3, 1) +sc = plt.scatter(apw[:, 0], apw[:, 1], c=colors) +fig, ax = plt.gcf(), plt.gca() +plt.xlabel("Cont right") +plt.ylabel("Cont left") + +annot = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + arrowprops=dict(arrowstyle="->")) +annot.set_visible(False) + +def update_annot(ind): + print(ind) + pos = sc.get_offsets()[ind["ind"][0]] + annot.xy = pos + text = "{}".format(np.round(apw[ind["ind"][0]], 2)) + annot.set_text(text) + +def hover(event): + vis = annot.get_visible() + if event.inaxes == ax: + cont, ind = sc.contains(event) + if cont: + update_annot(ind) + annot.set_visible(True) + fig.canvas.draw_idle() + else: + if vis: + annot.set_visible(False) + fig.canvas.draw_idle() + +fig.canvas.mpl_connect("motion_notify_event", hover) + +plt.subplot(1, 3, 2) +plt.scatter(apw[:, 3], apw[:, 1], c=colors) +plt.xlabel("Bias") +plt.ylabel("Cont left") + +plt.subplot(1, 3, 3) +plt.scatter(apw[:, 3], apw[:, 0], c=colors) +plt.xlabel("Bias") +plt.ylabel("Cont right") + +plt.show() + + +from mpl_toolkits import mplot3d + +fig = plt.figure() +ax = plt.axes(projection='3d') +ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors) +plt.show() diff --git a/simplex_animation.py b/simplex_animation.py index 2dd78ee1c91cd03961b826c39d3ba9604e06a08b..89dfe945c0994f472c74602d89541d12ac1d5fb1 100644 --- a/simplex_animation.py +++ b/simplex_animation.py @@ -38,6 +38,7 @@ assert (test_count == 1).all() # quit() session_counter = 0 +alph = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'za', 'zb', 'zc', 'zd', 'ze', 'zf', 'zg', 'zh', 'zi', 'zj'] # do as many sessions as it takes while True: @@ -64,9 +65,9 @@ while True: type_proportions[:, i] = sts[:, temp_counter] - plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title=None) - plt.title(session_counter) - plt.show() + plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(session_counter), + vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(alph[session_counter])) + plt.close() if not_ended == 0: break diff --git a/simplex_plot.py b/simplex_plot.py index 62bf36276da1b9f0f199687e968371d0d6673a46..635885d05417cb2c7810ec74e6d5cb5572e6f0b7 100644 --- a/simplex_plot.py +++ b/simplex_plot.py @@ -15,7 +15,7 @@ import matplotlib.patches as PA def plotSimplex(points, fig=None, vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'], - show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, save_title="dur_simplex.png", **kwargs): + save_title="test.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs): """ Plot Nx3 points array on the 3-simplex (with optionally labeled vertices) @@ -32,17 +32,17 @@ def plotSimplex(points, fig=None, fig.gca().xaxis.set_major_locator(MT.NullLocator()) fig.gca().yaxis.set_major_locator(MT.NullLocator()) # Draw vertex labels - fig.gca().annotate(vertexlabels[0], (-0.35, -0.05), size=24, color=vertexcolors[0], annotation_clip=False) - fig.gca().annotate(vertexlabels[1], (0.6, -0.05), size=24, color=vertexcolors[1], annotation_clip=False) - fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False) + # fig.gca().annotate(vertexlabels[0], (-0.35, -0.05), size=24, color=vertexcolors[0], annotation_clip=False) + # fig.gca().annotate(vertexlabels[1], (0.6, -0.05), size=24, color=vertexcolors[1], annotation_clip=False) + # fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False) # Project and draw the actual points projected = projectSimplex(points / points.sum(1)[:, None]) - # print(projected) - P.scatter(projected[:, 0] + x_offset, projected[:, 1] + y_offset, s=35, **kwargs)#s=points.sum(1) * 3.5 + print(projected) + P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs) # plot center with average size projected = projectSimplex(np.mean(points / points.sum(1)[:, None], axis=0).reshape(1, 3)) - P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=50)#np.mean(points.sum(1)) * 3.5) + P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5) # Leave some buffer around the triangle for vertex labels fig.gca().set_xlim(-0.05, 1.05) @@ -51,10 +51,11 @@ def plotSimplex(points, fig=None, P.axis('off') P.tight_layout() - if save_title: - P.savefig(save_title, bbox_inches='tight', dpi=300, transparent=True) + P.savefig("dur_simplex.png", bbox_inches='tight', dpi=300, transparent=True) if show: P.show() + else: + P.close() def projectSimplex(points):