diff --git a/pyhsmm/internals/hmm_messages_interface.pyx b/pyhsmm/internals/hmm_messages_interface.pyx index 7575cb39e428db7170860ae633b650b7316691f7..64e680d84c9131eb9a8ea2558f960b96b189643d 100644 --- a/pyhsmm/internals/hmm_messages_interface.pyx +++ b/pyhsmm/internals/hmm_messages_interface.pyx @@ -102,13 +102,17 @@ def expected_statistics_log( np.ndarray[floating,ndim=2,mode='c'] alphal not None, np.ndarray[floating,ndim=2,mode='c'] betal not None, np.ndarray[floating,ndim=2,mode='c'] expected_states not None, - np.ndarray[floating,ndim=2,mode='c'] expected_transcounts not None, + expected_transcounts not None, ): cdef hmmc[floating] ref cdef bool hetero = log_trans_potential.ndim == 3 cdef floating[:,:,::1] _A = log_trans_potential if hetero \ else np.expand_dims(log_trans_potential, 0) + cdef floating[:,:,::1] _expected_transcounts = \ + expected_transcounts if hetero \ + else np.expand_dims(expected_transcounts, 0) + cdef floating log_normalizer = ref.expected_statistics_log( hetero, @@ -118,7 +122,7 @@ def expected_statistics_log( &alphal[0,0], &betal[0,0], &expected_states[0,0], - &expected_transcounts[0,0]) + &_expected_transcounts[0,0,0]) return expected_states, expected_transcounts, log_normalizer def messages_forwards_normalized(