diff options
author | Cheng,Penghui <penghui.cheng@intel.com> | 2019-01-11 12:48:57 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-11 12:53:06 -0800 |
commit | 926e718d5fe129f67a10eef7ef8ce754b25c1e1e (patch) | |
tree | 759265d7d43a1a580ebfe794b7bb54f1db848392 /caffe2/ideep | |
parent | 96ea2594d882fc177a6730c42cd3c266f9bb2a67 (diff) | |
download | pytorch-926e718d5fe129f67a10eef7ef8ce754b25c1e1e.tar.gz pytorch-926e718d5fe129f67a10eef7ef8ce754b25c1e1e.tar.bz2 pytorch-926e718d5fe129f67a10eef7ef8ce754b25c1e1e.zip |
Add/fallback some operators for mkl-dnn (#11696)
Summary:
Implementation LeakyRelu operator for mkl-dnn,the speed-up of a single operation is up to 10X on BDW.
Implementation rashape operator for mkl-dnn,it will resolve occasionally crash issue which use fallback reshape operator.
Implementation CreateBlobQueue and SafeEnqueueBlobs operators,it will resolve crash issue which use fallback operators.
Fallback CreateBlobsQueueDBOp,TensorProtosDBInput,CloseBlobsQueue operators.
Implement adam operator for mkl-dnn,the speed-up of a single operator is up to 6X on BDW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11696
Reviewed By: yinghai
Differential Revision: D10100438
Pulled By: wesolwsk
fbshipit-source-id: 0b6e06897cc11e0a8e349d80a870b1e72e47f10d
Diffstat (limited to 'caffe2/ideep')
-rw-r--r-- | caffe2/ideep/operators/adam_op.cc | 179 | ||||
-rw-r--r-- | caffe2/ideep/operators/operator_fallback_ideep.cc | 19 | ||||
-rw-r--r-- | caffe2/ideep/operators/operator_fallback_ideep.h | 2 | ||||
-rw-r--r-- | caffe2/ideep/operators/queue_ops.cc | 71 | ||||
-rw-r--r-- | caffe2/ideep/operators/relu_op.cc | 38 | ||||
-rw-r--r-- | caffe2/ideep/operators/reshape_op.cc | 121 |
6 files changed, 411 insertions, 19 deletions
diff --git a/caffe2/ideep/operators/adam_op.cc b/caffe2/ideep/operators/adam_op.cc new file mode 100644 index 0000000000..732cf74b49 --- /dev/null +++ b/caffe2/ideep/operators/adam_op.cc @@ -0,0 +1,179 @@ +#include <caffe2/ideep/ideep_utils.h> + +namespace caffe2 { + +void adam_ideep_update( + int N, + const float* g, + const float* m, + const float* v, + float* ng, + float* nm, + float* nv, + float beta1, + float beta2, + float eps_hat, + float correction, + const float* lr) { +#ifdef _OPENMP + #pragma omp parallel for schedule(static) +#endif + for (auto i = 0; i < N; ++i) { + float gi = g[i]; + float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); + float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); + ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat); + } +} + +void adam_ideep_compute( + int N, + const float* w, + const float* g, + const float* m, + const float* v, + float* nw, + float* nm, + float* nv, + float beta1, + float beta2, + float eps_hat, + float correction, + const float* lr) { +#ifdef _OPENMP + #pragma omp parallel for schedule(static) +#endif + for (auto i = 0; i < N; ++i) { + float gi = g[i]; + float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); + float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); + nw[i] = w[i] + lr[0] * correction * mi / (std::sqrt(vi) + eps_hat); + } +} + +void adam_ideep_compute_output_grad( + int N, + const float* w, + const float* g, + const float* m, + const float* v, + float* nw, + float* nm, + float* nv, + float* ng, + float beta1, + float beta2, + float eps_hat, + float correction, + const float* lr) { + +#ifdef _OPENMP + #pragma omp parallel for schedule(static) +#endif + for (auto i = 0; i < N; ++i) { + float gi = g[i]; + float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); + float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); + float ngi = ng[i] = correction * mi / (std::sqrt(vi) + eps_hat); + nw[i] = w[i] + lr[0] * ngi; + } +} + +template <typename T> +class IDEEPAdamOp final : public IDEEPOperator { + public: + USE_IDEEP_DEF_ALIASES(); + USE_IDEEP_OPERATOR_FUNCTIONS(); + + IDEEPAdamOp(const OperatorDef& operator_def, Workspace* ws) + : IDEEPOperator(operator_def, ws), + beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)), + beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)), + epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {} + bool RunOnDevice() override { + // Iter live on the CPU + CAFFE_ENFORCE(OperatorBase::InputIsTensorType(ITER, CPU)); + const auto& params = Input(PARAM); + const auto& moment_1 = Input(MOMENT_1); + const auto& moment_2 = Input(MOMENT_2); + const auto& grad = Input(GRAD); + // TODO: Use itensor after 0-dim is supported. Now use CPU tensor. + const auto& lr = OperatorBase::Input<TensorCPU>(LR, CPU); + auto* out_params = Output(OUTPUT_PARAM); + auto* out_moment1 = Output(OUTPUT_MOMENT_1); + auto* out_moment2 = Output(OUTPUT_MOMENT_2); + + CAFFE_ENFORCE(lr.size() == 1); + CAFFE_ENFORCE(grad.get_nelems() == params.get_nelems()); + CAFFE_ENFORCE(grad.get_nelems() == moment_1.get_nelems()); + CAFFE_ENFORCE(grad.get_nelems() == moment_2.get_nelems()); + if (params != *out_params) + out_params->reinit(params.get_descriptor()); + if (moment_1 != *out_moment1) + out_moment1->reinit(moment_1.get_descriptor()); + if (moment_2 != *out_moment2) + out_moment2->reinit(moment_2.get_descriptor()); + const auto w = static_cast<float *>(params.get_data_handle()); + const auto g = static_cast<float *>(grad.get_data_handle()); + const auto m = static_cast<float *>(moment_1.get_data_handle()); + const auto v = static_cast<float *>(moment_2.get_data_handle()); + auto nw = static_cast<float *>(out_params->get_data_handle()); + auto nm = static_cast<float *>(out_moment1->get_data_handle()); + auto nv = static_cast<float *>(out_moment2->get_data_handle()); + const auto nlr = lr.template data<T>(); + const auto iter = + OperatorBase::Input<TensorCPU>(ITER, CPU).template data<int64_t>()[0]; + const auto t = iter + 1; + const auto correction = + std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t)); + if (OutputSize() == 3) { + adam_ideep_compute( + grad.get_nelems(), + w, + g, + m, + v, + nw, + nm, + nv, + beta1_, + beta2_, + epsilon_, + correction, + nlr); + } else { + auto* out_grad = Output(OUTPUT_GRAD); + if (grad != *out_grad) + out_grad->reinit(grad.get_descriptor()); + auto ng = static_cast<float *>(out_grad->get_data_handle()); + adam_ideep_compute_output_grad( + grad.get_nelems(), + w, + g, + m, + v, + nw, + nm, + nv, + ng, + beta1_, + beta2_, + epsilon_, + correction, + nlr); + } + + return true; + } + + protected: + T beta1_{0.9}; + T beta2_{0.999}; + T epsilon_{1e-8}; + INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER); + OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD); +}; + +REGISTER_IDEEP_OPERATOR(Adam, IDEEPAdamOp<float>); + +} // namespace caffe2 diff --git a/caffe2/ideep/operators/operator_fallback_ideep.cc b/caffe2/ideep/operators/operator_fallback_ideep.cc index d078a56e30..6016923c5f 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.cc +++ b/caffe2/ideep/operators/operator_fallback_ideep.cc @@ -18,12 +18,10 @@ #include <caffe2/operators/flatten_op.h> #include <caffe2/operators/generate_proposals_op.h> #include <caffe2/operators/given_tensor_fill_op.h> -#include <caffe2/operators/leaky_relu_op.h> #include <caffe2/operators/load_save_op.h> #include <caffe2/operators/loss_op.h> #include <caffe2/operators/pad_op.h> #include <caffe2/operators/prelu_op.h> -#include <caffe2/operators/reshape_op.h> #include <caffe2/operators/roi_align_op.h> #include <caffe2/operators/roi_align_rotated_op.h> #include <caffe2/operators/scale_op.h> @@ -32,9 +30,10 @@ #include <caffe2/operators/transpose_op.h> #include <caffe2/operators/affine_channel_op.h> #include <caffe2/operators/stop_gradient.h> -#include <caffe2/sgd/adam_op.h> #include <caffe2/sgd/iter_op.h> #include <caffe2/sgd/learning_rate_op.h> +#include <caffe2/queue/queue_ops.h> +#include <caffe2/operators/tensor_protos_db_input.h> // can add more non-IDEEP operators if needed namespace caffe2 { @@ -52,9 +51,6 @@ REGISTER_IDEEP_OPERATOR( REGISTER_IDEEP_OPERATOR(Flatten, IDEEPFallbackOp<FlattenOp<CPUContext>>); REGISTER_IDEEP_OPERATOR(ResizeLike, IDEEPFallbackOp<ResizeLikeOp<CPUContext>>); REGISTER_IDEEP_OPERATOR(Transpose, IDEEPFallbackOp<TransposeOp<CPUContext>>); -REGISTER_IDEEP_OPERATOR( - Reshape, - IDEEPFallbackOp<ReshapeOp<float, CPUContext>, SkipIndices<1>>); // filter operators REGISTER_IDEEP_OPERATOR( @@ -109,7 +105,7 @@ REGISTER_IDEEP_OPERATOR( REGISTER_IDEEP_OPERATOR( PRelu, IDEEPFallbackOp<PReluOp<float, CPUContext>>); - + // ctc decoder operators REGISTER_IDEEP_OPERATOR( CTCGreedyDecoder, @@ -134,9 +130,6 @@ REGISTER_IDEEP_OPERATOR( LearningRate, IDEEPFallbackOp<LearningRateOp<float, CPUContext>>); REGISTER_IDEEP_OPERATOR( - LeakyRelu, - IDEEPFallbackOp<LeakyReluOp<float, CPUContext>>); -REGISTER_IDEEP_OPERATOR( Mul, IDEEPFallbackOp< BinaryElementwiseOp<NumericTypes, CPUContext, MulFunctor<CPUContext>>>); @@ -170,14 +163,12 @@ REGISTER_IDEEP_OPERATOR( ConvTransposeGradient, IDEEPFallbackOp<ConvTransposeGradientOp<float, CPUContext>>); REGISTER_IDEEP_OPERATOR( - LeakyReluGradient, - IDEEPFallbackOp<LeakyReluGradientOp<float, CPUContext>>); -REGISTER_IDEEP_OPERATOR( MulGradient, IDEEPFallbackOp<BinaryElementwiseGradientOp< NumericTypes, CPUContext, MulFunctor<CPUContext>>>); -REGISTER_IDEEP_OPERATOR(Adam, IDEEPFallbackOp<AdamOp<float, CPUContext>>); +REGISTER_IDEEP_OPERATOR(TensorProtosDBInput, IDEEPFallbackOp<TensorProtosDBInput<CPUContext>>); +REGISTER_IDEEP_OPERATOR(CloseBlobsQueue, IDEEPFallbackOp<CloseBlobsQueueOp<CPUContext>>); } // namespace caffe2 diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h index 77001b94db..4372807e7a 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.h +++ b/caffe2/ideep/operators/operator_fallback_ideep.h @@ -111,7 +111,7 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator { } } - if (!base_op_->Run()) { + if (!base_op_->Run(0)) { LOG(ERROR) << "Base op run failed in IDEEPFallbackOp. Def: " << ProtoDebugString(this->debug_def()); return false; diff --git a/caffe2/ideep/operators/queue_ops.cc b/caffe2/ideep/operators/queue_ops.cc new file mode 100644 index 0000000000..fb7887cdea --- /dev/null +++ b/caffe2/ideep/operators/queue_ops.cc @@ -0,0 +1,71 @@ +#include <caffe2/ideep/ideep_utils.h> +#include <caffe2/queue/blobs_queue.h> + +namespace caffe2 { + +class IDEEPCreateBlobsQueueOp final : public IDEEPOperator { + public: + USE_IDEEP_DEF_ALIASES(); + USE_IDEEP_OPERATOR_FUNCTIONS(); + + IDEEPCreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws) + : IDEEPOperator(operator_def, ws), + ws_(ws), + name(operator_def.output().Get(0)) {} + + bool RunOnDevice() override { + const auto capacity = GetSingleArgument("capacity", 1); + const auto numBlobs = GetSingleArgument("num_blobs", 1); + const auto enforceUniqueName = + GetSingleArgument("enforce_unique_name", false); + const auto fieldNames = + OperatorBase::template GetRepeatedArgument<std::string>("field_names"); + CAFFE_ENFORCE_EQ(this->OutputSize(), 1); + auto queuePtr = OperatorBase::Outputs()[0] + ->template GetMutable<std::shared_ptr<BlobsQueue>>(); + + CAFFE_ENFORCE(queuePtr); + *queuePtr = std::make_shared<BlobsQueue>( + ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames); + return true; + } + + private: + Workspace* ws_{nullptr}; + const std::string name; +}; + +class IDEEPSafeEnqueueBlobsOp final : public IDEEPOperator { + public: + USE_IDEEP_DEF_ALIASES(); + USE_IDEEP_OPERATOR_FUNCTIONS(); + + IDEEPSafeEnqueueBlobsOp(const OperatorDef& operator_def, Workspace* ws) + : IDEEPOperator(operator_def, ws) {} + + bool RunOnDevice() override { + auto queue = + OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>(); + CAFFE_ENFORCE(queue); + auto size = queue->getNumBlobs(); + CAFFE_ENFORCE( + OutputSize() == size + 1, + "Expected " + caffe2::to_string(size + 1) + ", " + + " got: " + caffe2::to_string(size)); + bool status = queue->blockingWrite(OperatorBase::Outputs()); + + auto st = OperatorBase::Output<TensorCPU>(1, CPU); + st->Resize(); + auto stat = st->template mutable_data<bool>(); + stat[0] = !status; + return true; + } +}; + +REGISTER_IDEEP_OPERATOR(CreateBlobsQueue, IDEEPCreateBlobsQueueOp); +SHOULD_NOT_DO_GRADIENT(IDEEPCreateBlobsQueueOp); + +REGISTER_IDEEP_OPERATOR(SafeEnqueueBlobs, IDEEPSafeEnqueueBlobsOp); +SHOULD_NOT_DO_GRADIENT(IDEEPSafeEnqueueBlobsOp); + +} // namespace caffe2 diff --git a/caffe2/ideep/operators/relu_op.cc b/caffe2/ideep/operators/relu_op.cc index 7f81d0ea71..7e591ff4d5 100644 --- a/caffe2/ideep/operators/relu_op.cc +++ b/caffe2/ideep/operators/relu_op.cc @@ -8,19 +8,33 @@ class IDEEPReluOp final : public IDEEPOperator { USE_IDEEP_OPERATOR_FUNCTIONS(); IDEEPReluOp(const OperatorDef& operator_def, Workspace* ws) - : IDEEPOperator(operator_def, ws) {} + : IDEEPOperator(operator_def, ws), alpha_(0.0) { + // Figure out the Relu descriptor. + if (operator_def.type().substr(0, 4) == "Relu") { + alpha_ = 0.0; + } else if (operator_def.type().substr(0, 9) == "LeakyRelu") { + if (HasArgument("alpha")) { + alpha_ = static_cast<float>( + OperatorBase::GetSingleArgument<float>("alpha", 0.01)); + } + } else { + LOG(FATAL) << "Unsupported Relu method: " << operator_def.type(); + } + } virtual ~IDEEPReluOp() {} bool RunOnDevice() override { const auto& X = Input(INPUT); auto* Y = Output(OUTPUT); - ideep::eltwise_forward::compute(X, *Y); + ideep::eltwise_forward::compute( + X, *Y, ialgo::eltwise_relu, iprop::forward_training, alpha_); return true; } private: + float alpha_; INPUT_TAGS(INPUT); OUTPUT_TAGS(OUTPUT); @@ -32,7 +46,19 @@ class IDEEPReluGradientOp final : public IDEEPOperator { USE_IDEEP_OPERATOR_FUNCTIONS(); IDEEPReluGradientOp(const OperatorDef& operator_def, Workspace* ws) - : IDEEPOperator(operator_def, ws) {} + : IDEEPOperator(operator_def, ws), alpha_(0.0) { + // Figure out the Relu descriptor. + if (operator_def.type().substr(0, 12) == "ReluGradient") { + alpha_ = 0.0; + } else if (operator_def.type().substr(0, 17) == "LeakyReluGradient") { + if (HasArgument("alpha")) { + alpha_ = static_cast<float>( + OperatorBase::GetSingleArgument<float>("alpha", 0.01)); + } + } else { + LOG(FATAL) << "Unsupported Relu method: " << operator_def.type(); + } + } virtual ~IDEEPReluGradientOp() {} bool RunOnDevice() override { @@ -40,12 +66,13 @@ class IDEEPReluGradientOp final : public IDEEPOperator { const auto& dY = Input(OUTPUT_GRAD); auto* dX = Output(INPUT_GRAD); - ideep::eltwise_backward::compute(Y, dY, *dX); + ideep::eltwise_backward::compute(Y, dY, *dX, ialgo::eltwise_relu, alpha_); return true; } private: + float alpha_; INPUT_TAGS(OUTPUT, OUTPUT_GRAD); OUTPUT_TAGS(INPUT_GRAD); @@ -54,4 +81,7 @@ class IDEEPReluGradientOp final : public IDEEPOperator { REGISTER_IDEEP_OPERATOR(Relu, IDEEPReluOp); REGISTER_IDEEP_OPERATOR(ReluGradient, IDEEPReluGradientOp); +REGISTER_IDEEP_OPERATOR(LeakyRelu, IDEEPReluOp); +REGISTER_IDEEP_OPERATOR(LeakyReluGradient, IDEEPReluGradientOp); + } // namespace caffe2 diff --git a/caffe2/ideep/operators/reshape_op.cc b/caffe2/ideep/operators/reshape_op.cc new file mode 100644 index 0000000000..6a63359940 --- /dev/null +++ b/caffe2/ideep/operators/reshape_op.cc @@ -0,0 +1,121 @@ +#include <caffe2/ideep/ideep_utils.h> + +namespace caffe2 { + +// Takes a shape and data tensor and reshapes it +class IDEEPReshapeOp final : public IDEEPOperator { + public: + USE_IDEEP_DEF_ALIASES(); + USE_IDEEP_OPERATOR_FUNCTIONS(); + + IDEEPReshapeOp(const OperatorDef& operator_def, Workspace* ws) + : IDEEPOperator(operator_def, ws), + new_shape_(OperatorBase::GetRepeatedArgument<int>("shape")) {} + + bool RunOnDevice() override { + ideep::tensor::dims actual_new_shape = new_shape_; + if (InputSize() == 2) { + CAFFE_ENFORCE( + !OperatorBase::HasArgument("shape"), + "New shape is specified by the input blob, do not pass in " + "the argument `shape`."); + + // shape info live on CPU + auto& shape = OperatorBase::Input<TensorCPU>(1, CPU); + CAFFE_ENFORCE(shape.ndim() == 1, "Shape should be 1-D"); + const int* shape_data = shape.template data<int>(); + + actual_new_shape.reserve(shape.size()); + actual_new_shape.assign(shape_data, shape_data + shape.size()); + } else { + CAFFE_ENFORCE( + OperatorBase::HasArgument("shape"), "Argument `shape` is missing."); + } + + auto& input = Input(0); + // Copy over the dimensions for those that are specified zero. + for (int i = 0; i < actual_new_shape.size() && i < input.ndims(); ++i) { + if (actual_new_shape[i] == 0) { + actual_new_shape[i] = input.get_dim(i); + } + } + + // Checks if the new shape is valid and fills in the missing dimension + // specified by -1. + // NOTE: At most one dimension can be -1. + auto total_size = input.get_nelems(); + int size = 1; + int unknown_idx = -1; + for (int i = 0; i < actual_new_shape.size(); ++i) { + const auto dim = actual_new_shape[i]; + if (dim == -1) { + CAFFE_ENFORCE( + unknown_idx == -1, + "Argument `shape` has more than one missing dimension."); + unknown_idx = i; + } else { + size *= dim; + } + } + if (size == 0 && total_size != 0) { + CAFFE_THROW( + "Can not reshape a non-zero size (", + total_size, + ") tensor to zero size."); + } + + if (unknown_idx != -1) { + CAFFE_ENFORCE_NE( + size, + 0, + "New shape at dim ", + unknown_idx, + " can not be inferred since new size is zero."); + CAFFE_ENFORCE( + total_size % size == 0, + "Argument `shape` does not agree with the input data.", + " (", + total_size, + " vs ", + size, + ")"); + actual_new_shape[unknown_idx] = total_size / size; + } else { + CAFFE_ENFORCE_EQ( + total_size, + size, + "Argument `shape` does not agree with the input data.", + " (", + total_size, + " != ", + size, + ")"); + } + + // Write the original shape to the second output. + // shape info live on CPU + TensorCPU* old_shape = OperatorBase::Output<TensorCPU>(1, CPU); + old_shape->Resize(input.ndims()); + int* old_shape_data = old_shape->template mutable_data<int>(); + for (int i = 0; i < input.ndims(); ++i) { + old_shape_data[i] = input.get_dim(i); + } + + auto* output = Output(0); + if (output != &input) { + // If we are not doing in-place computation, a copy is needed. + output->reinit_like(input); + ideep::direct_copy::compute(input, *output); + } + + output->reshape(actual_new_shape); + return true; + } + + private: + ideep::tensor::dims new_shape_; +}; + +REGISTER_IDEEP_OPERATOR(Reshape, IDEEPReshapeOp); + +} // namespace caffe2 |