summaryrefslogtreecommitdiff
path: root/caffe2/operators/tile_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'caffe2/operators/tile_op.h')
-rw-r--r--caffe2/operators/tile_op.h6
1 files changed, 2 insertions, 4 deletions
diff --git a/caffe2/operators/tile_op.h b/caffe2/operators/tile_op.h
index ad0b924e40..72cd56da80 100644
--- a/caffe2/operators/tile_op.h
+++ b/caffe2/operators/tile_op.h
@@ -74,13 +74,12 @@ class TileOp final : public Operator<Context> {
}
const auto& X = Input(0);
- auto* Y = Output(0);
const int axis = X.canonical_axis_index(axis_);
// reshape output to be input tiled along the axis
std::vector<std::int64_t> Y_dims = X.sizes().vec();
Y_dims[axis] *= tiles_;
- Y->Resize(Y_dims);
+ auto* Y = Output(0, Y_dims, at::dtype<T>());
// size up to (and not including) axis
const int outer_size = X.size_to_dim(axis);
@@ -179,14 +178,13 @@ class TileGradientOp final : public Operator<Context> {
}
const auto& dY = Input(0);
- auto* dX = Output(0);
const int axis = dY.canonical_axis_index(axis_);
// reshape output to be input "untiled" along the axis
std::vector<std::int64_t> X_dims = dY.sizes().vec();
CAFFE_ENFORCE_EQ(X_dims[axis] % tiles_, 0);
X_dims[axis] /= tiles_;
- dX->Resize(X_dims);
+ auto* dX = Output(0, X_dims, at::dtype<T>());
// size up to (and not including) axis
const int outer_size = dX->size_to_dim(axis);