diff options
Diffstat (limited to 'caffe2/operators/tile_op.h')
-rw-r--r-- | caffe2/operators/tile_op.h | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/caffe2/operators/tile_op.h b/caffe2/operators/tile_op.h index ad0b924e40..72cd56da80 100644 --- a/caffe2/operators/tile_op.h +++ b/caffe2/operators/tile_op.h @@ -74,13 +74,12 @@ class TileOp final : public Operator<Context> { } const auto& X = Input(0); - auto* Y = Output(0); const int axis = X.canonical_axis_index(axis_); // reshape output to be input tiled along the axis std::vector<std::int64_t> Y_dims = X.sizes().vec(); Y_dims[axis] *= tiles_; - Y->Resize(Y_dims); + auto* Y = Output(0, Y_dims, at::dtype<T>()); // size up to (and not including) axis const int outer_size = X.size_to_dim(axis); @@ -179,14 +178,13 @@ class TileGradientOp final : public Operator<Context> { } const auto& dY = Input(0); - auto* dX = Output(0); const int axis = dY.canonical_axis_index(axis_); // reshape output to be input "untiled" along the axis std::vector<std::int64_t> X_dims = dY.sizes().vec(); CAFFE_ENFORCE_EQ(X_dims[axis] % tiles_, 0); X_dims[axis] /= tiles_; - dX->Resize(X_dims); + auto* dX = Output(0, X_dims, at::dtype<T>()); // size up to (and not including) axis const int outer_size = dX->size_to_dim(axis); |