From 15091fd9e8a12d324dc2e83b7d3b6ebc533e7ba1 Mon Sep 17 00:00:00 2001
From: Scott Linderman <slinderman@seas.harvard.edu>
Date: Mon, 18 Jan 2016 18:18:48 -0500
Subject: [PATCH] updating expected stats to work with time-varying transition
 matrices

---
 pyhsmm/internals/hmm_messages_interface.pyx | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/pyhsmm/internals/hmm_messages_interface.pyx b/pyhsmm/internals/hmm_messages_interface.pyx
index 7575cb3..64e680d 100644
--- a/pyhsmm/internals/hmm_messages_interface.pyx
+++ b/pyhsmm/internals/hmm_messages_interface.pyx
@@ -102,13 +102,17 @@ def expected_statistics_log(
         np.ndarray[floating,ndim=2,mode='c'] alphal not None,
         np.ndarray[floating,ndim=2,mode='c'] betal not None,
         np.ndarray[floating,ndim=2,mode='c'] expected_states not None,
-        np.ndarray[floating,ndim=2,mode='c'] expected_transcounts not None,
+	expected_transcounts not None,
         ):
 
     cdef hmmc[floating] ref
     cdef bool hetero = log_trans_potential.ndim == 3
     cdef floating[:,:,::1] _A = log_trans_potential if hetero \
         else np.expand_dims(log_trans_potential, 0)
+    cdef floating[:,:,::1] _expected_transcounts = \
+    	 expected_transcounts if hetero \
+         else np.expand_dims(expected_transcounts, 0)
+
 
     cdef floating log_normalizer = ref.expected_statistics_log(
             hetero,
@@ -118,7 +122,7 @@ def expected_statistics_log(
             &alphal[0,0],
             &betal[0,0],
             &expected_states[0,0],
-            &expected_transcounts[0,0])
+            &_expected_transcounts[0,0,0])
     return expected_states, expected_transcounts, log_normalizer
 
 def messages_forwards_normalized(
-- 
GitLab