From dfe77870fe0bbc0557a23c7b53d1f336f20c4081 Mon Sep 17 00:00:00 2001
From: SebastianBruijns <>
Date: Fri, 23 Jun 2023 18:16:11 +0200
Subject: [PATCH] updated dynamic glm code

---
 pybasicbayes/distributions/dynamic_glm.py | 94 +++++++++++------------
 1 file changed, 43 insertions(+), 51 deletions(-)

diff --git a/pybasicbayes/distributions/dynamic_glm.py b/pybasicbayes/distributions/dynamic_glm.py
index 3e1a38b..1a538a7 100644
--- a/pybasicbayes/distributions/dynamic_glm.py
+++ b/pybasicbayes/distributions/dynamic_glm.py
@@ -1,20 +1,18 @@
 from __future__ import division
 from builtins import zip
 from builtins import range
-__all__ = ['Dynamic_GLM']
-
-from pybasicbayes.abstractions import \
-    GibbsSampling
-
+from pybasicbayes.abstractions import GibbsSampling
 import numpy as np
 from warnings import warn
 from pypolyagamma import PyPolyaGamma
+__all__ = ['Dynamic_GLM']
 
 def local_multivariate_normal_draw(x, sigma, normal):
     """
+    Function to combine pre-drawn Normals (normal) with the desired mean x and variance sigma
     Cholesky doesn't like 0 cov matrix, but we want it.
 
-    TODO: This might need changing if in practice we see plently of 0 matrices
+    This might need changing if in practice we see plently of 0 matrices
     """
     try:
         return x + np.linalg.cholesky(sigma).dot(normal)
@@ -30,46 +28,51 @@ ppgsseed = 4
 if ppgsseed == 4:
     print("Using default seed")
 ppgs = PyPolyaGamma(ppgsseed)
+
+
 class Dynamic_GLM(GibbsSampling):
     """
     This class enables a drifting input output iHMM with logistic link function.
 
     States are thus dynamic GLMs, giving us more freedom as to the inputs we give the model.
 
-    Hyperparaemters:
-
-        TODO
+    Hyperparameters:
 
-    Parameters:
-        [weights]
+        n_regressors: number of regressors for the GLM
+        T: number of timesteps (sessions)
+        prior_mean: mean of regressors at the beginning (usually 0 vector)
+        P_0: variance of regressors at the beginning
+        Q: variance of regressors between timesteps (can be different across steps, but we use the same matrix throughout)
+        jumplimit: for how many timesteps after last being used are the state weights allowed to change
     """
 
-    def __init__(self, n_inputs, T, prior_mean, P_0, Q, jumplimit=3, seed=4):
+    def __init__(self, n_regressors, T, prior_mean, P_0, Q, jumplimit=1):
 
-        self.n_inputs = n_inputs
+        self.n_regressors = n_regressors
         self.T = T
         self.jumplimit = jumplimit
         self.x_0 = prior_mean
         self.P_0, self.Q = P_0, Q
-        self.psi_diff_saves = []
-        self.noise_mean = np.zeros(self.n_inputs)  # save this, so as to not keep creating it
-        self.identity = np.eye(self.n_inputs)  # not really needed, but kinda useful for state sampling, mabye delete TODO
+        self.psi_diff_saves = []  # this can be used to resample the variance, but is currently unused
+        self.noise_mean = np.zeros(self.n_regressors)  # save this, so as to not keep creating it
+        self.identity = np.eye(self.n_regressors)  # not really needed, but kinda useful for state sampling
 
-        # if seed == 4:
-        #     print("Using default seed")
-        # self.ppgs = PyPolyaGamma(seed)
-        self.weights = np.empty((self.T, self.n_inputs))  # one more spot for bias
+        self.weights = np.empty((self.T, self.n_regressors))
         self.weights[0] = np.random.multivariate_normal(mean=self.x_0, cov=self.P_0)
         for t in range(1, T):
             self.weights[t] = self.weights[t - 1] + np.random.multivariate_normal(mean=self.noise_mean, cov=self.Q[t - 1])
 
     def rvs(self, inputs, times):
+        """Given the input features and their time points, create responses from the dynamic GLM weights."""
         outputs = []
         for input, t in zip(inputs, times):
             if input.shape[0] == 0:
-                output = np.zeros((0, self.n_inputs + 1))
+                output = np.zeros((0, 1))
             else:
-                types, inverses, counts = np.unique(input, return_inverse=1, return_counts=True, axis=0)
+                # find the distinct sets of features, how often they exist, and how to put the answers back in place
+                types, inverses, counts = np.unique(input, return_inverse=True, return_counts=True, axis=0)
+
+                # draw responses
                 output = np.append(input, np.empty((input.shape[0], 1)), axis=1)
                 for i, (type, c) in enumerate(zip(types, counts)):
                     temp = np.random.rand(c) < 1 / (1 + np.exp(- np.sum(self.weights[t] * type)))
@@ -125,15 +128,7 @@ class Dynamic_GLM(GibbsSampling):
             timepoint_map[t] = total_types - 1
             prev_t = t
 
-        # print(total_types)
-        # print(actual_obs_count)
-        # print(change_points)
-        # print(fake_times)
-        # print(all_times)
-        # print(timepoint_map)
-        # return timepoint_map
-
-        self.pseudo_Q = np.zeros((total_types, self.n_inputs, self.n_inputs))
+        self.pseudo_Q = np.zeros((total_types, self.n_regressors, self.n_regressors))
         # TODO: is it okay to cut off last timepoint here?
         for k in range(self.T):
             if k in timepoint_map:
