diff --git a/pybasicbayes/distributions/dynamic_glm.py b/pybasicbayes/distributions/dynamic_glm.py index 1a538a7e7d16ddbba219ae6349610a47796a6d7e..117a210a75310deadb42b5b52509b4445a16b68a 100644 --- a/pybasicbayes/distributions/dynamic_glm.py +++ b/pybasicbayes/distributions/dynamic_glm.py @@ -7,12 +7,12 @@ from warnings import warn from pypolyagamma import PyPolyaGamma __all__ = ['Dynamic_GLM'] + def local_multivariate_normal_draw(x, sigma, normal): - """ - Function to combine pre-drawn Normals (normal) with the desired mean x and variance sigma - Cholesky doesn't like 0 cov matrix, but we want it. + """Function to combine pre-drawn Normals (normal) with the desired mean x and variance sigma - This might need changing if in practice we see plently of 0 matrices + Cholesky doesn't like 0 cov matrix, but we want it. + This could be usefully changed around if in practice we see plenty of 0 matrices """ try: return x + np.linalg.cholesky(sigma).dot(normal) @@ -41,7 +41,7 @@ class Dynamic_GLM(GibbsSampling): n_regressors: number of regressors for the GLM T: number of timesteps (sessions) prior_mean: mean of regressors at the beginning (usually 0 vector) - P_0: variance of regressors at the beginning + P_0: variance of regressors at the beginning (vague prior -> large variances in diagonal matrix) Q: variance of regressors between timesteps (can be different across steps, but we use the same matrix throughout) jumplimit: for how many timesteps after last being used are the state weights allowed to change """ @@ -58,12 +58,16 @@ class Dynamic_GLM(GibbsSampling): self.identity = np.eye(self.n_regressors) # not really needed, but kinda useful for state sampling self.weights = np.empty((self.T, self.n_regressors)) - self.weights[0] = np.random.multivariate_normal(mean=self.x_0, cov=self.P_0) - for t in range(1, T): + self.weights[0] = np.random.multivariate_normal(mean=self.x_0, cov=self.P_0) # initialise weights randomly... + for t in range(1, T): # ... then fill them up self.weights[t] = self.weights[t - 1] + np.random.multivariate_normal(mean=self.noise_mean, cov=self.Q[t - 1]) def rvs(self, inputs, times): - """Given the input features and their time points, create responses from the dynamic GLM weights.""" + """ + Given the input features and their time points, create responses from the dynamic GLM weights for each trial. + + This is for generative test purposes. + """ outputs = [] for input, t in zip(inputs, times): if input.shape[0] == 0: @@ -71,12 +75,11 @@ class Dynamic_GLM(GibbsSampling): else: # find the distinct sets of features, how often they exist, and how to put the answers back in place types, inverses, counts = np.unique(input, return_inverse=True, return_counts=True, axis=0) - # draw responses output = np.append(input, np.empty((input.shape[0], 1)), axis=1) for i, (type, c) in enumerate(zip(types, counts)): temp = np.random.rand(c) < 1 / (1 + np.exp(- np.sum(self.weights[t] * type))) - output[inverses == i, -1] = temp + output[inverses == i, -1] = temp # put responses in the right place outputs.append(output) return outputs @@ -88,10 +91,9 @@ class Dynamic_GLM(GibbsSampling): # I could possibly save the 1 / ..., since it's logged it's just - log (but the other half of the probs is an issue) probs[:, 1] = 1 / (1 + np.exp(- np.sum(self.weights[timepoint] * predictors, axis=1))) probs[:, 0] = 1 - probs[:, 1] - # probably not necessary, just fill everything with probs and then have some be 1 - out? out[~nans] = probs[np.arange(input.shape[0])[~nans], responses[~nans].astype(int)] - out = np.clip(out, np.spacing(1), 1 - np.spacing(1)) - out[nans] = 1 + out = np.clip(out, np.spacing(1), 1 - np.spacing(1)) # having an answer be impossible is not good, make sure everything is slightly possible + out[nans] = 1 # nans come from crossvalidation, every state generates this trial with prob. 1 return np.log(out) @@ -204,7 +206,13 @@ class Dynamic_GLM(GibbsSampling): # self.psi_diff_saves = np.concatenate(self.psi_diff_saves) def _get_statistics(self, data): - # TODO: improve + """ + Take the data assigned to one state, and collect their relevant statistics. + + For every session the state is active, we want to know: + What types of predictors did it encounter (types) + How did it respond to these (pseudo_counts, we already transform them for the sampling scheme) + """ summary_statistics = [[], [], []] times = [] if isinstance(data, np.ndarray):