From 89990a88a6bda5365d2403aa84b320ae21034081 Mon Sep 17 00:00:00 2001
From: SebastianBruijns <>
Date: Thu, 15 Sep 2022 14:23:11 +0200
Subject: [PATCH] my changes

---
 pybasicbayes/distributions/__init__.py        |   1 +
 pybasicbayes/distributions/dynamic_glm.py     | 279 +++++++++++++++++
 .../distributions/dynamic_multinomial.py      | 286 ++++++++++++++++++
 .../dynamic_multinomial_in_progress.py        | 265 ++++++++++++++++
 pybasicbayes/distributions/gaussian.py        |   3 +-
 pybasicbayes/distributions/multinomial.py     | 176 ++++++++++-
 .../distributions/negativebinomial.py         |   5 +-
 7 files changed, 1009 insertions(+), 6 deletions(-)
 create mode 100644 pybasicbayes/distributions/dynamic_glm.py
 create mode 100644 pybasicbayes/distributions/dynamic_multinomial.py
 create mode 100644 pybasicbayes/distributions/dynamic_multinomial_in_progress.py

diff --git a/pybasicbayes/distributions/__init__.py b/pybasicbayes/distributions/__init__.py
index 2318d13..d640135 100644
--- a/pybasicbayes/distributions/__init__.py
+++ b/pybasicbayes/distributions/__init__.py
@@ -10,3 +10,4 @@ from .multinomial import *
 from .negativebinomial import *
 from .geometric import *
 from .poisson import *
