From 0a1f804f38d8efc5d11212961689b8ff102abd43 Mon Sep 17 00:00:00 2001
From: SebastianBruijns <>
Date: Thu, 15 Sep 2022 14:21:12 +0200
Subject: [PATCH] my changes

---
 pyhsmm/internals/hmm_states.py    |  10 +-
 pyhsmm/internals/hsmm_states.py   | 201 ++++++++++++++++++++++++++++--
 pyhsmm/internals/initial_state.py |   2 +-
 pyhsmm/internals/transitions.py   |   5 +-
 pyhsmm/models.py                  |  72 +++++++++--
 5 files changed, 260 insertions(+), 30 deletions(-)

diff --git a/pyhsmm/internals/hmm_states.py b/pyhsmm/internals/hmm_states.py
index c5dfb7f..e6ff225 100644
--- a/pyhsmm/internals/hmm_states.py
+++ b/pyhsmm/internals/hmm_states.py
@@ -5,7 +5,7 @@ import numpy as np
 from numpy import newaxis as na
 import abc
 import copy
-from scipy.special import logsumexp
+from scipy.misc import logsumexp
 
 from pyhsmm.util.stats import sample_discrete
 try:
@@ -21,11 +21,12 @@ from pyhsmm.util.general import rle
 class _StatesBase(with_metaclass(abc.ABCMeta, object)):
 
     def __init__(self,model,T=None,data=None,stateseq=None,
-            generate=True,initialize_from_prior=True, fixed_stateseq=False):
+            generate=True,initialize_from_prior=True, fixed_stateseq=False, timepoint=0):
         self.model = model
 
         self.T = T if T is not None else data.shape[0]
         self.data = data
+        self.timepoint = timepoint
 
         self.clear_caches()
 
@@ -100,7 +101,10 @@ class _StatesBase(with_metaclass(abc.ABCMeta, object)):
 
             aBl = self._aBl = np.empty((data.shape[0],self.num_states))
             for idx, obs_distn in enumerate(self.obs_distns):
-                aBl[:,idx] = obs_distn.log_likelihood(data).ravel()
+                if self.model.var_prior is None:
+                    aBl[:,idx] = obs_distn.log_likelihood(data).ravel()
+                else:
+                    aBl[:,idx] = obs_distn.log_likelihood(data, self.timepoint).ravel()
             aBl[np.isnan(aBl).any(1)] = 0.
 
         return self._aBl
diff --git a/pyhsmm/internals/hsmm_states.py b/pyhsmm/internals/hsmm_states.py
index b06b831..6a0505e 100644
--- a/pyhsmm/internals/hsmm_states.py
+++ b/pyhsmm/internals/hsmm_states.py
@@ -2,10 +2,13 @@ from __future__ import division
 from builtins import range, map
 import numpy as np
 from numpy import newaxis as na
-from scipy.special import logsumexp
+from scipy.misc import logsumexp
 
 from pyhsmm.util.stats import sample_discrete
 from pyhsmm.util.general import rle, rcumsum, cumsum
+import random
+import time
+import pickle
 
 from . import hmm_states
 from .hmm_states import _StatesBase, _SeparateTransMixin, \
@@ -14,7 +17,7 @@ from .hmm_states import _StatesBase, _SeparateTransMixin, \
 
 class HSMMStatesPython(_StatesBase):
     def __init__(self,model,right_censoring=True,left_censoring=False,trunc=None,
-            stateseq=None,**kwargs):
+            stateseq=None, timepoint=0,**kwargs):
         self.right_censoring = right_censoring
         self.left_censoring = left_censoring
         self.trunc = trunc
@@ -23,7 +26,7 @@ class HSMMStatesPython(_StatesBase):
             self._kwargs,trunc=trunc,
             left_censoring=left_censoring,right_censoring=right_censoring)
 
-        super(HSMMStatesPython,self).__init__(model,stateseq=stateseq,**kwargs)
+        super(HSMMStatesPython,self).__init__(model,stateseq=stateseq,timepoint=timepoint,**kwargs)
 
     ### properties for the outside world
 
