From 0d6fdeb15484392bafc3eb236b13f658c4f53ea1 Mon Sep 17 00:00:00 2001 From: Scott Linderman <slinderman@seas.harvard.edu> Date: Mon, 18 Jan 2016 17:06:42 -0500 Subject: [PATCH] Adding a flag to specify whether a data objects state sequence should be resampled' --- pyhsmm/internals/hmm_states.py | 9 +++++++-- pyhsmm/models.py | 5 +++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyhsmm/internals/hmm_states.py b/pyhsmm/internals/hmm_states.py index 2ab5362..b138531 100644 --- a/pyhsmm/internals/hmm_states.py +++ b/pyhsmm/internals/hmm_states.py @@ -20,7 +20,7 @@ class _StatesBase(object): __metaclass__ = abc.ABCMeta def __init__(self,model,T=None,data=None,stateseq=None, - generate=True,initialize_from_prior=True): + generate=True,initialize_from_prior=True, fixed_stateseq=False): self.model = model self.T = T if T is not None else data.shape[0] @@ -28,6 +28,10 @@ class _StatesBase(object): self.clear_caches() + self.fixed_stateseq = fixed_stateseq + if fixed_stateseq: + assert stateseq is not None, "fixed_stateseq requires a stateseq to be supplied" + if stateseq is not None: self.stateseq = np.array(stateseq,dtype=np.int32) elif generate: @@ -353,7 +357,8 @@ class HMMStatesPython(_StatesBase): self.sample_backwards_normalized(alphan) def resample(self): - return self.resample_normalized() + if not self.fixed_stateseq: + return self.resample_normalized() @staticmethod def _sample_forwards_log(betal,trans_matrix,init_state_distn,log_likelihoods): diff --git a/pyhsmm/models.py b/pyhsmm/models.py index 5c044ff..9e10134 100644 --- a/pyhsmm/models.py +++ b/pyhsmm/models.py @@ -61,11 +61,12 @@ class _HMMBase(Model): self._clear_caches() - def add_data(self,data,stateseq=None,**kwargs): + def add_data(self,data,stateseq=None,fixed_stateseq=False,**kwargs): self.states_list.append( self._states_class( model=self,data=data, - stateseq=stateseq,**kwargs)) + stateseq=stateseq, fixed_stateseq=fixed_stateseq, + **kwargs)) return self.states_list[-1] def generate(self,T,keep=True): -- GitLab