diff --git a/pyhsmm/internals/hmm_messages_interface.pyx b/pyhsmm/internals/hmm_messages_interface.pyx
index 7575cb39e428db7170860ae633b650b7316691f7..64e680d84c9131eb9a8ea2558f960b96b189643d 100644
--- a/pyhsmm/internals/hmm_messages_interface.pyx
+++ b/pyhsmm/internals/hmm_messages_interface.pyx
@@ -102,13 +102,17 @@ def expected_statistics_log(
         np.ndarray[floating,ndim=2,mode='c'] alphal not None,
         np.ndarray[floating,ndim=2,mode='c'] betal not None,
         np.ndarray[floating,ndim=2,mode='c'] expected_states not None,
-        np.ndarray[floating,ndim=2,mode='c'] expected_transcounts not None,
+	expected_transcounts not None,
         ):
 
     cdef hmmc[floating] ref
     cdef bool hetero = log_trans_potential.ndim == 3
     cdef floating[:,:,::1] _A = log_trans_potential if hetero \
         else np.expand_dims(log_trans_potential, 0)
+    cdef floating[:,:,::1] _expected_transcounts = \
+    	 expected_transcounts if hetero \
+         else np.expand_dims(expected_transcounts, 0)
+
 
     cdef floating log_normalizer = ref.expected_statistics_log(
             hetero,
@@ -118,7 +122,7 @@ def expected_statistics_log(
             &alphal[0,0],
             &betal[0,0],
             &expected_states[0,0],
-            &expected_transcounts[0,0])
+            &_expected_transcounts[0,0,0])
     return expected_states, expected_transcounts, log_normalizer
 
 def messages_forwards_normalized(