@@ -501,13 +504,41 @@ class HSMMStatesEigen(HSMMStatesPython):
         # NOTE: np.maximum calls are because the C++ code doesn't do
         # np.logaddexp(-inf,-inf) = -inf, it likes nans instead
         from pyhsmm.internals.hsmm_messages_interface import messages_backwards_log
-        betal, betastarl = messages_backwards_log(
-                np.maximum(self.trans_matrix,1e-50),self.aBl,np.maximum(self.aDl,-1000000),
-                self.aDsl,np.empty_like(self.aBl),np.empty_like(self.aBl),
-                self.right_censoring,self.trunc if self.trunc is not None else self.T)
+        # aBl are emission log probs, aDl are log likelihoods of durations,
+        # aDsl are log probs of survival func of durations
+        # print()
+        # a = time.time()
+        # !!! trans_mat is not logged !!!
+        betal, betastarl = messages_backwards_log(np.maximum(self.trans_matrix, 1e-50), self.aBl,
+                                                  np.maximum(self.aDl, -1000000), self.aDsl,
+                                                  np.empty_like(self.aBl), np.empty_like(self.aBl),
+                                                  self.right_censoring, self.trunc if self.trunc is not None else self.T)
+        # print(time.time() - a)
         assert not np.isnan(betal).any()
         assert not np.isnan(betastarl).any()
 
+        """print('abl') # proof that betastarl[-1] == self.aBl[-1]
+        print(betastarl[-1] - self.aBl[-1])"""
+
+
+        # code to see that my python message passer works
+        # a = time.time()
+        # my_betal, my_betastarl = self.my_messages_backwards_log(
+        #         np.maximum(self.trans_matrix,1e-50), self.aBl, np.maximum(self.aDl,-1000000), self.aDsl)
+        # print(time.time() - a)
+
+        # print('diffs')
+        # print(np.sum( np.abs(betal - my_betal)))
+        # print(np.sum( np.abs(betastarl - my_betastarl)))
+        #
+        # my_var = 52
+        # print(betal[- my_var])
+        # print(betastarl[- my_var])
+        #
+        # print(my_betal[- my_var])
+        # print(my_betastarl[- my_var])
+
+
         if not self.left_censoring:
             self._normalizer = logsumexp(np.log(self.pi_0) + betastarl[0])
         else:
@@ -515,18 +546,67 @@ class HSMMStatesEigen(HSMMStatesPython):
 
         return betal, betastarl
 
+    def my_messages_backwards_log(self, trans_mat, log_obs, log_durs, log_survivals):
+        betal = np.zeros_like(log_obs) # 0's here are important because of += later
+        betastarl = np.zeros_like(log_obs)
+        betal[-1] = 1
+        betastarl[-1] = np.exp(log_obs[-1]) # empirical fact, might be censoring term -> I get it myself to, comes from marginalisation of all durations since survival func
+        S = betal.shape[1]
+        T = betal.shape[0]
+
+
+        #print(np.exp(log_durs[1]) + np.exp(log_durs[0]) + np.exp(log_survivals[1])) # this is how it works
+
+        for t in range(T - 1, 0, -1):
+            for i in range(S):
+                temp = 0
+                for d in range(1, T - t + 1):
+                    temp += betal[t + d - 1, i] * np.exp(log_durs[d - 1, i]) * np.prod(np.exp(log_obs[t:t + d, i])) # durs might start at 1
+
+                temp += np.exp(log_survivals[T - t - 1, i]) * np.prod(np.exp(log_obs[t:, i])) # depends on coding of sf
+                betastarl[t, i] = temp
+
+            for i in range(S):
+                for j in range(S):
+                    betal[t - 1, i] += betastarl[t, j] * trans_mat[i, j]
+
+        # one last time, to fill last row o fbetastarl
+        t = 0
+        for i in range(S):
+            temp = 0
+            for d in range(1, T - t + 1):
+                temp += betal[t + d - 1, i] * np.exp(log_durs[d - 1, i]) * np.prod(np.exp(log_obs[t:t + d, i])) # durs might start at 1
+
+            temp += np.exp(log_survivals[T - t - 1, i]) * np.prod(np.exp(log_obs[t:, i])) # depends on coding of sf
+            betastarl[t, i] = temp
+
+        return np.log(betal), np.log(betastarl)
+
     def messages_backwards_python(self):
         return super(HSMMStatesEigen,self).messages_backwards()
 
     def sample_forwards(self,betal,betastarl):