+from .dynamic_glm import *
diff --git a/pybasicbayes/distributions/dynamic_glm.py b/pybasicbayes/distributions/dynamic_glm.py
new file mode 100644
index 0000000..3e1a38b
--- /dev/null
+++ b/pybasicbayes/distributions/dynamic_glm.py
@@ -0,0 +1,279 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+__all__ = ['Dynamic_GLM']
+
+from pybasicbayes.abstractions import \
+    GibbsSampling
+
+import numpy as np
+from warnings import warn
+from pypolyagamma import PyPolyaGamma
+
+def local_multivariate_normal_draw(x, sigma, normal):
+    """
+    Cholesky doesn't like 0 cov matrix, but we want it.
+
+    TODO: This might need changing if in practice we see plently of 0 matrices
+    """
+    try:
+        return x + np.linalg.cholesky(sigma).dot(normal)
+    except np.linalg.LinAlgError:
+        if np.isclose(sigma, 0).all():
+            return x
+        else:
+            print("Weird covariance matrix")
+            quit()
+
+
+ppgsseed = 4
+if ppgsseed == 4:
+    print("Using default seed")
+ppgs = PyPolyaGamma(ppgsseed)
+class Dynamic_GLM(GibbsSampling):
+    """
+    This class enables a drifting input output iHMM with logistic link function.
+
+    States are thus dynamic GLMs, giving us more freedom as to the inputs we give the model.
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights]
+    """
+
+    def __init__(self, n_inputs, T, prior_mean, P_0, Q, jumplimit=3, seed=4):
+
+        self.n_inputs = n_inputs
+        self.T = T
+        self.jumplimit = jumplimit
+        self.x_0 = prior_mean
+        self.P_0, self.Q = P_0, Q
+        self.psi_diff_saves = []
+        self.noise_mean = np.zeros(self.n_inputs)  # save this, so as to not keep creating it
+        self.identity = np.eye(self.n_inputs)  # not really needed, but kinda useful for state sampling, mabye delete TODO
+
+        # if seed == 4:
+        #     print("Using default seed")
+        # self.ppgs = PyPolyaGamma(seed)
+        self.weights = np.empty((self.T, self.n_inputs))  # one more spot for bias
+        self.weights[0] = np.random.multivariate_normal(mean=self.x_0, cov=self.P_0)
+        for t in range(1, T):
+            self.weights[t] = self.weights[t - 1] + np.random.multivariate_normal(mean=self.noise_mean, cov=self.Q[t - 1])
+
+    def rvs(self, inputs, times):
+        outputs = []
+        for input, t in zip(inputs, times):
+            if input.shape[0] == 0:
+                output = np.zeros((0, self.n_inputs + 1))
+            else:
+                types, inverses, counts = np.unique(input, return_inverse=1, return_counts=True, axis=0)
+                output = np.append(input, np.empty((input.shape[0], 1)), axis=1)
+                for i, (type, c) in enumerate(zip(types, counts)):
+                    temp = np.random.rand(c) < 1 / (1 + np.exp(- np.sum(self.weights[t] * type)))
+                    output[inverses == i, -1] = temp
+            outputs.append(output)
+        return outputs
+
+    def log_likelihood(self, input, timepoint):
+        predictors, responses = input[:, :-1], input[:, -1]
+        nans = np.isnan(responses)
+        probs = np.zeros((input.shape[0], 2))
+        out = np.zeros(input.shape[0])
+        # I could possibly save the 1 / ..., since it's logged it's just - log (but the other half of the probs is an issue)
+        probs[:, 1] = 1 / (1 + np.exp(- np.sum(self.weights[timepoint] * predictors, axis=1)))
+        probs[:, 0] = 1 - probs[:, 1]
+        # probably not necessary, just fill everything with probs and then have some be 1 - out?
+        out[~nans] = probs[np.arange(input.shape[0])[~nans], responses[~nans].astype(int)]
+        out = np.clip(out, np.spacing(1), 1 - np.spacing(1))
+        out[nans] = 1
+
+        return np.log(out)
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        # TODO: Clean up this mess, I always have to call delete_obs_data because of all the saved shit!
+        self.psi_diff_saves = []
+        summary_statistics, all_times = self._get_statistics(data)
+        types, pseudo_counts, counts = summary_statistics
+
+        # if state is never active, resample from prior, but without dynamic change
+        if len(counts) == 0:
+            self.weights = np.tile(np.random.multivariate_normal(mean=self.x_0, cov=self.P_0), (self.T, 1))
+            return
+
+        """compute Kalman filter parameters, also sort out which weight vector goes to which timepoint"""
+        timepoint_map = {}
+        total_types = 0
+        actual_obs_count = 0
+        change_points = []
+        prev_t = all_times[0] - 1
+        fake_times = []
+        for type, t in zip(types, all_times):
+            if t > prev_t + 1:
+                add_list = list(range(total_types, min(total_types + t - prev_t - 1, total_types + self.jumplimit)))
+                change_points += add_list
+                fake_times += add_list
+                for i, sub_t in enumerate(range(total_types, total_types + t - prev_t - 1)): # TODO: set up this loop better
+                    timepoint_map[prev_t + i + 1] = min(sub_t, total_types + self.jumplimit - 1)
+                total_types += min(t - prev_t - 1, self.jumplimit)
+            total_types += type.shape[0]
+            actual_obs_count += type.shape[0]
+            change_points.append(total_types - 1)
+            timepoint_map[t] = total_types - 1
+            prev_t = t
+
+        # print(total_types)
+        # print(actual_obs_count)
+        # print(change_points)
+        # print(fake_times)
+        # print(all_times)
+        # print(timepoint_map)
+        # return timepoint_map
+
+        self.pseudo_Q = np.zeros((total_types, self.n_inputs, self.n_inputs))
+        # TODO: is it okay to cut off last timepoint here?
+        for k in range(self.T):
+            if k in timepoint_map:
+                self.pseudo_Q[timepoint_map[k]] = self.Q[k]  # for every timepoint, map it's variance onto the pseudo_Q
+
+        """sample pseudo obs"""
+        temp = np.empty(actual_obs_count)
+        psis = np.empty(actual_obs_count)
+        psi_count = 0
+        predictors = []
+        for type, time in zip(types, all_times):
+            for t in type:
+                psis[psi_count] = np.sum(self.weights[time] * t)
+                predictors.append(t)
+                psi_count += 1
+
+        ppgs.pgdrawv(np.concatenate(counts).astype(float), psis, temp)
+        self.R = np.zeros(total_types)
+        mask = np.ones(total_types, dtype=np.bool)
+        mask[fake_times] = False
+        self.R[mask] = 1 / temp
+        self.pseudo_obs = np.zeros(total_types)
+        self.pseudo_obs[mask] = np.concatenate(pseudo_counts) / temp
+        self.pseudo_obs = self.pseudo_obs.reshape(total_types, 1)
+        self.H = np.zeros((total_types, self.n_inputs, 1))
+        self.H[mask] = np.array(predictors).reshape(actual_obs_count, self.n_inputs, 1)
+
+        """compute means and sigmas by filtering"""
+        # if there is no obs, sigma_k = sigma_k_k_minus and x_hat_k = x_hat_k_k_minus (because R is infinite at that time)
+        self.compute_sigmas(total_types)
+        self.compute_means(total_types)
+
+        """sample states"""
+        self.weights.fill(0)
+        pseudo_weights = np.empty((total_types, self.n_inputs))
+        pseudo_weights[total_types - 1] = np.random.multivariate_normal(self.x_hat_k[total_types - 1], self.sigma_k[total_types - 1])
+
+        normals = np.random.standard_normal((total_types - 1, self.n_inputs))
+        for k in range(total_types - 2, -1, -1):  # normally -1, but we already did first sampling
+            if np.all(self.pseudo_Q[k] == 0):
+                pseudo_weights[k] = pseudo_weights[k + 1]
+            else:
+                updated_x = self.x_hat_k[k].copy()  # not sure whether copy is necessary here
+                updated_sigma = self.sigma_k[k].copy()
+
+                for m in range(self.n_inputs):
+                    epsilon = pseudo_weights[k + 1, m] - updated_x[m]
+                    state_R = updated_sigma[m, m] + self.pseudo_Q[k, m, m]
+
+                    updated_x += updated_sigma[:, m] * epsilon / state_R  # I don't think it's important, but probs we need the first column
+                    updated_sigma -= updated_sigma.dot(np.outer(self.identity[m], self.identity[m])).dot(updated_sigma) / state_R
+
+                pseudo_weights[k] = local_multivariate_normal_draw(updated_x, updated_sigma, normals[k])
+
+        for k in range(self.T):
+            if k in timepoint_map:
+                self.weights[k] = pseudo_weights[timepoint_map[k]]
+
+        """don't forget to sample before and after active times too"""
+        for k in range(all_times[0] - 1, -1, -1):
+            if k > all_times[0] - self.jumplimit - 1:
+                self.weights[k] = self.weights[k + 1] + np.random.multivariate_normal(self.noise_mean, self.Q[k])
+            else:
+                self.weights[k] = self.weights[k + 1]
+        for k in range(all_times[-1] + 1, self.T):
+            if k < min(all_times[-1] + 1 + self.jumplimit, self.T):
+                self.weights[k] = self.weights[k - 1] + np.random.multivariate_normal(self.noise_mean, self.Q[k])
+            else:
+                self.weights[k] = self.weights[k - 1]
+
+        return pseudo_weights
+        # TODO:
+        # self.psi_diff_saves = np.concatenate(self.psi_diff_saves)
+
+    def _get_statistics(self, data):
+        # TODO: improve
+        summary_statistics = [[], [], []]
+        times = []
+        if isinstance(data, np.ndarray):
+            warn('What you are trying is probably stupid, at least the code is not implemented')
+            quit()
+            # assert len(data.shape) == 2
+            # for d in data:
+            #     counts[tuple(d)] += 1
+        else:
+            for i, d in enumerate(data):
+                clean_d = d[~np.isnan(d[:, -1])]
+                if len(clean_d) != 0:
+                    predictors, responses = clean_d[:, :-1], clean_d[:, -1]
+                    types, inverses, counts = np.unique(predictors, return_inverse=True, return_counts=True, axis=0)
+                    pseudo_counts = np.zeros(len(types))
+                    for j, c in enumerate(counts):
+                        mask = inverses == j
+                        pseudo_counts[j] = np.sum(responses[mask]) - c / 2
+                    summary_statistics[0].append(types)
+                    summary_statistics[1].append(pseudo_counts)
+                    summary_statistics[2].append(counts)
+                    times.append(i)
+
+        return summary_statistics, times
+
+    def compute_sigmas(self, T):
+        """Sigmas can be precomputed (without z), we do this here."""
+        # We rely on the fact that H.T.dot(sigma).dot(H) is just a number, no matrix inversion needed
+        # furthermore we use the fact that many matrices are identities, namely F and G
+        self.sigma_k = []  # we have to reset this for repeating this calculation later for the resampling (R changes)
+        self.sigma_k_k_minus = [self.P_0]
+        self.gain_save = []
+        for k in range(T):
+            if self.R[k] == 0:
+                self.gain_save.append(None)
+                self.sigma_k.append(self.sigma_k_k_minus[k])
+                self.sigma_k_k_minus.append(self.sigma_k[k] + self.pseudo_Q[k])
+            else:
+                sigma, H = self.sigma_k_k_minus[k], self.H[k]  # we will need this a lot, so shorten it
+                gain = sigma.dot(H).dot(1 / (H.T.dot(sigma).dot(H) + self.R[k]))
+                self.gain_save.append(gain)
+                self.sigma_k.append(sigma - gain.dot(H.T).dot(sigma))
+                self.sigma_k_k_minus.append(self.sigma_k[k] + self.pseudo_Q[k])
+
+    def compute_means(self, T):
+        """Compute the means, the estimates of the states."""
+        self.x_hat_k = []  # we have to reset this for repeating this calculation later for the resampling
+        self.x_hat_k_k_minus = [self.x_0]
+        for k in range(T):  # this will leave out last state which doesn't have observation
+            if self.gain_save[k] is None:
+                self.x_hat_k.append(self.x_hat_k_k_minus[k])
+                self.x_hat_k_k_minus.append(self.x_hat_k[k])  # TODO: still no purpose
+            else:
+                x, H = self.x_hat_k_k_minus[k], self.H[k]  # we will need this a lot, so shorten it
+                self.x_hat_k.append(x + self.gain_save[k].dot(self.pseudo_obs[k] - H.T.dot(x)))
+                self.x_hat_k_k_minus.append(self.x_hat_k[k])  # TODO: doesn't really have a purpose if F is identity
+
+    def num_parameters(self):
+        return self.weights.size
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
diff --git a/pybasicbayes/distributions/dynamic_multinomial.py b/pybasicbayes/distributions/dynamic_multinomial.py
new file mode 100644
index 0000000..dbba312
--- /dev/null
+++ b/pybasicbayes/distributions/dynamic_multinomial.py
@@ -0,0 +1,286 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+__all__ = ['Dynamic_Input_Categorical']
+
+from pybasicbayes.abstractions import \
+    GibbsSampling
+
+from scipy import sparse
+import numpy as np
+from warnings import warn
+import time
+
+from pgmult.lda import StickbreakingDynamicTopicsLDA
+from pgmult.utils import psi_to_pi
+from scipy.stats import invgamma
+
+
+def enforce_limits(times, limit):
+    times = np.array(times)
+    diffs = np.zeros(len(times), dtype=np.int32)
+    diffs[1:] = np.diff(times)
+    diffs[diffs <= limit] = limit
+    diffs -= limit
+    diffs = np.cumsum(diffs)
+
+    diffs += times[0]
+    times -= diffs
+    return times
+
+
+assert np.array_equal(enforce_limits([0, 1, 2, 3, 4, 5], 3), [0, 1, 2, 3, 4, 5])
+assert np.array_equal(enforce_limits([0, 2, 4, 6, 8, 10, 12, 14, 15], 3), [0, 2, 4, 6, 8, 10, 12, 14, 15])
+assert np.array_equal(enforce_limits([0, 1, 2, 6, 10, 14], 3), [0, 1, 2, 5, 8, 11])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100], 3), [0, 1, 4, 7, 10])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102], 3), [0, 1, 4, 7, 10, 11, 12])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102, 104], 3), [0, 1, 4, 7, 10, 11, 12, 14])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102, 104, 110], 3), [0, 1, 4, 7, 10, 11, 12, 14, 17])
+assert np.array_equal(enforce_limits([1, 8, 20, 100, 101, 102, 104, 110], 3), [0, 3, 6, 9, 10, 11, 13, 16])
+
+
+def meh_time_info(all_times, sub_times, limit):
+    # Return time stamps for the needed counts, considering jumps
+    # Return list of mappings, to translate from this states timepoints to the overall timepoints
+    # Return first and last timepoint
+    times = []
+    maps = []
+    first, last = all_times[sub_times[0]], all_times[sub_times[-1]]
+    jump_counter = 0
+    time_counter = 0
+    for i in range(first, last+1):
+        if i in all_times:
+            times.append(i)
+            maps.append((time_counter, i))
+            jump_counter = 0
+            time_counter += 1
+        else:
+            jump_counter += 1
+            if jump_counter <= limit:
+                times.append(i)
+                maps.append((time_counter, i))
+                time_counter += 1
+            else:
+                maps.append((time_counter - 1, i))
+
+    return times, maps, first, last
+
+
+all_times, sub_times, limit = [0, 3, 9], [0, 2], 3
+mya1, mya2, mya3, mya4 = [0, 1, 2, 3, 4, 5, 6, 9], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (6, 7), (6, 8), (7, 9)], 0, 9
+o1, o2, o3, o4 = meh_time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+all_times, sub_times, limit = [0, 3, 9], [1], 3
+mya1, mya2, mya3, mya4 = [3], [(0, 3)], 3, 3
+o1, o2, o3, o4 = meh_time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+
+
+def time_info(all_times, sub_times, limit):
+    # Return list of mappings, to translate from this states timepoints to the overall timepoints
+    # Return first and last timepoint
+    maps = []
+    first, last = all_times[sub_times[0]], all_times[sub_times[-1]]
+    jump_counter = 0
+    time_counter = 0
+    for i in range(first, last+1):
+        if i in all_times:
+            maps.append((time_counter, i))
+            jump_counter = 0
+            time_counter += 1
+        else:
+            jump_counter += 1
+            if jump_counter < limit:
+                maps.append((time_counter, i))
+                time_counter += 1
+            else:
+                maps.append((time_counter - 1, i))
+
+    return maps, first, last
+
+
+all_times, sub_times, limit = [0, 3, 9], [0, 2], 3
+mya2, mya3, mya4 = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (5, 6), (5, 7), (5, 8), (6, 9)], 0, 9
+o2, o3, o4 = time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+all_times, sub_times, limit = [0, 3, 9], [1], 3
+mya2, mya3, mya4 = [(0, 3)], 3, 3  # TODO: I don't think it matters whether first answer is [3] or [0]
+o2, o3, o4 = time_info(all_times, sub_times, limit)
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+# TODO: more tests
+
+class Dynamic_Input_Categorical(GibbsSampling):
+    """
+    This class enables a drifting input output iHMM.
+
+    Everything is pretty much like in the input Categorical, but here we also
+    allow drift across sessions for the Categoricals. Similar to a dynamic topic
+    model (with one topic).
+
+    We'll have to tricks quite a bit:
+    - Instead of really resampling we'll reinstate a new dtm for every resampling and resample a couple of times (time killer)
+    - We can't initialize from prior all that easily. We'll just instantiate a constant Dirichlet from prior
+      once data is actually there, we'll use the real model
+    - this will also be a bit hard to copy... maybe really just use package for resampling, otherwise have all that stuff saved in a list of Categoricals?
+    Big problem: how to give log-likelihood for obs in sessions where this state was previously not present?
+    We don't know what value this thing should have there, could be anything...
+    -> if there is no data, we simply have to sample from prior. That is the will of Gibbs
+    That is: sample from prior for first 3 unaccounted sessions, then leave constant
+    But!: How to do this back in time? -> also Gibbs sample just from prior for three sessions
+
+    ! First session cannot contain no data for StickbreakingDynamicTopicsLDA
+
+    Idea: maybe save previously assigned betas (or psi's or whatever directly)
+    then initialize new dtm from there, so as to have to do less resampling
+    (this depends on what gets resampled first, hopefully the auxiliary variables first, then this saving should have an effect)
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights, a vector encoding a finite, multidimensional pmf]
+    """
+
+    def __init__(self, n_inputs, n_outputs, T, sigmasq_states, jumplimit=3, n_resample=15):
+
+        self.n_inputs = n_inputs
+        self.n_outputs = n_outputs  # could be made different for different inputs
+        if self.n_outputs != 2:
+            warn('this requires adaptions in the row swapping code for the dynamic topic model')
+            # quit()
+        self.T = T
+        self.jumplimit = jumplimit
+        self.sigmasq_states = sigmasq_states
+        self.n_resample = n_resample
+        self.psi_diff_saves = []
+
+        single_weights = np.zeros((self.n_inputs, self.n_outputs))
+        for i in range(self.n_inputs):
+            single_weights[i] = np.random.dirichlet(np.ones(self.n_outputs))
+        self.weights = np.tile(single_weights, (self.T, 1, 1))  # init to constant over timepoints
+
+    def rvs(self, input):
+        print("Looks like simple copy from Input_Categorical, not useful, no dynamics")
+        types, counts = np.unique(input, return_counts=True)
+        output = np.zeros_like(input)
+        for t, c in zip(types, counts):
+            temp = np.random.choice(self.n_outputs, c, p=self.weights[t])
+            output[input == t] = temp
+        return np.array((input, output)).T
+
+    def log_likelihood(self, x, timepoint):
+        out = np.zeros(x.shape[0], dtype=np.double)
+        nans = np.isnan(x[:, -1])
+        err = np.seterr(divide='ignore')
+        out[~nans] = np.log(self.weights[timepoint])[tuple(x[~nans].T.astype(int))]
+        np.seterr(**err)
+        out[nans] = 1
+        return out
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        self.psi_diff_saves = []
+        counts, all_times = self._get_statistics(data)
+
+        # if state is never active, resample from prior
+        if counts.sum() == 0:
+            single_weights = np.zeros((self.n_inputs, self.n_outputs))
+            for i in range(self.n_inputs):
+                single_weights[i] = np.random.dirichlet(np.ones(self.n_outputs))
+            self.weights = np.tile(single_weights, (self.T, 1, 1))  # init to constant over timepoints
+            return
+
+        fake_times = enforce_limits(all_times, self.jumplimit)
+        self.weights.fill(0)
+
+        for i in range(self.n_inputs):
+            if np.sum(counts[:, i]) == 0:
+                self.weights[:, i] = np.random.dirichlet(np.ones(self.n_outputs))
+            else:
+                temp = np.sum(counts[:, i], axis=1)
+                spec_times = np.where(temp)[0]
+                maps, first_non0, last_non0 = time_info(all_times, spec_times, self.jumplimit)
+                spec_fake_times = fake_times[spec_times]
+                # we shuffle the columns around, so as to have the timeout answer first, for a hopefully more constistent variance estimation
+
+                # dtm = StickbreakingDynamicTopicsLDA(sparse.csr_matrix(counts[spec_times, i][..., [2, 0, 1]]), spec_fake_times, K=1, alpha_theta=1, sigmasq_states=self.sigmasq_states)
+                dtm = StickbreakingDynamicTopicsLDA(sparse.csr_matrix(counts[spec_times, i]), spec_fake_times, K=1, alpha_theta=1, sigmasq_states=self.sigmasq_states)
+
+                for _ in range(self.n_resample):
+                    dtm.resample()
+
+                # save for resampling sigma
+                self.psi_diff_saves.append(np.diff(dtm.psi, axis=0).ravel())
+
+                # put dtm weights in right places
+                for m in maps:
+                    # shuffle back
+                    # self.weights[m[1], i] = dtm.beta[m[0], :, 0][..., [1, 2, 0]]
+                    self.weights[m[1], i] = dtm.beta[m[0], :, 0]
+
+                sample = dtm.psi[0]
+                for j in range(min(self.jumplimit, first_non0)):
+                    sample += np.random.normal(0, np.sqrt(self.sigmasq_states), size=self.n_outputs - 1)[:, None]  # is this the correct way to do this? not sure
+                    # self.weights[first_non0 - j - 1, i] = psi_to_pi(sample.T)[..., [1, 2, 0]]
+                    self.weights[first_non0 - j - 1, i] = psi_to_pi(sample.T)
+                if first_non0 > self.jumplimit:
+                    # self.weights[:first_non0 - self.jumplimit, i] = psi_to_pi(sample.T)[..., [1, 2, 0]]
+                    self.weights[:first_non0 - self.jumplimit, i] = psi_to_pi(sample.T)
+
+                sample = dtm.psi[-1]
+                for j in range(min(self.jumplimit, self.T - last_non0 - 1)):
+                    sample += np.random.normal(0, np.sqrt(self.sigmasq_states), size=self.n_outputs - 1)[:, None]  # is this the correct way to do this? not sure
+                    # self.weights[last_non0 + j + 1, i] = psi_to_pi(sample.T)[..., [1, 2, 0]]
+                    self.weights[last_non0 + j + 1, i] = psi_to_pi(sample.T)
+                if self.T - last_non0 - 1 > self.jumplimit:
+                    # self.weights[last_non0 + self.jumplimit + 1:, i] = psi_to_pi(sample.T)[..., [1, 2, 0]]
+                    self.weights[last_non0 + self.jumplimit + 1:, i] = psi_to_pi(sample.T)
+
+        self.psi_diff_saves = np.concatenate(self.psi_diff_saves)
+        assert np.count_nonzero(np.sum(self.weights, axis=2)) == np.sum(self.weights, axis=2).size
+
+    def _get_statistics(self, data):
+        # TODO: improve
+        counts = []
+        times = []
+        timepoint_count = np.empty((self.n_inputs, self.n_outputs), dtype=int)
+        if isinstance(data, np.ndarray):
+            warn('What you are trying is probably stupid, at least the code is not implemented')
+            quit()
+            # assert len(data.shape) == 2
+            # for d in data:
+            #     counts[tuple(d)] += 1
+        else:
+            for i, d in enumerate(data):
+                clean_d = (d[~np.isnan(d[:, -1])]).astype(int)
+                if len(clean_d) != 0:
+                    timepoint_count[:] = 0
+                    for subd in clean_d:
+                        timepoint_count[subd[0], subd[1]] += 1
+                    counts.append(timepoint_count.copy())
+                    times.append(i)
+
+        return np.array(counts), times
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
diff --git a/pybasicbayes/distributions/dynamic_multinomial_in_progress.py b/pybasicbayes/distributions/dynamic_multinomial_in_progress.py
new file mode 100644
index 0000000..46f2b0f
--- /dev/null
+++ b/pybasicbayes/distributions/dynamic_multinomial_in_progress.py
@@ -0,0 +1,265 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+__all__ = ['Dynamic_Input_Categorical']
+
+from pybasicbayes.abstractions import \
+    GibbsSampling
+
+from scipy import sparse
+import numpy as np
+from warnings import warn
+import time
+
+from pgmult.lda import StickbreakingDynamicTopicsLDA
+from pgmult.utils import psi_to_pi
+
+
+def enforce_limits(times, limit):
+    times = np.array(times)
+    diffs = np.zeros(len(times), dtype=np.int32)
+    diffs[1:] = np.diff(times)
+    diffs[diffs <= limit] = limit
+    diffs -= limit
+    diffs = np.cumsum(diffs)
+
+    diffs += times[0]
+    times -= diffs
+    return times
+
+
+assert np.array_equal(enforce_limits([0, 1, 2, 3, 4, 5], 3), [0, 1, 2, 3, 4, 5])
+assert np.array_equal(enforce_limits([0, 2, 4, 6, 8, 10, 12, 14, 15], 3), [0, 2, 4, 6, 8, 10, 12, 14, 15])
+assert np.array_equal(enforce_limits([0, 1, 2, 6, 10, 14], 3), [0, 1, 2, 5, 8, 11])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100], 3), [0, 1, 4, 7, 10])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102], 3), [0, 1, 4, 7, 10, 11, 12])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102, 104], 3), [0, 1, 4, 7, 10, 11, 12, 14])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102, 104, 110], 3), [0, 1, 4, 7, 10, 11, 12, 14, 17])
+assert np.array_equal(enforce_limits([1, 8, 20, 100, 101, 102, 104, 110], 3), [0, 3, 6, 9, 10, 11, 13, 16])
+
+
+def meh_time_info(all_times, sub_times, limit):
+    # Return time stamps for the needed counts, considering jumps
+    # Return list of mappings, to translate from this states timepoints to the overall timepoints
+    # Return first and last timepoint
+    times = []
+    maps = []
+    first, last = all_times[sub_times[0]], all_times[sub_times[-1]]
+    jump_counter = 0
+    time_counter = 0
+    for i in range(first, last+1):
+        if i in all_times:
+            times.append(i)
+            maps.append((time_counter, i))
+            jump_counter = 0
+            time_counter += 1
+        else:
+            jump_counter += 1
+            if jump_counter <= limit:
+                times.append(i)
+                maps.append((time_counter, i))
+                time_counter += 1
+            else:
+                maps.append((time_counter - 1, i))
+
+    return times, maps, first, last
+
+
+all_times, sub_times, limit = [0, 3, 9], [0, 2], 3
+mya1, mya2, mya3, mya4 = [0, 1, 2, 3, 4, 5, 6, 9], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (6, 7), (6, 8), (7, 9)], 0, 9
+o1, o2, o3, o4 = meh_time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+all_times, sub_times, limit = [0, 3, 9], [1], 3
+mya1, mya2, mya3, mya4 = [3], [(0, 3)], 3, 3
+o1, o2, o3, o4 = meh_time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+
+
+def time_info(all_times, sub_times, limit):
+    # Return list of mappings, to translate from this states timepoints to the overall timepoints
+    # Return first and last timepoint
+    maps = []
+    first, last = all_times[sub_times[0]], all_times[sub_times[-1]]
+    jump_counter = 0
+    time_counter = 0
+    for i in range(first, last+1):
+        if i in all_times:
+            maps.append((time_counter, i))
+            jump_counter = 0
+            time_counter += 1
+        else:
+            jump_counter += 1
+            if jump_counter < limit:
+                maps.append((time_counter, i))
+                time_counter += 1
+            else:
+                maps.append((time_counter - 1, i))
+
+    return maps, first, last
+
+
+all_times, sub_times, limit = [0, 3, 9], [0, 2], 3
+mya2, mya3, mya4 = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (5, 6), (5, 7), (5, 8), (6, 9)], 0, 9
+o2, o3, o4 = time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+all_times, sub_times, limit = [0, 3, 9], [1], 3
+mya2, mya3, mya4 = [(0, 3)], 3, 3  # TODO: I don't think it matters whether first answer is [3] or [0]
+o2, o3, o4 = time_info(all_times, sub_times, limit)
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+# TODO: more tests
+
+class Dynamic_Input_Categorical(GibbsSampling):
+    """
+    This class enables a drifting input output iHMM.
+
+    Everything is pretty much like in the input Categorical, but here we also
+    allow drift across sessions for the Categoricals. Similar to a dynamic topic
+    model (with one topic).
+
+    We'll have to tricks quite a bit:
+    - Instead of really resampling we'll reinstate a new dtm for every resampling and resample a couple of times (time killer)
+    - We can't initialize from prior all that easily. We'll just instantiate a constant Dirichlet from prior
+      once data is actually there, we'll use the real model
+    - this will also be a bit hard to copy... maybe really just use package for resampling, otherwise have all that stuff saved in a list of Categoricals?
+    Big problem: how to give log-likelihood for obs in sessions where this state was previously not present?
+    We don't know what value this thing should have there, could be anything...
+    -> if there is no data, we simply have to sample from prior. That is the will of Gibbs
+    That is: sample from prior for first 3 unaccounted sessions, then leave constant
+    But!: How to do this back in time? -> also Gibbs sample just from prior for three sessions
+
+    ! First session cannot contain no data for StickbreakingDynamicTopicsLDA
+
+    Idea: maybe save previously assigned betas (or psi's or whatever directly)
+    then initialize new dtm from there, so as to have to do less resampling
+    (this depends on what gets resampled first, hopefully the auxiliary variables first, then this saving should have an effect)
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights, a vector encoding a finite, multidimensional pmf]
+    """
+
+    def __init__(self, n_inputs, n_outputs, T, sigmasq_states=0.01, jumplimit=3):
+
+        self.n_inputs = n_inputs
+        self.n_outputs = n_outputs  # could be made different for different inputs
+        self.T = T
+        self.jumplimit = jumplimit
+        self.sigmasq_states = sigmasq_states  # !!! this needs to be updated with gibbs estimate of sdev
+        self.psi_save = None
+
+        single_weights = np.zeros((self.n_inputs, self.n_outputs))
+        for i in range(self.n_inputs):
+            single_weights[i] = np.random.dirichlet(np.ones(self.n_outputs))
+        self.weights = np.tile(single_weights, (self.T, 1, 1))  # init to constant over timepoints
+
+    def rvs(self, input):
+        types, counts = np.unique(input, return_counts=True)
+        output = np.zeros_like(input)
+        for t, c in zip(types, counts):
+            temp = np.random.choice(self.n_outputs, c, p=self.weights[t])
+            output[input == t] = temp
+        return np.array((input, output)).T
+
+    def log_likelihood(self, x, timepoint):
+        out = np.zeros_like(x, dtype=np.double)
+        err = np.seterr(divide='ignore')
+        out = np.log(self.weights[timepoint])[tuple(x.T)]
+        np.seterr(**err)
+        return out
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        counts, all_times = self._get_statistics(data)
+
+        # if state is never active, resample from prior
+        if counts.sum() == 0:
+            single_weights = np.zeros((self.n_inputs, self.n_outputs))
+            for i in range(self.n_inputs):
+                single_weights[i] = np.random.dirichlet(np.ones(self.n_outputs))
+            self.weights = np.tile(single_weights, (self.T, 1, 1))  # init to constant over timepoints
+            return
+
+        fake_times = enforce_limits(all_times, self.jumplimit)
+        self.weights.fill(0)
+
+        for i in range(self.n_inputs):
+            # TODO: case if this counts is always 0
+            if np.sum(counts[:, i]) == 0:
+                self.weights[:, i] = np.random.dirichlet(np.ones(self.n_outputs))
+            else:
+                temp = np.sum(counts[:, i], axis=1)
+                spec_times = np.where(temp)[0]
+                maps, first_non0, last_non0 = time_info(all_times, spec_times, self.jumplimit)
+
+                # represented_sessions = [m[1] for m in maps]  # these (add boundary samples) are sessions sampled, for later initialisation
+
+                spec_fake_times = fake_times[spec_times]
+                dtm = StickbreakingDynamicTopicsLDA(sparse.csr_matrix(counts[spec_times, i]), spec_fake_times, K=1, alpha_theta=1, sigmasq_states=self.sigmasq_states)
+
+                dtm.resample()
+
+                for m in maps:
+                    self.weights[m[1], i] = dtm.beta[m[0], :, 0]
+
+                sample = dtm.psi[0]
+                for j in range(min(self.jumplimit, first_non0)):
+                    sample += np.random.normal(0, np.sqrt(self.sigmasq_states), size=self.n_outputs - 1)[:, None]  # is this the correct way to do this? not sure
+                    self.weights[first_non0 - j - 1, i] = psi_to_pi(sample.T)
+                if first_non0 > self.jumplimit:
+                    self.weights[:first_non0 - self.jumplimit, i] = psi_to_pi(sample.T)
+
+                sample = dtm.psi[-1]
+                for j in range(min(self.jumplimit, self.T - last_non0 - 1)):
+                    sample += np.random.normal(0, np.sqrt(self.sigmasq_states), size=self.n_outputs - 1)[:, None]  # is this the correct way to do this? not sure
+                    self.weights[last_non0 + j + 1, i] = psi_to_pi(sample.T)
+                if self.T - last_non0 - 1 > self.jumplimit:
+                    self.weights[last_non0 + self.jumplimit + 1:, i] = psi_to_pi(sample.T)
+
+        assert np.count_nonzero(np.sum(self.weights, axis=2)) == np.sum(self.weights, axis=2).size  # weird way to test whether everything is normalised?
+
+    def _get_statistics(self, data):
+        # TODO: improve
+        counts = []
+        times = []
+        timepoint_count = np.empty((self.n_inputs, self.n_outputs), dtype=int)
+        if isinstance(data, np.ndarray):
+            warn('What you are trying is probably stupid, at least the code is not implemented')
+            quit()
+            # assert len(data.shape) == 2
+            # for d in data:
+            #     counts[tuple(d)] += 1
+        else:
+            for i, d in enumerate(data):
+                if len(d) != 0:
+                    timepoint_count[:] = 0
+                    for subd in d:
+                        timepoint_count[subd[0], subd[1]] += 1
+                    counts.append(timepoint_count.copy())
+                    times.append(i)
+
+        return np.array(counts), times
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
diff --git a/pybasicbayes/distributions/gaussian.py b/pybasicbayes/distributions/gaussian.py
index f283986..79d045a 100644
--- a/pybasicbayes/distributions/gaussian.py
+++ b/pybasicbayes/distributions/gaussian.py
@@ -348,7 +348,6 @@ class Gaussian(
             + (self.nu_0 - D - 1)/2*loglmbdatilde - 1/2*self.nu_mf \
             * np.linalg.solve(self.sigma_mf,self.sigma_0).trace()
 
-
         return p_avgengy + q_entropy
 
     def expected_log_likelihood(self, x=None, stats=None):
@@ -557,7 +556,7 @@ class GaussianFixedMean(_GaussianBase, GibbsSampling, MaxLikelihood):
 class GaussianFixedCov(_GaussianBase, GibbsSampling, MaxLikelihood):
     # See Gelman's Bayesian Data Analysis notation around Eq. 3.18, p. 85
     # in 2nd Edition. We replaced \Lambda_0 with sigma_0 since it is a prior
-    # *covariance* matrix rather than a precision matrix. 
+    # *covariance* matrix rather than a precision matrix.
     def __init__(self,mu=None,sigma=None,mu_0=None,sigma_0=None):
         self.mu = mu
 
diff --git a/pybasicbayes/distributions/multinomial.py b/pybasicbayes/distributions/multinomial.py
index 4b119fc..477675b 100644
--- a/pybasicbayes/distributions/multinomial.py
+++ b/pybasicbayes/distributions/multinomial.py
@@ -2,9 +2,12 @@ from __future__ import division
 from builtins import zip
 from builtins import map
 from builtins import range
+import copy
 __all__ = ['Categorical', 'CategoricalAndConcentration', 'Multinomial',
-           'MultinomialAndConcentration', 'GammaCompoundDirichlet', 'CRP']
+           'MultinomialAndConcentration', 'GammaCompoundDirichlet', 'CRP',
+           'Input_Categorical', 'Input_Categorical_Normal']
 
