f,axs=plt.subplots(n_weights+1,3,figsize=(12,9))# I added a third pseudo weight for distance between the two sensitivities, and do another extra to split the bias
f,axs=plt.subplots(n_weights+1+show_weight_augmentations,n_types-sudden_changes,figsize=(4*(3-sudden_changes),9))# We add another weight slot to split the bias
forstate_type,weight_gapsinenumerate(data):
forgapinweight_gaps:
foriinrange(n_weights):
ifi==n_weights-1:
axs[i+(gap[1][i]>0),state_type].plot([0,1],[gap[0][i],gap[-1][i]],marker="o")# plot how the weight evolves from first to last appearance
axs[i+(gap[1][i]>0),state_type].plot([0,1],[gap[0][i],gap[-1][i]],marker="o")# plot how the weight of the new state differs from the previous closest weight
assertx_lim_used_normal>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_normal,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
assertx_lim_used_bias>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_bias,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
assertx_lim_used_augment>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_augment,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
assertx_lim_used_full>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_full,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
assertx_lim_used_half>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_half,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
# axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 0, -1] < 0
# axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again
# axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 1, -1] < 0
# axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again
bin_sets=[np.linspace(-6.5,6.5,30),np.linspace(-1,2,30),np.linspace(-3.5,3.5,30),np.linspace(-0.2,1,30)]# TODO: make show_augmentation robuse
ifj<n_weights-1:
assertx_lim_used_full>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_full,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
assertx_lim_used_half>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_half,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
average,counter,all_datapoints=plot_traces_and_collate_data(data=all_sudden_transition_changes,augmented_data=aug_all_sudden_transition_changes,title="sudden_weight_change at transitions")
ifi==0:
axs[j,i].set_ylabel(ylabels[j])
axs[j,i*2+1].yaxis.set_ticklabels([])
# if j < n_weights - 1:
# axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 0, -1] < 0
# axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again
# axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 1, -1] < 0
# axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again
assertx_lim_used_full>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_full,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
assertx_lim_used_half>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_half,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
assertx_lim_used_full>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_full,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))
assertx_lim_used_half>max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]),"Hists are cut off ({} vs {})".format(x_lim_used_half,max(axs[j,i*2].set_xlim()[0],axs[j,i*2+1].set_xlim()[1]))