@@ -158,8 +153,8 @@ class Dynamic_GLM(GibbsSampling):
         self.pseudo_obs = np.zeros(total_types)
         self.pseudo_obs[mask] = np.concatenate(pseudo_counts) / temp
         self.pseudo_obs = self.pseudo_obs.reshape(total_types, 1)
-        self.H = np.zeros((total_types, self.n_inputs, 1))
-        self.H[mask] = np.array(predictors).reshape(actual_obs_count, self.n_inputs, 1)
+        self.H = np.zeros((total_types, self.n_regressors, 1))
+        self.H[mask] = np.array(predictors).reshape(actual_obs_count, self.n_regressors, 1)
 
         """compute means and sigmas by filtering"""
         # if there is no obs, sigma_k = sigma_k_k_minus and x_hat_k = x_hat_k_k_minus (because R is infinite at that time)
@@ -168,10 +163,10 @@ class Dynamic_GLM(GibbsSampling):
 
         """sample states"""
         self.weights.fill(0)
-        pseudo_weights = np.empty((total_types, self.n_inputs))
+        pseudo_weights = np.empty((total_types, self.n_regressors))
         pseudo_weights[total_types - 1] = np.random.multivariate_normal(self.x_hat_k[total_types - 1], self.sigma_k[total_types - 1])
 
-        normals = np.random.standard_normal((total_types - 1, self.n_inputs))
+        normals = np.random.standard_normal((total_types - 1, self.n_regressors))
         for k in range(total_types - 2, -1, -1):  # normally -1, but we already did first sampling
             if np.all(self.pseudo_Q[k] == 0):
                 pseudo_weights[k] = pseudo_weights[k + 1]
@@ -179,7 +174,7 @@ class Dynamic_GLM(GibbsSampling):
                 updated_x = self.x_hat_k[k].copy()  # not sure whether copy is necessary here
                 updated_sigma = self.sigma_k[k].copy()
 
-                for m in range(self.n_inputs):
+                for m in range(self.n_regressors):
                     epsilon = pseudo_weights[k + 1, m] - updated_x[m]
                     state_R = updated_sigma[m, m] + self.pseudo_Q[k, m, m]
 
@@ -192,7 +187,7 @@ class Dynamic_GLM(GibbsSampling):
             if k in timepoint_map:
                 self.weights[k] = pseudo_weights[timepoint_map[k]]
 
-        """don't forget to sample before and after active times too"""
+        """Sample before and after active times too"""
         for k in range(all_times[0] - 1, -1, -1):
             if k > all_times[0] - self.jumplimit - 1:
                 self.weights[k] = self.weights[k + 1] + np.random.multivariate_normal(self.noise_mean, self.Q[k])
@@ -205,7 +200,7 @@ class Dynamic_GLM(GibbsSampling):
                 self.weights[k] = self.weights[k - 1]
 
         return pseudo_weights
-        # TODO:
+        # If one wants to resample variance...
         # self.psi_diff_saves = np.concatenate(self.psi_diff_saves)
 
     def _get_statistics(self, data):
@@ -215,9 +210,6 @@ class Dynamic_GLM(GibbsSampling):
         if isinstance(data, np.ndarray):
             warn('What you are trying is probably stupid, at least the code is not implemented')
             quit()
-            # assert len(data.shape) == 2
-            # for d in data:
-            #     counts[tuple(d)] += 1
         else:
             for i, d in enumerate(data):
                 clean_d = d[~np.isnan(d[:, -1])]
@@ -255,25 +247,25 @@ class Dynamic_GLM(GibbsSampling):
                 self.sigma_k_k_minus.append(self.sigma_k[k] + self.pseudo_Q[k])
 
     def compute_means(self, T):
-        """Compute the means, the estimates of the states."""
-        self.x_hat_k = []  # we have to reset this for repeating this calculation later for the resampling
-        self.x_hat_k_k_minus = [self.x_0]
+        """Compute the means, the estimates of the states.
+        Used to also contain self.x_hat_k_k_minus, but it's not necessary for our setup"""
+        self.x_hat_k = [self.x_0]  # we have to reset this for repeating this calculation later for the resampling
         for k in range(T):  # this will leave out last state which doesn't have observation
             if self.gain_save[k] is None:
-                self.x_hat_k.append(self.x_hat_k_k_minus[k])
-                self.x_hat_k_k_minus.append(self.x_hat_k[k])  # TODO: still no purpose
+                self.x_hat_k.append(self.x_hat_k[k])
             else:
-                x, H = self.x_hat_k_k_minus[k], self.H[k]  # we will need this a lot, so shorten it
+                x, H = self.x_hat_k[k], self.H[k]  # we will need this a lot, so shorten it
                 self.x_hat_k.append(x + self.gain_save[k].dot(self.pseudo_obs[k] - H.T.dot(x)))
-                self.x_hat_k_k_minus.append(self.x_hat_k[k])  # TODO: doesn't really have a purpose if F is identity
+
+        self.x_hat_k.pop(0)  # remove initialisation element from list
 
     def num_parameters(self):
         return self.weights.size
 
     ### Max likelihood
-
-    def max_likelihood(self,data,weights=None):
+    def max_likelihood(self, data, weights=None):
         warn('ML not implemented')
 
-    def MAP(self,data,weights=None):
+
+    def MAP(self, data, weights=None):
         warn('MAP not implemented')
-- 
GitLab