diff --git a/pyhsmm/internals/hmm_states.py b/pyhsmm/internals/hmm_states.py index 5e9453a9305f1fadb5a82bc02cebfd77fdb44114..328ac8ae75904533ca5d14ae16bde37e2fa64bd1 100644 --- a/pyhsmm/internals/hmm_states.py +++ b/pyhsmm/internals/hmm_states.py @@ -497,12 +497,21 @@ class HMMStatesPython(_StatesBase): return linear_term - (new_normalizer - old_normalizer) def _expected_statistics(self,trans_potential,init_potential,likelihood_log_potential): - alphal = self._messages_forwards_log(trans_potential,init_potential, - likelihood_log_potential) - betal = self._messages_backwards_log(trans_potential,likelihood_log_potential) - expected_states, expected_transcounts, normalizer = \ - self._expected_statistics_from_messages(trans_potential,likelihood_log_potential,alphal,betal) - assert not np.isinf(expected_states).any() + if self.fixed_stateseq: + expected_states = np.zeros((self.T, self.num_states)) + expected_states[np.arange(self.T), self.stateseq] = 1.0 + + expected_transcounts = np.zeros((self.T-1, self.num_states, self.num_states)) + expected_transcounts[np.arange(self.T-1), self.stateseq[:-1], self.stateseq[1:]] = 1.0 + + normalizer = 0 + else: + alphal = self._messages_forwards_log(trans_potential,init_potential, + likelihood_log_potential) + betal = self._messages_backwards_log(trans_potential,likelihood_log_potential) + expected_states, expected_transcounts, normalizer = \ + self._expected_statistics_from_messages(trans_potential,likelihood_log_potential,alphal,betal) + assert not np.isinf(expected_states).any() return expected_states, expected_transcounts, normalizer @staticmethod