Skip to content
Snippets Groups Projects
Commit d28f3ba4 authored by Matthew Johnson's avatar Matthew Johnson
Browse files

fix copy_sample for initial state objects, fixes #75

parent 5c747381
No related branches found
No related tags found
No related merge requests found
from __future__ import division
import numpy as np
import copy
import pyhsmm
from pyhsmm.util.general import top_eigenvector
......@@ -41,6 +42,11 @@ class UniformInitialState(object):
def clear_caches(self):
pass
def copy_sample(self, new_model):
new = copy.copy(self)
new.model = new_model
return new
class HMMInitialState(Categorical):
def __init__(self,model,init_state_concentration=None,pi_0=None):
self.model = model
......@@ -73,7 +79,6 @@ class HMMInitialState(Categorical):
def clear_caches(self):
pass
def meanfieldupdate(self,expected_initial_states_list):
super(HMMInitialState,self).meanfieldupdate(None,expected_initial_states_list)
......@@ -85,6 +90,10 @@ class HMMInitialState(Categorical):
super(HMMInitialState,self).max_likelihood(
data=samples,weights=expected_states_list)
def copy_sample(self, new_model):
new = copy.deepcopy(self)
new.model = new_model
return new
class StartInZero(GibbsSampling,MaxLikelihood):
def __init__(self,num_states,**kwargs):
......@@ -100,6 +109,11 @@ class StartInZero(GibbsSampling,MaxLikelihood):
def max_likelihood(*args,**kwargs):
pass
def copy_sample(self, new_model):
new = copy.copy(self)
new.model = new_model
return new
class HSMMInitialState(HMMInitialState):
@property
def steady_state_distribution(self):
......@@ -112,4 +126,3 @@ class HSMMInitialState(HMMInitialState):
def clear_caches(self):
self._steady_state_distribution = None
......@@ -472,7 +472,7 @@ class _HMMGibbsSampling(_HMMBase,ModelGibbsSampling):
new = copy.copy(self)
new.obs_distns = [o.copy_sample() for o in self.obs_distns]
new.trans_distn = self.trans_distn.copy_sample()
new.init_state_distn = self.init_state_distn.copy_sample()
new.init_state_distn = self.init_state_distn.copy_sample(new)
new.states_list = [s.copy_sample(new) for s in self.states_list]
return new
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment