From 0d6fdeb15484392bafc3eb236b13f658c4f53ea1 Mon Sep 17 00:00:00 2001
From: Scott Linderman <slinderman@seas.harvard.edu>
Date: Mon, 18 Jan 2016 17:06:42 -0500
Subject: [PATCH] Adding a flag to specify whether a data objects state
 sequence should be resampled'

---
 pyhsmm/internals/hmm_states.py | 9 +++++++--
 pyhsmm/models.py               | 5 +++--
 2 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/pyhsmm/internals/hmm_states.py b/pyhsmm/internals/hmm_states.py
index 2ab5362..b138531 100644
--- a/pyhsmm/internals/hmm_states.py
+++ b/pyhsmm/internals/hmm_states.py
@@ -20,7 +20,7 @@ class _StatesBase(object):
     __metaclass__ = abc.ABCMeta
 
     def __init__(self,model,T=None,data=None,stateseq=None,
-            generate=True,initialize_from_prior=True):
+            generate=True,initialize_from_prior=True, fixed_stateseq=False):
         self.model = model
 
         self.T = T if T is not None else data.shape[0]
@@ -28,6 +28,10 @@ class _StatesBase(object):
 
         self.clear_caches()
 
+        self.fixed_stateseq = fixed_stateseq
+        if fixed_stateseq:
+            assert stateseq is not None, "fixed_stateseq requires a stateseq to be supplied"
+
         if stateseq is not None:
             self.stateseq = np.array(stateseq,dtype=np.int32)
         elif generate:
@@ -353,7 +357,8 @@ class HMMStatesPython(_StatesBase):
         self.sample_backwards_normalized(alphan)
 
     def resample(self):
-        return self.resample_normalized()
+        if not self.fixed_stateseq:
+            return self.resample_normalized()
 
     @staticmethod
     def _sample_forwards_log(betal,trans_matrix,init_state_distn,log_likelihoods):
diff --git a/pyhsmm/models.py b/pyhsmm/models.py
index 5c044ff..9e10134 100644
--- a/pyhsmm/models.py
+++ b/pyhsmm/models.py
@@ -61,11 +61,12 @@ class _HMMBase(Model):
 
         self._clear_caches()
 
-    def add_data(self,data,stateseq=None,**kwargs):
+    def add_data(self,data,stateseq=None,fixed_stateseq=False,**kwargs):
         self.states_list.append(
                 self._states_class(
                     model=self,data=data,
-                    stateseq=stateseq,**kwargs))
+                    stateseq=stateseq, fixed_stateseq=fixed_stateseq,
+                    **kwargs))
         return self.states_list[-1]
 
     def generate(self,T,keep=True):
-- 
GitLab