diff options
author | Mark Santaniello <marksan@fb.com> | 2019-04-15 23:40:21 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-15 23:43:13 -0700 |
commit | 20fc7b6ec7cb7b43f66526abf9d5b4fa44ffefb7 (patch) | |
tree | 4eba94cb33204a3d6a744ea0157355f7569854f6 /caffe2 | |
parent | ada10ad416f91602d46797b50b3faebcacc6e767 (diff) | |
download | pytorch-20fc7b6ec7cb7b43f66526abf9d5b4fa44ffefb7.tar.gz pytorch-20fc7b6ec7cb7b43f66526abf9d5b4fa44ffefb7.tar.bz2 pytorch-20fc7b6ec7cb7b43f66526abf9d5b4fa44ffefb7.zip |
Avoid undefined symbol error when building AdIndexer LTO (#19009)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19009
Move the definition of `MulFunctor<>::Backward()` into a header file.
Reviewed By: BIT-silence
Differential Revision: D14823230
fbshipit-source-id: 1efaec01863fcc02dcbe7e788d376e72f8564501
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/operators/elementwise_mul_gradient_op.cc | 81 | ||||
-rw-r--r-- | caffe2/operators/elementwise_mul_op.h | 81 |
2 files changed, 81 insertions, 81 deletions
diff --git a/caffe2/operators/elementwise_mul_gradient_op.cc b/caffe2/operators/elementwise_mul_gradient_op.cc index cbc73660d6..5065504349 100644 --- a/caffe2/operators/elementwise_mul_gradient_op.cc +++ b/caffe2/operators/elementwise_mul_gradient_op.cc @@ -7,87 +7,6 @@ namespace caffe2 { -namespace { - -template <typename TGrad, typename TIn> -void ComputeMulGradient( - const int ndim, - const int* A_dims, - const int* B_dims, - const int* C_dims, - const TGrad* dC, - const TIn* A, - const TIn* B, - TGrad* dA, - TGrad* dB, - CPUContext* context) { - const int A_size = - std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>()); - const int B_size = - std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>()); - const int C_size = - std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>()); - math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context); - math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context); - std::vector<int> index(ndim, 0); - for (int C_index = 0; C_index < C_size; ++C_index) { - const int A_index = - math::utils::GetIndexFromDims(ndim, A_dims, index.data()); - const int B_index = - math::utils::GetIndexFromDims(ndim, B_dims, index.data()); - dA[A_index] += dC[C_index] * B[B_index]; - dB[B_index] += dC[C_index] * A[A_index]; - math::utils::IncreaseIndexInDims(ndim, C_dims, index.data()); - } -} - -} // namespace - -template <> -template <typename TGrad, typename TIn, typename TOut> -bool MulFunctor<CPUContext>::Backward( - const std::vector<int>& A_dims, - const std::vector<int>& B_dims, - const TGrad* dC, - const TIn* A, - const TIn* B, - const TOut* /* C */, - TGrad* dA, - TGrad* dB, - CPUContext* context) const { - if (A_dims == B_dims) { - const int size = std::accumulate( - A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>()); - math::Mul(size, dC, B, dA, context); - math::Mul(size, dC, A, dB, context); - return true; - } - const int ndim = std::max(A_dims.size(), B_dims.size()); - std::vector<int> A_broadcast_dims(ndim); - std::vector<int> B_broadcast_dims(ndim); - std::vector<int> C_broadcast_dims(ndim); - math::utils::ComputeBroadcastBinaryOpDims( - A_dims.size(), - A_dims.data(), - B_dims.size(), - B_dims.data(), - A_broadcast_dims.data(), - B_broadcast_dims.data(), - C_broadcast_dims.data()); - ComputeMulGradient<TGrad, TIn>( - ndim, - A_broadcast_dims.data(), - B_broadcast_dims.data(), - C_broadcast_dims.data(), - dC, - A, - B, - dA, - dB, - context); - return true; -} - REGISTER_CPU_OPERATOR( MulGradient, BinaryElementwiseGradientOp< diff --git a/caffe2/operators/elementwise_mul_op.h b/caffe2/operators/elementwise_mul_op.h index f1c42edc48..6b31fe3684 100644 --- a/caffe2/operators/elementwise_mul_op.h +++ b/caffe2/operators/elementwise_mul_op.h @@ -8,6 +8,42 @@ namespace caffe2 { +namespace { + +template <typename TGrad, typename TIn> +void ComputeMulGradient( + const int ndim, + const int* A_dims, + const int* B_dims, + const int* C_dims, + const TGrad* dC, + const TIn* A, + const TIn* B, + TGrad* dA, + TGrad* dB, + CPUContext* context) { + const int A_size = + std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>()); + const int B_size = + std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>()); + const int C_size = + std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>()); + math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context); + math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context); + std::vector<int> index(ndim, 0); + for (int C_index = 0; C_index < C_size; ++C_index) { + const int A_index = + math::utils::GetIndexFromDims(ndim, A_dims, index.data()); + const int B_index = + math::utils::GetIndexFromDims(ndim, B_dims, index.data()); + dA[A_index] += dC[C_index] * B[B_index]; + dB[B_index] += dC[C_index] * A[A_index]; + math::utils::IncreaseIndexInDims(ndim, C_dims, index.data()); + } +} + +} // namespace + template <class Context> struct MulFunctor { template <typename TIn, typename TOut> @@ -43,6 +79,51 @@ struct MulFunctor { Context* context) const; }; +template <> +template <typename TGrad, typename TIn, typename TOut> +bool MulFunctor<CPUContext>::Backward( + const std::vector<int>& A_dims, + const std::vector<int>& B_dims, + const TGrad* dC, + const TIn* A, + const TIn* B, + const TOut* /* C */, + TGrad* dA, + TGrad* dB, + CPUContext* context) const { + if (A_dims == B_dims) { + const int size = std::accumulate( + A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>()); + math::Mul(size, dC, B, dA, context); + math::Mul(size, dC, A, dB, context); + return true; + } + const int ndim = std::max(A_dims.size(), B_dims.size()); + std::vector<int> A_broadcast_dims(ndim); + std::vector<int> B_broadcast_dims(ndim); + std::vector<int> C_broadcast_dims(ndim); + math::utils::ComputeBroadcastBinaryOpDims( + A_dims.size(), + A_dims.data(), + B_dims.size(), + B_dims.data(), + A_broadcast_dims.data(), + B_broadcast_dims.data(), + C_broadcast_dims.data()); + ComputeMulGradient<TGrad, TIn>( + ndim, + A_broadcast_dims.data(), + B_broadcast_dims.data(), + C_broadcast_dims.data(), + dC, + A, + B, + dA, + dB, + context); + return true; +} + } // namespace caffe2 #endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_ |