From d28f3ba4dc813e48deb31495ad5f9c092796b76c Mon Sep 17 00:00:00 2001
From: Matthew Johnson <mattjj@csail.mit.edu>
Date: Sat, 6 May 2017 12:48:04 -0700
Subject: [PATCH] fix copy_sample for initial state objects, fixes #75

---
 pyhsmm/internals/initial_state.py | 17 +++++++++++++++--
 pyhsmm/models.py                  |  2 +-
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/pyhsmm/internals/initial_state.py b/pyhsmm/internals/initial_state.py
index c0514a6..ca0bba8 100644
--- a/pyhsmm/internals/initial_state.py
+++ b/pyhsmm/internals/initial_state.py
@@ -1,5 +1,6 @@
 from __future__ import division
 import numpy as np
+import copy
 
 import pyhsmm
 from pyhsmm.util.general import top_eigenvector
@@ -41,6 +42,11 @@ class UniformInitialState(object):
     def clear_caches(self):
         pass
 
+    def copy_sample(self, new_model):
+        new = copy.copy(self)
+        new.model = new_model
+        return new
+
 class HMMInitialState(Categorical):
     def __init__(self,model,init_state_concentration=None,pi_0=None):
         self.model = model
@@ -73,7 +79,6 @@ class HMMInitialState(Categorical):
     def clear_caches(self):
         pass
 
-
     def meanfieldupdate(self,expected_initial_states_list):
         super(HMMInitialState,self).meanfieldupdate(None,expected_initial_states_list)
 
@@ -85,6 +90,10 @@ class HMMInitialState(Categorical):
         super(HMMInitialState,self).max_likelihood(
                 data=samples,weights=expected_states_list)
 
+    def copy_sample(self, new_model):
+        new = copy.deepcopy(self)
+        new.model = new_model
+        return new
 
 class StartInZero(GibbsSampling,MaxLikelihood):
     def __init__(self,num_states,**kwargs):
@@ -100,6 +109,11 @@ class StartInZero(GibbsSampling,MaxLikelihood):
     def max_likelihood(*args,**kwargs):
         pass
 
+    def copy_sample(self, new_model):
+        new = copy.copy(self)
+        new.model = new_model
+        return new
+
 class HSMMInitialState(HMMInitialState):
     @property
     def steady_state_distribution(self):
@@ -112,4 +126,3 @@ class HSMMInitialState(HMMInitialState):
 
     def clear_caches(self):
         self._steady_state_distribution = None
-
diff --git a/pyhsmm/models.py b/pyhsmm/models.py
index 63c8aad..ee98bbe 100644
--- a/pyhsmm/models.py
+++ b/pyhsmm/models.py
@@ -472,7 +472,7 @@ class _HMMGibbsSampling(_HMMBase,ModelGibbsSampling):
         new = copy.copy(self)
         new.obs_distns = [o.copy_sample() for o in self.obs_distns]
         new.trans_distn = self.trans_distn.copy_sample()
-        new.init_state_distn = self.init_state_distn.copy_sample()
+        new.init_state_distn = self.init_state_distn.copy_sample(new)
         new.states_list = [s.copy_sample(new) for s in self.states_list]
         return new
 
-- 
GitLab