-        from pyhsmm.internals.hsmm_messages_interface import sample_forwards_log
+        # from pyhsmm.internals.hsmm_messages_interface import sample_forwards_log
         if self.left_censoring:
             raise NotImplementedError
         caBl = np.vstack((np.zeros(betal.shape[1]),np.cumsum(self.aBl[:-1],axis=0)))
-        self.stateseq = sample_forwards_log(
-                self.trans_matrix,caBl,self.aDl,self.pi_0,betal,betastarl,
-                np.empty(betal.shape[0],dtype='int32'))
-        assert not (0 == self.stateseq).all()
+
+        # self.stateseq = sample_forwards_log(
+        #         self.trans_matrix,caBl,self.aDl,self.pi_0,betal,betastarl,
+        #         np.empty(betal.shape[0],dtype='int32'))
+
+        temp_a = self.trans_matrix
+        temp_b = caBl
+        temp_c = self.aDl
+        temp_d = self.pi_0
+        temp_e = betal
+        temp_f = betastarl
+        temp_g = np.empty(betal.shape[0], dtype='int32')
+        temp = (temp_a, temp_b, temp_c, temp_d, temp_e, temp_f, temp_g)
+
+        tempali = sample_forwards_log_mypy(*temp)
+        self.stateseq = tempali
+        # assert not (0 == self.stateseq).all()
 
     def sample_forwards_python(self,betal,betastarl):
         return super(HSMMStatesEigen,self).sample_forwards(betal,betastarl)
