From 9b8b3f34fd2acd58d45ea0ba1de916bf0c633bf5 Mon Sep 17 00:00:00 2001
From: Matthew Johnson <mattjj@csail.mit.edu>
Date: Thu, 11 Feb 2016 11:39:25 -0500
Subject: [PATCH] make HMM.predict handle 1D data

---
 pyhsmm/models.py | 3 ++-
 setup.py         | 2 +-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/pyhsmm/models.py b/pyhsmm/models.py
index 2ae2605..eb7c4fb 100644
--- a/pyhsmm/models.py
+++ b/pyhsmm/models.py
@@ -109,7 +109,8 @@ class _HMMBase(Model):
             return sum(s.log_likelihood() for s in self.states_list)
 
     def predict(self,seed_data,timesteps,**kwargs):
-        full_data = np.vstack((seed_data,np.nan*np.ones((timesteps,seed_data.shape[1]))))
+        padshape = (timesteps, seed_data.shape[1]) if seed_data.ndim == 2 else timesteps
+        full_data = np.concatenate((seed_data,np.nan*np.ones(padshape)))
         self.add_data(full_data,**kwargs)
         s = self.states_list.pop()
         s.resample()  # fills in states
diff --git a/setup.py b/setup.py
index 1540edd..39682e0 100644
--- a/setup.py
+++ b/setup.py
@@ -84,7 +84,7 @@ names = ['.'.join(os.path.split(p)) for p in paths]
 ext_modules = [
     Extension(
         name, sources=[path + '.cpp'],
-        include_dirs=[os.path.join('deps')],
+        include_dirs=['deps'],
         extra_compile_args=['-O3','-std=c++11','-DNDEBUG','-w','-DHMM_TEMPS_ON_HEAP'])
     for name, path in zip(names,paths)]
 
-- 
GitLab