diff --git a/pyhsmm/models.py b/pyhsmm/models.py index 2ae260525071aac3608366b14431fc3ca487b45a..eb7c4fbd957ecb679f1c5ce1094518b86e6718f7 100644 --- a/pyhsmm/models.py +++ b/pyhsmm/models.py @@ -109,7 +109,8 @@ class _HMMBase(Model): return sum(s.log_likelihood() for s in self.states_list) def predict(self,seed_data,timesteps,**kwargs): - full_data = np.vstack((seed_data,np.nan*np.ones((timesteps,seed_data.shape[1])))) + padshape = (timesteps, seed_data.shape[1]) if seed_data.ndim == 2 else timesteps + full_data = np.concatenate((seed_data,np.nan*np.ones(padshape))) self.add_data(full_data,**kwargs) s = self.states_list.pop() s.resample() # fills in states diff --git a/setup.py b/setup.py index 1540eddaeb722941a3ea3adfe30f3c288db29b01..39682e0c1c2d953b7352cc2737c32b9734c937cd 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,7 @@ names = ['.'.join(os.path.split(p)) for p in paths] ext_modules = [ Extension( name, sources=[path + '.cpp'], - include_dirs=[os.path.join('deps')], + include_dirs=['deps'], extra_compile_args=['-O3','-std=c++11','-DNDEBUG','-w','-DHMM_TEMPS_ON_HEAP']) for name, path in zip(names,paths)]