diff options
author | Yinghai Lu <yinghai@fb.com> | 2018-03-22 22:50:27 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-03-22 22:50:27 -0700 |
commit | 21918b94e4c4091b98e5d84a840bca92eb2e0d78 (patch) | |
tree | 31f40c038b622aa548234caa6c5d45e752d0bc21 | |
parent | bbb7c722df9fc502ed5204f1bed60f0f2b5c04c1 (diff) | |
download | pytorch-21918b94e4c4091b98e5d84a840bca92eb2e0d78.tar.gz pytorch-21918b94e4c4091b98e5d84a840bca92eb2e0d78.tar.bz2 pytorch-21918b94e4c4091b98e5d84a840bca92eb2e0d78.zip |
Add InheritOnnxSchema property to c2 op schema (#2366)
* Add InheritOnnxSchema property to c2 op schema
* Add onnx inherit for {Conv,Maxpool,AveragePool}{1D,2D,3D}
32 files changed, 115 insertions, 52 deletions
diff --git a/caffe2/core/operator_schema.cc b/caffe2/core/operator_schema.cc index d202e7e20c..ae67f575de 100644 --- a/caffe2/core/operator_schema.cc +++ b/caffe2/core/operator_schema.cc @@ -15,7 +15,6 @@ */ #include "caffe2/core/operator_schema.h" - #include "caffe2/core/logging.h" namespace caffe2 { @@ -212,6 +211,11 @@ OpSchema& OpSchema::TensorInferenceFunction( return *this; } +OpSchema& OpSchema::InheritOnnxSchema(const std::string& onnx_schema_name) { + onnx_schema_ = onnx_schema_name; + return *this; +} + OpSchema& OpSchema::IdenticalTypeAndShape() { return TensorInferenceFunction( [](const OperatorDef&, const vector<TensorShape>& input_types) { diff --git a/caffe2/core/operator_schema.h b/caffe2/core/operator_schema.h index b24a7af7bb..7244a91ff0 100644 --- a/caffe2/core/operator_schema.h +++ b/caffe2/core/operator_schema.h @@ -23,6 +23,7 @@ #include <ostream> #include <set> #include <vector> +#include <unordered_map> #include "caffe2/core/common.h" #include "caffe2/core/logging.h" @@ -155,11 +156,18 @@ class OpSchema { typedef std::function< vector<TensorShape>(const OperatorDef&, const vector<TensorShape>&)> TensorInferenceFunctionType; + /** * @brief Sets the tensor inference function, which is a std::function object * defined in operator_schema.h. */ OpSchema& TensorInferenceFunction(TensorInferenceFunctionType function); + + /** + * @brief Sets the corresponding onnx schema name + */ + OpSchema& InheritOnnxSchema(const std::string& onnx_schema_name); + /** * @brief Sets the tensor inference function to produce the same output as * the input. @@ -175,7 +183,7 @@ class OpSchema { */ inline vector<TensorShape> InferTensor( const OperatorDef& def, - const vector<TensorShape> input_type_shape) const { + const vector<TensorShape>& input_type_shape) const { return tensor_inference_function_(def, input_type_shape); } @@ -284,6 +292,10 @@ class OpSchema { */ int CalculateOutput(int num_input) const; + const std::string& onnx_schema() const { + return onnx_schema_; + } + int min_input() const { return min_input_; } @@ -355,6 +367,7 @@ class OpSchema { private: string file_; string doc_; + string onnx_schema_; std::vector<Argument> args_{}; std::vector<std::pair<const char*, const char*>> input_desc_{}; std::vector<std::pair<const char*, const char*>> output_desc_{}; diff --git a/caffe2/operators/abs_op.cc b/caffe2/operators/abs_op.cc index a9b1d23a1d..8769b2fde7 100644 --- a/caffe2/operators/abs_op.cc +++ b/caffe2/operators/abs_op.cc @@ -59,7 +59,8 @@ Calculates the absolute value of the given input tensor, element-wise. .Output( 0, "output", - "The absolute value of the input tensor computed element-wise"); + "The absolute value of the input tensor computed element-wise") + .InheritOnnxSchema("Abs"); OPERATOR_SCHEMA(AbsGradient).NumInputs(2).NumOutputs(1).IdenticalTypeAndShape(); diff --git a/caffe2/operators/batch_matmul_op.cc b/caffe2/operators/batch_matmul_op.cc index be77155415..8acfa188e2 100644 --- a/caffe2/operators/batch_matmul_op.cc +++ b/caffe2/operators/batch_matmul_op.cc @@ -127,7 +127,8 @@ two diemnsional, it behaves like normal matrix multiplication. return vector<TensorShape>{ CreateTensorShape(vector<TIndex>{new_dims}, in[0].data_type())}; } - }); + }) + .InheritOnnxSchema("MatMul"); class GetBatchMatMulGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; diff --git a/caffe2/operators/clip_op.cc b/caffe2/operators/clip_op.cc index 1866aa377b..b7f2b2f56b 100644 --- a/caffe2/operators/clip_op.cc +++ b/caffe2/operators/clip_op.cc @@ -73,7 +73,8 @@ are the same. 1, "output", "Output tensor (Tensor<float>) containing clipped" - "input elements"); + "input elements") + .InheritOnnxSchema("Clip"); OPERATOR_SCHEMA(ClipGradient).NumInputs(2).NumOutputs(1).AllowInplace({{1, 0}}); diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc index 200fe8de31..8e1b019158 100644 --- a/caffe2/operators/concat_split_op.cc +++ b/caffe2/operators/concat_split_op.cc @@ -49,7 +49,8 @@ Split a tensor into a list of tensors, along the specified 'axis'. The lengths of the split can be specified using argument 'split' or optional second input blob to the operator. Otherwise, the tensor is split to equal sized parts. -)DOC"); +)DOC") + .InheritOnnxSchema("Split"); namespace { OpSchema::Cost CostInferenceForConcat( diff --git a/caffe2/operators/conv_op.cc b/caffe2/operators/conv_op.cc index f53f46afb6..e9da05dedd 100644 --- a/caffe2/operators/conv_op.cc +++ b/caffe2/operators/conv_op.cc @@ -76,7 +76,8 @@ OPERATOR_SCHEMA(Conv) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv) .CostInferenceFunction(OpSchema::CostInferenceFunctionType( ConvPoolOpBase<CPUContext>::CostInferenceForConv)) - .FillUsing(ConvDocGenerator("")); + .FillUsing(ConvDocGenerator("")) + .InheritOnnxSchema("Conv"); REGISTER_CPU_OPERATOR(Conv1D, ConvOp<float, CPUContext>); @@ -84,7 +85,8 @@ OPERATOR_SCHEMA(Conv1D) .NumInputs(2, 3) .NumOutputs(1) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv) - .FillUsing(ConvDocGenerator("1D ")); + .FillUsing(ConvDocGenerator("1D ")) + .InheritOnnxSchema("Conv"); REGISTER_CPU_OPERATOR(Conv2D, ConvOp<float, CPUContext>); @@ -94,7 +96,8 @@ OPERATOR_SCHEMA(Conv2D) .CostInferenceFunction(OpSchema::CostInferenceFunctionType( ConvPoolOpBase<CPUContext>::CostInferenceForConv)) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv) - .FillUsing(ConvDocGenerator("2D ")); + .FillUsing(ConvDocGenerator("2D ")) + .InheritOnnxSchema("Conv"); REGISTER_CPU_OPERATOR(Conv3D, ConvOp<float, CPUContext>); @@ -104,6 +107,7 @@ OPERATOR_SCHEMA(Conv3D) .CostInferenceFunction(OpSchema::CostInferenceFunctionType( ConvPoolOpBase<CPUContext>::CostInferenceForConv)) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv) - .FillUsing(ConvDocGenerator("3D ")); + .FillUsing(ConvDocGenerator("3D ")) + .InheritOnnxSchema("Conv"); } // namespace caffe2 diff --git a/caffe2/operators/conv_transpose_op.cc b/caffe2/operators/conv_transpose_op.cc index 36b3c480d5..74563dc039 100644 --- a/caffe2/operators/conv_transpose_op.cc +++ b/caffe2/operators/conv_transpose_op.cc @@ -60,6 +60,7 @@ conv_transpose_op.h file, which is why they are separate files. "Y", "Output data blob that contains the result of the " "transposed convolution. The output dimensions are functions of the kernel" - " size, stride size, and pad lengths."); + " size, stride size, and pad lengths.") + .InheritOnnxSchema("ConvTranspose"); } // namespace caffe2 diff --git a/caffe2/operators/dropout_op.cc b/caffe2/operators/dropout_op.cc index 04de934e9f..ac8102ae4e 100644 --- a/caffe2/operators/dropout_op.cc +++ b/caffe2/operators/dropout_op.cc @@ -109,7 +109,8 @@ the training phase, so during testing nothing needs to be done. .Output( 1, "mask", - "The output mask. If is_test is nonzero, this output is not filled."); + "The output mask. If is_test is nonzero, this output is not filled.") + .InheritOnnxSchema("Dropout"); OPERATOR_SCHEMA(DropoutGrad) .NumInputs(1, 2) diff --git a/caffe2/operators/elementwise_op_schema.cc b/caffe2/operators/elementwise_op_schema.cc index 5aeb3dbf7a..a67466d713 100644 --- a/caffe2/operators/elementwise_op_schema.cc +++ b/caffe2/operators/elementwise_op_schema.cc @@ -70,28 +70,32 @@ OPERATOR_SCHEMA(Add) .AllowInplace({{0, 0}, {1, 0}}) .CostInferenceFunction(PointwiseCostInference<1>) .IdenticalTypeAndShapeOfInput(0) - .FillUsing(MathDocGenerator("addition")); + .FillUsing(MathDocGenerator("addition")) + .InheritOnnxSchema("Add"); OPERATOR_SCHEMA(Sub) .NumInputs(2) .NumOutputs(1) .AllowInplace({{0, 0}, {1, 0}}) .CostInferenceFunction(PointwiseCostInference<1>) .IdenticalTypeAndShapeOfInput(0) - .FillUsing(MathDocGenerator("subtraction")); + .FillUsing(MathDocGenerator("subtraction")) + .InheritOnnxSchema("Sub"); OPERATOR_SCHEMA(Mul) .NumInputs(2) .NumOutputs(1) .AllowInplace({{0, 0}, {1, 0}}) .CostInferenceFunction(PointwiseCostInference<1>) .IdenticalTypeAndShapeOfInput(0) - .FillUsing(MathDocGenerator("multiplication")); + .FillUsing(MathDocGenerator("multiplication")) + .InheritOnnxSchema("Mul"); OPERATOR_SCHEMA(Div) .NumInputs(2) .NumOutputs(1) .AllowInplace({{0, 0}}) .CostInferenceFunction(PointwiseCostInference<1>) .IdenticalTypeAndShapeOfInput(0) - .FillUsing(MathDocGenerator("division")); + .FillUsing(MathDocGenerator("division")) + .InheritOnnxSchema("Div"); OPERATOR_SCHEMA(DivGradient).NumInputs(3).NumOutputs(2).AllowInplace({{0, 0}}); OPERATOR_SCHEMA(SumReduceLike) @@ -347,24 +351,26 @@ Both input operands should be of type `bool`. }; } -#define CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(name, symbol) \ +#define CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(name, symbol, onnx_schema) \ OPERATOR_SCHEMA(name) \ .NumInputs(2) \ .NumOutputs(1) \ .AllowInplace({{0, 0}}) \ - .FillUsing(LogicalDocGenerator(symbol)); \ + .FillUsing(LogicalDocGenerator(symbol)) \ + .InheritOnnxSchema(onnx_schema); \ SHOULD_NOT_DO_GRADIENT(name) -CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(Or, "or"); -CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(And, "and"); -CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(Xor, "xor"); +CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(Or, "or", "Or"); +CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(And, "and", "And"); +CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(Xor, "xor", "Xor"); OPERATOR_SCHEMA(Not) .NumInputs(1) .NumOutputs(1) .SetDoc(R"DOC(Performs element-wise negation.)DOC") .Input(0, "X", "Input tensor of type `bool`.") - .Output(0, "Y", "Output tensor of type `bool`."); + .Output(0, "Y", "Output tensor of type `bool`.") + .InheritOnnxSchema("Not"); SHOULD_NOT_DO_GRADIENT(Not); } // namespace caffe2 diff --git a/caffe2/operators/elementwise_sum_op.cc b/caffe2/operators/elementwise_sum_op.cc index e1165fe427..f975584942 100644 --- a/caffe2/operators/elementwise_sum_op.cc +++ b/caffe2/operators/elementwise_sum_op.cc @@ -45,5 +45,6 @@ place and results will be accumulated in input0. All inputs and outputs must have the same shape and data type. )DOC") .Input(0, "data_0", "First of the input tensors. Can be inplace.") - .Output(0, "sum", "Output tensor. Same dimension as inputs."); + .Output(0, "sum", "Output tensor. Same dimension as inputs.") + .InheritOnnxSchema("Sum"); } diff --git a/caffe2/operators/elu_op.cc b/caffe2/operators/elu_op.cc index c3bb9e671b..6bbccac305 100644 --- a/caffe2/operators/elu_op.cc +++ b/caffe2/operators/elu_op.cc @@ -71,7 +71,8 @@ Elu takes one input data (Tensor<T>) and produces one output data )DOC") .Input(0, "X", "1D input tensor") - .Output(0, "Y", "1D input tensor"); + .Output(0, "Y", "1D input tensor") + .InheritOnnxSchema("Elu"); // Input: Y, dY, output: dX OPERATOR_SCHEMA(EluGradient) diff --git a/caffe2/operators/exp_op.cc b/caffe2/operators/exp_op.cc index dfd18668de..2af8495de7 100644 --- a/caffe2/operators/exp_op.cc +++ b/caffe2/operators/exp_op.cc @@ -46,7 +46,8 @@ and output blobs. 0, "output", "The exponential of the input tensor computed " - "element-wise"); + "element-wise") + .InheritOnnxSchema("Exp"); class GetExpGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; diff --git a/caffe2/operators/expand_squeeze_dims_op.cc b/caffe2/operators/expand_squeeze_dims_op.cc index e8029ed796..6ce3d1649e 100644 --- a/caffe2/operators/expand_squeeze_dims_op.cc +++ b/caffe2/operators/expand_squeeze_dims_op.cc @@ -104,7 +104,8 @@ This is the exact inverse operation of ExpandDims given the same `dims` arg. SqueezeOp<CPUContext>::ComputeDims(GetDimsVector(in[0]), dims); out[0] = CreateTensorShape(newDims, in[0].data_type()); return out; - }); + }) + .InheritOnnxSchema("Squeeze"); class GetSqueezeGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; diff --git a/caffe2/operators/flatten_op.cc b/caffe2/operators/flatten_op.cc index 48f6573078..4c69f19c4d 100644 --- a/caffe2/operators/flatten_op.cc +++ b/caffe2/operators/flatten_op.cc @@ -60,7 +60,8 @@ Flattens the input tensor into a 2D matrix. If input tensor has shape .Arg( "axis", "(Default to 1) Indicate up to which input dimensions " - "(exclusive) should be flattened to the outer dimension of the output"); + "(exclusive) should be flattened to the outer dimension of the output") + .InheritOnnxSchema("Flatten"); class GetFlattenGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc index 9212535a78..f4e393afc0 100644 --- a/caffe2/operators/fully_connected_op.cc +++ b/caffe2/operators/fully_connected_op.cc @@ -142,7 +142,8 @@ will throw errors. "A tensor that is coerced into a 2D blob of size (KxN) " "containing fully connected weight matrix") .Input(2, "b", "1D blob containing bias vector") - .Output(0, "Y", "2D output tensor"); + .Output(0, "Y", "2D output tensor") + .InheritOnnxSchema("Gemm"); OPERATOR_SCHEMA(FCGradient).NumInputs(3).NumOutputs(2, 3); OPERATOR_SCHEMA(FCTransposedGradient).NumInputs(3).NumOutputs(2, 3); diff --git a/caffe2/operators/leaky_relu_op.cc b/caffe2/operators/leaky_relu_op.cc index e9e748d490..6081097b8e 100644 --- a/caffe2/operators/leaky_relu_op.cc +++ b/caffe2/operators/leaky_relu_op.cc @@ -70,7 +70,8 @@ OPERATOR_SCHEMA(LeakyReluGradient) .NumInputs(2) .NumOutputs(1) .AllowInplace({{1, 0}}) - .Arg("alpha", "Coefficient of leakage"); + .Arg("alpha", "Coefficient of leakage") + .InheritOnnxSchema("LeakyRelu"); class GetLeakyReluGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; diff --git a/caffe2/operators/log_op.cc b/caffe2/operators/log_op.cc index 139ee51693..d2e7cc992c 100644 --- a/caffe2/operators/log_op.cc +++ b/caffe2/operators/log_op.cc @@ -47,7 +47,8 @@ and output blobs. 0, "output", "The natural log of the input tensor computed " - "element-wise"); + "element-wise") + .InheritOnnxSchema("Log"); class GetLogGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; diff --git a/caffe2/operators/minmax_ops.cc b/caffe2/operators/minmax_ops.cc index 480ab5bd8a..be04c27625 100644 --- a/caffe2/operators/minmax_ops.cc +++ b/caffe2/operators/minmax_ops.cc @@ -33,7 +33,8 @@ place and results will be accumulated in input0. All inputs and outputs must have the same shape and data type. )DOC") .Input(0, "data_0", "First of the input tensors. Can be inplace.") - .Output(0, "max", "Output tensor. Same dimension as inputs."); + .Output(0, "max", "Output tensor. Same dimension as inputs.") + .InheritOnnxSchema("Max"); OPERATOR_SCHEMA(Min) .NumInputs(1, INT_MAX) @@ -47,7 +48,8 @@ place and results will be accumulated in input0. All inputs and outputs must have the same shape and data type. )DOC") .Input(0, "data_0", "First of the input tensors. Can be inplace.") - .Output(0, "min", "Output tensor. Same dimension as inputs."); + .Output(0, "min", "Output tensor. Same dimension as inputs.") + .InheritOnnxSchema("Min"); template <typename T, class Context> bool MaxOp<T, Context>::Compute() { diff --git a/caffe2/operators/negative_op.cc b/caffe2/operators/negative_op.cc index b83ac3047a..2bbe188f54 100644 --- a/caffe2/operators/negative_op.cc +++ b/caffe2/operators/negative_op.cc @@ -43,7 +43,8 @@ OPERATOR_SCHEMA(Negative) Computes the element-wise negative of the input. )DOC") .Input(0, "X", "1D input tensor") - .Output(0, "Y", "1D input tensor"); + .Output(0, "Y", "1D input tensor") + .InheritOnnxSchema("Neg"); class GetNegativeGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; diff --git a/caffe2/operators/pool_op.cc b/caffe2/operators/pool_op.cc index e8e9958d8c..dae6e919b1 100644 --- a/caffe2/operators/pool_op.cc +++ b/caffe2/operators/pool_op.cc @@ -814,7 +814,8 @@ OPERATOR_SCHEMA(AveragePool) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool) - .FillUsing(AveragePoolDocGenerator("")); + .FillUsing(AveragePoolDocGenerator("")) + .InheritOnnxSchema("AveragePool"); REGISTER_CPU_OPERATOR( AveragePool1D, @@ -824,7 +825,8 @@ OPERATOR_SCHEMA(AveragePool1D) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool) - .FillUsing(AveragePoolDocGenerator("1D")); + .FillUsing(AveragePoolDocGenerator("1D")) + .InheritOnnxSchema("AveragePool"); REGISTER_CPU_OPERATOR( AveragePool2D, @@ -834,7 +836,8 @@ OPERATOR_SCHEMA(AveragePool2D) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool) - .FillUsing(AveragePoolDocGenerator("2D")); + .FillUsing(AveragePoolDocGenerator("2D")) + .InheritOnnxSchema("AveragePool"); REGISTER_CPU_OPERATOR( AveragePool3D, @@ -844,7 +847,8 @@ OPERATOR_SCHEMA(AveragePool3D) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool) - .FillUsing(AveragePoolDocGenerator("3D")); + .FillUsing(AveragePoolDocGenerator("3D")) + .InheritOnnxSchema("AveragePool"); REGISTER_CPU_OPERATOR(MaxPool, PoolOp<float, CPUContext, MaxPool<float>>); @@ -852,7 +856,8 @@ OPERATOR_SCHEMA(MaxPool) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool) - .FillUsing(MaxPoolDocGenerator("")); + .FillUsing(MaxPoolDocGenerator("")) + .InheritOnnxSchema("MaxPool"); REGISTER_CPU_OPERATOR(MaxPool1D, PoolOp<float, CPUContext, MaxPool<float>>); @@ -860,7 +865,8 @@ OPERATOR_SCHEMA(MaxPool1D) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool) - .FillUsing(MaxPoolDocGenerator("1D")); + .FillUsing(MaxPoolDocGenerator("1D")) + .InheritOnnxSchema("MaxPool"); REGISTER_CPU_OPERATOR(MaxPool2D, PoolOp<float, CPUContext, MaxPool<float>>); @@ -868,7 +874,8 @@ OPERATOR_SCHEMA(MaxPool2D) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool) - .FillUsing(MaxPoolDocGenerator("2D")); + .FillUsing(MaxPoolDocGenerator("2D")) + .InheritOnnxSchema("MaxPool"); REGISTER_CPU_OPERATOR(MaxPool3D, PoolOp<float, CPUContext, MaxPool<float>>); @@ -876,5 +883,6 @@ OPERATOR_SCHEMA(MaxPool3D) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool) - .FillUsing(MaxPoolDocGenerator("3D")); + .FillUsing(MaxPoolDocGenerator("3D")) + .InheritOnnxSchema("MaxPool"); } // namespace caffe2 diff --git a/caffe2/operators/prelu_op.cc b/caffe2/operators/prelu_op.cc index a6aaca4efd..3f8bce4cb7 100644 --- a/caffe2/operators/prelu_op.cc +++ b/caffe2/operators/prelu_op.cc @@ -290,7 +290,8 @@ output data (Tensor<T>) where the function `f(x) = slope * x for x < 0`, "Slope", "1D slope tensor. If `Slope` is of size 1, the value is shared" "across different channels") - .Output(0, "Y", "1D input tensor"); + .Output(0, "Y", "1D input tensor") + .InheritOnnxSchema("PRelu"); // Input: Y, dY, output: dX OPERATOR_SCHEMA(PReluGradient).NumInputs(4).NumOutputs(2).SetDoc(R"DOC( diff --git a/caffe2/operators/relu_op.cc b/caffe2/operators/relu_op.cc index 0206b6f626..54dce8fb7b 100644 --- a/caffe2/operators/relu_op.cc +++ b/caffe2/operators/relu_op.cc @@ -96,7 +96,8 @@ Relu takes one input data (Tensor<T>) and produces one output data the tensor elementwise. )DOC") .Input(0, "X", "1D input tensor") - .Output(0, "Y", "1D input tensor"); + .Output(0, "Y", "1D input tensor") + .InheritOnnxSchema("Relu"); // Input: Y, dY, output: dX OPERATOR_SCHEMA(ReluGradient) diff --git a/caffe2/operators/selu_op.cc b/caffe2/operators/selu_op.cc index 6c3ae9983b..8ddda64137 100644 --- a/caffe2/operators/selu_op.cc +++ b/caffe2/operators/selu_op.cc @@ -56,7 +56,8 @@ is applied to the tensor elementwise. "scale", "(float) default to 1.0507~; affects the activation function itself.") .Input(0, "X", "input tensor") - .Output(0, "Y", "input tensor"); + .Output(0, "Y", "input tensor") + .InheritOnnxSchema("Selu"); // Input: Y, dY; output: dX OPERATOR_SCHEMA(SeluGradient) diff --git a/caffe2/operators/sigmoid_op.cc b/caffe2/operators/sigmoid_op.cc index cf4b62404d..15c2a3624b 100644 --- a/caffe2/operators/sigmoid_op.cc +++ b/caffe2/operators/sigmoid_op.cc @@ -63,7 +63,8 @@ Sigmoid takes one input data (Tensor<T>) and produces one output data tensor elementwise. )DOC") .Input(0, "X", "1D input tensor") - .Output(0, "Y", "1D output tensor"); + .Output(0, "Y", "1D output tensor") + .InheritOnnxSchema("Sigmoid"); // Input: Y, dY, output: dX OPERATOR_SCHEMA(SigmoidGradient) .NumInputs(2) diff --git a/caffe2/operators/softmax_op.cc b/caffe2/operators/softmax_op.cc index 32736f9a92..b803fec130 100644 --- a/caffe2/operators/softmax_op.cc +++ b/caffe2/operators/softmax_op.cc @@ -123,7 +123,8 @@ will throw errors. "The input tensor that's coerced into a 2D matrix of size (NxD) " "as described above.") .Output(0, "output", "The softmax normalized output values with the same " - "shape as input tensor."); + "shape as input tensor.") + .InheritOnnxSchema("Softmax"); // Input: Y, dY. Output: dX OPERATOR_SCHEMA(SoftmaxGradient).NumInputs(2).NumOutputs(1); diff --git a/caffe2/operators/softplus_op.cc b/caffe2/operators/softplus_op.cc index c03e556ddf..71eff2ec06 100644 --- a/caffe2/operators/softplus_op.cc +++ b/caffe2/operators/softplus_op.cc @@ -66,7 +66,8 @@ Softplus takes one input data (Tensor<T>) and produces one output data the tensor elementwise. )DOC") .Input(0, "X", "1D input tensor") - .Output(0, "Y", "1D input tensor"); + .Output(0, "Y", "1D input tensor") + .InheritOnnxSchema("Softplus"); // Input: Y, dY, output: dX OPERATOR_SCHEMA(SoftplusGradient) diff --git a/caffe2/operators/softsign_op.cc b/caffe2/operators/softsign_op.cc index 750d6a53ea..b9835e0412 100644 --- a/caffe2/operators/softsign_op.cc +++ b/caffe2/operators/softsign_op.cc @@ -67,7 +67,8 @@ and output blobs. 0, "output", "The softsign (x/1+|x|) values of the input tensor " - "computed element-wise"); + "computed element-wise") + .InheritOnnxSchema("Softsign"); OPERATOR_SCHEMA(SoftsignGradient) .NumInputs(2) diff --git a/caffe2/operators/spatial_batch_norm_op.cc b/caffe2/operators/spatial_batch_norm_op.cc index dcd25cd45b..94a4e8f7f2 100644 --- a/caffe2/operators/spatial_batch_norm_op.cc +++ b/caffe2/operators/spatial_batch_norm_op.cc @@ -308,6 +308,7 @@ Output case #2: 4, "saved_var", "Saved variance used during training to speed up " - "gradient computation. Should not be used for testing."); + "gradient computation. Should not be used for testing.") + .InheritOnnxSchema("BatchNormalization"); } // namespace caffe2 diff --git a/caffe2/operators/tanh_op.cc b/caffe2/operators/tanh_op.cc index 1e945b3725..dd8a086989 100644 --- a/caffe2/operators/tanh_op.cc +++ b/caffe2/operators/tanh_op.cc @@ -69,7 +69,8 @@ and output blobs. )DOC") .Input(0, "input", "1-D input tensor") .Output(0, "output", "The hyperbolic tangent values of the input tensor " - "computed element-wise"); + "computed element-wise") + .InheritOnnxSchema("Tanh"); OPERATOR_SCHEMA(TanhGradient).NumInputs(2).NumOutputs(1).AllowInplace({{1, 0}}); diff --git a/caffe2/operators/tile_op.cc b/caffe2/operators/tile_op.cc index f31566fb56..1058aa731d 100644 --- a/caffe2/operators/tile_op.cc +++ b/caffe2/operators/tile_op.cc @@ -62,7 +62,8 @@ For example, tiling [[a b c d]] by tile=2, axis=0 produces .Output( 0, "tiled_data", - "Tensor that will contain input replicated along the given axis."); + "Tensor that will contain input replicated along the given axis.") + .InheritOnnxSchema("Tile"); OPERATOR_SCHEMA(TileGradient).NumInputs(1, 3).NumOutputs(1); diff --git a/caffe2/operators/transpose_op.cc b/caffe2/operators/transpose_op.cc index c4318b0a60..e56c820527 100644 --- a/caffe2/operators/transpose_op.cc +++ b/caffe2/operators/transpose_op.cc @@ -75,7 +75,8 @@ will be (2, 1, 3). "A list of integers. By default, reverse the dimensions, " "otherwise permute the axes according to the values given.") .Input(0, "data", "An input tensor.") - .Output(0, "transposed", "Transposed output."); + .Output(0, "transposed", "Transposed output.") + .InheritOnnxSchema("Transpose"); class GetTransposeGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; |