From d28f3ba4dc813e48deb31495ad5f9c092796b76c Mon Sep 17 00:00:00 2001 From: Matthew Johnson <mattjj@csail.mit.edu> Date: Sat, 6 May 2017 12:48:04 -0700 Subject: [PATCH] fix copy_sample for initial state objects, fixes #75 --- pyhsmm/internals/initial_state.py | 17 +++++++++++++++-- pyhsmm/models.py | 2 +- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pyhsmm/internals/initial_state.py b/pyhsmm/internals/initial_state.py index c0514a6..ca0bba8 100644 --- a/pyhsmm/internals/initial_state.py +++ b/pyhsmm/internals/initial_state.py @@ -1,5 +1,6 @@ from __future__ import division import numpy as np +import copy import pyhsmm from pyhsmm.util.general import top_eigenvector @@ -41,6 +42,11 @@ class UniformInitialState(object): def clear_caches(self): pass + def copy_sample(self, new_model): + new = copy.copy(self) + new.model = new_model + return new + class HMMInitialState(Categorical): def __init__(self,model,init_state_concentration=None,pi_0=None): self.model = model @@ -73,7 +79,6 @@ class HMMInitialState(Categorical): def clear_caches(self): pass - def meanfieldupdate(self,expected_initial_states_list): super(HMMInitialState,self).meanfieldupdate(None,expected_initial_states_list) @@ -85,6 +90,10 @@ class HMMInitialState(Categorical): super(HMMInitialState,self).max_likelihood( data=samples,weights=expected_states_list) + def copy_sample(self, new_model): + new = copy.deepcopy(self) + new.model = new_model + return new class StartInZero(GibbsSampling,MaxLikelihood): def __init__(self,num_states,**kwargs): @@ -100,6 +109,11 @@ class StartInZero(GibbsSampling,MaxLikelihood): def max_likelihood(*args,**kwargs): pass + def copy_sample(self, new_model): + new = copy.copy(self) + new.model = new_model + return new + class HSMMInitialState(HMMInitialState): @property def steady_state_distribution(self): @@ -112,4 +126,3 @@ class HSMMInitialState(HMMInitialState): def clear_caches(self): self._steady_state_distribution = None - diff --git a/pyhsmm/models.py b/pyhsmm/models.py index 63c8aad..ee98bbe 100644 --- a/pyhsmm/models.py +++ b/pyhsmm/models.py @@ -472,7 +472,7 @@ class _HMMGibbsSampling(_HMMBase,ModelGibbsSampling): new = copy.copy(self) new.obs_distns = [o.copy_sample() for o in self.obs_distns] new.trans_distn = self.trans_distn.copy_sample() - new.init_state_distn = self.init_state_distn.copy_sample() + new.init_state_distn = self.init_state_distn.copy_sample(new) new.states_list = [s.copy_sample(new) for s in self.states_list] return new -- GitLab