From 0a1f804f38d8efc5d11212961689b8ff102abd43 Mon Sep 17 00:00:00 2001 From: SebastianBruijns <> Date: Thu, 15 Sep 2022 14:21:12 +0200 Subject: [PATCH] my changes --- pyhsmm/internals/hmm_states.py | 10 +- pyhsmm/internals/hsmm_states.py | 201 ++++++++++++++++++++++++++++-- pyhsmm/internals/initial_state.py | 2 +- pyhsmm/internals/transitions.py | 5 +- pyhsmm/models.py | 72 +++++++++-- 5 files changed, 260 insertions(+), 30 deletions(-) diff --git a/pyhsmm/internals/hmm_states.py b/pyhsmm/internals/hmm_states.py index c5dfb7f..e6ff225 100644 --- a/pyhsmm/internals/hmm_states.py +++ b/pyhsmm/internals/hmm_states.py @@ -5,7 +5,7 @@ import numpy as np from numpy import newaxis as na import abc import copy -from scipy.special import logsumexp +from scipy.misc import logsumexp from pyhsmm.util.stats import sample_discrete try: @@ -21,11 +21,12 @@ from pyhsmm.util.general import rle class _StatesBase(with_metaclass(abc.ABCMeta, object)): def __init__(self,model,T=None,data=None,stateseq=None, - generate=True,initialize_from_prior=True, fixed_stateseq=False): + generate=True,initialize_from_prior=True, fixed_stateseq=False, timepoint=0): self.model = model self.T = T if T is not None else data.shape[0] self.data = data + self.timepoint = timepoint self.clear_caches() @@ -100,7 +101,10 @@ class _StatesBase(with_metaclass(abc.ABCMeta, object)): aBl = self._aBl = np.empty((data.shape[0],self.num_states)) for idx, obs_distn in enumerate(self.obs_distns): - aBl[:,idx] = obs_distn.log_likelihood(data).ravel() + if self.model.var_prior is None: + aBl[:,idx] = obs_distn.log_likelihood(data).ravel() + else: + aBl[:,idx] = obs_distn.log_likelihood(data, self.timepoint).ravel() aBl[np.isnan(aBl).any(1)] = 0. return self._aBl diff --git a/pyhsmm/internals/hsmm_states.py b/pyhsmm/internals/hsmm_states.py index b06b831..6a0505e 100644 --- a/pyhsmm/internals/hsmm_states.py +++ b/pyhsmm/internals/hsmm_states.py @@ -2,10 +2,13 @@ from __future__ import division from builtins import range, map import numpy as np from numpy import newaxis as na -from scipy.special import logsumexp +from scipy.misc import logsumexp from pyhsmm.util.stats import sample_discrete from pyhsmm.util.general import rle, rcumsum, cumsum +import random +import time +import pickle from . import hmm_states from .hmm_states import _StatesBase, _SeparateTransMixin, \ @@ -14,7 +17,7 @@ from .hmm_states import _StatesBase, _SeparateTransMixin, \ class HSMMStatesPython(_StatesBase): def __init__(self,model,right_censoring=True,left_censoring=False,trunc=None, - stateseq=None,**kwargs): + stateseq=None, timepoint=0,**kwargs): self.right_censoring = right_censoring self.left_censoring = left_censoring self.trunc = trunc @@ -23,7 +26,7 @@ class HSMMStatesPython(_StatesBase): self._kwargs,trunc=trunc, left_censoring=left_censoring,right_censoring=right_censoring) - super(HSMMStatesPython,self).__init__(model,stateseq=stateseq,**kwargs) + super(HSMMStatesPython,self).__init__(model,stateseq=stateseq,timepoint=timepoint,**kwargs) ### properties for the outside world @@ -501,13 +504,41 @@ class HSMMStatesEigen(HSMMStatesPython): # NOTE: np.maximum calls are because the C++ code doesn't do # np.logaddexp(-inf,-inf) = -inf, it likes nans instead from pyhsmm.internals.hsmm_messages_interface import messages_backwards_log - betal, betastarl = messages_backwards_log( - np.maximum(self.trans_matrix,1e-50),self.aBl,np.maximum(self.aDl,-1000000), - self.aDsl,np.empty_like(self.aBl),np.empty_like(self.aBl), - self.right_censoring,self.trunc if self.trunc is not None else self.T) + # aBl are emission log probs, aDl are log likelihoods of durations, + # aDsl are log probs of survival func of durations + # print() + # a = time.time() + # !!! trans_mat is not logged !!! + betal, betastarl = messages_backwards_log(np.maximum(self.trans_matrix, 1e-50), self.aBl, + np.maximum(self.aDl, -1000000), self.aDsl, + np.empty_like(self.aBl), np.empty_like(self.aBl), + self.right_censoring, self.trunc if self.trunc is not None else self.T) + # print(time.time() - a) assert not np.isnan(betal).any() assert not np.isnan(betastarl).any() + """print('abl') # proof that betastarl[-1] == self.aBl[-1] + print(betastarl[-1] - self.aBl[-1])""" + + + # code to see that my python message passer works + # a = time.time() + # my_betal, my_betastarl = self.my_messages_backwards_log( + # np.maximum(self.trans_matrix,1e-50), self.aBl, np.maximum(self.aDl,-1000000), self.aDsl) + # print(time.time() - a) + + # print('diffs') + # print(np.sum( np.abs(betal - my_betal))) + # print(np.sum( np.abs(betastarl - my_betastarl))) + # + # my_var = 52 + # print(betal[- my_var]) + # print(betastarl[- my_var]) + # + # print(my_betal[- my_var]) + # print(my_betastarl[- my_var]) + + if not self.left_censoring: self._normalizer = logsumexp(np.log(self.pi_0) + betastarl[0]) else: @@ -515,18 +546,67 @@ class HSMMStatesEigen(HSMMStatesPython): return betal, betastarl + def my_messages_backwards_log(self, trans_mat, log_obs, log_durs, log_survivals): + betal = np.zeros_like(log_obs) # 0's here are important because of += later + betastarl = np.zeros_like(log_obs) + betal[-1] = 1 + betastarl[-1] = np.exp(log_obs[-1]) # empirical fact, might be censoring term -> I get it myself to, comes from marginalisation of all durations since survival func + S = betal.shape[1] + T = betal.shape[0] + + + #print(np.exp(log_durs[1]) + np.exp(log_durs[0]) + np.exp(log_survivals[1])) # this is how it works + + for t in range(T - 1, 0, -1): + for i in range(S): + temp = 0 + for d in range(1, T - t + 1): + temp += betal[t + d - 1, i] * np.exp(log_durs[d - 1, i]) * np.prod(np.exp(log_obs[t:t + d, i])) # durs might start at 1 + + temp += np.exp(log_survivals[T - t - 1, i]) * np.prod(np.exp(log_obs[t:, i])) # depends on coding of sf + betastarl[t, i] = temp + + for i in range(S): + for j in range(S): + betal[t - 1, i] += betastarl[t, j] * trans_mat[i, j] + + # one last time, to fill last row o fbetastarl + t = 0 + for i in range(S): + temp = 0 + for d in range(1, T - t + 1): + temp += betal[t + d - 1, i] * np.exp(log_durs[d - 1, i]) * np.prod(np.exp(log_obs[t:t + d, i])) # durs might start at 1 + + temp += np.exp(log_survivals[T - t - 1, i]) * np.prod(np.exp(log_obs[t:, i])) # depends on coding of sf + betastarl[t, i] = temp + + return np.log(betal), np.log(betastarl) + def messages_backwards_python(self): return super(HSMMStatesEigen,self).messages_backwards() def sample_forwards(self,betal,betastarl): - from pyhsmm.internals.hsmm_messages_interface import sample_forwards_log + # from pyhsmm.internals.hsmm_messages_interface import sample_forwards_log if self.left_censoring: raise NotImplementedError caBl = np.vstack((np.zeros(betal.shape[1]),np.cumsum(self.aBl[:-1],axis=0))) - self.stateseq = sample_forwards_log( - self.trans_matrix,caBl,self.aDl,self.pi_0,betal,betastarl, - np.empty(betal.shape[0],dtype='int32')) - assert not (0 == self.stateseq).all() + + # self.stateseq = sample_forwards_log( + # self.trans_matrix,caBl,self.aDl,self.pi_0,betal,betastarl, + # np.empty(betal.shape[0],dtype='int32')) + + temp_a = self.trans_matrix + temp_b = caBl + temp_c = self.aDl + temp_d = self.pi_0 + temp_e = betal + temp_f = betastarl + temp_g = np.empty(betal.shape[0], dtype='int32') + temp = (temp_a, temp_b, temp_c, temp_d, temp_e, temp_f, temp_g) + + tempali = sample_forwards_log_mypy(*temp) + self.stateseq = tempali + # assert not (0 == self.stateseq).all() def sample_forwards_python(self,betal,betastarl): return super(HSMMStatesEigen,self).sample_forwards(betal,betastarl) @@ -994,7 +1074,6 @@ def hsmm_messages_backwards_log( betal[-1] = 0. for t in range(T-1,-1,-1): cB, offset = cumulative_obs_potentials(t) - dp = dur_potentials(t) betastarl[t] = logsumexp( betal[t:t+cB.shape[0]] + cB + dur_potentials(t), axis=0) betastarl[t] -= offset @@ -1131,3 +1210,99 @@ def hsmm_maximizing_assignment( return stateseq + +def sample_forwards_log_myc(trans_matrix, caBl, aDl, pi_0, betal, betastarl, irrev): + t = 0 + T, limit_n = betal.shape + nextstate_distr = pi_0 + stateseq = np.empty(T, dtype=np.int32) + + while t < T: + logdomain = betastarl[t] - betastarl[t].max() + nextstate_distr *= np.exp(logdomain) + if (nextstate_distr == 0.).all(): + print("Warning: all-zero posterior state belief, following likelihood") + nextstate_distr = np.exp(logdomain) + + # sample_discrete behaves weirdly, doesn't need normalization? + # https://www.sidefx.com/docs/houdini/vex/functions/sample_discrete.html + # seems to work fine + state = np.random.choice(limit_n, p=nextstate_distr / nextstate_distr.sum()) + + durprob = random.random() + dur = 0 + while durprob > 0. and t + dur < T: + p_d_prior = np.exp(aDl[dur, state]) + if 0.0 == p_d_prior: + dur += 1 + continue + + p_d = p_d_prior * np.exp(caBl[t+dur+1, state] - caBl[t, state] + betal[t+dur, state] - betastarl[t, state]) + durprob -= p_d + dur += 1 + + stateseq[t: t + dur] = state + t += dur + nextstate_distr = trans_matrix[state] + + +def sample_forwards_log_mypy(trans_matrix, caBl, aDl, pi_0, betal, betastarl, irrev): + # Does Matt rely on no impossible observations, since he uses the cumulative obs? + T, limit_n = betal.shape + stateseq = np.empty(T, dtype=np.int32) + + assert np.allclose(np.sum(pi_0), 1) + assert np.allclose(np.sum(trans_matrix, axis=1), np.ones(limit_n)) # not quite clear why this all_close is necessary here + + total_dur = 0 + current_state = None + # TODO: do I end sequence at right time? + while total_dur < T: + # this was previously done in a try except, with the except hitting if the attempted normalisation fails. Now we do it always, should give more precision (mostly negligible) + # also the except clause modified betastarl, causing a bug later on + temp_beta_potential = betastarl[total_dur] - np.max(betastarl[total_dur]) + state_dist_un = trans_matrix[current_state] * np.exp(temp_beta_potential) if total_dur != 0 else pi_0 * np.exp(temp_beta_potential) + state_dist = state_dist_un / np.sum(state_dist_un) + + next_state = np.random.choice(limit_n, p=state_dist) + + # old duration drawing code, not as efficient as it always generates the entire probability vector + # temp = np.zeros(T - total_dur) # generate one extra, for 1 - P(staying within T) + # for d in range(T - total_dur - 1): + # temp[d] += aDl[d, next_state] + # temp[d] += caBl[total_dur + d + 1, next_state] - caBl[total_dur, next_state] + # temp[d] += betal[total_dur + d, next_state] # no + 1 here, B goes from 1 (-> 0) to T (-> T-1) + # temp[:-1] -= betastarl[total_dur, next_state] + # temp[:-1] = np.exp(temp[:-1]) # TODO: but weird that we have to do -2 here, why skip 2? + # temp[-1] = max(0, 1 - np.sum(temp)) # cap at 0, don't let negativity intrude, TODO: let diff not be to big + # sample_dur = np.random.choice(T - total_dur, p=temp) + 1 + # this bit of code actually does the same as random.choice, pretty neat + # # d = 0 + # # durprob = np.random.rand() + # # while durprob > 0.: + # # durprob -= temp[d] + # # d += 1 + # # sample_dur = d + + d = 0 + durprob = np.random.rand() + while durprob > 0. and total_dur + d < T - 1: + p_d_prior = np.exp(aDl[d, next_state]) + if 0.0 == p_d_prior: + d += 1 + continue + + p_d = p_d_prior * np.exp(caBl[total_dur + d + 1, next_state] + - caBl[total_dur, next_state] + betal[total_dur + d, next_state] + - betastarl[total_dur, next_state]) + durprob -= p_d + d += 1 + sample_dur = d + if total_dur + d == T - 1 and durprob > 0.: + sample_dur += 1 + + stateseq[total_dur: min(total_dur + sample_dur, T)] = next_state # I really want to go till end of array, no self.T - 1 in min + total_dur += sample_dur + current_state = next_state + + return stateseq diff --git a/pyhsmm/internals/initial_state.py b/pyhsmm/internals/initial_state.py index ca0bba8..10feb0c 100644 --- a/pyhsmm/internals/initial_state.py +++ b/pyhsmm/internals/initial_state.py @@ -8,7 +8,7 @@ from pyhsmm.basic.abstractions import GibbsSampling, MaxLikelihood from pyhsmm.basic.distributions import Categorical class UniformInitialState(object): - def __init__(self,model): + def __init__(self,model,init_state_concentration=None,pi_0=None): # Changed, swallow these arguments self.model = model @property diff --git a/pyhsmm/internals/transitions.py b/pyhsmm/internals/transitions.py index 5921c66..9e335ce 100644 --- a/pyhsmm/internals/transitions.py +++ b/pyhsmm/internals/transitions.py @@ -213,10 +213,12 @@ class _HSMMTransitionsGibbs(_HSMMTransitionsBase,_HMMTransitionsGibbs): if trans_counts.sum() > 0: froms = trans_counts.sum(1) - self_trans = [np.random.geometric(1-A_ii,size=n).sum() if n > 0 else 0 + # big change here, SAB!!! + self_trans = [min(50000, np.random.geometric(1-A_ii,size=n).sum() if n > 0 else 0) for A_ii, n in zip(self.full_trans_matrix.diagonal(),froms)] trans_counts += np.diag(self_trans) + #print("trans_counts {}".format(trans_counts.max())) return trans_counts class _HSMMTransitionsMaxLikelihood(_HSMMTransitionsBase,_HMMTransitionsMaxLikelihood): @@ -581,4 +583,3 @@ class _DATruncHDPHSMMTransitionsSVI(_DATruncHDPHMMTransitionsSVI,_HSMMTransition class DATruncHDPHSMMTransitions(_DATruncHDPHSMMTransitionsSVI): pass - diff --git a/pyhsmm/models.py b/pyhsmm/models.py index e3b1b9b..bf5a5ff 100644 --- a/pyhsmm/models.py +++ b/pyhsmm/models.py @@ -11,7 +11,8 @@ import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec from matplotlib import cm from warnings import warn -from scipy.special import logsumexp +from scipy.misc import logsumexp +from scipy.stats import invgamma, chi2 from pyhsmm.basic.abstractions import Model, ModelGibbsSampling, \ ModelEM, ModelMAPEM, ModelMeanField, ModelMeanFieldSVI, ModelParallelTempering @@ -30,16 +31,19 @@ from pybasicbayes.distributions.gaussian import Gaussian class _HMMBase(Model): _states_class = hmm_states.HMMStatesPython _trans_class = transitions.HMMTransitions - _trans_conc_class = transitions.HMMTransitionsConc - _init_state_class = initial_state.HMMInitialState + _trans_conc_class = transitions.HMMTransitionsConc # conc stands for concentration, here the concentration is also inferred + _init_state_class = initial_state.HMMInitialState # Changed HMMInitialState -> UniformInitialState def __init__(self, obs_distns, trans_distn=None, alpha=None,alpha_a_0=None,alpha_b_0=None,trans_matrix=None, - init_state_distn=None,init_state_concentration=None,pi_0=None): + init_state_distn=None,init_state_concentration=None,pi_0=None, + var_prior=None): self.obs_distns = obs_distns self.states_list = [] + self.timepoint = -1 + self.var_prior = var_prior if trans_distn is not None: self.trans_distn = trans_distn @@ -159,6 +163,31 @@ class _HMMBase(Model): return outs + def delete_data(self): + for s in self.states_list: + s.data = None + + def delete_obs_data(self): + for o in self.obs_distns: + o.sigma_k = None + o.sigma_k_k_minus = None + o.gain_save = None + o.x_hat_k = None + o.x_hat_k_k_minus = None + o.R = None + o.pseudo_obs = None + o.H = None + o.pseudo_Q = None + + def delete_dur_data(self): + for d in self.dur_distns: + d.r_support = None + d.r_probs = None + d.rho_0 = None + d.rho_mf = None + d.p_save = d.p + d._fixedr_distns = None # can't access p of the overall dist if I delete this. But if need be I could probs delete the prior of these dists + @property def stateseqs(self): return [s.stateseq for s in self.states_list] @@ -192,8 +221,8 @@ class _HMMBase(Model): for s in self.states_list: for state in s.stateseq: canonical_ids[state] - return list(map(operator.itemgetter(0), - sorted(canonical_ids.items(),key=operator.itemgetter(1)))) + return map(operator.itemgetter(0), + sorted(canonical_ids.items(),key=operator.itemgetter(1))) @property def state_usages(self): @@ -448,8 +477,24 @@ class _HMMGibbsSampling(_HMMBase,ModelGibbsSampling): self.resample_init_state_distn() def resample_obs_distns(self): - for state, distn in enumerate(self.obs_distns): - distn.resample([s.data[s.stateseq == state] for s in self.states_list]) + if self.var_prior is None: + for state, distn in enumerate(self.obs_distns): + distn.resample([s.data[s.stateseq == state] for s in self.states_list]) + else: + psi_diffs = [] + for state, distn in enumerate(self.obs_distns): + distn.resample([s.data[s.stateseq == state] for s in self.states_list]) + psi_diffs.append(distn.psi_diff_saves) + + if self.var_prior == 'uniform': + # resample sigma squared of the dynamic distributions + diffs = np.concatenate(psi_diffs) + fraction = np.sum(diffs ** 2) / 2 + # sigmasq_states = invgamma.rvs(self.obs_distns[0].sigma_alpha + diffs.size / 2, scale=self.obs_distns[0].sigma_beta + fraction) + sigmasq_states = 2 * fraction / chi2(diffs.size - 1).rvs() # uniform prior, see Gelman (p.598, tau update) + for distn in self.obs_distns: + distn.sigmasq_states = sigmasq_states + self._clear_caches() @line_profiled @@ -671,8 +716,12 @@ class _HMMEM(_HMMBase,ModelEM): # approximation!) assert data is None and len(self.states_list) > 0, 'Must have data to get BIC' if data is None: - return -2*sum(self.log_likelihood(s.data).sum() for s in self.states_list) + \ - self.num_parameters() * np.log( + # return -2*sum(self.log_likelihood(s.data).sum() for s in self.states_list) + \ + # self.num_parameters() * np.log( + # sum(s.data.shape[0] for s in self.states_list)) + # I changed this ! SAB + return -2*sum(s.log_likelihood().sum() for s in self.states_list) + \ + self.num_parameters * np.log( sum(s.data.shape[0] for s in self.states_list)) else: return -2*self.log_likelihood(data) + self.num_parameters() * np.log(data.shape[0]) @@ -898,6 +947,7 @@ class _HSMMBase(_HMMBase): def add_data(self,data,stateseq=None,trunc=None, right_censoring=True,left_censoring=False,**kwargs): + self.timepoint += 1 self.states_list.append(self._states_class( model=self, data=np.asarray(data), @@ -905,6 +955,7 @@ class _HSMMBase(_HMMBase): right_censoring=right_censoring, left_censoring=left_censoring, trunc=trunc, + timepoint=self.timepoint, **kwargs)) return self.states_list[-1] @@ -1393,4 +1444,3 @@ class WeakLimitHDPHSMMTruncatedIntNegBinSeparateTrans( _SeparateTransMixin, WeakLimitHDPHSMMTruncatedIntNegBin): _states_class = hsmm_inb_states.HSMMStatesTruncatedIntegerNegativeBinomialSeparateTrans - -- GitLab