From ecb947bad0cd42aa2b92124265140c19ea101b15 Mon Sep 17 00:00:00 2001 From: Scott Linderman <slinderman@seas.harvard.edu> Date: Mon, 1 Feb 2016 13:34:53 -0500 Subject: [PATCH] updating expected statistics to work with fixed stateseqs --- pyhsmm/internals/hmm_states.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/pyhsmm/internals/hmm_states.py b/pyhsmm/internals/hmm_states.py index 5e9453a..328ac8a 100644 --- a/pyhsmm/internals/hmm_states.py +++ b/pyhsmm/internals/hmm_states.py @@ -497,12 +497,21 @@ class HMMStatesPython(_StatesBase): return linear_term - (new_normalizer - old_normalizer) def _expected_statistics(self,trans_potential,init_potential,likelihood_log_potential): - alphal = self._messages_forwards_log(trans_potential,init_potential, - likelihood_log_potential) - betal = self._messages_backwards_log(trans_potential,likelihood_log_potential) - expected_states, expected_transcounts, normalizer = \ - self._expected_statistics_from_messages(trans_potential,likelihood_log_potential,alphal,betal) - assert not np.isinf(expected_states).any() + if self.fixed_stateseq: + expected_states = np.zeros((self.T, self.num_states)) + expected_states[np.arange(self.T), self.stateseq] = 1.0 + + expected_transcounts = np.zeros((self.T-1, self.num_states, self.num_states)) + expected_transcounts[np.arange(self.T-1), self.stateseq[:-1], self.stateseq[1:]] = 1.0 + + normalizer = 0 + else: + alphal = self._messages_forwards_log(trans_potential,init_potential, + likelihood_log_potential) + betal = self._messages_backwards_log(trans_potential,likelihood_log_potential) + expected_states, expected_transcounts, normalizer = \ + self._expected_statistics_from_messages(trans_potential,likelihood_log_potential,alphal,betal) + assert not np.isinf(expected_states).any() return expected_states, expected_transcounts, normalizer @staticmethod -- GitLab