@@ -994,7 +1074,6 @@ def hsmm_messages_backwards_log(
     betal[-1] = 0.
     for t in range(T-1,-1,-1):
         cB, offset = cumulative_obs_potentials(t)
-        dp = dur_potentials(t)
         betastarl[t] = logsumexp(
             betal[t:t+cB.shape[0]] + cB + dur_potentials(t), axis=0)
         betastarl[t] -= offset
@@ -1131,3 +1210,99 @@ def hsmm_maximizing_assignment(
 
     return stateseq
 
+
+def sample_forwards_log_myc(trans_matrix, caBl, aDl, pi_0, betal, betastarl, irrev):
+    t = 0
+    T, limit_n = betal.shape
+    nextstate_distr = pi_0
+    stateseq = np.empty(T, dtype=np.int32)
+
+    while t < T:
+        logdomain = betastarl[t] - betastarl[t].max()
+        nextstate_distr *= np.exp(logdomain)
+        if (nextstate_distr == 0.).all():
+            print("Warning: all-zero posterior state belief, following likelihood")
+            nextstate_distr = np.exp(logdomain)
+
+        # sample_discrete behaves weirdly, doesn't need normalization?
+        # https://www.sidefx.com/docs/houdini/vex/functions/sample_discrete.html
+        # seems to work fine
+        state = np.random.choice(limit_n, p=nextstate_distr / nextstate_distr.sum())
+
+        durprob = random.random()
+        dur = 0
+        while durprob > 0. and t + dur < T:
+            p_d_prior = np.exp(aDl[dur, state])
+            if 0.0 == p_d_prior:
+                dur += 1
+                continue
+
+            p_d = p_d_prior * np.exp(caBl[t+dur+1, state] - caBl[t, state] + betal[t+dur, state] - betastarl[t, state])
+            durprob -= p_d
+            dur += 1
+
+        stateseq[t: t + dur] = state
+        t += dur
+        nextstate_distr = trans_matrix[state]
+
+
+def sample_forwards_log_mypy(trans_matrix, caBl, aDl, pi_0, betal, betastarl, irrev):
+    # Does Matt rely on no impossible observations, since he uses the cumulative obs?
+    T, limit_n = betal.shape
+    stateseq = np.empty(T, dtype=np.int32)
+
+    assert np.allclose(np.sum(pi_0), 1)
+    assert np.allclose(np.sum(trans_matrix, axis=1), np.ones(limit_n))  # not quite clear why this all_close is necessary here
+
+    total_dur = 0
+    current_state = None
+    # TODO: do I end sequence at right time?
+    while total_dur < T:
+        # this was previously done in a try except, with the except hitting if the attempted normalisation fails. Now we do it always, should give more precision (mostly negligible)
+        # also the except clause modified betastarl, causing a bug later on
+        temp_beta_potential = betastarl[total_dur] - np.max(betastarl[total_dur])
+        state_dist_un = trans_matrix[current_state] * np.exp(temp_beta_potential) if total_dur != 0 else pi_0 * np.exp(temp_beta_potential)
+        state_dist = state_dist_un / np.sum(state_dist_un)
+
+        next_state = np.random.choice(limit_n, p=state_dist)
+
+        # old duration drawing code, not as efficient as it always generates the entire probability vector
+        # temp = np.zeros(T - total_dur)  # generate one extra, for 1 - P(staying within T)
+        # for d in range(T - total_dur - 1):
+        #     temp[d] += aDl[d, next_state]
+        #     temp[d] += caBl[total_dur + d + 1, next_state] - caBl[total_dur, next_state]
+        #     temp[d] += betal[total_dur + d, next_state]  # no + 1 here, B goes from 1 (-> 0) to T (-> T-1)
+        # temp[:-1] -= betastarl[total_dur, next_state]
+        # temp[:-1] = np.exp(temp[:-1])  # TODO: but weird that we have to do -2 here, why skip 2?
+        # temp[-1] = max(0, 1 - np.sum(temp))  # cap at 0, don't let negativity intrude, TODO: let diff not be to big
+        # sample_dur = np.random.choice(T - total_dur, p=temp) + 1
+        # this bit of code actually does the same as random.choice, pretty neat
+        # # d = 0
+        # # durprob = np.random.rand()
+        # # while durprob > 0.:
+        # #     durprob -= temp[d]
+        # #     d += 1
+        # # sample_dur = d
+
+        d = 0
+        durprob = np.random.rand()
+        while durprob > 0. and total_dur + d < T - 1:
+            p_d_prior = np.exp(aDl[d, next_state])
+            if 0.0 == p_d_prior:
+                d += 1
+                continue
+
+            p_d = p_d_prior * np.exp(caBl[total_dur + d + 1, next_state]
+                            - caBl[total_dur, next_state] + betal[total_dur + d, next_state]
+                            - betastarl[total_dur, next_state])
+            durprob -= p_d
+            d += 1
+        sample_dur = d
+        if total_dur + d == T - 1 and durprob > 0.:
+            sample_dur += 1
+
+        stateseq[total_dur: min(total_dur + sample_dur, T)] = next_state  # I really want to go till end of array, no self.T - 1 in min
+        total_dur += sample_dur
+        current_state = next_state
+
+    return stateseq
diff --git a/pyhsmm/internals/initial_state.py b/pyhsmm/internals/initial_state.py
index ca0bba8..10feb0c 100644
--- a/pyhsmm/internals/initial_state.py
+++ b/pyhsmm/internals/initial_state.py
@@ -8,7 +8,7 @@ from pyhsmm.basic.abstractions import GibbsSampling, MaxLikelihood
 from pyhsmm.basic.distributions import Categorical
 
 class UniformInitialState(object):
-    def __init__(self,model):
+    def __init__(self,model,init_state_concentration=None,pi_0=None):  # Changed, swallow these arguments
         self.model = model
 
     @property
diff --git a/pyhsmm/internals/transitions.py b/pyhsmm/internals/transitions.py
index 5921c66..9e335ce 100644
--- a/pyhsmm/internals/transitions.py
+++ b/pyhsmm/internals/transitions.py
@@ -213,10 +213,12 @@ class _HSMMTransitionsGibbs(_HSMMTransitionsBase,_HMMTransitionsGibbs):
 
         if trans_counts.sum() > 0:
             froms = trans_counts.sum(1)
-            self_trans = [np.random.geometric(1-A_ii,size=n).sum() if n > 0 else 0
+            # big change here, SAB!!!
+            self_trans = [min(50000, np.random.geometric(1-A_ii,size=n).sum() if n > 0 else 0)
                     for A_ii, n in zip(self.full_trans_matrix.diagonal(),froms)]
             trans_counts += np.diag(self_trans)
 
+        #print("trans_counts {}".format(trans_counts.max()))
         return trans_counts
 
 class _HSMMTransitionsMaxLikelihood(_HSMMTransitionsBase,_HMMTransitionsMaxLikelihood):
@@ -581,4 +583,3 @@ class _DATruncHDPHSMMTransitionsSVI(_DATruncHDPHMMTransitionsSVI,_HSMMTransition
 
 class DATruncHDPHSMMTransitions(_DATruncHDPHSMMTransitionsSVI):
     pass
-
diff --git a/pyhsmm/models.py b/pyhsmm/models.py
index e3b1b9b..bf5a5ff 100644
--- a/pyhsmm/models.py
+++ b/pyhsmm/models.py
@@ -11,7 +11,8 @@ import matplotlib.pyplot as plt
 from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
 from matplotlib import cm
 from warnings import warn
-from scipy.special import logsumexp
+from scipy.misc import logsumexp
+from scipy.stats import invgamma, chi2
 
 from pyhsmm.basic.abstractions import Model, ModelGibbsSampling, \
     ModelEM, ModelMAPEM, ModelMeanField, ModelMeanFieldSVI, ModelParallelTempering
@@ -30,16 +31,19 @@ from pybasicbayes.distributions.gaussian import Gaussian
 class _HMMBase(Model):
     _states_class = hmm_states.HMMStatesPython
     _trans_class = transitions.HMMTransitions
-    _trans_conc_class = transitions.HMMTransitionsConc
-    _init_state_class = initial_state.HMMInitialState
+    _trans_conc_class = transitions.HMMTransitionsConc # conc stands for concentration, here the concentration is also inferred
+    _init_state_class = initial_state.HMMInitialState  # Changed HMMInitialState -> UniformInitialState
 
     def __init__(self,
             obs_distns,
             trans_distn=None,
             alpha=None,alpha_a_0=None,alpha_b_0=None,trans_matrix=None,
-            init_state_distn=None,init_state_concentration=None,pi_0=None):
+            init_state_distn=None,init_state_concentration=None,pi_0=None,
+            var_prior=None):
         self.obs_distns = obs_distns
         self.states_list = []
+        self.timepoint = -1
+        self.var_prior = var_prior
 
         if trans_distn is not None:
             self.trans_distn = trans_distn
@@ -159,6 +163,31 @@ class _HMMBase(Model):
 
         return outs
 
+    def delete_data(self):
+        for s in self.states_list:
+            s.data = None
+
+    def delete_obs_data(self):
+        for o in self.obs_distns:
+            o.sigma_k = None
+            o.sigma_k_k_minus = None
+            o.gain_save = None
+            o.x_hat_k = None
+            o.x_hat_k_k_minus = None
+            o.R = None
+            o.pseudo_obs = None
+            o.H = None
+            o.pseudo_Q = None
+
+    def delete_dur_data(self):
+        for d in self.dur_distns:
+            d.r_support = None
+            d.r_probs = None
+            d.rho_0 = None
+            d.rho_mf = None
+            d.p_save = d.p
+            d._fixedr_distns = None  # can't access p of the overall dist if I delete this. But if need be I could probs delete the prior of these dists
+
     @property
     def stateseqs(self):
         return [s.stateseq for s in self.states_list]
@@ -192,8 +221,8 @@ class _HMMBase(Model):
         for s in self.states_list:
             for state in s.stateseq:
                 canonical_ids[state]
-        return list(map(operator.itemgetter(0),
-                sorted(canonical_ids.items(),key=operator.itemgetter(1))))
+        return map(operator.itemgetter(0),
+                sorted(canonical_ids.items(),key=operator.itemgetter(1)))
 
     @property
     def state_usages(self):
@@ -448,8 +477,24 @@ class _HMMGibbsSampling(_HMMBase,ModelGibbsSampling):
         self.resample_init_state_distn()
 
     def resample_obs_distns(self):
-        for state, distn in enumerate(self.obs_distns):
-            distn.resample([s.data[s.stateseq == state] for s in self.states_list])
+        if self.var_prior is None:
+            for state, distn in enumerate(self.obs_distns):
+                distn.resample([s.data[s.stateseq == state] for s in self.states_list])
+        else:
+            psi_diffs = []
+            for state, distn in enumerate(self.obs_distns):
+                distn.resample([s.data[s.stateseq == state] for s in self.states_list])
+                psi_diffs.append(distn.psi_diff_saves)
+
+            if self.var_prior == 'uniform':
+                # resample sigma squared of the dynamic distributions
+                diffs = np.concatenate(psi_diffs)
+                fraction = np.sum(diffs ** 2) / 2
+                # sigmasq_states = invgamma.rvs(self.obs_distns[0].sigma_alpha + diffs.size / 2, scale=self.obs_distns[0].sigma_beta + fraction)
+                sigmasq_states = 2 * fraction / chi2(diffs.size - 1).rvs()  # uniform prior, see Gelman (p.598, tau update)
+                for distn in self.obs_distns:
+                    distn.sigmasq_states = sigmasq_states
+
         self._clear_caches()
 
     @line_profiled
@@ -671,8 +716,12 @@ class _HMMEM(_HMMBase,ModelEM):
         # approximation!)
         assert data is None and len(self.states_list) > 0, 'Must have data to get BIC'
         if data is None:
-            return -2*sum(self.log_likelihood(s.data).sum() for s in self.states_list) + \
-                        self.num_parameters() * np.log(
+            # return -2*sum(self.log_likelihood(s.data).sum() for s in self.states_list) + \
+            #             self.num_parameters() * np.log(
+            #                     sum(s.data.shape[0] for s in self.states_list))
+            # I changed this ! SAB
+            return -2*sum(s.log_likelihood().sum() for s in self.states_list) + \
+                        self.num_parameters * np.log(
                                 sum(s.data.shape[0] for s in self.states_list))
         else:
             return -2*self.log_likelihood(data) + self.num_parameters() * np.log(data.shape[0])
@@ -898,6 +947,7 @@ class _HSMMBase(_HMMBase):
 
     def add_data(self,data,stateseq=None,trunc=None,
             right_censoring=True,left_censoring=False,**kwargs):
+        self.timepoint += 1
         self.states_list.append(self._states_class(
             model=self,
             data=np.asarray(data),
@@ -905,6 +955,7 @@ class _HSMMBase(_HMMBase):
             right_censoring=right_censoring,
             left_censoring=left_censoring,
             trunc=trunc,
+            timepoint=self.timepoint,
             **kwargs))
         return self.states_list[-1]
 
@@ -1393,4 +1444,3 @@ class WeakLimitHDPHSMMTruncatedIntNegBinSeparateTrans(
         _SeparateTransMixin,
         WeakLimitHDPHSMMTruncatedIntNegBin):
     _states_class = hsmm_inb_states.HSMMStatesTruncatedIntegerNegativeBinomialSeparateTrans
-
-- 
GitLab