diff options
author | Kevin Wilfong <kevinwilfong@fb.com> | 2017-08-16 12:48:53 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2017-08-16 13:05:11 -0700 |
commit | 1f47a80e8846fa367de36e7fe58b9463678adf5f (patch) | |
tree | 1b06217000e4625c333a01e7d5eb6072913d4dd0 /caffe2/operators/filler_op.h | |
parent | 30616ee309ae4fed777927dc5b0c5620e207bbf9 (diff) | |
download | pytorch-1f47a80e8846fa367de36e7fe58b9463678adf5f.tar.gz pytorch-1f47a80e8846fa367de36e7fe58b9463678adf5f.tar.bz2 pytorch-1f47a80e8846fa367de36e7fe58b9463678adf5f.zip |
Caffe2: diagonal fill op
Summary: Caffe2: diagonal fill op
Reviewed By: panshen1
Differential Revision: D4775640
fbshipit-source-id: bb388ffe223e6b153d4cde1fdad6f84a2bb65b0f
Diffstat (limited to 'caffe2/operators/filler_op.h')
-rw-r--r-- | caffe2/operators/filler_op.h | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/caffe2/operators/filler_op.h b/caffe2/operators/filler_op.h index 7564f0423f..d17bd0e6f5 100644 --- a/caffe2/operators/filler_op.h +++ b/caffe2/operators/filler_op.h @@ -295,6 +295,107 @@ class ConstantFillOp final : public FillerOp<Context> { bool (ConstantFillOp::*body_)(Tensor<Context>* output); }; +template <class Context> +class DiagonalFillOp final : public FillerOp<Context> { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + DiagonalFillOp(const OperatorDef& operator_def, Workspace* ws) + : FillerOp<Context>(operator_def, ws) { + TensorProto_DataType dtype = + static_cast<TensorProto_DataType>(OperatorBase::GetSingleArgument<int>( + "dtype", TensorProto_DataType_FLOAT)); + + if (!OperatorBase::HasArgument("dtype") && + OperatorBase::HasArgument("value")) { + // If 'dtype' is not provided, infer type based on the type of 'value' + // Currently, single argument contains either float, int64 or bytes + if (OperatorBase::HasSingleArgumentOfType<float>("value")) { + dtype = TensorProto_DataType_FLOAT; + } else if (OperatorBase::HasSingleArgumentOfType<int64_t>("value")) { + dtype = TensorProto_DataType_INT64; + } else { + CAFFE_THROW("Argument 'value' is of unexpected type"); + } + VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is " + << "the same as that of argument 'value': " << dtype; + } + + switch (dtype) { + case TensorProto_DataType_FLOAT: + body_ = &DiagonalFillOp::FillWithType<float>; + break; + case TensorProto_DataType_DOUBLE: + body_ = &DiagonalFillOp::FillWithType<double>; + break; + case TensorProto_DataType_BOOL: + body_ = &DiagonalFillOp::FillWithType<bool>; + break; + case TensorProto_DataType_INT8: + body_ = &DiagonalFillOp::FillWithType<int8_t>; + break; + case TensorProto_DataType_INT16: + body_ = &DiagonalFillOp::FillWithType<int16_t>; + break; + case TensorProto_DataType_INT32: + body_ = &DiagonalFillOp::FillWithType<int>; + break; + case TensorProto_DataType_INT64: + body_ = &DiagonalFillOp::FillWithType<int64_t>; + break; + case TensorProto_DataType_UINT8: + body_ = &DiagonalFillOp::FillWithType<uint8_t>; + break; + case TensorProto_DataType_UINT16: + body_ = &DiagonalFillOp::FillWithType<uint16_t>; + break; + case TensorProto_DataType_UNDEFINED: + CAFFE_THROW("Cannot have undefined 'dtype' argument"); + default: + CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype); + } + } + + bool Fill(Tensor<Context>* output) override { + return (this->*body_)(output); + } + + template <typename T> + bool FillWithType(Tensor<Context>* output); + + private: + void VerifyOutputShape(Tensor<Context>* output) { + CAFFE_ENFORCE(output->ndim() >= 2, "Input shape must be >= 2D"); + } + + TIndex GetStepSize(Tensor<Context>* output) { + TIndex step; + if (output->ndim() == 2) { + step = output->dim(1) + 1; + } else { + TIndex prev_i = output->dim(0); + for (auto i : output->dims()) { + if (i != prev_i) { + CAFFE_THROW("All dimensions of input must be of equal length"); + } + } + vector<TIndex> cumprod(output->ndim()); + auto dims = output->dims(); + std::partial_sum( + dims.begin(), + dims.end() - 1, + cumprod.begin(), + std::multiplies<TIndex>()); + step = 1 + + std::accumulate( + cumprod.begin(), cumprod.end(), static_cast<TIndex>(0)); + VLOG(0) << step; + } + return step; + } + + bool (DiagonalFillOp::*body_)(Tensor<Context>* output); +}; + template <typename T, class Context> class GaussianFillOp final : public FillerOp<Context> { public: |