summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-04-01 13:02:02 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-01 13:09:07 -0700
commit89e9b1cf8e869799b14e5673fa0631982059c671 (patch)
treed25eb3d1f970162725e74e2247962007d55e3534 /caffe2
parent90a5c569884ed7ef7d116daba2bdf999fa6d4a36 (diff)
downloadpytorch-89e9b1cf8e869799b14e5673fa0631982059c671.tar.gz
pytorch-89e9b1cf8e869799b14e5673fa0631982059c671.tar.bz2
pytorch-89e9b1cf8e869799b14e5673fa0631982059c671.zip
add ConvRelu schema (#18693)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18693 As title Reviewed By: protonu Differential Revision: D14662880 fbshipit-source-id: 3664faa660a04e1f528a413d2a1700b872c3c684
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/quantization/server/conv_dnnlowp_op.cc3
-rw-r--r--caffe2/quantization/server/conv_relu_op.cc7
2 files changed, 7 insertions, 3 deletions
diff --git a/caffe2/quantization/server/conv_dnnlowp_op.cc b/caffe2/quantization/server/conv_dnnlowp_op.cc
index bb2aeeef21..d585af57d4 100644
--- a/caffe2/quantization/server/conv_dnnlowp_op.cc
+++ b/caffe2/quantization/server/conv_dnnlowp_op.cc
@@ -1534,9 +1534,6 @@ template class ConvDNNLowPOp<uint8_t, true>;
template class ConvDNNLowPOp<uint16_t, false>;
template class ConvDNNLowPOp<uint16_t, true>;
-OPERATOR_SCHEMA(ConvRelu).NumInputs(2, 3).NumOutputs(1).TensorInferenceFunction(
- ConvPoolOpBase<CPUContext>::TensorInferenceForConv);
-
REGISTER_CPU_OPERATOR_WITH_ENGINE(Conv, DNNLOWP, ConvDNNLowPOp<uint8_t, false>);
REGISTER_CPU_OPERATOR_WITH_ENGINE(
ConvRelu,
diff --git a/caffe2/quantization/server/conv_relu_op.cc b/caffe2/quantization/server/conv_relu_op.cc
index 66683894ea..f3511a8d99 100644
--- a/caffe2/quantization/server/conv_relu_op.cc
+++ b/caffe2/quantization/server/conv_relu_op.cc
@@ -64,6 +64,13 @@ bool ConvReluOp<T, Context>::RunOnDeviceWithOrderNHWC() {
return true;
}
+OPERATOR_SCHEMA(ConvRelu)
+ .NumInputs(2, 3)
+ .NumOutputs(1)
+ .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
+ .CostInferenceFunction(OpSchema::CostInferenceFunctionType(
+ ConvPoolOpBase<CPUContext>::CostInferenceForConv));
+
REGISTER_CPU_OPERATOR(ConvRelu, ConvReluOp<float, CPUContext>);
} // namespace caffe2