diff --git a/pyhsmm/internals/hmm_messages.h b/pyhsmm/internals/hmm_messages.h index eb2b67342c30d22fb1e8790ffdb65dfab4bbd5b3..f10e9b6233a365ff5bf4bdd56b756090ed952e37 100644 --- a/pyhsmm/internals/hmm_messages.h +++ b/pyhsmm/internals/hmm_messages.h @@ -93,7 +93,7 @@ namespace hmm NPArray<Type> ealphal(alphal,T,M); NPArray<Type> eexpected_states(expected_states,T,M); - NPArray<Type> eexpected_transcounts(expected_transcounts,M,M); + NPArray<Type> eexpected_transcounts(expected_transcounts,hetero ? (T-1)*M : M,M); #ifdef HMM_TEMPS_ON_HEAP Array<Type,Dynamic,Dynamic> pair(M,M); @@ -111,7 +111,7 @@ namespace hmm pair.colwise() += ealphal.row(t).transpose().array(); pair.rowwise() += ebetal.row(t+1) + eaBl.row(t+1); - eexpected_transcounts += pair.exp(); + eexpected_transcounts.block(t*M*hetero,0,M,M) += pair.exp(); eexpected_states.row(t) += (ealphal.row(t) + ebetal.row(t) - log_normalizer).exp(); } eexpected_states.row(T-1) += (ealphal.row(T-1) - log_normalizer).exp();