Skip to content
Snippets Groups Projects
Commit 1b578620 authored by Scott Linderman's avatar Scott Linderman
Browse files

working around weird bug with broadcasting over strided arrays

parent f9406df0
No related branches found
No related tags found
No related merge requests found
......@@ -985,8 +985,8 @@ class RobustRegression(Regression):
n, D = y.shape
scaled_x = x * precisions[:, None]
scaled_y = y * precisions[:, None]
scaled_x = x * precisions[:, na]
scaled_y = y * precisions[:, na]
xxT = scaled_x.T.dot(x)
yxT = scaled_y.T.dot(x)
yyT = scaled_y.T.dot(y)
......@@ -1006,7 +1006,9 @@ class RobustRegression(Regression):
precisions = precisions[~bad]
n, D = data.shape[0], self.D_out
scaled_data = data * precisions[: None]
# This tile call is suboptimal but without it we can hit issues
# with strided data, as in autoregressive models.
scaled_data = data * np.tile(precisions[:,None], (1, data.shape[1]))
statmat = scaled_data.T.dot(data)
xxT, yxT, yyT = \
......@@ -1044,7 +1046,7 @@ class RobustRegression(Regression):
x, y = data
else:
x, y = data[:, :self.D_in], data[:, self.D_in:]
x, y = data[:, :-self.D_out], data[:, -self.D_out:]
assert x.ndim == y.ndim == 2
assert x.shape[0] == y.shape[0]
......@@ -1111,7 +1113,7 @@ class _ARMixin(object):
return self.D_out
def predict(self, x):
return super(_ARMixin,self).predict(np.atleast_2d(x.ravel()))
return super(_ARMixin,self).predict(np.atleast_2d(x))
def rvs(self,lagged_data):
return super(_ARMixin,self).rvs(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment