diff --git a/test_codes/pymc_compare/data_generation.py b/test_codes/pymc_compare/data_generation.py index 5c63b396bfe261abd26c0fd22b6dcb867ae408ed..077f61fbffd1f7ff9105a4c6146e213c60fc94f7 100644 --- a/test_codes/pymc_compare/data_generation.py +++ b/test_codes/pymc_compare/data_generation.py @@ -24,7 +24,7 @@ T = 12 n_inputs = 3 step_size = 0.03 Q = np.tile(np.eye(n_inputs), (T, 1, 1)) -test = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size) +test = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size, prior_mean=np.zeros(n_inputs)) w = np.zeros(n_inputs) # w = np.array([4.23061493, 2.14425199, -2.1125851]) # test.weights = w.reshape(T, n_inputs) @@ -45,7 +45,8 @@ for _ in range(T): sample = test.rvs(predictors, list(range(T))) pickle.dump(sample, open('test_data', 'wb')) -learn = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size) +learn = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size, prior_mean=np.zeros(n_inputs)) + def wrapper(w, t): learn.weights = np.tile(w, (T, 1)) @@ -71,34 +72,3 @@ for t in range(T): LL_weights[t] = minimize(f, np.zeros(n_inputs)).x pickle.dump((samples, LL_weights), open('gibbs_posterior', 'wb')) - -plt.figure(figsize=(16, 9)) -for i in range(n_inputs): - plt.subplot(n_inputs, 1, i+1) - label = 'Truth' if i == 0 else None - plt.plot(np.arange(T), test.weights[:, i], label=label) - label = 'LL' if i == 0 else None - plt.plot(np.arange(T), LL_weights[:, i], label=label) - sample_mean = np.mean(samples[:, :, i], axis=0) - label = 'Posterior mean' if i == 0 else None - plt.plot(np.arange(T), sample_mean, label=label, c='g') - credible_interval = np.percentile(samples[:, :, i], [2.5, 97.5], axis=0) - plt.fill_between(np.arange(T), credible_interval[1], credible_interval[0], alpha=0.2, color='g') - - print(i) - print(sample_mean) - print(credible_interval[1], credible_interval[0]) - - label = 'pymc mean' if i == 0 else None - plt.plot(np.arange(T), m[:, i], label=label, c='r') - plt.fill_between(np.arange(T), u[:, i], low[:, i], alpha=0.2, color='r') - - sns.despine() - if i == 0: - plt.legend(fontsize=18, frameon=False) - plt.xlim(left=0, right=T) - # for t in takefrom: - # plt.axvline(t) -plt.tight_layout() -plt.savefig("timepoint test") -plt.show() diff --git a/test_codes/pymc_compare/dynglm_optimisation_test.py b/test_codes/pymc_compare/dynglm_optimisation_test.py index 915400ac3d264aedc28e809a02a88c525056fd68..05df6e0cd7e1e02aa462861346bcc1e721d82e8f 100644 --- a/test_codes/pymc_compare/dynglm_optimisation_test.py +++ b/test_codes/pymc_compare/dynglm_optimisation_test.py @@ -22,7 +22,7 @@ T = 16 n_inputs = 3 step_size = 0.2 Q = np.tile(np.eye(n_inputs), (T, 1, 1)) -test = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size) +test = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size, prior_mean=np.zeros(n_inputs)) w = np.zeros(n_inputs) # w = np.array([4.23061493, 2.14425199, -2.1125851]) # test.weights = w.reshape(T, n_inputs) diff --git a/test_codes/pymc_compare/gibbs_posterior b/test_codes/pymc_compare/gibbs_posterior index 383917f8fe897dcb0afcb2ff3bc4529ffbcc771e..9956f9734279c1c49ddf940543147afd0c02cdc8 100644 Binary files a/test_codes/pymc_compare/gibbs_posterior and b/test_codes/pymc_compare/gibbs_posterior differ diff --git a/test_codes/pymc_compare/gibbs_sample.py b/test_codes/pymc_compare/gibbs_sample.py index 44f5c025c4bf346b6f3a455c734a74a526fd673e..ba07e3766f0f410ce0a27ddfe6c0a2a3a2a2eeae 100644 --- a/test_codes/pymc_compare/gibbs_sample.py +++ b/test_codes/pymc_compare/gibbs_sample.py @@ -10,7 +10,7 @@ step_size = 0.2 Q = np.tile(np.eye(n_inputs), (T, 1, 1)) sample = pickle.load(open('test_data', 'rb')) -learn = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size) +learn = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size, prior_mean=np.zeros(n_inputs)) def wrapper(w, t): diff --git a/test_codes/pymc_compare/pymc_posterior b/test_codes/pymc_compare/pymc_posterior index 9cd816741f1dd620ea44cde53a1254e7b47033cd..bddfe9db2d9337c1ef82131bb57af421a98b41d1 100644 Binary files a/test_codes/pymc_compare/pymc_posterior and b/test_codes/pymc_compare/pymc_posterior differ diff --git a/test_codes/pymc_compare/test_data b/test_codes/pymc_compare/test_data index 635eb2f04157bb19ba7638c795ef7b1ee15c39f1..0c79466950726e3dd85a57038135ee98cfc63ff4 100644 Binary files a/test_codes/pymc_compare/test_data and b/test_codes/pymc_compare/test_data differ diff --git a/test_codes/pymc_compare/truth b/test_codes/pymc_compare/truth index 82915e67299670118449b9a7788574ebb2bcc1ad..65de7816b3e0374a51ec2315b4f3698f4afd7f38 100644 Binary files a/test_codes/pymc_compare/truth and b/test_codes/pymc_compare/truth differ