summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiaomeng Yang <yangxm@fb.com>2019-03-12 11:54:29 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-12 12:16:30 -0700
commitf229521154e3b2093b9088fb21d48bedf3551c6b (patch)
treefc31290ab7691d4b168dd8f0f93575c671fe9617
parent54b33503ec022b39173f08edd8136d01d058dea0 (diff)
downloadpytorch-f229521154e3b2093b9088fb21d48bedf3551c6b.tar.gz
pytorch-f229521154e3b2093b9088fb21d48bedf3551c6b.tar.bz2
pytorch-f229521154e3b2093b9088fb21d48bedf3551c6b.zip
Optimize TileOp (#17290)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17290 Optimize TileOp Reviewed By: wesolwsk Differential Revision: D14145844 fbshipit-source-id: 1571fa0512218dbc48080592ede4e23903be85dd
-rw-r--r--caffe2/operators/tile_op.cc122
-rw-r--r--caffe2/operators/tile_op.cu155
-rw-r--r--caffe2/operators/tile_op.h267
3 files changed, 315 insertions, 229 deletions
diff --git a/caffe2/operators/tile_op.cc b/caffe2/operators/tile_op.cc
index 52f6c10d1b..11d096e899 100644
--- a/caffe2/operators/tile_op.cc
+++ b/caffe2/operators/tile_op.cc
@@ -1,33 +1,103 @@
#include "caffe2/operators/tile_op.h"
+#include <string>
+
namespace caffe2 {
+template <>
+bool TileOp<CPUContext>::RunOnDevice() {
+ return DispatchHelper<
+ TensorTypes<std::int32_t, std::int64_t, float, double, std::string>>::
+ call(this, Input(0));
+}
+
+template <>
+template <>
+bool TileOp<CPUContext>::DoRunWithType<std::string>() {
+ if (InputSize() > 1) {
+ // We potentially have tiles and/or axis specified as inputs
+ // as well. We will check for them in that order. In other words:
+ // InputSize() == 2: tiles is specified
+ // InputSize() == 3: tiles is specified and axis.
+ // Anything specified as input will override the arguments
+ CAFFE_ENFORCE(
+ Input(1).dim() == 1 && Input(1).numel() == 1,
+ "Input `tiles` should be a vector of size 1.");
+ tiles_ = GetArgFromTensor(Input(1));
+ if (InputSize() > 2) {
+ CAFFE_ENFORCE(
+ Input(2).dim() == 1 && Input(2).numel() == 1,
+ "Input `axis` should be a vector of size 1.");
+ axis_ = GetArgFromTensor(Input(2));
+ } else {
+ CAFFE_ENFORCE(
+ OperatorBase::HasArgument("axis"),
+ "Argument `axis` is missing and was not specified as input.");
+ }
+ } else {
+ CAFFE_ENFORCE(
+ OperatorBase::HasArgument("tiles"),
+ "Argument `tiles` is missing and was not specified as input.");
+ CAFFE_ENFORCE(
+ OperatorBase::HasArgument("axis"),
+ "Argument `axis` is missing and was not specified as input.");
+ }
+
+ const auto& X = Input(0);
+ auto* Y = Output(0);
+ const int axis = X.canonical_axis_index(axis_);
+
+ // reshape output to be input tiled along the axis
+ std::vector<std::int64_t> Y_dims = X.sizes().vec();
+ Y_dims[axis] *= tiles_;
+ Y->Resize(Y_dims);
+
+ // size up to (and not including) axis
+ const int outer_size = X.size_to_dim(axis);
+ // size from axis up
+ const int inner_size = X.size_from_dim(axis);
+
+ const TypeMeta& meta = X.dtype();
+ const int item_size = X.itemsize();
+ const char* X_ptr = reinterpret_cast<const char*>(X.raw_data());
+ char* Y_ptr = reinterpret_cast<char*>(Y->raw_mutable_data(meta));
+ for (int i = 0; i < outer_size; ++i) {
+ for (int t = 0; t < tiles_; ++t) {
+ context_.CopyItemsSameDevice(meta, inner_size, X_ptr, Y_ptr);
+ Y_ptr += inner_size * item_size;
+ }
+ X_ptr += inner_size * item_size;
+ }
+ return true;
+}
+
REGISTER_CPU_OPERATOR(Tile, TileOp<CPUContext>);
-REGISTER_CPU_OPERATOR(TileGradient, TileGradientOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(TileGradient, TileGradientOp<CPUContext>);
OPERATOR_SCHEMA(Tile)
.NumInputs(1, 3)
.NumOutputs(1)
- .TensorInferenceFunction(
- [](const OperatorDef& def, const vector<TensorShape>& in) {
- vector<TensorShape> out(1);
- out[0] = TensorShape(in[0]);
- ArgumentHelper helper(def);
-
- auto tiles = helper.GetSingleArgument<int32_t>("tiles", 1);
- auto axis = helper.GetSingleArgument<int32_t>("axis", 0);
- if (in.size() > 1) {
- // Tile or axis is specified as input; we can't determine
- // the size
- out[0].set_unknown_shape(true);
- } else {
- const auto canonical_axis =
- canonical_axis_index_(axis, out[0].dims().size());
- out[0].set_dims(
- canonical_axis, out[0].dims().Get(canonical_axis) * tiles);
- }
- return out;
- })
+ .TensorInferenceFunction([](const OperatorDef& def,
+ const std::vector<TensorShape>& in) {
+ std::vector<TensorShape> out(1);
+ out[0] = TensorShape(in[0]);
+ ArgumentHelper helper(def);
+ const std::int32_t tiles =
+ helper.GetSingleArgument<std::int32_t>("tiles", 1);
+ const std::int32_t axis =
+ helper.GetSingleArgument<std::int32_t>("axis", 0);
+ if (in.size() > 1) {
+ // Tile or axis is specified as input; we can't determine
+ // the size
+ out[0].set_unknown_shape(true);
+ } else {
+ const auto canonical_axis =
+ canonical_axis_index_(axis, out[0].dims().size());
+ out[0].set_dims(
+ canonical_axis, out[0].dims().Get(canonical_axis) * tiles);
+ }
+ return out;
+ })
.SetDoc(R"DOC(
Constructs a tensor by tiling a given tensor along a specified axis. This operation creates a new tensor by replicating the input tensor a number of times specified by the `tiles` argument along the `axis` dimension. The output tensor's `axis` dimension has $(X.dims(axis) * tiles)$ elements.
@@ -97,12 +167,14 @@ Y:
OPERATOR_SCHEMA(TileGradient).NumInputs(1, 3).NumOutputs(1);
+namespace {
+
class GetTileGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
- vector<OperatorDef> GetGradientDefs() override {
+ std::vector<OperatorDef> GetGradientDefs() override {
// Check whether the tiles/axis information was
// passed through input arguments
- vector<std::string> g_inputs({GO(0)});
+ std::vector<std::string> g_inputs({GO(0)});
if (Def().input_size() > 1) {
g_inputs.push_back(I(1));
}
@@ -110,10 +182,12 @@ class GetTileGradient : public GradientMakerBase {
g_inputs.push_back(I(2));
}
return SingleGradientDef(
- "TileGradient", "", g_inputs, vector<string>{GI(0)});
+ "TileGradient", "", g_inputs, std::vector<std::string>{GI(0)});
}
};
+} // namespace
+
REGISTER_GRADIENT(Tile, GetTileGradient);
} // namespace caffe2
diff --git a/caffe2/operators/tile_op.cu b/caffe2/operators/tile_op.cu
index b8f78201b3..a0a7e294e7 100644
--- a/caffe2/operators/tile_op.cu
+++ b/caffe2/operators/tile_op.cu
@@ -1,93 +1,102 @@
-#include <cub/block/block_reduce.cuh>
+#include "caffe2/operators/tile_op.h"
+
+#include <array>
#include "caffe2/core/context_gpu.h"
-#include "caffe2/operators/tile_op.h"
+#include "caffe2/utils/math.h"
namespace caffe2 {
+
namespace {
+
template <typename T>
-__global__ void TileCopyKernel(
- int outer_dim,
- int inner_dim,
- int tiles,
- const T* input_data,
- T* output_data) {
- CUDA_1D_KERNEL_LOOP(index, outer_dim * inner_dim * tiles) {
- int col = index % inner_dim;
- int row = index / (inner_dim * tiles);
- output_data[index] = input_data[row * inner_dim + col];
+__global__ void TileCopyCUDAKernel(
+ const int total_size,
+ const int inner_size,
+ const int tiles,
+ const T* X,
+ T* Y) {
+ const int x = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (x < total_size) {
+ const int r = x / inner_size / tiles;
+ const int c = x % inner_size;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+ Y[x] = __ldg(X + r * inner_size + c);
+#else
+ Y[x] = X[r * inner_size + c];
+#endif
}
}
-template <typename T>
-__global__ void TileGradientAxpyKernel(
- int outer_dim,
- int inner_dim,
- int tiles,
- const T* input_data,
- T* output_data) {
- typedef cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS> BlockReduce;
-
- for (int idx = blockIdx.x; idx < outer_dim * inner_dim; idx += gridDim.x) {
- int i = idx / inner_dim;
- int j = idx % inner_dim;
- T* output_ptr = output_data + inner_dim * i;
+} // namespace
- T x = 0.0;
- for (int t = threadIdx.x; t < tiles; t += blockDim.x) {
- const T* input_ptr = input_data + (i * tiles + t) * inner_dim;
- x += input_ptr[j];
- }
- __shared__ typename BlockReduce::TempStorage temp_storage;
- T totx = BlockReduce(temp_storage).Sum(x);
- if (threadIdx.x == 0) {
- output_ptr[j] = totx;
- }
- __syncthreads();
- }
+template <>
+template <typename T>
+bool TileOp<CUDAContext>::DoTile(
+ const int outer_size,
+ const int inner_size,
+ const T* X,
+ T* Y) {
+ const std::int64_t total_size = static_cast<std::int64_t>(outer_size) *
+ static_cast<std::int64_t>(tiles_) * static_cast<std::int64_t>(inner_size);
+ const int M = math::DivUp<std::int64_t>(total_size, CAFFE_CUDA_NUM_THREADS);
+ TileCopyCUDAKernel<T>
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ total_size, inner_size, tiles_, X, Y);
+ return true;
}
-} // namespace
template <>
-void TileOp<CUDAContext>::DoTile(
- const TypeMeta& meta,
- int item_size,
- int outer_dim,
- int inner_dim,
- const char* input_data,
- char* output_data) {
- TileCopyKernel<float>
- <<<std::min(outer_dim * inner_dim * tiles_, CAFFE_MAXIMUM_NUM_BLOCKS),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- outer_dim,
- inner_dim,
- tiles_,
- reinterpret_cast<const float*>(input_data),
- reinterpret_cast<float*>(output_data));
+template <typename T>
+bool TileGradientOp<CUDAContext>::DoTileGradient(
+ const int outer_size,
+ const int inner_size,
+ const T* dY,
+ T* dX) {
+ const std::array<int, 3> dY_dims = {outer_size, tiles_, inner_size};
+ const std::array<int, 3> dX_dims = {outer_size, 1, inner_size};
+ math::ReduceSum<T, CUDAContext>(
+ 3, dY_dims.data(), dX_dims.data(), T(1), dY, dX, &context_);
+ return true;
}
template <>
-void TileGradientOp<float, CUDAContext>::DoTileGradient(
- const TypeMeta& meta,
- int item_size,
- int outer_dim,
- int inner_dim,
- const char* input_data,
- char* output_data) {
- TileGradientAxpyKernel<float><<<
- std::min(outer_dim * inner_dim, CAFFE_MAXIMUM_NUM_BLOCKS),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- outer_dim,
- inner_dim,
- tiles_,
- reinterpret_cast<const float*>(input_data),
- reinterpret_cast<float*>(output_data));
+template <>
+bool TileGradientOp<CUDAContext>::DoTileGradient<float>(
+ const int outer_size,
+ const int inner_size,
+ const float* dY,
+ float* dX) {
+ if (inner_size == 1) {
+ const std::array<int, 2> dY_dims = {outer_size, tiles_};
+ const std::array<int, 2> dX_dims = {outer_size, 1};
+ math::ReduceSum<float, CUDAContext>(
+ 2, dY_dims.data(), dX_dims.data(), 1.0f, dY, dX, &context_);
+ } else {
+ ReinitializeTensor(&ones_, tiles_, at::dtype<float>().device(CUDA));
+ math::Set<float, CUDAContext>(
+ tiles_, 1.0f, ones_.template mutable_data<float>(), &context_);
+ math::GemmStridedBatched<float, CUDAContext>(
+ CblasTrans,
+ CblasNoTrans,
+ outer_size,
+ inner_size,
+ 1,
+ tiles_,
+ 1.0f,
+ dY,
+ tiles_ * inner_size,
+ ones_.template data<float>(),
+ 0,
+ 0.0f,
+ dX,
+ inner_size,
+ &context_);
+ }
+ return true;
}
REGISTER_CUDA_OPERATOR(Tile, TileOp<CUDAContext>);
-REGISTER_CUDA_OPERATOR(TileGradient, TileGradientOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(TileGradient, TileGradientOp<CUDAContext>);
+
} // namespace caffe2
diff --git a/caffe2/operators/tile_op.h b/caffe2/operators/tile_op.h
index df33bc6461..71f0bde2d7 100644
--- a/caffe2/operators/tile_op.h
+++ b/caffe2/operators/tile_op.h
@@ -1,29 +1,40 @@
#ifndef CAFFE2_OPERATORS_TILE_OP_H_
#define CAFFE2_OPERATORS_TILE_OP_H_
+#include <array>
+#include <string>
+#include <type_traits>
+#include <vector>
+
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
+#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
// Copy a Blob n times along a specified axis.
template <class Context>
-class TileOp : public Operator<Context> {
+class TileOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
+
template <class... Args>
explicit TileOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
- tiles_(this->template GetSingleArgument<int32_t>("tiles", 1)),
- axis_(this->template GetSingleArgument<int32_t>("axis", 0)) {}
- ~TileOp() {}
+ OP_SINGLE_ARG(std::int32_t, "tiles", tiles_, 1),
+ OP_SINGLE_ARG(std::int32_t, "axis", axis_, 0) {}
bool RunOnDevice() override {
- const auto& input = Input(0);
- std::array<int32_t, 2> temp_params = {{tiles_, axis_}};
+ return DispatchHelper<
+ TensorTypes<std::int32_t, std::int64_t, float, double>>::
+ call(this, Input(0));
+ }
+
+ template <typename T>
+ bool DoRunWithType() {
if (InputSize() > 1) {
// We potentially have tiles and/or axis specified as inputs
// as well. We will check for them in that order. In other words:
@@ -33,25 +44,12 @@ class TileOp : public Operator<Context> {
CAFFE_ENFORCE(
Input(1).dim() == 1 && Input(1).numel() == 1,
"Input `tiles` should be a vector of size 1.");
-
- const auto& input1 = Input(1);
- context_.CopyItemsToCPU(
- input1.dtype(),
- 1,
- static_cast<const char*>(input1.raw_data()),
- &(temp_params[0]));
-
+ tiles_ = GetArgFromTensor(Input(1));
if (InputSize() > 2) {
CAFFE_ENFORCE(
Input(2).dim() == 1 && Input(2).numel() == 1,
"Input `axis` should be a vector of size 1.");
-
- const auto& input2 = Input(2);
- context_.CopyItemsToCPU(
- input2.dtype(),
- 1,
- static_cast<const char*>(input2.raw_data()),
- &(temp_params[1]));
+ axis_ = GetArgFromTensor(Input(2));
} else {
CAFFE_ENFORCE(
OperatorBase::HasArgument("axis"),
@@ -66,79 +64,82 @@ class TileOp : public Operator<Context> {
"Argument `axis` is missing and was not specified as input.");
}
- tiles_ = temp_params[0];
- axis_ = temp_params[1];
-
- auto* output = Output(0);
- const auto axis = input.canonical_axis_index(axis_);
+ const auto& X = Input(0);
+ auto* Y = Output(0);
+ const int axis = X.canonical_axis_index(axis_);
// reshape output to be input tiled along the axis
- vector<int64_t> output_dims(input.sizes().vec());
- output_dims[axis_] = output_dims[axis_] * tiles_;
- output->Resize(output_dims);
+ std::vector<std::int64_t> Y_dims = X.sizes().vec();
+ Y_dims[axis] *= tiles_;
+ Y->Resize(Y_dims);
// size up to (and not including) axis
- const auto outer_dim = input.size_to_dim(axis);
+ const int outer_size = X.size_to_dim(axis);
// size from axis up
- const auto inner_dim = input.size_from_dim(axis);
+ const int inner_size = X.size_from_dim(axis);
- /**
- * How this works:
- * Imagine a 2D tensor (matrix) of size 3x10, tiled 2 times.
- * - Tiling along axis 0 (row) means copying the entire 3x10 Matrix 2
- * times. outer_dim = 0, inner_dim = 30.
- * - Tiling along axis 1 (column) means copying each row 2 times, then
- * proceed to the next row, until the end. outer_dim = 3, inner_dim = 10.
- */
- const char* input_data = static_cast<const char*>(input.raw_data());
- char* output_data =
- static_cast<char*>(output->raw_mutable_data(input.dtype()));
-
- DoTile(
- input.dtype(),
- input.itemsize(),
- outer_dim,
- inner_dim,
- input_data,
- output_data);
-
- return true;
+ const T* X_data = X.template data<T>();
+ T* Y_data = Y->template mutable_data<T>();
+ return DoTile<T>(outer_size, inner_size, X_data, Y_data);
}
private:
- void DoTile(
- const TypeMeta& meta,
- int item_size,
- int outer_dim,
- int inner_dim,
- const char* input_data,
- char* output_data) {
- for (auto i = 0; i < outer_dim; ++i) {
- for (auto t = 0; t < tiles_; ++t) {
- context_.CopyItemsSameDevice(meta, inner_dim, input_data, output_data);
- output_data += inner_dim * item_size;
+ std::int32_t GetArgFromTensor(const Tensor& tensor) {
+ CAFFE_ENFORCE(
+ tensor.IsType<std::int32_t>() || tensor.IsType<std::int64_t>());
+ std::int32_t val = -1;
+ if (tensor.IsType<std::int32_t>()) {
+ context_.template CopyToCPU<std::int32_t>(
+ 1, tensor.data<std::int32_t>(), &val);
+ } else if (tensor.IsType<std::int64_t>()) {
+ std::int64_t val_int64;
+ context_.template CopyToCPU<std::int64_t>(
+ 1, tensor.data<std::int64_t>(), &val_int64);
+ val = static_cast<std::int32_t>(val_int64);
+ }
+ return val;
+ }
+
+ template <typename T>
+ bool DoTile(const int outer_size, const int inner_size, const T* X, T* Y) {
+ if (inner_size == 1) {
+ EigenArrayMap<T> Y_arr(Y, tiles_, outer_size);
+ for (int i = 0; i < outer_size; ++i) {
+ Y_arr.col(i) = X[i];
+ }
+ } else {
+ ConstEigenArrayMap<T> X_arr(X, inner_size, outer_size);
+ for (int i = 0; i < outer_size; ++i) {
+ EigenArrayMap<T>(Y + i * tiles_ * inner_size, inner_size, tiles_)
+ .colwise() = X_arr.col(i);
}
- input_data += inner_dim * item_size;
}
+ return true;
}
- int32_t tiles_;
- int32_t axis_;
+ std::int32_t tiles_;
+ std::int32_t axis_;
};
-template <typename T, class Context>
-class TileGradientOp : public Operator<Context> {
+template <class Context>
+class TileGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
+
template <class... Args>
explicit TileGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
- tiles_(this->template GetSingleArgument<int32_t>("tiles", 1)),
- axis_(this->template GetSingleArgument<int32_t>("axis", 0)) {}
- ~TileGradientOp() {}
+ OP_SINGLE_ARG(std::int32_t, "tiles", tiles_, 1),
+ OP_SINGLE_ARG(std::int32_t, "axis", axis_, 0) {}
bool RunOnDevice() override {
- std::array<int32_t, 2> temp_params = {{tiles_, axis_}};
+ return DispatchHelper<
+ TensorTypes<std::int32_t, std::int64_t, float, double>>::
+ call(this, Input(0));
+ }
+
+ template <typename T>
+ bool DoRunWithType() {
if (InputSize() > 1) {
// We potentially have tiles and/or axis specified as inputs
// as well. We will check for them in that order. In other words:
@@ -148,25 +149,12 @@ class TileGradientOp : public Operator<Context> {
CAFFE_ENFORCE(
Input(1).dim() == 1 && Input(1).numel() == 1,
"Input `tiles` should be a vector of size 1.");
-
- const auto& input1 = Input(1);
- context_.CopyItemsToCPU(
- input1.dtype(),
- 1,
- static_cast<const char*>(input1.raw_data()),
- &(temp_params[0]));
-
+ tiles_ = GetArgFromTensor(Input(1));
if (InputSize() > 2) {
CAFFE_ENFORCE(
Input(2).dim() == 1 && Input(2).numel() == 1,
"Input `axis` should be a vector of size 1.");
-
- const auto& input2 = Input(2);
- context_.CopyItemsToCPU(
- input2.dtype(),
- 1,
- static_cast<const char*>(input2.raw_data()),
- &(temp_params[1]));
+ axis_ = GetArgFromTensor(Input(2));
} else {
CAFFE_ENFORCE(
OperatorBase::HasArgument("axis"),
@@ -181,22 +169,20 @@ class TileGradientOp : public Operator<Context> {
"Argument `axis` is missing and was not specified as input.");
}
- tiles_ = temp_params[0];
- axis_ = temp_params[1];
-
- const auto& input = Input(0);
- auto* output = Output(0);
- const auto axis = input.canonical_axis_index(axis_);
+ const auto& dY = Input(0);
+ auto* dX = Output(0);
+ const int axis = dY.canonical_axis_index(axis_);
// reshape output to be input "untiled" along the axis
- vector<int64_t> output_dims(input.sizes().vec());
- output_dims[axis_] = output_dims[axis_] / tiles_;
- output->Resize(output_dims);
+ std::vector<std::int64_t> X_dims = dY.sizes().vec();
+ CAFFE_ENFORCE_EQ(X_dims[axis] % tiles_, 0);
+ X_dims[axis] /= tiles_;
+ dX->Resize(X_dims);
// size up to (and not including) axis
- const auto outer_dim = output->size_to_dim(axis);
+ const int outer_size = dX->size_to_dim(axis);
// size from axis up
- const auto inner_dim = output->size_from_dim(axis);
+ const int inner_size = dX->size_from_dim(axis);
/**
* How this works:
@@ -208,47 +194,64 @@ class TileGradientOp : public Operator<Context> {
* So the output gradient should be the matrix multipication result
* of input gradient (gradient of tiled tensor output) and X.
*/
- const char* input_data = static_cast<const char*>(input.raw_data());
- char* output_data =
- static_cast<char*>(output->raw_mutable_data(input.dtype()));
-
- DoTileGradient(
- input.dtype(),
- input.itemsize(),
- outer_dim,
- inner_dim,
- input_data,
- output_data);
-
- return true;
+ const T* dY_data = dY.template data<T>();
+ T* dX_data = dX->template mutable_data<T>();
+ return DoTileGradient<T>(outer_size, inner_size, dY_data, dX_data);
}
private:
- void DoTileGradient(
- const TypeMeta& meta,
- int item_size,
- int outer_dim,
- int inner_dim,
- const char* input_data,
- char* output_data) {
- for (auto i = 0; i < outer_dim; ++i) {
- context_.CopyItemsSameDevice(meta, inner_dim, input_data, output_data);
- input_data += inner_dim * item_size;
- for (auto t = 1; t < tiles_; ++t) {
- math::Axpy<T, Context>(
- inner_dim,
- T(1),
- reinterpret_cast<const T*>(input_data),
- reinterpret_cast<T*>(output_data),
- &context_);
- input_data += inner_dim * item_size;
+ std::int32_t GetArgFromTensor(const Tensor& tensor) {
+ CAFFE_ENFORCE(
+ tensor.IsType<std::int32_t>() || tensor.IsType<std::int64_t>());
+ std::int32_t val = -1;
+ if (tensor.IsType<std::int32_t>()) {
+ context_.template CopyToCPU<std::int32_t>(
+ 1, tensor.data<std::int32_t>(), &val);
+ } else if (tensor.IsType<std::int64_t>()) {
+ std::int64_t val_int64;
+ context_.template CopyToCPU<std::int64_t>(
+ 1, tensor.data<std::int64_t>(), &val_int64);
+ val = static_cast<std::int32_t>(val_int64);
+ }
+ return val;
+ }
+
+ template <typename T>
+ bool DoTileGradient(
+ const int outer_size,
+ const int inner_size,
+ const T* dY,
+ T* dX) {
+ if (inner_size == 1) {
+ const std::array<int, 2> dY_dims = {outer_size, tiles_};
+ const std::array<int, 2> dX_dims = {outer_size, 1};
+ math::ReduceSum<T, Context>(
+ 2, dY_dims.data(), dX_dims.data(), T(1), dY, dX, &context_);
+ } else {
+ math::CopyMatrix<T, Context>(
+ outer_size,
+ inner_size,
+ dY,
+ inner_size * tiles_,
+ dX,
+ inner_size,
+ &context_);
+ for (int i = 0; i < outer_size; ++i) {
+ const T* dY_ptr = dY + i * tiles_ * inner_size;
+ T* dX_ptr = dX + i * inner_size;
+ for (int j = 1; j < tiles_; ++j) {
+ math::Add<T, Context>(
+ inner_size, dX_ptr, dY_ptr + j * inner_size, dX_ptr, &context_);
+ }
}
- output_data += inner_dim * item_size;
}
+ return true;
}
- int32_t tiles_;
- int32_t axis_;
+ std::int32_t tiles_;
+ std::int32_t axis_;
+
+ Tensor ones_;
};
} // namespace caffe2