From 15091fd9e8a12d324dc2e83b7d3b6ebc533e7ba1 Mon Sep 17 00:00:00 2001 From: Scott Linderman <slinderman@seas.harvard.edu> Date: Mon, 18 Jan 2016 18:18:48 -0500 Subject: [PATCH] updating expected stats to work with time-varying transition matrices --- pyhsmm/internals/hmm_messages_interface.pyx | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyhsmm/internals/hmm_messages_interface.pyx b/pyhsmm/internals/hmm_messages_interface.pyx index 7575cb3..64e680d 100644 --- a/pyhsmm/internals/hmm_messages_interface.pyx +++ b/pyhsmm/internals/hmm_messages_interface.pyx @@ -102,13 +102,17 @@ def expected_statistics_log( np.ndarray[floating,ndim=2,mode='c'] alphal not None, np.ndarray[floating,ndim=2,mode='c'] betal not None, np.ndarray[floating,ndim=2,mode='c'] expected_states not None, - np.ndarray[floating,ndim=2,mode='c'] expected_transcounts not None, + expected_transcounts not None, ): cdef hmmc[floating] ref cdef bool hetero = log_trans_potential.ndim == 3 cdef floating[:,:,::1] _A = log_trans_potential if hetero \ else np.expand_dims(log_trans_potential, 0) + cdef floating[:,:,::1] _expected_transcounts = \ + expected_transcounts if hetero \ + else np.expand_dims(expected_transcounts, 0) + cdef floating log_normalizer = ref.expected_statistics_log( hetero, @@ -118,7 +122,7 @@ def expected_statistics_log( &alphal[0,0], &betal[0,0], &expected_states[0,0], - &expected_transcounts[0,0]) + &_expected_transcounts[0,0,0]) return expected_states, expected_transcounts, log_normalizer def messages_forwards_normalized( -- GitLab