From 333903b43765628911a76c08920fb39b9a4a87ef Mon Sep 17 00:00:00 2001
From: Scott Linderman <scott.linderman@gmail.com>
Date: Sun, 31 Jan 2016 21:08:32 -0500
Subject: [PATCH] updating expected_statistics_log to compute (T-1) x M x M
 expected_trans_counts

---
 pyhsmm/internals/hmm_messages.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/pyhsmm/internals/hmm_messages.h b/pyhsmm/internals/hmm_messages.h
index eb2b673..f10e9b6 100644
--- a/pyhsmm/internals/hmm_messages.h
+++ b/pyhsmm/internals/hmm_messages.h
@@ -93,7 +93,7 @@ namespace hmm
         NPArray<Type> ealphal(alphal,T,M);
 
         NPArray<Type> eexpected_states(expected_states,T,M);
-        NPArray<Type> eexpected_transcounts(expected_transcounts,M,M);
+        NPArray<Type> eexpected_transcounts(expected_transcounts,hetero ? (T-1)*M : M,M);
 
 #ifdef HMM_TEMPS_ON_HEAP
         Array<Type,Dynamic,Dynamic> pair(M,M);
@@ -111,7 +111,7 @@ namespace hmm
             pair.colwise() += ealphal.row(t).transpose().array();
             pair.rowwise() += ebetal.row(t+1) + eaBl.row(t+1);
 
-            eexpected_transcounts += pair.exp();
+            eexpected_transcounts.block(t*M*hetero,0,M,M) += pair.exp();
             eexpected_states.row(t) += (ealphal.row(t) + ebetal.row(t) - log_normalizer).exp();
         }
         eexpected_states.row(T-1) += (ealphal.row(T-1) - log_normalizer).exp();
-- 
GitLab