diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/__init__.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f8fd26cf1025eb92b008300a18f3da6e57ba454
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import abstractions, distributions, models, util
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/abstractions.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/abstractions.py
new file mode 100644
index 0000000000000000000000000000000000000000..853cab010a8da9317cdb63b254be6a5c4fb60cbc
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/abstractions.py
@@ -0,0 +1,247 @@
+from __future__ import print_function
+from builtins import range
+from builtins import object
+import abc
+import numpy as np
+import copy
+
+import pybasicbayes
+from pybasicbayes.util.stats import combinedata
+from pybasicbayes.util.text import progprint_xrange
+from future.utils import with_metaclass
+
+# NOTE: data is always a (possibly masked) np.ndarray or list of (possibly
+# masked) np.ndarrays.
+
+# TODO figure out a data abstraction
+# TODO make an exponential family abc to reduce boilerplate
+
+################
+#  Base class  #
+################
+
+class Distribution(with_metaclass(abc.ABCMeta, object)):
+    @abc.abstractmethod
+    def rvs(self,size=[]):
+        'random variates (samples)'
+        pass
+
+    @abc.abstractmethod
+    def log_likelihood(self,x):
+        '''
+        log likelihood (either log probability mass function or log probability
+        density function) of x, which has the same type as the output of rvs()
+        '''
+        pass
+
+class BayesianDistribution(with_metaclass(abc.ABCMeta, Distribution)):
+    def empirical_bayes(self,data):
+        '''
+        (optional) set hyperparameters via empirical bayes
+        e.g. treat argument as a pseudo-dataset for exponential family
+        '''
+        raise NotImplementedError
+
+#########################################################
+#  Algorithm interfaces for inference in distributions  #
+#########################################################
+
+class GibbsSampling(with_metaclass(abc.ABCMeta, BayesianDistribution)):
+    @abc.abstractmethod
+    def resample(self,data=[]):
+        pass
+
+    def copy_sample(self):
+        '''
+        return an object copy suitable for making lists of posterior samples
+        (override this method to prevent copying shared structures into each sample)
+        '''
+        return copy.deepcopy(self)
+
+    def resample_and_copy(self):
+        self.resample()
+        return self.copy_sample()
+
+class MeanField(with_metaclass(abc.ABCMeta, BayesianDistribution)):
+    @abc.abstractmethod
+    def expected_log_likelihood(self,x):
+        pass
+
+    @abc.abstractmethod
+    def meanfieldupdate(self,data,weights):
+        pass
+
+    def get_vlb(self):
+        raise NotImplementedError
+
+class MeanFieldSVI(with_metaclass(abc.ABCMeta, BayesianDistribution)):
+    @abc.abstractmethod
+    def meanfield_sgdstep(self,expected_suff_stats,prob,stepsize):
+        pass
+
+class Collapsed(with_metaclass(abc.ABCMeta, BayesianDistribution)):
+    @abc.abstractmethod
+    def log_marginal_likelihood(self,data):
+        pass
+
+    def log_predictive(self,newdata,olddata):
+        return self.log_marginal_likelihood(combinedata((newdata,olddata))) \
+                    - self.log_marginal_likelihood(olddata)
+
+    def predictive(self,*args,**kwargs):
+        return np.exp(self.log_predictive(*args,**kwargs))
+
+class MaxLikelihood(with_metaclass(abc.ABCMeta, Distribution)):
+    @abc.abstractmethod
+    def max_likelihood(self,data,weights=None):
+        '''
+        sets the parameters set to their maximum likelihood values given the
+        (weighted) data
+        '''
+        pass
+
+    @property
+    def num_parameters(self):
+        raise NotImplementedError
+
+class MAP(with_metaclass(abc.ABCMeta, BayesianDistribution)):
+    @abc.abstractmethod
+    def MAP(self,data,weights=None):
+        '''
+        sets the parameters to their MAP values given the (weighted) data
+        analogous to max_likelihood but includes hyperparameters
+        '''
+        pass
+
+class Tempering(BayesianDistribution):
+    @abc.abstractmethod
+    def log_likelihood(self,data,temperature=1.):
+        pass
+
+    @abc.abstractmethod
+    def resample(self,data,temperature=1.):
+        pass
+
+    def energy(self,data):
+        return -self.log_likelihood(data,temperature=1.)
+
+############
+#  Models  #
+############
+
+# a "model" is differentiated from a "distribution" in this code by latent state
+# over data: a model attaches a latent variable (like a label or state sequence)
+# to data, and so it 'holds onto' data. Hence the add_data method.
+
+class Model(with_metaclass(abc.ABCMeta, object)):
+    @abc.abstractmethod
+    def add_data(self,data):
+        pass
+
+    @abc.abstractmethod
+    def generate(self,keep=True,**kwargs):
+        '''
+        Like a distribution's rvs, but this also fills in latent state over
+        data and keeps references to the data.
+        '''
+        pass
+
+    def rvs(self,*args,**kwargs):
+        return self.generate(*args,keep=False,**kwargs)[0] # 0th component is data, not latent stuff
+
+##################################################
+#  Algorithm interfaces for inference in models  #
+##################################################
+
+class ModelGibbsSampling(with_metaclass(abc.ABCMeta, Model)):
+    @abc.abstractmethod
+    def resample_model(self): # TODO niter?
+        pass
+
+    def copy_sample(self):
+        '''
+        return an object copy suitable for making lists of posterior samples
+        (override this method to prevent copying shared structures into each sample)
+        '''
+        return copy.deepcopy(self)
+
+    def resample_and_copy(self):
+        self.resample_model()
+        return self.copy_sample()
+
+class ModelMeanField(with_metaclass(abc.ABCMeta, Model)):
+    @abc.abstractmethod
+    def meanfield_coordinate_descent_step(self):
+        # returns variational lower bound after update, if available
+        pass
+
+    def meanfield_coordinate_descent(self,tol=1e-1,maxiter=250,progprint=False,**kwargs):
+        # NOTE: doesn't re-initialize!
+        scores = []
+        step_iterator = range(maxiter) if not progprint else progprint_xrange(maxiter)
+        for itr in step_iterator:
+            scores.append(self.meanfield_coordinate_descent_step(**kwargs))
+            if scores[-1] is not None and len(scores) > 1:
+                if np.abs(scores[-1]-scores[-2]) < tol:
+                    return scores
+        print('WARNING: meanfield_coordinate_descent hit maxiter of %d' % maxiter)
+        return scores
+
+class ModelMeanFieldSVI(with_metaclass(abc.ABCMeta, Model)):
+    @abc.abstractmethod
+    def meanfield_sgdstep(self,minibatch,prob,stepsize):
+        pass
+
+class _EMBase(with_metaclass(abc.ABCMeta, Model)):
+    @abc.abstractmethod
+    def log_likelihood(self):
+        # returns a log likelihood number on attached data
+        pass
+
+    def _EM_fit(self,method,tol=1e-1,maxiter=100,progprint=False):
+        # NOTE: doesn't re-initialize!
+        likes = []
+        step_iterator = range(maxiter) if not progprint else progprint_xrange(maxiter)
+        for itr in step_iterator:
+            method()
+            likes.append(self.log_likelihood())
+            if len(likes) > 1:
+                if likes[-1]-likes[-2] < tol:
+                    return likes
+                elif likes[-1] < likes[-2]:
+                    # probably oscillation, do one more
+                    method()
+                    likes.append(self.log_likelihood())
+                    return likes
+        print('WARNING: EM_fit reached maxiter of %d' % maxiter)
+        return likes
+
+class ModelEM(with_metaclass(abc.ABCMeta, _EMBase)):
+    def EM_fit(self,tol=1e-1,maxiter=100):
+        return self._EM_fit(self.EM_step,tol=tol,maxiter=maxiter)
+
+    @abc.abstractmethod
+    def EM_step(self):
+        pass
+
+class ModelMAPEM(with_metaclass(abc.ABCMeta, _EMBase)):
+    def MAP_EM_fit(self,tol=1e-1,maxiter=100):
+        return self._EM_fit(self.MAP_EM_step,tol=tol,maxiter=maxiter)
+
+    @abc.abstractmethod
+    def MAP_EM_step(self):
+        pass
+
+class ModelParallelTempering(with_metaclass(abc.ABCMeta, Model)):
+    @abc.abstractproperty
+    def temperature(self):
+        pass
+
+    @abc.abstractproperty
+    def energy(self):
+        pass
+
+    @abc.abstractmethod
+    def swap_sample_with(self,other):
+        pass
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/__init__.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6401355d18af187680f1b65541d6b3c12851497
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/__init__.py
@@ -0,0 +1,13 @@
+from __future__ import absolute_import
+from .meta import *
+
+from .regression import *
+from .gaussian import *
+from .uniform import *
+
+from .binomial import *
+from .multinomial import *
+from .negativebinomial import *
+from .geometric import *
+from .poisson import *
+from .dynamic_glm import *
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/binomial.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/binomial.py
new file mode 100644
index 0000000000000000000000000000000000000000..98eb59b195df61a912ae28f6153292a4e65165a4
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/binomial.py
@@ -0,0 +1,131 @@
+from __future__ import division
+from builtins import zip
+__all__ = ['Binomial']
+
+import numpy as np
+import scipy.stats as stats
+import scipy.special as special
+from warnings import warn
+
+from pybasicbayes.abstractions import GibbsSampling, MeanField, \
+    MeanFieldSVI
+
+
+class Binomial(GibbsSampling, MeanField, MeanFieldSVI):
+    '''
+    Models a Binomial likelihood and a Beta prior:
+
+        p ~ Beta(alpha_0, beta_0)
+        x | p ~ Binom(p,n)
+
+    where p is the success probability, alpha_0-1 is the prior number of
+    successes, beta_0-1 is the prior number of failures.
+
+    A special case of Multinomial where N is fixed and each observation counts
+    the number of successes and is in {0,1,...,N}.
+    '''
+    def __init__(self,alpha_0,beta_0,alpha_mf=None,beta_mf=None,p=None,n=None):
+        warn('this class is untested!')
+        assert n is not None
+
+        self.n = n
+        self.alpha_0 = alpha_0
+        self.beta_0 = beta_0
+
+        self.alpha_mf = alpha_mf if alpha_mf is not None else alpha_0
+        self.beta_mf = beta_mf if beta_mf is not None else beta_0
+
+        if p is not None:
+            self.p = p
+        else:
+            self.resample()
+
+    def log_likelihood(self,x):
+        return stats.binom.pmf(x,self.n,self.p)
+
+    def rvs(self,size=None):
+        return stats.binom.pmf(self.n,self.p,size=size)
+
+    @property
+    def natural_hypparam(self):
+        return np.array([self.alpha_0 - 1, self.beta_0 - 1])
+
+    @natural_hypparam.setter
+    def natural_hypparam(self,natparam):
+        self.alpha_0, self.beta_0 = natparam + 1
+
+    def _get_statistics(self,data):
+        if isinstance(data,np.ndarray):
+            data = data.ravel()
+            tot = data.sum()
+            return np.array([tot, self.n*data.shape[0] - tot])
+        else:
+            return sum(
+                (self._get_statistics(d) for d in data),
+                self._empty_statistics())
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(data,np.ndarray):
+            data, weights = data.ravel(), weights.ravel()
+            tot = weights.dot(data)
+            return np.array([tot, self.n*weights.sum() - tot])
+        else:
+            return sum(
+                (self._get_weighted_statistics(d,w) for d,w in zip(data,weights)),
+                self._empty_statistics())
+
+    def _empty_statistics(self):
+        return np.zeros(2)
+
+    ### Gibbs
+
+    def resample(self,data=[]):
+        alpha_n, beta_n = self.natural_hypparam + self._get_statistics(data) + 1
+        self.p = np.random.beta(alpha_n,beta_n)
+
+        # use Gibbs to initialize mean field
+        self.alpha_mf = self.p * (self.alpha_0 + self.beta_0)
+        self.beta_mf = (1-self.p) * (self.alpha_0 + self.beta_0)
+
+    ### Mean field and SVI
+
+    def meanfieldupdate(self,data,weights):
+        self.mf_natural_hypparam = \
+            self.natural_hypparam + self._get_weighted_statistics(data,weights)
+
+        # use mean field to initialize Gibbs
+        self.p = self.alpha_mf / (self.alpha_mf + self.beta_mf)
+
+    def meanfield_sgdstep(self,data,weights,minibatchprob,stepsize):
+        self.mf_natural_hypparam = \
+            (1-stepsize) * self.mf_natural_hypparam + stepsize * (
+                self.natural_hypparam
+                + 1./minibatchprob * self._get_weighted_statistics(data,weights))
+
+    @property
+    def mf_natural_hypparam(self):
+        return np.array([self.alpha_mf - 1, self.beta_mf - 1])
+
+    @mf_natural_hypparam.setter
+    def mf_natural_hypparam(self,natparam):
+        self.alpha_mf, self.beta_mf = natparam + 1
+
+    def expected_log_likelihood(self,x):
+        n = self.n
+        Elnp, Eln1mp = self._mf_expected_statistics()
+        return special.gammaln(n+1) - special.gammaln(x+1) - special.gammaln(n-x+1) \
+            + x*Elnp + (n-x)*Eln1mp
+
+    def _mf_expected_statistics(self):
+        return special.digamma([self.alpha_mf, self.beta_mf]) \
+            - special.digamma(self.alpha_mf + self.beta_mf)
+
+    def get_vlb(self):
+        Elnp, Eln1mp = self._mf_expected_statistics()
+        return (self.alpha_0 - self.alpha_mf)*Elnp \
+            + (self.beta_0 - self.beta_mf)*Eln1mp \
+            - (self._log_partition_function(self.alpha_0, self.beta_0)
+                - self._log_partition_function(self.alpha_mf,self.beta_mf))
+
+    def _log_partition_function(self,alpha,beta):
+        return special.betaln(alpha,beta)
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/dynamic_glm.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/dynamic_glm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e1a38b0776c2b5283b5e17eeec004e9412f910e
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/dynamic_glm.py
@@ -0,0 +1,279 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+__all__ = ['Dynamic_GLM']
+
+from pybasicbayes.abstractions import \
+    GibbsSampling
+
+import numpy as np
+from warnings import warn
+from pypolyagamma import PyPolyaGamma
+
+def local_multivariate_normal_draw(x, sigma, normal):
+    """
+    Cholesky doesn't like 0 cov matrix, but we want it.
+
+    TODO: This might need changing if in practice we see plently of 0 matrices
+    """
+    try:
+        return x + np.linalg.cholesky(sigma).dot(normal)
+    except np.linalg.LinAlgError:
+        if np.isclose(sigma, 0).all():
+            return x
+        else:
+            print("Weird covariance matrix")
+            quit()
+
+
+ppgsseed = 4
+if ppgsseed == 4:
+    print("Using default seed")
+ppgs = PyPolyaGamma(ppgsseed)
+class Dynamic_GLM(GibbsSampling):
+    """
+    This class enables a drifting input output iHMM with logistic link function.
+
+    States are thus dynamic GLMs, giving us more freedom as to the inputs we give the model.
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights]
+    """
+
+    def __init__(self, n_inputs, T, prior_mean, P_0, Q, jumplimit=3, seed=4):
+
+        self.n_inputs = n_inputs
+        self.T = T
+        self.jumplimit = jumplimit
+        self.x_0 = prior_mean
+        self.P_0, self.Q = P_0, Q
+        self.psi_diff_saves = []
+        self.noise_mean = np.zeros(self.n_inputs)  # save this, so as to not keep creating it
+        self.identity = np.eye(self.n_inputs)  # not really needed, but kinda useful for state sampling, mabye delete TODO
+
+        # if seed == 4:
+        #     print("Using default seed")
+        # self.ppgs = PyPolyaGamma(seed)
+        self.weights = np.empty((self.T, self.n_inputs))  # one more spot for bias
+        self.weights[0] = np.random.multivariate_normal(mean=self.x_0, cov=self.P_0)
+        for t in range(1, T):
+            self.weights[t] = self.weights[t - 1] + np.random.multivariate_normal(mean=self.noise_mean, cov=self.Q[t - 1])
+
+    def rvs(self, inputs, times):
+        outputs = []
+        for input, t in zip(inputs, times):
+            if input.shape[0] == 0:
+                output = np.zeros((0, self.n_inputs + 1))
+            else:
+                types, inverses, counts = np.unique(input, return_inverse=1, return_counts=True, axis=0)
+                output = np.append(input, np.empty((input.shape[0], 1)), axis=1)
+                for i, (type, c) in enumerate(zip(types, counts)):
+                    temp = np.random.rand(c) < 1 / (1 + np.exp(- np.sum(self.weights[t] * type)))
+                    output[inverses == i, -1] = temp
+            outputs.append(output)
+        return outputs
+
+    def log_likelihood(self, input, timepoint):
+        predictors, responses = input[:, :-1], input[:, -1]
+        nans = np.isnan(responses)
+        probs = np.zeros((input.shape[0], 2))
+        out = np.zeros(input.shape[0])
+        # I could possibly save the 1 / ..., since it's logged it's just - log (but the other half of the probs is an issue)
+        probs[:, 1] = 1 / (1 + np.exp(- np.sum(self.weights[timepoint] * predictors, axis=1)))
+        probs[:, 0] = 1 - probs[:, 1]
+        # probably not necessary, just fill everything with probs and then have some be 1 - out?
+        out[~nans] = probs[np.arange(input.shape[0])[~nans], responses[~nans].astype(int)]
+        out = np.clip(out, np.spacing(1), 1 - np.spacing(1))
+        out[nans] = 1
+
+        return np.log(out)
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        # TODO: Clean up this mess, I always have to call delete_obs_data because of all the saved shit!
+        self.psi_diff_saves = []
+        summary_statistics, all_times = self._get_statistics(data)
+        types, pseudo_counts, counts = summary_statistics
+
+        # if state is never active, resample from prior, but without dynamic change
+        if len(counts) == 0:
+            self.weights = np.tile(np.random.multivariate_normal(mean=self.x_0, cov=self.P_0), (self.T, 1))
+            return
+
+        """compute Kalman filter parameters, also sort out which weight vector goes to which timepoint"""
+        timepoint_map = {}
+        total_types = 0
+        actual_obs_count = 0
+        change_points = []
+        prev_t = all_times[0] - 1
+        fake_times = []
+        for type, t in zip(types, all_times):
+            if t > prev_t + 1:
+                add_list = list(range(total_types, min(total_types + t - prev_t - 1, total_types + self.jumplimit)))
+                change_points += add_list
+                fake_times += add_list
+                for i, sub_t in enumerate(range(total_types, total_types + t - prev_t - 1)): # TODO: set up this loop better
+                    timepoint_map[prev_t + i + 1] = min(sub_t, total_types + self.jumplimit - 1)
+                total_types += min(t - prev_t - 1, self.jumplimit)
+            total_types += type.shape[0]
+            actual_obs_count += type.shape[0]
+            change_points.append(total_types - 1)
+            timepoint_map[t] = total_types - 1
+            prev_t = t
+
+        # print(total_types)
+        # print(actual_obs_count)
+        # print(change_points)
+        # print(fake_times)
+        # print(all_times)
+        # print(timepoint_map)
+        # return timepoint_map
+
+        self.pseudo_Q = np.zeros((total_types, self.n_inputs, self.n_inputs))
+        # TODO: is it okay to cut off last timepoint here?
+        for k in range(self.T):
+            if k in timepoint_map:
+                self.pseudo_Q[timepoint_map[k]] = self.Q[k]  # for every timepoint, map it's variance onto the pseudo_Q
+
+        """sample pseudo obs"""
+        temp = np.empty(actual_obs_count)
+        psis = np.empty(actual_obs_count)
+        psi_count = 0
+        predictors = []
+        for type, time in zip(types, all_times):
+            for t in type:
+                psis[psi_count] = np.sum(self.weights[time] * t)
+                predictors.append(t)
+                psi_count += 1
+
+        ppgs.pgdrawv(np.concatenate(counts).astype(float), psis, temp)
+        self.R = np.zeros(total_types)
+        mask = np.ones(total_types, dtype=np.bool)
+        mask[fake_times] = False
+        self.R[mask] = 1 / temp
+        self.pseudo_obs = np.zeros(total_types)
+        self.pseudo_obs[mask] = np.concatenate(pseudo_counts) / temp
+        self.pseudo_obs = self.pseudo_obs.reshape(total_types, 1)
+        self.H = np.zeros((total_types, self.n_inputs, 1))
+        self.H[mask] = np.array(predictors).reshape(actual_obs_count, self.n_inputs, 1)
+
+        """compute means and sigmas by filtering"""
+        # if there is no obs, sigma_k = sigma_k_k_minus and x_hat_k = x_hat_k_k_minus (because R is infinite at that time)
+        self.compute_sigmas(total_types)
+        self.compute_means(total_types)
+
+        """sample states"""
+        self.weights.fill(0)
+        pseudo_weights = np.empty((total_types, self.n_inputs))
+        pseudo_weights[total_types - 1] = np.random.multivariate_normal(self.x_hat_k[total_types - 1], self.sigma_k[total_types - 1])
+
+        normals = np.random.standard_normal((total_types - 1, self.n_inputs))
+        for k in range(total_types - 2, -1, -1):  # normally -1, but we already did first sampling
+            if np.all(self.pseudo_Q[k] == 0):
+                pseudo_weights[k] = pseudo_weights[k + 1]
+            else:
+                updated_x = self.x_hat_k[k].copy()  # not sure whether copy is necessary here
+                updated_sigma = self.sigma_k[k].copy()
+
+                for m in range(self.n_inputs):
+                    epsilon = pseudo_weights[k + 1, m] - updated_x[m]
+                    state_R = updated_sigma[m, m] + self.pseudo_Q[k, m, m]
+
+                    updated_x += updated_sigma[:, m] * epsilon / state_R  # I don't think it's important, but probs we need the first column
+                    updated_sigma -= updated_sigma.dot(np.outer(self.identity[m], self.identity[m])).dot(updated_sigma) / state_R
+
+                pseudo_weights[k] = local_multivariate_normal_draw(updated_x, updated_sigma, normals[k])
+
+        for k in range(self.T):
+            if k in timepoint_map:
+                self.weights[k] = pseudo_weights[timepoint_map[k]]
+
+        """don't forget to sample before and after active times too"""
+        for k in range(all_times[0] - 1, -1, -1):
+            if k > all_times[0] - self.jumplimit - 1:
+                self.weights[k] = self.weights[k + 1] + np.random.multivariate_normal(self.noise_mean, self.Q[k])
+            else:
+                self.weights[k] = self.weights[k + 1]
+        for k in range(all_times[-1] + 1, self.T):
+            if k < min(all_times[-1] + 1 + self.jumplimit, self.T):
+                self.weights[k] = self.weights[k - 1] + np.random.multivariate_normal(self.noise_mean, self.Q[k])
+            else:
+                self.weights[k] = self.weights[k - 1]
+
+        return pseudo_weights
+        # TODO:
+        # self.psi_diff_saves = np.concatenate(self.psi_diff_saves)
+
+    def _get_statistics(self, data):
+        # TODO: improve
+        summary_statistics = [[], [], []]
+        times = []
+        if isinstance(data, np.ndarray):
+            warn('What you are trying is probably stupid, at least the code is not implemented')
+            quit()
+            # assert len(data.shape) == 2
+            # for d in data:
+            #     counts[tuple(d)] += 1
+        else:
+            for i, d in enumerate(data):
+                clean_d = d[~np.isnan(d[:, -1])]
+                if len(clean_d) != 0:
+                    predictors, responses = clean_d[:, :-1], clean_d[:, -1]
+                    types, inverses, counts = np.unique(predictors, return_inverse=True, return_counts=True, axis=0)
+                    pseudo_counts = np.zeros(len(types))
+                    for j, c in enumerate(counts):
+                        mask = inverses == j
+                        pseudo_counts[j] = np.sum(responses[mask]) - c / 2
+                    summary_statistics[0].append(types)
+                    summary_statistics[1].append(pseudo_counts)
+                    summary_statistics[2].append(counts)
+                    times.append(i)
+
+        return summary_statistics, times
+
+    def compute_sigmas(self, T):
+        """Sigmas can be precomputed (without z), we do this here."""
+        # We rely on the fact that H.T.dot(sigma).dot(H) is just a number, no matrix inversion needed
+        # furthermore we use the fact that many matrices are identities, namely F and G
+        self.sigma_k = []  # we have to reset this for repeating this calculation later for the resampling (R changes)
+        self.sigma_k_k_minus = [self.P_0]
+        self.gain_save = []
+        for k in range(T):
+            if self.R[k] == 0:
+                self.gain_save.append(None)
+                self.sigma_k.append(self.sigma_k_k_minus[k])
+                self.sigma_k_k_minus.append(self.sigma_k[k] + self.pseudo_Q[k])
+            else:
+                sigma, H = self.sigma_k_k_minus[k], self.H[k]  # we will need this a lot, so shorten it
+                gain = sigma.dot(H).dot(1 / (H.T.dot(sigma).dot(H) + self.R[k]))
+                self.gain_save.append(gain)
+                self.sigma_k.append(sigma - gain.dot(H.T).dot(sigma))
+                self.sigma_k_k_minus.append(self.sigma_k[k] + self.pseudo_Q[k])
+
+    def compute_means(self, T):
+        """Compute the means, the estimates of the states."""
+        self.x_hat_k = []  # we have to reset this for repeating this calculation later for the resampling
+        self.x_hat_k_k_minus = [self.x_0]
+        for k in range(T):  # this will leave out last state which doesn't have observation
+            if self.gain_save[k] is None:
+                self.x_hat_k.append(self.x_hat_k_k_minus[k])
+                self.x_hat_k_k_minus.append(self.x_hat_k[k])  # TODO: still no purpose
+            else:
+                x, H = self.x_hat_k_k_minus[k], self.H[k]  # we will need this a lot, so shorten it
+                self.x_hat_k.append(x + self.gain_save[k].dot(self.pseudo_obs[k] - H.T.dot(x)))
+                self.x_hat_k_k_minus.append(self.x_hat_k[k])  # TODO: doesn't really have a purpose if F is identity
+
+    def num_parameters(self):
+        return self.weights.size
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/dynamic_multinomial.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/dynamic_multinomial.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbba312b67bcbdec2310614db0833e26a5029c47
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/dynamic_multinomial.py
@@ -0,0 +1,286 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+__all__ = ['Dynamic_Input_Categorical']
+
+from pybasicbayes.abstractions import \
+    GibbsSampling
+
+from scipy import sparse
+import numpy as np
+from warnings import warn
+import time
+
+from pgmult.lda import StickbreakingDynamicTopicsLDA
+from pgmult.utils import psi_to_pi
+from scipy.stats import invgamma
+
+
+def enforce_limits(times, limit):
+    times = np.array(times)
+    diffs = np.zeros(len(times), dtype=np.int32)
+    diffs[1:] = np.diff(times)
+    diffs[diffs <= limit] = limit
+    diffs -= limit
+    diffs = np.cumsum(diffs)
+
+    diffs += times[0]
+    times -= diffs
+    return times
+
+
+assert np.array_equal(enforce_limits([0, 1, 2, 3, 4, 5], 3), [0, 1, 2, 3, 4, 5])
+assert np.array_equal(enforce_limits([0, 2, 4, 6, 8, 10, 12, 14, 15], 3), [0, 2, 4, 6, 8, 10, 12, 14, 15])
+assert np.array_equal(enforce_limits([0, 1, 2, 6, 10, 14], 3), [0, 1, 2, 5, 8, 11])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100], 3), [0, 1, 4, 7, 10])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102], 3), [0, 1, 4, 7, 10, 11, 12])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102, 104], 3), [0, 1, 4, 7, 10, 11, 12, 14])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102, 104, 110], 3), [0, 1, 4, 7, 10, 11, 12, 14, 17])
+assert np.array_equal(enforce_limits([1, 8, 20, 100, 101, 102, 104, 110], 3), [0, 3, 6, 9, 10, 11, 13, 16])
+
+
+def meh_time_info(all_times, sub_times, limit):
+    # Return time stamps for the needed counts, considering jumps
+    # Return list of mappings, to translate from this states timepoints to the overall timepoints
+    # Return first and last timepoint
+    times = []
+    maps = []
+    first, last = all_times[sub_times[0]], all_times[sub_times[-1]]
+    jump_counter = 0
+    time_counter = 0
+    for i in range(first, last+1):
+        if i in all_times:
+            times.append(i)
+            maps.append((time_counter, i))
+            jump_counter = 0
+            time_counter += 1
+        else:
+            jump_counter += 1
+            if jump_counter <= limit:
+                times.append(i)
+                maps.append((time_counter, i))
+                time_counter += 1
+            else:
+                maps.append((time_counter - 1, i))
+
+    return times, maps, first, last
+
+
+all_times, sub_times, limit = [0, 3, 9], [0, 2], 3
+mya1, mya2, mya3, mya4 = [0, 1, 2, 3, 4, 5, 6, 9], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (6, 7), (6, 8), (7, 9)], 0, 9
+o1, o2, o3, o4 = meh_time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+all_times, sub_times, limit = [0, 3, 9], [1], 3
+mya1, mya2, mya3, mya4 = [3], [(0, 3)], 3, 3
+o1, o2, o3, o4 = meh_time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+
+
+def time_info(all_times, sub_times, limit):
+    # Return list of mappings, to translate from this states timepoints to the overall timepoints
+    # Return first and last timepoint
+    maps = []
+    first, last = all_times[sub_times[0]], all_times[sub_times[-1]]
+    jump_counter = 0
+    time_counter = 0
+    for i in range(first, last+1):
+        if i in all_times:
+            maps.append((time_counter, i))
+            jump_counter = 0
+            time_counter += 1
+        else:
+            jump_counter += 1
+            if jump_counter < limit:
+                maps.append((time_counter, i))
+                time_counter += 1
+            else:
+                maps.append((time_counter - 1, i))
+
+    return maps, first, last
+
+
+all_times, sub_times, limit = [0, 3, 9], [0, 2], 3
+mya2, mya3, mya4 = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (5, 6), (5, 7), (5, 8), (6, 9)], 0, 9
+o2, o3, o4 = time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+all_times, sub_times, limit = [0, 3, 9], [1], 3
+mya2, mya3, mya4 = [(0, 3)], 3, 3  # TODO: I don't think it matters whether first answer is [3] or [0]
+o2, o3, o4 = time_info(all_times, sub_times, limit)
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+# TODO: more tests
+
+class Dynamic_Input_Categorical(GibbsSampling):
+    """
+    This class enables a drifting input output iHMM.
+
+    Everything is pretty much like in the input Categorical, but here we also
+    allow drift across sessions for the Categoricals. Similar to a dynamic topic
+    model (with one topic).
+
+    We'll have to tricks quite a bit:
+    - Instead of really resampling we'll reinstate a new dtm for every resampling and resample a couple of times (time killer)
+    - We can't initialize from prior all that easily. We'll just instantiate a constant Dirichlet from prior
+      once data is actually there, we'll use the real model
+    - this will also be a bit hard to copy... maybe really just use package for resampling, otherwise have all that stuff saved in a list of Categoricals?
+    Big problem: how to give log-likelihood for obs in sessions where this state was previously not present?
+    We don't know what value this thing should have there, could be anything...
+    -> if there is no data, we simply have to sample from prior. That is the will of Gibbs
+    That is: sample from prior for first 3 unaccounted sessions, then leave constant
+    But!: How to do this back in time? -> also Gibbs sample just from prior for three sessions
+
+    ! First session cannot contain no data for StickbreakingDynamicTopicsLDA
+
+    Idea: maybe save previously assigned betas (or psi's or whatever directly)
+    then initialize new dtm from there, so as to have to do less resampling
+    (this depends on what gets resampled first, hopefully the auxiliary variables first, then this saving should have an effect)
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights, a vector encoding a finite, multidimensional pmf]
+    """
+
+    def __init__(self, n_inputs, n_outputs, T, sigmasq_states, jumplimit=3, n_resample=15):
+
+        self.n_inputs = n_inputs
+        self.n_outputs = n_outputs  # could be made different for different inputs
+        if self.n_outputs != 2:
+            warn('this requires adaptions in the row swapping code for the dynamic topic model')
+            # quit()
+        self.T = T
+        self.jumplimit = jumplimit
+        self.sigmasq_states = sigmasq_states
+        self.n_resample = n_resample
+        self.psi_diff_saves = []
+
+        single_weights = np.zeros((self.n_inputs, self.n_outputs))
+        for i in range(self.n_inputs):
+            single_weights[i] = np.random.dirichlet(np.ones(self.n_outputs))
+        self.weights = np.tile(single_weights, (self.T, 1, 1))  # init to constant over timepoints
+
+    def rvs(self, input):
+        print("Looks like simple copy from Input_Categorical, not useful, no dynamics")
+        types, counts = np.unique(input, return_counts=True)
+        output = np.zeros_like(input)
+        for t, c in zip(types, counts):
+            temp = np.random.choice(self.n_outputs, c, p=self.weights[t])
+            output[input == t] = temp
+        return np.array((input, output)).T
+
+    def log_likelihood(self, x, timepoint):
+        out = np.zeros(x.shape[0], dtype=np.double)
+        nans = np.isnan(x[:, -1])
+        err = np.seterr(divide='ignore')
+        out[~nans] = np.log(self.weights[timepoint])[tuple(x[~nans].T.astype(int))]
+        np.seterr(**err)
+        out[nans] = 1
+        return out
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        self.psi_diff_saves = []
+        counts, all_times = self._get_statistics(data)
+
+        # if state is never active, resample from prior
+        if counts.sum() == 0:
+            single_weights = np.zeros((self.n_inputs, self.n_outputs))
+            for i in range(self.n_inputs):
+                single_weights[i] = np.random.dirichlet(np.ones(self.n_outputs))
+            self.weights = np.tile(single_weights, (self.T, 1, 1))  # init to constant over timepoints
+            return
+
+        fake_times = enforce_limits(all_times, self.jumplimit)
+        self.weights.fill(0)
+
+        for i in range(self.n_inputs):
+            if np.sum(counts[:, i]) == 0:
+                self.weights[:, i] = np.random.dirichlet(np.ones(self.n_outputs))
+            else:
+                temp = np.sum(counts[:, i], axis=1)
+                spec_times = np.where(temp)[0]
+                maps, first_non0, last_non0 = time_info(all_times, spec_times, self.jumplimit)
+                spec_fake_times = fake_times[spec_times]
+                # we shuffle the columns around, so as to have the timeout answer first, for a hopefully more constistent variance estimation
+
+                # dtm = StickbreakingDynamicTopicsLDA(sparse.csr_matrix(counts[spec_times, i][..., [2, 0, 1]]), spec_fake_times, K=1, alpha_theta=1, sigmasq_states=self.sigmasq_states)
+                dtm = StickbreakingDynamicTopicsLDA(sparse.csr_matrix(counts[spec_times, i]), spec_fake_times, K=1, alpha_theta=1, sigmasq_states=self.sigmasq_states)
+
+                for _ in range(self.n_resample):
+                    dtm.resample()
+
+                # save for resampling sigma
+                self.psi_diff_saves.append(np.diff(dtm.psi, axis=0).ravel())
+
+                # put dtm weights in right places
+                for m in maps:
+                    # shuffle back
+                    # self.weights[m[1], i] = dtm.beta[m[0], :, 0][..., [1, 2, 0]]
+                    self.weights[m[1], i] = dtm.beta[m[0], :, 0]
+
+                sample = dtm.psi[0]
+                for j in range(min(self.jumplimit, first_non0)):
+                    sample += np.random.normal(0, np.sqrt(self.sigmasq_states), size=self.n_outputs - 1)[:, None]  # is this the correct way to do this? not sure
+                    # self.weights[first_non0 - j - 1, i] = psi_to_pi(sample.T)[..., [1, 2, 0]]
+                    self.weights[first_non0 - j - 1, i] = psi_to_pi(sample.T)
+                if first_non0 > self.jumplimit:
+                    # self.weights[:first_non0 - self.jumplimit, i] = psi_to_pi(sample.T)[..., [1, 2, 0]]
+                    self.weights[:first_non0 - self.jumplimit, i] = psi_to_pi(sample.T)
+
+                sample = dtm.psi[-1]
+                for j in range(min(self.jumplimit, self.T - last_non0 - 1)):
+                    sample += np.random.normal(0, np.sqrt(self.sigmasq_states), size=self.n_outputs - 1)[:, None]  # is this the correct way to do this? not sure
+                    # self.weights[last_non0 + j + 1, i] = psi_to_pi(sample.T)[..., [1, 2, 0]]
+                    self.weights[last_non0 + j + 1, i] = psi_to_pi(sample.T)
+                if self.T - last_non0 - 1 > self.jumplimit:
+                    # self.weights[last_non0 + self.jumplimit + 1:, i] = psi_to_pi(sample.T)[..., [1, 2, 0]]
+                    self.weights[last_non0 + self.jumplimit + 1:, i] = psi_to_pi(sample.T)
+
+        self.psi_diff_saves = np.concatenate(self.psi_diff_saves)
+        assert np.count_nonzero(np.sum(self.weights, axis=2)) == np.sum(self.weights, axis=2).size
+
+    def _get_statistics(self, data):
+        # TODO: improve
+        counts = []
+        times = []
+        timepoint_count = np.empty((self.n_inputs, self.n_outputs), dtype=int)
+        if isinstance(data, np.ndarray):
+            warn('What you are trying is probably stupid, at least the code is not implemented')
+            quit()
+            # assert len(data.shape) == 2
+            # for d in data:
+            #     counts[tuple(d)] += 1
+        else:
+            for i, d in enumerate(data):
+                clean_d = (d[~np.isnan(d[:, -1])]).astype(int)
+                if len(clean_d) != 0:
+                    timepoint_count[:] = 0
+                    for subd in clean_d:
+                        timepoint_count[subd[0], subd[1]] += 1
+                    counts.append(timepoint_count.copy())
+                    times.append(i)
+
+        return np.array(counts), times
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/dynamic_multinomial_in_progress.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/dynamic_multinomial_in_progress.py
new file mode 100644
index 0000000000000000000000000000000000000000..46f2b0f991f31bee904b6668f6436702c5ccf304
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/dynamic_multinomial_in_progress.py
@@ -0,0 +1,265 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+__all__ = ['Dynamic_Input_Categorical']
+
+from pybasicbayes.abstractions import \
+    GibbsSampling
+
+from scipy import sparse
+import numpy as np
+from warnings import warn
+import time
+
+from pgmult.lda import StickbreakingDynamicTopicsLDA
+from pgmult.utils import psi_to_pi
+
+
+def enforce_limits(times, limit):
+    times = np.array(times)
+    diffs = np.zeros(len(times), dtype=np.int32)
+    diffs[1:] = np.diff(times)
+    diffs[diffs <= limit] = limit
+    diffs -= limit
+    diffs = np.cumsum(diffs)
+
+    diffs += times[0]
+    times -= diffs
+    return times
+
+
+assert np.array_equal(enforce_limits([0, 1, 2, 3, 4, 5], 3), [0, 1, 2, 3, 4, 5])
+assert np.array_equal(enforce_limits([0, 2, 4, 6, 8, 10, 12, 14, 15], 3), [0, 2, 4, 6, 8, 10, 12, 14, 15])
+assert np.array_equal(enforce_limits([0, 1, 2, 6, 10, 14], 3), [0, 1, 2, 5, 8, 11])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100], 3), [0, 1, 4, 7, 10])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102], 3), [0, 1, 4, 7, 10, 11, 12])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102, 104], 3), [0, 1, 4, 7, 10, 11, 12, 14])
+assert np.array_equal(enforce_limits([0, 1, 8, 20, 100, 101, 102, 104, 110], 3), [0, 1, 4, 7, 10, 11, 12, 14, 17])
+assert np.array_equal(enforce_limits([1, 8, 20, 100, 101, 102, 104, 110], 3), [0, 3, 6, 9, 10, 11, 13, 16])
+
+
+def meh_time_info(all_times, sub_times, limit):
+    # Return time stamps for the needed counts, considering jumps
+    # Return list of mappings, to translate from this states timepoints to the overall timepoints
+    # Return first and last timepoint
+    times = []
+    maps = []
+    first, last = all_times[sub_times[0]], all_times[sub_times[-1]]
+    jump_counter = 0
+    time_counter = 0
+    for i in range(first, last+1):
+        if i in all_times:
+            times.append(i)
+            maps.append((time_counter, i))
+            jump_counter = 0
+            time_counter += 1
+        else:
+            jump_counter += 1
+            if jump_counter <= limit:
+                times.append(i)
+                maps.append((time_counter, i))
+                time_counter += 1
+            else:
+                maps.append((time_counter - 1, i))
+
+    return times, maps, first, last
+
+
+all_times, sub_times, limit = [0, 3, 9], [0, 2], 3
+mya1, mya2, mya3, mya4 = [0, 1, 2, 3, 4, 5, 6, 9], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (6, 7), (6, 8), (7, 9)], 0, 9
+o1, o2, o3, o4 = meh_time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+all_times, sub_times, limit = [0, 3, 9], [1], 3
+mya1, mya2, mya3, mya4 = [3], [(0, 3)], 3, 3
+o1, o2, o3, o4 = meh_time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+
+
+def time_info(all_times, sub_times, limit):
+    # Return list of mappings, to translate from this states timepoints to the overall timepoints
+    # Return first and last timepoint
+    maps = []
+    first, last = all_times[sub_times[0]], all_times[sub_times[-1]]
+    jump_counter = 0
+    time_counter = 0
+    for i in range(first, last+1):
+        if i in all_times:
+            maps.append((time_counter, i))
+            jump_counter = 0
+            time_counter += 1
+        else:
+            jump_counter += 1
+            if jump_counter < limit:
+                maps.append((time_counter, i))
+                time_counter += 1
+            else:
+                maps.append((time_counter - 1, i))
+
+    return maps, first, last
+
+
+all_times, sub_times, limit = [0, 3, 9], [0, 2], 3
+mya2, mya3, mya4 = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (5, 6), (5, 7), (5, 8), (6, 9)], 0, 9
+o2, o3, o4 = time_info(all_times, sub_times, limit)
+assert mya1 == o1
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+
+all_times, sub_times, limit = [0, 3, 9], [1], 3
+mya2, mya3, mya4 = [(0, 3)], 3, 3  # TODO: I don't think it matters whether first answer is [3] or [0]
+o2, o3, o4 = time_info(all_times, sub_times, limit)
+assert mya2 == o2
+assert mya3 == o3
+assert mya4 == o4
+# TODO: more tests
+
+class Dynamic_Input_Categorical(GibbsSampling):
+    """
+    This class enables a drifting input output iHMM.
+
+    Everything is pretty much like in the input Categorical, but here we also
+    allow drift across sessions for the Categoricals. Similar to a dynamic topic
+    model (with one topic).
+
+    We'll have to tricks quite a bit:
+    - Instead of really resampling we'll reinstate a new dtm for every resampling and resample a couple of times (time killer)
+    - We can't initialize from prior all that easily. We'll just instantiate a constant Dirichlet from prior
+      once data is actually there, we'll use the real model
+    - this will also be a bit hard to copy... maybe really just use package for resampling, otherwise have all that stuff saved in a list of Categoricals?
+    Big problem: how to give log-likelihood for obs in sessions where this state was previously not present?
+    We don't know what value this thing should have there, could be anything...
+    -> if there is no data, we simply have to sample from prior. That is the will of Gibbs
+    That is: sample from prior for first 3 unaccounted sessions, then leave constant
+    But!: How to do this back in time? -> also Gibbs sample just from prior for three sessions
+
+    ! First session cannot contain no data for StickbreakingDynamicTopicsLDA
+
+    Idea: maybe save previously assigned betas (or psi's or whatever directly)
+    then initialize new dtm from there, so as to have to do less resampling
+    (this depends on what gets resampled first, hopefully the auxiliary variables first, then this saving should have an effect)
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights, a vector encoding a finite, multidimensional pmf]
+    """
+
+    def __init__(self, n_inputs, n_outputs, T, sigmasq_states=0.01, jumplimit=3):
+
+        self.n_inputs = n_inputs
+        self.n_outputs = n_outputs  # could be made different for different inputs
+        self.T = T
+        self.jumplimit = jumplimit
+        self.sigmasq_states = sigmasq_states  # !!! this needs to be updated with gibbs estimate of sdev
+        self.psi_save = None
+
+        single_weights = np.zeros((self.n_inputs, self.n_outputs))
+        for i in range(self.n_inputs):
+            single_weights[i] = np.random.dirichlet(np.ones(self.n_outputs))
+        self.weights = np.tile(single_weights, (self.T, 1, 1))  # init to constant over timepoints
+
+    def rvs(self, input):
+        types, counts = np.unique(input, return_counts=True)
+        output = np.zeros_like(input)
+        for t, c in zip(types, counts):
+            temp = np.random.choice(self.n_outputs, c, p=self.weights[t])
+            output[input == t] = temp
+        return np.array((input, output)).T
+
+    def log_likelihood(self, x, timepoint):
+        out = np.zeros_like(x, dtype=np.double)
+        err = np.seterr(divide='ignore')
+        out = np.log(self.weights[timepoint])[tuple(x.T)]
+        np.seterr(**err)
+        return out
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        counts, all_times = self._get_statistics(data)
+
+        # if state is never active, resample from prior
+        if counts.sum() == 0:
+            single_weights = np.zeros((self.n_inputs, self.n_outputs))
+            for i in range(self.n_inputs):
+                single_weights[i] = np.random.dirichlet(np.ones(self.n_outputs))
+            self.weights = np.tile(single_weights, (self.T, 1, 1))  # init to constant over timepoints
+            return
+
+        fake_times = enforce_limits(all_times, self.jumplimit)
+        self.weights.fill(0)
+
+        for i in range(self.n_inputs):
+            # TODO: case if this counts is always 0
+            if np.sum(counts[:, i]) == 0:
+                self.weights[:, i] = np.random.dirichlet(np.ones(self.n_outputs))
+            else:
+                temp = np.sum(counts[:, i], axis=1)
+                spec_times = np.where(temp)[0]
+                maps, first_non0, last_non0 = time_info(all_times, spec_times, self.jumplimit)
+
+                # represented_sessions = [m[1] for m in maps]  # these (add boundary samples) are sessions sampled, for later initialisation
+
+                spec_fake_times = fake_times[spec_times]
+                dtm = StickbreakingDynamicTopicsLDA(sparse.csr_matrix(counts[spec_times, i]), spec_fake_times, K=1, alpha_theta=1, sigmasq_states=self.sigmasq_states)
+
+                dtm.resample()
+
+                for m in maps:
+                    self.weights[m[1], i] = dtm.beta[m[0], :, 0]
+
+                sample = dtm.psi[0]
+                for j in range(min(self.jumplimit, first_non0)):
+                    sample += np.random.normal(0, np.sqrt(self.sigmasq_states), size=self.n_outputs - 1)[:, None]  # is this the correct way to do this? not sure
+                    self.weights[first_non0 - j - 1, i] = psi_to_pi(sample.T)
+                if first_non0 > self.jumplimit:
+                    self.weights[:first_non0 - self.jumplimit, i] = psi_to_pi(sample.T)
+
+                sample = dtm.psi[-1]
+                for j in range(min(self.jumplimit, self.T - last_non0 - 1)):
+                    sample += np.random.normal(0, np.sqrt(self.sigmasq_states), size=self.n_outputs - 1)[:, None]  # is this the correct way to do this? not sure
+                    self.weights[last_non0 + j + 1, i] = psi_to_pi(sample.T)
+                if self.T - last_non0 - 1 > self.jumplimit:
+                    self.weights[last_non0 + self.jumplimit + 1:, i] = psi_to_pi(sample.T)
+
+        assert np.count_nonzero(np.sum(self.weights, axis=2)) == np.sum(self.weights, axis=2).size  # weird way to test whether everything is normalised?
+
+    def _get_statistics(self, data):
+        # TODO: improve
+        counts = []
+        times = []
+        timepoint_count = np.empty((self.n_inputs, self.n_outputs), dtype=int)
+        if isinstance(data, np.ndarray):
+            warn('What you are trying is probably stupid, at least the code is not implemented')
+            quit()
+            # assert len(data.shape) == 2
+            # for d in data:
+            #     counts[tuple(d)] += 1
+        else:
+            for i, d in enumerate(data):
+                if len(d) != 0:
+                    timepoint_count[:] = 0
+                    for subd in d:
+                        timepoint_count[subd[0], subd[1]] += 1
+                    counts.append(timepoint_count.copy())
+                    times.append(i)
+
+        return np.array(counts), times
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/gaussian.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/gaussian.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d045a5299756d68c3da4e2f3f7a85cbcfc4eb6
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/gaussian.py
@@ -0,0 +1,1518 @@
+from __future__ import division
+from builtins import map
+from builtins import zip
+from builtins import range
+from builtins import object
+__all__ = \
+    ['Gaussian', 'GaussianFixedMean', 'GaussianFixedCov', 'GaussianFixed',
+     'GaussianNonConj', 'DiagonalGaussian', 'DiagonalGaussianNonconjNIG',
+     'IsotropicGaussian', 'ScalarGaussianNIX', 'ScalarGaussianNonconjNIX',
+     'ScalarGaussianNonconjNIG', 'ScalarGaussianFixedvar']
+
+import numpy as np
+from numpy import newaxis as na
+from numpy.core.umath_tests import inner1d
+import scipy.linalg
+import scipy.stats as stats
+import scipy.special as special
+import copy
+
+from pybasicbayes.abstractions import GibbsSampling, MeanField, \
+    MeanFieldSVI, Collapsed, MaxLikelihood, MAP, Tempering
+from pybasicbayes.distributions.meta import _FixedParamsMixin
+from pybasicbayes.util.stats import sample_niw, invwishart_entropy, \
+    sample_invwishart, invwishart_log_partitionfunction, \
+    getdatasize, flattendata, getdatadimension, \
+    combinedata, multivariate_t_loglik, gi, niw_expectedstats
+
+weps = 1e-12
+
+
+class _GaussianBase(object):
+    @property
+    def params(self):
+        return dict(mu=self.mu, sigma=self.sigma)
+
+    @property
+    def D(self):
+        return self.mu.shape[0]
+
+    ### internals
+
+    def getsigma(self):
+        return self._sigma
+
+    def setsigma(self,sigma):
+        self._sigma = sigma
+        self._sigma_chol = None
+
+    sigma = property(getsigma,setsigma)
+
+    @property
+    def sigma_chol(self):
+        if not hasattr(self,'_sigma_chol') or self._sigma_chol is None:
+            self._sigma_chol = np.linalg.cholesky(self.sigma)
+        return self._sigma_chol
+
+    ### distribution stuff
+
+    def rvs(self,size=None):
+        size = 1 if size is None else size
+        size = size + (self.mu.shape[0],) if isinstance(size,tuple) \
+            else (size,self.mu.shape[0])
+        return self.mu + np.random.normal(size=size).dot(self.sigma_chol.T)
+
+    def log_likelihood(self,x):
+        try:
+            mu, D = self.mu, self.D
+            sigma_chol = self.sigma_chol
+            bads = np.isnan(np.atleast_2d(x)).any(axis=1)
+            x = np.nan_to_num(x).reshape((-1,D)) - mu
+            xs = scipy.linalg.solve_triangular(sigma_chol,x.T,lower=True)
+            out = -1./2. * inner1d(xs.T,xs.T) - D/2*np.log(2*np.pi) \
+                - np.log(sigma_chol.diagonal()).sum()
+            out[bads] = 0
+            return out
+        except np.linalg.LinAlgError:
+            # NOTE: degenerate distribution doesn't have a density
+            return np.repeat(-np.inf,x.shape[0])
+
+    ### plotting
+
+    # TODO making animations, this seems to generate an extra notebook figure
+
+    _scatterplot = None
+    _parameterplot = None
+
+    def plot(self,ax=None,data=None,indices=None,color='b',
+             plot_params=True,label='',alpha=1.,
+             update=False,draw=True):
+        import matplotlib.pyplot as plt
+        from pybasicbayes.util.plot import project_data, \
+                plot_gaussian_projection, plot_gaussian_2D
+        ax = ax if ax else plt.gca()
+        D = self.D
+        if data is not None:
+            data = flattendata(data)
+
+        if data is not None:
+            if D > 2:
+                plot_basis = np.random.RandomState(seed=0).randn(2,D)
+                data = project_data(data,plot_basis)
+            if update and self._scatterplot is not None:
+                self._scatterplot.set_offsets(data)
+                self._scatterplot.set_color(color)
+            else:
+                self._scatterplot = ax.scatter(
+                    data[:,0],data[:,1],marker='.',color=color)
+
+        if plot_params:
+            if D > 2:
+                plot_basis = np.random.RandomState(seed=0).randn(2,D)
+                self._parameterplot = \
+                    plot_gaussian_projection(
+                        self.mu,self.sigma,plot_basis,
+                        color=color,label=label,alpha=min(1-1e-3,alpha),
+                        ax=ax, artists=self._parameterplot if update else None)
+            else:
+                self._parameterplot = \
+                    plot_gaussian_2D(
+                        self.mu,self.sigma,color=color,label=label,
+                        alpha=min(1-1e-3,alpha), ax=ax,
+                        artists=self._parameterplot if update else None)
+
+        if draw:
+            plt.draw()
+
+        return [self._scatterplot] + list(self._parameterplot)
+
+    def to_json_dict(self):
+        D = self.mu.shape[0]
+        assert D == 2
+        U,s,_ = np.linalg.svd(self.sigma)
+        U /= np.linalg.det(U)
+        theta = np.arctan2(U[0,0],U[0,1])*180/np.pi
+        return {'x':self.mu[0],'y':self.mu[1],'rx':np.sqrt(s[0]),
+                'ry':np.sqrt(s[1]), 'theta':theta}
+
+
+class Gaussian(
+        _GaussianBase, GibbsSampling, MeanField, MeanFieldSVI,
+        Collapsed, MAP, MaxLikelihood):
+    '''
+    Multivariate Gaussian distribution class.
+
+    NOTE: Only works for 2 or more dimensions. For a scalar Gaussian, use a
+    scalar class.  Uses a conjugate Normal/Inverse-Wishart prior.
+
+    Hyperparameters mostly follow Gelman et al.'s notation in Bayesian Data
+    Analysis:
+        nu_0, sigma_0, mu_0, kappa_0
+
+    Parameters are mean and covariance matrix:
+        mu, sigma
+    '''
+
+    def __init__(
+            self, mu=None, sigma=None,
+            mu_0=None, sigma_0=None, kappa_0=None, nu_0=None):
+        self.mu = mu
+        self.sigma = sigma
+
+        self.mu_0    = self.mu_mf    = mu_0
+        self.sigma_0 = self.sigma_mf = sigma_0
+        self.kappa_0 = self.kappa_mf = kappa_0
+        self.nu_0    = self.nu_mf    = nu_0
+
+        # NOTE: resampling will set mu_mf and sigma_mf if necessary
+        if mu is sigma is None \
+                and not any(_ is None for _ in (mu_0,sigma_0,kappa_0,nu_0)):
+            self.resample()  # initialize from prior
+        if mu is not None and sigma is not None \
+                and not any(_ is None for _ in (mu_0,sigma_0,kappa_0,nu_0)):
+            self.mu_mf = mu
+            self.sigma_mf = sigma * (self.nu_0 - self.mu_mf.shape[0] - 1)
+
+    @property
+    def hypparams(self):
+        return dict(
+            mu_0=self.mu_0,sigma_0=self.sigma_0,
+            kappa_0=self.kappa_0,nu_0=self.nu_0)
+
+    @property
+    def natural_hypparam(self):
+        return self._standard_to_natural(
+            self.mu_0,self.sigma_0,self.kappa_0,self.nu_0)
+
+    @natural_hypparam.setter
+    def natural_hypparam(self,natparam):
+        self.mu_0, self.sigma_0, self.kappa_0, self.nu_0 = \
+            self._natural_to_standard(natparam)
+
+    def _standard_to_natural(self,mu_mf,sigma_mf,kappa_mf,nu_mf):
+        D = sigma_mf.shape[0]
+        out = np.zeros((D+2,D+2))
+        out[:D,:D] = sigma_mf + kappa_mf * np.outer(mu_mf,mu_mf)
+        out[:D,-2] = out[-2,:D] = kappa_mf * mu_mf
+        out[-2,-2] = kappa_mf
+        out[-1,-1] = nu_mf + 2 + D
+        return out
+
+    def _natural_to_standard(self,natparam):
+        D = natparam.shape[0]-2
+        A = natparam[:D,:D]
+        b = natparam[:D,-2]
+        c = natparam[-2,-2]
+        d = natparam[-1,-1]
+        return b/c, A - np.outer(b,b)/c, c, d - 2 - D
+
+    @property
+    def num_parameters(self):
+        D = self.D
+        return D*(D+1)/2
+
+    @property
+    def D(self):
+        if self.mu is not None:
+            return self.mu.shape[0]
+        elif self.mu_0 is not None:
+            return self.mu_0.shape[0]
+
+    def _get_statistics(self,data,D=None):
+        if D is None:
+            D = self.D if self.D is not None else getdatadimension(data)
+        out = np.zeros((D+2,D+2))
+        if isinstance(data,np.ndarray):
+            out[:D,:D] = data.T.dot(data)
+            out[-2,:D] = out[:D,-2] = data.sum(0)
+            out[-2,-2] = out[-1,-1] = data.shape[0]
+            return out
+        else:
+            return sum(list(map(self._get_statistics,data)),out)
+
+    def _get_weighted_statistics(self,data,weights,D=None):
+        D = getdatadimension(data) if D is None else D
+        out = np.zeros((D+2,D+2))
+        if isinstance(data,np.ndarray):
+            out[:D,:D] = data.T.dot(weights[:,na]*data)
+            out[-2,:D] = out[:D,-2] = weights.dot(data)
+            out[-2,-2] = out[-1,-1] = weights.sum()
+            return out
+        else:
+            return sum(list(map(self._get_weighted_statistics,data,weights)),out)
+
+    def _get_empty_statistics(self, D):
+        out = np.zeros((D+2,D+2))
+        return out
+
+    def empirical_bayes(self,data):
+        self.natural_hypparam = self._get_statistics(data)
+        self.resample()  # intialize from prior given new hyperparameters
+        return self
+
+    @staticmethod
+    def _stats_ensure_array(stats):
+        if isinstance(stats, np.ndarray):
+            return stats
+        x, xxT, n = stats
+        D = x.shape[-1]
+        out = np.zeros((D+2,D+2))
+        out[:D,:D] = xxT
+        out[-2,:D] = out[:D,-2] = x
+        out[-2,-2] = out[-1,-1] = n
+        return out
+
+    ### Gibbs sampling
+
+    def resample(self,data=[]):
+        D = len(self.mu_0)
+        self.mu, self.sigma = \
+            sample_niw(*self._natural_to_standard(
+                self.natural_hypparam + self._get_statistics(data,D)))
+        # NOTE: next lines let Gibbs sampling initialize mean
+        nu = self.nu_mf if hasattr(self,'nu_mf') and self.nu_mf \
+            else self.nu_0
+        self.mu_mf, self._sigma_mf = self.mu, self.sigma * (nu - D - 1)
+        return self
+
+    def copy_sample(self):
+        new = copy.copy(self)
+        new.mu = self.mu.copy()
+        new.sigma = self.sigma.copy()
+        return new
+
+    ### Mean Field
+
+    def _resample_from_mf(self):
+        self.mu, self.sigma = \
+            sample_niw(*self._natural_to_standard(
+                self.mf_natural_hypparam))
+        return self
+
+    def meanfieldupdate(self, data=None, weights=None, stats=None):
+        assert (data is not None and weights is not None) ^ (stats is not None)
+        stats = self._stats_ensure_array(stats) if stats is not None else \
+            self._get_weighted_statistics(data, weights, self.mu_0.shape[0])
+        self.mf_natural_hypparam = \
+            self.natural_hypparam + stats
+
+    def meanfield_sgdstep(self,data,weights,prob,stepsize):
+        D = len(self.mu_0)
+        self.mf_natural_hypparam = \
+            (1-stepsize) * self.mf_natural_hypparam + stepsize * (
+                self.natural_hypparam
+                + 1./prob
+                * self._get_weighted_statistics(data,weights,D))
+
+    @property
+    def mf_natural_hypparam(self):
+        return self._standard_to_natural(
+            self.mu_mf,self.sigma_mf,self.kappa_mf,self.nu_mf)
+
+    @mf_natural_hypparam.setter
+    def mf_natural_hypparam(self,natparam):
+        self.mu_mf, self.sigma_mf, self.kappa_mf, self.nu_mf = \
+            self._natural_to_standard(natparam)
+        # NOTE: next line is for plotting
+        self.mu, self.sigma = \
+            self.mu_mf, self.sigma_mf/(self.nu_mf - self.mu_mf.shape[0] - 1)
+
+    @property
+    def sigma_mf(self):
+        return self._sigma_mf
+
+    @sigma_mf.setter
+    def sigma_mf(self,val):
+        self._sigma_mf = val
+        self._sigma_mf_chol = None
+
+    @property
+    def sigma_mf_chol(self):
+        if self._sigma_mf_chol is None:
+            self._sigma_mf_chol = np.linalg.cholesky(self.sigma_mf)
+        return self._sigma_mf_chol
+
+    def get_vlb(self):
+        D = len(self.mu_0)
+        loglmbdatilde = self._loglmbdatilde()
+
+        # see Eq. 10.77 in Bishop
+        q_entropy = -0.5 * (loglmbdatilde + D * (np.log(self.kappa_mf/(2*np.pi))-1)) \
+            + invwishart_entropy(self.sigma_mf,self.nu_mf)
+        # see Eq. 10.74 in Bishop, we aren't summing over K
+        p_avgengy = 0.5 * (D * np.log(self.kappa_0/(2*np.pi)) + loglmbdatilde
+            - D*self.kappa_0/self.kappa_mf - self.kappa_0*self.nu_mf*
+            np.dot(self.mu_mf -
+                self.mu_0,np.linalg.solve(self.sigma_mf,self.mu_mf - self.mu_0))) \
+            - invwishart_log_partitionfunction(self.sigma_0,self.nu_0) \
+            + (self.nu_0 - D - 1)/2*loglmbdatilde - 1/2*self.nu_mf \
+            * np.linalg.solve(self.sigma_mf,self.sigma_0).trace()
+
+        return p_avgengy + q_entropy
+
+    def expected_log_likelihood(self, x=None, stats=None):
+        assert (x is not None) ^ isinstance(stats, (tuple, np.ndarray))
+
+        if x is not None:
+            mu_n, kappa_n, nu_n = self.mu_mf, self.kappa_mf, self.nu_mf
+            D = len(mu_n)
+            x = np.reshape(x,(-1,D)) - mu_n  # x is now centered
+            xs = np.linalg.solve(self.sigma_mf_chol,x.T)
+
+            # see Eqs. 10.64, 10.67, and 10.71 in Bishop
+            return self._loglmbdatilde()/2 - D/(2*kappa_n) - nu_n/2 * \
+                inner1d(xs.T,xs.T) - D/2*np.log(2*np.pi)
+        else:
+            D = self.mu_mf.shape[0]
+
+            E_J, E_h, E_muJmuT, E_logdetJ = \
+                niw_expectedstats(
+                    self.nu_mf, self.sigma_mf, self.mu_mf, self.kappa_mf)
+
+            if isinstance(stats, np.ndarray):
+                parammat = np.zeros((D+2,D+2))
+                parammat[:D,:D] = E_J
+                parammat[:D,-2] = parammat[-2,:D] = -E_h
+                parammat[-2,-2] = E_muJmuT
+                parammat[-1,-1] = -E_logdetJ
+
+                contract = 'ij,nij->n' if stats.ndim == 3 else 'ij,ij->'
+                return -1./2*np.einsum(contract, parammat, stats) \
+                    - D/2.*np.log(2*np.pi)
+            else:
+                x, xxT, n = stats
+                c1, c2 = ('i,i->', 'ij,ij->') if x.ndim == 1 \
+                    else ('i,ni->n', 'ij,nij->n')
+
+                out = -1./2 * np.einsum(c2, E_J, xxT)
+                out += np.einsum(c1, E_h, x)
+                out += -n/2.*E_muJmuT
+                out += -D/2.*np.log(2*np.pi) + n/2.*E_logdetJ
+
+                return out
+
+    def _loglmbdatilde(self):
+        # see Eq. 10.65 in Bishop
+        D = len(self.mu_0)
+        chol = self.sigma_mf_chol
+        return special.digamma((self.nu_mf-np.arange(D))/2.).sum() \
+            + D*np.log(2) - 2*np.log(chol.diagonal()).sum()
+
+    ### Collapsed
+
+    def log_marginal_likelihood(self,data):
+        n, D = getdatasize(data), len(self.mu_0)
+        return self._log_partition_function(
+            *self._natural_to_standard(
+                self.natural_hypparam + self._get_statistics(data,D))) \
+            - self._log_partition_function(self.mu_0,self.sigma_0,self.kappa_0,self.nu_0) \
+            - n*D/2 * np.log(2*np.pi)
+
+    def _log_partition_function(self,mu,sigma,kappa,nu):
+        D = len(mu)
+        chol = np.linalg.cholesky(sigma)
+        return nu*D/2*np.log(2) + special.multigammaln(nu/2,D) + D/2*np.log(2*np.pi/kappa) \
+            - nu*np.log(chol.diagonal()).sum()
+
+    def log_predictive_studentt_datapoints(self,datapoints,olddata):
+        D = len(self.mu_0)
+        mu_n, sigma_n, kappa_n, nu_n = \
+            self._natural_to_standard(
+                self.natural_hypparam + self._get_statistics(olddata,D))
+        return multivariate_t_loglik(
+            datapoints,nu_n-D+1,mu_n,(kappa_n+1)/(kappa_n*(nu_n-D+1))*sigma_n)
+
+    def log_predictive_studentt(self,newdata,olddata):
+        newdata = np.atleast_2d(newdata)
+        return sum(self.log_predictive_studentt_datapoints(
+            d,combinedata((olddata,newdata[:i])))[0] for i,d in enumerate(newdata))
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        D = getdatadimension(data)
+        if weights is None:
+            statmat = self._get_statistics(data,D)
+        else:
+            statmat = self._get_weighted_statistics(data,weights,D)
+
+        n, x, xxt = statmat[-1,-1], statmat[-2,:D], statmat[:D,:D]
+
+        # this SVD is necessary to check if the max likelihood solution is
+        # degenerate, which can happen in the EM algorithm
+        if n < D or (np.linalg.svd(xxt,compute_uv=False) > 1e-6).sum() < D:
+            self.broken = True
+            self.mu = 99999999*np.ones(D)
+            self.sigma = np.eye(D)
+        else:
+            self.mu = x/n
+            self.sigma = xxt/n - np.outer(self.mu,self.mu)
+
+        return self
+
+    def MAP(self,data,weights=None):
+        D = getdatadimension(data)
+        # max likelihood with prior pseudocounts included in data
+        if weights is None:
+            statmat = self._get_statistics(data)
+        else:
+            statmat = self._get_weighted_statistics(data,weights)
+        statmat += self.natural_hypparam
+
+        n, x, xxt = statmat[-1,-1], statmat[-2,:D], statmat[:D,:D]
+
+        self.mu = x/n
+        self.sigma = xxt/n - np.outer(self.mu,self.mu)
+
+        return self
+
+
+class GaussianFixedMean(_GaussianBase, GibbsSampling, MaxLikelihood):
+    def __init__(self,mu=None,sigma=None,nu_0=None,lmbda_0=None):
+        self.sigma = sigma
+
+        self.mu = mu
+
+        self.nu_0 = nu_0
+        self.lmbda_0 = lmbda_0
+
+        if sigma is None and not any(_ is None for _ in (nu_0,lmbda_0)):
+            self.resample()  # initialize from prior
+
+    @property
+    def hypparams(self):
+        return dict(nu_0=self.nu_0,lmbda_0=self.lmbda_0)
+
+    @property
+    def num_parameters(self):
+        D = len(self.mu)
+        return D*(D+1)/2
+
+    def _get_statistics(self,data):
+        n = getdatasize(data)
+        if n > 1e-4:
+            if isinstance(data,np.ndarray):
+                centered = data[gi(data)] - self.mu
+                sumsq = centered.T.dot(centered)
+                n = len(centered)
+            else:
+                sumsq = sum((d[gi(d)]-self.mu).T.dot(d[gi(d)]-self.mu) for d in data)
+        else:
+            sumsq = None
+        return n, sumsq
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(data,np.ndarray):
+            neff = weights.sum()
+            if neff > weps:
+                centered = data - self.mu
+                sumsq = centered.T.dot(weights[:,na]*centered)
+            else:
+                sumsq = None
+        else:
+            neff = sum(w.sum() for w in weights)
+            if neff > weps:
+                sumsq = sum((d-self.mu).T.dot(w[:,na]*(d-self.mu)) for w,d in zip(weights,data))
+            else:
+                sumsq = None
+
+        return neff, sumsq
+
+    def _posterior_hypparams(self,n,sumsq):
+        nu_0, lmbda_0 = self.nu_0, self.lmbda_0
+        if n > 1e-4:
+            nu_0 = nu_0 + n
+            sigma_n = self.lmbda_0 + sumsq
+            return sigma_n, nu_0
+        else:
+            return lmbda_0, nu_0
+
+    ### Gibbs sampling
+
+    def resample(self, data=[]):
+        self.sigma = sample_invwishart(*self._posterior_hypparams(
+            *self._get_statistics(data)))
+        return self
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        D = getdatadimension(data)
+        if weights is None:
+            n, sumsq = self._get_statistics(data)
+        else:
+            n, sumsq = self._get_weighted_statistics(data,weights)
+
+        if n < D or (np.linalg.svd(sumsq,compute_uv=False) > 1e-6).sum() < D:
+            # broken!
+            self.sigma = np.eye(D)*1e-9
+            self.broken = True
+        else:
+            self.sigma = sumsq/n
+
+        return self
+
+
+class GaussianFixedCov(_GaussianBase, GibbsSampling, MaxLikelihood):
+    # See Gelman's Bayesian Data Analysis notation around Eq. 3.18, p. 85
+    # in 2nd Edition. We replaced \Lambda_0 with sigma_0 since it is a prior
+    # *covariance* matrix rather than a precision matrix.
+    def __init__(self,mu=None,sigma=None,mu_0=None,sigma_0=None):
+        self.mu = mu
+
+        self.sigma = sigma
+
+        self.mu_0 = mu_0
+        self.sigma_0 = sigma_0
+
+        if mu is None and not any(_ is None for _ in (mu_0,sigma_0)):
+            self.resample()
+
+    @property
+    def hypparams(self):
+        return dict(mu_0=self.mu_0,sigma_0=self.sigma_0)
+
+    @property
+    def sigma_inv(self):
+        if not hasattr(self,'_sigma_inv'):
+            self._sigma_inv = np.linalg.inv(self.sigma)
+        return self._sigma_inv
+
+    @property
+    def sigma_inv_0(self):
+        if not hasattr(self,'_sigma_inv_0'):
+            self._sigma_inv_0 = np.linalg.inv(self.sigma_0)
+        return self._sigma_inv_0
+
+    @property
+    def num_parameters(self):
+        return len(self.mu)
+
+    def _get_statistics(self,data):
+        n = getdatasize(data)
+        if n > 0:
+            if isinstance(data,np.ndarray):
+                xbar = data.mean(0)
+            else:
+                xbar = sum(d.sum(0) for d in data) / n
+        else:
+            xbar = None
+
+        return n, xbar
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(data,np.ndarray):
+            neff = weights.sum()
+            if neff > weps:
+                xbar = weights.dot(data) / neff
+            else:
+                xbar = None
+        else:
+            neff = sum(w.sum() for w in weights)
+            if neff > weps:
+                xbar = sum(w.dot(d) for w,d in zip(weights,data)) / neff
+            else:
+                xbar = None
+
+        return neff, xbar
+
+    def _posterior_hypparams(self,n,xbar):
+        # It seems we should be working with lmbda and sigma inv (unless lmbda
+        # is a covariance, not a precision)
+        sigma_inv, mu_0, sigma_inv_0 = self.sigma_inv, self.mu_0, self.sigma_inv_0
+        if n > 0:
+            sigma_inv_n = n*sigma_inv + sigma_inv_0
+            mu_n = np.linalg.solve(
+                sigma_inv_n, sigma_inv_0.dot(mu_0) + n*sigma_inv.dot(xbar))
+            return mu_n, sigma_inv_n
+        else:
+            return mu_0, sigma_inv_0
+
+    ### Gibbs sampling
+
+    def resample(self,data=[]):
+        mu_n, sigma_n_inv = self._posterior_hypparams(*self._get_statistics(data))
+        D = len(mu_n)
+        L = np.linalg.cholesky(sigma_n_inv)
+        self.mu = scipy.linalg.solve_triangular(L,np.random.normal(size=D),lower=True) \
+            + mu_n
+        return self
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        if weights is None:
+            n, xbar = self._get_statistics(data)
+        else:
+            n, xbar = self._get_weighted_statistics(data,weights)
+
+        self.mu = xbar
+        return self
+
+
+class GaussianFixed(_FixedParamsMixin, Gaussian):
+    def __init__(self,mu,sigma):
+        self.mu = mu
+        self.sigma = sigma
+
+class GaussianNonConj(_GaussianBase, GibbsSampling):
+    def __init__(self,mu=None,sigma=None,
+            mu_0=None,mu_lmbda_0=None,nu_0=None,sigma_lmbda_0=None):
+        self._sigma_distn = GaussianFixedMean(mu=mu,
+                nu_0=nu_0,lmbda_0=sigma_lmbda_0,sigma=sigma)
+        self._mu_distn = GaussianFixedCov(sigma=self._sigma_distn.sigma,
+                mu_0=mu_0, sigma_0=mu_lmbda_0,mu=mu)
+        self._sigma_distn.mu = self._mu_distn.mu
+
+    @property
+    def hypparams(self):
+        d = self._mu_distn.hypparams
+        d.update(**self._sigma_distn.hypparams)
+        return d
+
+    def _get_mu(self):
+        return self._mu_distn.mu
+
+    def _set_mu(self,val):
+        self._mu_distn.mu = val
+        self._sigma_distn.mu = val
+
+    mu = property(_get_mu,_set_mu)
+
+    def _get_sigma(self):
+        return self._sigma_distn.sigma
+
+    def _set_sigma(self,val):
+        self._sigma_distn.sigma = val
+        self._mu_distn.sigma = val
+
+    sigma = property(_get_sigma,_set_sigma)
+
+    ### Gibbs sampling
+
+    def resample(self,data=[],niter=1):
+        if getdatasize(data) == 0:
+            niter = 1
+
+        # TODO this is kinda dumb because it collects statistics over and over
+        # instead of updating them...
+        for itr in range(niter):
+            # resample mu
+            self._mu_distn.sigma = self._sigma_distn.sigma
+            self._mu_distn.resample(data)
+
+            # resample sigma
+            self._sigma_distn.mu = self._mu_distn.mu
+            self._sigma_distn.resample(data)
+
+        return self
+
+
+# TODO collapsed
+class DiagonalGaussian(_GaussianBase,GibbsSampling,MaxLikelihood,MeanField,Tempering):
+    '''
+    Product of normal-inverse-gamma priors over mu (mean vector) and sigmas
+    (vector of scalar variances).
+
+    The prior follows
+        sigmas     ~ InvGamma(alphas_0,betas_0) iid
+        mu | sigma ~ N(mu_0,1/nus_0 * diag(sigmas))
+
+    It allows placing different prior hyperparameters on different components.
+    '''
+
+    def __init__(self,mu=None,sigmas=None,mu_0=None,nus_0=None,alphas_0=None,betas_0=None):
+        # all the s's refer to the fact that these are vectors of length
+        # len(mu_0) OR scalars
+        if mu_0 is not None:
+            D = mu_0.shape[0]
+            if nus_0 is not None and \
+                    (isinstance(nus_0,int) or isinstance(nus_0,float)):
+                nus_0 = nus_0*np.ones(D)
+            if alphas_0 is not None and \
+                    (isinstance(alphas_0,int) or isinstance(alphas_0,float)):
+                alphas_0 = alphas_0*np.ones(D)
+            if betas_0 is not None and \
+                    (isinstance(betas_0,int) or isinstance(betas_0,float)):
+                betas_0 = betas_0*np.ones(D)
+
+        self.mu_0 = self.mf_mu = mu_0
+        self.nus_0 = self.mf_nus = nus_0
+        self.alphas_0 = self.mf_alphas = alphas_0
+        self.betas_0 = self.mf_betas = betas_0
+
+        self.mu = mu
+        self.sigmas = sigmas
+
+        assert self.mu is None or (isinstance(self.mu,np.ndarray) and not isinstance(self.mu,np.ma.MaskedArray))
+        assert self.sigmas is None or (isinstance(self.sigmas,np.ndarray) and not isinstance(self.sigmas,np.ma.MaskedArray))
+
+        if mu is sigmas is None \
+                and not any(_ is None for _ in (mu_0,nus_0,alphas_0,betas_0)):
+            self.resample() # intialize from prior
+
+    ### the basics!
+
+    @property
+    def parameters(self):
+        return self.mu, self.sigmas
+
+    @parameters.setter
+    def parameters(self, mu_sigmas_tuple):
+        (mu,sigmas) = mu_sigmas_tuple
+        self.mu, self.sigmas = mu, sigmas
+
+    @property
+    def sigma(self):
+        return np.diag(self.sigmas)
+
+    @sigma.setter
+    def sigma(self,val):
+        val = np.array(val)
+        assert val.ndim in (1,2)
+        if val.ndim == 1:
+            self.sigmas = val
+        else:
+            self.sigmas = np.diag(val)
+
+    @property
+    def hypparams(self):
+        return dict(mu_0=self.mu_0,nus_0=self.nus_0,
+                alphas_0=self.alphas_0,betas_0=self.betas_0)
+
+    def rvs(self,size=None):
+        size = np.array(size,ndmin=1)
+        return np.sqrt(self.sigmas)*\
+            np.random.normal(size=np.concatenate((size,self.mu.shape))) + self.mu
+
+    def log_likelihood(self,x,temperature=1.):
+        mu, sigmas, D = self.mu, self.sigmas * temperature, self.mu.shape[0]
+        x = np.reshape(x,(-1,D))
+        Js = -1./(2*sigmas)
+        return (np.einsum('ij,ij,j->i',x,x,Js) - np.einsum('ij,j,j->i',x,2*mu,Js)) \
+            + (mu**2*Js - 1./2*np.log(2*np.pi*sigmas)).sum()
+
+    ### posterior updating stuff
+
+    @property
+    def natural_hypparam(self):
+        return self._standard_to_natural(self.alphas_0,self.betas_0,self.mu_0,self.nus_0)
+
+    @natural_hypparam.setter
+    def natural_hypparam(self,natparam):
+        self.alphas_0, self.betas_0, self.mu_0, self.nus_0 = \
+            self._natural_to_standard(natparam)
+
+    def _standard_to_natural(self,alphas,betas,mu,nus):
+        return np.array([2*betas + nus * mu**2, nus*mu, nus, 2*alphas])
+
+    def _natural_to_standard(self,natparam):
+        nus = natparam[2]
+        mu = natparam[1] / nus
+        alphas = natparam[3]/2.
+        betas = (natparam[0] - nus*mu**2) / 2.
+        return alphas, betas, mu, nus
+
+    def _get_statistics(self,data):
+        if isinstance(data,np.ndarray) and data.shape[0] > 0:
+            data = data[gi(data)]
+            ns = np.repeat(*data.shape)
+            return np.array([
+                np.einsum('ni,ni->i',data,data),
+                np.einsum('ni->i',data),
+                ns,
+                ns,
+                ])
+        else:
+            return sum((self._get_statistics(d) for d in data), self._empty_stats())
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(data,np.ndarray):
+            idx = ~np.isnan(data).any(1)
+            data = data[idx]
+            weights = weights[idx]
+            assert data.ndim == 2 and weights.ndim == 1 \
+                and data.shape[0] == weights.shape[0]
+            neff = np.repeat(weights.sum(),data.shape[1])
+            return np.array([weights.dot(data**2), weights.dot(data), neff, neff])
+        else:
+            return sum(
+                (self._get_weighted_statistics(d,w) for d, w in zip(data,weights)),
+                self._empty_stats())
+
+    def _empty_stats(self):
+        return np.zeros_like(self.natural_hypparam)
+
+    ### Gibbs sampling
+
+    def resample(self,data=[],temperature=1.,stats=None):
+        stats = self._get_statistics(data) if stats is None else stats
+
+        alphas_n, betas_n, mu_n, nus_n = self._natural_to_standard(
+            self.natural_hypparam + stats / temperature)
+
+        D = mu_n.shape[0]
+        self.sigmas = 1/np.random.gamma(alphas_n,scale=1/betas_n)
+        self.mu = np.sqrt(self.sigmas/nus_n)*np.random.randn(D) + mu_n
+
+        assert not np.isnan(self.mu).any()
+        assert not np.isnan(self.sigmas).any()
+
+        # NOTE: next line is to use Gibbs sampling to initialize mean field
+        self.mf_mu = self.mu
+
+        assert self.sigmas.ndim == 1
+        return self
+
+    def copy_sample(self):
+        new = copy.copy(self)
+        new.mu = self.mu.copy()
+        new.sigmas = self.sigmas.copy()
+        return new
+
+    ### max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        if weights is None:
+            n, muhat, sumsq = self._get_statistics(data)
+        else:
+            n, muhat, sumsq = self._get_weighted_statistics_old(data,weights)
+
+        self.mu = muhat
+        self.sigmas = sumsq/n
+
+        return self
+
+    ### Mean Field
+
+    @property
+    def mf_natural_hypparam(self):
+        return self._standard_to_natural(self.mf_alphas,self.mf_betas,self.mf_mu,self.mf_nus)
+
+    @mf_natural_hypparam.setter
+    def mf_natural_hypparam(self,natparam):
+        self.mf_alphas, self.mf_betas, self.mf_mu, self.mf_nus = \
+            self._natural_to_standard(natparam)
+        # NOTE: this part is for plotting
+        self.mu = self.mf_mu
+        self.sigmas = np.where(self.mf_alphas > 1,self.mf_betas / (self.mf_alphas - 1),100000)
+
+    def meanfieldupdate(self,data,weights):
+        self.mf_natural_hypparam = \
+            self.natural_hypparam + self._get_weighted_statistics(data,weights)
+
+    def meanfield_sgdstep(self,data,weights,prob,stepsize):
+        self.mf_natural_hypparam = \
+            (1-stepsize) * self.mf_natural_hypparam + stepsize * (
+                self.natural_hypparam
+                + 1./prob * self._get_weighted_statistics(data,weights))
+
+    def get_vlb(self):
+        natparam_diff = self.natural_hypparam - self.mf_natural_hypparam
+        expected_stats = self._expected_statistics(
+            self.mf_alphas,self.mf_betas,self.mf_mu,self.mf_nus)
+        linear_term = sum(v1.dot(v2) for v1, v2 in zip(natparam_diff, expected_stats))
+
+        normalizer_term = \
+            self._log_Z(self.alphas_0,self.betas_0,self.mu_0,self.nus_0) \
+            - self._log_Z(self.mf_alphas,self.mf_betas,self.mf_mu,self.mf_nus)
+
+        return linear_term - normalizer_term - len(self.mf_mu)/2. * np.log(2*np.pi)
+
+    def expected_log_likelihood(self,x):
+        x = np.atleast_2d(x).reshape((-1,len(self.mf_mu)))
+        a,b,c,d = self._expected_statistics(
+            self.mf_alphas,self.mf_betas,self.mf_mu,self.mf_nus)
+        return (x**2).dot(a) + x.dot(b) + c.sum() + d.sum() \
+            - len(self.mf_mu)/2. * np.log(2*np.pi)
+
+    def _expected_statistics(self,alphas,betas,mu,nus):
+        return np.array([
+            -1./2 * alphas/betas,
+            mu * alphas/betas,
+            -1./2 * (1./nus + mu**2 * alphas/betas),
+            -1./2 * (np.log(betas) - special.digamma(alphas))])
+
+    def _log_Z(self,alphas,betas,mu,nus):
+        return (special.gammaln(alphas) - alphas*np.log(betas) - 1./2*np.log(nus)).sum()
+
+# TODO meanfield
+class DiagonalGaussianNonconjNIG(_GaussianBase,GibbsSampling):
+    '''
+    Product of normal priors over mu and product of gamma priors over sigmas.
+    Note that while the conjugate prior in DiagonalGaussian is of the form
+    p(mu,sigmas), this prior is of the form p(mu)p(sigmas). Therefore its
+    resample() update has to perform inner iterations.
+
+    The prior follows
+        mu     ~ N(mu_0,diag(sigmas_0))
+        sigmas ~ InvGamma(alpha_0,beta_0) iid
+    '''
+
+    def __init__(self,mu=None,sigmas=None,mu_0=None,sigmas_0=None,alpha_0=None,beta_0=None,
+            niter=20):
+        self.mu_0, self.sigmas_0 = mu_0, sigmas_0
+        self.alpha_0, self.beta_0 = alpha_0, beta_0
+
+        self.niter = niter
+
+        if None in (mu,sigmas):
+            self.resample()
+        else:
+            self.mu, self.sigmas = mu, sigmas
+
+    @property
+    def hypparams(self):
+        return dict(mu_0=self.mu_0,sigmas_0=self.sigmas_0,alpha_0=self.alpha_0,beta_0=self.beta_0)
+
+    # TODO next three methods are copied from DiagonalGaussian, factor them out
+
+    @property
+    def sigma(self):
+        return np.diag(self.sigmas)
+
+    def rvs(self,size=None):
+        size = np.array(size,ndmin=1)
+        return np.sqrt(self.sigmas)*\
+            np.random.normal(size=np.concatenate((size,self.mu.shape))) + self.mu
+
+    def log_likelihood(self,x):
+        mu, sigmas, D = self.mu, self.sigmas, self.mu.shape[0]
+        x = np.reshape(x,(-1,D))
+        Js = -1./(2*sigmas)
+        return (np.einsum('ij,ij,j->i',x,x,Js) - np.einsum('ij,j,j->i',x,2*mu,Js)) \
+            + (mu**2*Js - 1./2*np.log(2*np.pi*sigmas)).sum()
+
+
+    def resample(self,data=[]):
+        n, y, ysq = self._get_statistics(data)
+        if n == 0:
+            self.mu = np.sqrt(self.sigmas_0) * np.random.randn(self.mu_0.shape[0]) + self.mu_0
+            self.sigmas = 1./np.random.gamma(self.alpha_0,scale=1./self.beta_0)
+        else:
+            for itr in range(self.niter):
+                sigmas_n = 1./(1./self.sigmas_0 + n / self.sigmas)
+                mu_n = (self.mu_0 / self.sigmas_0 + y / self.sigmas) * sigmas_n
+                self.mu = np.sqrt(sigmas_n) * np.random.randn(mu_n.shape[0]) + mu_n
+
+                alphas_n = self.alpha_0 + 1./2*n
+                betas_n = self.beta_0 + 1./2*(ysq + n*self.mu**2 - 2*self.mu*y)
+                self.sigmas = 1./np.random.gamma(alphas_n,scale=1./betas_n)
+        return self
+
+    def _get_statistics(self,data):
+        # TODO dont forget to handle nans
+        assert isinstance(data,(list,np.ndarray)) and not isinstance(data,np.ma.MaskedArray)
+        if isinstance(data,np.ndarray):
+            data = data[gi(data)]
+            n = data.shape[0]
+            y = np.einsum('ni->i',data)
+            ysq = np.einsum('ni,ni->i',data,data)
+            return np.array([n,y,ysq],dtype=np.object)
+        else:
+            return sum((self._get_statistics(d) for d in data),self._empty_stats)
+
+    @property
+    def _empty_stats(self):
+        return np.array([0.,np.zeros_like(self.mu_0),np.zeros_like(self.mu_0)],
+                dtype=np.object)
+
+# TODO collapsed, meanfield, max_likelihood
+class IsotropicGaussian(GibbsSampling):
+    '''
+    Normal-Inverse-Gamma prior over mu (mean vector) and sigma (scalar
+    variance). Essentially, all coordinates of all observations inform the
+    variance.
+
+    The prior follows
+        sigma      ~ InvGamma(alpha_0,beta_0)
+        mu | sigma ~ N(mu_0,sigma/nu_0 * I)
+    '''
+
+    def __init__(self,mu=None,sigma=None,mu_0=None,nu_0=None,alpha_0=None,beta_0=None):
+        self.mu = mu
+        self.sigma = sigma
+
+        self.mu_0 = mu_0
+        self.nu_0 = nu_0
+        self.alpha_0 = alpha_0
+        self.beta_0 = beta_0
+
+        if mu is sigma is None and not any(_ is None for _ in (mu_0,nu_0,alpha_0,beta_0)):
+            self.resample() # intialize from prior
+
+    @property
+    def hypparams(self):
+        return dict(mu_0=self.mu_0,nu_0=self.nu_0,alpha_0=self.alpha_0,beta_0=self.beta_0)
+
+    def rvs(self,size=None):
+        return np.sqrt(self.sigma)*np.random.normal(size=tuple(size)+self.mu.shape) + self.mu
+
+    def log_likelihood(self,x):
+        mu, sigma, D = self.mu, self.sigma, self.mu.shape[0]
+        x = np.reshape(x,(-1,D))
+        return (-0.5*((x-mu)**2).sum(1)/sigma - D*np.log(np.sqrt(2*np.pi*sigma)))
+
+    def _posterior_hypparams(self,n,xbar,sumsq):
+        mu_0, nu_0, alpha_0, beta_0 = self.mu_0, self.nu_0, self.alpha_0, self.beta_0
+        D = mu_0.shape[0]
+        if n > 0:
+            nu_n = D*n + nu_0
+            alpha_n = alpha_0 + D*n/2
+            beta_n = beta_0 + 1/2*sumsq + (n*D*nu_0)/(n*D+nu_0) * 1/2 * ((xbar - mu_0)**2).sum()
+            mu_n = (n*xbar + nu_0*mu_0)/(n+nu_0)
+
+            return mu_n, nu_n, alpha_n, beta_n
+        else:
+            return mu_0, nu_0, alpha_0, beta_0
+
+    ### Gibbs sampling
+
+    def resample(self,data=[]):
+        mu_n, nu_n, alpha_n, beta_n = self._posterior_hypparams(
+            *self._get_statistics(data, D=self.mu_0.shape[0]))
+        D = mu_n.shape[0]
+        self.sigma = 1/np.random.gamma(alpha_n,scale=1/beta_n)
+        self.mu = np.sqrt(self.sigma/nu_n)*np.random.randn(D)+mu_n
+        return self
+
+    def _get_statistics(self,data, D=None):
+        n = getdatasize(data)
+        if n > 0:
+            D = D if D else getdatadimension(data)
+            if isinstance(data,np.ndarray):
+                assert (data.ndim == 1 and data.shape == (D,)) \
+                    or (data.ndim == 2 and data.shape[1] == D)
+                data = np.reshape(data,(-1,D))
+                xbar = data.mean(0)
+                sumsq = ((data-xbar)**2).sum()
+            else:
+                xbar = sum(np.reshape(d,(-1,D)).sum(0) for d in data) / n
+                sumsq = sum(((np.reshape(data,(-1,D)) - xbar)**2).sum() for d in data)
+        else:
+            xbar, sumsq = None, None
+        return n, xbar, sumsq
+
+
+class _ScalarGaussianBase(object):
+    @property
+    def params(self):
+        return dict(mu=self.mu,sigmasq=self.sigmasq)
+
+    def rvs(self,size=None):
+        return np.sqrt(self.sigmasq)*np.random.normal(size=size)+self.mu
+
+    def log_likelihood(self,x):
+        x = np.reshape(x,(-1,1))
+        return (-0.5*(x-self.mu)**2/self.sigmasq - np.log(np.sqrt(2*np.pi*self.sigmasq))).ravel()
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(mu=%f,sigmasq=%f)' % (self.mu,self.sigmasq)
+
+    def plot(self,data=None,indices=None,color='b',plot_params=True,label=None):
+        import matplotlib.pyplot as plt
+        data = np.concatenate(data) if data is not None else None
+        indices = np.concatenate(indices) if indices is not None else None
+
+        if data is not None:
+            assert indices is not None
+            plt.plot(indices,data,color=color,marker='x',linestyle='')
+
+        if plot_params:
+            assert indices is not None
+            if len(indices) > 1:
+                from util.general import rle
+                vals, lens = rle(np.diff(indices))
+                starts = np.concatenate(((0,),lens.cumsum()[:-1]))
+                for start, blocklen in zip(starts[vals == 1], lens[vals == 1]):
+                    plt.plot(indices[start:start+blocklen],
+                            np.repeat(self.mu,blocklen),color=color,linestyle='--')
+            else:
+                plt.plot(indices,[self.mu],color=color,marker='+')
+
+    ### mostly shared statistics gathering
+
+    def _get_statistics(self,data):
+        n = getdatasize(data)
+        if n > 0:
+            if isinstance(data,np.ndarray):
+                ybar = data.mean()
+                centered = data.ravel() - ybar
+                sumsqc = centered.dot(centered)
+            elif isinstance(data,list):
+                ybar = sum(d.sum() for d in data)/n
+                sumsqc = sum((d.ravel()-ybar).dot(d.ravel()-ybar) for d in data)
+            else:
+                ybar = data
+                sumsqc = 0
+        else:
+            ybar = None
+            sumsqc = None
+
+        return n, ybar, sumsqc
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(data,np.ndarray):
+            neff = weights.sum()
+            if neff > weps:
+                ybar = weights.dot(data.ravel()) / neff
+                centered = data.ravel() - ybar
+                sumsqc = centered.dot(weights*centered)
+            else:
+                ybar = None
+                sumsqc = None
+        elif isinstance(data,list):
+            neff = sum(w.sum() for w in weights)
+            if neff > weps:
+                ybar = sum(w.dot(d.ravel()) for d,w in zip(data,weights)) / neff
+                sumsqc = sum((d.ravel()-ybar).dot(w*(d.ravel()-ybar))
+                        for d,w in zip(data,weights))
+            else:
+                ybar = None
+                sumsqc = None
+        else:
+            ybar = data
+            sumsqc = 0
+
+        return neff, ybar, sumsqc
+
+    ### max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        if weights is None:
+            n, ybar, sumsqc = self._get_statistics(data)
+        else:
+            n, ybar, sumsqc = self._get_weighted_statistics(data,weights)
+
+        if sumsqc > 0:
+            self.mu = ybar
+            self.sigmasq = sumsqc/n
+        else:
+            self.broken = True
+            self.mu = 999999999.
+            self.sigmsq = 1.
+
+        return self
+
+class ScalarGaussianNIX(_ScalarGaussianBase, GibbsSampling, Collapsed):
+    '''
+    Conjugate Normal-(Scaled-)Inverse-ChiSquared prior. (Another parameterization is the
+    Normal-Inverse-Gamma.)
+    '''
+    def __init__(self,mu=None,sigmasq=None,mu_0=None,kappa_0=None,sigmasq_0=None,nu_0=None):
+        self.mu = mu
+        self.sigmasq = sigmasq
+
+        self.mu_0 = mu_0
+        self.kappa_0 = kappa_0
+        self.sigmasq_0 = sigmasq_0
+        self.nu_0 = nu_0
+
+        if mu is sigmasq is None \
+                and not any(_ is None for _ in (mu_0,kappa_0,sigmasq_0,nu_0)):
+            self.resample() # intialize from prior
+
+    @property
+    def hypparams(self):
+        return dict(mu_0=self.mu_0,kappa_0=self.kappa_0,
+                sigmasq_0=self.sigmasq_0,nu_0=self.nu_0)
+
+    def _posterior_hypparams(self,n,ybar,sumsqc):
+        mu_0, kappa_0, sigmasq_0, nu_0 = self.mu_0, self.kappa_0, self.sigmasq_0, self.nu_0
+        if n > 0:
+            kappa_n = kappa_0 + n
+            mu_n = (kappa_0 * mu_0 + n * ybar) / kappa_n
+            nu_n = nu_0 + n
+            sigmasq_n = 1/nu_n * (nu_0 * sigmasq_0 + sumsqc + kappa_0 * n / (kappa_0 + n) * (ybar - mu_0)**2)
+
+            return mu_n, kappa_n, sigmasq_n, nu_n
+        else:
+            return mu_0, kappa_0, sigmasq_0, nu_0
+
+    ### Gibbs sampling
+
+    def resample(self,data=[]):
+        mu_n, kappa_n, sigmasq_n, nu_n = self._posterior_hypparams(*self._get_statistics(data))
+        self.sigmasq = nu_n * sigmasq_n / np.random.chisquare(nu_n)
+        self.mu = np.sqrt(self.sigmasq / kappa_n) * np.random.randn() + mu_n
+        return self
+
+    ### Collapsed
+
+    def log_marginal_likelihood(self,data):
+        n = getdatasize(data)
+        kappa_0, sigmasq_0, nu_0 = self.kappa_0, self.sigmasq_0, self.nu_0
+        mu_n, kappa_n, sigmasq_n, nu_n = self._posterior_hypparams(*self._get_statistics(data))
+        return special.gammaln(nu_n/2) - special.gammaln(nu_0/2) \
+            + 0.5*(np.log(kappa_0) - np.log(kappa_n)
+                   + nu_0 * (np.log(nu_0) + np.log(sigmasq_0))
+                   - nu_n * (np.log(nu_n) + np.log(sigmasq_n))
+                   - n*np.log(np.pi))
+
+    def log_predictive_single(self,y,olddata):
+        # mostly for testing or speed
+        mu_n, kappa_n, sigmasq_n, nu_n = self._posterior_hypparams(*self._get_statistics(olddata))
+        return stats.t.logpdf(y,nu_n,loc=mu_n,scale=np.sqrt((1+kappa_n)*sigmasq_n/kappa_n))
+
+
+class ScalarGaussianNonconjNIX(_ScalarGaussianBase, GibbsSampling):
+    '''
+    Non-conjugate separate priors on mean and variance parameters, via
+    mu ~ Normal(mu_0,tausq_0)
+    sigmasq ~ (Scaled-)Inverse-ChiSquared(sigmasq_0,nu_0)
+    '''
+    def __init__(self,mu=None,sigmasq=None,mu_0=None,tausq_0=None,sigmasq_0=None,nu_0=None,
+            niter=1):
+        self.mu, self.sigmasq = mu, sigmasq
+        self.mu_0, self.tausq_0 = mu_0, tausq_0
+        self.sigmasq_0, self.nu_0 = sigmasq_0, nu_0
+
+        self.niter = niter
+
+        if mu is sigmasq is None \
+                and not any(_ is None for _ in (mu_0, tausq_0, sigmasq_0, nu_0)):
+            self.resample() # intialize from prior
+
+    @property
+    def hypparams(self):
+        return dict(mu_0=self.mu_0,tausq_0=self.tausq_0,
+                sigmasq_0=self.sigmasq_0,nu_0=self.nu_0)
+
+    def resample(self,data=[],niter=None):
+        n = getdatasize(data)
+        niter = self.niter if niter is None else niter
+        if n > 0:
+            data = flattendata(data)
+            datasum = data[gi(data)].sum()
+            datasqsum = (data[gi(data)]**2).sum()
+            nu_n = self.nu_0 + n
+            for itr in range(niter):
+                # resample mean
+                tausq_n = 1/(1/self.tausq_0 + n/self.sigmasq)
+                mu_n = tausq_n*(self.mu_0/self.tausq_0 + datasum/self.sigmasq)
+                self.mu = np.sqrt(tausq_n)*np.random.normal() + mu_n
+                # resample variance
+                sigmasq_n = (self.nu_0*self.sigmasq_0 + (datasqsum + n*self.mu**2-2*datasum*self.mu))/nu_n
+                self.sigmasq = sigmasq_n*nu_n/np.random.chisquare(nu_n)
+        else:
+            self.mu = np.sqrt(self.tausq_0) * np.random.normal() + self.mu_0
+            self.sigmasq = self.sigmasq_0*self.nu_0/np.random.chisquare(self.nu_0)
+
+        return self
+
+class ScalarGaussianNonconjNIG(_ScalarGaussianBase, MeanField, MeanFieldSVI):
+    # NOTE: this is like ScalarGaussianNonconjNiIG except prior is in natural
+    # coordinates
+
+    def __init__(self,h_0,J_0,alpha_0,beta_0,
+            mu=None,sigmasq=None,
+            h_mf=None,J_mf=None,alpha_mf=None,beta_mf=None,niter=1):
+        self.h_0, self.J_0 = h_0, J_0
+        self.alpha_0, self.beta_0 = alpha_0, beta_0
+
+        self.h_mf = h_mf if h_mf is not None else J_0 * np.random.normal(h_0/J_0,1./np.sqrt(J_0))
+        self.J_mf = J_mf if J_mf is not None else J_0
+        self.alpha_mf = alpha_mf if alpha_mf is not None else alpha_0
+        self.beta_mf = beta_mf if beta_mf is not None else beta_0
+
+        self.niter = niter
+
+        self.mu = mu if mu is not None else np.random.normal(h_0/J_0,1./np.sqrt(J_0))
+        self.sigmasq = sigmasq if sigmasq is not None else 1./np.random.gamma(alpha_0,1./beta_0)
+
+    @property
+    def hypparams(self):
+        return dict(h_0=self.h_0,J_0=self.J_0,alpha_0=self.alpha_0,beta_0=self.beta_0)
+
+    @property
+    def _E_mu(self):
+        # E[mu], E[mu**2]
+        return self.h_mf / self.J_mf, 1./self.J_mf + (self.h_mf / self.J_mf)**2
+
+    @property
+    def _E_sigmasq(self):
+        # E[1/sigmasq], E[ln sigmasq]
+        return self.alpha_mf / self.beta_mf, \
+            np.log(self.beta_mf) - special.digamma(self.alpha_mf)
+
+    @property
+    def natural_hypparam(self):
+        return np.array([self.alpha_0,self.beta_0,self.h_0,self.J_0])
+
+    @natural_hypparam.setter
+    def natural_hypparam(self,natural_hypparam):
+        self.alpha_0, self.beta_0, self.h_0, self.J_0 = natural_hypparam
+
+    @property
+    def mf_natural_hypparam(self):
+        return np.array([self.alpha_mf,self.beta_mf,self.h_mf,self.J_mf])
+
+    @mf_natural_hypparam.setter
+    def mf_natural_hypparam(self,mf_natural_hypparam):
+        self.alpha_mf, self.beta_mf, self.h_mf, self.J_mf = mf_natural_hypparam
+        # set point estimates of (mu, sigmasq) for plotting and stuff
+        self.mu, self.sigmasq = self.h_mf / self.J_mf, self.beta_mf / (self.alpha_mf-1)
+
+    def _resample_from_mf(self):
+        self.mu, self.sigmasq = np.random.normal(self.h_mf/self.J_mf,np.sqrt(1./self.J_mf)), \
+            np.random.gamma(self.alpha_mf,1./self.beta_mf)
+        return self
+
+    def expected_log_likelihood(self,x):
+        (Emu, Emu2), (Esigmasqinv, Elnsigmasq) = self._E_mu, self._E_sigmasq
+        return -1./2 * Esigmasqinv * (x**2 + Emu2 - 2*x*Emu) \
+            - 1./2*Elnsigmasq - 1./2*np.log(2*np.pi)
+
+    def get_vlb(self):
+        # E[ln p(mu) / q(mu)] part
+        h_0, J_0, J_mf = self.h_0, self.J_0, self.J_mf
+        Emu, Emu2 = self._E_mu
+        p_mu_avgengy = -1./2*J_0*Emu2 + h_0*Emu \
+            - 1./2*(h_0**2/J_0) + 1./2*np.log(J_0) - 1./2*np.log(2*np.pi)
+        q_mu_entropy = 1./2*np.log(2*np.pi*np.e/J_mf)
+
+        # E[ln p(sigmasq) / q(sigmasq)] part
+        alpha_0, beta_0, alpha_mf, beta_mf = \
+            self.alpha_0, self.beta_0, self.alpha_mf, self.beta_mf
+        (Esigmasqinv, Elnsigmasq) = self._E_sigmasq
+        p_sigmasq_avgengy = (-alpha_0-1)*Elnsigmasq + (-beta_0)*Esigmasqinv \
+            - (special.gammaln(alpha_0) - alpha_0*np.log(beta_0))
+        q_sigmasq_entropy = alpha_mf + np.log(beta_mf) + special.gammaln(alpha_mf) \
+            - (1+alpha_mf)*special.digamma(alpha_mf)
+
+        return p_mu_avgengy + q_mu_entropy + p_sigmasq_avgengy + q_sigmasq_entropy
+
+    def meanfield_sgdstep(self,data,weights,prob,stepsize):
+        # like meanfieldupdate except we step the factors simultaneously
+
+        # NOTE: unlike the fully conjugate case, there are interaction terms, so
+        # we work on the destructured pieces
+        neff, y, ysq = self._get_weighted_statistics(data,weights)
+        Emu, _ = self._E_mu
+        Esigmasqinv, _ = self._E_sigmasq
+
+
+        # form new natural hyperparameters as if doing a batch update
+        alpha_new = self.alpha_0 + 1./prob * 1./2*neff
+        beta_new = self.beta_0 + 1./prob * 1./2*(ysq + neff*Emu**2 - 2*Emu*y)
+
+        h_new = self.h_0 + 1./prob * Esigmasqinv * y
+        J_new = self.J_0 + 1./prob * Esigmasqinv * neff
+
+
+        # take a step
+        self.alpha_mf = (1-stepsize)*self.alpha_mf + stepsize*alpha_new
+        self.beta_mf = (1-stepsize)*self.beta_mf + stepsize*beta_new
+
+        self.h_mf = (1-stepsize)*self.h_mf + stepsize*h_new
+        self.J_mf = (1-stepsize)*self.J_mf + stepsize*J_new
+
+        # calling this setter will set point estimates for (mu,sigmasq) for
+        # plotting and sampling and stuff
+        self.mf_natural_hypparam = (self.alpha_mf, self.beta_mf, self.h_mf, self.J_mf)
+
+        return self
+
+    def meanfieldupdate(self,data,weights,niter=None):
+        niter = niter if niter is not None else self.niter
+        neff, y, ysq = self._get_weighted_statistics(data,weights)
+        for niter in range(niter):
+            # update q(sigmasq)
+            Emu, _ = self._E_mu
+
+            self.alpha_mf = self.alpha_0 + 1./2*neff
+            self.beta_mf = self.beta_0 + 1./2*(ysq + neff*Emu**2 - 2*Emu*y)
+
+            # update q(mu)
+            Esigmasqinv, _ = self._E_sigmasq
+
+            self.h_mf = self.h_0 + Esigmasqinv * y
+            self.J_mf = self.J_0 + Esigmasqinv * neff
+
+        # calling this setter will set point estimates for (mu,sigmasq) for
+        # plotting and sampling and stuff
+        self.mf_natural_hypparam = \
+            (self.alpha_mf, self.beta_mf, self.h_mf, self.J_mf)
+
+        return self
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(data,np.ndarray):
+            neff = weights.sum()
+            y = weights.dot(data)
+            ysq = weights.dot(data**2)
+        else:
+            return sum(
+                self._get_weighted_statistics(d,w) for d,w in zip(data,weights))
+        return np.array([neff,y,ysq])
+
+
+class ScalarGaussianFixedvar(_ScalarGaussianBase, GibbsSampling):
+    '''
+    Conjugate normal prior on mean.
+    '''
+    def __init__(self,mu=None,sigmasq=None,mu_0=None,tausq_0=None):
+        self.mu = mu
+
+        self.sigmasq = sigmasq
+
+        self.mu_0 = mu_0
+        self.tausq_0 = tausq_0
+
+        if mu is None and not any(_ is None for _ in (mu_0,tausq_0)):
+            self.resample()  # intialize from prior
+
+    @property
+    def hypparams(self):
+        return dict(mu_0=self.mu_0,tausq_0=self.tausq_0)
+
+    def _posterior_hypparams(self,n,xbar):
+        mu_0, tausq_0 = self.mu_0, self.tausq_0
+        sigmasq = self.sigmasq
+        if n > 0:
+            tausq_n = 1/(1/tausq_0 + n/sigmasq)
+            mu_n = (mu_0/tausq_0 + n*xbar/sigmasq)*tausq_n
+
+            return mu_n, tausq_n
+        else:
+            return mu_0, tausq_0
+
+    def resample(self,data=[]):
+        mu_n, tausq_n = self._posterior_hypparams(*self._get_statistics(data))
+        self.mu = np.sqrt(tausq_n)*np.random.randn()+mu_n
+        return self
+
+    def _get_statistics(self,data):
+        n = getdatasize(data)
+        if n > 0:
+            if isinstance(data,np.ndarray):
+                xbar = data.mean()
+            else:
+                xbar = sum(d.sum() for d in data)/n
+        else:
+            xbar = None
+        return n, xbar
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(data,np.ndarray):
+            neff = weights.sum()
+        else:
+            neff = sum(w.sum() for w in weights)
+
+        if neff > weps:
+            if isinstance(data,np.ndarray):
+                xbar = data.dot(weights) / neff
+            else:
+                xbar = sum(w.dot(d) for d,w in zip(data,weights)) / neff
+        else:
+            xbar = None
+
+        return neff, xbar
+
+    def max_likelihood(self,data,weights=None):
+        if weights is None:
+            _, xbar = self._get_statistics(data)
+        else:
+            _, xbar = self._get_weighted_statistics(data,weights)
+
+        self.mu = xbar
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/geometric.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..88413003c887396e869bfba8240b5b325c6897aa
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/geometric.py
@@ -0,0 +1,147 @@
+from __future__ import division
+from builtins import zip
+__all__ = ['Geometric']
+
+import numpy as np
+import scipy.stats as stats
+import scipy.special as special
+from warnings import warn
+
+from pybasicbayes.abstractions import GibbsSampling, MeanField, \
+    Collapsed, MaxLikelihood
+
+
+class Geometric(GibbsSampling, MeanField, Collapsed, MaxLikelihood):
+    '''
+    Geometric distribution with a conjugate beta prior.
+    The support is {1,2,3,...}.
+
+    Hyperparameters:
+        alpha_0, beta_0
+
+    Parameter is the success probability:
+        p
+    '''
+    def __init__(self,alpha_0=None,beta_0=None,p=None):
+        self.p = p
+
+        self.alpha_0 = self.mf_alpha_0 = alpha_0
+        self.beta_0 = self.mf_beta_0 = beta_0
+
+        if p is None and not any(_ is None for _ in (alpha_0,beta_0)):
+            self.resample() # intialize from prior
+
+    @property
+    def params(self):
+        return dict(p=self.p)
+
+    @property
+    def hypparams(self):
+        return dict(alpha_0=self.alpha_0,beta_0=self.beta_0)
+
+    def _posterior_hypparams(self,n,tot):
+        return self.alpha_0 + n, self.beta_0 + tot
+
+    def log_likelihood(self,x):
+        x = np.array(x,ndmin=1)
+        raw = np.empty(x.shape)
+        raw[x>0] = (x[x>0]-1.)*np.log(1.-self.p) + np.log(self.p)
+        raw[x<1] = -np.inf
+        return raw if isinstance(x,np.ndarray) else raw[0]
+
+    def log_sf(self,x):
+        return stats.geom.logsf(x,self.p)
+
+    def pmf(self,x):
+        return stats.geom.pmf(x,self.p)
+
+    def rvs(self,size=None):
+        return np.random.geometric(self.p,size=size)
+
+    def _get_statistics(self,data):
+        if isinstance(data,np.ndarray):
+            n = data.shape[0]
+            tot = data.sum() - n
+        elif isinstance(data,list):
+            n = sum(d.shape[0] for d in data)
+            tot = sum(d.sum() for d in data) - n
+        else:
+            assert np.isscalar(data)
+            n = 1
+            tot = data-1
+        return n, tot
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(data,np.ndarray):
+             n = weights.sum()
+             tot = weights.dot(data) - n
+        elif isinstance(data,list):
+            n = sum(w.sum() for w in weights)
+            tot = sum(w.dot(d) for w,d in zip(weights,data)) - n
+        else:
+            assert np.isscalar(data) and np.isscalar(weights)
+            n = weights
+            tot = weights*data - 1
+
+        return n, tot
+
+    ### Gibbs sampling
+
+    def resample(self,data=[]):
+        self.p = np.random.beta(*self._posterior_hypparams(*self._get_statistics(data)))
+
+        # initialize mean field
+        self.alpha_mf = self.p*(self.alpha_0+self.beta_0)
+        self.beta_mf = (1-self.p)*(self.alpha_0+self.beta_0)
+
+        return self
+
+    ### mean field
+
+    def meanfieldupdate(self,data,weights,stats=None):
+        warn('untested')
+        n, tot = self._get_weighted_statistics(data,weights) if stats is None else stats
+        self.alpha_mf = self.alpha_0 + n
+        self.beta_mf = self.beta_0 + tot
+
+        # initialize Gibbs
+        self.p = self.alpha_mf / (self.alpha_mf + self.beta_mf)
+
+    def get_vlb(self):
+        warn('untested')
+        Elnp, Eln1mp = self._expected_statistics(self.alpha_mf,self.beta_mf)
+        return (self.alpha_0 - self.alpha_mf)*Elnp \
+                + (self.beta_0 - self.beta_mf)*Eln1mp \
+                - (self._log_partition_function(self.alpha_0,self.beta_0)
+                        - self._log_partition_function(self.alpha_mf,self.beta_mf))
+
+    def expected_log_likelihood(self,x):
+        warn('untested')
+        Elnp, Eln1mp = self._expected_statistics(self.alpha_mf,self.beta_mf)
+        return (x-1)*Eln1mp + Elnp1mp
+
+    def _expected_statistics(self,alpha,beta):
+        warn('untested')
+        Elnp = special.digamma(alpha) - special.digamma(alpha+beta)
+        Eln1mp = special.digamma(beta) - special.digamma(alpha+beta)
+        return Elnp, Eln1mp
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        if weights is None:
+            n, tot = self._get_statistics(data)
+        else:
+            n, tot = self._get_weighted_statistics(data,weights)
+
+        self.p = n/tot
+        return self
+
+    ### Collapsed
+
+    def log_marginal_likelihood(self,data):
+        return self._log_partition_function(*self._posterior_hypparams(*self._get_statistics(data))) \
+            - self._log_partition_function(self.alpha_0,self.beta_0)
+
+    def _log_partition_function(self,alpha,beta):
+        return special.betaln(alpha,beta)
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/meta.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2ee51943b75e91e0c5ac553592f8731802494ba
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/meta.py
@@ -0,0 +1,121 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+__all__ = ['_FixedParamsMixin', 'ProductDistribution']
+
+import numpy as np
+
+from pybasicbayes.abstractions import Distribution, \
+    GibbsSampling, MeanField, MeanFieldSVI, MaxLikelihood
+from pybasicbayes.util.stats import atleast_2d
+
+
+class _FixedParamsMixin(Distribution):
+    @property
+    def num_parameters(self):
+        return 0
+
+    def resample(self, *args, **kwargs):
+        return self
+
+    def meanfieldupdate(self, *args, **kwargs):
+        return self
+
+    def get_vlb(self):
+        return 0.
+
+    def copy_sample(self):
+        return self
+
+
+class ProductDistribution(
+        GibbsSampling, MeanField, MeanFieldSVI, MaxLikelihood):
+    def __init__(self, distns, slices=None):
+        self._distns = distns
+        self._slices = slices if slices is not None else \
+            [slice(i, i+1) for i in range(len(distns))]
+
+    @property
+    def params(self):
+        return {idx:distn.params for idx, distn in enumerate(self._distns)}
+
+    @property
+    def hypparams(self):
+        return {idx:distn.hypparams for idx, distn in enumerate(self._distns)}
+
+    @property
+    def num_parameters(self):
+        return sum(d.num_parameters for d in self._distns)
+
+    def rvs(self,size=[]):
+        return np.concatenate(
+            [atleast_2d(distn.rvs(size=size))
+             for distn in self._distns],axis=-1)
+
+    def log_likelihood(self,x):
+        return sum(
+            distn.log_likelihood(x[...,sl])
+            for distn,sl in zip(self._distns,self._slices))
+
+    ### Gibbs
+
+    def resample(self,data=[]):
+        assert isinstance(data,(np.ndarray,list))
+        if isinstance(data,np.ndarray):
+            for distn,sl in zip(self._distns,self._slices):
+                distn.resample(data[...,sl])
+        else:
+            for distn,sl in zip(self._distns,self._slices):
+                distn.resample([d[...,sl] for d in data])
+        return self
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        assert isinstance(data,(np.ndarray,list))
+        if isinstance(data,np.ndarray):
+            for distn,sl in zip(self._distns,self._slices):
+                distn.max_likelihood(data[...,sl],weights=weights)
+        else:
+            for distn,sl in zip(self._distns,self._slices):
+                distn.max_likelihood([d[...,sl] for d in data],weights=weights)
+        return self
+
+    ### Mean field
+
+    def get_vlb(self):
+        return sum(distn.get_vlb() for distn in self._distns)
+
+    def expected_log_likelihood(self,x):
+        return np.sum(
+            [distn.expected_log_likelihood(x[...,sl])
+             for distn,sl in zip(self._distns,self._slices)], axis=0).ravel()
+
+    def meanfieldupdate(self,data,weights,**kwargs):
+        assert isinstance(data,(np.ndarray,list))
+        if isinstance(data,np.ndarray):
+            for distn,sl in zip(self._distns,self._slices):
+                distn.meanfieldupdate(data[...,sl],weights)
+        else:
+            for distn,sl in zip(self._distns,self._slices):
+                distn.meanfieldupdate(
+                    [d[...,sl] for d in data],weights=weights)
+        return self
+
+    def _resample_from_mf(self):
+        for distn in self._distns:
+            distn._resample_from_mf()
+
+    ### SVI
+
+    def meanfield_sgdstep(self,data,weights,prob,stepsize):
+        assert isinstance(data,(np.ndarray,list))
+        if isinstance(data,np.ndarray):
+            for distn,sl in zip(self._distns,self._slices):
+                distn.meanfield_sgdstep(
+                    data[...,sl],weights,prob,stepsize)
+        else:
+            for distn,sl in zip(self._distns,self._slices):
+                distn.meanfield_sgdstep(
+                    [d[...,sl] for d in data],weights,prob,stepsize)
+        return self
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/multinomial.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/multinomial.py
new file mode 100644
index 0000000000000000000000000000000000000000..477675be59a274e5ee3018fb7e899b49fe4d3577
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/multinomial.py
@@ -0,0 +1,649 @@
+from __future__ import division
+from builtins import zip
+from builtins import map
+from builtins import range
+import copy
+__all__ = ['Categorical', 'CategoricalAndConcentration', 'Multinomial',
+           'MultinomialAndConcentration', 'GammaCompoundDirichlet', 'CRP',
+           'Input_Categorical', 'Input_Categorical_Normal']
+
+from pybasicbayes.distributions import gaussian
+import numpy as np
+from warnings import warn
+import scipy.stats as stats
+import scipy.special as special
+
+from pybasicbayes.abstractions import \
+    GibbsSampling, MeanField, MeanFieldSVI, MaxLikelihood, MAP
+
+from pybasicbayes.util.stats import sample_discrete
+
+try:
+    from pybasicbayes.util.cstats import sample_crp_tablecounts
+except ImportError:
+    warn('using slow sample_crp_tablecounts')
+    from pybasicbayes.util.stats import sample_crp_tablecounts
+
+
+class Categorical(GibbsSampling, MeanField, MeanFieldSVI, MaxLikelihood, MAP):
+    '''
+    This class represents a categorical distribution over labels, where the
+    parameter is weights and the prior is a Dirichlet distribution.
+    For example, if K == 3, then five samples may look like
+        [0,1,0,2,1]
+    Each entry is the label of a sample, like the outcome of die rolls. In other
+    words, generated data or data passed to log_likelihood are indices, not
+    indicator variables!  (But when 'weighted data' is passed, like in mean
+    field or weighted max likelihood, the weights are over indicator
+    variables...)
+
+    This class can be used as a weak limit approximation for a DP, particularly by
+    calling __init__ with alpha_0 and K arguments, in which case the prior will be
+    a symmetric Dirichlet with K components and parameter alpha_0/K; K is then the
+    weak limit approximation parameter.
+
+    Hyperparaemters:
+        alphav_0 (vector) OR alpha_0 (scalar) and K
+
+    Parameters:
+        weights, a vector encoding a finite pmf
+    '''
+    def __init__(self,weights=None,alpha_0=None,K=None,alphav_0=None,alpha_mf=None):
+        self.K = K
+        self.alpha_0 = alpha_0
+        self.alphav_0 = alphav_0
+
+        self._alpha_mf = alpha_mf if alpha_mf is not None else self.alphav_0
+
+        self.weights = weights
+
+        if weights is None and self.alphav_0 is not None:
+            self.resample()  # intialize from prior
+
+    def _get_alpha_0(self):
+        return self._alpha_0
+
+    def _set_alpha_0(self,alpha_0):
+        self._alpha_0 = alpha_0
+        if not any(_ is None for _ in (self.K, self._alpha_0)):
+            self.alphav_0 = np.repeat(self._alpha_0/self.K,self.K)
+
+    alpha_0 = property(_get_alpha_0,_set_alpha_0)
+
+    def _get_alphav_0(self):
+        return self._alphav_0 if hasattr(self,'_alphav_0') else None
+
+    def _set_alphav_0(self,alphav_0):
+        if alphav_0 is not None:
+            self._alphav_0 = alphav_0
+            self.K = len(alphav_0)
+
+    alphav_0 = property(_get_alphav_0,_set_alphav_0)
+
+    @property
+    def params(self):
+        return dict(weights=self.weights)
+
+    @property
+    def hypparams(self):
+        return dict(alphav_0=self.alphav_0)
+
+    @property
+    def num_parameters(self):
+        return len(self.weights)
+
+    def rvs(self,size=None):
+        return sample_discrete(self.weights,size)
+
+    def log_likelihood(self,x):
+        out = np.zeros_like(x, dtype=np.double)
+        nanidx = np.isnan(x)
+        err = np.seterr(divide='ignore')
+        out[~nanidx] = np.log(self.weights)[list(x[~nanidx])]  # log(0) can happen, no warning
+        np.seterr(**err)
+        return out
+
+    ### Gibbs sampling
+
+    def resample(self,data=[],counts=None):
+        counts = self._get_statistics(data) if counts is None else counts
+        try:
+            self.weights = np.random.dirichlet(self.alphav_0 + counts)
+        except ZeroDivisionError as e:
+            # print("ZeroDivisionError {}".format(e))
+            self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts)
+        except ValueError as e:
+            # print("ValueError {}".format(e))
+            self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts)
+        if np.isnan(self.weights).any():
+            self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts)
+        np.clip(self.weights, np.spacing(1.), np.inf, out=self.weights)
+        # NOTE: next line is so we can use Gibbs sampling to initialize mean field
+        self._alpha_mf = self.weights * self.alphav_0.sum()
+        assert (self._alpha_mf >= 0.).all()
+        return self
+
+    def _get_statistics(self,data,K=None):
+        K = K if K else self.K
+        if isinstance(data,np.ndarray) or \
+                (isinstance(data,list) and len(data) > 0
+                 and not isinstance(data[0],(np.ndarray,list))):
+            counts = np.bincount(data,minlength=K)
+        else:
+            counts = sum(np.bincount(d,minlength=K) for d in data)
+        return counts
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(weights,np.ndarray):
+            assert weights.ndim in (1,2)
+            if data is None or weights.ndim == 2:
+                # when weights is 2D or data is None, the weights are expected
+                # indicators and data is just a placeholder; nominally data
+                # should be np.arange(K)[na,:].repeat(N,axis=0)
+                counts = np.atleast_2d(weights).sum(0)
+            else:
+                # when weights is 1D, data is indices and we do a weighted
+                # bincount
+                counts = np.bincount(data,weights,minlength=self.K)
+        else:
+            if len(weights) == 0:
+                counts = np.zeros(self.K,dtype=int)
+            else:
+                data = data if data else [None]*len(weights)
+                counts = sum(self._get_weighted_statistics(d,w)
+                             for d, w in zip(data,weights))
+        return counts
+
+    ### Mean Field
+
+    def meanfieldupdate(self,data,weights):
+        # update
+        self._alpha_mf = self.alphav_0 + self._get_weighted_statistics(data,weights)
+        self.weights = self._alpha_mf / self._alpha_mf.sum()  # for plotting
+        assert (self._alpha_mf > 0.).all()
+        return self
+
+    def get_vlb(self):
+        # return avg energy plus entropy, our contribution to the vlb
+        # see Eq. 10.66 in Bishop
+        logpitilde = self.expected_log_likelihood()  # default is on np.arange(self.K)
+        q_entropy = -1* (
+            (logpitilde*(self._alpha_mf-1)).sum()
+            + special.gammaln(self._alpha_mf.sum()) - special.gammaln(self._alpha_mf).sum())
+        p_avgengy = special.gammaln(self.alphav_0.sum()) - special.gammaln(self.alphav_0).sum() \
+            + ((self.alphav_0-1)*logpitilde).sum()
+
+        return p_avgengy + q_entropy
+
+    def expected_log_likelihood(self,x=None):
+        # usually called when np.all(x == np.arange(self.K))
+        x = x if x is not None else slice(None)
+        return special.digamma(self._alpha_mf[x]) - special.digamma(self._alpha_mf.sum())
+
+    ### Mean Field SGD
+
+    def meanfield_sgdstep(self,data,weights,prob,stepsize):
+        self._alpha_mf = \
+            (1-stepsize) * self._alpha_mf + stepsize * (
+                self.alphav_0
+                + 1./prob * self._get_weighted_statistics(data,weights))
+        self.weights = self._alpha_mf / self._alpha_mf.sum()  # for plotting
+        return self
+
+    def _resample_from_mf(self):
+        self.weights = np.random.dirichlet(self._alpha_mf)
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        if weights is None:
+            counts = self._get_statistics(data)
+        else:
+            counts = self._get_weighted_statistics(data,weights)
+        self.weights = counts/counts.sum()
+        return self
+
+    def MAP(self,data,weights=None):
+        if weights is None:
+            counts = self._get_statistics(data)
+        else:
+            counts = self._get_weighted_statistics(data,weights)
+        counts += self.alphav_0
+        self.weights = counts/counts.sum()
+        return self
+
+
+class CategoricalAndConcentration(Categorical):
+    '''
+    Categorical with resampling of the symmetric Dirichlet concentration
+    parameter.
+
+        concentration ~ Gamma(a_0,b_0)
+
+    The Dirichlet prior over pi is then
+
+        pi ~ Dir(concentration/K)
+    '''
+    def __init__(self,a_0,b_0,K,alpha_0=None,weights=None):
+        self.alpha_0_obj = GammaCompoundDirichlet(a_0=a_0,b_0=b_0,K=K,concentration=alpha_0)
+        super(CategoricalAndConcentration,self).__init__(alpha_0=self.alpha_0,
+                K=K,weights=weights)
+
+    def _get_alpha_0(self):
+        return self.alpha_0_obj.concentration
+
+    def _set_alpha_0(self,alpha_0):
+        self.alpha_0_obj.concentration = alpha_0
+        self.alphav_0 = np.repeat(alpha_0/self.K,self.K)
+
+    alpha_0 = property(_get_alpha_0, _set_alpha_0)
+
+    @property
+    def params(self):
+        return dict(alpha_0=self.alpha_0,weights=self.weights)
+
+    @property
+    def hypparams(self):
+        return dict(a_0=self.a_0,b_0=self.b_0,K=self.K)
+
+    def resample(self,data=[]):
+        counts = self._get_statistics(data,self.K)
+        self.alpha_0_obj.resample(counts)
+        self.alpha_0 = self.alpha_0  # for the effect on alphav_0
+        return super(CategoricalAndConcentration,self).resample(data)
+
+    def resample_just_weights(self,data=[]):
+        return super(CategoricalAndConcentration,self).resample(data)
+
+    def meanfieldupdate(self,*args,**kwargs): # TODO
+        warn('MeanField not implemented for %s; concentration parameter will stay fixed')
+        return super(CategoricalAndConcentration,self).meanfieldupdate(*args,**kwargs)
+
+    def max_likelihood(self,*args,**kwargs):
+        raise NotImplementedError
+
+
+class Multinomial(Categorical):
+    '''
+    Like Categorical but the data are counts, so _get_statistics is overridden
+    (though _get_weighted_statistics can stay the same!). log_likelihood also
+    changes since, just like for the binomial special case, we sum over all
+    possible orderings.
+
+    For example, if K == 3, then a sample with n=5 might be
+        array([2,2,1])
+
+    A Poisson process conditioned on the number of points emitted.
+    '''
+    def __init__(self,weights=None,alpha_0=None,K=None,alphav_0=None,alpha_mf=None,
+                 N=1):
+        self.N = N
+        super(Multinomial, self).__init__(weights,alpha_0,K,alphav_0,alpha_mf)
+
+    def log_likelihood(self,x):
+        assert isinstance(x,np.ndarray) and x.ndim == 2 and x.shape[1] == self.K
+        return np.where(x,x*np.log(self.weights),0.).sum(1) \
+            + special.gammaln(x.sum(1)+1) - special.gammaln(x+1).sum(1)
+
+    def rvs(self,size=None,N=None):
+        N = N if N else self.N
+        return np.random.multinomial(N, self.weights, size=size)
+
+    def _get_statistics(self,data,K=None):
+        K = K if K else self.K
+        if isinstance(data,np.ndarray):
+            return np.atleast_2d(data).sum(0)
+        else:
+            if len(data) == 0:
+                return np.zeros(K,dtype=int)
+            return np.concatenate(data).sum(0)
+
+    def expected_log_likelihood(self,x=None):
+        if x is not None and (not x.ndim == 2 or not np.all(x == np.eye(x.shape[0]))):
+            raise NotImplementedError # TODO nontrivial expected log likelihood
+        return super(Multinomial,self).expected_log_likelihood()
+
+
+class MultinomialAndConcentration(CategoricalAndConcentration,Multinomial):
+    pass
+
+
+class CRP(GibbsSampling):
+    '''
+    concentration ~ Gamma(a_0,b_0) [b_0 is inverse scale, inverse of numpy scale arg]
+    rvs ~ CRP(concentration)
+
+    This class models CRPs. The parameter is the concentration parameter (proportional
+    to probability of starting a new table given some number of customers in the
+    restaurant), which has a Gamma prior.
+    '''
+
+    def __init__(self,a_0,b_0,concentration=None):
+        self.a_0 = a_0
+        self.b_0 = b_0
+
+        if concentration is not None:
+            self.concentration = concentration
+        else:
+            self.resample(niter=1)
+
+    @property
+    def params(self):
+        return dict(concentration=self.concentration)
+
+    @property
+    def hypparams(self):
+        return dict(a_0=self.a_0,b_0=self.b_0)
+
+    def rvs(self,customer_counts):
+        # could replace this with one of the faster C versions I have lying
+        # around, but at least the Python version is clearer
+        assert isinstance(customer_counts,list) or isinstance(customer_counts,int)
+        if isinstance(customer_counts,int):
+            customer_counts = [customer_counts]
+
+        restaurants = []
+        for num in customer_counts:
+            # a CRP with num customers
+            tables = []
+            for c in range(num):
+                newidx = sample_discrete(np.array(tables + [self.concentration]))
+                if newidx == len(tables):
+                    tables += [1]
+                else:
+                    tables[newidx] += 1
+
+            restaurants.append(tables)
+
+        return restaurants if len(restaurants) > 1 else restaurants[0]
+
+    def log_likelihood(self,restaurants):
+        assert isinstance(restaurants,list) and len(restaurants) > 0
+        if not isinstance(restaurants[0],list): restaurants=[restaurants]
+
+        likes = []
+        for counts in restaurants:
+            counts = np.array([c for c in counts if c > 0])    # remove zero counts b/c of gammaln
+            K = len(counts) # number of tables
+            N = sum(counts) # number of customers
+            likes.append(K*np.log(self.concentration) + np.sum(special.gammaln(counts)) +
+                            special.gammaln(self.concentration) -
+                            special.gammaln(N+self.concentration))
+
+        return np.asarray(likes) if len(likes) > 1 else likes[0]
+
+    def resample(self,data=[],niter=50):
+        for itr in range(niter):
+            a_n, b_n = self._posterior_hypparams(*self._get_statistics(data))
+            try:
+                self.concentration = np.random.gamma(a_n,scale=1./b_n)
+            except Exception:
+                print("have to apply weird bug fix in /apps/conda/sbruijns/envs/hdp_pg_env/lib/python3.4/site-packages/pybasicbayes/distributions/multinomial")
+                self.concentration += 0.00001
+                a_n, b_n = self._posterior_hypparams(*self._get_statistics(data))
+                self.concentration = np.random.gamma(a_n,scale=1./b_n)
+
+    def _posterior_hypparams(self,sample_numbers,total_num_distinct):
+        # NOTE: this is a stochastic function: it samples auxiliary variables
+        if total_num_distinct > 0:
+            sample_numbers = np.array(sample_numbers)
+            sample_numbers = sample_numbers[sample_numbers > 0]
+
+            wvec = np.random.beta(self.concentration+1,sample_numbers)
+            svec = np.array(stats.bernoulli.rvs(sample_numbers/(sample_numbers+self.concentration)))
+            return self.a_0 + total_num_distinct-svec.sum(), (self.b_0 - np.log(wvec).sum())
+        else:
+            return self.a_0, self.b_0
+        return self
+
+    def _get_statistics(self,data):
+        assert isinstance(data,list)
+        if len(data) == 0:
+            sample_numbers = 0
+            total_num_distinct = 0
+        else:
+            if isinstance(data[0],list):
+                sample_numbers = np.array(list(map(sum,data)))
+                total_num_distinct = sum(map(len,data))
+            else:
+                sample_numbers = np.array(sum(data))
+                total_num_distinct = len(data)
+
+        return sample_numbers, total_num_distinct
+
+
+class GammaCompoundDirichlet(CRP):
+    # TODO this class is a bit ugly
+    '''
+    Implements a Gamma(a_0,b_0) prior over finite dirichlet concentration
+    parameter. The concentration is scaled according to the weak-limit sequence.
+
+    For each set of counts i, the model is
+        concentration ~ Gamma(a_0,b_0)
+        pi_i ~ Dir(concentration/K)
+        data_i ~ Multinomial(pi_i)
+
+    K is a free parameter in that with big enough K (relative to the size of the
+    sampled data) everything starts to act like a DP; K is just the size of the
+    size of the mesh projection.
+    '''
+    def __init__(self,K,a_0,b_0,concentration=None):
+        self.K = K
+        super(GammaCompoundDirichlet,self).__init__(a_0=a_0,b_0=b_0,
+                concentration=concentration)
+
+    @property
+    def params(self):
+        return dict(concentration=self.concentration)
+
+    @property
+    def hypparams(self):
+        return dict(a_0=self.a_0,b_0=self.b_0,K=self.K)
+
+    def rvs(self, sample_counts=None, size=None):
+        if sample_counts is None:
+            sample_counts = size
+        if isinstance(sample_counts,int):
+            sample_counts = [sample_counts]
+        out = np.empty((len(sample_counts),self.K),dtype=int)
+        for idx,c in enumerate(sample_counts):
+            out[idx] = np.random.multinomial(c,
+                np.random.dirichlet(np.repeat(self.concentration/self.K,self.K)))
+        return out if out.shape[0] > 1 else out[0]
+
+    def resample(self,data=[],niter=50,weighted_cols=None):
+        if weighted_cols is not None:
+            self.weighted_cols = weighted_cols
+        else:
+            self.weighted_cols = np.ones(self.K)
+
+        # all this is to check if data is empty
+        if isinstance(data,np.ndarray):
+            size = data.sum()
+        elif isinstance(data,list):
+            size = sum(d.sum() for d in data)
+        else:
+            assert data == 0
+            size = 0
+
+        if size > 0:
+            return super(GammaCompoundDirichlet,self).resample(data,niter=niter)
+        else:
+            return super(GammaCompoundDirichlet,self).resample(data,niter=1)
+
+    def _get_statistics(self,data):
+        # NOTE: this is a stochastic function: it samples auxiliary variables
+        counts = np.array(data,ndmin=2,order='C')
+
+        # sample m's, which sample an inverse of the weak limit projection
+        if counts.sum() == 0:
+            return 0, 0
+        else:
+            m = sample_crp_tablecounts(self.concentration,counts,self.weighted_cols)
+            return counts.sum(1), m.sum()
+
+    def _get_statistics_python(self,data):
+        counts = np.array(data,ndmin=2)
+
+        # sample m's
+        if counts.sum() == 0:
+            return 0, 0
+        else:
+            m = 0
+            for (i,j), n in np.ndenumerate(counts):
+                m += (np.random.rand(n) < self.concentration*self.K*self.weighted_cols[j] \
+                        / (np.arange(n)+self.concentration*self.K*self.weighted_cols[j])).sum()
+            return counts.sum(1), m
+
+
+class Input_Categorical():
+    '''
+    This class enables an input output HMM structure, where the observation distribution
+    that we sample from is dictated by the input at that time step, so this input must be an
+    integer. Parameters are weights and the prior is a Dirichlet distribution.
+    Inspired by Matt's other distributions.
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights, a vector encoding a finite, multidimensional pmf]
+    '''
+    def __init__(self, n_inputs, n_outputs, modulate=100):
+
+        self.n_inputs = n_inputs
+        self.n_outputs = n_outputs  # could be made different for different inputs
+        self.modulate = modulate  # high modulation (take modulo of first row data) means no effect, modulation = stimulus set size means no conditinoning
+        #print("Input data is taken mod {}".format(self.modulate))
+        self.weights = np.zeros((n_inputs, n_outputs))
+
+        self.resample()  # intialize from prior
+
+    def rvs(self, input):
+        input = copy.copy(input)
+        input = input % self.modulate
+        types, counts = np.unique(input, return_counts=True)
+        output = np.zeros_like(input)
+        for t, c in zip(types, counts):
+            temp = np.random.choice(self.n_outputs, c, p=self.weights[t])
+            output[input == t] = temp
+        return np.array((input, output)).T
+
+    def log_likelihood(self, x):
+        x = copy.copy(x)
+        x[:, 0] %= self.modulate
+        out = np.zeros_like(x, dtype=np.double)
+        err = np.seterr(divide='ignore')
+        out = np.log(self.weights)[tuple(x.T)]
+        np.seterr(**err)
+        return out
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        counts = self._get_statistics(data)
+
+        for i in range(self.n_inputs):
+            self.weights[i] = np.random.dirichlet(np.ones(self.n_outputs) + counts[i])
+
+    def _get_statistics(self, data):
+        # TODO: improve
+        data = copy.copy(data)
+        counts = np.zeros_like(self.weights, dtype=np.int32)
+        if isinstance(data, np.ndarray):
+            assert len(data.shape) == 2
+            data[:, 0] %= self.modulate
+            for d in data:
+                counts[tuple(d)] += 1
+        else:
+            for d in data:
+                d[:, 0] %= self.modulate
+                for subd in d:
+                    counts[tuple(subd)] += 1
+        return counts
+
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
+
+
+class Input_Categorical_Normal():
+    '''
+    This class enables an input output HMM structure, where the observation distribution
+    that we sample from is dictated by the input at that time step, so this input must be an
+    integer.
+    Here, One output is from a Categorical, the other from a normal.
+    Parameters are weights and the prior is a Dirichlet distribution.
+    Inspired by Matt's other distributions.
+
+    Hyperparaemters:
+
+        TODO
+
+    Parameters:
+        [weights, a vector encoding a finite, multidimensional pmf]
+    '''
+    def __init__(self, n_inputs, n_outputs, normal_hypparams):
+
+        self.n_inputs = n_inputs
+        self.n_outputs = n_outputs  # could be made different for different inputs
+        self.cats = Input_Categorical(n_inputs, n_outputs)
+        self.normals = [gaussian.Gaussian(**normal_hypparams) for state in range(n_inputs)]
+
+    def rvs(self, input):
+        types, counts = np.unique(input, return_counts=True)
+        output = np.zeros((len(input), 2))
+        output[:, 0] = self.cats.rvs(input)[:, 1]
+
+        for t, c in zip(types, counts):
+            temp = self.normals[t].rvs(c)
+            output[input == t, 1] = temp[:, 0]
+
+        return np.concatenate((np.array(input)[None].T, output), axis=1)
+
+    def log_likelihood(self, x):
+        out = self.cats.log_likelihood(x[:, [0, 1]].astype(int))
+
+        # Filter out negative or 0 reaction times
+        bad_nums = x[:, 2] == -np.inf
+        x[bad_nums, 2] = 0
+
+        means = np.zeros(x.shape[0])
+        stds = np.zeros(x.shape[0])
+        normal_ll = np.zeros(x.shape[0])
+        for i, n in enumerate(self.normals):
+            mask = x[:, 0] == i
+            # don't consider negative RTs in ll calculation
+            means[mask], stds[mask] = n.mu, n.sigma[0]
+
+        normal_ll = stats.norm().logpdf((x[:, 2] - means) / stds) / stds
+        normal_ll[bad_nums] = 0
+
+        return out
+
+    # Gibbs sampling
+    def resample(self, data=[]):
+        if isinstance(data, np.ndarray):
+            pass
+        if isinstance(data, list):
+            data = np.concatenate(data)
+        self.cats.resample(data[:, [0, 1]].astype(int))
+
+        bad_nums = data[:, 2] == -np.inf
+        for i, n in enumerate(self.normals):
+            mask = data[:, 0] == i
+            n.resample(data=data[np.logical_and(mask, ~bad_nums), 2])
+
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        warn('ML not implemented')
+
+    def MAP(self,data,weights=None):
+        warn('MAP not implemented')
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/negativebinomial.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/negativebinomial.py
new file mode 100644
index 0000000000000000000000000000000000000000..9336a2a61cc350eacd35fb5341dbaebd8cd04de8
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/negativebinomial.py
@@ -0,0 +1,681 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+from builtins import object
+__all__ = [
+    'NegativeBinomial', 'NegativeBinomialFixedR', 'NegativeBinomialIntegerR2',
+    'NegativeBinomialIntegerR', 'NegativeBinomialFixedRVariant',
+    'NegativeBinomialIntegerRVariant', 'NegativeBinomialIntegerRVariant',
+    'NegativeBinomialIntegerR2Variant']
+
+import numpy as np
+from numpy import newaxis as na
+import scipy.special as special
+from scipy.special import logsumexp
+from warnings import warn
+
+from pybasicbayes.abstractions import Distribution, GibbsSampling, \
+    MeanField, MeanFieldSVI, MaxLikelihood
+from pybasicbayes.util.stats import getdatasize, flattendata, \
+    sample_discrete_from_log, sample_discrete, atleast_2d
+
+try:
+    from pybasicbayes.util.cstats import sample_crp_tablecounts
+except ImportError:
+    warn('using slow sample_crp_tablecounts')
+    from pybasicbayes.util.stats import sample_crp_tablecounts
+
+
+class _NegativeBinomialBase(Distribution):
+    '''
+    Negative Binomial distribution with a conjugate beta prior on p and a
+    separate gamma prior on r. The parameter r does not need to be an integer.
+    If r is an integer, then x ~ NegBin(r,p) is the same as
+    x = np.random.geometric(1-p,size=r).sum() - r
+    where r is subtracted to make the geometric support be {0,1,2,...}
+    Mean is r*p/(1-p), var is r*p/(1-p)**2
+
+    Uses the data augemntation sampling method from Zhou et al. ICML 2012
+
+    NOTE: the support is {0,1,2,...}.
+
+    Hyperparameters:
+        k_0, theta_0: r ~ Gamma(k, theta)
+                      or r = np.random.gamma(k,theta)
+        alpha_0, beta_0: p ~ Beta(alpha,beta)
+                      or p = np.random.beta(alpha,beta)
+
+    Parameters:
+        r
+        p
+    '''
+    def __init__(self,r=None,p=None,k_0=None,theta_0=None,alpha_0=None,beta_0=None):
+        self.r = r
+        self.p = p
+
+        self.k_0 = k_0
+        self.theta_0 = theta_0
+        self.alpha_0 = alpha_0
+        self.beta_0 = beta_0
+
+        if r is p is None and not any(_ is None for _ in (k_0,theta_0,alpha_0,beta_0)):
+            self.resample() # intialize from prior
+
+    @property
+    def params(self):
+        return dict(r=self.r,p=self.p)
+
+    @property
+    def hypparams(self):
+        return dict(k_0=self.k_0,theta_0=self.theta_0,
+                alpha_0=self.alpha_0,beta_0=self.beta_0)
+
+    def log_likelihood(self,x,r=None,p=None):
+        r = r if r is not None else self.r
+        p = p if p is not None else self.p
+        x = np.array(x,ndmin=1)
+
+        if self.p > 0:
+            xnn = x[x >= 0]
+            raw = np.empty(x.shape)
+            raw[x>=0] = special.gammaln(r + xnn) - special.gammaln(r) \
+                    - special.gammaln(xnn+1) + r*np.log(1-p) + xnn*np.log(p)
+            raw[x<0] = -np.inf
+            return raw if isinstance(x,np.ndarray) else raw[0]
+        else:
+            raw = np.log(np.zeros(x.shape))
+            raw[x == 0] = 0.
+            return raw if isinstance(x,np.ndarray) else raw[0]
+
+    def log_sf(self,x):
+        scalar = not isinstance(x,np.ndarray)
+        x = np.atleast_1d(x)
+        errs = np.seterr(divide='ignore')
+        ret = np.log(special.betainc(x+1,self.r,self.p))
+        np.seterr(**errs)
+        ret[x < 0] = np.log(1.)
+        if scalar:
+            return ret[0]
+        else:
+            return ret
+
+    def rvs(self,size=None):
+        return np.random.poisson(np.random.gamma(self.r,self.p/(1-self.p),size=size))
+
+class NegativeBinomial(_NegativeBinomialBase, GibbsSampling):
+    def resample(self,data=[],niter=20):
+        if getdatasize(data) == 0:
+            self.p = np.random.beta(self.alpha_0,self.beta_0)
+            self.r = np.random.gamma(self.k_0,self.theta_0)
+        else:
+            data = atleast_2d(flattendata(data))
+            N = len(data)
+            for itr in range(niter):
+                ### resample r
+                msum = sample_crp_tablecounts(self.r,data).sum()
+                self.r = np.random.gamma(self.k_0 + msum, 1/(1/self.theta_0 - N*np.log(1-self.p)))
+                ### resample p
+                self.p = np.random.beta(self.alpha_0 + data.sum(), self.beta_0 + N*self.r)
+        return self
+
+    def resample_python(self,data=[],niter=20):
+        if getdatasize(data) == 0:
+            self.p = np.random.beta(self.alpha_0,self.beta_0)
+            self.r = np.random.gamma(self.k_0,self.theta_0)
+        else:
+            data = flattendata(data)
+            N = len(data)
+            for itr in range(niter):
+                ### resample r
+                msum = 0.
+                for n in data:
+                    msum += (np.random.rand(n) < self.r/(np.arange(n)+self.r)).sum()
+                self.r = np.random.gamma(self.k_0 + msum, 1/(1/self.theta_0 - N*np.log(1-self.p)))
+                ### resample p
+                self.p = np.random.beta(self.alpha_0 + data.sum(), self.beta_0 + N*self.r)
+        return self
+
+    ### OLD unused alternatives
+
+    def resample_logseriesaug(self,data=[],niter=20):
+        # an alternative algorithm, kind of opaque and no advantages...
+        if getdatasize(data) == 0:
+            self.p = np.random.beta(self.alpha_0,self.beta_0)
+            self.r = np.random.gamma(self.k_0,self.theta_0)
+        else:
+            data = flattendata(data)
+            N = data.shape[0]
+            logF = self.logF
+            L_i = np.zeros(N)
+            data_nz = data[data > 0]
+            for itr in range(niter):
+                logR = np.arange(1,logF.shape[1]+1)*np.log(self.r) + logF
+                L_i[data > 0] = sample_discrete_from_log(logR[data_nz-1,:data_nz.max()],axis=1)+1
+                self.r = np.random.gamma(self.k_0 + L_i.sum(), 1/(1/self.theta_0 - np.log(1-self.p)*N))
+                self.p = np.random.beta(self.alpha_0 + data.sum(), self.beta_0 + N*self.r)
+        return self
+
+    @classmethod
+    def _set_up_logF(cls):
+        if not hasattr(cls,'logF'):
+            # actually indexes logF[0,0] to correspond to log(F(1,1)) in Zhou
+            # paper, but keeps track of that alignment with the other code!
+            # especially arange(1,...), only using nonzero data and shifting it
+            SIZE = 500
+
+            logF = -np.inf * np.ones((SIZE,SIZE))
+            logF[0,0] = 0.
+            for m in range(1,logF.shape[0]):
+                prevrow = np.exp(logF[m-1] - logF[m-1].max())
+                logF[m] = np.log(np.convolve(prevrow,[0,m,1],'same')) + logF[m-1].max()
+            cls.logF = logF
+
+
+class NegativeBinomialFixedR(_NegativeBinomialBase, GibbsSampling, MeanField, MeanFieldSVI, MaxLikelihood):
+    def __init__(self,r=None,p=None,alpha_0=None,beta_0=None,alpha_mf=None,beta_mf=None):
+        self.p = p
+
+        self.r = r
+
+        self.alpha_0 = alpha_0
+        self.beta_0 = beta_0
+
+        if p is None and not any(_ is None for _ in (alpha_0,beta_0)):
+            self.resample() # intialize from prior
+
+        if not any(_ is None for _ in (alpha_mf,beta_mf)):
+            self.alpha_mf = alpha_mf
+            self.beta_mf = beta_mf
+
+    @property
+    def hypparams(self):
+        return dict(alpha_0=self.alpha_0,beta_0=self.beta_0)
+
+    @property
+    def natural_hypparam(self):
+        return np.array([self.alpha_0,self.beta_0]) - 1
+
+    @natural_hypparam.setter
+    def natural_hypparam(self,natparam):
+        self.alpha_0, self.beta_0 = natparam + 1
+
+    ### Mean Field
+
+    def _resample_from_mf(self):
+        self.p = np.random.beta(self.alpha_mf,self.beta_mf)
+        return self
+
+    def meanfieldupdate(self,data,weights):
+        self.alpha_mf, self.beta_mf = \
+                self._posterior_hypparams(*self._get_weighted_statistics(data,weights))
+        self.p = self.alpha_mf / (self.alpha_mf + self.beta_mf)
+
+    def meanfield_sgdstep(self,data,weights,prob,stepsize):
+        alpha_new, beta_new = \
+                self._posterior_hypparams(*(
+                    1./prob * self._get_weighted_statistics(data,weights)))
+        self.alpha_mf = (1-stepsize)*self.alpha_mf + stepsize*alpha_new
+        self.beta_mf = (1-stepsize)*self.beta_mf + stepsize*beta_new
+        self.p = self.alpha_mf / (self.alpha_mf + self.beta_mf)
+
+    def get_vlb(self):
+        Elnp, Eln1mp = self._mf_expected_statistics()
+        p_avgengy = (self.alpha_0-1)*Elnp + (self.beta_0-1)*Eln1mp \
+                - (special.gammaln(self.alpha_0) + special.gammaln(self.beta_0)
+                        - special.gammaln(self.alpha_0 + self.beta_0))
+        q_entropy = special.betaln(self.alpha_mf,self.beta_mf) \
+                - (self.alpha_mf-1)*special.digamma(self.alpha_mf) \
+                - (self.beta_mf-1)*special.digamma(self.beta_mf) \
+                + (self.alpha_mf+self.beta_mf-2)*special.digamma(self.alpha_mf+self.beta_mf)
+        return p_avgengy + q_entropy
+
+    def _mf_expected_statistics(self):
+        Elnp, Eln1mp = special.digamma([self.alpha_mf,self.beta_mf]) \
+                        - special.digamma(self.alpha_mf + self.beta_mf)
+        return Elnp, Eln1mp
+
+    def expected_log_likelihood(self,x):
+        Elnp, Eln1mp = self._mf_expected_statistics()
+        x = np.atleast_1d(x)
+        errs = np.seterr(invalid='ignore')
+        out = x*Elnp + self.r*Eln1mp + self._log_base_measure(x,self.r)
+        np.seterr(**errs)
+        out[np.isnan(out)] = -np.inf
+        return out if out.shape[0] > 1 else out[0]
+
+    @staticmethod
+    def _log_base_measure(x,r):
+        return special.gammaln(x+r) - special.gammaln(x+1) - special.gammaln(r)
+
+    ### Gibbs
+
+    def resample(self,data=[]):
+        self.p = np.random.beta(*self._posterior_hypparams(*self._get_statistics(data)))
+        # set mean field params to something reasonable for initialization
+        fakedata = self.rvs(10)
+        self.alpha_mf, self.beta_mf = self._posterior_hypparams(*self._get_statistics(fakedata))
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        if weights is None:
+            n, tot = self._get_statistics(data)
+        else:
+            n, tot = self._get_weighted_statistics(data,weights)
+
+        self.p = (tot/n) / (self.r + tot/n)
+        return self
+
+    ### Statistics and posterior hypparams
+
+    def _get_statistics(self,data):
+        if getdatasize(data) == 0:
+            n, tot = 0, 0
+        elif isinstance(data,np.ndarray):
+            assert np.all(data >= 0)
+            data = np.atleast_1d(data)
+            n, tot = data.shape[0], data.sum()
+        elif isinstance(data,list):
+            assert all(np.all(d >= 0) for d in data)
+            n = sum(d.shape[0] for d in data)
+            tot = sum(d.sum() for d in data)
+        else:
+            assert np.isscalar(data)
+            n = 1
+            tot = data
+
+        return np.array([n, tot])
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(weights,np.ndarray):
+            assert np.all(data >= 0) and data.ndim == 1
+            n, tot = weights.sum(), weights.dot(data)
+        else:
+            assert all(np.all(d >= 0) for d in data)
+            n = sum(w.sum() for w in weights)
+            tot = sum(w.dot(d) for d,w in zip(data,weights))
+
+        return np.array([n, tot])
+
+    def _posterior_hypparams(self,n,tot):
+        return np.array([self.alpha_0 + tot, self.beta_0 + n*self.r])
+
+class NegativeBinomialIntegerR2(_NegativeBinomialBase,MeanField,MeanFieldSVI,GibbsSampling):
+    # NOTE: this class should replace NegativeBinomialFixedR completely...
+    _fixedr_class = NegativeBinomialFixedR
+
+    def __init__(self,alpha_0=None,beta_0=None,alphas_0=None,betas_0=None,
+            r_support=None,r_probs=None,r_discrete_distn=None,
+            r=None,ps=None):
+
+        assert (r_discrete_distn is not None) ^ (r_support is not None and r_probs is not None)
+        if r_discrete_distn is not None:
+            r_support, = np.where(r_discrete_distn)
+            r_probs = r_discrete_distn[r_support]
+            r_support += 1
+        self.r_support = np.asarray(r_support)
+        self.rho_0 = self.rho_mf = np.log(r_probs)
+
+        assert (alpha_0 is not None and  beta_0 is not None) \
+                ^ (alphas_0 is not None and betas_0 is not None)
+        alphas_0 = alphas_0 if alphas_0 is not None else [alpha_0]*len(r_support)
+        betas_0 = betas_0 if betas_0 is not None else [beta_0]*len(r_support)
+        ps = ps if ps is not None else [None]*len(r_support)
+        self._fixedr_distns = \
+            [self._fixedr_class(r=r,p=p,alpha_0=alpha_0,beta_0=beta_0)
+                    for r,p,alpha_0,beta_0 in zip(r_support,ps,alphas_0,betas_0)]
+
+        # for init
+        self.ridx = sample_discrete(r_probs)
+        self.r = r_support[self.ridx]
+
+    def __repr__(self):
+        return 'NB(r=%d,p=%0.3f)' % (self.r,self.p)
+
+    @property
+    def alphas_0(self):
+        return np.array([d.alpha_0 for d in self._fixedr_distns]) \
+                if len(self._fixedr_distns) > 0 else None
+
+    @property
+    def betas_0(self):
+        return np.array([d.beta_0 for d in self._fixedr_distns]) \
+                if len(self._fixedr_distns) > 0 else None
+
+    @property
+    def p(self):
+        return self._fixedr_distns[self.ridx].p
+
+    @p.setter
+    def p(self,val):
+        self._fixedr_distns[self.ridx].p = val
+
+    def _resample_from_mf(self):
+        self._resample_r_from_mf()
+        self._resample_p_from_mf()
+
+    def _resample_r_from_mf(self):
+        lognorm = logsumexp(self.rho_mf)
+        self.ridx = sample_discrete(np.exp(self.rho_mf - lognorm))
+        self.r = self.r_support[self.ridx]
+
+    def _resample_p_from_mf(self):
+        d = self._fixedr_distns[self.ridx]
+        self.p = np.random.beta(d.alpha_mf,d.beta_mf)
+
+    def get_vlb(self):
+        return self._r_vlb() + sum(np.exp(rho)*d.get_vlb()
+                for rho,d in zip(self.rho_mf,self._fixedr_distns))
+
+    def _r_vlb(self):
+        return np.exp(self.rho_mf).dot(self.rho_0) \
+                - np.exp(self.rho_mf).dot(self.rho_mf)
+
+    def num_parameters(self):
+        return 2
+
+    def meanfieldupdate(self,data,weights):
+        for d in self._fixedr_distns:
+            d.meanfieldupdate(data,weights)
+        self._update_rho_mf(data,weights)
+        # everything below here is for plotting
+        ridx = self.rho_mf.argmax()
+        d = self._fixedr_distns[ridx]
+        self.r = d.r
+        self.p = d.alpha_mf / (d.alpha_mf + d.beta_mf)
+
+    def _update_rho_mf(self,data,weights):
+        self.rho_mf = self.rho_0.copy()
+        for idx, d in enumerate(self._fixedr_distns):
+            n, tot = d._get_weighted_statistics(data,weights)
+            Elnp, Eln1mp = d._mf_expected_statistics()
+            self.rho_mf[idx] += (d.alpha_0-1+tot)*Elnp + (d.beta_0-1+n*d.r)*Eln1mp
+            if isinstance(data,np.ndarray):
+                self.rho_mf[idx] += weights.dot(d._log_base_measure(data,d.r))
+            else:
+                self.rho_mf[idx] += sum(w.dot(d._log_base_measure(dt,d.r))
+                        for dt,w in zip(data,weights))
+
+    def expected_log_likelihood(self,x):
+        lognorm = logsumexp(self.rho_mf)
+        return sum(np.exp(rho-lognorm)*d.expected_log_likelihood(x)
+                for rho,d in zip(self.rho_mf,self._fixedr_distns))
+
+    def meanfield_sgdstep(self,data,weights,prob,stepsize):
+        rho_mf_orig = self.rho_mf.copy()
+        if isinstance(data,np.ndarray):
+            self._update_rho_mf(data,prob*weights)
+        else:
+            self._update_rho_mf(data,[w*prob for w in weights])
+        rho_mf_new = self.rho_mf
+
+        for d in self._fixedr_distns:
+            d.meanfield_sgdstep(data,weights,prob,stepsize)
+
+        self.rho_mf = (1-stepsize)*rho_mf_orig + stepsize*rho_mf_new
+
+        # for plotting
+        ridx = self.rho_mf.argmax()
+        d = self._fixedr_distns[ridx]
+        self.r = d.r
+        self.p = d.alpha_mf / (d.alpha_mf + d.beta_mf)
+
+    def resample(self,data=[]):
+        self._resample_r(data) # marginalizes out p values
+        self._resample_p(data) # resample p given sampled r
+        return self
+
+    def _resample_r(self,data):
+        self.ridx = sample_discrete(
+                self._posterior_hypparams(self._get_statistics(data)))
+        self.r = self.r_support[self.ridx]
+        return self
+
+    def _resample_p(self,data):
+        self._fixedr_distns[self.ridx].resample(data)
+        return self
+
+    def _get_statistics(self,data=[]):
+        n, tot = self._fixedr_distns[0]._get_statistics(data)
+        if n > 0:
+            data = flattendata(data)
+            alphas_n, betas_n = self.alphas_0 + tot, self.betas_0 + self.r_support*n
+            log_marg_likelihoods = \
+                    special.betaln(alphas_n, betas_n) \
+                        - special.betaln(self.alphas_0, self.betas_0) \
+                    + (special.gammaln(data[:,na]+self.r_support)
+                        - special.gammaln(data[:,na]+1) \
+                        - special.gammaln(self.r_support)).sum(0)
+        else:
+            log_marg_likelihoods = np.zeros_like(self.r_support)
+        return log_marg_likelihoods
+
+    def _posterior_hypparams(self,log_marg_likelihoods):
+        log_posterior_discrete = self.rho_0 + log_marg_likelihoods
+        return np.exp(log_posterior_discrete - log_posterior_discrete.max())
+
+class NegativeBinomialIntegerR(NegativeBinomialFixedR, GibbsSampling, MaxLikelihood):
+    '''
+    Nonconjugate Discrete+Beta prior
+    r_discrete_distribution is an array where index i is p(r=i+1)
+    '''
+    def __init__(self,r_discrete_distn=None,r_support=None,
+            alpha_0=None,beta_0=None,r=None,p=None):
+        self.r_support = r_support
+        self.r_discrete_distn = r_discrete_distn
+        self.alpha_0 = alpha_0
+        self.beta_0 = beta_0
+        self.r = r
+        self.p = p
+
+        if r is p is None \
+                and not any(_ is None for _ in (r_discrete_distn,alpha_0,beta_0)):
+            self.resample() # intialize from prior
+
+    @property
+    def hypparams(self):
+        return dict(r_discrete_distn=self.r_discrete_distn,
+                alpha_0=self.alpha_0,beta_0=self.beta_0)
+
+    def get_r_discrete_distn(self):
+        return self._r_discrete_distn
+
+    def set_r_discrete_distn(self,r_discrete_distn):
+        if r_discrete_distn is not None:
+            r_discrete_distn = np.asarray(r_discrete_distn,dtype=np.float)
+            r_support, = np.where(r_discrete_distn)
+            r_probs = r_discrete_distn[r_support]
+            r_probs /= r_probs.sum()
+            r_support += 1 # r_probs[0] corresponds to r=1
+
+            self.r_support = r_support
+            self.r_probs = r_probs
+            self._r_discrete_distn = r_discrete_distn
+
+    r_discrete_distn = property(get_r_discrete_distn,set_r_discrete_distn)
+
+    def rvs(self,size=None):
+        out = np.random.geometric(1-self.p,size=size)-1
+        for i in range(self.r-1):
+            out += np.random.geometric(1-self.p,size=size)-1
+        return out
+
+    def resample(self,data=[]):
+        alpha_n, betas_n, posterior_discrete = self._posterior_hypparams(
+                *self._get_statistics(data))
+
+        r_idx = sample_discrete(posterior_discrete)
+        self.r = self.r_support[r_idx]
+        self.p = np.random.beta(alpha_n, betas_n[r_idx])
+
+    # NOTE: this class has a conjugate prior even though it's not in the
+    # exponential family, so I wrote _get_statistics and _get_weighted_statistics
+    # (which integrate out p) for the resample() and meanfield_update() methods,
+    # though these aren't statistics in the exponential family sense
+
+    def _get_statistics(self,data):
+        # NOTE: since this isn't really in exponential family, this method needs
+        # to look at hyperparameters. form posterior hyperparameters for the p
+        # parameters here so we can integrate them out and get the r statistics
+        n, tot = super(NegativeBinomialIntegerR,self)._get_statistics(data)
+        if n > 0:
+            alpha_n, betas_n = self.alpha_0 + tot, self.beta_0 + self.r_support*n
+            data = flattendata(data)
+            log_marg_likelihoods = \
+                    special.betaln(alpha_n, betas_n) \
+                        - special.betaln(self.alpha_0, self.beta_0) \
+                    + (special.gammaln(data[:,na]+self.r_support)
+                        - special.gammaln(data[:,na]+1) \
+                        - special.gammaln(self.r_support)).sum(0)
+        else:
+            log_marg_likelihoods = np.zeros_like(self.r_support)
+
+        return n, tot, log_marg_likelihoods
+
+    def _get_weighted_statistics(self,data,weights):
+        n, tot = super(NegativeBinomialIntegerR,self)._get_weighted_statistics(data,weights)
+        if n > 0:
+            alpha_n, betas_n = self.alpha_0 + tot, self.beta_0 + self.r_support*n
+            data, weights = flattendata(data), flattendata(weights)
+            log_marg_likelihoods = \
+                    special.betaln(alpha_n, betas_n) \
+                        - special.betaln(self.alpha_0, self.beta_0) \
+                    + (special.gammaln(data[:,na]+self.r_support)
+                        - special.gammaln(data[:,na]+1) \
+                        - special.gammaln(self.r_support)).dot(weights)
+        else:
+            log_marg_likelihoods = np.zeros_like(self.r_support)
+
+        return n, tot, log_marg_likelihoods
+
+    def _posterior_hypparams(self,n,tot,log_marg_likelihoods):
+        alpha_n = self.alpha_0 + tot
+        betas_n = self.beta_0 + n*self.r_support
+        log_posterior_discrete = np.log(self.r_probs) + log_marg_likelihoods
+        posterior_discrete = np.exp(log_posterior_discrete - log_posterior_discrete.max())
+        return alpha_n, betas_n, posterior_discrete
+
+    def max_likelihood(self,data,weights=None,stats=None):
+        if stats is not None:
+            n, tot = stats
+        elif weights is None:
+            n, tot = super(NegativeBinomialIntegerR,self)._get_statistics(data)
+        else:
+            n, tot = super(NegativeBinomialIntegerR,self)._get_weighted_statistics(data,weights)
+
+        if n > 1:
+            rs = self.r_support
+            ps = self._max_likelihood_ps(n,tot,rs)
+
+            # TODO TODO this isn't right for weighted data: do weighted sums
+            if isinstance(data,np.ndarray):
+                likelihoods = np.array([self.log_likelihood(data,r=r,p=p).sum()
+                                            for r,p in zip(rs,ps)])
+            else:
+                likelihoods = np.array([sum(self.log_likelihood(d,r=r,p=p).sum()
+                                            for d in data) for r,p in zip(rs,ps)])
+
+            argmax = likelihoods.argmax()
+            self.r = self.r_support[argmax]
+            self.p = ps[argmax]
+        return self
+
+    def _log_base_measure(self,data):
+        return [(special.gammaln(r+data) - special.gammaln(r) - special.gammaln(data+1)).sum()
+                for r in self.r_support]
+
+    def _max_likelihood_ps(self,n,tot,rs):
+        ps = (tot/n) / (rs + tot/n)
+        assert (ps >= 0).all()
+        return ps
+
+class _StartAtRMixin(object):
+    def log_likelihood(self,x,**kwargs):
+        r = kwargs['r'] if 'r' in kwargs else self.r
+        return super(_StartAtRMixin,self).log_likelihood(x-r,**kwargs)
+
+    def log_sf(self,x,**kwargs):
+        return super(_StartAtRMixin,self).log_sf(x-self.r,**kwargs)
+
+    def expected_log_likelihood(self,x,**kwargs):
+        r = kwargs['r'] if 'r' in kwargs else self.r
+        return super(_StartAtRMixin,self).expected_log_likelihood(x-r,**kwargs)
+
+    def rvs(self,size=[]):
+        return super(_StartAtRMixin,self).rvs(size)+self.r
+
+class NegativeBinomialFixedRVariant(_StartAtRMixin,NegativeBinomialFixedR):
+    def _get_statistics(self,data):
+        n, tot = super(NegativeBinomialFixedRVariant,self)._get_statistics(data)
+        n, tot = n, tot-n*self.r
+        assert tot >= 0
+        return np.array([n, tot])
+
+    def _get_weighted_statistics(self,data,weights):
+        n, tot = super(NegativeBinomialFixedRVariant,self)._get_weighted_statistics(data,weights)
+        n, tot = n, tot-n*self.r
+        assert tot >= 0
+        return np.array([n, tot])
+
+class NegativeBinomialIntegerRVariant(NegativeBinomialIntegerR):
+    def resample(self,data=[]):
+        n, alpha_n, posterior_discrete, r_support = self._posterior_hypparams(
+                *self._get_statistics(data)) # NOTE: pass out r_support b/c feasible subset
+        self.r = r_support[sample_discrete(posterior_discrete)]
+        self.p = np.random.beta(alpha_n - n*self.r, self.beta_0 + n*self.r)
+
+    def _get_statistics(self,data):
+        n = getdatasize(data)
+        if n > 0:
+            data = flattendata(data)
+            feasible = self.r_support <= data.min()
+            assert np.any(feasible)
+            r_support = self.r_support[feasible]
+            normalizers = (special.gammaln(data[:,na]) - special.gammaln(data[:,na]-r_support+1)
+                    - special.gammaln(r_support)).sum(0)
+            return n, data.sum(), normalizers, feasible
+        else:
+            return n, None, None, None
+
+    def _posterior_hypparams(self,n,tot,normalizers,feasible):
+        if n == 0:
+            return n, self.alpha_0, self.r_probs, self.r_support
+        else:
+            r_probs = self.r_probs[feasible]
+            r_support = self.r_support[feasible]
+            log_marg_likelihoods = special.betaln(self.alpha_0 + tot - n*r_support,
+                                                        self.beta_0 + r_support*n) \
+                                    - special.betaln(self.alpha_0, self.beta_0) \
+                                    + normalizers
+            log_marg_probs = np.log(r_probs) + log_marg_likelihoods
+            log_marg_probs -= log_marg_probs.max()
+            marg_probs = np.exp(log_marg_probs)
+
+            return n, self.alpha_0 + tot, marg_probs, r_support
+
+    def _max_likelihood_ps(self,n,tot,rs):
+        ps = 1-(rs*n)/tot
+        assert (ps >= 0).all()
+        return ps
+
+    def rvs(self,size=[]):
+        return super(NegativeBinomialIntegerRVariant,self).rvs(size) + self.r
+
+class NegativeBinomialIntegerR2Variant(NegativeBinomialIntegerR2):
+    _fixedr_class = NegativeBinomialFixedRVariant
+
+    def _update_rho_mf(self,data,weights):
+        self.rho_mf = self.rho_0.copy()
+        for idx, d in enumerate(self._fixedr_distns):
+            n, tot = d._get_weighted_statistics(data,weights)
+            Elnp, Eln1mp = d._mf_expected_statistics()
+            self.rho_mf[idx] += (d.alpha_0-1+tot)*Elnp + (d.beta_0-1+n*d.r)*Eln1mp
+            self.rho_mf_temp = self.rho_mf.copy()
+
+            # NOTE: this method only needs to override parent in the base measure
+            # part, i.e. data -> data-r
+            if isinstance(data,np.ndarray):
+                self.rho_mf[idx] += weights.dot(d._log_base_measure(data-d.r,d.r))
+            else:
+                self.rho_mf[idx] += sum(w.dot(d._log_base_measure(dt-d.r,d.r))
+                        for dt,w in zip(data,weights))
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/poisson.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/poisson.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee6099e6d993990bc745b0b1de502930f0c7c0c7
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/poisson.py
@@ -0,0 +1,186 @@
+from __future__ import division
+from builtins import zip
+__all__ = ['Poisson']
+import numpy as np
+import scipy.stats as stats
+import scipy.special as special
+
+from pybasicbayes.abstractions import GibbsSampling, Collapsed, \
+    MaxLikelihood, MeanField, MeanFieldSVI
+
+
+class Poisson(GibbsSampling, Collapsed, MaxLikelihood, MeanField, MeanFieldSVI):
+    '''
+    Poisson distribution with a conjugate Gamma prior.
+
+    NOTE: the support is {0,1,2,...}
+
+    Hyperparameters (following Wikipedia's notation):
+        alpha_0, beta_0
+
+    Parameter is the mean/variance parameter:
+        lmbda
+    '''
+    def __init__(self,lmbda=None,alpha_0=None,beta_0=None,mf_alpha_0=None,mf_beta_0=None):
+        self.lmbda = lmbda
+
+        self.alpha_0 = alpha_0
+        self.beta_0 = beta_0
+        self.mf_alpha_0 = mf_alpha_0 if mf_alpha_0 is not None else alpha_0
+        self.mf_beta_0 = mf_beta_0 if mf_beta_0 is not None else beta_0
+
+        if lmbda is None and not any(_ is None for _ in (alpha_0,beta_0)):
+            self.resample() # intialize from prior
+
+    @property
+    def params(self):
+        return dict(lmbda=self.lmbda)
+
+    @property
+    def hypparams(self):
+        return dict(alpha_0=self.alpha_0,beta_0=self.beta_0)
+
+    def log_sf(self,x):
+        return stats.poisson.logsf(x,self.lmbda)
+
+    def _posterior_hypparams(self,n,tot):
+        return self.alpha_0 + tot, self.beta_0 + n
+
+    def rvs(self,size=None):
+        return np.random.poisson(self.lmbda,size=size)
+
+    def log_likelihood(self,x):
+        lmbda = self.lmbda
+        x = np.array(x,ndmin=1)
+        raw = np.empty(x.shape)
+        raw[x>=0] = -lmbda + x[x>=0]*np.log(lmbda) - special.gammaln(x[x>=0]+1)
+        raw[x<0] = -np.inf
+        return raw if isinstance(x,np.ndarray) else raw[0]
+
+    def _get_statistics(self,data):
+        if isinstance(data,np.ndarray):
+            n = data.shape[0]
+            tot = data.sum()
+        elif isinstance(data,list):
+            n = sum(d.shape[0] for d in data)
+            tot = sum(d.sum() for d in data)
+        else:
+            assert np.isscalar(data)
+            n = 1
+            tot = data
+
+        return n, tot
+
+    def _get_weighted_statistics(self,data,weights):
+        if isinstance(data,np.ndarray):
+            n = weights.sum()
+            tot = weights.dot(data)
+        elif isinstance(data,list):
+            n = sum(w.sum() for w in weights)
+            tot = sum(w.dot(d) for w,d in zip(weights,data))
+        else:
+            assert np.isscalar(data) and np.isscalar(weights)
+            n = weights
+            tot = weights*data
+
+        return np.array([n, tot])
+
+    ### Gibbs Sampling
+
+    def resample(self,data=[],stats=None):
+        stats = self._get_statistics(data) if stats is None else stats
+        alpha_n, beta_n = self._posterior_hypparams(*stats)
+        self.lmbda = np.random.gamma(alpha_n,1/beta_n)
+
+        # next line is for mean field initialization
+        self.mf_alpha_0, self.mf_beta_0 = self.lmbda * self.beta_0, self.beta_0
+
+        return self
+
+    ### Mean Field
+
+    def _resample_from_mf(self):
+        mf_alpha_0, mf_beta_0 = self._natural_to_standard(self.mf_natural_hypparam)
+        self.lmbda = np.random.gamma(mf_alpha_0, 1./mf_beta_0)
+
+    def meanfieldupdate(self,data,weights):
+        self.mf_natural_hypparam = \
+                self.natural_hypparam + self._get_weighted_statistics(data,weights)
+        self.lmbda = self.mf_alpha_0 / self.mf_beta_0
+
+    def meanfield_sgdstep(self,data,weights,prob,stepsize):
+        self.mf_natural_hypparam = \
+                (1-stepsize) * self.mf_natural_hypparam + stepsize * (
+                        self.natural_hypparam
+                        + 1./prob * self._get_weighted_statistics(data,weights))
+
+    def get_vlb(self):
+        return (self.natural_hypparam - self.mf_natural_hypparam).dot(self._mf_expected_statistics) \
+                - (self._log_partition_fn(self.alpha_0,self.beta_0)
+                        - self._log_partition_fn(self.mf_alpha_0,self.mf_beta_0))
+
+    def expected_log_likelihood(self,x):
+        Emlmbda, Elnlmbda = self._mf_expected_statistics
+        return -special.gammaln(x+1) + Elnlmbda * x + Emlmbda
+
+    @property
+    def _mf_expected_statistics(self):
+        alpha, beta = self.mf_alpha_0, self.mf_beta_0
+        return np.array([-alpha/beta, special.digamma(alpha) - np.log(beta)])
+
+
+    @property
+    def natural_hypparam(self):
+        return self._standard_to_natural(self.alpha_0,self.beta_0)
+
+    @property
+    def mf_natural_hypparam(self):
+        return self._standard_to_natural(self.mf_alpha_0,self.mf_beta_0)
+
+    @mf_natural_hypparam.setter
+    def mf_natural_hypparam(self,natparam):
+        self.mf_alpha_0, self.mf_beta_0 = self._natural_to_standard(natparam)
+
+
+    def _standard_to_natural(self,alpha,beta):
+        return np.array([beta, alpha-1])
+
+    def _natural_to_standard(self,natparam):
+        return natparam[1]+1, natparam[0]
+
+    ### Collapsed
+
+    def log_marginal_likelihood(self,data):
+        return self._log_partition_fn(*self._posterior_hypparams(*self._get_statistics(data))) \
+                - self._log_partition_fn(self.alpha_0,self.beta_0) \
+                - self._get_sum_of_gammas(data)
+
+    def _log_partition_fn(self,alpha,beta):
+        return special.gammaln(alpha) - alpha * np.log(beta)
+
+    def _get_sum_of_gammas(self,data):
+        if isinstance(data,np.ndarray):
+            return special.gammaln(data+1).sum()
+        elif isinstance(data,list):
+            return sum(special.gammaln(d+1).sum() for d in data)
+        else:
+            assert isinstance(data,int)
+            return special.gammaln(data+1)
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None):
+        if weights is None:
+            n, tot = self._get_statistics(data)
+        else:
+            n, tot = self._get_weighted_statistics(data,weights)
+
+        if n > 1e-2:
+            self.lmbda = tot/n
+            assert self.lmbda > 0
+        else:
+            self.broken = True
+            self.lmbda = 999999
+
+        return self
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/regression.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..b524061ee1a0f4ea1618164c61f143db695f4cc7
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/regression.py
@@ -0,0 +1,1206 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+__all__ = ['Regression', 'RegressionNonconj', 'ARDRegression',
+           'AutoRegression', 'ARDAutoRegression', 'DiagonalRegression',
+           'RobustRegression', 'RobustAutoRegression']
+
+from warnings import warn
+
+import numpy as np
+from numpy import newaxis as na
+
+from scipy.linalg import solve_triangular
+from scipy.special import gammaln, digamma, polygamma
+
+from pybasicbayes.abstractions import GibbsSampling, MaxLikelihood, \
+    MeanField, MeanFieldSVI
+from pybasicbayes.util.stats import sample_gaussian, sample_mniw, \
+    sample_invwishart, getdatasize, mniw_expectedstats, mniw_log_partitionfunction, \
+    sample_invgamma, update_param
+
+from pybasicbayes.util.general import blockarray, inv_psd, cumsum, \
+    all_none, any_none, AR_striding, objarray
+
+
+class Regression(GibbsSampling, MeanField, MaxLikelihood):
+    def __init__(
+            self, nu_0=None,S_0=None,M_0=None,K_0=None,
+            affine=False,
+            A=None,sigma=None):
+        self.affine = affine
+
+        self._check_shapes(A, sigma, nu_0, S_0, M_0, K_0)
+
+        self.A = A
+        self.sigma = sigma
+
+        have_hypers = not any_none(nu_0,S_0,M_0,K_0)
+
+        if have_hypers:
+            self.natural_hypparam = self.mf_natural_hypparam = \
+                self._standard_to_natural(nu_0,S_0,M_0,K_0)
+
+        if A is sigma is None and have_hypers:
+            self.resample()  # initialize from prior
+
+    @staticmethod
+    def _check_shapes(A, sigma, nu, S, M, K):
+        is_2d = lambda x: isinstance(x, np.ndarray) and x.ndim == 2
+        not_none = lambda x: x is not None
+        assert all(map(is_2d, filter(not_none, [A, sigma, S, M, K]))), 'Matrices must be 2D'
+
+        get_dim = lambda x, i: x.shape[i] if x is not None else None
+        get_dim_list = lambda pairs: filter(not_none, map(get_dim, *zip(*pairs)))
+        is_consistent = lambda dimlist: len(set(dimlist)) == 1
+        dims_agree = lambda pairs: is_consistent(get_dim_list(pairs))
+        assert dims_agree([(A, 1), (M, 1), (K, 0), (K, 1)]), 'Input dimensions not consistent'
+        assert dims_agree([(A, 0), (sigma, 0), (sigma, 1), (S, 0), (S, 1), (M, 0)]), \
+            'Output dimensions not consistent'
+
+    @property
+    def parameters(self):
+        return (self.A, self.sigma)
+
+    @parameters.setter
+    def parameters(self, A_sigma_tuple):
+        (A,sigma) = A_sigma_tuple
+        self.A = A
+        self.sigma = sigma
+
+    @property
+    def D_in(self):
+        # NOTE: D_in includes the extra affine coordinate
+        mat = self.A if self.A is not None else self.natural_hypparam[1]
+        return mat.shape[1]
+
+    @property
+    def D_out(self):
+        mat = self.A if self.A is not None else self.natural_hypparam[1]
+        return mat.shape[0]
+
+    ### converting between natural and standard parameters
+
+    @staticmethod
+    def _standard_to_natural(nu,S,M,K):
+        Kinv = inv_psd(K)
+        A = S + M.dot(Kinv).dot(M.T)
+        B = M.dot(Kinv)
+        C = Kinv
+        d = nu
+        return np.array([A,B,C,d])
+
+    @staticmethod
+    def _natural_to_standard(natparam):
+        A,B,C,d = natparam   # natparam is roughly (yyT, yxT, xxT, n)
+        nu = d
+        Kinv = C
+        K = inv_psd(Kinv)
+        # M = B.dot(K)
+        M = np.linalg.solve(Kinv, B.T).T
+        # This subtraction seems unstable!
+        # It does not necessarily return a PSD matrix
+        S = A - M.dot(B.T)
+
+        # numerical padding here...
+        K += 1e-8*np.eye(K.shape[0])
+        S += 1e-8*np.eye(S.shape[0])
+        assert np.all(0 < np.linalg.eigvalsh(S))
+        assert np.all(0 < np.linalg.eigvalsh(K))
+
+        # standard is degrees of freedom, mean of sigma (ish), mean of A, cov of rows of A
+        return nu, S, M, K
+
+    ### getting statistics
+
+    # NOTE: stats object arrays depend on the last element being a scalar,
+    # otherwise numpy will try to create a dense array and fail
+
+    def _get_statistics(self,data):
+        assert isinstance(data, (list, tuple, np.ndarray))
+        if isinstance(data,list):
+            return sum((self._get_statistics(d) for d in data),
+                       self._empty_statistics())
+        elif isinstance(data, tuple):
+            x, y = data
+            bad = np.isnan(x).any(1) | np.isnan(y).any(1)
+            x, y = x[~bad], y[~bad]
+
+            n, D = y.shape
+
+            xxT, yxT, yyT = \
+                x.T.dot(x), y.T.dot(x), y.T.dot(y)
+
+            if self.affine:
+                x, y = x.sum(0), y.sum(0)
+                xxT = blockarray([[xxT,x[:,na]],[x[na,:],np.atleast_2d(n)]])
+                yxT = np.hstack((yxT,y[:,na]))
+
+            return np.array([yyT, yxT, xxT, n])
+        else:
+            # data passed in like np.hstack((x, y))
+            data = data[~np.isnan(data).any(1)]
+            n, D = data.shape[0], self.D_out
+
+            statmat = data.T.dot(data)
+            xxT, yxT, yyT = \
+                statmat[:-D,:-D], statmat[-D:,:-D], statmat[-D:,-D:]
+
+            if self.affine:
+                xy = data.sum(0)
+                x, y = xy[:-D], xy[-D:]
+                xxT = blockarray([[xxT,x[:,na]],[x[na,:],np.atleast_2d(n)]])
+                yxT = np.hstack((yxT,y[:,na]))
+
+            return np.array([yyT, yxT, xxT, n])
+
+    def _get_weighted_statistics(self,data,weights):
+        assert isinstance(data, (list, tuple, np.ndarray))
+        if isinstance(data,list):
+            return sum((self._get_statistics(d) for d in data),
+                       self._empty_statistics())
+        elif isinstance(data, tuple):
+            x, y = data
+            bad = np.isnan(x).any(1) | np.isnan(y).any(1)
+            x, y, weights = x[~bad], y[~bad], weights[~bad]
+
+            n, D = weights.sum(), y.shape[1]
+            wx = weights[:,na]*x
+
+            xxT, yxT, yyT = \
+                x.T.dot(wx), y.T.dot(wx), y.T.dot(weights[:,na]*y)
+
+            if self.affine:
+                x, y = weights.dot(x), weights.dot(y)
+                xxT = blockarray([[xxT,x[:,na]],[x[na,:],np.atleast_2d(n)]])
+                yxT = np.hstack((yxT,y[:,na]))
+
+            return np.array([yyT, yxT, xxT, n])
+        else:
+            # data passed in like np.hstack((x, y))
+            gi = ~np.isnan(data).any(1)
+            data, weights = data[gi], weights[gi]
+            n, D = weights.sum(), self.D_out
+
+            statmat = data.T.dot(weights[:,na]*data)
+            xxT, yxT, yyT = \
+                statmat[:-D,:-D], statmat[-D:,:-D], statmat[-D:,-D:]
+
+            if self.affine:
+                xy = weights.dot(data)
+                x, y = xy[:-D], xy[-D:]
+                xxT = blockarray([[xxT,x[:,na]],[x[na,:],np.atleast_2d(n)]])
+                yxT = np.hstack((yxT,y[:,na]))
+
+            return np.array([yyT, yxT, xxT, n])
+
+    def _empty_statistics(self):
+        D_in, D_out = self.D_in, self.D_out
+        return np.array(
+            [np.zeros((D_out,D_out)), np.zeros((D_out,D_in)),
+             np.zeros((D_in,D_in)),0])
+
+    @staticmethod
+    def _stats_ensure_array(stats):
+        if isinstance(stats, np.ndarray):
+            return stats
+        affine = len(stats) > 4
+
+        yyT, yxT, xxT, n = stats[-4:]
+        if affine:
+            y, x = stats[:2]
+            yxT = np.hstack((yxT, y[:,None]))
+            xxT = blockarray([[xxT, x[:,None]], [x[None,:], 1.]])
+
+        return np.array([yyT, yxT, xxT, n])
+
+    ### distribution
+
+    def log_likelihood(self,xy):
+        assert isinstance(xy, (tuple,np.ndarray))
+        A, sigma, D = self.A, self.sigma, self.D_out
+        x, y = (xy[:,:-D], xy[:,-D:]) if isinstance(xy,np.ndarray) else xy
+
+        if self.affine:
+            A, b = A[:,:-1], A[:,-1]
+
+        sigma_inv, L = inv_psd(sigma, return_chol=True)
+        parammat = -1./2 * blockarray([
+            [A.T.dot(sigma_inv).dot(A), -A.T.dot(sigma_inv)],
+            [-sigma_inv.dot(A), sigma_inv]])
+
+        contract = 'ni,ni->n' if x.ndim == 2 else 'i,i->'
+        if isinstance(xy, np.ndarray):
+            out = np.einsum(contract,xy.dot(parammat),xy)
+        else:
+            out = np.einsum(contract,x.dot(parammat[:-D,:-D]),x)
+            out += np.einsum(contract,y.dot(parammat[-D:,-D:]),y)
+            out += 2*np.einsum(contract,x.dot(parammat[:-D,-D:]),y)
+
+        out -= D/2*np.log(2*np.pi) + np.log(np.diag(L)).sum()
+
+        if self.affine:
+            out += y.dot(sigma_inv).dot(b)
+            out -= x.dot(A.T).dot(sigma_inv).dot(b)
+            out -= 1./2*b.dot(sigma_inv).dot(b)
+
+        return out
+
+    def predict(self, x):
+        A, sigma = self.A, self.sigma
+
+        if self.affine:
+            A, b = A[:, :-1], A[:, -1]
+            y = x.dot(A.T) + b.T
+        else:
+            y = x.dot(A.T)
+
+        return y
+
+    def rvs(self,x=None,size=1,return_xy=True):
+        A, sigma = self.A, self.sigma
+
+        if self.affine:
+            A, b = A[:,:-1], A[:,-1]
+
+        x = np.random.normal(size=(size,A.shape[1])) if x is None else x
+        y = self.predict(x)
+        y += np.random.normal(size=(x.shape[0], self.D_out)) \
+            .dot(np.linalg.cholesky(sigma).T)
+
+        return np.hstack((x,y)) if return_xy else y
+
+    ### Gibbs sampling
+
+    def resample(self,data=[],stats=None):
+        stats = self._get_statistics(data) if stats is None else stats
+        self.A, self.sigma = sample_mniw(
+            *self._natural_to_standard(self.natural_hypparam + stats))
+        self._initialize_mean_field()
+
+    ### Max likelihood
+
+    def max_likelihood(self,data,weights=None,stats=None):
+        if stats is None:
+            stats = self._get_statistics(data) if weights is None \
+                else self._get_weighted_statistics(data,weights)
+
+        yyT, yxT, xxT, n = stats
+
+        if n > 0:
+            try:
+                self.A = np.linalg.solve(xxT, yxT.T).T
+                self.sigma = (yyT - self.A.dot(yxT.T))/n
+
+                def symmetrize(A):
+                    return (A + A.T)/2.
+                self.sigma = 1e-10*np.eye(self.D_out) \
+                    + symmetrize(self.sigma)  # numerical
+            except np.linalg.LinAlgError:
+                self.broken = True
+        else:
+            self.broken = True
+
+        assert np.allclose(self.sigma,self.sigma.T)
+        assert np.all(np.linalg.eigvalsh(self.sigma) > 0.)
+
+        self._initialize_mean_field()
+
+        return self
+
+    ### Mean Field
+
+    def meanfieldupdate(self, data=None, weights=None, stats=None):
+        assert (data is not None and weights is not None) ^ (stats is not None)
+        stats = self._stats_ensure_array(stats) if stats is not None \
+            else self._get_weighted_statistics(data, weights)
+        self.mf_natural_hypparam = self.natural_hypparam + stats
+        self._set_params_from_mf()
+
+    def meanfield_sgdstep(self, data, weights, prob, stepsize, stats=None):
+        if stats is None:
+            stats = self._get_weighted_statistics(data, weights)
+        self.mf_natural_hypparam = \
+            (1-stepsize) * self.mf_natural_hypparam + stepsize \
+            * (self.natural_hypparam + 1./prob * stats)
+        self._set_params_from_mf()
+
+    def meanfield_expectedstats(self):
+        from pybasicbayes.util.stats import mniw_expectedstats
+        return mniw_expectedstats(
+                *self._natural_to_standard(self.mf_natural_hypparam))
+
+    def expected_log_likelihood(self, xy=None, stats=None):
+        # TODO test values, test for the affine case
+        assert isinstance(xy, (tuple, np.ndarray)) ^ isinstance(stats, tuple)
+
+        D = self.D_out
+        E_Sigmainv, E_Sigmainv_A, E_AT_Sigmainv_A, E_logdetSigmainv = \
+            mniw_expectedstats(
+                *self._natural_to_standard(self.mf_natural_hypparam))
+
+        if self.affine:
+            E_Sigmainv_A, E_Sigmainv_b = \
+                E_Sigmainv_A[:,:-1], E_Sigmainv_A[:,-1]
+            E_AT_Sigmainv_A, E_AT_Sigmainv_b, E_bT_Sigmainv_b = \
+                E_AT_Sigmainv_A[:-1,:-1], E_AT_Sigmainv_A[:-1,-1], \
+                E_AT_Sigmainv_A[-1,-1]
+
+        if xy is not None:
+            x, y = (xy[:,:-D], xy[:,-D:]) if isinstance(xy, np.ndarray) \
+                else xy
+
+            parammat = -1./2 * blockarray([
+                [E_AT_Sigmainv_A, -E_Sigmainv_A.T],
+                [-E_Sigmainv_A, E_Sigmainv]])
+
+            contract = 'ni,ni->n' if x.ndim == 2 else 'i,i->'
+            if isinstance(xy, np.ndarray):
+                out = np.einsum('ni,ni->n', xy.dot(parammat), xy)
+            else:
+                out = np.einsum(contract,x.dot(parammat[:-D,:-D]),x)
+                out += np.einsum(contract,y.dot(parammat[-D:,-D:]),y)
+                out += 2*np.einsum(contract,x.dot(parammat[:-D,-D:]),y)
+
+            out += -D/2*np.log(2*np.pi) + 1./2*E_logdetSigmainv
+
+            if self.affine:
+                out += y.dot(E_Sigmainv_b)
+                out -= x.dot(E_AT_Sigmainv_b)
+                out -= 1./2 * E_bT_Sigmainv_b
+        else:
+            if self.affine:
+                Ey, Ex = stats[:2]
+            yyT, yxT, xxT, n = stats[-4:]
+
+            contract = 'ij,nij->n' if yyT.ndim == 3 else 'ij,ij->'
+
+            out = -1./2 * np.einsum(contract, E_AT_Sigmainv_A, xxT)
+            out += np.einsum(contract, E_Sigmainv_A, yxT)
+            out += -1./2 * np.einsum(contract, E_Sigmainv, yyT)
+            out += -D/2*np.log(2*np.pi) + n/2.*E_logdetSigmainv
+
+            if self.affine:
+                out += Ey.dot(E_Sigmainv_b)
+                out -= Ex.dot(E_AT_Sigmainv_b)
+                out -= 1./2 * E_bT_Sigmainv_b
+
+        return out
+
+    def get_vlb(self):
+        E_Sigmainv, E_Sigmainv_A, E_AT_Sigmainv_A, E_logdetSigmainv = \
+            mniw_expectedstats(*self._natural_to_standard(self.mf_natural_hypparam))
+        A, B, C, d = self.natural_hypparam - self.mf_natural_hypparam
+        bilinear_term = -1./2 * np.trace(A.dot(E_Sigmainv)) \
+            + np.trace(B.T.dot(E_Sigmainv_A)) \
+            - 1./2 * np.trace(C.dot(E_AT_Sigmainv_A)) \
+            + 1./2 * d * E_logdetSigmainv
+
+        # log normalizer term
+        Z = mniw_log_partitionfunction(*self._natural_to_standard(
+            self.natural_hypparam))
+        Z_mf = mniw_log_partitionfunction(*self._natural_to_standard(
+            self.mf_natural_hypparam))
+
+        return bilinear_term - (Z - Z_mf)
+
+    def resample_from_mf(self):
+        self.A, self.sigma = sample_mniw(
+            *self._natural_to_standard(self.mf_natural_hypparam))
+
+    def _set_params_from_mf(self):
+        nu, S, M, K = self._natural_to_standard(self.mf_natural_hypparam)
+        self.A, self.sigma = M, S / nu
+
+    def _initialize_mean_field(self):
+        if hasattr(self, 'natural_hypparam'):
+            A, Sigma = self.A, self.sigma
+            nu, S, M, K = self._natural_to_standard(self.natural_hypparam)
+            self.mf_natural_hypparam = self._standard_to_natural(
+                nu, nu*Sigma, A, K)
+
+
+class RegressionNonconj(Regression):
+    def __init__(self, M_0, Sigma_0, nu_0, S_0,
+                 A=None, sigma=None, affine=False, niter=10):
+        self.A = A
+        self.sigma = sigma
+        self.affine = affine
+
+        self.h_0 = np.linalg.solve(Sigma_0, M_0.ravel()).reshape(M_0.shape)
+        self.J_0 = np.linalg.inv(Sigma_0)
+        self.nu_0 = nu_0
+        self.S_0 = S_0
+
+        self.niter = niter
+
+        if all_none(A,sigma):
+            self.resample()  # initialize from prior
+
+    ### Gibbs
+
+    def resample(self,data=[],niter=None):
+        niter = niter if niter else self.niter
+        if getdatasize(data) == 0:
+            self.A = sample_gaussian(J=self.J_0,h=self.h_0.ravel())\
+                .reshape(self.h_0.shape)
+            self.sigma = sample_invwishart(self.S_0,self.nu_0)
+        else:
+            yyT, yxT, xxT, n = self._get_statistics(data)
+            for itr in range(niter):
+                self._resample_A(xxT, yxT, self.sigma)
+                self._resample_sigma(xxT, yxT, yyT, n, self.A)
+
+    def _resample_A(self, xxT, yxT, sigma):
+        sigmainv = np.linalg.inv(sigma)
+        J = self.J_0 + np.kron(sigmainv, xxT)
+        h = self.h_0 + sigmainv.dot(yxT)
+        self.A = sample_gaussian(J=J,h=h.ravel()).reshape(h.shape)
+
+    def _resample_sigma(self, xxT, yxT, yyT, n, A):
+        S = self.S_0 + yyT - yxT.dot(A.T) - A.dot(yxT.T) + A.dot(xxT).dot(A.T)
+        nu = self.nu_0 + n
+        self.sigma = sample_invwishart(S, nu)
+
+
+class ARDRegression(Regression):
+    def __init__(
+            self, a,b,nu_0,S_0,M_0,
+            blocksizes=None,K_0=None,niter=10,**kwargs):
+        blocksizes = np.ones(M_0.shape[1],dtype=np.int64) \
+            if blocksizes is None else blocksizes
+        self.niter = niter
+        self.blocksizes = np.array(blocksizes)
+        self.starts = cumsum(blocksizes,strict=True)
+        self.stops = cumsum(blocksizes,strict=False)
+
+        self.a = np.repeat(a,len(blocksizes))
+        self.b = np.repeat(b,len(blocksizes))
+
+        self.nu_0 = nu_0
+        self.S_0 = S_0
+        self.M_0 = M_0
+
+        if K_0 is None:
+            self.resample_K()
+        else:
+            self.K_0 = K_0
+
+        super(ARDRegression,self).__init__(
+            K_0=self.K_0,nu_0=nu_0,S_0=S_0,M_0=M_0,**kwargs)
+
+    def resample(self,data=[],stats=None):
+        if len(data) > 0 or stats is not None:
+            stats = self._get_statistics(data) if stats is None else stats
+            for itr in range(self.niter):
+                self.A, self.sigma = \
+                    sample_mniw(*self._natural_to_standard(
+                        self.natural_hypparam + stats))
+
+                mat = self.M_0 - self.A
+                self.resample_K(1./2*np.einsum(
+                    'ij,ij->j',mat,np.linalg.solve(self.sigma,mat)))
+        else:
+            self.resample_K()
+            super(ARDRegression,self).resample()
+
+    def resample_K(self,diag=None):
+        if diag is None:
+            a, b = self.a, self.b
+        else:
+            sums = [diag[start:stop].sum()
+                    for start,stop in zip(self.starts,self.stops)]
+            a = self.a + self.D_out*self.blocksizes/2.
+            b = self.b + np.array(sums)
+
+        ks = 1./np.random.gamma(a,scale=1./b)
+        self.K_0 = np.diag(np.repeat(ks,self.blocksizes))
+
+        self.natural_hypparam = self._standard_to_natural(
+            self.nu_0,self.S_0,self.M_0,self.K_0)
+
+    @property
+    def parameters(self):
+        return (self.A, self.sigma, self.K_0)
+
+    @parameters.setter
+    def parameters(self, A_sigma_K_0_tuple1):
+        (A,sigma,K_0) = A_sigma_K_0_tuple1
+        self.A = A
+        self.sigma = sigma
+        self.K_0 = K_0
+
+
+class DiagonalRegression(Regression, MeanFieldSVI):
+    """
+    Special case of the regression class in which the observations
+    have diagonal Gaussian noise and, potentially, missing data.
+    """
+
+    def __init__(self, D_out, D_in, mu_0=None, Sigma_0=None, alpha_0=3.0, beta_0=2.0,
+                 A=None, sigmasq=None, niter=1):
+
+        self._D_out = D_out
+        self._D_in = D_in
+        self.A = A
+        self.sigmasq_flat = sigmasq
+        self.affine = False # We do not yet support affine
+
+        mu_0 = np.zeros(D_in) if mu_0 is None else mu_0
+        Sigma_0 = np.eye(D_in) if Sigma_0 is None else Sigma_0
+        assert mu_0.shape == (D_in,)
+        assert Sigma_0.shape == (D_in, D_in)
+        self.h_0 = np.linalg.solve(Sigma_0, mu_0)
+        self.J_0 = np.linalg.inv(Sigma_0)
+        self.alpha_0 = alpha_0
+        self.beta_0 = beta_0
+
+        self.niter = niter
+
+        if any_none(A, sigmasq):
+            self.A = np.zeros((D_out, D_in))
+            self.sigmasq_flat = np.ones((D_out,))
+            self.resample(data=None)  # initialize from prior
+
+        # Store the natural parameters and expose the standard versions as properties
+        self.mf_J_A = np.array([self.J_0.copy() for _ in range(D_out)])
+        # Initializing with mean zero is pathological. Break symmetry by starting with sampled A.
+        # self.mf_h_A = np.array([self.h_0.copy() for _ in range(D_out)])
+        self.mf_h_A = np.array([Jd.dot(Ad) for Jd,Ad in zip(self.mf_J_A, self.A)])
+
+        self.mf_alpha = self.alpha_0 * np.ones(D_out)
+        self.mf_beta = self.alpha_0 * self.sigmasq_flat
+
+        # Cache the standard parameters for A as well
+        self._mf_A_cache = {}
+
+        # Store the natural hypparams.  These correspond to the suff. stats
+        # (y^2, yxT, xxT, n)
+        # self.natural_hypparam = (2 * self.beta_0, self.h_0, self.J_0, 1.0)
+
+    @property
+    def D_out(self):
+        return self._D_out
+
+    @property
+    def D_in(self):
+        return self._D_in
+
+    @property
+    def sigma(self):
+        return np.diag(self.sigmasq_flat)
+
+    @property
+    def mf_expectations(self):
+        # Look for expectations in the cache
+        if ("mf_E_A" not in self._mf_A_cache) or \
+            ("mf_E_AAT" not in self._mf_A_cache):
+            mf_Sigma_A = \
+                np.array([np.linalg.inv(Jd) for Jd in self.mf_J_A])
+
+            self._mf_A_cache["mf_E_A"] = \
+                np.array([np.dot(Sd, hd)
+                          for Sd,hd in zip(mf_Sigma_A, self.mf_h_A)])
+
+            self._mf_A_cache["mf_E_AAT"] = \
+                np.array([Sd + np.outer(md,md)
+                          for Sd,md in zip(mf_Sigma_A, self._mf_A_cache["mf_E_A"])])
+
+        mf_E_A = self._mf_A_cache["mf_E_A"]
+        mf_E_AAT = self._mf_A_cache["mf_E_AAT"]
+
+        # Set the invgamma meanfield expectation
+        from scipy.special import digamma
+        mf_E_sigmasq_inv = self.mf_alpha / self.mf_beta
+        mf_E_log_sigmasq = np.log(self.mf_beta) - digamma(self.mf_alpha)
+
+        return mf_E_A, mf_E_AAT, mf_E_sigmasq_inv, mf_E_log_sigmasq
+
+    # TODO: This is a bit ugly... Return stats in the form expected by PyLDS
+    def meanfield_expectedstats(self):
+        mf_E_A, mf_E_AAT, mf_E_sigmasq_inv, mf_E_log_sigmasq = self.mf_expectations
+        E_Sigmainv = np.diag(mf_E_sigmasq_inv)
+        E_Sigmainv_A  = mf_E_A * mf_E_sigmasq_inv[:,None]
+        E_AT_Sigmainv_A = np.sum(mf_E_sigmasq_inv[:,None,None] * mf_E_AAT, axis=0)
+        E_logdetSigmainv = -np.sum(mf_E_log_sigmasq)
+        return E_Sigmainv, E_Sigmainv_A, E_AT_Sigmainv_A, E_logdetSigmainv
+
+    def log_likelihood(self, xy, mask=None):
+        if isinstance(xy, tuple):
+            x,y = xy
+        else:
+            x,y = xy[:,:self.D_in], xy[:,self.D_in:]
+            assert y.shape[1] == self.D_out
+
+        if mask is None:
+            mask = np.ones_like(y)
+        else:
+            assert mask.shape == y.shape
+
+        sqerr = -0.5 * (y-x.dot(self.A.T))**2 * mask
+        ll = np.sum(sqerr / self.sigmasq_flat, axis=1)
+
+        # Add normalizer
+        ll += np.sum(-0.5*np.log(2*np.pi*self.sigmasq_flat) * mask, axis=1)
+
+        return ll
+
+    def _get_statistics(self, data, D_out=None, D_in=None, mask=None):
+        D_out = self.D_out if D_out is None else D_out
+        D_in = self.D_in if D_in is None else D_in
+        if data is None:
+            return (np.zeros((D_out,)),
+                    np.zeros((D_out, D_in)),
+                    np.zeros((D_out, D_in, D_in)),
+                    np.zeros((D_out,)))
+
+        # Make sure data is a list
+        if not isinstance(data, list):
+            datas = [data]
+        else:
+            datas = data
+
+        # Make sure mask is also a list if given
+        if mask is not None:
+            if not isinstance(mask, list):
+                masks = [mask]
+            else:
+                masks = mask
+        else:
+            masks = [None] * len(datas)
+
+        # Sum sufficient statistics from each dataset
+        ysq = np.zeros(D_out)
+        yxT = np.zeros((D_out, D_in))
+        xxT = np.zeros((D_out, D_in, D_in))
+        n = np.zeros(D_out)
+
+        for data, mask in zip(datas, masks):
+            # Dandle tuples or hstack-ed arrays
+            if isinstance(data, tuple):
+                x, y = data
+            else:
+                x, y = data[:,:D_in], data[:, D_in:]
+            assert x.shape[1] == D_in
+            assert y.shape[1] == D_out
+
+            if mask is None:
+                mask = np.ones_like(y, dtype=bool)
+
+            ysq += np.sum(y**2 * mask, axis=0)
+            yxT += (y*mask).T.dot(x)
+            xxT += np.array([(x * mask[:,d][:,None]).T.dot(x)
+                            for d in range(D_out)])
+            n += np.sum(mask, axis=0)
+        return ysq, yxT, xxT, n
+
+    @staticmethod
+    def _stats_ensure_array(stats):
+        ysq, yxT, xxT, n = stats
+
+        if yxT.ndim != 2:
+            raise Exception("yxT.shape must be (D_out, D_in)")
+        D_out, D_in = yxT.shape
+
+        # If ysq is D_out x D_out, just take the diagonal
+        if ysq.ndim == 1:
+            assert ysq.shape == (D_out,)
+        elif ysq.ndim == 2:
+            assert ysq.shape == (D_out, D_out)
+            ysq = np.diag(ysq)
+        else:
+            raise Exception("ysq.shape must be (D_out,) or (D_out, D_out)")
+
+        # Make sure xxT is D_out x D_in x D_in
+        if xxT.ndim == 2:
+            assert xxT.shape == (D_in, D_in)
+            xxT = np.tile(xxT[None,:,:], (D_out, 1, 1))
+        elif xxT.ndim == 3:
+            assert xxT.shape == (D_out, D_in, D_in)
+        else:
+            raise Exception("xxT.shape must be (D_in, D_in) or (D_out, D_in, D_in)")
+
+        # Make sure n is of shape (D_out,)
+        if np.isscalar(n):
+            n = n * np.ones(D_out)
+        elif n.ndim == 1:
+            assert n.shape == (D_out,)
+        else:
+            raise Exception("n must be a scalar or an array of shape (D_out,)")
+
+        return objarray([ysq, yxT, xxT, n])
+
+    ### Gibbs
+    def resample(self, data, stats=None, mask=None, niter=None):
+        """
+        Introduce a mask that allows for missing data
+        """
+        stats = self._get_statistics(data, mask=mask) if stats is None else stats
+        stats = self._stats_ensure_array(stats)
+
+        niter = niter if niter else self.niter
+        for itr in range(niter):
+            self._resample_A(stats)
+            self._resample_sigma(stats)
+
+    def _resample_A(self, stats):
+
+        _, yxT, xxT, _ = stats
+
+        # Sample each row of W
+        for d in range(self.D_out):
+            # Get sufficient statistics from the data
+            Jd = self.J_0 + xxT[d] / self.sigmasq_flat[d]
+            hd = self.h_0 + yxT[d] / self.sigmasq_flat[d]
+            self.A[d] = sample_gaussian(J=Jd, h=hd)
+
+    def _resample_sigma(self, stats):
+        ysq, yxT, xxT, n = stats
+        AAT = np.array([np.outer(a,a) for a in self.A])
+
+        alpha = self.alpha_0 + n / 2.0
+
+        beta = self.beta_0
+        beta += 0.5 * ysq
+        beta += -1.0 * np.sum(yxT * self.A, axis=1)
+        beta += 0.5 * np.sum(AAT * xxT, axis=(1,2))
+
+        self.sigmasq_flat = np.reshape(sample_invgamma(alpha, beta), (self.D_out,))
+
+    ### Max likelihood
+    def max_likelihood(self,data, weights=None, stats=None, mask=None):
+        if stats is None:
+            stats = self._get_statistics(data, mask)
+        stats = self._stats_ensure_array(stats)
+
+        ysq, yxT, xxT, n = stats
+
+        self.A = np.array([
+            np.linalg.solve(self.J_0 + xxTd, self.h_0 + yxTd)
+            for xxTd, yxTd in zip(xxT, yxT)
+        ])
+
+        alpha = self.alpha_0 + n / 2.0
+        beta = self.beta_0
+        beta += 0.5 * ysq
+        beta += -1.0 * np.sum(yxT * self.A, axis=1)
+        AAT = np.array([np.outer(ad, ad) for ad in self.A])
+        beta += 0.5 * np.sum(AAT * xxT, axis=(1, 2))
+
+        self.sigmasq_flat = beta / (alpha + 1.0)
+        assert np.all(self.sigmasq_flat >= 0)
+
+    ### Mean Field
+    def meanfieldupdate(self, data=None, weights=None, stats=None, mask=None):
+        assert weights is None, "Not supporting weighted data, just masked data."
+        if stats is None:
+            stats = self._get_statistics(data, mask)
+        stats = self._stats_ensure_array(stats)
+
+        self._meanfieldupdate_A(stats)
+        self._meanfieldupdate_sigma(stats)
+
+        # Update A and sigmasq_flat
+        A, _, sigmasq_inv, _ = self.mf_expectations
+        self.A = A.copy()
+        self.sigmasq_flat = 1. / sigmasq_inv
+
+    def _meanfieldupdate_A(self, stats, prob=1.0, stepsize=1.0):
+        E_sigmasq_inv = self.mf_alpha / self.mf_beta
+        _, E_yxT, E_xxT, _ = stats  / prob
+
+        # Update statistics each row of A
+        for d in range(self.D_out):
+            Jd = self.J_0 + (E_xxT[d] * E_sigmasq_inv[d])
+            hd = self.h_0 + (E_yxT[d] * E_sigmasq_inv[d])
+
+            # Update the mean field natural parameters
+            self.mf_J_A[d] = update_param(self.mf_J_A[d], Jd, stepsize)
+            self.mf_h_A[d] = update_param(self.mf_h_A[d], hd, stepsize)
+
+        # Clear the cache
+        self._mf_A_cache = {}
+
+    def _meanfieldupdate_sigma(self, stats, prob=1.0, stepsize=1.0):
+        E_ysq, E_yxT, E_xxT, E_n = stats / prob
+        E_A, E_AAT, _, _ = self.mf_expectations
+
+        alpha = self.alpha_0 + E_n / 2.0
+
+        beta = self.beta_0
+        beta += 0.5 * E_ysq
+        beta += -1.0 * np.sum(E_yxT * E_A, axis=1)
+        beta += 0.5 * np.sum(E_AAT * E_xxT, axis=(1,2))
+
+        # Set the invgamma meanfield parameters
+        self.mf_alpha = update_param(self.mf_alpha, alpha, stepsize)
+        self.mf_beta = update_param(self.mf_beta, beta, stepsize)
+
+    def get_vlb(self):
+        # TODO: Implement this!
+        return 0
+
+
+    def expected_log_likelihood(self, xy=None, stats=None, mask=None):
+        if xy is not None:
+            if isinstance(xy, tuple):
+                x, y = xy
+            else:
+                x, y = xy[:, :self.D_in], xy[:, self.D_in:]
+                assert y.shape[1] == self.D_out
+
+            E_ysq = y**2
+            E_yxT = y[:,:,None] * x[:,None,:]
+            E_xxT = x[:,:,None] * x[:,None,:]
+            E_n = np.ones_like(y) if mask is None else mask
+
+        elif stats is not None:
+            E_ysq, E_yxT, E_xxT, E_n = stats
+            T = E_ysq.shape[0]
+            assert E_ysq.shape == (T,self.D_out)
+            assert E_yxT.shape == (T,self.D_out,self.D_in)
+
+            if E_xxT.shape == (T, self.D_in, self.D_in):
+                E_xxT = E_xxT[:, None, :, :]
+            else:
+                assert E_xxT.shape == (T,self.D_out,self.D_in,self.D_in)
+
+            if E_n.shape == (T,):
+                E_n = E_n[:,None]
+            else:
+                assert E_n.shape == (T,self.D_out)
+
+        E_A, E_AAT, E_sigmasq_inv, E_log_sigmasq = self.mf_expectations
+
+        sqerr = -0.5 * E_ysq
+        sqerr += 1.0 * np.sum(E_yxT * E_A, axis=2)
+        sqerr += -0.5 * np.sum(E_xxT * E_AAT, axis=(2,3))
+
+        # Compute expected log likelihood
+        ell = np.sum(sqerr * E_sigmasq_inv, axis=1)
+        ell += np.sum(-0.5 * E_n * (E_log_sigmasq + np.log(2 * np.pi)), axis=1)
+
+        return ell
+
+    def resample_from_mf(self):
+        for d in range(self.D_out):
+            self.A[d] = sample_gaussian(J=self.mf_J_A[d], h=self.mf_h_A[d])
+        self.sigmasq_flat = sample_invgamma(self.mf_alpha, self.mf_beta) * np.ones(self.D_out)
+
+    def _initialize_mean_field(self):
+        A, sigmasq = self.A, self.sigmasq_flat
+
+        # Set mean field params such that A and sigmasq are the mean
+        self.mf_alpha = self.alpha_0
+        self.mf_beta = self.alpha_0 * sigmasq
+
+        self.mf_J_A = np.array([self.J_0.copy() for _ in range(self.D_out)])
+        self.mf_h_A = np.array([Jd.dot(Ad) for Jd, Ad in zip(self.mf_J_A, A)])
+
+    ### SVI
+    def meanfield_sgdstep(self, data, weights, prob, stepsize, stats=None, mask=None):
+        assert weights is None, "Not supporting weighted datapoints (just masked data)"
+        if stats is None:
+            stats = self._get_statistics(data, mask)
+        stats = self._stats_ensure_array(stats)
+
+        self._meanfieldupdate_A(stats, prob=prob, stepsize=stepsize)
+        self._meanfieldupdate_sigma(stats, prob=prob, stepsize=stepsize)
+
+
+class RobustRegression(Regression):
+    """
+    Regression with multivariate-t distributed noise.
+
+        y | x ~ t(Ax + b, \Sigma, \nu)
+
+    where \nu >= 1 is the degrees of freedom.
+
+    This is equivalent to the model,
+
+        \tau ~ Gamma(\nu/2,  \nu/2)
+        y | x, \tau ~ N(Ax + b, \Sigma / \tau)
+
+    To perform inference in this model, we will introduce
+    auxiliary variables tau (precisions).  With these, we
+    can compute sufficient statistics scaled by \tau and
+    use the standard regression object to
+    update A, b, Sigma | x, y, \tau.
+
+    The degrees of freedom parameter \nu is updated via maximum
+    likelihood using a generalized Newton's method proposed by
+    Tom Minka.  We are not using any prior on \nu, but we 
+    could experiment with updating \nu under an
+    uninformative prior, e.g. p(\nu) \propto \nu^{-2},
+    which is equivalent to a flat prior on \nu^{-1}.
+    """
+    def __init__(
+            self, nu_0=None,S_0=None, M_0=None, K_0=None, affine=False,
+            A=None, sigma=None, nu=None):
+
+        # Default to a somewhat intermediate value of nu
+        self.nu = self.default_nu = nu if nu is not None else 4.0
+
+        super(RobustRegression, self).__init__(
+            nu_0=nu_0, S_0=S_0, M_0=M_0, K_0=K_0, affine=affine, A=A, sigma=sigma)
+
+    def log_likelihood(self,xy):
+        assert isinstance(xy, (tuple, np.ndarray))
+        sigma, D, nu = self.sigma, self.D_out, self.nu
+        x, y = (xy[:,:-D], xy[:,-D:]) if isinstance(xy,np.ndarray) else xy
+
+        sigma_inv, L = inv_psd(sigma, return_chol=True)
+        r = y - self.predict(x)
+        z = sigma_inv.dot(r.T).T
+
+        out = -0.5 * (nu + D) * np.log(1.0 + (r * z).sum(1) / nu)
+        out += gammaln((nu + D) / 2.0) - gammaln(nu / 2.0) - D / 2.0 * np.log(nu) \
+            - D / 2.0 * np.log(np.pi) - np.log(np.diag(L)).sum()
+
+        return out
+
+    def rvs(self,x=None,size=1,return_xy=True):
+        A, sigma, nu, D = self.A, self.sigma, self.nu, self.D_out
+
+        if self.affine:
+            A, b = A[:,:-1], A[:,-1]
+
+        x = np.random.normal(size=(size, A.shape[1])) if x is None else x
+        N = x.shape[0]
+        mu = self.predict(x)
+
+        # Sample precisions and t-distributed residuals
+        tau = np.random.gamma(nu / 2.0, 2.0 / nu, size=(N,))
+        resid = np.random.randn(N, D).dot(np.linalg.cholesky(sigma).T)
+        resid /= np.sqrt(tau[:, None])
+
+        y = mu + resid
+        return np.hstack((x,y)) if return_xy else y
+
+    def _get_statistics(self, data):
+        raise Exception("RobustRegression needs scaled statistics.")
+
+    def _get_scaled_statistics(self, data, precisions):
+        assert isinstance(data, (list, tuple, np.ndarray))
+        if isinstance(data,list):
+            return sum((self._get_scaled_statistics(d, p) for d, p in zip(data, precisions)),
+                       self._empty_statistics())
+
+        elif isinstance(data, tuple):
+            x, y = data
+            bad = np.isnan(x).any(1) | np.isnan(y).any(1)
+            x, y = x[~bad], y[~bad]
+            precisions = precisions[~bad]
+            sqrt_prec = np.sqrt(precisions)
+            n, D = y.shape
+
+            if self.affine:
+                x = np.column_stack((x, np.ones(n)))
+
+            # Scale by the precision
+            # xs = x * sqrt_prec[:, na]
+            # ys = y * sqrt_prec[:, na]
+            xs = x * np.tile(sqrt_prec[:, None], (1, x.shape[1]))
+            ys = y * np.tile(sqrt_prec[:, None], (1, D))
+
+            xxT, yxT, yyT = xs.T.dot(xs), ys.T.dot(xs), ys.T.dot(ys)
+            return np.array([yyT, yxT, xxT, n])
+
+        else:
+            # data passed in like np.hstack((x, y))
+            # x, y = data[:,:-self.D_out], data[:,-self.D_out:]
+            # return self._get_scaled_statistics((x, y), precisions)
+            bad = np.isnan(data).any(1)
+            data = data[~bad]
+            precisions = precisions[~bad]
+            n, D = data.shape[0], self.D_out
+
+            # This tile call is suboptimal but without it we can hit issues
+            # with strided data, as in autoregressive models.
+            scaled_data = data * np.tile(precisions[:,None], (1, data.shape[1]))
+            statmat = scaled_data.T.dot(data)
+
+            xxT, yxT, yyT = \
+                statmat[:-D,:-D], statmat[-D:,:-D], statmat[-D:,-D:]
+
+            if self.affine:
+                xy = scaled_data.sum(0)
+                x, y = xy[:-D], xy[-D:]
+                xxT = blockarray([[xxT,     x[:,na]],
+                                  [x[na,:], np.atleast_2d(precisions.sum())]])
+                yxT = np.hstack((yxT, y[:,na]))
+
+            return np.array([yyT, yxT, xxT, n])
+
+    def resample(self, data=[], stats=None):
+        assert stats is None, \
+            "We only support RobustRegression.resample() with data, not stats."
+
+        # First sample auxiliary variables for each data point
+        tau = self._resample_precision(data)
+
+        # Compute statistics, scaling by tau, and resample as in standard Regression
+        stats = self._get_scaled_statistics(data, tau)
+        super(RobustRegression, self).resample(stats=stats)
+
+        # Resample degrees of freedom \nu
+        self._resample_nu(tau)
+
+    def _resample_precision(self, data):
+        assert isinstance(data, (list, tuple, np.ndarray))
+        if isinstance(data, list):
+            return [self._resample_precision(d) for d in data]
+
+        elif isinstance(data, tuple):
+            x, y = data
+
+        else:
+            x, y = data[:, :-self.D_out], data[:, -self.D_out:]
+
+        assert x.ndim == y.ndim == 2
+        assert x.shape[0] == y.shape[0]
+        assert x.shape[1] == self.D_in - 1 if self.affine else self.D_in
+        assert y.shape[1] == self.D_out
+        N = x.shape[0]
+
+        # Weed out the nan's
+        bad = np.any(np.isnan(x), axis=1) | np.any(np.isnan(y), axis=1)
+
+        # Compute posterior params of gamma distribution
+        a_post = self.nu / 2.0 + self.D_out / 2.0
+
+        r = y - self.predict(x)
+        sigma_inv = inv_psd(self.sigma)
+        z = sigma_inv.dot(r.T).T
+        b_post = self.nu / 2.0 + (r * z).sum(1) / 2.0
+
+        assert np.isscalar(a_post) and b_post.shape == (N,)
+        tau = np.nan * np.ones(N)
+        tau[~bad] = np.random.gamma(a_post, 1./b_post[~bad])
+
+        return tau
+        
+    def _resample_nu(self, tau, N_steps=100, prop_std=0.1, alpha=1, beta=1):
+        """
+        Update the degree of freedom parameter with 
+        Metropolis-Hastings. Assume a prior nu ~ Ga(alpha, beta) 
+        and use a proposal nu' ~ N(nu, prop_std^2). If proposals
+        are negative, reject automatically due to likelihood.
+        """
+        # Convert tau to a list of arrays
+        taus = [tau] if isinstance(tau, np.ndarray) else tau
+
+        N = 0
+        E_tau = 0
+        E_logtau = 0
+        for tau in taus:
+            bad = ~np.isfinite(tau)
+            N += np.sum(~bad)
+            E_tau += np.sum(tau[~bad])
+            E_logtau += np.sum(np.log(tau[~bad]))
+
+        if N > 0:
+            E_tau /= N
+            E_logtau /= N
+        
+        # Compute the log prior, likelihood, and posterior
+        lprior = lambda nu: (alpha - 1) * np.log(nu) - alpha * nu
+        ll = lambda nu: N * (nu/2 * np.log(nu/2)  - gammaln(nu/2) + (nu/2 - 1) * E_logtau - nu/2 * E_tau)
+        lp = lambda nu: ll(nu) + lprior(nu)
+
+        lp_curr = lp(self.nu)
+        for step in range(N_steps):
+            # Symmetric proposal
+            nu_new = self.nu + prop_std * np.random.randn()
+            if nu_new <1e-3:
+                # Reject if too small
+                continue
+
+            # Accept / reject based on likelihoods
+            lp_new = lp(nu_new)
+            if np.log(np.random.rand()) < lp_new - lp_curr:
+                self.nu = nu_new
+                lp_curr = lp_new
+        
+    # Not supporting MLE or mean field for now
+    def max_likelihood(self,data,weights=None,stats=None):
+        raise NotImplementedError
+
+    def meanfieldupdate(self, data=None, weights=None, stats=None):
+        raise NotImplementedError
+
+    def meanfield_sgdstep(self, data, weights, prob, stepsize, stats=None):
+        raise NotImplementedError
+
+    def meanfield_expectedstats(self):
+        raise NotImplementedError
+
+    def expected_log_likelihood(self, xy=None, stats=None):
+        raise NotImplementedError
+
+    def get_vlb(self):
+        raise NotImplementedError
+
+    def resample_from_mf(self):
+        raise NotImplementedError
+
+    def _set_params_from_mf(self):
+        raise NotImplementedError
+
+    def _initialize_mean_field(self):
+        pass
+
+
+class _ARMixin(object):
+    @property
+    def nlags(self):
+        if not self.affine:
+            return self.D_in // self.D_out
+        else:
+            return (self.D_in - 1) // self.D_out
+
+    @property
+    def D(self):
+        return self.D_out
+
+    def predict(self, x):
+        return super(_ARMixin,self).predict(np.atleast_2d(x))
+
+    def rvs(self,lagged_data):
+        return super(_ARMixin,self).rvs(
+                x=np.atleast_2d(lagged_data.ravel()),return_xy=False)
+
+    def _get_statistics(self,data):
+        return super(_ARMixin,self)._get_statistics(
+                data=self._ensure_strided(data))
+
+    def _get_weighted_statistics(self,data,weights):
+        return super(_ARMixin,self)._get_weighted_statistics(
+                data=self._ensure_strided(data),weights=weights)
+
+    def log_likelihood(self,xy):
+        return super(_ARMixin,self).log_likelihood(self._ensure_strided(xy))
+
+    def _ensure_strided(self,data):
+        if isinstance(data,np.ndarray):
+            if data.shape[1] != self.D*(self.nlags+1):
+                data = AR_striding(data,self.nlags)
+            return data
+        else:
+            return [self._ensure_strided(d) for d in data]
+
+
+class AutoRegression(_ARMixin,Regression):
+    pass
+
+
+class ARDAutoRegression(_ARMixin,ARDRegression):
+    def __init__(self,M_0,**kwargs):
+        blocksizes = [M_0.shape[0]]*(M_0.shape[1] // M_0.shape[0]) \
+                + ([1] if M_0.shape[1] % M_0.shape[0] and M_0.shape[0] != 1 else [])
+        super(ARDAutoRegression,self).__init__(
+                M_0=M_0,blocksizes=blocksizes,**kwargs)
+
+
+class RobustAutoRegression(_ARMixin, RobustRegression):
+    pass
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/uniform.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/uniform.py
new file mode 100644
index 0000000000000000000000000000000000000000..9104120c1b02fa61c3dc49f53494e5658fa5ce56
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/distributions/uniform.py
@@ -0,0 +1,137 @@
+from __future__ import division
+from builtins import map
+from builtins import range
+__all__ = ['UniformOneSided', 'Uniform']
+
+import numpy as np
+
+from pybasicbayes.abstractions import GibbsSampling
+from pybasicbayes.util.stats import sample_pareto
+from pybasicbayes.util.general import any_none
+
+
+class UniformOneSided(GibbsSampling):
+    '''
+    Models a uniform distribution over [low,high] for a parameter high.
+    Low is a fixed hyperparameter (hence "OneSided"). See the Uniform class for
+    the two-sided version.
+
+    Likelihood is x ~ U[low,high]
+    Prior is high ~ Pareto(x_m,alpha) following Wikipedia's notation
+
+    Hyperparameters:
+        x_m, alpha, low
+
+    Parameters:
+        high
+    '''
+    def __init__(self,high=None,x_m=None,alpha=None,low=0.):
+        self.high = high
+
+        self.x_m = x_m
+        self.alpha = alpha
+        self.low = low
+
+        have_hypers = x_m is not None and alpha is not None
+        if high is None and have_hypers:
+            self.resample()  # intialize from prior
+
+    @property
+    def params(self):
+        return {'high':self.high}
+
+    @property
+    def hypparams(self):
+        return dict(x_m=self.x_m,alpha=self.alpha,low=self.low)
+
+    def log_likelihood(self,x):
+        x = np.atleast_1d(x)
+        raw = np.where(
+            (self.low <= x) & (x < self.high),
+            -np.log(self.high - self.low),-np.inf)
+        return raw if isinstance(x,np.ndarray) else raw[0]
+
+    def rvs(self,size=[]):
+        return np.random.uniform(low=self.low,high=self.high,size=size)
+
+    def resample(self,data=[]):
+        self.high = sample_pareto(
+            *self._posterior_hypparams(*self._get_statistics(data)))
+        return self
+
+    def _get_statistics(self,data):
+        if isinstance(data,np.ndarray):
+            n = data.shape[0]
+            datamax = data.max()
+        else:
+            n = sum(d.shape[0] for d in data)
+            datamax = \
+                max(d.max() for d in data) if n > 0 else -np.inf
+        return n, datamax
+
+    def _posterior_hypparams(self,n,datamax):
+        return max(datamax,self.x_m), n + self.alpha
+
+
+class Uniform(UniformOneSided):
+    '''
+    Models a uniform distribution over [low,high] for parameters low and high.
+    The prior is non-conjugate (though it's conditionally conjugate over one
+    parameter at a time).
+
+    Likelihood is x ~ U[low,high]
+    Prior is -low ~ Pareto(x_m_low,alpha_low)-2*x_m_low
+             high ~ Pareto(x_m_high,alpha_high)
+
+    Hyperparameters:
+        x_m_low, alpha_low
+        x_m_high, alpha_high
+
+    Parameters:
+        low, high
+    '''
+    def __init__(
+            self,low=None,high=None,
+            x_m_low=None,alpha_low=None,x_m_high=None,alpha_high=None):
+        self.low = low
+        self.high = high
+
+        self.x_m_low = x_m_low
+        self.alpha_low = alpha_low
+        self.x_m_high = x_m_high
+        self.alpha_high = alpha_high
+
+        have_hypers = not any_none(x_m_low,alpha_low,x_m_high,alpha_high)
+        if low is high is None and have_hypers:
+            self.resample()  # initialize from prior
+
+    @property
+    def params(self):
+        return dict(low=self.low,high=self.high)
+
+    @property
+    def hypparams(self):
+        return dict(
+            x_m_low=self.x_m_low,alpha_low=self.alpha_low,
+            x_m_high=self.x_m_high,alpha_high=self.alpha_high)
+
+    def resample(self,data=[],niter=5):
+        if len(data) == 0:
+            self.low = -sample_pareto(-self.x_m_low,self.alpha_low)
+            self.high = sample_pareto(self.x_m_high,self.alpha_high)
+        else:
+            for itr in range(niter):
+                # resample high, fixing low
+                self.x_m, self.alpha = self.x_m_high, self.alpha_high
+                super(Uniform,self).resample(data)
+                # tricky: flip data and resample 'high' again
+                self.x_m, self.alpha = -self.x_m_low, self.alpha_low
+                self.low, self.high = self.high, self.low
+                super(Uniform,self).resample(self._flip_data(data))
+                self.low, self.high = self.x_m_low - self.high, self.low
+
+    def _flip_data(self,data):
+        if isinstance(data,np.ndarray):
+            return self.x_m_low - data
+        else:
+            return list(map(self._flip_data,data))
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/__init__.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..838c912bb3d21af7d0c1d159edadb92044bb8250
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/__init__.py
@@ -0,0 +1,2 @@
+from .mixture import Labels, CRPLabels, Mixture, MixtureDistribution, CollapsedMixture, CRPMixture
+from .factor_analysis import FactorAnalysis
\ No newline at end of file
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/factor_analysis.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/factor_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8eba488ea5a7a3b921357da1d93d5f35be1f6e
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/factor_analysis.py
@@ -0,0 +1,336 @@
+"""
+Probabilistic factor analysis to perform dimensionality reduction on mouse images.
+With the probabilistic approach, we can handle missing data in the images.
+Technically this holds for missing at random data, but we can try it
+out on images where we treat cable pixels as missing, even though they
+won't be random. This should give us a model-based way to fill in pixels,
+and hopefully a more robust way to estimate principle components for modeling.
+"""
+import abc
+import numpy as np
+
+from pybasicbayes.abstractions import Model, \
+    ModelGibbsSampling, ModelMeanField, ModelMeanFieldSVI, ModelEM
+from pybasicbayes.util.stats import sample_gaussian
+from pybasicbayes.util.general import objarray
+
+from pybasicbayes.distributions import DiagonalRegression
+
+from pybasicbayes.util.profiling import line_profiled
+PROFILING = True
+
+class FactorAnalysisStates(object):
+    """
+    Wrapper for the latent states of a factor analysis model
+    """
+    def __init__(self, model, data, mask=None, **kwargs):
+        self.model = model
+        self.X = data
+        if mask is None:
+            mask = np.ones_like(data, dtype=bool)
+        self.mask = mask
+        assert data.shape == mask.shape and mask.dtype == bool
+        assert self.X.shape[1] == self.D_obs
+
+        # Initialize latent states
+        self.N = self.X.shape[0]
+        self.Z = np.random.randn(self.N, self.D_latent)
+
+    @property
+    def D_obs(self):
+        return self.model.D_obs
+
+    @property
+    def D_latent(self):
+        return self.model.D_latent
+
+    @property
+    def W(self):
+        return self.model.W
+
+    @property
+    def mean(self):
+        return self.model.mean
+
+    @property
+    def sigmasq(self):
+        return self.model.sigmasq
+
+    @property
+    def regression(self):
+        return self.model.regression
+
+    def log_likelihood(self):
+        # mu = np.dot(self.Z, self.W.T)
+        # return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq)
+
+        # Compute the marginal likelihood, integrating out z
+        mu_x = self.mean
+        Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq)
+
+        from scipy.stats import multivariate_normal
+        if not np.all(self.mask):
+            # Find the patterns of missing dta
+            missing_patterns = np.unique(self.mask, axis=0)
+
+            # Evaluate the likelihood for each missing pattern
+            lls = np.zeros(self.N)
+            for pat in missing_patterns:
+                inds = np.all(self.mask == pat, axis=1)
+                lls[inds] = \
+                    multivariate_normal(mu_x[pat], Sigma_x[np.ix_(pat, pat)])\
+                    .logpdf(self.X[np.ix_(inds, pat)])
+
+        else:
+            lls = multivariate_normal(mu_x, Sigma_x).logpdf(self.X)
+
+        return lls
+
+    ## Gibbs
+    def resample(self):
+        W, sigmasq = self.W, self.sigmasq
+        J0 = np.eye(self.D_latent)
+        h0 = np.zeros(self.D_latent)
+
+        # Sample each latent embedding
+        for n in range(self.N):
+            Jobs = self.mask[n] / sigmasq
+            Jpost = J0 + (W * Jobs[:, None]).T.dot(W)
+            hpost = h0 + ((self.X[n] - self.mean) * Jobs).dot(W)
+            self.Z[n] = sample_gaussian(J=Jpost, h=hpost)
+
+    ## Mean field
+    def E_step(self):
+        W = self.W
+        WWT = np.array([np.outer(wd,wd) for wd in W])
+        sigmasq_inv = 1./self.sigmasq
+        self._meanfieldupdate(W, WWT, sigmasq_inv)
+
+        # Copy over the expected states to Z
+        self.Z = self.E_Z
+
+    def meanfieldupdate(self):
+        E_W, E_WWT, E_sigmasq_inv, _ = self.regression.mf_expectations
+        self._meanfieldupdate(E_W, E_WWT, E_sigmasq_inv)
+
+        # Copy over the expected states to Z
+        self.Z = self.E_Z
+
+    def _meanfieldupdate(self, E_W, E_WWT, E_sigmasq_inv):
+        N, D_obs, D_lat = self.N, self.D_obs, self.D_latent
+        E_WWT_vec = E_WWT.reshape(D_obs, -1)
+
+        J0 = np.eye(D_lat)
+        h0 = np.zeros(D_lat)
+
+        # Get expectations for the latent embedding of these datapoints
+        self.E_Z = np.zeros((N, D_lat))
+        self.E_ZZT = np.zeros((N, D_lat, D_lat))
+
+        for n in range(N):
+            Jobs = self.mask[n] * E_sigmasq_inv
+            # Faster than Jpost = J0 + np.sum(E_WWT * Jobs[:,None,None], axis=0)
+            Jpost = J0 + (np.dot(Jobs, E_WWT_vec)).reshape((D_lat, D_lat))
+            hpost = h0 + ((self.X[n] - self.mean) * Jobs).dot(E_W)
+
+            # Get the expectations for this set of indices
+            Sigma_post = np.linalg.inv(Jpost)
+            self.E_Z[n] = Sigma_post.dot(hpost)
+            self.E_ZZT[n] = Sigma_post + np.outer(self.E_Z[n], self.E_Z[n])
+
+        self._set_expected_stats()
+
+    def _set_expected_stats(self):
+        D_lat = self.D_latent
+        Xc = self.X - self.mean
+        E_Xsq = np.sum(Xc**2 * self.mask, axis=0)
+        E_XZT = (Xc * self.mask).T.dot(self.E_Z)
+        E_ZZT_vec = self.E_ZZT.reshape((self.E_ZZT.shape[0], D_lat ** 2))
+        E_ZZT = np.array([np.dot(self.mask[:, d], E_ZZT_vec).reshape((D_lat, D_lat))
+                          for d in range(self.D_obs)])
+        n = np.sum(self.mask, axis=0)
+
+        self.E_emission_stats = objarray([E_Xsq, E_XZT, E_ZZT, n])
+
+    def resample_from_mf(self):
+        for n in range(self.N):
+            mu_n = self.E_Z[n]
+            Sigma_n = self.E_ZZT[n] - np.outer(mu_n, mu_n)
+            self.Z[n] = sample_gaussian(mu=mu_n, Sigma=Sigma_n)
+
+    def expected_log_likelihood(self):
+        E_W, E_WWT, E_sigmasq_inv, E_log_sigmasq = self.regression.mf_expectations
+        E_Xsq, E_XZT, E_ZZT, n = self.E_emission_stats
+
+        ll = -0.5 * np.log(2 * np.pi) - 0.5 * np.sum(E_log_sigmasq * self.mask)
+        ll += -0.5 * np.sum(E_Xsq * E_sigmasq_inv)
+        ll += -0.5 * np.sum(-2 * E_XZT * E_W * E_sigmasq_inv[:,None])
+        ll += -0.5 * np.sum(E_WWT * E_ZZT * E_sigmasq_inv[:,None,None])
+        return ll
+
+
+class _FactorAnalysisBase(Model):
+    __metaclass__ = abc.ABCMeta
+    _states_class = FactorAnalysisStates
+
+    def __init__(self, D_obs, D_latent,
+                 W=None, sigmasq=None,
+                 sigmasq_W_0=1.0, mu_W_0=0.0,
+                 alpha_0=3.0, beta_0=2.0):
+
+        self.D_obs, self.D_latent = D_obs, D_latent
+
+        # The weights and variances are encapsulated in a DiagonalRegression class
+        self.regression = \
+            DiagonalRegression(
+                self.D_obs, self.D_latent,
+                mu_0=mu_W_0 * np.ones(self.D_latent),
+                Sigma_0=sigmasq_W_0 * np.eye(self.D_latent),
+                alpha_0=alpha_0, beta_0=beta_0,
+                A=W, sigmasq=sigmasq)
+
+        # Handle the mean separately since DiagonalRegression doesn't support affine :-/
+        self.mean = np.zeros(D_obs)
+
+        self.data_list = []
+
+    @property
+    def W(self):
+        return self.regression.A
+
+    @property
+    def sigmasq(self):
+        return self.regression.sigmasq_flat
+
+    def set_empirical_mean(self):
+        self.mean = np.zeros(self.D_obs)
+        for n in range(self.D_obs):
+            self.mean[n] = np.concatenate([d.X[d.mask[:,n] == 1, n] for d in self.data_list]).mean()
+
+    def add_data(self, data, mask=None, **kwargs):
+        self.data_list.append(self._states_class(self, data, mask=mask, **kwargs))
+        return self.data_list[-1]
+
+    def generate(self, keep=True, N=1, mask=None, **kwargs):
+        # Sample from the factor analysis model
+        W, sigmasq = self.W, self.sigmasq
+        Z = np.random.randn(N, self.D_latent)
+        X = self.mean + np.dot(Z, W.T) + np.sqrt(sigmasq) * np.random.randn(N, self.D_obs)
+
+        data = self._states_class(self, X, mask=mask, **kwargs)
+        data.Z = Z
+        if keep:
+            self.data_list.append(data)
+        return data.X, data.Z
+
+    def _log_likelihoods(self, x, mask=None, **kwargs):
+        self.add_data(x, mask=mask, **kwargs)
+        states = self.data_list.pop()
+        return states.log_likelihood()
+
+    def log_likelihood(self):
+        return sum([d.log_likelihood().sum() for d in self.data_list])
+
+    def log_probability(self):
+        lp = 0
+
+        # Prior
+        # lp += (-self.alpha_0-1) * np.log(self.sigmasq) - self.beta_0 / self.sigmasq
+        lp += -0.5 * np.sum(self.W**2)
+        lp += -0.5 * np.sum(self.Z**2)
+        lp += self.log_likelihood()
+        return lp
+
+
+class _FactorAnalysisGibbs(_FactorAnalysisBase, ModelGibbsSampling):
+    __metaclass__ = abc.ABCMeta
+
+    def resample_model(self):
+        for data in self.data_list:
+            data.resample()
+
+        Zs = np.vstack([d.Z for d in self.data_list])
+        Xs = np.vstack([d.X for d in self.data_list])
+        mask = np.vstack([d.mask for d in self.data_list])
+        self.regression.resample((Zs, Xs), mask=mask)
+
+
+class _FactorAnalysisEM(_FactorAnalysisBase, ModelEM):
+
+    def _null_stats(self):
+        return objarray(
+            [np.zeros(self.D_obs),
+             np.zeros((self.D_obs, self.D_latent)),
+             np.zeros((self.D_obs, self.D_latent, self.D_latent)),
+             np.zeros(self.D_obs)])
+
+    def EM_step(self):
+        for data in self.data_list:
+            data.E_step()
+
+        stats = self._null_stats() + sum([d.E_emission_stats for d in self.data_list])
+        self.regression.max_likelihood(data=None, weights=None, stats=stats)
+        assert np.all(np.isfinite(self.sigmasq ))
+
+
+class _FactorAnalysisMeanField(_FactorAnalysisBase, ModelMeanField, ModelMeanFieldSVI):
+    __metaclass__ = abc.ABCMeta
+
+    def _null_stats(self):
+        return objarray(
+            [np.zeros(self.D_obs),
+             np.zeros((self.D_obs, self.D_latent)),
+             np.zeros((self.D_obs, self.D_latent, self.D_latent)),
+             np.zeros(self.D_obs)])
+
+    def meanfield_coordinate_descent_step(self):
+        for data in self.data_list:
+            data.meanfieldupdate()
+
+        stats = self._null_stats() + sum([d.E_emission_stats for d in self.data_list])
+        self.regression.meanfieldupdate(stats=stats)
+
+    def meanfield_sgdstep(self, minibatch, prob, stepsize, masks=None):
+        assert stepsize > 0 and stepsize <= 1
+
+        states_list = self._get_mb_states_list(minibatch, masks)
+        for s in states_list:
+            s.meanfieldupdate()
+
+        # Compute the sufficient statistics of the latent parameters
+        self.regression.meanfield_sgdstep(
+            data=None, weights=None, prob=prob, stepsize=stepsize,
+            stats=(sum(s.E_emission_stats for s in states_list)))
+
+        # Compute the expected log likelihood for this minibatch
+        return sum([s.expected_log_likelihood() for s in states_list])
+
+    def _get_mb_states_list(self, minibatch, masks):
+        minibatch = minibatch if isinstance(minibatch, list) else [minibatch]
+        masks = [None] * len(minibatch) if masks is None else \
+            (masks if isinstance(masks, list) else [masks])
+
+        def get_states(data, mask):
+            self.add_data(data, mask=mask)
+            return self.data_list.pop()
+
+        return [get_states(data, mask) for data, mask in zip(minibatch, masks)]
+
+    def resample_from_mf(self):
+        for data in self.data_list:
+            data.resample_from_mf()
+        self.regression.resample_from_mf()
+
+    def expected_log_likelihood(self):
+        ell = 0
+        for data in self.data_list:
+            ell += data.expected_log_likelihood()
+        return ell
+
+    def initialize_meanfield(self):
+        self.regression._initialize_mean_field()
+
+
+class FactorAnalysis(_FactorAnalysisGibbs, _FactorAnalysisEM, _FactorAnalysisMeanField):
+    pass
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/mixture.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/mixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..55292c644336f0dcb8c3185f9a0a03bb1d6c5730
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/mixture.py
@@ -0,0 +1,841 @@
+from __future__ import division
+from __future__ import absolute_import
+from builtins import zip
+from builtins import range
+from builtins import object
+import numpy as np
+from functools import reduce
+from future.utils import with_metaclass
+na = np.newaxis
+import scipy.special as special
+import abc, copy
+from warnings import warn
+from scipy.special import logsumexp
+
+from pybasicbayes.abstractions import ModelGibbsSampling, ModelMeanField, ModelEM
+from pybasicbayes.abstractions import Distribution, GibbsSampling, MeanField, Collapsed, \
+        MeanFieldSVI, MaxLikelihood, ModelParallelTempering
+from pybasicbayes.distributions import Categorical, CategoricalAndConcentration
+from pybasicbayes.util.stats import getdatasize, sample_discrete_from_log, sample_discrete
+
+
+#############################
+#  internal labels classes  #
+#############################
+
+class Labels(object):
+    def __init__(self,model,data=None,N=None,z=None,
+            initialize_from_prior=True):
+        assert data is not None or (N is not None and z is None)
+
+        self.model = model
+
+        if data is None:
+            self._generate(N)
+        else:
+            self.data = data
+
+            if z is not None:
+                self.z = z
+            elif initialize_from_prior:
+                self._generate(len(data))
+            else:
+                self.resample()
+
+    def _generate(self,N):
+        self.z = self.weights.rvs(N)
+
+    @property
+    def N(self):
+        return len(self.z)
+
+    @property
+    def components(self):
+        return self.model.components
+
+    @property
+    def weights(self):
+        return self.model.weights
+
+    def log_likelihood(self):
+        if not hasattr(self,'_normalizer') or self._normalizer is None:
+            scores = self._compute_scores()
+            self._normalizer = logsumexp(scores[~np.isnan(self.data).any(1)],axis=1).sum()
+        return self._normalizer
+
+    def _compute_scores(self):
+        data, K = self.data, len(self.components)
+        scores = np.empty((data.shape[0],K))
+        for idx, c in enumerate(self.components):
+            scores[:,idx] = c.log_likelihood(data)
+        scores += self.weights.log_likelihood(np.arange(K))
+        scores[np.isnan(data).any(1)] = 0. # missing data
+        return scores
+
+    def clear_caches(self):
+        self._normalizer = None
+
+    ### Gibbs sampling
+
+    def resample(self):
+        scores = self._compute_scores()
+        self.z, lognorms = sample_discrete_from_log(scores,axis=1,return_lognorms=True)
+        self._normalizer = lognorms[~np.isnan(self.data).any(1)].sum()
+
+    def copy_sample(self):
+        new = copy.copy(self)
+        new.z = self.z.copy()
+        return new
+
+    ### Mean Field
+
+    def meanfieldupdate(self):
+        data, N, K = self.data, self.data.shape[0], len(self.components)
+
+        # update, see Eq. 10.67 in Bishop
+        component_scores = np.empty((N,K))
+        for idx, c in enumerate(self.components):
+            component_scores[:,idx] = c.expected_log_likelihood(data)
+        component_scores = np.nan_to_num(component_scores)
+
+        logpitilde = self.weights.expected_log_likelihood(np.arange(len(self.components)))
+        logr = logpitilde + component_scores
+
+        self.r = np.exp(logr - logr.max(1)[:,na])
+        self.r /= self.r.sum(1)[:,na]
+
+        # for plotting
+        self.z = self.r.argmax(1)
+
+    def get_vlb(self):
+        # return avg energy plus entropy, our contribution to the mean field
+        # variational lower bound
+        errs = np.seterr(invalid='ignore',divide='ignore')
+        prod = self.r*np.log(self.r)
+        prod[np.isnan(prod)] = 0. # 0 * -inf = 0.
+        np.seterr(**errs)
+
+        logpitilde = self.weights.expected_log_likelihood(np.arange(len(self.components)))
+
+        q_entropy = -prod.sum()
+        p_avgengy = (self.r*logpitilde).sum()
+
+        return p_avgengy + q_entropy
+
+    ### EM
+
+    def E_step(self):
+        data, N, K = self.data, self.data.shape[0], len(self.components)
+
+        self.expectations = np.empty((N,K))
+        for idx, c in enumerate(self.components):
+            self.expectations[:,idx] = c.log_likelihood(data)
+        self.expectations = np.nan_to_num(self.expectations)
+
+        self.expectations += self.weights.log_likelihood(np.arange(K))
+
+        self.expectations -= self.expectations.max(1)[:,na]
+        np.exp(self.expectations,out=self.expectations)
+        self.expectations /= self.expectations.sum(1)[:,na]
+
+        self.z = self.expectations.argmax(1)
+
+
+class CRPLabels(object):
+    def __init__(self,model,alpha_0,obs_distn,data=None,N=None):
+        assert (data is not None) ^ (N is not None)
+        self.alpha_0 = alpha_0
+        self.obs_distn = obs_distn
+        self.model = model
+
+        if data is None:
+            # generating
+            self._generate(N)
+        else:
+            self.data = data
+            self._generate(data.shape[0])
+            self.resample() # one resampling step
+
+    def _generate(self,N):
+        # run a CRP forwards
+        alpha_0 = self.alpha_0
+        self.z = np.zeros(N,dtype=np.int32)
+        for n in range(N):
+            self.z[n] = sample_discrete(np.concatenate((np.bincount(self.z[:n]),(alpha_0,))))
+
+    def resample(self):
+        al, o = np.log(self.alpha_0), self.obs_distn
+        self.z = ma.masked_array(self.z,mask=np.zeros(self.z.shape))
+        model = self.model
+
+        for n in np.random.permutation(self.data.shape[0]):
+            # mask out n
+            self.z.mask[n] = True
+
+            # form the scores and sample them
+            ks = list(model._get_occupied())
+            scores = np.array([
+                np.log(model._get_counts(k))+ o.log_predictive(self.data[n],model._get_data_withlabel(k)) \
+                        for k in ks] + [al + o.log_marginal_likelihood(self.data[n])])
+
+            idx = sample_discrete_from_log(scores)
+            if idx == scores.shape[0]-1:
+                self.z[n] = self._new_label(ks)
+            else:
+                self.z[n] = ks[idx]
+
+            # sample
+            # note: the mask gets fixed by assigning into the array
+            self.z[n] = sample_discrete_from_log(np.array(scores))
+
+    def _new_label(self,ks):
+        # return a label that isn't already used...
+        newlabel = np.random.randint(low=0,high=5*max(ks))
+        while newlabel in ks:
+            newlabel = np.random.randint(low=0,high=5*max(ks))
+        return newlabel
+
+
+    def _get_counts(self,k):
+        return np.sum(self.z == k)
+
+    def _get_data_withlabel(self,k):
+        return self.data[self.z == k]
+
+    def _get_occupied(self):
+        if ma.is_masked(self.z):
+            return set(self.z[~self.z.mask])
+        else:
+            return set(self.z)
+
+
+###################
+#  model classes  #
+###################
+
+class Mixture(ModelGibbsSampling, ModelMeanField, ModelEM, ModelParallelTempering):
+    '''
+    This class is for mixtures of other distributions.
+    '''
+    _labels_class = Labels
+
+    def __init__(self,components,alpha_0=None,a_0=None,b_0=None,weights=None,weights_obj=None):
+        assert len(components) > 0
+        assert (alpha_0 is not None) ^ (a_0 is not None and b_0 is not None) \
+                ^ (weights_obj is not None)
+
+        self.components = components
+
+        if alpha_0 is not None:
+            self.weights = Categorical(alpha_0=alpha_0,K=len(components),weights=weights)
+        elif weights_obj is not None:
+            self.weights = weights_obj
+        else:
+            self.weights = CategoricalAndConcentration(
+                    a_0=a_0,b_0=b_0,K=len(components),weights=weights)
+
+        self.labels_list = []
+
+    def add_data(self,data,**kwargs):
+        self.labels_list.append(self._labels_class(data=np.asarray(data),model=self,**kwargs))
+        return self.labels_list[-1]
+
+    @property
+    def N(self):
+        return len(self.components)
+
+    def generate(self,N,keep=True):
+        templabels = self._labels_class(model=self,N=N)
+
+        out = np.empty(self.components[0].rvs(N).shape)
+        counts = np.bincount(templabels.z,minlength=self.N)
+        for idx,(c,count) in enumerate(zip(self.components,counts)):
+            out[templabels.z == idx,...] = c.rvs(count)
+
+        perm = np.random.permutation(N)
+        out = out[perm]
+        templabels.z = templabels.z[perm]
+
+        if keep:
+            templabels.data = out
+            self.labels_list.append(templabels)
+
+        return out, templabels.z
+
+    def _clear_caches(self):
+        for l in self.labels_list:
+            l.clear_caches()
+
+    def _log_likelihoods(self,x):
+        # NOTE: nans propagate as nans
+        x = np.asarray(x)
+        K = len(self.components)
+        vals = np.empty((x.shape[0],K))
+        for idx, c in enumerate(self.components):
+            vals[:,idx] = c.log_likelihood(x)
+        vals += self.weights.log_likelihood(np.arange(K))
+        return logsumexp(vals,axis=1)
+
+    def log_likelihood(self,x=None):
+        if x is None:
+            return sum(l.log_likelihood() for l in self.labels_list)
+        else:
+            assert isinstance(x,(np.ndarray,list))
+            if isinstance(x,list):
+                return sum(self.log_likelihood(d) for d in x)
+            else:
+                self.add_data(x)
+                return self.labels_list.pop().log_likelihood()
+
+    ### parallel tempering
+
+    @property
+    def temperature(self):
+        return self._temperature if hasattr(self,'_temperature') else 1.
+
+    @temperature.setter
+    def temperature(self,T):
+        self._temperature = T
+
+    @property
+    def energy(self):
+        energy = 0.
+        for l in self.labels_list:
+            for label, datum in zip(l.z,l.data):
+                energy += self.components[label].energy(datum)
+        return energy
+
+    def swap_sample_with(self,other):
+        self.components, other.components = other.components, self.components
+        self.weights, other.weights = other.weights, self.weights
+
+        for l1, l2 in zip(self.labels_list,other.labels_list):
+            l1.z, l2.z = l2.z, l1.z
+
+    ### Gibbs sampling
+
+    def resample_model(self,num_procs=0,components_jobs=0):
+        self.resample_components(num_procs=components_jobs)
+        self.resample_weights()
+        self.resample_labels(num_procs=num_procs)
+
+    def resample_weights(self):
+        self.weights.resample([l.z for l in self.labels_list])
+        self._clear_caches()
+
+    def resample_components(self,num_procs=0):
+        if num_procs == 0:
+            for idx, c in enumerate(self.components):
+                c.resample(data=[l.data[l.z == idx] for l in self.labels_list])
+        else:
+            self._resample_components_joblib(num_procs)
+        self._clear_caches()
+
+    def resample_labels(self,num_procs=0):
+        if num_procs == 0:
+            for l in self.labels_list:
+                l.resample()
+        else:
+            self._resample_labels_joblib(num_procs)
+
+    def copy_sample(self):
+        new = copy.copy(self)
+        new.components = [c.copy_sample() for c in self.components]
+        new.weights = self.weights.copy_sample()
+        new.labels_list = [l.copy_sample() for l in self.labels_list]
+        for l in new.labels_list:
+            l.model = new
+        return new
+
+    def _resample_components_joblib(self,num_procs):
+        from joblib import Parallel, delayed
+        from . import parallel_mixture
+
+        parallel_mixture.model = self
+        parallel_mixture.labels_list = self.labels_list
+
+        if len(self.components) > 0:
+            params = Parallel(n_jobs=num_procs,backend='multiprocessing')\
+                    (delayed(parallel_mixture._get_sampled_component_params)(idx)
+                            for idx in range(len(self.components)))
+
+        for c, p in zip(self.components,params):
+            c.parameters = p
+
+    def _resample_labels_joblib(self,num_procs):
+        from joblib import Parallel, delayed
+        from . import parallel_mixture
+
+        if len(self.labels_list) > 0:
+            parallel_mixture.model = self
+
+            raw = Parallel(n_jobs=num_procs,backend='multiprocessing')\
+                    (delayed(parallel_mixture._get_sampled_labels)(idx)
+                            for idx in range(len(self.labels_list)))
+
+            for l, (z,normalizer) in zip(self.labels_list,raw):
+                l.z, l._normalizer = z, normalizer
+
+
+    ### Mean Field
+
+    def meanfield_coordinate_descent_step(self):
+        assert all(isinstance(c,MeanField) for c in self.components), \
+                'Components must implement MeanField'
+        assert len(self.labels_list) > 0, 'Must have data to run MeanField'
+
+        self._meanfield_update_sweep()
+        return self._vlb()
+
+    def _meanfield_update_sweep(self):
+        # NOTE: to interleave mean field steps with Gibbs sampling steps, label
+        # updates need to come first, otherwise the sampled updates will be
+        # ignored and the model will essentially stay where it was the last time
+        # mean field updates were run
+        # TODO fix that, seed with sample from variational distribution
+        self.meanfield_update_labels()
+        self.meanfield_update_parameters()
+
+    def meanfield_update_labels(self):
+        for l in self.labels_list:
+            l.meanfieldupdate()
+
+    def meanfield_update_parameters(self):
+        self.meanfield_update_components()
+        self.meanfield_update_weights()
+
+    def meanfield_update_weights(self):
+        self.weights.meanfieldupdate(None,[l.r for l in self.labels_list])
+        self._clear_caches()
+
+    def meanfield_update_components(self):
+        for idx, c in enumerate(self.components):
+            c.meanfieldupdate([l.data for l in self.labels_list],
+                    [l.r[:,idx] for l in self.labels_list])
+        self._clear_caches()
+
+    def _vlb(self):
+        vlb = 0.
+        vlb += sum(l.get_vlb() for l in self.labels_list)
+        vlb += self.weights.get_vlb()
+        vlb += sum(c.get_vlb() for c in self.components)
+        for l in self.labels_list:
+            vlb += np.sum([r.dot(c.expected_log_likelihood(l.data))
+                                for c,r in zip(self.components, l.r.T)])
+
+        # add in symmetry factor (if we're actually symmetric)
+        if len(set(type(c) for c in self.components)) == 1:
+            vlb += special.gammaln(len(self.components)+1)
+
+        return vlb
+
+    ### SVI
+
+    def meanfield_sgdstep(self,minibatch,prob,stepsize,**kwargs):
+        minibatch = minibatch if isinstance(minibatch,list) else [minibatch]
+        mb_labels_list = []
+        for data in minibatch:
+            self.add_data(data,z=np.empty(data.shape[0]),**kwargs) # NOTE: dummy
+            mb_labels_list.append(self.labels_list.pop())
+
+        for l in mb_labels_list:
+            l.meanfieldupdate()
+
+        self._meanfield_sgdstep_parameters(mb_labels_list,prob,stepsize)
+
+    def _meanfield_sgdstep_parameters(self,mb_labels_list,prob,stepsize):
+        self._meanfield_sgdstep_components(mb_labels_list,prob,stepsize)
+        self._meanfield_sgdstep_weights(mb_labels_list,prob,stepsize)
+
+    def _meanfield_sgdstep_components(self,mb_labels_list,prob,stepsize):
+        for idx, c in enumerate(self.components):
+            c.meanfield_sgdstep(
+                    [l.data for l in mb_labels_list],
+                    [l.r[:,idx] for l in mb_labels_list],
+                    prob,stepsize)
+
+    def _meanfield_sgdstep_weights(self,mb_labels_list,prob,stepsize):
+        self.weights.meanfield_sgdstep(
+                None,[l.r for l in mb_labels_list],
+                prob,stepsize)
+
+    ### EM
+
+    def EM_step(self):
+        # assert all(isinstance(c,MaxLikelihood) for c in self.components), \
+        #         'Components must implement MaxLikelihood'
+        assert len(self.labels_list) > 0, 'Must have data to run EM'
+
+        ## E step
+        for l in self.labels_list:
+            l.E_step()
+
+        ## M step
+        # component parameters
+        for idx, c in enumerate(self.components):
+            c.max_likelihood([l.data for l in self.labels_list],
+                    [l.expectations[:,idx] for l in self.labels_list])
+
+        # mixture weights
+        self.weights.max_likelihood(None,
+                [l.expectations for l in self.labels_list])
+
+    @property
+    def num_parameters(self):
+        # NOTE: scikit.learn's gmm.py doesn't count the weights in the number of
+        # parameters, but I don't know why they wouldn't. Some convention?
+        return sum(c.num_parameters for c in self.components) + self.weights.num_parameters
+
+    def BIC(self,data=None):
+        '''
+        BIC on the passed data.
+        If passed data is None (default), calculates BIC on the model's assigned data.
+        '''
+        # NOTE: in principle this method computes the BIC only after finding the
+        # maximum likelihood parameters (or, of course, an EM fixed-point as an
+        # approximation!)
+        if data is None:
+            assert len(self.labels_list) > 0, \
+                    "If not passing in data, the class must already have it. Use the method add_data()"
+            return -2*sum(self.log_likelihood(l.data) for l in self.labels_list) + \
+                        self.num_parameters * np.log(sum(l.data.shape[0] for l in self.labels_list))
+        else:
+            return -2*self.log_likelihood(data) + self.num_parameters * np.log(data.shape[0])
+
+    def AIC(self):
+        # NOTE: in principle this method computes the AIC only after finding the
+        # maximum likelihood parameters (or, of course, an EM fixed-point as an
+        # approximation!)
+        assert len(self.labels_list) > 0, 'Must have data to get AIC'
+        return 2*self.num_parameters - 2*sum(self.log_likelihood(l.data) for l in self.labels_list)
+
+    ### Misc.
+
+    @property
+    def used_labels(self):
+        if len(self.labels_list) > 0:
+            label_usages = sum(np.bincount(l.z,minlength=self.N) for l in self.labels_list)
+            used_labels, = np.where(label_usages > 0)
+        else:
+            used_labels = np.argsort(self.weights.weights)[-1:-11:-1]
+        return used_labels
+
+    def plot(self,color=None,legend=False,alpha=None,update=False,draw=True):
+        import matplotlib.pyplot as plt
+        from matplotlib import cm
+        artists = []
+
+        ### get colors
+        cmap = cm.get_cmap()
+        if color is None:
+            label_colors = dict((idx,cmap(v))
+                for idx, v in enumerate(np.linspace(0,1,self.N,endpoint=True)))
+        else:
+            label_colors = dict((idx,color) for idx in range(self.N))
+
+        ### plot data scatter
+        for l in self.labels_list:
+            colorseq = [label_colors[label] for label in l.z]
+            if update and hasattr(l,'_data_scatter'):
+                l._data_scatter.set_offsets(l.data[:,:2])
+                l._data_scatter.set_color(colorseq)
+            else:
+                l._data_scatter = plt.scatter(l.data[:,0],l.data[:,1],c=colorseq,s=5)
+            artists.append(l._data_scatter)
+
+        ### plot parameters
+        axis = plt.axis()
+        for label, (c, w) in enumerate(zip(self.components,self.weights.weights)):
+            artists.extend(
+                c.plot(
+                    color=label_colors[label],
+                    label='%d' % label,
+                    alpha=min(0.25,1.-(1.-w)**2)/0.25 if alpha is None else alpha,
+                    update=update,draw=False))
+        plt.axis(axis)
+
+        ### add legend
+        if legend and color is None:
+            plt.legend(
+                [plt.Rectangle((0,0),1,1,fc=c)
+                    for i,c in label_colors.items() if i in used_labels],
+                [i for i in label_colors if i in used_labels],
+                loc='best', ncol=2)
+
+        if draw: plt.draw()
+        return artists
+
+
+    def to_json_dict(self):
+        assert len(self.labels_list) == 1
+        data = self.labels_list[0].data
+        z = self.labels_list[0].z
+        assert data.ndim == 2 and data.shape[1] == 2
+
+        return  {
+                    'points':[{'x':x,'y':y,'label':int(label)} for x,y,label in zip(data[:,0],data[:,1],z)],
+                    'ellipses':[dict(list(c.to_json_dict().items()) + [('label',i)])
+                        for i,c in enumerate(self.components) if i in z]
+                }
+
+    def predictive_likelihoods(self,test_data,forecast_horizons):
+        likes = self._log_likelihoods(test_data)
+        return [likes[k:] for k in forecast_horizons]
+
+    def block_predictive_likelihoods(self,test_data,blocklens):
+        csums = np.cumsum(self._log_likelihoods(test_data))
+        outs = []
+        for k in blocklens:
+            outs.append(csums[k:] - csums[:-k])
+        return outs
+
+
+class MixtureDistribution(Mixture, GibbsSampling, MeanField, MeanFieldSVI, Distribution):
+    '''
+    This makes a Mixture act like a Distribution for use in other models
+    '''
+
+    def __init__(self,niter=1,**kwargs):
+        self.niter = niter
+        super(MixtureDistribution,self).__init__(**kwargs)
+
+    @property
+    def params(self):
+        return dict(weights=self.weights.params,components=[c.params for c in self.components])
+
+    @property
+    def hypparams(self):
+        return dict(weights=self.weights.hypparams,components=[c.hypparams for c in self.components])
+
+    def energy(self,data):
+        # TODO TODO this function is horrible
+        assert data.ndim == 1
+
+        if np.isnan(data).any():
+            return 0.
+
+        from .util.stats import sample_discrete
+        likes = np.array([c.log_likelihood(data) for c in self.components]).reshape((-1,))
+        likes += np.log(self.weights.weights)
+        label = sample_discrete(np.exp(likes - likes.max()))
+
+        return self.components[label].energy(data)
+
+    def log_likelihood(self,x):
+        return self._log_likelihoods(x)
+
+    def resample(self,data):
+        # doesn't keep a reference to the data like a model would
+        assert isinstance(data,list) or isinstance(data,np.ndarray)
+
+        if getdatasize(data) > 0:
+            if not isinstance(data,np.ndarray):
+                data = np.concatenate(data)
+
+            self.add_data(data,initialize_from_prior=False)
+
+            for itr in range(self.niter):
+                self.resample_model()
+
+            self.labels_list.pop()
+        else:
+            self.resample_model()
+
+    def max_likelihood(self,data,weights=None):
+        if weights is not None:
+            raise NotImplementedError
+        assert isinstance(data,list) or isinstance(data,np.ndarray)
+        if isinstance(data,list):
+            data = np.concatenate(data)
+
+        if getdatasize(data) > 0:
+            self.add_data(data)
+            self.EM_fit()
+            self.labels_list = []
+
+    def get_vlb(self):
+        from warnings import warn
+        warn('Pretty sure this is missing a term, VLB is wrong but updates are fine') # TODO
+        vlb = 0.
+        # vlb += self._labels_vlb # TODO this part is wrong! we need weights passed in again
+        vlb += self.weights.get_vlb()
+        vlb += sum(c.get_vlb() for c in self.components)
+        return vlb
+
+    def expected_log_likelihood(self,x):
+        lognorm = logsumexp(self.weights._alpha_mf)
+        return sum(np.exp(a - lognorm) * c.expected_log_likelihood(x)
+                for a, c in zip(self.weights._alpha_mf, self.components))
+
+    def meanfieldupdate(self,data,weights,**kwargs):
+        # NOTE: difference from parent's method is the inclusion of weights
+        if not isinstance(data,(list,tuple)):
+            data = [data]
+            weights = [weights]
+        old_labels = self.labels_list
+        self.labels_list = []
+
+        for d in data:
+            self.add_data(d,z=np.empty(d.shape[0])) # NOTE: dummy
+
+        self.meanfield_update_labels()
+        for l, w in zip(self.labels_list,weights):
+            l.r *= w[:,na] # here's where the weights are used
+        self.meanfield_update_parameters()
+
+        # self._labels_vlb = sum(l.get_vlb() for l in self.labels_list) # TODO hack
+
+        self.labels_list = old_labels
+
+    def meanfield_sgdstep(self,minibatch,weights,prob,stepsize):
+        # NOTE: difference from parent's method is the inclusion of weights
+        if not isinstance(minibatch,list):
+            minibatch = [minibatch]
+            weights = [weights]
+        mb_labels_list = []
+        for data in minibatch:
+            self.add_data(data,z=np.empty(data.shape[0])) # NOTE: dummy
+            mb_labels_list.append(self.labels_list.pop())
+
+        for l, w in zip(mb_labels_list,weights):
+            l.meanfieldupdate()
+            l.r *= w[:,na] # here's where weights are used
+
+        self._meanfield_sgdstep_parameters(mb_labels_list,prob,stepsize)
+
+    def plot(self,data=[],color='b',label='',plot_params=True,indices=None):
+        # TODO handle indices for 1D
+        if not isinstance(data,list):
+            data = [data]
+        for d in data:
+            self.add_data(d)
+
+        for l in self.labels_list:
+            l.E_step() # sets l.z to MAP estimates
+            for label, o in enumerate(self.components):
+                if label in l.z:
+                    o.plot(color=color,label=label,
+                            data=l.data[l.z == label] if l.data is not None else None)
+
+        for d in data:
+            self.labels_list.pop()
+
+
+class CollapsedMixture(with_metaclass(abc.ABCMeta, ModelGibbsSampling)):
+    def _get_counts(self,k):
+        return sum(l._get_counts(k) for l in self.labels_list)
+
+    def _get_data_withlabel(self,k):
+        return [l._get_data_withlabel(k) for l in self.labels_list]
+
+    def _get_occupied(self):
+        return reduce(set.union,(l._get_occupied() for l in self.labels_list),set([]))
+
+    def plot(self):
+        import matplotlib.pyplot as plt
+        from matplotlib import cm
+        plt.figure()
+        cmap = cm.get_cmap()
+        used_labels = self._get_occupied()
+        num_labels = len(used_labels)
+
+        label_colors = {}
+        for idx,label in enumerate(used_labels):
+            label_colors[label] = idx/(num_labels-1. if num_labels > 1 else 1.)
+
+        for subfigidx,l in enumerate(self.labels_list):
+            plt.subplot(len(self.labels_list),1,1+subfigidx)
+            # TODO assuming data is 2D
+            for label in used_labels:
+                if label in l.z:
+                    plt.plot(l.data[l.z==label,0],l.data[l.z==label,1],
+                            color=cmap(label_colors[label]),ls='None',marker='x')
+
+
+class CRPMixture(CollapsedMixture):
+    _labels_class = CRPLabels
+
+    def __init__(self,alpha_0,obs_distn):
+        assert isinstance(obs_distn,Collapsed)
+        self.obs_distn = obs_distn
+        self.alpha_0 = alpha_0
+
+        self.labels_list = []
+
+    def add_data(self,data):
+        assert len(self.labels_list) == 0
+        self.labels_list.append(self._labels_class(model=self,data=np.asarray(data),
+            alpha_0=self.alpha_0,obs_distn=self.obs_distn))
+        return self.labels_list[-1]
+
+    def resample_model(self):
+        for l in self.labels_list:
+            l.resample()
+
+    def generate(self,N,keep=True):
+        warn('not fully implemented')
+        # TODO only works if there's no other data in the model; o/w need to add
+        # existing data to obs resample. should be an easy update.
+        # templabels needs to pay attention to its own counts as well as model
+        # counts
+        assert len(self.labels_list) == 0
+
+        templabels = self._labels_class(model=self,alpha_0=self.alpha_0,obs_distn=self.obs_distn,N=N)
+
+        counts = np.bincount(templabels.z)
+        out = np.empty(self.obs_distn.rvs(N).shape)
+        for idx, count in enumerate(counts):
+            self.obs_distn.resample()
+            out[templabels.z == idx,...] = self.obs_distn.rvs(count)
+
+        perm = np.random.permutation(N)
+        out = out[perm]
+        templabels.z = templabels.z[perm]
+
+        if keep:
+            templabels.data = out
+            self.labels_list.append(templabels)
+
+        return out, templabels.z
+
+    def log_likelihood(self,x, K_extra=1):
+        """
+        Estimate the log likelihood with samples from
+         the model. Draw k_extra components which were not populated by
+         the current model in order to create a truncated approximate
+         mixture model.
+        """
+        x = np.asarray(x)
+        ks = self._get_occupied()
+        K = len(ks)
+        K_total = K + K_extra
+
+        # Sample observation distributions given current labels
+        obs_distns = []
+        for k in range(K):
+            o = copy.deepcopy(self.obs_distn)
+            o.resample(data=self._get_data_withlabel(k))
+            obs_distns.append(o)
+
+        # Sample extra observation distributions from prior
+        for k in range(K_extra):
+            o = copy.deepcopy(self.obs_distn)
+            o.resample()
+            obs_distns.append(o)
+
+        # Sample a set of weights
+        weights = Categorical(alpha_0=self.alpha_0,
+                              K=K_total,
+                              weights=None)
+
+        assert len(self.labels_list) == 1
+        weights.resample(data=self.labels_list[0].z)
+
+        # Now compute the log likelihood
+        vals = np.empty((x.shape[0],K_total))
+        for k in range(K_total):
+            vals[:,k] = obs_distns[k].log_likelihood(x)
+
+        vals += weights.log_likelihood(np.arange(K_total))
+        assert not np.isnan(vals).any()
+        return logsumexp(vals,axis=1).sum()
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/parallel_mixture.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/parallel_mixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..8607ecd68a0734474b1d5f2e3ad4d4b0d4792141
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/models/parallel_mixture.py
@@ -0,0 +1,15 @@
+from __future__ import division
+import numpy as np
+
+model = None
+labels_list = None
+
+def _get_sampled_labels(idx):
+    model.add_data(model.labels_list[idx].data,initialize_from_prior=False)
+    l = model.labels_list.pop()
+    return l.z, l._normalizer
+
+def _get_sampled_component_params(idx):
+    model.components[idx].resample([l.data[l.z == idx] for l in labels_list])
+    return model.components[idx].parameters
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/testing/__init__.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/testing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/testing/mixins.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/testing/mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f4d22c40714d1cb7d8f17a60668ab807c5c8e29
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/testing/mixins.py
@@ -0,0 +1,262 @@
+from __future__ import division
+from builtins import zip
+from builtins import range
+from builtins import object
+import numpy as np
+import abc, os
+
+from nose.plugins.attrib import attr
+
+import pybasicbayes
+from pybasicbayes.util import testing
+from future.utils import with_metaclass
+
+class DistributionTester(with_metaclass(abc.ABCMeta, object)):
+    @abc.abstractproperty
+    def distribution_class(self):
+        pass
+
+    @abc.abstractproperty
+    def hyperparameter_settings(self):
+        pass
+
+class BasicTester(DistributionTester):
+    @property
+    def basic_data_size(self):
+        return 1000
+
+    def loglike_lists_tests(self):
+        for setting_idx, hypparam_dict in enumerate(self.hyperparameter_settings):
+            yield self.check_loglike_lists, setting_idx, hypparam_dict
+
+    def check_loglike_lists(self,setting_idx,hypparam_dict):
+        dist = self.distribution_class(**hypparam_dict)
+        data = dist.rvs(size=self.basic_data_size)
+
+        l1 = dist.log_likelihood(data).sum()
+        l2 = sum(dist.log_likelihood(d) for d in np.array_split(data,self.basic_data_size))
+
+        assert np.isclose(l1,l2)
+
+    def stats_lists_tests(self):
+        for setting_idx, hypparam_dict in enumerate(self.hyperparameter_settings):
+            yield self.check_stats_lists, setting_idx, hypparam_dict
+
+    def check_stats_lists(self,setting_idx,hypparam_dict):
+        dist = self.distribution_class(**hypparam_dict)
+        data = dist.rvs(size=self.basic_data_size)
+
+        if hasattr(dist,'_get_statistics'):
+            s1 = dist._get_statistics(data)
+            s2 = dist._get_statistics([d for d in np.array_split(data,self.basic_data_size)])
+
+            self._check_stats(s1,s2)
+
+    def _check_stats(self,s1,s2):
+        if isinstance(s1,np.ndarray):
+            if s1.dtype == np.object:
+                assert all(np.allclose(t1,t2) for t1, t2 in zip(s1,s2))
+            else:
+                assert np.allclose(s1,s2)
+        elif isinstance(s1,tuple):
+            assert all(np.allclose(ss1,ss2) for ss1,ss2 in zip(s1,s2))
+
+    def missing_data_tests(self):
+        for setting_idx, hypparam_dict in enumerate(self.hyperparameter_settings):
+            yield self.check_missing_data_stats, setting_idx, hypparam_dict
+
+    def check_missing_data_stats(self,setting_idx,hypparam_dict):
+        dist = self.distribution_class(**hypparam_dict)
+        data = dist.rvs(size=self.basic_data_size)
+
+        if isinstance(data,np.ndarray):
+            data[np.random.randint(2,size=data.shape[0]) == 1] = np.nan
+
+            s1 = dist._get_statistics(data)
+            s2 = dist._get_statistics(data[~np.isnan(data).any(1)])
+
+            self._check_stats(s1,s2)
+
+class BigDataGibbsTester(with_metaclass(abc.ABCMeta, DistributionTester)):
+    @abc.abstractmethod
+    def params_close(self,distn1,distn2):
+        pass
+
+    @property
+    def big_data_size(self):
+        return 20000
+
+    @property
+    def big_data_repeats_per_setting(self):
+        return 1
+
+    @property
+    def big_data_hyperparameter_settings(self):
+        return self.hyperparameter_settings
+
+    @attr('random')
+    def big_data_Gibbs_tests(self):
+        for setting_idx, hypparam_dict in enumerate(self.big_data_hyperparameter_settings):
+            for i in range(self.big_data_repeats_per_setting):
+                yield self.check_big_data_Gibbs, setting_idx, hypparam_dict
+
+    def check_big_data_Gibbs(self,setting_idx,hypparam_dict):
+        d1 = self.distribution_class(**hypparam_dict)
+        d2 = self.distribution_class(**hypparam_dict)
+
+        data = d1.rvs(size=self.big_data_size)
+        d2.resample(data)
+
+        assert self.params_close(d1,d2)
+
+class MaxLikelihoodTester(with_metaclass(abc.ABCMeta, DistributionTester)):
+    @abc.abstractmethod
+    def params_close(self,distn1,distn2):
+        pass
+
+
+    @property
+    def big_data_size(self):
+        return 20000
+
+    @property
+    def big_data_repeats_per_setting(self):
+        return 1
+
+    @property
+    def big_data_hyperparameter_settings(self):
+        return self.hyperparameter_settings
+
+
+    def maxlike_tests(self):
+        for setting_idx, hypparam_dict in enumerate(self.big_data_hyperparameter_settings):
+            for i in range(self.big_data_repeats_per_setting):
+                yield self.check_maxlike, setting_idx, hypparam_dict
+
+    def check_maxlike(self,setting_idx,hypparam_dict):
+        d1 = self.distribution_class(**hypparam_dict)
+        d2 = self.distribution_class(**hypparam_dict)
+
+        data = d1.rvs(size=self.big_data_size)
+        d2.max_likelihood(data)
+
+        assert self.params_close(d1,d2)
+
+class GewekeGibbsTester(with_metaclass(abc.ABCMeta, DistributionTester)):
+    @abc.abstractmethod
+    def geweke_statistics(self,distn,data):
+        pass
+
+
+    @property
+    def geweke_nsamples(self):
+        return 30000
+
+    @property
+    def geweke_data_size(self):
+        return 1 # NOTE: more data usually means slower mixing
+
+    @property
+    def geweke_ntrials(self):
+        return 3
+
+    @property
+    def geweke_pval(self):
+        return 0.05
+
+    @property
+    def geweke_hyperparameter_settings(self):
+        return self.hyperparameter_settings
+
+    def geweke_numerical_slice(self,distn,setting_idx):
+        return slice(None)
+
+    @property
+    def resample_kwargs(self):
+        return {}
+
+    @property
+    def geweke_resample_kwargs(self):
+        return self.resample_kwargs
+
+    @property
+    def geweke_num_statistic_fails_to_tolerate(self):
+        return 1
+
+
+    @attr('slow', 'random')
+    def geweke_tests(self):
+        for setting_idx, hypparam_dict in enumerate(self.geweke_hyperparameter_settings):
+            yield self.check_geweke, setting_idx, hypparam_dict
+
+    def geweke_figure_filepath(self,setting_idx):
+        return os.path.join(os.path.dirname(__file__),'figures',
+                            self.__class__.__name__,'setting_%d.pdf' % setting_idx)
+
+    def check_geweke(self,setting_idx,hypparam_dict):
+        import os
+        from matplotlib import pyplot as plt
+        plt.ioff()
+        fig = plt.figure()
+        figpath = self.geweke_figure_filepath(setting_idx)
+        mkdir(os.path.dirname(figpath))
+
+        nsamples, data_size, ntrials = self.geweke_nsamples, \
+                self.geweke_data_size, self.geweke_ntrials
+
+        d = self.distribution_class(**hypparam_dict)
+        sample_dim = np.atleast_1d(self.geweke_statistics(d,d.rvs(size=10))).shape[0]
+
+        num_statistic_fails = 0
+        for trial in range(ntrials):
+            # collect forward-generated statistics
+            forward_statistics = np.squeeze(np.empty((nsamples,sample_dim)))
+            for i in range(nsamples):
+                d = self.distribution_class(**hypparam_dict)
+                data = d.rvs(size=data_size)
+                forward_statistics[i] = self.geweke_statistics(d,data)
+
+            # collect gibbs-generated statistics
+            gibbs_statistics = np.squeeze(np.empty((nsamples,sample_dim)))
+            d = self.distribution_class(**hypparam_dict)
+            data = d.rvs(size=data_size)
+            for i in range(nsamples):
+                d.resample(data,**self.geweke_resample_kwargs)
+                data = d.rvs(size=data_size)
+                gibbs_statistics[i] = self.geweke_statistics(d,data)
+
+            testing.populations_eq_quantile_plot(forward_statistics,gibbs_statistics,fig=fig)
+            try:
+                sl = self.geweke_numerical_slice(d,setting_idx)
+                testing.assert_populations_eq_moments(
+                        forward_statistics[...,sl],gibbs_statistics[...,sl],
+                        pval=self.geweke_pval)
+            except AssertionError:
+                datapath = os.path.join(os.path.dirname(__file__),'figures',
+                        self.__class__.__name__,'setting_%d_trial_%d.npz' % (setting_idx,trial))
+                np.savez(datapath,fwd=forward_statistics,gibbs=gibbs_statistics)
+                example_violating_means = forward_statistics.mean(0), gibbs_statistics.mean(0)
+                num_statistic_fails += 1
+
+        plt.savefig(figpath)
+
+        assert num_statistic_fails <= self.geweke_num_statistic_fails_to_tolerate, \
+                'Geweke MAY have failed, check FIGURES in %s (e.g. %s vs %s)' \
+                % ((os.path.dirname(figpath),) + example_violating_means)
+
+
+##########
+#  misc  #
+##########
+
+def mkdir(path):
+    # from
+    # http://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python
+    import errno
+    try:
+        os.makedirs(path)
+    except OSError as exc:
+        if exc.errno == errno.EEXIST and os.path.isdir(path):
+            pass
+        else: raise
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/__init__.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed6fb6df75d06313f4d24356855e82c8f3e7c651
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+__all__ = ['general','plot','stats','text']
+from . import general, plot, stats, text
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/cstats.cpython-37m-x86_64-linux-gnu.so b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/cstats.cpython-37m-x86_64-linux-gnu.so
new file mode 100755
index 0000000000000000000000000000000000000000..0d0e15f9b43a5a4573a59861fc7ca12496eeb175
Binary files /dev/null and b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/cstats.cpython-37m-x86_64-linux-gnu.so differ
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/cyutil.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/cyutil.py
new file mode 100644
index 0000000000000000000000000000000000000000..d17086dfc3e11955ed8206517b21b4790519e1f0
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/cyutil.py
@@ -0,0 +1,112 @@
+from builtins import map
+from builtins import str
+import Cython.Build
+from Cython.Build.Dependencies import *
+
+# NOTE: mostly a copy of cython's create_extension_list except for the lines
+# surrounded by "begin matt added" / "end matt added"
+def create_extension_list(patterns, exclude=[], ctx=None, aliases=None, quiet=False, language=None,
+                          exclude_failures=False):
+    if not isinstance(patterns, (list, tuple)):
+        patterns = [patterns]
+    explicit_modules = set([m.name for m in patterns if isinstance(m, Extension)])
+    seen = set()
+    deps = create_dependency_tree(ctx, quiet=quiet)
+    to_exclude = set()
+    if not isinstance(exclude, list):
+        exclude = [exclude]
+    for pattern in exclude:
+        to_exclude.update(list(map(os.path.abspath, extended_iglob(pattern))))
+
+    module_list = []
+    for pattern in patterns:
+        if isinstance(pattern, str):
+            filepattern = pattern
+            template = None
+            name = '*'
+            base = None
+            exn_type = Extension
+            ext_language = language
+        elif isinstance(pattern, Extension):
+            for filepattern in pattern.sources:
+                if os.path.splitext(filepattern)[1] in ('.py', '.pyx'):
+                    break
+            else:
+                # ignore non-cython modules
+                module_list.append(pattern)
+                continue
+            template = pattern
+            name = template.name
+            base = DistutilsInfo(exn=template)
+            exn_type = template.__class__
+            ext_language = None  # do not override whatever the Extension says
+        else:
+            raise TypeError(pattern)
+
+        for file in extended_iglob(filepattern):
+            if os.path.abspath(file) in to_exclude:
+                continue
+            pkg = deps.package(file)
+            if '*' in name:
+                module_name = deps.fully_qualified_name(file)
+                if module_name in explicit_modules:
+                    continue
+            else:
+                module_name = name
+
+            if module_name not in seen:
+                try:
+                    kwds = deps.distutils_info(file, aliases, base).values
+                except Exception:
+                    if exclude_failures:
+                        continue
+                    raise
+                if base is not None:
+                    for key, value in list(base.values.items()):
+                        if key not in kwds:
+                            kwds[key] = value
+
+                sources = [file]
+                if template is not None:
+                    sources += [m for m in template.sources if m != filepattern]
+                if 'sources' in kwds:
+                    # allow users to add .c files etc.
+                    for source in kwds['sources']:
+                        source = encode_filename_in_py2(source)
+                        if source not in sources:
+                            sources.append(source)
+                    del kwds['sources']
+                if 'depends' in kwds:
+                    depends = resolve_depends(kwds['depends'], (kwds.get('include_dirs') or []) + [find_root_package_dir(file)])
+                    if template is not None:
+                        # Always include everything from the template.
+                        depends = list(set(template.depends).union(set(depends)))
+                    kwds['depends'] = depends
+
+                if ext_language and 'language' not in kwds:
+                    kwds['language'] = ext_language
+
+                # NOTE: begin matt added
+                if 'name' in kwds:
+                    module_name = str(kwds['name'])
+                    del kwds['name']
+                else:
+                    module_name = os.path.splitext(file)[0].replace('/','.')
+                # NOTE: end matt added
+                module_list.append(exn_type(
+                        name=module_name,
+                        sources=sources,
+                        **kwds))
+                m = module_list[-1]
+                seen.add(name)
+    return module_list
+
+true_cythonize = Cython.Build.cythonize
+true_create_extension_list = Cython.Build.Dependencies.create_extension_list
+
+def cythonize(*args,**kwargs):
+    Cython.Build.Dependencies.create_extension_list = create_extension_list
+    out = true_cythonize(*args,**kwargs)
+    Cython.Build.Dependencies.create_extension_list = true_create_extension_list
+    return out
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/general.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/general.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a7c8ef30d4a241fb0a2b37bc0a67d4e64e28f51
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/general.py
@@ -0,0 +1,328 @@
+from __future__ import division
+from future import standard_library
+standard_library.install_aliases()
+from builtins import next
+from builtins import zip
+from builtins import range
+import sys
+import numpy as np
+from numpy.lib.stride_tricks import as_strided as ast
+import scipy.linalg
+import scipy.linalg.lapack as lapack
+import copy, collections, os, shutil, hashlib
+from contextlib import closing
+from itertools import chain, count
+from functools import reduce
+from urllib.request import urlopen  # py2.7 covered by standard_library.install_aliases()
+
+
+def blockarray(*args,**kwargs):
+    return np.array(np.bmat(*args,**kwargs),copy=False)
+
+def interleave(*iterables):
+    return list(chain.from_iterable(zip(*iterables)))
+
+def joindicts(dicts):
+    # stuff on right clobbers stuff on left
+    return reduce(lambda x,y: dict(x,**y), dicts, {})
+
+def one_vs_all(stuff):
+    stuffset = set(stuff)
+    for thing in stuff:
+        yield thing, stuffset - set([thing])
+
+def rle(stateseq):
+    pos, = np.where(np.diff(stateseq) != 0)
+    pos = np.concatenate(([0],pos+1,[len(stateseq)]))
+    return stateseq[pos[:-1]], np.diff(pos)
+
+def irle(vals,lens):
+    out = np.empty(np.sum(lens))
+    for v,l,start in zip(vals,lens,np.concatenate(((0,),np.cumsum(lens)[:-1]))):
+        out[start:start+l] = v
+    return out
+
+def ibincount(counts):
+    'returns an array a such that counts = np.bincount(a)'
+    return np.repeat(np.arange(counts.shape[0]),counts)
+
+def cumsum(v,strict=False):
+    if not strict:
+        return np.cumsum(v,axis=0)
+    else:
+        out = np.zeros_like(v)
+        out[1:] = np.cumsum(v[:-1],axis=0)
+        return out
+
+def rcumsum(v,strict=False):
+    if not strict:
+        return np.cumsum(v[::-1],axis=0)[::-1]
+    else:
+        out = np.zeros_like(v)
+        out[:-1] = np.cumsum(v[-1:0:-1],axis=0)[::-1]
+        return out
+
+def delta_like(v,i):
+    out = np.zeros_like(v)
+    out[i] = 1
+    return out
+
+def deepcopy(obj):
+    return copy.deepcopy(obj)
+
+def nice_indices(arr):
+    '''
+    takes an array like [1,1,5,5,5,999,1,1]
+    and maps to something like [0,0,1,1,1,2,0,0]
+    modifies original in place as well as returns a ref
+    '''
+    # surprisingly, this is slower for very small (and very large) inputs:
+    # u,f,i = np.unique(arr,return_index=True,return_inverse=True)
+    # arr[:] = np.arange(u.shape[0])[np.argsort(f)][i]
+    ids = collections.defaultdict(count().__next__)
+    for idx,x in enumerate(arr):
+        arr[idx] = ids[x]
+    return arr
+
+def ndargmax(arr):
+    return np.unravel_index(np.argmax(np.ravel(arr)),arr.shape)
+
+def match_by_overlap(a,b):
+    assert a.ndim == b.ndim == 1 and a.shape[0] == b.shape[0]
+    ais, bjs = list(set(a)), list(set(b))
+    scores = np.zeros((len(ais),len(bjs)))
+    for i,ai in enumerate(ais):
+        for j,bj in enumerate(bjs):
+            scores[i,j] = np.dot(np.array(a==ai,dtype=np.float),b==bj)
+
+    flip = len(bjs) > len(ais)
+
+    if flip:
+        ais, bjs = bjs, ais
+        scores = scores.T
+
+    matching = []
+    while scores.size > 0:
+        i,j = ndargmax(scores)
+        matching.append((ais[i],bjs[j]))
+        scores = np.delete(np.delete(scores,i,0),j,1)
+        ais = np.delete(ais,i)
+        bjs = np.delete(bjs,j)
+
+    return matching if not flip else [(x,y) for y,x in matching]
+
+def hamming_error(a,b):
+    return (a!=b).sum()
+
+def scoreatpercentile(data,per,axis=0):
+    'like the function in scipy.stats but with an axis argument and works on arrays'
+    a = np.sort(data,axis=axis)
+    idx = per/100. * (data.shape[axis]-1)
+
+    if (idx % 1 == 0):
+        return a[[slice(None) if ii != axis else idx for ii in range(a.ndim)]]
+    else:
+        lowerweight = 1-(idx % 1)
+        upperweight = (idx % 1)
+        idx = int(np.floor(idx))
+        return lowerweight * a[[slice(None) if ii != axis else idx for ii in range(a.ndim)]] \
+                + upperweight * a[[slice(None) if ii != axis else idx+1 for ii in range(a.ndim)]]
+
+def stateseq_hamming_error(sampledstates,truestates):
+    sampledstates = np.array(sampledstates,ndmin=2).copy()
+
+    errors = np.zeros(sampledstates.shape[0])
+    for idx,s in enumerate(sampledstates):
+        # match labels by maximum overlap
+        matching = match_by_overlap(s,truestates)
+        s2 = s.copy()
+        for i,j in matching:
+            s2[s==i] = j
+        errors[idx] = hamming_error(s2,truestates)
+
+    return errors if errors.shape[0] > 1 else errors[0]
+
+def _sieve(stream):
+    # just for fun; doesn't work over a few hundred
+    val = next(stream)
+    yield val
+    for x in [x for x in _sieve(stream) if x % val != 0]:
+        yield x
+
+def primes():
+    return _sieve(count(2))
+
+def top_eigenvector(A,niter=1000,force_iteration=False):
+    '''
+    assuming the LEFT invariant subspace of A corresponding to the LEFT
+    eigenvalue of largest modulus has geometric multiplicity of 1 (trivial
+    Jordan block), returns the vector at the intersection of that eigenspace and
+    the simplex
+
+    A should probably be a ROW-stochastic matrix
+
+    probably uses power iteration
+    '''
+    n = A.shape[0]
+    np.seterr(invalid='raise',divide='raise')
+    if n <= 25 and not force_iteration:
+        x = np.repeat(1./n,n)
+        x = np.linalg.matrix_power(A.T,niter).dot(x)
+        x /= x.sum()
+        return x
+    else:
+        x1 = np.repeat(1./n,n)
+        x2 = x1.copy()
+        for itr in range(niter):
+            np.dot(A.T,x1,out=x2)
+            x2 /= x2.sum()
+            x1,x2 = x2,x1
+            if np.linalg.norm(x1-x2) < 1e-8:
+                break
+        return x1
+
+def engine_global_namespace(f):
+    # see IPython.parallel.util.interactive; it's copied here so as to avoid
+    # extra imports/dependences elsewhere, and to provide a slightly clearer
+    # name
+    f.__module__ = '__main__'
+    return f
+
+def block_view(a,block_shape):
+    shape = (a.shape[0]/block_shape[0],a.shape[1]/block_shape[1]) + block_shape
+    strides = (a.strides[0]*block_shape[0],a.strides[1]*block_shape[1]) + a.strides
+    return ast(a,shape=shape,strides=strides)
+
+def AR_striding(data,nlags):
+    data = np.asarray(data)
+    if not data.flags.c_contiguous:
+        data = data.copy(order='C')
+    if data.ndim == 1:
+        data = np.reshape(data,(-1,1))
+    sz = data.dtype.itemsize
+    return ast(
+            data,
+            shape=(data.shape[0]-nlags,data.shape[1]*(nlags+1)),
+            strides=(data.shape[1]*sz,sz))
+
+def count_transitions(stateseq,minlength=None):
+    if minlength is None:
+        minlength = stateseq.max() + 1
+    out = np.zeros((minlength,minlength),dtype=np.int32)
+    for a,b in zip(stateseq[:-1],stateseq[1:]):
+        out[a,b] += 1
+    return out
+
+### SGD
+
+def sgd_steps(tau,kappa):
+    assert 0.5 < kappa <= 1 and tau >= 0
+    for t in count(1):
+        yield (t+tau)**(-kappa)
+
+def hold_out(datalist,frac):
+    N = len(datalist)
+    perm = np.random.permutation(N)
+    split = int(np.ceil(frac * N))
+    return [datalist[i] for i in perm[split:]], [datalist[i] for i in perm[:split]]
+
+def sgd_passes(tau,kappa,datalist,minibatchsize=1,npasses=1):
+    N = len(datalist)
+
+    for superitr in range(npasses):
+        if minibatchsize == 1:
+            perm = np.random.permutation(N)
+            for idx, rho_t in zip(perm,sgd_steps(tau,kappa)):
+                yield datalist[idx], rho_t
+        else:
+            minibatch_indices = np.array_split(np.random.permutation(N),N/minibatchsize)
+            for indices, rho_t in zip(minibatch_indices,sgd_steps(tau,kappa)):
+                yield [datalist[idx] for idx in indices], rho_t
+
+def sgd_sampling(tau,kappa,datalist,minibatchsize=1):
+    N = len(datalist)
+    if minibatchsize == 1:
+        for rho_t in sgd_steps(tau,kappa):
+            minibatch_index = np.random.choice(N)
+            yield datalist[minibatch_index], rho_t
+    else:
+        for rho_t in sgd_steps(tau,kappa):
+            minibatch_indices = np.random.choice(N,size=minibatchsize,replace=False)
+            yield [datalist[idx] for idx in minibatch_indices], rho_t
+
+# TODO should probably eliminate this function
+def minibatchsize(lst):
+    return float(sum(d.shape[0] for d in lst))
+
+### misc
+
+def random_subset(lst,sz):
+    perm = np.random.permutation(len(lst))
+    return [lst[perm[idx]] for idx in range(sz)]
+
+def get_file(remote_url,local_path):
+    if not os.path.isfile(local_path):
+        with closing(urlopen(remote_url)) as remotefile:
+            with open(local_path,'wb') as localfile:
+                shutil.copyfileobj(remotefile,localfile)
+
+def list_split(lst,num):
+    assert num > 0
+    return [lst[start::num] for start in range(num)]
+
+def ndarrayhash(v):
+    assert isinstance(v,np.ndarray)
+    return hashlib.sha1(v).hexdigest()
+
+### numerical linear algebra
+
+def inv_psd(A, return_chol=False):
+    L = np.linalg.cholesky(A)
+    Ainv = lapack.dpotri(L, lower=True)[0]
+    copy_lower_to_upper(Ainv)
+    # if not np.allclose(Ainv, np.linalg.inv(A), rtol=1e-5, atol=1e-5):
+    #     import ipdb; ipdb.set_trace()
+    if return_chol:
+        return Ainv, L
+    else:
+        return Ainv
+
+def solve_psd(A,b,chol=None,lower=True,overwrite_b=False,overwrite_A=False):
+    if chol is None:
+        return lapack.dposv(A,b,overwrite_b=overwrite_b,overwrite_a=overwrite_A)[1]
+    else:
+        return lapack.dpotrs(chol,b,lower,overwrite_b)[0]
+
+def copy_lower_to_upper(A):
+    A += np.tril(A,k=-1).T
+
+
+# NOTE: existing numpy object array construction acts a bit weird, e.g.
+# np.array([randn(3,4),randn(3,5)]) vs np.array([randn(3,4),randn(5,3)])
+# this wrapper class is just meant to ensure that when ndarrays of objects are
+# constructed the construction doesn't "recurse" as in the first example
+class ObjArray(np.ndarray):
+    def __new__(cls,lst):
+        if isinstance(lst,(np.ndarray,float,int)):
+            return lst
+        else:
+            return np.ndarray.__new__(cls,len(lst),dtype=np.object)
+
+    def __init__(self,lst):
+        if not isinstance(lst,(np.ndarray,float,int)):
+            for i, elt in enumerate(lst):
+                self[i] = self.__class__(elt)
+
+# Here's an alternative to ObjArray: just construct an obj array from a list
+def objarray(lst):
+    a = np.empty(len(lst), dtype=object)
+    for i,o in enumerate(lst):
+        a[i] = o
+    return a
+
+def all_none(*args):
+    return all(_ is None for _ in args)
+
+def any_none(*args):
+    return any(_ is None for _ in args)
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/plot.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bb6e774a30d2d0d269a2914fac4dae00c6f5581
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/plot.py
@@ -0,0 +1,68 @@
+from __future__ import division
+from builtins import range
+import numpy as np
+from matplotlib import pyplot as plt
+
+def plot_gaussian_2D(mu, lmbda, color='b', centermarker=True,label='',alpha=1.,ax=None,artists=None):
+    '''
+    Plots mean and cov ellipsoid into current axes. Must be 2D. lmbda is a covariance matrix.
+    '''
+    assert len(mu) == 2
+    ax = ax if ax else plt.gca()
+
+    # TODO if update alpha=0. and our previous alpha is 0., we don't need to
+    # dirty the artist
+
+    t = np.hstack([np.arange(0,2*np.pi,0.01),0])
+    circle = np.vstack([np.sin(t),np.cos(t)])
+    ellipse = np.dot(np.linalg.cholesky(lmbda),circle)
+
+    if artists is None:
+        point = ax.scatter([mu[0]],[mu[1]],marker='D',color=color,s=4,alpha=alpha) \
+                if centermarker else None
+        line, = ax.plot(ellipse[0,:] + mu[0], ellipse[1,:] + mu[1],linestyle='-',
+                linewidth=2,color=color,label=label,alpha=alpha)
+    else:
+        line, point = artists
+        if centermarker:
+            point.set_offsets(np.atleast_2d(mu))
+            point.set_alpha(alpha)
+            point.set_color(color)
+        line.set_xdata(ellipse[0,:] + mu[0])
+        line.set_ydata(ellipse[1,:] + mu[1])
+        line.set_alpha(alpha)
+        line.set_color(color)
+
+    return (line, point) if point else (line,)
+
+
+def plot_gaussian_projection(mu, lmbda, vecs, **kwargs):
+    '''
+    Plots a ndim gaussian projected onto 2D vecs, where vecs is a matrix whose two columns
+    are the subset of some orthonomral basis (e.g. from PCA on samples).
+    '''
+    return plot_gaussian_2D(project_data(mu,vecs),project_ellipsoid(lmbda,vecs),**kwargs)
+
+
+def pca_project_data(data,num_components=2):
+    # convenience combination of the next two functions
+    return project_data(data,pca(data,num_components=num_components))
+
+
+def pca(data,num_components=2):
+    U,s,Vh = np.linalg.svd(data - np.mean(data,axis=0))
+    return Vh.T[:,:num_components]
+
+
+def project_data(data,vecs):
+    return np.dot(data,vecs.T)
+
+
+def project_ellipsoid(ellipsoid,vecs):
+    # vecs is a matrix whose columns are a subset of an orthonormal basis
+    # ellipsoid is a pos def matrix
+    return np.dot(vecs,np.dot(ellipsoid,vecs.T))
+
+
+def subplot_gridsize(num):
+    return sorted(min([(x,int(np.ceil(num/x))) for x in range(1,int(np.floor(np.sqrt(num)))+1)],key=sum))
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/profiling.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/profiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..6666b170421f5bd466f651a6754e151a085cfc75
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/profiling.py
@@ -0,0 +1,53 @@
+from __future__ import division
+from __future__ import print_function
+from future import standard_library
+standard_library.install_aliases()
+import numpy as np
+import sys, io, inspect, os, functools, time, collections
+
+### use @timed for really basic timing
+
+_timings = collections.defaultdict(list)
+
+def timed(func):
+    @functools.wraps(func)
+    def wrapped(*args,**kwargs):
+        tic = time.time()
+        out = func(*args,**kwargs)
+        _timings[func].append(time.time() - tic)
+        return out
+    return wrapped
+
+def show_timings(stream=None):
+    if stream is None:
+        stream = sys.stdout
+    if len(_timings) > 0:
+        results = [(inspect.getsourcefile(f),f.__name__,
+            len(vals),np.sum(vals),np.mean(vals),np.std(vals))
+            for f, vals in _timings.items()]
+        filename_lens = max(len(filename) for filename, _, _, _, _, _ in results)
+        name_lens = max(len(name) for _, name, _, _, _, _ in results)
+
+        fmt = '{:>%d} {:>%d} {:>10} {:>10} {:>10} {:>10}' % (filename_lens, name_lens)
+        print(fmt.format('file','name','ncalls','tottime','avg time','std dev'), file=stream)
+
+        fmt = '{:>%d} {:>%d} {:>10} {:>10.3} {:>10.3} {:>10.3}' % (filename_lens, name_lens)
+        print('\n'.join(fmt.format(*tup) for tup in sorted(results)), file=stream)
+
+### use @line_profiled for a thin wrapper around line_profiler
+
+try:
+    import line_profiler
+    _prof = line_profiler.LineProfiler()
+
+    def line_profiled(func):
+        mod = inspect.getmodule(func)
+        if 'PROFILING' in os.environ or (hasattr(mod,'PROFILING') and mod.PROFILING):
+            return _prof(func)
+        return func
+
+    def show_line_stats(stream=None):
+        _prof.print_stats(stream=stream)
+except ImportError:
+    line_profiled = lambda x: x
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/stats.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a133d4ff2725023f9f61e92a52aa95498ddb402
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/stats.py
@@ -0,0 +1,365 @@
+from __future__ import division
+from __future__ import absolute_import
+from builtins import range
+import numpy as np
+from numpy.random import random
+na = np.newaxis
+import scipy.stats as stats
+import scipy.special as special
+import scipy.linalg
+from scipy.special import logsumexp
+from numpy.core.umath_tests import inner1d
+
+from .general import any_none, blockarray
+
+### data abstraction
+
+# the data type is ndarrays OR lists of ndarrays
+# type Data = ndarray | [ndarray]
+
+def atleast_2d(data):
+    # NOTE: can't use np.atleast_2d because if it's 1D we want axis 1 to be the
+    # singleton and axis 0 to be the sequence index
+    if data.ndim == 1:
+        return data.reshape((-1,1))
+    return data
+
+def mask_data(data):
+    return np.ma.masked_array(
+        np.nan_to_num(data),np.isnan(data),fill_value=0.,hard_mask=True)
+
+def gi(data):
+    out = (np.isnan(atleast_2d(data)).sum(1) == 0).ravel()
+    return out if len(out) != 0 else None
+
+def getdatasize(data):
+    if isinstance(data,np.ma.masked_array):
+        return data.shape[0] - data.mask.reshape((data.shape[0],-1))[:,0].sum()
+    elif isinstance(data,np.ndarray):
+        if len(data) == 0:
+            return 0
+        return data[gi(data)].shape[0]
+    elif isinstance(data,list):
+        return sum(getdatasize(d) for d in data)
+    else:
+        # handle unboxed case for convenience
+        assert isinstance(data,int) or isinstance(data,float)
+        return 1
+
+def getdatadimension(data):
+    if isinstance(data,np.ndarray):
+        assert data.ndim > 1
+        return data.shape[1]
+    elif isinstance(data,list):
+        assert len(data) > 0
+        return getdatadimension(data[0])
+    else:
+        # handle unboxed case for convenience
+        assert isinstance(data,int) or isinstance(data,float)
+        return 1
+
+def combinedata(datas):
+    ret = []
+    for data in datas:
+        if isinstance(data,np.ma.masked_array):
+            ret.append(np.ma.compress_rows(data))
+        if isinstance(data,np.ndarray):
+            ret.append(data)
+        elif isinstance(data,list):
+            ret.extend(combinedata(data))
+        else:
+            # handle unboxed case for convenience
+            assert isinstance(data,int) or isinstance(data,float)
+            ret.append(np.atleast_1d(data))
+    return ret
+
+def flattendata(data):
+    # data is either an array (possibly a maskedarray) or a list of arrays
+    if isinstance(data,np.ndarray):
+        return data
+    elif isinstance(data,list) or isinstance(data,tuple):
+        if any(isinstance(d,np.ma.MaskedArray) for d in data):
+            return np.concatenate([np.ma.compress_rows(d) for d in data])
+        else:
+            return np.concatenate(data)
+    else:
+        # handle unboxed case for convenience
+        assert isinstance(data,int) or isinstance(data,float)
+        return np.atleast_1d(data)
+
+### misc
+def update_param(oldv, newv, stepsize):
+    return oldv * (1 - stepsize) + newv * stepsize
+
+
+def cov(a):
+    # return np.cov(a,rowvar=0,bias=1)
+    mu = a.mean(0)
+    if isinstance(a,np.ma.MaskedArray):
+        return np.ma.dot(a.T,a)/a.count(0)[0] - np.ma.outer(mu,mu)
+    else:
+        return a.T.dot(a)/a.shape[0] - np.outer(mu,mu)
+
+def normal_cdf(x, mu=0.0, sigma=1.0):
+    z = (x - mu) / sigma
+    return 0.5 * special.erfc(-z / np.sqrt(2))
+
+
+### Sampling functions
+
+def sample_gaussian(mu=None,Sigma=None,J=None,h=None):
+    mean_params = mu is not None and Sigma is not None
+    info_params = J is not None and h is not None
+    assert mean_params or info_params
+
+    if not any_none(mu,Sigma):
+        return np.random.multivariate_normal(mu,Sigma)
+    else:
+        from scipy.linalg.lapack import dpotrs
+        L = np.linalg.cholesky(J)
+        x = np.random.randn(h.shape[0])
+        return scipy.linalg.solve_triangular(L,x,lower=True,trans='T') \
+            + dpotrs(L,h,lower=True)[0]
+
+def sample_truncated_gaussian(mu=0, sigma=1, lb=-np.Inf, ub=np.Inf):
+    """
+    Sample a truncated normal with the specified params. This
+    is not the most stable way but it works as long as the
+    truncation region is not too far from the mean.
+    """
+    # Broadcast arrays to be of the same shape
+    mu, sigma, lb, ub = np.broadcast_arrays(mu, sigma, lb, ub)
+    shp = mu.shape
+    if np.allclose(sigma, 0.0):
+        return mu
+
+    cdflb = normal_cdf(lb, mu, sigma)
+    cdfub = normal_cdf(ub, mu, sigma)
+
+    # Sample uniformly from the CDF
+    cdfsamples = cdflb + np.random.rand(*shp) * (cdfub-cdflb)
+
+    # Clip the CDF samples so that we can invert them
+    cdfsamples = np.clip(cdfsamples, 1e-15, 1-1e-15)
+    zs = -np.sqrt(2) * special.erfcinv(2 * cdfsamples)
+
+    # Transform the standard normal samples
+    xs = sigma * zs + mu
+    xs = np.clip(xs, lb, ub)
+
+    return xs
+
+def sample_discrete(distn,size=[],dtype=np.int32):
+    'samples from a one-dimensional finite pmf'
+    distn = np.atleast_1d(distn)
+    assert (distn >=0).all() and distn.ndim == 1
+    if (0 == distn).all():
+        return np.random.randint(distn.shape[0],size=size)
+    cumvals = np.cumsum(distn)
+    return np.sum(np.array(random(size))[...,na] * cumvals[-1] > cumvals, axis=-1,dtype=dtype)
+
+def sample_discrete_from_log(p_log,return_lognorms=False,axis=0,dtype=np.int32):
+    'samples log probability array along specified axis'
+    lognorms = logsumexp(p_log,axis=axis)
+    cumvals = np.exp(p_log - np.expand_dims(lognorms,axis)).cumsum(axis)
+    thesize = np.array(p_log.shape)
+    thesize[axis] = 1
+    randvals = random(size=thesize) * \
+            np.reshape(cumvals[[slice(None) if i is not axis else -1
+                for i in range(p_log.ndim)]],thesize)
+    samples = np.sum(randvals > cumvals,axis=axis,dtype=dtype)
+    if return_lognorms:
+        return samples, lognorms
+    else:
+        return samples
+
+def sample_markov(T,trans_matrix,init_state_distn):
+    out = np.empty(T,dtype=np.int32)
+    out[0] = sample_discrete(init_state_distn)
+    for t in range(1,T):
+        out[t] = sample_discrete(trans_matrix[out[t-1]])
+    return out
+
+def sample_invgamma(alpha, beta):
+    return 1./np.random.gamma(alpha, 1./beta)
+
+def niw_expectedstats(nu, S, m, kappa):
+    D = m.shape[0]
+
+    # TODO speed this up with cholesky of S
+    E_J = nu * np.linalg.inv(S)
+    E_h = nu * np.linalg.solve(S,m)
+    E_muJmuT = D/kappa + m.dot(E_h)
+    E_logdetSigmainv = special.digamma((nu-np.arange(D))/2.).sum() \
+        + D*np.log(2.) - np.linalg.slogdet(S)[1]
+
+    return E_J, E_h, E_muJmuT, E_logdetSigmainv
+
+
+def sample_niw(mu,lmbda,kappa,nu):
+    '''
+    Returns a sample from the normal/inverse-wishart distribution, conjugate
+    prior for (simultaneously) unknown mean and unknown covariance in a
+    Gaussian likelihood model. Returns covariance.
+    '''
+    # code is based on Matlab's method
+    # reference: p. 87 in Gelman's Bayesian Data Analysis
+    assert nu > lmbda.shape[0] and kappa > 0
+
+    # first sample Sigma ~ IW(lmbda,nu)
+    lmbda = sample_invwishart(lmbda,nu)
+    # then sample mu | Lambda ~ N(mu, Lambda/kappa)
+    mu = np.random.multivariate_normal(mu,lmbda / kappa)
+
+    return mu, lmbda
+
+def sample_invwishart(S,nu):
+    # TODO make a version that returns the cholesky
+    # TODO allow passing in chol/cholinv of matrix parameter lmbda
+    # TODO lowmem! memoize! dchud (eigen?)
+    n = S.shape[0]
+    chol = np.linalg.cholesky(S)
+
+    if (nu <= 81+n) and (nu == np.round(nu)):
+        x = np.random.randn(int(nu),n)
+    else:
+        x = np.diag(np.sqrt(np.atleast_1d(stats.chi2.rvs(nu-np.arange(n)))))
+        x[np.triu_indices_from(x,1)] = np.random.randn(n*(n-1)//2)
+    R = np.linalg.qr(x,'r')
+    T = scipy.linalg.solve_triangular(R.T,chol.T,lower=True).T
+    return np.dot(T,T.T)
+
+def sample_wishart(sigma, nu):
+    n = sigma.shape[0]
+    chol = np.linalg.cholesky(sigma)
+
+    # use matlab's heuristic for choosing between the two different sampling schemes
+    if (nu <= 81+n) and (nu == round(nu)):
+        # direct
+        X = np.dot(chol,np.random.normal(size=(n,nu)))
+    else:
+        A = np.diag(np.sqrt(np.random.chisquare(nu - np.arange(n))))
+        A[np.tri(n,k=-1,dtype=bool)] = np.random.normal(size=(n*(n-1)/2.))
+        X = np.dot(chol,A)
+
+    return np.dot(X,X.T)
+
+def sample_mn(M, U=None, Uinv=None, V=None, Vinv=None):
+    assert (U is None) ^ (Uinv is None)
+    assert (V is None) ^ (Vinv is None)
+
+    G = np.random.normal(size=M.shape)
+
+    if U is not None:
+        G = np.dot(np.linalg.cholesky(U),G)
+    else:
+        G = np.linalg.solve(np.linalg.cholesky(Uinv).T,G)
+
+    if V is not None:
+        G = np.dot(G,np.linalg.cholesky(V).T)
+    else:
+        G = np.linalg.solve(np.linalg.cholesky(Vinv).T,G.T).T
+
+    return M + G
+
+def sample_mniw(nu, S, M, K=None, Kinv=None):
+    assert (K is None) ^ (Kinv is None)
+    Sigma = sample_invwishart(S,nu)
+    if K is not None:
+        return sample_mn(M=M,U=Sigma,V=K), Sigma
+    else:
+        return sample_mn(M=M,U=Sigma,Vinv=Kinv), Sigma
+
+def mniw_expectedstats(nu, S, M, K=None, Kinv=None):
+    # NOTE: could speed this up with chol factorizing S, not re-solving
+    assert (K is None) ^ (Kinv is None)
+    m = M.shape[0]
+    K = K if K is not None else np.linalg.inv(Kinv)
+
+    E_Sigmainv = nu*np.linalg.inv(S)
+    E_Sigmainv_A = nu*np.linalg.solve(S,M)
+    E_AT_Sigmainv_A = m*K + nu*M.T.dot(np.linalg.solve(S,M))
+    E_logdetSigmainv = special.digamma((nu-np.arange(m))/2.).sum() \
+        + m*np.log(2) - np.linalg.slogdet(S)[1]
+
+    return E_Sigmainv, E_Sigmainv_A, E_AT_Sigmainv_A, E_logdetSigmainv
+
+def mniw_log_partitionfunction(nu, S, M, K):
+    n = M.shape[0]
+    return n*nu/2*np.log(2) + special.multigammaln(nu/2., n) \
+        - nu/2*np.linalg.slogdet(S)[1] - n/2*np.linalg.slogdet(K)[1]
+
+def sample_pareto(x_m,alpha):
+    return x_m + np.random.pareto(alpha)
+
+def sample_crp_tablecounts(concentration,customers,colweights):
+    m = np.zeros_like(customers)
+    tot = customers.sum()
+    randseq = np.random.random(tot)
+
+    starts = np.empty_like(customers)
+    starts[0,0] = 0
+    starts.flat[1:] = np.cumsum(np.ravel(customers)[:customers.size-1])
+
+    for (i,j), n in np.ndenumerate(customers):
+        w = colweights[j]
+        for k in range(n):
+            m[i,j] += randseq[starts[i,j]+k] \
+                    < (concentration * w) / (k + concentration * w)
+
+    return m
+
+### Entropy
+def invwishart_entropy(sigma,nu,chol=None):
+    D = sigma.shape[0]
+    chol = np.linalg.cholesky(sigma) if chol is None else chol
+    Elogdetlmbda = special.digamma((nu-np.arange(D))/2).sum() + D*np.log(2) - 2*np.log(chol.diagonal()).sum()
+    return invwishart_log_partitionfunction(sigma,nu,chol)-(nu-D-1)/2*Elogdetlmbda + nu*D/2
+
+def invwishart_log_partitionfunction(sigma,nu,chol=None):
+    # In Bishop B.79 notation, this is -log B(W, nu), where W = sigma^{-1}
+    D = sigma.shape[0]
+    chol = np.linalg.cholesky(sigma) if chol is None else chol
+    return -1*(nu*np.log(chol.diagonal()).sum() - (nu*D/2*np.log(2) + D*(D-1)/4*np.log(np.pi) \
+            + special.gammaln((nu-np.arange(D))/2).sum()))
+
+### Predictive
+
+def multivariate_t_loglik(y,nu,mu,lmbda):
+    # returns the log value
+    d = len(mu)
+    yc = np.array(y-mu,ndmin=2)
+    L = np.linalg.cholesky(lmbda)
+    ys = scipy.linalg.solve_triangular(L,yc.T,overwrite_b=True,lower=True)
+    return scipy.special.gammaln((nu+d)/2.) - scipy.special.gammaln(nu/2.) \
+            - (d/2.)*np.log(nu*np.pi) - np.log(L.diagonal()).sum() \
+            - (nu+d)/2.*np.log1p(1./nu*inner1d(ys.T,ys.T))
+
+def beta_predictive(priorcounts,newcounts):
+    prior_nsuc, prior_nfail = priorcounts
+    nsuc, nfail = newcounts
+
+    numer = scipy.special.gammaln(np.array([nsuc+prior_nsuc,
+        nfail+prior_nfail, prior_nsuc+prior_nfail])).sum()
+    denom = scipy.special.gammaln(np.array([prior_nsuc, prior_nfail,
+        prior_nsuc+prior_nfail+nsuc+nfail])).sum()
+    return numer - denom
+
+### Statistical tests
+
+def two_sample_t_statistic(pop1, pop2):
+    pop1, pop2 = (flattendata(p) for p in (pop1, pop2))
+    t = (pop1.mean(0) - pop2.mean(0)) / np.sqrt(pop1.var(0)/pop1.shape[0] + pop2.var(0)/pop2.shape[0])
+    p = 2*stats.t.sf(np.abs(t),np.minimum(pop1.shape[0],pop2.shape[0]))
+    return t,p
+
+def f_statistic(pop1, pop2): # TODO test
+    pop1, pop2 = (flattendata(p) for p in (pop1, pop2))
+    var1, var2 = pop1.var(0), pop2.var(0)
+    n1, n2 = np.where(var1 >= var2, pop1.shape[0], pop2.shape[0]), \
+             np.where(var1 >= var2, pop2.shape[0], pop1.shape[0])
+    var1, var2 = np.maximum(var1,var2), np.minimum(var1,var2)
+    f = var1 / var2
+    p = stats.f.sf(f,n1,n2)
+    return f,p
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/testing.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..73f772595310039efe8b9567282b9733c43c22e4
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/testing.py
@@ -0,0 +1,99 @@
+from __future__ import division
+from __future__ import absolute_import
+from builtins import zip
+import numpy as np
+from numpy import newaxis as na
+
+from . import stats, general
+
+#########################
+#  statistical testing  #
+#########################
+
+### graphical
+
+def populations_eq_quantile_plot(pop1, pop2, fig=None, percentilecutoff=5):
+    import matplotlib.pyplot as plt
+
+    pop1, pop2 = stats.flattendata(pop1), stats.flattendata(pop2)
+    assert pop1.ndim == pop2.ndim == 1 or \
+            (pop1.ndim == pop2.ndim == 2 and pop1.shape[1] == pop2.shape[1]), \
+            'populations must have consistent dimensions'
+    D = pop1.shape[1] if pop1.ndim == 2 else 1
+
+    # we want to have the same number of samples
+    n1, n2 = pop1.shape[0], pop2.shape[0]
+    if n1 != n2:
+        # subsample, since interpolation is dangerous
+        if n1 < n2:
+            pop1, pop2 = pop2, pop1
+        np.random.shuffle(pop1)
+        pop1 = pop1[:pop2.shape[0]]
+
+    def plot_1d_scaled_quantiles(p1,p2,plot_midline=True):
+
+        # scaled quantiles so that multiple calls line up
+        p1.sort(), p2.sort() # NOTE: destructive! but that's cool
+        xmin,xmax = general.scoreatpercentile(p1,percentilecutoff), \
+                    general.scoreatpercentile(p1,100-percentilecutoff)
+        ymin,ymax = general.scoreatpercentile(p2,percentilecutoff), \
+                    general.scoreatpercentile(p2,100-percentilecutoff)
+        plt.plot((p1-xmin)/(xmax-xmin),(p2-ymin)/(ymax-ymin))
+
+        if plot_midline:
+            plt.plot((0,1),(0,1),'k--')
+        plt.axis((0,1,0,1))
+
+    if D == 1:
+        if fig is None:
+            plt.figure()
+        plot_1d_scaled_quantiles(pop1,pop2)
+    else:
+        if fig is None:
+            fig = plt.figure()
+
+        if not hasattr(fig,'_quantile_test_projs'):
+            firsttime = True
+            randprojs = np.random.randn(D,D)
+            randprojs /= np.sqrt(np.sum(randprojs**2,axis=1))[:,na]
+            projs = np.vstack((np.eye(D),randprojs))
+            fig._quantile_test_projs = projs
+        else:
+            firsttime = False
+            projs = fig._quantile_test_projs
+
+        ims1, ims2 = pop1.dot(projs.T), pop2.dot(projs.T)
+        for i, (im1, im2) in enumerate(zip(ims1.T,ims2.T)):
+            plt.subplot(2,D,i+1)
+            plot_1d_scaled_quantiles(im1,im2,plot_midline=firsttime)
+
+### numerical
+
+# NOTE: a random numerical test should be repeated at the OUTERMOST loop (with
+# exception catching) to see if its failures exceed the number expected
+# according to the specified pvalue (tests could be repeated via sample
+# bootstrapping inside the test, but that doesn't work reliably and random tests
+# should have no problem generating new randomness!)
+
+def assert_populations_eq(pop1, pop2):
+    assert_populations_eq_moments(pop1,pop2) and \
+    assert_populations_eq_komolgorofsmirnov(pop1,pop2)
+
+def assert_populations_eq_moments(pop1, pop2, **kwargs):
+    # just first two moments implemented; others are hard to estimate anyway!
+    assert_populations_eq_means(pop1,pop2,**kwargs) and \
+    assert_populations_eq_variances(pop1,pop2,**kwargs)
+
+def assert_populations_eq_means(pop1, pop2, pval=0.05, msg=None):
+    _,p = stats.two_sample_t_statistic(pop1,pop2)
+    if np.any(p < pval):
+        raise AssertionError(msg or "population means might be different at %0.3f" % pval)
+
+def assert_populations_eq_variances(pop1, pop2, pval=0.05, msg=None):
+    _,p = stats.f_statistic(pop1, pop2)
+    if np.any(p < pval):
+        raise AssertionError(msg or "population variances might be different at %0.3f" % pval)
+
+def assert_populations_eq_komolgorofsmirnov(pop1, pop2, msg=None):
+    raise NotImplementedError # TODO
+
diff --git a/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/text.py b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcddd4a0b42304d6896c2e433d8c07ab714b185f
--- /dev/null
+++ b/build/lib.linux-x86_64-cpython-37/pybasicbayes/util/text.py
@@ -0,0 +1,59 @@
+from __future__ import print_function
+from builtins import range
+import numpy as np
+import sys, time
+
+# time.clock() is cpu time of current process
+# time.time() is wall time
+
+# TODO there are probably better progress bar libraries I could use
+
+round = (lambda x: lambda y: int(x(y)))(round)
+
+# NOTE: datetime.timedelta.__str__ doesn't allow formatting the number of digits
+def sec2str(seconds):
+    hours, rem = divmod(seconds,3600)
+    minutes, seconds = divmod(rem,60)
+    if hours > 0:
+        return '%02d:%02d:%02d' % (hours,minutes,round(seconds))
+    elif minutes > 0:
+        return '%02d:%02d' % (minutes,round(seconds))
+    else:
+        return '%0.2f' % seconds
+
+def progprint_xrange(*args,**kwargs):
+    xr = range(*args)
+    return progprint(xr,total=len(xr),**kwargs)
+
+def progprint(iterator,total=None,perline=25,show_times=True):
+    times = []
+    idx = 0
+    if total is not None:
+        numdigits = len('%d' % total)
+    for thing in iterator:
+        prev_time = time.time()
+        yield thing
+        times.append(time.time() - prev_time)
+        sys.stdout.write('.')
+        if (idx+1) % perline == 0:
+            if show_times:
+                avgtime = np.mean(times)
+                if total is not None:
+                    eta = sec2str(avgtime*(total-(idx+1)))
+                    sys.stdout.write((
+                        '  [ %%%dd/%%%dd, %%7.2fsec avg, ETA %%s ]\n'
+                                % (numdigits,numdigits)) % (idx+1,total,avgtime,eta))
+                else:
+                    sys.stdout.write('  [ %d done, %7.2fsec avg ]\n' % (idx+1,avgtime))
+            else:
+                if total is not None:
+                    sys.stdout.write(('  [ %%%dd/%%%dd ]\n' % (numdigits,numdigits) ) % (idx+1,total))
+                else:
+                    sys.stdout.write('  [ %d ]\n' % (idx+1))
+        idx += 1
+        sys.stdout.flush()
+    print('')
+    if show_times and len(times) > 0:
+        total = sec2str(seconds=np.sum(times))
+        print('%7.2fsec avg, %s total\n' % (np.mean(times),total))
+
diff --git a/build/temp.linux-x86_64-cpython-37/pybasicbayes/util/cstats.o b/build/temp.linux-x86_64-cpython-37/pybasicbayes/util/cstats.o
new file mode 100644
index 0000000000000000000000000000000000000000..9387a5d9c71dfce414545f1c543831d4491afd69
Binary files /dev/null and b/build/temp.linux-x86_64-cpython-37/pybasicbayes/util/cstats.o differ
diff --git a/pybasicbayes.egg-info/PKG-INFO b/pybasicbayes.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..ca4f4ded4e714bac85bdf176f18f1da59e5f52f1
--- /dev/null
+++ b/pybasicbayes.egg-info/PKG-INFO
@@ -0,0 +1,12 @@
+Metadata-Version: 2.1
+Name: pybasicbayes
+Version: 0.2.4
+Summary: Basic utilities for Bayesian inference
+Home-page: http://github.com/mattjj/pybasicbayes
+Author: Matthew James Johnson
+Author-email: mattjj@csail.mit.edu
+Keywords: bayesian,inference,mcmc,variational inference,mean field,vb
+Platform: ALL
+Classifier: Intended Audience :: Science/Research
+Classifier: Programming Language :: Python
+License-File: LICENSE-MIT
diff --git a/pybasicbayes.egg-info/SOURCES.txt b/pybasicbayes.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5d6133d6a35f50c046a1fd67342fe2ecbc13ce15
--- /dev/null
+++ b/pybasicbayes.egg-info/SOURCES.txt
@@ -0,0 +1,40 @@
+LICENSE-MIT
+MANIFEST.in
+README.md
+setup.py
+pybasicbayes/__init__.py
+pybasicbayes/abstractions.py
+pybasicbayes.egg-info/PKG-INFO
+pybasicbayes.egg-info/SOURCES.txt
+pybasicbayes.egg-info/dependency_links.txt
+pybasicbayes.egg-info/requires.txt
+pybasicbayes.egg-info/top_level.txt
+pybasicbayes/distributions/__init__.py
+pybasicbayes/distributions/binomial.py
+pybasicbayes/distributions/dynamic_glm.py
+pybasicbayes/distributions/dynamic_multinomial.py
+pybasicbayes/distributions/dynamic_multinomial_in_progress.py
+pybasicbayes/distributions/gaussian.py
+pybasicbayes/distributions/geometric.py
+pybasicbayes/distributions/meta.py
+pybasicbayes/distributions/multinomial.py
+pybasicbayes/distributions/negativebinomial.py
+pybasicbayes/distributions/poisson.py
+pybasicbayes/distributions/regression.py
+pybasicbayes/distributions/uniform.py
+pybasicbayes/models/__init__.py
+pybasicbayes/models/factor_analysis.py
+pybasicbayes/models/mixture.py
+pybasicbayes/models/parallel_mixture.py
+pybasicbayes/testing/__init__.py
+pybasicbayes/testing/mixins.py
+pybasicbayes/util/__init__.py
+pybasicbayes/util/cstats.c
+pybasicbayes/util/cstats.pyx
+pybasicbayes/util/cyutil.py
+pybasicbayes/util/general.py
+pybasicbayes/util/plot.py
+pybasicbayes/util/profiling.py
+pybasicbayes/util/stats.py
+pybasicbayes/util/testing.py
+pybasicbayes/util/text.py
\ No newline at end of file
diff --git a/pybasicbayes.egg-info/dependency_links.txt b/pybasicbayes.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/pybasicbayes.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/pybasicbayes.egg-info/requires.txt b/pybasicbayes.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..282c01825ff3431ee578440f385edc6925c77f55
--- /dev/null
+++ b/pybasicbayes.egg-info/requires.txt
@@ -0,0 +1,5 @@
+numpy
+scipy
+matplotlib
+nose
+future
diff --git a/pybasicbayes.egg-info/top_level.txt b/pybasicbayes.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c3d06054ef9ee2ef2afc5d34ca25aecef210321c
--- /dev/null
+++ b/pybasicbayes.egg-info/top_level.txt
@@ -0,0 +1 @@
+pybasicbayes
diff --git a/pybasicbayes/distributions/multinomial.py b/pybasicbayes/distributions/multinomial.py
index 477675be59a274e5ee3018fb7e899b49fe4d3577..f779c84ae757fc74bf01a29da14d55646cc6a1de 100644
--- a/pybasicbayes/distributions/multinomial.py
+++ b/pybasicbayes/distributions/multinomial.py
@@ -111,12 +111,12 @@ class Categorical(GibbsSampling, MeanField, MeanFieldSVI, MaxLikelihood, MAP):
             self.weights = np.random.dirichlet(self.alphav_0 + counts)
         except ZeroDivisionError as e:
             # print("ZeroDivisionError {}".format(e))
-            self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts)
+            self.weights = np.random.dirichlet(self.alphav_0 + 0.03 + counts)
         except ValueError as e:
             # print("ValueError {}".format(e))
-            self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts)
+            self.weights = np.random.dirichlet(self.alphav_0 + 0.03 + counts)
         if np.isnan(self.weights).any():
-            self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts)
+            self.weights = np.random.dirichlet(self.alphav_0 + 0.03 + counts)
         np.clip(self.weights, np.spacing(1.), np.inf, out=self.weights)
         # NOTE: next line is so we can use Gibbs sampling to initialize mean field
         self._alpha_mf = self.weights * self.alphav_0.sum()