summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_params.h
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_params.h')
-rw-r--r--inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_params.h723
1 files changed, 76 insertions, 647 deletions
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_params.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_params.h
index 39a379855..d4351f2b4 100644
--- a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_params.h
+++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_params.h
@@ -17,9 +17,8 @@
#pragma once
#include <string>
-#include <cstddef>
#include <memory>
-#include <map>
+#include <cstddef>
#include "common_types.h"
#include "common_tools.h"
#include "tensor_type.h"
@@ -90,6 +89,13 @@ namespace kernel_selector
uint32_t axisBatch : 1;
uint32_t axisXYF : 1;
} argm;
+ struct idxsel_t
+ {
+ uint32_t axisX : 1;
+ uint32_t axisY : 1;
+ uint32_t axisFeature : 1;
+ uint32_t axisBatch : 1;
+ } idxsel;
struct norm_t
{
uint32_t across : 1;
@@ -110,11 +116,12 @@ namespace kernel_selector
uint32_t floor : 1;
uint32_t max_with_argmax : 1;
uint32_t ceil : 1;
+ uint32_t bilinear : 1;
uint32_t fixedKenrelDivider : 1;
uint32_t dynamicKenrelDivider : 1;
uint32_t dynamicKenrelDividerWithPadding : 1;
} pooling;
- struct conv_t
+ struct conv_t
{
uint32_t split : 1;
uint32_t dilation : 1;
@@ -124,7 +131,7 @@ namespace kernel_selector
uint32_t calibration : 1;
} conv;
struct fc_t {} fc;
- struct softmax_t
+ struct softmax_t
{
uint32_t dimX : 1;
uint32_t dimY : 1;
@@ -198,6 +205,7 @@ namespace kernel_selector
uint32_t uint16 : 1;
uint32_t int32 : 1;
uint32_t uint32 : 1;
+ uint32_t int64 : 1;
uint32_t F16 : 1;
uint32_t F32 : 1;
} val;
@@ -215,559 +223,71 @@ namespace kernel_selector
uint32_t weightsOutputLayout;
};
- void EnableInputDataType(Datatype dt)
- {
- switch (dt)
- {
- case Datatype::INT8:
- key.inputType.val.int8 = 1;
- break;
- case Datatype::UINT8:
- key.inputType.val.uint8 = 1;
- break;
- case Datatype::INT16:
- key.inputType.val.int16 = 1;
- break;
- case Datatype::UINT16:
- key.inputType.val.uint16 = 1;
- break;
- case Datatype::INT32:
- key.inputType.val.int32 = 1;
- break;
- case Datatype::UINT32:
- key.inputType.val.uint32 = 1;
- break;
- case Datatype::F16:
- key.inputType.val.F16 = 1;
- break;
- case Datatype::F32:
- key.inputType.val.F32 = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableAllInputDataType()
- {
- key.inputType.raw = 0xffffffff;
- }
-
- void EnableOutputDataType(Datatype dt)
- {
- switch (dt)
- {
- case Datatype::INT8:
- key.outputType.val.int8 = 1;
- break;
- case Datatype::UINT8:
- key.outputType.val.uint8 = 1;
- break;
- case Datatype::INT16:
- key.outputType.val.int16 = 1;
- break;
- case Datatype::UINT16:
- key.outputType.val.uint16 = 1;
- break;
- case Datatype::INT32:
- key.outputType.val.int32 = 1;
- break;
- case Datatype::UINT32:
- key.outputType.val.uint32 = 1;
- break;
- case Datatype::F16:
- key.outputType.val.F16 = 1;
- break;
- case Datatype::F32:
- key.outputType.val.F32 = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableAllOutputDataType()
- {
- key.outputType.raw = 0xffffffff;
- }
-
- void EnableInputWeightsType(WeightsType wt)
- {
- switch (wt)
- {
- case WeightsType::F16:
- key.inputWeightsType.val.F16 = 1;
- break;
- case WeightsType::F32:
- key.inputWeightsType.val.F32 = 1;
- break;
- case WeightsType::INT8:
- key.inputWeightsType.val.int8 = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableAllInputWeightsType()
- {
- key.inputWeightsType.raw = 0xffffffff;
- }
-
- void EnableOutputWeightsType(WeightsType wt)
- {
- switch (wt)
- {
- case WeightsType::F16:
- key.outputWeightsType.val.F16 = 1;
- break;
- case WeightsType::F32:
- key.outputWeightsType.val.F32 = 1;
- break;
- case WeightsType::INT8:
- key.outputWeightsType.val.int8 = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableAllOutputWeightsType()
- {
- key.outputWeightsType.raw = 0xffffffff;
- }
-
- void EnableFP16Emulation()
- {
- key.restrict.val.FP16Emulation = 1;
- }
-
- void EnableDifferentTypes()
- {
- key.restrict.val.different_types = 1;
- }
-
- void EnableInputLayout(DataLayout l)
- {
- key.inputLayout |= (1 << l);
- }
-
- void EnableAllInputLayout()
- {
- key.inputLayout = 0xffffffff;
- }
-
- void EnableOutputLayout(DataLayout l)
- {
- key.outputLayout |= (1 << l);
- }
-
- void EnableAllOutputLayout()
- {
- key.outputLayout = 0xffffffff;
- }
-
- void EnableInputWeightsLayout(WeightsLayout l)
- {
- key.weightsInputLayout |= (1 << l);
- }
-
- void EnableAllInputWeightsLayout()
- {
- key.weightsInputLayout = 0xffffffff;
- }
-
- void EnableOutputWeightsLayout(WeightsLayout l)
- {
- key.weightsOutputLayout |= (1 << l);
- }
-
- void EnableAllOutputWeightsLayout()
- {
- key.weightsOutputLayout = 0xffffffff;
- }
-
- void EnableTensorOffset()
- {
- key.restrict.val.offset = 1;
- }
-
- void EnableTensorPitches()
- {
- key.restrict.val.pitches = 1;
- }
-
- void EnableBatching()
- {
- key.restrict.val.batching = 1;
- }
-
- void EnableGradient()
- {
- key.restrict.val.gradient = 1;
- }
-
- void EnableSubGroup()
- {
- key.machineInfo.val.subgroup = 1;
- }
-
- void EnableSubGroupShort()
- {
- key.machineInfo.val.subgroupShort = 1;
- }
-
- void EnableNonBiasTerm()
- {
- key.restrict.val.nonBias = 1;
- }
-
- void EnableBiasPerFeature()
- {
- key.restrict.val.biasPerFeatureMap = 1;
- }
-
- void EnableBiasPerOutput()
- {
- key.restrict.val.biasPerOutput = 1;
- }
-
- void EnableActivationAdditionalParamsAsInput()
- {
- key.restrict.val.activationAdditionalParamsAsInput = 1;
- }
-
- void EnableMomentum()
- {
- key.restrict.val.momentum = 1;
- }
-
- void EnableLRNMode(LRNMode m)
- {
- switch (m)
- {
- case LRNMode::ACROSS_CHANNEL:
- key.restrict.val.dedicated.norm.across = 1;
- break;
- case LRNMode::WITHIN_CHANNEL:
- key.restrict.val.dedicated.norm.within = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableLookUpTableAxis(LookUpTableAxis m)
- {
- switch (m)
- {
- case kernel_selector::LookUpTableAxis::BATCH:
- key.restrict.val.dedicated.lookt.axisBatch = 1;
- break;
- case kernel_selector::LookUpTableAxis::FEATURE:
- key.restrict.val.dedicated.lookt.axisFeature = 1;
- break;
- case kernel_selector::LookUpTableAxis::X:
- key.restrict.val.dedicated.lookt.axisX = 1;
- break;
- case kernel_selector::LookUpTableAxis::Y:
- key.restrict.val.dedicated.lookt.axisY = 1;
- break;
- case kernel_selector::LookUpTableAxis::XYF:
- key.restrict.val.dedicated.lookt.axisXYF = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableNormalizeMode(NormalizeMode m)
- {
- switch (m)
- {
- case NormalizeMode::ACROSS_SPATIAL:
- key.restrict.val.dedicated.norm.across = 1;
- break;
- case NormalizeMode::WITHIN_SPATIAL:
- key.restrict.val.dedicated.norm.within = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableMVNMode(MVNMode m)
- {
- switch (m)
- {
- case MVNMode::ACROSS_CHANNELS:
- key.restrict.val.dedicated.mvn.across = 1;
- break;
- case MVNMode::WITHIN_CHANNELS:
- key.restrict.val.dedicated.mvn.within = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableMVNNormalizeVariance()
- {
- key.restrict.val.dedicated.mvn.normalize_variance = 1;
- }
-
- void EnableLRNKernelDividerMode(KernelDividerMode m)
- {
- switch (m)
- {
- case KernelDividerMode::FIXED:
- key.restrict.val.dedicated.norm.fixedKenrelDivider = 1;
- break;
- case KernelDividerMode::DYNAMIC:
- key.restrict.val.dedicated.norm.dynamicKenrelDivider = 1;
- break;
- default:
- break;
- }
- }
-
- void EnablePoolKernelDividerMode(KernelDividerMode m)
- {
- switch (m)
- {
- case KernelDividerMode::FIXED:
- key.restrict.val.dedicated.pooling.fixedKenrelDivider = 1;
- break;
- case KernelDividerMode::DYNAMIC:
- key.restrict.val.dedicated.pooling.dynamicKenrelDivider = 1;
- break;
- case KernelDividerMode::DYNAMIC_WITH_PADDING:
- key.restrict.val.dedicated.pooling.dynamicKenrelDividerWithPadding = 1;
- break;
- default:
- break;
- }
- }
-
- void EnablePoolType(PoolType t)
- {
- switch (t)
- {
- case PoolType::MAX:
- key.restrict.val.dedicated.pooling.max = 1;
- break;
- case PoolType::AVG:
- key.restrict.val.dedicated.pooling.avg = 1;
- break;
- case PoolType::MAX_WITH_ARGMAX:
- key.restrict.val.dedicated.pooling.max_with_argmax = 1;
- break;
- default:
- break;
- }
- }
-
- void EnablePoolRemainder(PoolRemainder r)
- {
- switch (r)
- {
- case PoolRemainder::FLOOR:
- key.restrict.val.dedicated.pooling.floor = 1;
- break;
- case PoolRemainder::CEIL:
- key.restrict.val.dedicated.pooling.ceil = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableSplitSupport()
- {
- key.restrict.val.dedicated.conv.split = 1;
- }
-
- void EnableDilation()
- {
- key.restrict.val.dedicated.conv.dilation = 1;
- }
-
- void EnableDepthwiseSeparableOpt()
- {
- key.restrict.val.dedicated.conv.depthwiseSeparableOpt = 1;
- }
-
- void EnableTranspose()
- {
- key.restrict.val.dedicated.conv.transposed = 1;
- }
-
- void EnableInt8Quantization()
- {
- key.restrict.val.dedicated.conv.quantization = 1;
- }
-
- void EnableOutputCalibration()
- {
- key.restrict.val.dedicated.conv.calibration = 1;
- }
-
- void EnableWinogradReorder()
- {
- key.restrict.val.dedicated.reorder.winograd = 1;
- }
-
- void EnableSoftmaxDim(SoftmaxDim d)
- {
- switch (d)
- {
- case SoftmaxDim::X:
- key.restrict.val.dedicated.softmax.dimX = 1;
- break;
- case SoftmaxDim::Y:
- key.restrict.val.dedicated.softmax.dimY = 1;
- break;
- case SoftmaxDim::FEATURE:
- key.restrict.val.dedicated.softmax.dimFeature = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableConcatAxis(ConcatAxis a)
- {
- switch (a)
- {
- case ConcatAxis::X:
- key.restrict.val.dedicated.concat.axisX = 1;
- break;
- case ConcatAxis::Y:
- key.restrict.val.dedicated.concat.axisY = 1;
- break;
- case ConcatAxis::FEATURE:
- key.restrict.val.dedicated.concat.axisFeature = 1;
- break;
- case ConcatAxis::BATCH:
- key.restrict.val.dedicated.concat.axisBatch = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableUpSamplingSampleType(SampleType a)
- {
- switch (a)
- {
- case SampleType::NEAREST:
- key.restrict.val.dedicated.upsample.nearest = 1;
- break;
- case SampleType::BILINEAR:
- key.restrict.val.dedicated.upsample.bilinear = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableLSTMGEMMBias() {
- key.restrict.val.dedicated.lstm_gemm.bias = 1;
- }
-
- void EnableLSTMGEMMHidden() {
- key.restrict.val.dedicated.lstm_gemm.hidden = 1;
- }
-
- void EnableLSTMEltCell() {
- key.restrict.val.dedicated.lstm_elt.cell = 1;
- }
-
-
- void EnableConcatKernelPerInput()
- {
- key.restrict.val.dedicated.concat.kernelPerInput = 1;
- }
-
- void DisableTuning()
- {
- key.enableTuning = 0;
- }
-
- void EnableConcatOneKernel()
- {
- key.restrict.val.dedicated.concat.oneKernel = 1;
- }
-
- void EnableArgMaxMinAxis(ArgMaxMinAxis a)
- {
- switch (a)
- {
- case ArgMaxMinAxis::X:
- key.restrict.val.dedicated.argm.axisX = 1;
- break;
- case ArgMaxMinAxis::Y:
- key.restrict.val.dedicated.argm.axisY = 1;
- break;
- case ArgMaxMinAxis::FEATURE:
- key.restrict.val.dedicated.argm.axisFeature = 1;
- break;
- case ArgMaxMinAxis::BATCH:
- key.restrict.val.dedicated.argm.axisBatch = 1;
- break;
- case ArgMaxMinAxis::XYF:
- key.restrict.val.dedicated.argm.axisXYF = 1;
- break;
- default:
- break;
- }
- }
-
- void EnableLookUpTableIndicesFormat(Datatype a)
- {
- if (a == Datatype::F32)
- key.restrict.val.dedicated.lookt.indicesF32 = 1;
- else
- key.restrict.val.dedicated.lookt.indicesOther = 1;
- }
-
- bool Support(const ParamsKey& k) const
- {
- return
- ((key.restrict.raw & k.key.restrict.raw) == k.key.restrict.raw) && // check if this kernel supports this params
- ((key.machineInfo.raw & k.key.machineInfo.raw) == key.machineInfo.raw) && // check if machine supports this kernel
- ((key.inputType.raw & k.key.inputType.raw) == k.key.inputType.raw) &&
- ((key.outputType.raw & k.key.outputType.raw) == k.key.outputType.raw) &&
- ((key.inputWeightsType.raw & k.key.inputWeightsType.raw) == k.key.inputWeightsType.raw) &&
- ((key.outputWeightsType.raw & k.key.outputWeightsType.raw) == k.key.outputWeightsType.raw) &&
- ((key.inputLayout & k.key.inputLayout) != 0 || key.inputLayout == k.key.inputLayout) &&
- ((key.outputLayout & k.key.outputLayout) != 0 || key.outputLayout == k.key.outputLayout) &&
- ((key.weightsInputLayout & k.key.weightsInputLayout) != 0 || key.weightsInputLayout == k.key.weightsInputLayout) &&
- ((key.weightsOutputLayout & k.key.weightsOutputLayout) != 0 || key.weightsOutputLayout == k.key.weightsOutputLayout);
- }
-
+ void EnableInputDataType(Datatype dt);
+ void EnableAllInputDataType();
+ void EnableOutputDataType(Datatype dt);
+ void EnableAllOutputDataType();
+ void EnableInputWeightsType(WeightsType wt);
+ void EnableAllInputWeightsType();
+ void EnableOutputWeightsType(WeightsType wt);
+ void EnableAllOutputWeightsType();
+ void EnableFP16Emulation() { key.restrict.val.FP16Emulation = 1; }
+ void EnableDifferentTypes() { key.restrict.val.different_types = 1; }
+ void EnableInputLayout(DataLayout l) { key.inputLayout |= (1 << l); }
+ void EnableAllInputLayout() { key.inputLayout = 0xffffffff; }
+ void EnableOutputLayout(DataLayout l) { key.outputLayout |= (1 << l); }
+ void EnableAllOutputLayout() { key.outputLayout = 0xffffffff; }
+ void EnableInputWeightsLayout(WeightsLayout l) { key.weightsInputLayout |= (1 << l); }
+ void EnableAllInputWeightsLayout() { key.weightsInputLayout = 0xffffffff; }
+ void EnableOutputWeightsLayout(WeightsLayout l) { key.weightsOutputLayout |= (1 << l); }
+ void EnableAllOutputWeightsLayout() { key.weightsOutputLayout = 0xffffffff; }
+ void EnableTensorOffset() { key.restrict.val.offset = 1; }
+ void EnableTensorPitches() { key.restrict.val.pitches = 1; }
+ void EnableBatching() { key.restrict.val.batching = 1; }
+ void EnableGradient() { key.restrict.val.gradient = 1; }
+ void EnableSubGroup() { key.machineInfo.val.subgroup = 1; }
+ void EnableSubGroupShort() { key.machineInfo.val.subgroupShort = 1; }
+ void EnableNonBiasTerm() { key.restrict.val.nonBias = 1; }
+ void EnableBiasPerFeature() { key.restrict.val.biasPerFeatureMap = 1; }
+ void EnableBiasPerOutput() { key.restrict.val.biasPerOutput = 1; }
+ void EnableActivationAdditionalParamsAsInput() { key.restrict.val.activationAdditionalParamsAsInput = 1; }
+ void EnableMomentum() { key.restrict.val.momentum = 1; }
+ void EnableLRNMode(LRNMode m);
+ void EnableLookUpTableAxis(LookUpTableAxis m);
+ void EnableNormalizeMode(NormalizeMode m);
+ void EnableMVNMode(MVNMode m);
+ void EnableMVNNormalizeVariance();
+ void EnableLRNKernelDividerMode(KernelDividerMode m);
+ void EnablePoolKernelDividerMode(KernelDividerMode m);
+ void EnablePoolType(PoolType t);
+ void EnablePoolRemainder(PoolRemainder r);
+ void EnableSplitSupport() { key.restrict.val.dedicated.conv.split = 1; }
+ void EnableDilation() { key.restrict.val.dedicated.conv.dilation = 1; }
+ void EnableDepthwiseSeparableOpt() { key.restrict.val.dedicated.conv.depthwiseSeparableOpt = 1; }
+ void EnableTranspose() { key.restrict.val.dedicated.conv.transposed = 1; }
+ void EnableInt8Quantization() { key.restrict.val.dedicated.conv.quantization = 1; }
+ void EnableOutputCalibration() { key.restrict.val.dedicated.conv.calibration = 1; }
+ void EnableWinogradReorder() { key.restrict.val.dedicated.reorder.winograd = 1; }
+ void EnableSoftmaxDim(SoftmaxDim d);
+ void EnableConcatAxis(ConcatAxis a);
+ void EnableUpSamplingSampleType(SampleType a);
+ void EnableLSTMGEMMBias() { key.restrict.val.dedicated.lstm_gemm.bias = 1; }
+ void EnableLSTMGEMMHidden() { key.restrict.val.dedicated.lstm_gemm.hidden = 1; }
+ void EnableLSTMEltCell() { key.restrict.val.dedicated.lstm_elt.cell = 1; }
+ void EnableConcatKernelPerInput() { key.restrict.val.dedicated.concat.kernelPerInput = 1; }
+ void DisableTuning() { key.enableTuning = 0; }
+ void EnableConcatOneKernel() { key.restrict.val.dedicated.concat.oneKernel = 1; }
+ void EnableArgMaxMinAxis(ArgMaxMinAxis a);
+ void EnableLookUpTableIndicesFormat(Datatype a);
+ void EnableIndexSelectAxis(IndexSelectAxis a);
+ bool Support(const ParamsKey& k) const;
bool TuningSupport() const
{
if (key.enableTuning == 1)
return true;
return false;
}
-
- ParamsKey Merge(const ParamsKey& k) const
- {
- ParamsKey ret;
- ret.key.restrict.raw = key.restrict.raw | k.key.restrict.raw;
- ret.key.machineInfo.raw = key.machineInfo.raw | k.key.machineInfo.raw;
- ret.key.inputType.raw = key.inputType.raw | k.key.inputType.raw;
- ret.key.outputType.raw = key.outputType.raw | k.key.outputType.raw;
- ret.key.inputWeightsType.raw = key.inputWeightsType.raw | k.key.inputWeightsType.raw;
- ret.key.outputWeightsType.raw = key.outputWeightsType.raw | k.key.outputWeightsType.raw;
- ret.key.inputLayout = key.inputLayout | k.key.inputLayout;
- ret.key.outputLayout = key.outputLayout | k.key.outputLayout;
- ret.key.weightsInputLayout = key.weightsInputLayout | k.key.weightsInputLayout;
- ret.key.weightsOutputLayout = key.weightsOutputLayout | k.key.weightsOutputLayout;
- return ret;
- }
+ ParamsKey Merge(const ParamsKey& k) const;
private:
Key key;
@@ -783,6 +303,8 @@ namespace kernel_selector
bool bFP16Support = false;
bool bFP64Support = false;
bool bImageSupport = false;
+ bool bIMADSupport = false;
+ bool bIMMADSupport = false;
uint64_t maxWorkGroupSize = 0;
uint64_t maxLocalMemSize = 0;
uint64_t maxImage2dWidth = 0;
@@ -800,22 +322,7 @@ namespace kernel_selector
virtual ~Params() {}
KernelType GetType() const { return kType; }
- virtual ParamsKey GetParamsKey() const
- {
- ParamsKey k;
-
- if (engineInfo.bSubGroupSupport)
- {
- k.EnableSubGroup();
- }
-
- if (engineInfo.bSubGroupShortSupport)
- {
- k.EnableSubGroupShort();
- }
-
- return k;
- }
+ virtual ParamsKey GetParamsKey() const;
protected:
Params(KernelType kt, const std::string& id) : kType(kt), layerID(id) {}
@@ -842,69 +349,7 @@ namespace kernel_selector
bool gradient = false;
virtual std::string to_string() const;
-
- virtual ParamsKey GetParamsKey() const
- {
- ParamsKey k = Params::GetParamsKey();
-
- bool bBatching = false;
- bool bPitches = false;
- bool bOffests = false;
- bool bDifferentTypes = false;
- bool bFP16Used = (output.GetDType() == Datatype::F16);
-
- for (const auto& i : inputs)
- {
- k.EnableInputDataType(i.GetDType());
- k.EnableInputLayout(i.GetLayout());
-
- bBatching |= (i.Batch().v > 1);
- bPitches |= (i.PitchesDifferFromLogicalDims());
- bOffests |= (i.GetFirstElementOffset() != 0);
- bDifferentTypes |= (i.GetDType() != output.GetDType());
- bFP16Used |= (i.GetDType() == Datatype::F16);
- }
-
- k.EnableOutputDataType(output.GetDType());
- k.EnableOutputLayout(output.GetLayout());
-
- if (bBatching)
- {
- k.EnableBatching();
- }
-
- if (bPitches ||
- output.PitchesDifferFromLogicalDims())
- {
- k.EnableTensorPitches();
- }
-
- if (bDifferentTypes)
- {
- k.EnableDifferentTypes();
- }
-
- if (bOffests ||
- output.GetFirstElementOffset() != 0)
- {
- k.EnableTensorOffset();
- }
-
- if (!engineInfo.bFP16Support &&
- bFP16Used)
- {
- // I'm not sure it's the best idea, but we can live with it right now
- k.EnableFP16Emulation();
- }
-
- if (gradient)
- {
- k.EnableGradient();
- }
-
- return k;
- }
-
+ virtual ParamsKey GetParamsKey() const;
protected:
base_params(KernelType kt) : Params(kt, ""), inputs(1){}
@@ -942,23 +387,7 @@ namespace kernel_selector
TuningParams tuningParams;
- virtual ParamsKey GetSupportedKey() const
- {
- ParamsKey k;
-
- for (auto l : inputLayouts)
- {
- k.EnableInputLayout(l);
- }
-
- for (auto l : outputLayouts)
- {
- k.EnableOutputLayout(l);
- }
-
- return k;
- }
-
+ virtual ParamsKey GetSupportedKey() const;
protected:
optional_params(KernelType kt) : kType(kt) {}
KernelType kType;