summaryrefslogtreecommitdiff
path: root/caffe2/operators/filler_op.h
diff options
context:
space:
mode:
authorKevin Wilfong <kevinwilfong@fb.com>2017-08-16 12:48:53 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-08-16 13:05:11 -0700
commit1f47a80e8846fa367de36e7fe58b9463678adf5f (patch)
tree1b06217000e4625c333a01e7d5eb6072913d4dd0 /caffe2/operators/filler_op.h
parent30616ee309ae4fed777927dc5b0c5620e207bbf9 (diff)
downloadpytorch-1f47a80e8846fa367de36e7fe58b9463678adf5f.tar.gz
pytorch-1f47a80e8846fa367de36e7fe58b9463678adf5f.tar.bz2
pytorch-1f47a80e8846fa367de36e7fe58b9463678adf5f.zip
Caffe2: diagonal fill op
Summary: Caffe2: diagonal fill op Reviewed By: panshen1 Differential Revision: D4775640 fbshipit-source-id: bb388ffe223e6b153d4cde1fdad6f84a2bb65b0f
Diffstat (limited to 'caffe2/operators/filler_op.h')
-rw-r--r--caffe2/operators/filler_op.h101
1 files changed, 101 insertions, 0 deletions
diff --git a/caffe2/operators/filler_op.h b/caffe2/operators/filler_op.h
index 7564f0423f..d17bd0e6f5 100644
--- a/caffe2/operators/filler_op.h
+++ b/caffe2/operators/filler_op.h
@@ -295,6 +295,107 @@ class ConstantFillOp final : public FillerOp<Context> {
bool (ConstantFillOp::*body_)(Tensor<Context>* output);
};
+template <class Context>
+class DiagonalFillOp final : public FillerOp<Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+ DiagonalFillOp(const OperatorDef& operator_def, Workspace* ws)
+ : FillerOp<Context>(operator_def, ws) {
+ TensorProto_DataType dtype =
+ static_cast<TensorProto_DataType>(OperatorBase::GetSingleArgument<int>(
+ "dtype", TensorProto_DataType_FLOAT));
+
+ if (!OperatorBase::HasArgument("dtype") &&
+ OperatorBase::HasArgument("value")) {
+ // If 'dtype' is not provided, infer type based on the type of 'value'
+ // Currently, single argument contains either float, int64 or bytes
+ if (OperatorBase::HasSingleArgumentOfType<float>("value")) {
+ dtype = TensorProto_DataType_FLOAT;
+ } else if (OperatorBase::HasSingleArgumentOfType<int64_t>("value")) {
+ dtype = TensorProto_DataType_INT64;
+ } else {
+ CAFFE_THROW("Argument 'value' is of unexpected type");
+ }
+ VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is "
+ << "the same as that of argument 'value': " << dtype;
+ }
+
+ switch (dtype) {
+ case TensorProto_DataType_FLOAT:
+ body_ = &DiagonalFillOp::FillWithType<float>;
+ break;
+ case TensorProto_DataType_DOUBLE:
+ body_ = &DiagonalFillOp::FillWithType<double>;
+ break;
+ case TensorProto_DataType_BOOL:
+ body_ = &DiagonalFillOp::FillWithType<bool>;
+ break;
+ case TensorProto_DataType_INT8:
+ body_ = &DiagonalFillOp::FillWithType<int8_t>;
+ break;
+ case TensorProto_DataType_INT16:
+ body_ = &DiagonalFillOp::FillWithType<int16_t>;
+ break;
+ case TensorProto_DataType_INT32:
+ body_ = &DiagonalFillOp::FillWithType<int>;
+ break;
+ case TensorProto_DataType_INT64:
+ body_ = &DiagonalFillOp::FillWithType<int64_t>;
+ break;
+ case TensorProto_DataType_UINT8:
+ body_ = &DiagonalFillOp::FillWithType<uint8_t>;
+ break;
+ case TensorProto_DataType_UINT16:
+ body_ = &DiagonalFillOp::FillWithType<uint16_t>;
+ break;
+ case TensorProto_DataType_UNDEFINED:
+ CAFFE_THROW("Cannot have undefined 'dtype' argument");
+ default:
+ CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
+ }
+ }
+
+ bool Fill(Tensor<Context>* output) override {
+ return (this->*body_)(output);
+ }
+
+ template <typename T>
+ bool FillWithType(Tensor<Context>* output);
+
+ private:
+ void VerifyOutputShape(Tensor<Context>* output) {
+ CAFFE_ENFORCE(output->ndim() >= 2, "Input shape must be >= 2D");
+ }
+
+ TIndex GetStepSize(Tensor<Context>* output) {
+ TIndex step;
+ if (output->ndim() == 2) {
+ step = output->dim(1) + 1;
+ } else {
+ TIndex prev_i = output->dim(0);
+ for (auto i : output->dims()) {
+ if (i != prev_i) {
+ CAFFE_THROW("All dimensions of input must be of equal length");
+ }
+ }
+ vector<TIndex> cumprod(output->ndim());
+ auto dims = output->dims();
+ std::partial_sum(
+ dims.begin(),
+ dims.end() - 1,
+ cumprod.begin(),
+ std::multiplies<TIndex>());
+ step = 1 +
+ std::accumulate(
+ cumprod.begin(), cumprod.end(), static_cast<TIndex>(0));
+ VLOG(0) << step;
+ }
+ return step;
+ }
+
+ bool (DiagonalFillOp::*body_)(Tensor<Context>* output);
+};
+
template <typename T, class Context>
class GaussianFillOp final : public FillerOp<Context> {
public: