diff options
author | Sebastian Messmer <messmer@fb.com> | 2019-02-28 09:50:19 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-28 09:53:18 -0800 |
commit | a9395ce259a27abf141e6add7f5d220bd15d03eb (patch) | |
tree | 094958056059f15a6c1afd04ae739e4a33051abe /caffe2 | |
parent | 9bcceb75b515c318a9f761fffa3acb1acc955fe3 (diff) | |
download | pytorch-a9395ce259a27abf141e6add7f5d220bd15d03eb.tar.gz pytorch-a9395ce259a27abf141e6add7f5d220bd15d03eb.tar.bz2 pytorch-a9395ce259a27abf141e6add7f5d220bd15d03eb.zip |
refactor caffe2 operator constructors - 9/9 (#17090)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17090
clangr codemod
Reviewed By: ezyang
Differential Revision: D14078550
fbshipit-source-id: 68e6de4298e55ce83039b7806c1a275c4d6593c8
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/operators/utility_ops.h | 52 | ||||
-rw-r--r-- | caffe2/operators/utility_ops_cudnn.cc | 6 | ||||
-rw-r--r-- | caffe2/operators/variable_length_sequence_padding.h | 7 | ||||
-rw-r--r-- | caffe2/operators/weighted_multi_sampling_op.h | 5 | ||||
-rw-r--r-- | caffe2/operators/weighted_sample_op.h | 5 | ||||
-rw-r--r-- | caffe2/operators/while_op.h | 2 | ||||
-rw-r--r-- | caffe2/operators/workspace_ops.cc | 2 |
7 files changed, 46 insertions, 33 deletions
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index 469e97a015..a77ee9a2a2 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -20,8 +20,9 @@ template <class Context> class NanCheckOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - NanCheckOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws) {} + template <class... Args> + explicit NanCheckOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...) {} bool RunOnDevice() override; @@ -46,8 +47,9 @@ class WallClockTimeOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - WallClockTimeOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws) {} + template <class... Args> + explicit WallClockTimeOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...) {} bool RunOnDevice() override { int64_t nanoseconds = static_cast<long int>( @@ -70,7 +72,7 @@ class PrintOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_DISPATCH_HELPER; - PrintOp(const OperatorDef& operator_def, Workspace* ws) + explicit PrintOp(const OperatorDef& operator_def, Workspace* ws) : Operator<Context>(operator_def, ws), tensor_printer_( operator_def.input(0), @@ -395,8 +397,9 @@ class WeightedSumGradientOp : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - WeightedSumGradientOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws), + template <class... Args> + explicit WeightedSumGradientOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...), grad_on_w_(this->template GetSingleArgument<bool>("grad_on_w", false)) { } @@ -597,8 +600,9 @@ class ScatterAssignOp : public Operator<Context> { USE_OPERATOR_CONTEXT_FUNCTIONS; virtual ~ScatterAssignOp() {} - ScatterAssignOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws), + template <class... Args> + explicit ScatterAssignOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...), runners_({{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT}, &ScatterAssignOp::DoRun<int32_t, float>}, {{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT16}, @@ -871,8 +875,9 @@ template <class Context> class LengthsToWeightsOp : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - LengthsToWeightsOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws), + template <class... Args> + explicit LengthsToWeightsOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...), power_(this->template GetSingleArgument<float>("power", 0.5)) {} bool RunOnDevice() override { @@ -1149,8 +1154,9 @@ class LengthsGatherOp : public Operator<Context> { template <typename T, class Context> class AccumulateHistogramOp : public Operator<Context> { public: - AccumulateHistogramOp(const OperatorDef& def, Workspace* ws) - : Operator<Context>(def, ws), + template <class... Args> + explicit AccumulateHistogramOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...), lower_bound_( this->template GetSingleArgument<float>("lower_bound", 0.0)), upper_bound_( @@ -1288,8 +1294,9 @@ class RangeOp : public Operator<Context> { class ThrowExceptionOp : public Operator<CPUContext> { public: - ThrowExceptionOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<CPUContext>(operator_def, ws), + template <class... Args> + explicit ThrowExceptionOp(Args&&... args) + : Operator<CPUContext>(std::forward<Args>(args)...), message_(GetSingleArgument<std::string>( "message", "Exception from ThrowExceptionOp")) {} @@ -1304,8 +1311,9 @@ class ThrowExceptionOp : public Operator<CPUContext> { class ThrowChildThreadExceptionOp : public Operator<CPUContext> { public: - ThrowChildThreadExceptionOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<CPUContext>(operator_def, ws), + template <class... Args> + explicit ThrowChildThreadExceptionOp(Args&&... args) + : Operator<CPUContext>(std::forward<Args>(args)...), message_(GetSingleArgument<std::string>( "message", "Exception from ThrowChildThreadExceptionOp")) {} @@ -1323,8 +1331,9 @@ class ThrowChildThreadExceptionOp : public Operator<CPUContext> { class LogFatalOp : public Operator<CPUContext> { public: - LogFatalOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<CPUContext>(operator_def, ws), + template <class... Args> + explicit LogFatalOp(Args&&... args) + : Operator<CPUContext>(std::forward<Args>(args)...), message_(GetSingleArgument<std::string>( "message", "Logging from LogFatalOp")) {} @@ -1340,8 +1349,9 @@ class LogFatalOp : public Operator<CPUContext> { class FailOp : public Operator<CPUContext> { public: - FailOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<CPUContext>(operator_def, ws) {} + template <class... Args> + explicit FailOp(Args&&... args) + : Operator<CPUContext>(std::forward<Args>(args)...) {} bool RunOnDevice() override { return false; diff --git a/caffe2/operators/utility_ops_cudnn.cc b/caffe2/operators/utility_ops_cudnn.cc index 7fe4503b2b..c04ad3292c 100644 --- a/caffe2/operators/utility_ops_cudnn.cc +++ b/caffe2/operators/utility_ops_cudnn.cc @@ -12,8 +12,10 @@ class CuDNNWeightedSumOp : public Operator<CUDAContext> { public: USE_OPERATOR_FUNCTIONS(CUDAContext); - CuDNNWeightedSumOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) { + template <class... Args> + explicit CuDNNWeightedSumOp(Args&&... args) + : Operator<CUDAContext>(std::forward<Args>(args)...), + cudnn_wrapper_(&context_) { CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_)); CUDNN_ENFORCE(cudnnCreateOpTensorDescriptor(&add_desc_)); // Both float and at::Half require opTensorCompType to be CUDNN_DATA_FLOAT. diff --git a/caffe2/operators/variable_length_sequence_padding.h b/caffe2/operators/variable_length_sequence_padding.h index 9e4c9da83e..f86964d639 100644 --- a/caffe2/operators/variable_length_sequence_padding.h +++ b/caffe2/operators/variable_length_sequence_padding.h @@ -29,10 +29,9 @@ void VariableLengthSequencePadding( template <typename T, typename Context> class VariableLengthSequencePaddingOp : public Operator<Context> { public: - VariableLengthSequencePaddingOp( - const OperatorDef& operator_def, - Workspace* ws) - : Operator<Context>(operator_def, ws) {} + template <class... Args> + explicit VariableLengthSequencePaddingOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { diff --git a/caffe2/operators/weighted_multi_sampling_op.h b/caffe2/operators/weighted_multi_sampling_op.h index 968c518ca0..1de0fd1137 100644 --- a/caffe2/operators/weighted_multi_sampling_op.h +++ b/caffe2/operators/weighted_multi_sampling_op.h @@ -9,8 +9,9 @@ class WeightedMultiSamplingOp : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - WeightedMultiSamplingOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws), + template <class... Args> + explicit WeightedMultiSamplingOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...), num_samples_( this->template GetSingleArgument<int64_t>("num_samples", 0)) { CAFFE_ENFORCE_GE(num_samples_, 0); diff --git a/caffe2/operators/weighted_sample_op.h b/caffe2/operators/weighted_sample_op.h index 1474a9aee2..361d20ef5b 100644 --- a/caffe2/operators/weighted_sample_op.h +++ b/caffe2/operators/weighted_sample_op.h @@ -13,8 +13,9 @@ namespace caffe2 { template <typename T, class Context> class WeightedSampleOp final : public Operator<Context> { public: - WeightedSampleOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws) {} + template <class... Args> + explicit WeightedSampleOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; diff --git a/caffe2/operators/while_op.h b/caffe2/operators/while_op.h index 66869ddfc4..445e4e4983 100644 --- a/caffe2/operators/while_op.h +++ b/caffe2/operators/while_op.h @@ -10,7 +10,7 @@ namespace caffe2 { template <class Context> class WhileOp final : public Operator<Context> { public: - WhileOp(const OperatorDef& operator_def, Workspace* ws) + explicit WhileOp(const OperatorDef& operator_def, Workspace* ws) : Operator<Context>(operator_def, ws) { CAFFE_ENFORCE( this->template HasSingleArgumentOfType<NetDef>("loop_net"), diff --git a/caffe2/operators/workspace_ops.cc b/caffe2/operators/workspace_ops.cc index e3345f0152..488c7f408b 100644 --- a/caffe2/operators/workspace_ops.cc +++ b/caffe2/operators/workspace_ops.cc @@ -6,7 +6,7 @@ namespace { class GetAllBlobNamesOp final : public Operator<CPUContext> { public: - GetAllBlobNamesOp(const OperatorDef& operator_def, Workspace* ws) + explicit GetAllBlobNamesOp(const OperatorDef& operator_def, Workspace* ws) : Operator<CPUContext>(operator_def, ws), include_shared_(GetSingleArgument<int>("include_shared", true)), ws_(ws) {} |