summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorSummer Deng <summerdeng@fb.com>2019-04-06 21:50:28 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-06 21:53:08 -0700
commit907b4c5890a9a9e5508114ffaf8e898ef69a2cb2 (patch)
tree3ebadb0143f6c9a30994e7e43bf006fd08f48e9d /caffe2
parentdbd9971dd2c4fad595592c0dde5cc8f3fc1d54a1 (diff)
downloadpytorch-907b4c5890a9a9e5508114ffaf8e898ef69a2cb2.tar.gz
pytorch-907b4c5890a9a9e5508114ffaf8e898ef69a2cb2.tar.bz2
pytorch-907b4c5890a9a9e5508114ffaf8e898ef69a2cb2.zip
fix bug when falling back to acc32 when weight is prepacked (#18974)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18974 When the weight is prepacked and it doesn't contain a prepacked weight for acc32, we shouldn't fallback to acc32. Reviewed By: bddppq Differential Revision: D14814067 fbshipit-source-id: aec917322de695e283f0aca1e930c5603d196404
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/quantization/server/conv_dnnlowp_acc16_op.cc160
-rw-r--r--caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py18
-rw-r--r--caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py22
-rw-r--r--caffe2/quantization/server/fbgemm_pack_op.cc135
4 files changed, 213 insertions, 122 deletions
diff --git a/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc b/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
index 454be17480..3bce76003f 100644
--- a/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
+++ b/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
@@ -9,6 +9,7 @@
#include <omp.h>
#endif
+#include "caffe2/core/logging.h"
#include "dnnlowp_op.h"
#include "dnnlowp_partition.h"
#include "fbgemm_pack_op.h"
@@ -17,7 +18,6 @@
C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier);
C10_DECLARE_int32(caffe2_dnnlowp_copy_to_32bit_frequency);
C10_DECLARE_bool(caffe2_dnnlowp_shared_int32_buffer);
-
// Thresholds to fallback to 32-bit accumulation when 16-bit accumulation
// doesn't provide performance benefits.
C10_DEFINE_double(
@@ -62,35 +62,8 @@ ConvDNNLowPAcc16Op<ReluFused>::ConvDNNLowPAcc16Op(
template <bool ReluFused>
bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
if (fallback_to_32_bit_accumulation_) {
- return true;
- }
-
- if (!BaseType::GetQuantizationParameters_()) {
- return false;
- }
-
- if (!Wq_acc16_packed_ &&
- this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
- CAFFE_ENFORCE_EQ(
- this->order_,
- StorageOrder::NHWC,
- "Pre-packed weight only works with NHWC layout");
- // If the input is already packed
- const auto& packed_filter =
- this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
- Wq_outlier_ = packed_filter.W_outlier;
- Wq_acc16_packed_ = packed_filter.W_acc16;
-
- if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
- LOG(WARNING)
- << "nbits_in_non_outlier in packed weight "
- << packed_filter.nbits_in_non_outlier
- << " doesn't match with nbits_in_non_outlier specified in operator "
- << nbits_in_non_outlier_;
- }
-
- first_invocation_ = false;
- return true;
+ // Short cut if we already know we are falling back to acc32
+ return BaseType::GetQuantizationParameters_();
}
int kernel_dim = this->KernelDim_();
@@ -98,7 +71,17 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
int num_out_channels = filter.dim32(0);
// Check if we should fallback to 32-bit accumulation
- if (this->order_ == StorageOrder::NHWC) {
+ // We should do this before GetQuantizationParameters_ to make sure
+ // GetQuantizationParameters_ initialize things like Wq_packed_ for acc32
+ // properly.
+
+ // We can't fallback if layout is not NHWC or
+ // if weight is prepacked and the prepacked weight doesn't have acc32.
+ bool can_fallback_to_32_bit_accumulation =
+ this->order_ == StorageOrder::NHWC &&
+ (!this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) ||
+ this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER).W);
+ if (can_fallback_to_32_bit_accumulation) {
const Tensor& X = InputTensorCPU_(INPUT);
int N = X.dim32(0);
@@ -121,31 +104,71 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
}
if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) {
- LOG(INFO) << "M " << N * output_image_size
- << " of Conv layer with weight blob "
- << this->debug_def().input(1) << " is smaller than threshold "
- << FLAGS_caffe2_dnnlowp_acc16_m_threshold
- << " . Falling back to acc32";
+ C10_LOG_FIRST_N(INFO, 10)
+ << "M " << N * output_image_size << " of Conv layer with weight blob "
+ << this->debug_def().input(FILTER) << " is smaller than threshold "
+ << FLAGS_caffe2_dnnlowp_acc16_m_threshold
+ << " . Falling back to acc32";
+ fallback_to_32_bit_accumulation_ = true;
+ }
+ if (!fallback_to_32_bit_accumulation_ &&
+ num_out_channels / group_ < acc16_n_threshold) {
+ C10_LOG_FIRST_N(INFO, 10)
+ << "N " << num_out_channels / group_
+ << " of Conv layer with weight blob "
+ << this->debug_def().input(FILTER) << " is smaller than threshold "
+ << acc16_n_threshold << " . Falling back to acc32";
fallback_to_32_bit_accumulation_ = true;
- return true;
}
- if (num_out_channels / group_ < acc16_n_threshold) {
- LOG(INFO) << "N " << num_out_channels / group_
- << " of Conv layer with weight blob "
- << this->debug_def().input(1) << " is smaller than threshold "
- << acc16_n_threshold << " . Falling back to acc32";
+ if (!fallback_to_32_bit_accumulation_ && kernel_dim < acc16_k_threshold) {
+ C10_LOG_FIRST_N(INFO, 10)
+ << "K " << kernel_dim << " of Conv layer with weight blob "
+ << this->debug_def().input(FILTER) << " is smaller than threshold "
+ << acc16_k_threshold << " . Falling back to acc32";
fallback_to_32_bit_accumulation_ = true;
- return true;
}
- if (kernel_dim < acc16_k_threshold) {
- LOG(INFO) << "K " << kernel_dim << " of Conv layer with weight blob "
- << this->debug_def().input(1) << " is smaller than threshold "
- << acc16_k_threshold << " . Falling back to acc32";
+ if (!fallback_to_32_bit_accumulation_ &&
+ this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) &&
+ !this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER)
+ .W_acc16) {
+ C10_LOG_FIRST_N(INFO, 10)
+ << "Falling back to acc32 because packed weight for acc16 is not "
+ "available";
fallback_to_32_bit_accumulation_ = true;
- return true;
}
}
+ if (!BaseType::GetQuantizationParameters_()) {
+ return false;
+ }
+
+ if (fallback_to_32_bit_accumulation_) {
+ return true;
+ }
+
+ if (!Wq_acc16_packed_ &&
+ this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+ CAFFE_ENFORCE_EQ(
+ this->order_,
+ StorageOrder::NHWC,
+ "Pre-packed weight only works with NHWC layout");
+ // If the input is already packed
+ const auto& packed_filter =
+ this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+ Wq_outlier_ = packed_filter.W_outlier;
+ Wq_acc16_packed_ = packed_filter.W_acc16;
+
+ if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
+ C10_LOG_FIRST_N(WARNING, 10)
+ << "nbits_in_non_outlier in packed weight "
+ << packed_filter.nbits_in_non_outlier
+ << " doesn't match with nbits_in_non_outlier specified in operator "
+ << nbits_in_non_outlier_;
+ }
+ first_invocation_ = false;
+ return true;
+ }
+
// Separate out outliers
if (!Wq_outlier_ && this->order_ == StorageOrder::NHWC &&
nbits_in_non_outlier_ < 8) {
@@ -159,20 +182,25 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
W_quantized_));
int outlier_cnt = Wq_outlier_->ColPtr()[num_out_channels];
- LOG(INFO) << "Proportion of outlier for Conv layer with weight blob "
- << this->debug_def().input(1) << " is "
- << static_cast<float>(outlier_cnt) / W_quantized_.size();
- LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_
- << " copy_to_32bit_frequency " << copy_to_32bit_frequency_;
-
- if (static_cast<float>(outlier_cnt) / W_quantized_.size() >
- FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
- LOG(INFO) << "Density of outliers is higher than threshold "
- << FLAGS_caffe2_dnnlowp_acc16_density_threshold
- << " . Falling back to acc32";
+ C10_LOG_FIRST_N(INFO, 10)
+ << "Proportion of outlier for Conv layer with weight blob "
+ << this->debug_def().input(FILTER) << " is "
+ << static_cast<float>(outlier_cnt) / W_quantized_.size();
+ C10_LOG_FIRST_N(INFO, 10)
+ << "nbits_in_non_outlier " << nbits_in_non_outlier_
+ << " copy_to_32bit_frequency " << copy_to_32bit_frequency_;
+
+ if (can_fallback_to_32_bit_accumulation &&
+ static_cast<float>(outlier_cnt) / W_quantized_.size() >
+ FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
+ C10_LOG_FIRST_N(INFO, 10)
+ << "Density of outliers is higher than threshold "
+ << FLAGS_caffe2_dnnlowp_acc16_density_threshold
+ << " . Falling back to acc32";
fallback_to_32_bit_accumulation_ = true;
Wq_outlier_.reset();
- return true;
+ // We need to call GetQuantizationParameters_ again to pack for acc32
+ return BaseType::GetQuantizationParameters_();
}
}
@@ -193,8 +221,9 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
static int log_occurences = 0;
if (log_occurences < 32) {
++log_occurences;
- LOG(WARNING) << "Conv with weight " << this->debug_def().input(FILTER)
- << " falls back to slow path because " << reason;
+ C10_LOG_FIRST_N(WARNING, 10)
+ << "Conv with weight " << this->debug_def().input(FILTER)
+ << " falls back to slow path because " << reason;
}
}
}
@@ -202,8 +231,9 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
static int log_occurences = 0;
if (log_occurences < 32) {
++log_occurences;
- LOG(WARNING) << "Outlier-aware quantization only supports "
- "NHWC layout";
+ C10_LOG_FIRST_N(WARNING, 10)
+ << "Outlier-aware quantization only supports "
+ "NHWC layout";
}
}
first_invocation_ = false;
@@ -359,7 +389,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
static int log_occurences = 0;
if (log_occurences < 32) {
++log_occurences;
- LOG(WARNING)
+ C10_LOG_FIRST_N(WARNING, 10)
<< "Consider using DNNLOWP instead of DNNLOWP_ACC16 engine since "
"we're falling back to a slow path because of NCHW layout";
}
diff --git a/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py b/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py
index 1da2f3120d..1ddf2ced86 100644
--- a/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py
+++ b/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py
@@ -7,15 +7,19 @@ import hypothesis.strategies as st
import numpy as np
from caffe2.python import core, dyndep, utils, workspace
from caffe2.quantization.server import utils as dnnlowp_utils
-from dnnlowp_test_utils import (
- check_quantized_results_close,
- generate_conv_inputs,
-)
+from dnnlowp_test_utils import check_quantized_results_close
from hypothesis import assume, given
dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
-workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
+workspace.GlobalInit(
+ [
+ "caffe2",
+ "--caffe2_omp_num_threads=11",
+ # Increase this threshold to test acc16 with randomly generated data
+ "--caffe2_dnnlowp_acc16_density_threshold=0.9",
+ ]
+)
class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
@@ -254,9 +258,7 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
W_min = -100
W_max = W_min + 255
W = (
- np.random.rand(
- output_channels, kernel, kernel, input_channels_per_group
- )
+ np.random.rand(output_channels, kernel, kernel, input_channels_per_group)
* 4
- 2
+ W_min
diff --git a/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py b/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py
index 44f7aad3b8..d542126c13 100644
--- a/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py
+++ b/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py
@@ -7,15 +7,19 @@ import hypothesis.strategies as st
import numpy as np
from caffe2.python import core, dyndep, utils, workspace
from caffe2.quantization.server import utils as dnnlowp_utils
-from dnnlowp_test_utils import (
- check_quantized_results_close,
- generate_conv_inputs,
-)
+from dnnlowp_test_utils import check_quantized_results_close
from hypothesis import assume, given
dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
-workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
+workspace.GlobalInit(
+ [
+ "caffe2",
+ "--caffe2_omp_num_threads=11",
+ # Increase this threshold to test acc16 with randomly generated data
+ "--caffe2_dnnlowp_acc16_density_threshold=0.9",
+ ]
+)
class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
@@ -224,9 +228,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
W_min = -100
W_max = W_min + 255
W = (
- np.random.rand(
- output_channels, kernel, kernel, input_channels_per_group
- )
+ np.random.rand(output_channels, kernel, kernel, input_channels_per_group)
* 4
- 2
+ W_min
@@ -237,9 +239,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
for g in range(group):
W[g * output_channels_per_group, 0, 0, 0] = W_min
W[g * output_channels_per_group + 1, 0, 0, 0] = W_max
- W[
- g * output_channels_per_group : (g + 1) * output_channels_per_group,
- ] += g
+ W[g * output_channels_per_group : (g + 1) * output_channels_per_group,] += g
if order == "NCHW":
X = utils.NHWC2NCHW(X)
diff --git a/caffe2/quantization/server/fbgemm_pack_op.cc b/caffe2/quantization/server/fbgemm_pack_op.cc
index 704d4e1fa9..9e98b0194a 100644
--- a/caffe2/quantization/server/fbgemm_pack_op.cc
+++ b/caffe2/quantization/server/fbgemm_pack_op.cc
@@ -5,6 +5,9 @@
#include "caffe2_dnnlowp_utils.h"
C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier);
+C10_DECLARE_double(caffe2_dnnlowp_acc16_density_threshold);
+C10_DECLARE_int32(caffe2_dnnlowp_acc16_n_threshold);
+C10_DECLARE_int32(caffe2_dnnlowp_acc16_k_threshold);
namespace caffe2 {
@@ -422,9 +425,44 @@ bool ConvDNNLowPPackWeightOp::RunOnDevice() {
ComputeColumnOffsets(
kernel_dim, M, W_quantized.data(), Y->qparams, *Y->column_offsets);
+ // Check if we should fallback to 32-bit accumulation.
+ // This check is only meaningful when engine is DNNLOWP_ACC16.
+ bool fallback_to_32_bit_accumulation = false;
+ if (nbits_in_non_outlier_ == 0) {
+ LOG(INFO) << "nbits_in_non_outlier == 0 means everything is outlier so we "
+ "fallback to acc32";
+ fallback_to_32_bit_accumulation = true;
+ }
+ // In Skylake, acc16 is not faster when N or K is smaller than 128
+ // FIXME : code duplication with conv_dnnlowp_acc16_op.cc
+ constexpr int SKYLAKE_ACC16_N_THRESHOLD_MIN = 128,
+ SKYLAKE_ACC16_K_THRESHOLD_MIN = 128;
+ int acc16_n_threshold = FLAGS_caffe2_dnnlowp_acc16_n_threshold;
+ if (caffe2::GetCpuId().avx512f() &&
+ acc16_n_threshold < SKYLAKE_ACC16_N_THRESHOLD_MIN) {
+ acc16_n_threshold = SKYLAKE_ACC16_N_THRESHOLD_MIN;
+ }
+ int acc16_k_threshold = FLAGS_caffe2_dnnlowp_acc16_k_threshold;
+ if (caffe2::GetCpuId().avx512f() &&
+ acc16_k_threshold < SKYLAKE_ACC16_K_THRESHOLD_MIN) {
+ acc16_k_threshold = SKYLAKE_ACC16_K_THRESHOLD_MIN;
+ }
+ if (!fallback_to_32_bit_accumulation && M / group_ < acc16_n_threshold) {
+ LOG(INFO) << "N " << M / group_ << " of weight blob "
+ << this->debug_def().input(0) << " is smaller than threshold "
+ << acc16_n_threshold << " . Falling back to acc32";
+ fallback_to_32_bit_accumulation = true;
+ }
+ if (!fallback_to_32_bit_accumulation && kernel_dim < acc16_k_threshold) {
+ LOG(INFO) << "K " << kernel_dim << " of weight blob "
+ << this->debug_def().input(0) << " is smaller than threshold "
+ << acc16_k_threshold << " . Falling back to acc32";
+ fallback_to_32_bit_accumulation = true;
+ }
+
// When nbits_in_non_outlier == 0, we fall back to acc32
if (this->debug_def().engine() == "DNNLOWP_ACC16" &&
- nbits_in_non_outlier_ > 0) {
+ !fallback_to_32_bit_accumulation) {
if (nbits_in_non_outlier_ < 8) {
Y->W_outlier.reset(ExtractOutlierMatrix(
group_, kernel_dim, M, nbits_in_non_outlier_, W_quantized));
@@ -434,45 +472,66 @@ bool ConvDNNLowPPackWeightOp::RunOnDevice() {
<< this->debug_def().input(0) << " is "
<< static_cast<float>(outlier_cnt) / W_quantized.size();
LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_;
+
+ if (static_cast<float>(outlier_cnt) / W_quantized.size() >
+ FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
+ LOG(INFO) << "Density of outliers is higher than threshold "
+ << FLAGS_caffe2_dnnlowp_acc16_density_threshold
+ << " . Falling back to acc32";
+ fallback_to_32_bit_accumulation = true;
+ }
}
- Y->nbits_in_non_outlier = nbits_in_non_outlier_;
- Y->W_acc16.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
- fbgemm::matrix_op_t::Transpose,
- group_ * kernel_dim,
- M / group_,
- W_quantized.data(),
- kernel_dim,
- nullptr, // pmat
- group_));
- } else if (TakeDepthWise3x3FastPath_()) {
- Y->W_depthwise_3x3.reset(
- new fbgemm::Packed3x3ConvMatrix(group_, W_quantized.data()));
- } else if (TakeDepthWise3x3x3FastPath_()) {
- Y->W_depthwise_3x3x3.reset(
- new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data()));
- } else if (TakeGConvFastPath_()) {
- fbgemm::conv_param_t<> conv_p(
- 1,
- group_ * C_per_group,
- M,
- {1, 1},
- group_,
- {this->kernel_[0], this->kernel_[1]},
- {this->stride_[0], this->stride_[1]},
- {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
-
- Y->W_gconv.reset(new fbgemm::PackWeightMatrixForGConv<int8_t>(
- fbgemm::matrix_op_t::Transpose, conv_p, W_quantized.data()));
- } else {
- Y->W.reset(new fbgemm::PackBMatrix<int8_t>(
- fbgemm::matrix_op_t::Transpose,
- group_ * kernel_dim,
- M / group_,
- W_quantized.data(),
- kernel_dim,
- nullptr, // pmat
- group_));
+ if (!fallback_to_32_bit_accumulation) {
+ Y->nbits_in_non_outlier = nbits_in_non_outlier_;
+ Y->W_acc16.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
+ fbgemm::matrix_op_t::Transpose,
+ group_ * kernel_dim,
+ M / group_,
+ W_quantized.data(),
+ kernel_dim,
+ nullptr, // pmat
+ group_));
+ }
+ }
+
+ if (fallback_to_32_bit_accumulation) {
+ Y->W_acc16.reset();
+ Y->W_outlier.reset();
+ }
+
+ if (this->debug_def().engine() != "DNNLOWP_ACC16" ||
+ fallback_to_32_bit_accumulation) {
+ // acc32
+ if (TakeDepthWise3x3FastPath_()) {
+ Y->W_depthwise_3x3.reset(
+ new fbgemm::Packed3x3ConvMatrix(group_, W_quantized.data()));
+ } else if (TakeDepthWise3x3x3FastPath_()) {
+ Y->W_depthwise_3x3x3.reset(
+ new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data()));
+ } else if (TakeGConvFastPath_()) {
+ fbgemm::conv_param_t<> conv_p(
+ 1,
+ group_ * C_per_group,
+ M,
+ {1, 1},
+ group_,
+ {this->kernel_[0], this->kernel_[1]},
+ {this->stride_[0], this->stride_[1]},
+ {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
+
+ Y->W_gconv.reset(new fbgemm::PackWeightMatrixForGConv<int8_t>(
+ fbgemm::matrix_op_t::Transpose, conv_p, W_quantized.data()));
+ } else {
+ Y->W.reset(new fbgemm::PackBMatrix<int8_t>(
+ fbgemm::matrix_op_t::Transpose,
+ group_ * kernel_dim,
+ M / group_,
+ W_quantized.data(),
+ kernel_dim,
+ nullptr, // pmat
+ group_));
+ }
}
if (InputSize() >= 2) {