summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorMark Santaniello <marksan@fb.com>2019-04-15 23:40:21 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-15 23:43:13 -0700
commit20fc7b6ec7cb7b43f66526abf9d5b4fa44ffefb7 (patch)
tree4eba94cb33204a3d6a744ea0157355f7569854f6 /caffe2
parentada10ad416f91602d46797b50b3faebcacc6e767 (diff)
downloadpytorch-20fc7b6ec7cb7b43f66526abf9d5b4fa44ffefb7.tar.gz
pytorch-20fc7b6ec7cb7b43f66526abf9d5b4fa44ffefb7.tar.bz2
pytorch-20fc7b6ec7cb7b43f66526abf9d5b4fa44ffefb7.zip
Avoid undefined symbol error when building AdIndexer LTO (#19009)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19009 Move the definition of `MulFunctor<>::Backward()` into a header file. Reviewed By: BIT-silence Differential Revision: D14823230 fbshipit-source-id: 1efaec01863fcc02dcbe7e788d376e72f8564501
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/operators/elementwise_mul_gradient_op.cc81
-rw-r--r--caffe2/operators/elementwise_mul_op.h81
2 files changed, 81 insertions, 81 deletions
diff --git a/caffe2/operators/elementwise_mul_gradient_op.cc b/caffe2/operators/elementwise_mul_gradient_op.cc
index cbc73660d6..5065504349 100644
--- a/caffe2/operators/elementwise_mul_gradient_op.cc
+++ b/caffe2/operators/elementwise_mul_gradient_op.cc
@@ -7,87 +7,6 @@
namespace caffe2 {
-namespace {
-
-template <typename TGrad, typename TIn>
-void ComputeMulGradient(
- const int ndim,
- const int* A_dims,
- const int* B_dims,
- const int* C_dims,
- const TGrad* dC,
- const TIn* A,
- const TIn* B,
- TGrad* dA,
- TGrad* dB,
- CPUContext* context) {
- const int A_size =
- std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
- const int B_size =
- std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
- const int C_size =
- std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
- math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
- math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
- std::vector<int> index(ndim, 0);
- for (int C_index = 0; C_index < C_size; ++C_index) {
- const int A_index =
- math::utils::GetIndexFromDims(ndim, A_dims, index.data());
- const int B_index =
- math::utils::GetIndexFromDims(ndim, B_dims, index.data());
- dA[A_index] += dC[C_index] * B[B_index];
- dB[B_index] += dC[C_index] * A[A_index];
- math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
- }
-}
-
-} // namespace
-
-template <>
-template <typename TGrad, typename TIn, typename TOut>
-bool MulFunctor<CPUContext>::Backward(
- const std::vector<int>& A_dims,
- const std::vector<int>& B_dims,
- const TGrad* dC,
- const TIn* A,
- const TIn* B,
- const TOut* /* C */,
- TGrad* dA,
- TGrad* dB,
- CPUContext* context) const {
- if (A_dims == B_dims) {
- const int size = std::accumulate(
- A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
- math::Mul(size, dC, B, dA, context);
- math::Mul(size, dC, A, dB, context);
- return true;
- }
- const int ndim = std::max(A_dims.size(), B_dims.size());
- std::vector<int> A_broadcast_dims(ndim);
- std::vector<int> B_broadcast_dims(ndim);
- std::vector<int> C_broadcast_dims(ndim);
- math::utils::ComputeBroadcastBinaryOpDims(
- A_dims.size(),
- A_dims.data(),
- B_dims.size(),
- B_dims.data(),
- A_broadcast_dims.data(),
- B_broadcast_dims.data(),
- C_broadcast_dims.data());
- ComputeMulGradient<TGrad, TIn>(
- ndim,
- A_broadcast_dims.data(),
- B_broadcast_dims.data(),
- C_broadcast_dims.data(),
- dC,
- A,
- B,
- dA,
- dB,
- context);
- return true;
-}
-
REGISTER_CPU_OPERATOR(
MulGradient,
BinaryElementwiseGradientOp<
diff --git a/caffe2/operators/elementwise_mul_op.h b/caffe2/operators/elementwise_mul_op.h
index f1c42edc48..6b31fe3684 100644
--- a/caffe2/operators/elementwise_mul_op.h
+++ b/caffe2/operators/elementwise_mul_op.h
@@ -8,6 +8,42 @@
namespace caffe2 {
+namespace {
+
+template <typename TGrad, typename TIn>
+void ComputeMulGradient(
+ const int ndim,
+ const int* A_dims,
+ const int* B_dims,
+ const int* C_dims,
+ const TGrad* dC,
+ const TIn* A,
+ const TIn* B,
+ TGrad* dA,
+ TGrad* dB,
+ CPUContext* context) {
+ const int A_size =
+ std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
+ const int B_size =
+ std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
+ const int C_size =
+ std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
+ math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
+ math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
+ std::vector<int> index(ndim, 0);
+ for (int C_index = 0; C_index < C_size; ++C_index) {
+ const int A_index =
+ math::utils::GetIndexFromDims(ndim, A_dims, index.data());
+ const int B_index =
+ math::utils::GetIndexFromDims(ndim, B_dims, index.data());
+ dA[A_index] += dC[C_index] * B[B_index];
+ dB[B_index] += dC[C_index] * A[A_index];
+ math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
+ }
+}
+
+} // namespace
+
template <class Context>
struct MulFunctor {
template <typename TIn, typename TOut>
@@ -43,6 +79,51 @@ struct MulFunctor {
Context* context) const;
};
+template <>
+template <typename TGrad, typename TIn, typename TOut>
+bool MulFunctor<CPUContext>::Backward(
+ const std::vector<int>& A_dims,
+ const std::vector<int>& B_dims,
+ const TGrad* dC,
+ const TIn* A,
+ const TIn* B,
+ const TOut* /* C */,
+ TGrad* dA,
+ TGrad* dB,
+ CPUContext* context) const {
+ if (A_dims == B_dims) {
+ const int size = std::accumulate(
+ A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
+ math::Mul(size, dC, B, dA, context);
+ math::Mul(size, dC, A, dB, context);
+ return true;
+ }
+ const int ndim = std::max(A_dims.size(), B_dims.size());
+ std::vector<int> A_broadcast_dims(ndim);
+ std::vector<int> B_broadcast_dims(ndim);
+ std::vector<int> C_broadcast_dims(ndim);
+ math::utils::ComputeBroadcastBinaryOpDims(
+ A_dims.size(),
+ A_dims.data(),
+ B_dims.size(),
+ B_dims.data(),
+ A_broadcast_dims.data(),
+ B_broadcast_dims.data(),
+ C_broadcast_dims.data());
+ ComputeMulGradient<TGrad, TIn>(
+ ndim,
+ A_broadcast_dims.data(),
+ B_broadcast_dims.data(),
+ C_broadcast_dims.data(),
+ dC,
+ A,
+ B,
+ dA,
+ dB,
+ context);
+ return true;
+}
+
} // namespace caffe2
#endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_