diff options
author | Xiaomeng Yang <yangxm@fb.com> | 2019-03-12 11:54:29 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-12 12:16:30 -0700 |
commit | f229521154e3b2093b9088fb21d48bedf3551c6b (patch) | |
tree | fc31290ab7691d4b168dd8f0f93575c671fe9617 | |
parent | 54b33503ec022b39173f08edd8136d01d058dea0 (diff) | |
download | pytorch-f229521154e3b2093b9088fb21d48bedf3551c6b.tar.gz pytorch-f229521154e3b2093b9088fb21d48bedf3551c6b.tar.bz2 pytorch-f229521154e3b2093b9088fb21d48bedf3551c6b.zip |
Optimize TileOp (#17290)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17290
Optimize TileOp
Reviewed By: wesolwsk
Differential Revision: D14145844
fbshipit-source-id: 1571fa0512218dbc48080592ede4e23903be85dd
-rw-r--r-- | caffe2/operators/tile_op.cc | 122 | ||||
-rw-r--r-- | caffe2/operators/tile_op.cu | 155 | ||||
-rw-r--r-- | caffe2/operators/tile_op.h | 267 |
3 files changed, 315 insertions, 229 deletions
diff --git a/caffe2/operators/tile_op.cc b/caffe2/operators/tile_op.cc index 52f6c10d1b..11d096e899 100644 --- a/caffe2/operators/tile_op.cc +++ b/caffe2/operators/tile_op.cc @@ -1,33 +1,103 @@ #include "caffe2/operators/tile_op.h" +#include <string> + namespace caffe2 { +template <> +bool TileOp<CPUContext>::RunOnDevice() { + return DispatchHelper< + TensorTypes<std::int32_t, std::int64_t, float, double, std::string>>:: + call(this, Input(0)); +} + +template <> +template <> +bool TileOp<CPUContext>::DoRunWithType<std::string>() { + if (InputSize() > 1) { + // We potentially have tiles and/or axis specified as inputs + // as well. We will check for them in that order. In other words: + // InputSize() == 2: tiles is specified + // InputSize() == 3: tiles is specified and axis. + // Anything specified as input will override the arguments + CAFFE_ENFORCE( + Input(1).dim() == 1 && Input(1).numel() == 1, + "Input `tiles` should be a vector of size 1."); + tiles_ = GetArgFromTensor(Input(1)); + if (InputSize() > 2) { + CAFFE_ENFORCE( + Input(2).dim() == 1 && Input(2).numel() == 1, + "Input `axis` should be a vector of size 1."); + axis_ = GetArgFromTensor(Input(2)); + } else { + CAFFE_ENFORCE( + OperatorBase::HasArgument("axis"), + "Argument `axis` is missing and was not specified as input."); + } + } else { + CAFFE_ENFORCE( + OperatorBase::HasArgument("tiles"), + "Argument `tiles` is missing and was not specified as input."); + CAFFE_ENFORCE( + OperatorBase::HasArgument("axis"), + "Argument `axis` is missing and was not specified as input."); + } + + const auto& X = Input(0); + auto* Y = Output(0); + const int axis = X.canonical_axis_index(axis_); + + // reshape output to be input tiled along the axis + std::vector<std::int64_t> Y_dims = X.sizes().vec(); + Y_dims[axis] *= tiles_; + Y->Resize(Y_dims); + + // size up to (and not including) axis + const int outer_size = X.size_to_dim(axis); + // size from axis up + const int inner_size = X.size_from_dim(axis); + + const TypeMeta& meta = X.dtype(); + const int item_size = X.itemsize(); + const char* X_ptr = reinterpret_cast<const char*>(X.raw_data()); + char* Y_ptr = reinterpret_cast<char*>(Y->raw_mutable_data(meta)); + for (int i = 0; i < outer_size; ++i) { + for (int t = 0; t < tiles_; ++t) { + context_.CopyItemsSameDevice(meta, inner_size, X_ptr, Y_ptr); + Y_ptr += inner_size * item_size; + } + X_ptr += inner_size * item_size; + } + return true; +} + REGISTER_CPU_OPERATOR(Tile, TileOp<CPUContext>); -REGISTER_CPU_OPERATOR(TileGradient, TileGradientOp<float, CPUContext>); +REGISTER_CPU_OPERATOR(TileGradient, TileGradientOp<CPUContext>); OPERATOR_SCHEMA(Tile) .NumInputs(1, 3) .NumOutputs(1) - .TensorInferenceFunction( - [](const OperatorDef& def, const vector<TensorShape>& in) { - vector<TensorShape> out(1); - out[0] = TensorShape(in[0]); - ArgumentHelper helper(def); - - auto tiles = helper.GetSingleArgument<int32_t>("tiles", 1); - auto axis = helper.GetSingleArgument<int32_t>("axis", 0); - if (in.size() > 1) { - // Tile or axis is specified as input; we can't determine - // the size - out[0].set_unknown_shape(true); - } else { - const auto canonical_axis = - canonical_axis_index_(axis, out[0].dims().size()); - out[0].set_dims( - canonical_axis, out[0].dims().Get(canonical_axis) * tiles); - } - return out; - }) + .TensorInferenceFunction([](const OperatorDef& def, + const std::vector<TensorShape>& in) { + std::vector<TensorShape> out(1); + out[0] = TensorShape(in[0]); + ArgumentHelper helper(def); + const std::int32_t tiles = + helper.GetSingleArgument<std::int32_t>("tiles", 1); + const std::int32_t axis = + helper.GetSingleArgument<std::int32_t>("axis", 0); + if (in.size() > 1) { + // Tile or axis is specified as input; we can't determine + // the size + out[0].set_unknown_shape(true); + } else { + const auto canonical_axis = + canonical_axis_index_(axis, out[0].dims().size()); + out[0].set_dims( + canonical_axis, out[0].dims().Get(canonical_axis) * tiles); + } + return out; + }) .SetDoc(R"DOC( Constructs a tensor by tiling a given tensor along a specified axis. This operation creates a new tensor by replicating the input tensor a number of times specified by the `tiles` argument along the `axis` dimension. The output tensor's `axis` dimension has $(X.dims(axis) * tiles)$ elements. @@ -97,12 +167,14 @@ Y: OPERATOR_SCHEMA(TileGradient).NumInputs(1, 3).NumOutputs(1); +namespace { + class GetTileGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; - vector<OperatorDef> GetGradientDefs() override { + std::vector<OperatorDef> GetGradientDefs() override { // Check whether the tiles/axis information was // passed through input arguments - vector<std::string> g_inputs({GO(0)}); + std::vector<std::string> g_inputs({GO(0)}); if (Def().input_size() > 1) { g_inputs.push_back(I(1)); } @@ -110,10 +182,12 @@ class GetTileGradient : public GradientMakerBase { g_inputs.push_back(I(2)); } return SingleGradientDef( - "TileGradient", "", g_inputs, vector<string>{GI(0)}); + "TileGradient", "", g_inputs, std::vector<std::string>{GI(0)}); } }; +} // namespace + REGISTER_GRADIENT(Tile, GetTileGradient); } // namespace caffe2 diff --git a/caffe2/operators/tile_op.cu b/caffe2/operators/tile_op.cu index b8f78201b3..a0a7e294e7 100644 --- a/caffe2/operators/tile_op.cu +++ b/caffe2/operators/tile_op.cu @@ -1,93 +1,102 @@ -#include <cub/block/block_reduce.cuh> +#include "caffe2/operators/tile_op.h" + +#include <array> #include "caffe2/core/context_gpu.h" -#include "caffe2/operators/tile_op.h" +#include "caffe2/utils/math.h" namespace caffe2 { + namespace { + template <typename T> -__global__ void TileCopyKernel( - int outer_dim, - int inner_dim, - int tiles, - const T* input_data, - T* output_data) { - CUDA_1D_KERNEL_LOOP(index, outer_dim * inner_dim * tiles) { - int col = index % inner_dim; - int row = index / (inner_dim * tiles); - output_data[index] = input_data[row * inner_dim + col]; +__global__ void TileCopyCUDAKernel( + const int total_size, + const int inner_size, + const int tiles, + const T* X, + T* Y) { + const int x = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (x < total_size) { + const int r = x / inner_size / tiles; + const int c = x % inner_size; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + Y[x] = __ldg(X + r * inner_size + c); +#else + Y[x] = X[r * inner_size + c]; +#endif } } -template <typename T> -__global__ void TileGradientAxpyKernel( - int outer_dim, - int inner_dim, - int tiles, - const T* input_data, - T* output_data) { - typedef cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS> BlockReduce; - - for (int idx = blockIdx.x; idx < outer_dim * inner_dim; idx += gridDim.x) { - int i = idx / inner_dim; - int j = idx % inner_dim; - T* output_ptr = output_data + inner_dim * i; +} // namespace - T x = 0.0; - for (int t = threadIdx.x; t < tiles; t += blockDim.x) { - const T* input_ptr = input_data + (i * tiles + t) * inner_dim; - x += input_ptr[j]; - } - __shared__ typename BlockReduce::TempStorage temp_storage; - T totx = BlockReduce(temp_storage).Sum(x); - if (threadIdx.x == 0) { - output_ptr[j] = totx; - } - __syncthreads(); - } +template <> +template <typename T> +bool TileOp<CUDAContext>::DoTile( + const int outer_size, + const int inner_size, + const T* X, + T* Y) { + const std::int64_t total_size = static_cast<std::int64_t>(outer_size) * + static_cast<std::int64_t>(tiles_) * static_cast<std::int64_t>(inner_size); + const int M = math::DivUp<std::int64_t>(total_size, CAFFE_CUDA_NUM_THREADS); + TileCopyCUDAKernel<T> + <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>( + total_size, inner_size, tiles_, X, Y); + return true; } -} // namespace template <> -void TileOp<CUDAContext>::DoTile( - const TypeMeta& meta, - int item_size, - int outer_dim, - int inner_dim, - const char* input_data, - char* output_data) { - TileCopyKernel<float> - <<<std::min(outer_dim * inner_dim * tiles_, CAFFE_MAXIMUM_NUM_BLOCKS), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - outer_dim, - inner_dim, - tiles_, - reinterpret_cast<const float*>(input_data), - reinterpret_cast<float*>(output_data)); +template <typename T> +bool TileGradientOp<CUDAContext>::DoTileGradient( + const int outer_size, + const int inner_size, + const T* dY, + T* dX) { + const std::array<int, 3> dY_dims = {outer_size, tiles_, inner_size}; + const std::array<int, 3> dX_dims = {outer_size, 1, inner_size}; + math::ReduceSum<T, CUDAContext>( + 3, dY_dims.data(), dX_dims.data(), T(1), dY, dX, &context_); + return true; } template <> -void TileGradientOp<float, CUDAContext>::DoTileGradient( - const TypeMeta& meta, - int item_size, - int outer_dim, - int inner_dim, - const char* input_data, - char* output_data) { - TileGradientAxpyKernel<float><<< - std::min(outer_dim * inner_dim, CAFFE_MAXIMUM_NUM_BLOCKS), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - outer_dim, - inner_dim, - tiles_, - reinterpret_cast<const float*>(input_data), - reinterpret_cast<float*>(output_data)); +template <> +bool TileGradientOp<CUDAContext>::DoTileGradient<float>( + const int outer_size, + const int inner_size, + const float* dY, + float* dX) { + if (inner_size == 1) { + const std::array<int, 2> dY_dims = {outer_size, tiles_}; + const std::array<int, 2> dX_dims = {outer_size, 1}; + math::ReduceSum<float, CUDAContext>( + 2, dY_dims.data(), dX_dims.data(), 1.0f, dY, dX, &context_); + } else { + ReinitializeTensor(&ones_, tiles_, at::dtype<float>().device(CUDA)); + math::Set<float, CUDAContext>( + tiles_, 1.0f, ones_.template mutable_data<float>(), &context_); + math::GemmStridedBatched<float, CUDAContext>( + CblasTrans, + CblasNoTrans, + outer_size, + inner_size, + 1, + tiles_, + 1.0f, + dY, + tiles_ * inner_size, + ones_.template data<float>(), + 0, + 0.0f, + dX, + inner_size, + &context_); + } + return true; } REGISTER_CUDA_OPERATOR(Tile, TileOp<CUDAContext>); -REGISTER_CUDA_OPERATOR(TileGradient, TileGradientOp<float, CUDAContext>); +REGISTER_CUDA_OPERATOR(TileGradient, TileGradientOp<CUDAContext>); + } // namespace caffe2 diff --git a/caffe2/operators/tile_op.h b/caffe2/operators/tile_op.h index df33bc6461..71f0bde2d7 100644 --- a/caffe2/operators/tile_op.h +++ b/caffe2/operators/tile_op.h @@ -1,29 +1,40 @@ #ifndef CAFFE2_OPERATORS_TILE_OP_H_ #define CAFFE2_OPERATORS_TILE_OP_H_ +#include <array> +#include <string> +#include <type_traits> +#include <vector> + #include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" +#include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/math.h" namespace caffe2 { // Copy a Blob n times along a specified axis. template <class Context> -class TileOp : public Operator<Context> { +class TileOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; + template <class... Args> explicit TileOp(Args&&... args) : Operator<Context>(std::forward<Args>(args)...), - tiles_(this->template GetSingleArgument<int32_t>("tiles", 1)), - axis_(this->template GetSingleArgument<int32_t>("axis", 0)) {} - ~TileOp() {} + OP_SINGLE_ARG(std::int32_t, "tiles", tiles_, 1), + OP_SINGLE_ARG(std::int32_t, "axis", axis_, 0) {} bool RunOnDevice() override { - const auto& input = Input(0); - std::array<int32_t, 2> temp_params = {{tiles_, axis_}}; + return DispatchHelper< + TensorTypes<std::int32_t, std::int64_t, float, double>>:: + call(this, Input(0)); + } + + template <typename T> + bool DoRunWithType() { if (InputSize() > 1) { // We potentially have tiles and/or axis specified as inputs // as well. We will check for them in that order. In other words: @@ -33,25 +44,12 @@ class TileOp : public Operator<Context> { CAFFE_ENFORCE( Input(1).dim() == 1 && Input(1).numel() == 1, "Input `tiles` should be a vector of size 1."); - - const auto& input1 = Input(1); - context_.CopyItemsToCPU( - input1.dtype(), - 1, - static_cast<const char*>(input1.raw_data()), - &(temp_params[0])); - + tiles_ = GetArgFromTensor(Input(1)); if (InputSize() > 2) { CAFFE_ENFORCE( Input(2).dim() == 1 && Input(2).numel() == 1, "Input `axis` should be a vector of size 1."); - - const auto& input2 = Input(2); - context_.CopyItemsToCPU( - input2.dtype(), - 1, - static_cast<const char*>(input2.raw_data()), - &(temp_params[1])); + axis_ = GetArgFromTensor(Input(2)); } else { CAFFE_ENFORCE( OperatorBase::HasArgument("axis"), @@ -66,79 +64,82 @@ class TileOp : public Operator<Context> { "Argument `axis` is missing and was not specified as input."); } - tiles_ = temp_params[0]; - axis_ = temp_params[1]; - - auto* output = Output(0); - const auto axis = input.canonical_axis_index(axis_); + const auto& X = Input(0); + auto* Y = Output(0); + const int axis = X.canonical_axis_index(axis_); // reshape output to be input tiled along the axis - vector<int64_t> output_dims(input.sizes().vec()); - output_dims[axis_] = output_dims[axis_] * tiles_; - output->Resize(output_dims); + std::vector<std::int64_t> Y_dims = X.sizes().vec(); + Y_dims[axis] *= tiles_; + Y->Resize(Y_dims); // size up to (and not including) axis - const auto outer_dim = input.size_to_dim(axis); + const int outer_size = X.size_to_dim(axis); // size from axis up - const auto inner_dim = input.size_from_dim(axis); + const int inner_size = X.size_from_dim(axis); - /** - * How this works: - * Imagine a 2D tensor (matrix) of size 3x10, tiled 2 times. - * - Tiling along axis 0 (row) means copying the entire 3x10 Matrix 2 - * times. outer_dim = 0, inner_dim = 30. - * - Tiling along axis 1 (column) means copying each row 2 times, then - * proceed to the next row, until the end. outer_dim = 3, inner_dim = 10. - */ - const char* input_data = static_cast<const char*>(input.raw_data()); - char* output_data = - static_cast<char*>(output->raw_mutable_data(input.dtype())); - - DoTile( - input.dtype(), - input.itemsize(), - outer_dim, - inner_dim, - input_data, - output_data); - - return true; + const T* X_data = X.template data<T>(); + T* Y_data = Y->template mutable_data<T>(); + return DoTile<T>(outer_size, inner_size, X_data, Y_data); } private: - void DoTile( - const TypeMeta& meta, - int item_size, - int outer_dim, - int inner_dim, - const char* input_data, - char* output_data) { - for (auto i = 0; i < outer_dim; ++i) { - for (auto t = 0; t < tiles_; ++t) { - context_.CopyItemsSameDevice(meta, inner_dim, input_data, output_data); - output_data += inner_dim * item_size; + std::int32_t GetArgFromTensor(const Tensor& tensor) { + CAFFE_ENFORCE( + tensor.IsType<std::int32_t>() || tensor.IsType<std::int64_t>()); + std::int32_t val = -1; + if (tensor.IsType<std::int32_t>()) { + context_.template CopyToCPU<std::int32_t>( + 1, tensor.data<std::int32_t>(), &val); + } else if (tensor.IsType<std::int64_t>()) { + std::int64_t val_int64; + context_.template CopyToCPU<std::int64_t>( + 1, tensor.data<std::int64_t>(), &val_int64); + val = static_cast<std::int32_t>(val_int64); + } + return val; + } + + template <typename T> + bool DoTile(const int outer_size, const int inner_size, const T* X, T* Y) { + if (inner_size == 1) { + EigenArrayMap<T> Y_arr(Y, tiles_, outer_size); + for (int i = 0; i < outer_size; ++i) { + Y_arr.col(i) = X[i]; + } + } else { + ConstEigenArrayMap<T> X_arr(X, inner_size, outer_size); + for (int i = 0; i < outer_size; ++i) { + EigenArrayMap<T>(Y + i * tiles_ * inner_size, inner_size, tiles_) + .colwise() = X_arr.col(i); } - input_data += inner_dim * item_size; } + return true; } - int32_t tiles_; - int32_t axis_; + std::int32_t tiles_; + std::int32_t axis_; }; -template <typename T, class Context> -class TileGradientOp : public Operator<Context> { +template <class Context> +class TileGradientOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; + template <class... Args> explicit TileGradientOp(Args&&... args) : Operator<Context>(std::forward<Args>(args)...), - tiles_(this->template GetSingleArgument<int32_t>("tiles", 1)), - axis_(this->template GetSingleArgument<int32_t>("axis", 0)) {} - ~TileGradientOp() {} + OP_SINGLE_ARG(std::int32_t, "tiles", tiles_, 1), + OP_SINGLE_ARG(std::int32_t, "axis", axis_, 0) {} bool RunOnDevice() override { - std::array<int32_t, 2> temp_params = {{tiles_, axis_}}; + return DispatchHelper< + TensorTypes<std::int32_t, std::int64_t, float, double>>:: + call(this, Input(0)); + } + + template <typename T> + bool DoRunWithType() { if (InputSize() > 1) { // We potentially have tiles and/or axis specified as inputs // as well. We will check for them in that order. In other words: @@ -148,25 +149,12 @@ class TileGradientOp : public Operator<Context> { CAFFE_ENFORCE( Input(1).dim() == 1 && Input(1).numel() == 1, "Input `tiles` should be a vector of size 1."); - - const auto& input1 = Input(1); - context_.CopyItemsToCPU( - input1.dtype(), - 1, - static_cast<const char*>(input1.raw_data()), - &(temp_params[0])); - + tiles_ = GetArgFromTensor(Input(1)); if (InputSize() > 2) { CAFFE_ENFORCE( Input(2).dim() == 1 && Input(2).numel() == 1, "Input `axis` should be a vector of size 1."); - - const auto& input2 = Input(2); - context_.CopyItemsToCPU( - input2.dtype(), - 1, - static_cast<const char*>(input2.raw_data()), - &(temp_params[1])); + axis_ = GetArgFromTensor(Input(2)); } else { CAFFE_ENFORCE( OperatorBase::HasArgument("axis"), @@ -181,22 +169,20 @@ class TileGradientOp : public Operator<Context> { "Argument `axis` is missing and was not specified as input."); } - tiles_ = temp_params[0]; - axis_ = temp_params[1]; - - const auto& input = Input(0); - auto* output = Output(0); - const auto axis = input.canonical_axis_index(axis_); + const auto& dY = Input(0); + auto* dX = Output(0); + const int axis = dY.canonical_axis_index(axis_); // reshape output to be input "untiled" along the axis - vector<int64_t> output_dims(input.sizes().vec()); - output_dims[axis_] = output_dims[axis_] / tiles_; - output->Resize(output_dims); + std::vector<std::int64_t> X_dims = dY.sizes().vec(); + CAFFE_ENFORCE_EQ(X_dims[axis] % tiles_, 0); + X_dims[axis] /= tiles_; + dX->Resize(X_dims); // size up to (and not including) axis - const auto outer_dim = output->size_to_dim(axis); + const int outer_size = dX->size_to_dim(axis); // size from axis up - const auto inner_dim = output->size_from_dim(axis); + const int inner_size = dX->size_from_dim(axis); /** * How this works: @@ -208,47 +194,64 @@ class TileGradientOp : public Operator<Context> { * So the output gradient should be the matrix multipication result * of input gradient (gradient of tiled tensor output) and X. */ - const char* input_data = static_cast<const char*>(input.raw_data()); - char* output_data = - static_cast<char*>(output->raw_mutable_data(input.dtype())); - - DoTileGradient( - input.dtype(), - input.itemsize(), - outer_dim, - inner_dim, - input_data, - output_data); - - return true; + const T* dY_data = dY.template data<T>(); + T* dX_data = dX->template mutable_data<T>(); + return DoTileGradient<T>(outer_size, inner_size, dY_data, dX_data); } private: - void DoTileGradient( - const TypeMeta& meta, - int item_size, - int outer_dim, - int inner_dim, - const char* input_data, - char* output_data) { - for (auto i = 0; i < outer_dim; ++i) { - context_.CopyItemsSameDevice(meta, inner_dim, input_data, output_data); - input_data += inner_dim * item_size; - for (auto t = 1; t < tiles_; ++t) { - math::Axpy<T, Context>( - inner_dim, - T(1), - reinterpret_cast<const T*>(input_data), - reinterpret_cast<T*>(output_data), - &context_); - input_data += inner_dim * item_size; + std::int32_t GetArgFromTensor(const Tensor& tensor) { + CAFFE_ENFORCE( + tensor.IsType<std::int32_t>() || tensor.IsType<std::int64_t>()); + std::int32_t val = -1; + if (tensor.IsType<std::int32_t>()) { + context_.template CopyToCPU<std::int32_t>( + 1, tensor.data<std::int32_t>(), &val); + } else if (tensor.IsType<std::int64_t>()) { + std::int64_t val_int64; + context_.template CopyToCPU<std::int64_t>( + 1, tensor.data<std::int64_t>(), &val_int64); + val = static_cast<std::int32_t>(val_int64); + } + return val; + } + + template <typename T> + bool DoTileGradient( + const int outer_size, + const int inner_size, + const T* dY, + T* dX) { + if (inner_size == 1) { + const std::array<int, 2> dY_dims = {outer_size, tiles_}; + const std::array<int, 2> dX_dims = {outer_size, 1}; + math::ReduceSum<T, Context>( + 2, dY_dims.data(), dX_dims.data(), T(1), dY, dX, &context_); + } else { + math::CopyMatrix<T, Context>( + outer_size, + inner_size, + dY, + inner_size * tiles_, + dX, + inner_size, + &context_); + for (int i = 0; i < outer_size; ++i) { + const T* dY_ptr = dY + i * tiles_ * inner_size; + T* dX_ptr = dX + i * inner_size; + for (int j = 1; j < tiles_; ++j) { + math::Add<T, Context>( + inner_size, dX_ptr, dY_ptr + j * inner_size, dX_ptr, &context_); + } } - output_data += inner_dim * item_size; } + return true; } - int32_t tiles_; - int32_t axis_; + std::int32_t tiles_; + std::int32_t axis_; + + Tensor ones_; }; } // namespace caffe2 |