summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorSebastian Messmer <messmer@fb.com>2019-02-28 09:50:19 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-28 09:53:18 -0800
commita9395ce259a27abf141e6add7f5d220bd15d03eb (patch)
tree094958056059f15a6c1afd04ae739e4a33051abe /caffe2
parent9bcceb75b515c318a9f761fffa3acb1acc955fe3 (diff)
downloadpytorch-a9395ce259a27abf141e6add7f5d220bd15d03eb.tar.gz
pytorch-a9395ce259a27abf141e6add7f5d220bd15d03eb.tar.bz2
pytorch-a9395ce259a27abf141e6add7f5d220bd15d03eb.zip
refactor caffe2 operator constructors - 9/9 (#17090)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17090 clangr codemod Reviewed By: ezyang Differential Revision: D14078550 fbshipit-source-id: 68e6de4298e55ce83039b7806c1a275c4d6593c8
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/operators/utility_ops.h52
-rw-r--r--caffe2/operators/utility_ops_cudnn.cc6
-rw-r--r--caffe2/operators/variable_length_sequence_padding.h7
-rw-r--r--caffe2/operators/weighted_multi_sampling_op.h5
-rw-r--r--caffe2/operators/weighted_sample_op.h5
-rw-r--r--caffe2/operators/while_op.h2
-rw-r--r--caffe2/operators/workspace_ops.cc2
7 files changed, 46 insertions, 33 deletions
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 469e97a015..a77ee9a2a2 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -20,8 +20,9 @@ template <class Context>
class NanCheckOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- NanCheckOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws) {}
+ template <class... Args>
+ explicit NanCheckOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
@@ -46,8 +47,9 @@ class WallClockTimeOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- WallClockTimeOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws) {}
+ template <class... Args>
+ explicit WallClockTimeOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
int64_t nanoseconds = static_cast<long int>(
@@ -70,7 +72,7 @@ class PrintOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_DISPATCH_HELPER;
- PrintOp(const OperatorDef& operator_def, Workspace* ws)
+ explicit PrintOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
tensor_printer_(
operator_def.input(0),
@@ -395,8 +397,9 @@ class WeightedSumGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- WeightedSumGradientOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
+ template <class... Args>
+ explicit WeightedSumGradientOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
grad_on_w_(this->template GetSingleArgument<bool>("grad_on_w", false)) {
}
@@ -597,8 +600,9 @@ class ScatterAssignOp : public Operator<Context> {
USE_OPERATOR_CONTEXT_FUNCTIONS;
virtual ~ScatterAssignOp() {}
- ScatterAssignOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
+ template <class... Args>
+ explicit ScatterAssignOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
runners_({{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT},
&ScatterAssignOp::DoRun<int32_t, float>},
{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT16},
@@ -871,8 +875,9 @@ template <class Context>
class LengthsToWeightsOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- LengthsToWeightsOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
+ template <class... Args>
+ explicit LengthsToWeightsOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
power_(this->template GetSingleArgument<float>("power", 0.5)) {}
bool RunOnDevice() override {
@@ -1149,8 +1154,9 @@ class LengthsGatherOp : public Operator<Context> {
template <typename T, class Context>
class AccumulateHistogramOp : public Operator<Context> {
public:
- AccumulateHistogramOp(const OperatorDef& def, Workspace* ws)
- : Operator<Context>(def, ws),
+ template <class... Args>
+ explicit AccumulateHistogramOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
lower_bound_(
this->template GetSingleArgument<float>("lower_bound", 0.0)),
upper_bound_(
@@ -1288,8 +1294,9 @@ class RangeOp : public Operator<Context> {
class ThrowExceptionOp : public Operator<CPUContext> {
public:
- ThrowExceptionOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CPUContext>(operator_def, ws),
+ template <class... Args>
+ explicit ThrowExceptionOp(Args&&... args)
+ : Operator<CPUContext>(std::forward<Args>(args)...),
message_(GetSingleArgument<std::string>(
"message",
"Exception from ThrowExceptionOp")) {}
@@ -1304,8 +1311,9 @@ class ThrowExceptionOp : public Operator<CPUContext> {
class ThrowChildThreadExceptionOp : public Operator<CPUContext> {
public:
- ThrowChildThreadExceptionOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CPUContext>(operator_def, ws),
+ template <class... Args>
+ explicit ThrowChildThreadExceptionOp(Args&&... args)
+ : Operator<CPUContext>(std::forward<Args>(args)...),
message_(GetSingleArgument<std::string>(
"message",
"Exception from ThrowChildThreadExceptionOp")) {}
@@ -1323,8 +1331,9 @@ class ThrowChildThreadExceptionOp : public Operator<CPUContext> {
class LogFatalOp : public Operator<CPUContext> {
public:
- LogFatalOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CPUContext>(operator_def, ws),
+ template <class... Args>
+ explicit LogFatalOp(Args&&... args)
+ : Operator<CPUContext>(std::forward<Args>(args)...),
message_(GetSingleArgument<std::string>(
"message",
"Logging from LogFatalOp")) {}
@@ -1340,8 +1349,9 @@ class LogFatalOp : public Operator<CPUContext> {
class FailOp : public Operator<CPUContext> {
public:
- FailOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CPUContext>(operator_def, ws) {}
+ template <class... Args>
+ explicit FailOp(Args&&... args)
+ : Operator<CPUContext>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
return false;
diff --git a/caffe2/operators/utility_ops_cudnn.cc b/caffe2/operators/utility_ops_cudnn.cc
index 7fe4503b2b..c04ad3292c 100644
--- a/caffe2/operators/utility_ops_cudnn.cc
+++ b/caffe2/operators/utility_ops_cudnn.cc
@@ -12,8 +12,10 @@ class CuDNNWeightedSumOp : public Operator<CUDAContext> {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
- CuDNNWeightedSumOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
+ template <class... Args>
+ explicit CuDNNWeightedSumOp(Args&&... args)
+ : Operator<CUDAContext>(std::forward<Args>(args)...),
+ cudnn_wrapper_(&context_) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(cudnnCreateOpTensorDescriptor(&add_desc_));
// Both float and at::Half require opTensorCompType to be CUDNN_DATA_FLOAT.
diff --git a/caffe2/operators/variable_length_sequence_padding.h b/caffe2/operators/variable_length_sequence_padding.h
index 9e4c9da83e..f86964d639 100644
--- a/caffe2/operators/variable_length_sequence_padding.h
+++ b/caffe2/operators/variable_length_sequence_padding.h
@@ -29,10 +29,9 @@ void VariableLengthSequencePadding(
template <typename T, typename Context>
class VariableLengthSequencePaddingOp : public Operator<Context> {
public:
- VariableLengthSequencePaddingOp(
- const OperatorDef& operator_def,
- Workspace* ws)
- : Operator<Context>(operator_def, ws) {}
+ template <class... Args>
+ explicit VariableLengthSequencePaddingOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
diff --git a/caffe2/operators/weighted_multi_sampling_op.h b/caffe2/operators/weighted_multi_sampling_op.h
index 968c518ca0..1de0fd1137 100644
--- a/caffe2/operators/weighted_multi_sampling_op.h
+++ b/caffe2/operators/weighted_multi_sampling_op.h
@@ -9,8 +9,9 @@ class WeightedMultiSamplingOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- WeightedMultiSamplingOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
+ template <class... Args>
+ explicit WeightedMultiSamplingOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
num_samples_(
this->template GetSingleArgument<int64_t>("num_samples", 0)) {
CAFFE_ENFORCE_GE(num_samples_, 0);
diff --git a/caffe2/operators/weighted_sample_op.h b/caffe2/operators/weighted_sample_op.h
index 1474a9aee2..361d20ef5b 100644
--- a/caffe2/operators/weighted_sample_op.h
+++ b/caffe2/operators/weighted_sample_op.h
@@ -13,8 +13,9 @@ namespace caffe2 {
template <typename T, class Context>
class WeightedSampleOp final : public Operator<Context> {
public:
- WeightedSampleOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws) {}
+ template <class... Args>
+ explicit WeightedSampleOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
diff --git a/caffe2/operators/while_op.h b/caffe2/operators/while_op.h
index 66869ddfc4..445e4e4983 100644
--- a/caffe2/operators/while_op.h
+++ b/caffe2/operators/while_op.h
@@ -10,7 +10,7 @@ namespace caffe2 {
template <class Context>
class WhileOp final : public Operator<Context> {
public:
- WhileOp(const OperatorDef& operator_def, Workspace* ws)
+ explicit WhileOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
CAFFE_ENFORCE(
this->template HasSingleArgumentOfType<NetDef>("loop_net"),
diff --git a/caffe2/operators/workspace_ops.cc b/caffe2/operators/workspace_ops.cc
index e3345f0152..488c7f408b 100644
--- a/caffe2/operators/workspace_ops.cc
+++ b/caffe2/operators/workspace_ops.cc
@@ -6,7 +6,7 @@ namespace {
class GetAllBlobNamesOp final : public Operator<CPUContext> {
public:
- GetAllBlobNamesOp(const OperatorDef& operator_def, Workspace* ws)
+ explicit GetAllBlobNamesOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
include_shared_(GetSingleArgument<int>("include_shared", true)),
ws_(ws) {}