+from pybasicbayes.distributions import gaussian
 import numpy as np
 from warnings import warn
 import scipy.stats as stats
@@ -104,7 +107,16 @@ class Categorical(GibbsSampling, MeanField, MeanFieldSVI, MaxLikelihood, MAP):
 
     def resample(self,data=[],counts=None):
         counts = self._get_statistics(data) if counts is None else counts
-        self.weights = np.random.dirichlet(self.alphav_0 + counts)
+        try:
+            self.weights = np.random.dirichlet(self.alphav_0 + counts)
+        except ZeroDivisionError as e:
+            # print("ZeroDivisionError {}".format(e))
+            self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts)
+        except ValueError as e:
+            # print("ValueError {}".format(e))
+            self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts)
+        if np.isnan(self.weights).any():
+            self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts)
         np.clip(self.weights, np.spacing(1.), np.inf, out=self.weights)
         # NOTE: next line is so we can use Gibbs sampling to initialize mean field
         self._alpha_mf = self.weights * self.alphav_0.sum()
@@ -363,7 +375,13 @@ class CRP(GibbsSampling):
     def resample(self,data=[],niter=50):
         for itr in range(niter):
             a_n, b_n = self._posterior_hypparams(*self._get_statistics(data))
-            self.concentration = np.random.gamma(a_n,scale=1./b_n)
+            try:
+                self.concentration = np.random.gamma(a_n,scale=1./b_n)
+            except Exception:
+                print("have to apply weird bug fix in /apps/conda/sbruijns/envs/hdp_pg_env/lib/python3.4/site-packages/pybasicbayes/distributions/multinomial")
+                self.concentration += 0.00001
+                a_n, b_n = self._posterior_hypparams(*self._get_statistics(data))
+                self.concentration = np.random.gamma(a_n,scale=1./b_n)
 
     def _posterior_hypparams(self,sample_numbers,total_num_distinct):
         # NOTE: this is a stochastic function: it samples auxiliary variables
