summaryrefslogtreecommitdiff
path: root/caffe2/ideep
diff options
context:
space:
mode:
authorCheng,Penghui <penghui.cheng@intel.com>2019-01-11 12:48:57 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-11 12:53:06 -0800
commit926e718d5fe129f67a10eef7ef8ce754b25c1e1e (patch)
tree759265d7d43a1a580ebfe794b7bb54f1db848392 /caffe2/ideep
parent96ea2594d882fc177a6730c42cd3c266f9bb2a67 (diff)
downloadpytorch-926e718d5fe129f67a10eef7ef8ce754b25c1e1e.tar.gz
pytorch-926e718d5fe129f67a10eef7ef8ce754b25c1e1e.tar.bz2
pytorch-926e718d5fe129f67a10eef7ef8ce754b25c1e1e.zip
Add/fallback some operators for mkl-dnn (#11696)
Summary: Implementation LeakyRelu operator for mkl-dnn,the speed-up of a single operation is up to 10X on BDW. Implementation rashape operator for mkl-dnn,it will resolve occasionally crash issue which use fallback reshape operator. Implementation CreateBlobQueue and SafeEnqueueBlobs operators,it will resolve crash issue which use fallback operators. Fallback CreateBlobsQueueDBOp,TensorProtosDBInput,CloseBlobsQueue operators. Implement adam operator for mkl-dnn,the speed-up of a single operator is up to 6X on BDW. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11696 Reviewed By: yinghai Differential Revision: D10100438 Pulled By: wesolwsk fbshipit-source-id: 0b6e06897cc11e0a8e349d80a870b1e72e47f10d
Diffstat (limited to 'caffe2/ideep')
-rw-r--r--caffe2/ideep/operators/adam_op.cc179
-rw-r--r--caffe2/ideep/operators/operator_fallback_ideep.cc19
-rw-r--r--caffe2/ideep/operators/operator_fallback_ideep.h2
-rw-r--r--caffe2/ideep/operators/queue_ops.cc71
-rw-r--r--caffe2/ideep/operators/relu_op.cc38
-rw-r--r--caffe2/ideep/operators/reshape_op.cc121
6 files changed, 411 insertions, 19 deletions
diff --git a/caffe2/ideep/operators/adam_op.cc b/caffe2/ideep/operators/adam_op.cc
new file mode 100644
index 0000000000..732cf74b49
--- /dev/null
+++ b/caffe2/ideep/operators/adam_op.cc
@@ -0,0 +1,179 @@
+#include <caffe2/ideep/ideep_utils.h>
+
+namespace caffe2 {
+
+void adam_ideep_update(
+ int N,
+ const float* g,
+ const float* m,
+ const float* v,
+ float* ng,
+ float* nm,
+ float* nv,
+ float beta1,
+ float beta2,
+ float eps_hat,
+ float correction,
+ const float* lr) {
+#ifdef _OPENMP
+ #pragma omp parallel for schedule(static)
+#endif
+ for (auto i = 0; i < N; ++i) {
+ float gi = g[i];
+ float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
+ float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
+ ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
+ }
+}
+
+void adam_ideep_compute(
+ int N,
+ const float* w,
+ const float* g,
+ const float* m,
+ const float* v,
+ float* nw,
+ float* nm,
+ float* nv,
+ float beta1,
+ float beta2,
+ float eps_hat,
+ float correction,
+ const float* lr) {
+#ifdef _OPENMP
+ #pragma omp parallel for schedule(static)
+#endif
+ for (auto i = 0; i < N; ++i) {
+ float gi = g[i];
+ float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
+ float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
+ nw[i] = w[i] + lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
+ }
+}
+
+void adam_ideep_compute_output_grad(
+ int N,
+ const float* w,
+ const float* g,
+ const float* m,
+ const float* v,
+ float* nw,
+ float* nm,
+ float* nv,
+ float* ng,
+ float beta1,
+ float beta2,
+ float eps_hat,
+ float correction,
+ const float* lr) {
+
+#ifdef _OPENMP
+ #pragma omp parallel for schedule(static)
+#endif
+ for (auto i = 0; i < N; ++i) {
+ float gi = g[i];
+ float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
+ float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
+ float ngi = ng[i] = correction * mi / (std::sqrt(vi) + eps_hat);
+ nw[i] = w[i] + lr[0] * ngi;
+ }
+}
+
+template <typename T>
+class IDEEPAdamOp final : public IDEEPOperator {
+ public:
+ USE_IDEEP_DEF_ALIASES();
+ USE_IDEEP_OPERATOR_FUNCTIONS();
+
+ IDEEPAdamOp(const OperatorDef& operator_def, Workspace* ws)
+ : IDEEPOperator(operator_def, ws),
+ beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)),
+ beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)),
+ epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
+ bool RunOnDevice() override {
+ // Iter live on the CPU
+ CAFFE_ENFORCE(OperatorBase::InputIsTensorType(ITER, CPU));
+ const auto& params = Input(PARAM);
+ const auto& moment_1 = Input(MOMENT_1);
+ const auto& moment_2 = Input(MOMENT_2);
+ const auto& grad = Input(GRAD);
+ // TODO: Use itensor after 0-dim is supported. Now use CPU tensor.
+ const auto& lr = OperatorBase::Input<TensorCPU>(LR, CPU);
+ auto* out_params = Output(OUTPUT_PARAM);
+ auto* out_moment1 = Output(OUTPUT_MOMENT_1);
+ auto* out_moment2 = Output(OUTPUT_MOMENT_2);
+
+ CAFFE_ENFORCE(lr.size() == 1);
+ CAFFE_ENFORCE(grad.get_nelems() == params.get_nelems());
+ CAFFE_ENFORCE(grad.get_nelems() == moment_1.get_nelems());
+ CAFFE_ENFORCE(grad.get_nelems() == moment_2.get_nelems());
+ if (params != *out_params)
+ out_params->reinit(params.get_descriptor());
+ if (moment_1 != *out_moment1)
+ out_moment1->reinit(moment_1.get_descriptor());
+ if (moment_2 != *out_moment2)
+ out_moment2->reinit(moment_2.get_descriptor());
+ const auto w = static_cast<float *>(params.get_data_handle());
+ const auto g = static_cast<float *>(grad.get_data_handle());
+ const auto m = static_cast<float *>(moment_1.get_data_handle());
+ const auto v = static_cast<float *>(moment_2.get_data_handle());
+ auto nw = static_cast<float *>(out_params->get_data_handle());
+ auto nm = static_cast<float *>(out_moment1->get_data_handle());
+ auto nv = static_cast<float *>(out_moment2->get_data_handle());
+ const auto nlr = lr.template data<T>();
+ const auto iter =
+ OperatorBase::Input<TensorCPU>(ITER, CPU).template data<int64_t>()[0];
+ const auto t = iter + 1;
+ const auto correction =
+ std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
+ if (OutputSize() == 3) {
+ adam_ideep_compute(
+ grad.get_nelems(),
+ w,
+ g,
+ m,
+ v,
+ nw,
+ nm,
+ nv,
+ beta1_,
+ beta2_,
+ epsilon_,
+ correction,
+ nlr);
+ } else {
+ auto* out_grad = Output(OUTPUT_GRAD);
+ if (grad != *out_grad)
+ out_grad->reinit(grad.get_descriptor());
+ auto ng = static_cast<float *>(out_grad->get_data_handle());
+ adam_ideep_compute_output_grad(
+ grad.get_nelems(),
+ w,
+ g,
+ m,
+ v,
+ nw,
+ nm,
+ nv,
+ ng,
+ beta1_,
+ beta2_,
+ epsilon_,
+ correction,
+ nlr);
+ }
+
+ return true;
+ }
+
+ protected:
+ T beta1_{0.9};
+ T beta2_{0.999};
+ T epsilon_{1e-8};
+ INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER);
+ OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
+};
+
+REGISTER_IDEEP_OPERATOR(Adam, IDEEPAdamOp<float>);
+
+} // namespace caffe2
diff --git a/caffe2/ideep/operators/operator_fallback_ideep.cc b/caffe2/ideep/operators/operator_fallback_ideep.cc
index d078a56e30..6016923c5f 100644
--- a/caffe2/ideep/operators/operator_fallback_ideep.cc
+++ b/caffe2/ideep/operators/operator_fallback_ideep.cc
@@ -18,12 +18,10 @@
#include <caffe2/operators/flatten_op.h>
#include <caffe2/operators/generate_proposals_op.h>
#include <caffe2/operators/given_tensor_fill_op.h>
-#include <caffe2/operators/leaky_relu_op.h>
#include <caffe2/operators/load_save_op.h>
#include <caffe2/operators/loss_op.h>
#include <caffe2/operators/pad_op.h>
#include <caffe2/operators/prelu_op.h>
-#include <caffe2/operators/reshape_op.h>
#include <caffe2/operators/roi_align_op.h>
#include <caffe2/operators/roi_align_rotated_op.h>
#include <caffe2/operators/scale_op.h>
@@ -32,9 +30,10 @@
#include <caffe2/operators/transpose_op.h>
#include <caffe2/operators/affine_channel_op.h>
#include <caffe2/operators/stop_gradient.h>
-#include <caffe2/sgd/adam_op.h>
#include <caffe2/sgd/iter_op.h>
#include <caffe2/sgd/learning_rate_op.h>
+#include <caffe2/queue/queue_ops.h>
+#include <caffe2/operators/tensor_protos_db_input.h>
// can add more non-IDEEP operators if needed
namespace caffe2 {
@@ -52,9 +51,6 @@ REGISTER_IDEEP_OPERATOR(
REGISTER_IDEEP_OPERATOR(Flatten, IDEEPFallbackOp<FlattenOp<CPUContext>>);
REGISTER_IDEEP_OPERATOR(ResizeLike, IDEEPFallbackOp<ResizeLikeOp<CPUContext>>);
REGISTER_IDEEP_OPERATOR(Transpose, IDEEPFallbackOp<TransposeOp<CPUContext>>);
-REGISTER_IDEEP_OPERATOR(
- Reshape,
- IDEEPFallbackOp<ReshapeOp<float, CPUContext>, SkipIndices<1>>);
// filter operators
REGISTER_IDEEP_OPERATOR(
@@ -109,7 +105,7 @@ REGISTER_IDEEP_OPERATOR(
REGISTER_IDEEP_OPERATOR(
PRelu,
IDEEPFallbackOp<PReluOp<float, CPUContext>>);
-
+
// ctc decoder operators
REGISTER_IDEEP_OPERATOR(
CTCGreedyDecoder,
@@ -134,9 +130,6 @@ REGISTER_IDEEP_OPERATOR(
LearningRate,
IDEEPFallbackOp<LearningRateOp<float, CPUContext>>);
REGISTER_IDEEP_OPERATOR(
- LeakyRelu,
- IDEEPFallbackOp<LeakyReluOp<float, CPUContext>>);
-REGISTER_IDEEP_OPERATOR(
Mul,
IDEEPFallbackOp<
BinaryElementwiseOp<NumericTypes, CPUContext, MulFunctor<CPUContext>>>);
@@ -170,14 +163,12 @@ REGISTER_IDEEP_OPERATOR(
ConvTransposeGradient,
IDEEPFallbackOp<ConvTransposeGradientOp<float, CPUContext>>);
REGISTER_IDEEP_OPERATOR(
- LeakyReluGradient,
- IDEEPFallbackOp<LeakyReluGradientOp<float, CPUContext>>);
-REGISTER_IDEEP_OPERATOR(
MulGradient,
IDEEPFallbackOp<BinaryElementwiseGradientOp<
NumericTypes,
CPUContext,
MulFunctor<CPUContext>>>);
-REGISTER_IDEEP_OPERATOR(Adam, IDEEPFallbackOp<AdamOp<float, CPUContext>>);
+REGISTER_IDEEP_OPERATOR(TensorProtosDBInput, IDEEPFallbackOp<TensorProtosDBInput<CPUContext>>);
+REGISTER_IDEEP_OPERATOR(CloseBlobsQueue, IDEEPFallbackOp<CloseBlobsQueueOp<CPUContext>>);
} // namespace caffe2
diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h
index 77001b94db..4372807e7a 100644
--- a/caffe2/ideep/operators/operator_fallback_ideep.h
+++ b/caffe2/ideep/operators/operator_fallback_ideep.h
@@ -111,7 +111,7 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
}
}
- if (!base_op_->Run()) {
+ if (!base_op_->Run(0)) {
LOG(ERROR) << "Base op run failed in IDEEPFallbackOp. Def: "
<< ProtoDebugString(this->debug_def());
return false;
diff --git a/caffe2/ideep/operators/queue_ops.cc b/caffe2/ideep/operators/queue_ops.cc
new file mode 100644
index 0000000000..fb7887cdea
--- /dev/null
+++ b/caffe2/ideep/operators/queue_ops.cc
@@ -0,0 +1,71 @@
+#include <caffe2/ideep/ideep_utils.h>
+#include <caffe2/queue/blobs_queue.h>
+
+namespace caffe2 {
+
+class IDEEPCreateBlobsQueueOp final : public IDEEPOperator {
+ public:
+ USE_IDEEP_DEF_ALIASES();
+ USE_IDEEP_OPERATOR_FUNCTIONS();
+
+ IDEEPCreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws)
+ : IDEEPOperator(operator_def, ws),
+ ws_(ws),
+ name(operator_def.output().Get(0)) {}
+
+ bool RunOnDevice() override {
+ const auto capacity = GetSingleArgument("capacity", 1);
+ const auto numBlobs = GetSingleArgument("num_blobs", 1);
+ const auto enforceUniqueName =
+ GetSingleArgument("enforce_unique_name", false);
+ const auto fieldNames =
+ OperatorBase::template GetRepeatedArgument<std::string>("field_names");
+ CAFFE_ENFORCE_EQ(this->OutputSize(), 1);
+ auto queuePtr = OperatorBase::Outputs()[0]
+ ->template GetMutable<std::shared_ptr<BlobsQueue>>();
+
+ CAFFE_ENFORCE(queuePtr);
+ *queuePtr = std::make_shared<BlobsQueue>(
+ ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames);
+ return true;
+ }
+
+ private:
+ Workspace* ws_{nullptr};
+ const std::string name;
+};
+
+class IDEEPSafeEnqueueBlobsOp final : public IDEEPOperator {
+ public:
+ USE_IDEEP_DEF_ALIASES();
+ USE_IDEEP_OPERATOR_FUNCTIONS();
+
+ IDEEPSafeEnqueueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
+ : IDEEPOperator(operator_def, ws) {}
+
+ bool RunOnDevice() override {
+ auto queue =
+ OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
+ CAFFE_ENFORCE(queue);
+ auto size = queue->getNumBlobs();
+ CAFFE_ENFORCE(
+ OutputSize() == size + 1,
+ "Expected " + caffe2::to_string(size + 1) + ", " +
+ " got: " + caffe2::to_string(size));
+ bool status = queue->blockingWrite(OperatorBase::Outputs());
+
+ auto st = OperatorBase::Output<TensorCPU>(1, CPU);
+ st->Resize();
+ auto stat = st->template mutable_data<bool>();
+ stat[0] = !status;
+ return true;
+ }
+};
+
+REGISTER_IDEEP_OPERATOR(CreateBlobsQueue, IDEEPCreateBlobsQueueOp);
+SHOULD_NOT_DO_GRADIENT(IDEEPCreateBlobsQueueOp);
+
+REGISTER_IDEEP_OPERATOR(SafeEnqueueBlobs, IDEEPSafeEnqueueBlobsOp);
+SHOULD_NOT_DO_GRADIENT(IDEEPSafeEnqueueBlobsOp);
+
+} // namespace caffe2
diff --git a/caffe2/ideep/operators/relu_op.cc b/caffe2/ideep/operators/relu_op.cc
index 7f81d0ea71..7e591ff4d5 100644
--- a/caffe2/ideep/operators/relu_op.cc
+++ b/caffe2/ideep/operators/relu_op.cc
@@ -8,19 +8,33 @@ class IDEEPReluOp final : public IDEEPOperator {
USE_IDEEP_OPERATOR_FUNCTIONS();
IDEEPReluOp(const OperatorDef& operator_def, Workspace* ws)
- : IDEEPOperator(operator_def, ws) {}
+ : IDEEPOperator(operator_def, ws), alpha_(0.0) {
+ // Figure out the Relu descriptor.
+ if (operator_def.type().substr(0, 4) == "Relu") {
+ alpha_ = 0.0;
+ } else if (operator_def.type().substr(0, 9) == "LeakyRelu") {
+ if (HasArgument("alpha")) {
+ alpha_ = static_cast<float>(
+ OperatorBase::GetSingleArgument<float>("alpha", 0.01));
+ }
+ } else {
+ LOG(FATAL) << "Unsupported Relu method: " << operator_def.type();
+ }
+ }
virtual ~IDEEPReluOp() {}
bool RunOnDevice() override {
const auto& X = Input(INPUT);
auto* Y = Output(OUTPUT);
- ideep::eltwise_forward::compute(X, *Y);
+ ideep::eltwise_forward::compute(
+ X, *Y, ialgo::eltwise_relu, iprop::forward_training, alpha_);
return true;
}
private:
+ float alpha_;
INPUT_TAGS(INPUT);
OUTPUT_TAGS(OUTPUT);
@@ -32,7 +46,19 @@ class IDEEPReluGradientOp final : public IDEEPOperator {
USE_IDEEP_OPERATOR_FUNCTIONS();
IDEEPReluGradientOp(const OperatorDef& operator_def, Workspace* ws)
- : IDEEPOperator(operator_def, ws) {}
+ : IDEEPOperator(operator_def, ws), alpha_(0.0) {
+ // Figure out the Relu descriptor.
+ if (operator_def.type().substr(0, 12) == "ReluGradient") {
+ alpha_ = 0.0;
+ } else if (operator_def.type().substr(0, 17) == "LeakyReluGradient") {
+ if (HasArgument("alpha")) {
+ alpha_ = static_cast<float>(
+ OperatorBase::GetSingleArgument<float>("alpha", 0.01));
+ }
+ } else {
+ LOG(FATAL) << "Unsupported Relu method: " << operator_def.type();
+ }
+ }
virtual ~IDEEPReluGradientOp() {}
bool RunOnDevice() override {
@@ -40,12 +66,13 @@ class IDEEPReluGradientOp final : public IDEEPOperator {
const auto& dY = Input(OUTPUT_GRAD);
auto* dX = Output(INPUT_GRAD);
- ideep::eltwise_backward::compute(Y, dY, *dX);
+ ideep::eltwise_backward::compute(Y, dY, *dX, ialgo::eltwise_relu, alpha_);
return true;
}
private:
+ float alpha_;
INPUT_TAGS(OUTPUT, OUTPUT_GRAD);
OUTPUT_TAGS(INPUT_GRAD);
@@ -54,4 +81,7 @@ class IDEEPReluGradientOp final : public IDEEPOperator {
REGISTER_IDEEP_OPERATOR(Relu, IDEEPReluOp);
REGISTER_IDEEP_OPERATOR(ReluGradient, IDEEPReluGradientOp);
+REGISTER_IDEEP_OPERATOR(LeakyRelu, IDEEPReluOp);
+REGISTER_IDEEP_OPERATOR(LeakyReluGradient, IDEEPReluGradientOp);
+
} // namespace caffe2
diff --git a/caffe2/ideep/operators/reshape_op.cc b/caffe2/ideep/operators/reshape_op.cc
new file mode 100644
index 0000000000..6a63359940
--- /dev/null
+++ b/caffe2/ideep/operators/reshape_op.cc
@@ -0,0 +1,121 @@
+#include <caffe2/ideep/ideep_utils.h>
+
+namespace caffe2 {
+
+// Takes a shape and data tensor and reshapes it
+class IDEEPReshapeOp final : public IDEEPOperator {
+ public:
+ USE_IDEEP_DEF_ALIASES();
+ USE_IDEEP_OPERATOR_FUNCTIONS();
+
+ IDEEPReshapeOp(const OperatorDef& operator_def, Workspace* ws)
+ : IDEEPOperator(operator_def, ws),
+ new_shape_(OperatorBase::GetRepeatedArgument<int>("shape")) {}
+
+ bool RunOnDevice() override {
+ ideep::tensor::dims actual_new_shape = new_shape_;
+ if (InputSize() == 2) {
+ CAFFE_ENFORCE(
+ !OperatorBase::HasArgument("shape"),
+ "New shape is specified by the input blob, do not pass in "
+ "the argument `shape`.");
+
+ // shape info live on CPU
+ auto& shape = OperatorBase::Input<TensorCPU>(1, CPU);
+ CAFFE_ENFORCE(shape.ndim() == 1, "Shape should be 1-D");
+ const int* shape_data = shape.template data<int>();
+
+ actual_new_shape.reserve(shape.size());
+ actual_new_shape.assign(shape_data, shape_data + shape.size());
+ } else {
+ CAFFE_ENFORCE(
+ OperatorBase::HasArgument("shape"), "Argument `shape` is missing.");
+ }
+
+ auto& input = Input(0);
+ // Copy over the dimensions for those that are specified zero.
+ for (int i = 0; i < actual_new_shape.size() && i < input.ndims(); ++i) {
+ if (actual_new_shape[i] == 0) {
+ actual_new_shape[i] = input.get_dim(i);
+ }
+ }
+
+ // Checks if the new shape is valid and fills in the missing dimension
+ // specified by -1.
+ // NOTE: At most one dimension can be -1.
+ auto total_size = input.get_nelems();
+ int size = 1;
+ int unknown_idx = -1;
+ for (int i = 0; i < actual_new_shape.size(); ++i) {
+ const auto dim = actual_new_shape[i];
+ if (dim == -1) {
+ CAFFE_ENFORCE(
+ unknown_idx == -1,
+ "Argument `shape` has more than one missing dimension.");
+ unknown_idx = i;
+ } else {
+ size *= dim;
+ }
+ }
+ if (size == 0 && total_size != 0) {
+ CAFFE_THROW(
+ "Can not reshape a non-zero size (",
+ total_size,
+ ") tensor to zero size.");
+ }
+
+ if (unknown_idx != -1) {
+ CAFFE_ENFORCE_NE(
+ size,
+ 0,
+ "New shape at dim ",
+ unknown_idx,
+ " can not be inferred since new size is zero.");
+ CAFFE_ENFORCE(
+ total_size % size == 0,
+ "Argument `shape` does not agree with the input data.",
+ " (",
+ total_size,
+ " vs ",
+ size,
+ ")");
+ actual_new_shape[unknown_idx] = total_size / size;
+ } else {
+ CAFFE_ENFORCE_EQ(
+ total_size,
+ size,
+ "Argument `shape` does not agree with the input data.",
+ " (",
+ total_size,
+ " != ",
+ size,
+ ")");
+ }
+
+ // Write the original shape to the second output.
+ // shape info live on CPU
+ TensorCPU* old_shape = OperatorBase::Output<TensorCPU>(1, CPU);
+ old_shape->Resize(input.ndims());
+ int* old_shape_data = old_shape->template mutable_data<int>();
+ for (int i = 0; i < input.ndims(); ++i) {
+ old_shape_data[i] = input.get_dim(i);
+ }
+
+ auto* output = Output(0);
+ if (output != &input) {
+ // If we are not doing in-place computation, a copy is needed.
+ output->reinit_like(input);
+ ideep::direct_copy::compute(input, *output);
+ }
+
+ output->reshape(actual_new_shape);
+ return true;
+ }
+
+ private:
+ ideep::tensor::dims new_shape_;
+};
+
+REGISTER_IDEEP_OPERATOR(Reshape, IDEEPReshapeOp);
+
+} // namespace caffe2