From 333903b43765628911a76c08920fb39b9a4a87ef Mon Sep 17 00:00:00 2001 From: Scott Linderman <scott.linderman@gmail.com> Date: Sun, 31 Jan 2016 21:08:32 -0500 Subject: [PATCH] updating expected_statistics_log to compute (T-1) x M x M expected_trans_counts --- pyhsmm/internals/hmm_messages.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhsmm/internals/hmm_messages.h b/pyhsmm/internals/hmm_messages.h index eb2b673..f10e9b6 100644 --- a/pyhsmm/internals/hmm_messages.h +++ b/pyhsmm/internals/hmm_messages.h @@ -93,7 +93,7 @@ namespace hmm NPArray<Type> ealphal(alphal,T,M); NPArray<Type> eexpected_states(expected_states,T,M); - NPArray<Type> eexpected_transcounts(expected_transcounts,M,M); + NPArray<Type> eexpected_transcounts(expected_transcounts,hetero ? (T-1)*M : M,M); #ifdef HMM_TEMPS_ON_HEAP Array<Type,Dynamic,Dynamic> pair(M,M); @@ -111,7 +111,7 @@ namespace hmm pair.colwise() += ealphal.row(t).transpose().array(); pair.rowwise() += ebetal.row(t+1) + eaBl.row(t+1); - eexpected_transcounts += pair.exp(); + eexpected_transcounts.block(t*M*hetero,0,M,M) += pair.exp(); eexpected_states.row(t) += (ealphal.row(t) + ebetal.row(t) - log_normalizer).exp(); } eexpected_states.row(T-1) += (ealphal.row(T-1) - log_normalizer).exp(); -- GitLab