@@ -477,3 +495,155 @@ class GammaCompoundDirichlet(CRP):
                         / (np.arange(n)+self.concentration*self.K*self.weighted_cols[j])).sum()
             return counts.sum(1), m
 
+
+class Input_Categorical():
+    '''
+    This class enables an input output HMM structure, where the observation distribution
+    that we sample from is dictated by the input at that time step, so this input must be an
+    integer. Parameters are weights and the prior is a Dirichlet distribution.
+    Inspired by Matt's other distributions.
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights, a vector encoding a finite, multidimensional pmf]
+    '''
+    def __init__(self, n_inputs, n_outputs, modulate=100):
+
+        self.n_inputs = n_inputs
+        self.n_outputs = n_outputs  # could be made different for different inputs
+        self.modulate = modulate  # high modulation (take modulo of first row data) means no effect, modulation = stimulus set size means no conditinoning
+        #print("Input data is taken mod {}".format(self.modulate))
+        self.weights = np.zeros((n_inputs, n_outputs))
+
+        self.resample()  # intialize from prior
+
+    def rvs(self, input):
+        input = copy.copy(input)
+        input = input % self.modulate
+        types, counts = np.unique(input, return_counts=True)
+        output = np.zeros_like(input)
+        for t, c in zip(types, counts):
+            temp = np.random.choice(self.n_outputs, c, p=self.weights[t])
+            output[input == t] = temp
+        return np.array((input, output)).T
+
+    def log_likelihood(self, x):
+        x = copy.copy(x)
+        x[:, 0] %= self.modulate
+        out = np.zeros_like(x, dtype=np.double)
+        err = np.seterr(divide='ignore')
+        out = np.log(self.weights)[tuple(x.T)]
+        np.seterr(**err)
+        return out
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        counts = self._get_statistics(data)
+
+        for i in range(self.n_inputs):
+            self.weights[i] = np.random.dirichlet(np.ones(self.n_outputs) + counts[i])
+
+    def _get_statistics(self, data):
+        # TODO: improve
+        data = copy.copy(data)
+        counts = np.zeros_like(self.weights, dtype=np.int32)
+        if isinstance(data, np.ndarray):
+            assert len(data.shape) == 2
+            data[:, 0] %= self.modulate
+            for d in data:
+                counts[tuple(d)] += 1
+        else:
+            for d in data:
+                d[:, 0] %= self.modulate
+                for subd in d:
+                    counts[tuple(subd)] += 1
+        return counts
+
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
+
+
+class Input_Categorical_Normal():
+    '''
+    This class enables an input output HMM structure, where the observation distribution
+    that we sample from is dictated by the input at that time step, so this input must be an
+    integer.
+    Here, One output is from a Categorical, the other from a normal.
+    Parameters are weights and the prior is a Dirichlet distribution.
+    Inspired by Matt's other distributions.
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights, a vector encoding a finite, multidimensional pmf]
+    '''
+    def __init__(self, n_inputs, n_outputs, normal_hypparams):
+
+        self.n_inputs = n_inputs
+        self.n_outputs = n_outputs  # could be made different for different inputs
+        self.cats = Input_Categorical(n_inputs, n_outputs)
+        self.normals = [gaussian.Gaussian(**normal_hypparams) for state in range(n_inputs)]
+
+    def rvs(self, input):
+        types, counts = np.unique(input, return_counts=True)
+        output = np.zeros((len(input), 2))
+        output[:, 0] = self.cats.rvs(input)[:, 1]
+
+        for t, c in zip(types, counts):
+            temp = self.normals[t].rvs(c)
+            output[input == t, 1] = temp[:, 0]
+
+        return np.concatenate((np.array(input)[None].T, output), axis=1)
+
+    def log_likelihood(self, x):
+        out = self.cats.log_likelihood(x[:, [0, 1]].astype(int))
+
+        # Filter out negative or 0 reaction times
+        bad_nums = x[:, 2] == -np.inf
+        x[bad_nums, 2] = 0
+
+        means = np.zeros(x.shape[0])
+        stds = np.zeros(x.shape[0])
+        normal_ll = np.zeros(x.shape[0])
+        for i, n in enumerate(self.normals):
+            mask = x[:, 0] == i
+            # don't consider negative RTs in ll calculation
+            means[mask], stds[mask] = n.mu, n.sigma[0]
+
+        normal_ll = stats.norm().logpdf((x[:, 2] - means) / stds) / stds
+        normal_ll[bad_nums] = 0
+
+        return out
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        if isinstance(data, np.ndarray):
+            pass
+        if isinstance(data, list):
+            data = np.concatenate(data)
+        self.cats.resample(data[:, [0, 1]].astype(int))
+
+        bad_nums = data[:, 2] == -np.inf
+        for i, n in enumerate(self.normals):
+            mask = data[:, 0] == i
+            n.resample(data=data[np.logical_and(mask, ~bad_nums), 2])
+
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
diff --git a/pybasicbayes/distributions/negativebinomial.py b/pybasicbayes/distributions/negativebinomial.py
index 7caed4e..9ed66c8 100644
--- a/pybasicbayes/distributions/negativebinomial.py
+++ b/pybasicbayes/distributions/negativebinomial.py
@@ -11,7 +11,7 @@ __all__ = [
 import numpy as np
 from numpy import newaxis as na
 import scipy.special as special
-from scipy.special import logsumexp
+from scipy.misc import logsumexp
 from warnings import warn
 
 from pybasicbayes.abstractions import Distribution, GibbsSampling, \
@@ -371,6 +371,9 @@ class NegativeBinomialIntegerR2(_NegativeBinomialBase,MeanField,MeanFieldSVI,Gib
         return np.exp(self.rho_mf).dot(self.rho_0) \
                 - np.exp(self.rho_mf).dot(self.rho_mf)
 
+    def num_parameters(self):
+        return 2
+
     def meanfieldupdate(self,data,weights):
         for d in self._fixedr_distns:
             d.meanfieldupdate(data,weights)
-- 
GitLab