From 9b8b3f34fd2acd58d45ea0ba1de916bf0c633bf5 Mon Sep 17 00:00:00 2001 From: Matthew Johnson <mattjj@csail.mit.edu> Date: Thu, 11 Feb 2016 11:39:25 -0500 Subject: [PATCH] make HMM.predict handle 1D data --- pyhsmm/models.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyhsmm/models.py b/pyhsmm/models.py index 2ae2605..eb7c4fb 100644 --- a/pyhsmm/models.py +++ b/pyhsmm/models.py @@ -109,7 +109,8 @@ class _HMMBase(Model): return sum(s.log_likelihood() for s in self.states_list) def predict(self,seed_data,timesteps,**kwargs): - full_data = np.vstack((seed_data,np.nan*np.ones((timesteps,seed_data.shape[1])))) + padshape = (timesteps, seed_data.shape[1]) if seed_data.ndim == 2 else timesteps + full_data = np.concatenate((seed_data,np.nan*np.ones(padshape))) self.add_data(full_data,**kwargs) s = self.states_list.pop() s.resample() # fills in states diff --git a/setup.py b/setup.py index 1540edd..39682e0 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,7 @@ names = ['.'.join(os.path.split(p)) for p in paths] ext_modules = [ Extension( name, sources=[path + '.cpp'], - include_dirs=[os.path.join('deps')], + include_dirs=['deps'], extra_compile_args=['-O3','-std=c++11','-DNDEBUG','-w','-DHMM_TEMPS_ON_HEAP']) for name, path in zip(names,paths)] -- GitLab