summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/clDNN/api/CPP/lstm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/clDNN/api/CPP/lstm.hpp')
-rw-r--r--inference-engine/thirdparty/clDNN/api/CPP/lstm.hpp16
1 files changed, 8 insertions, 8 deletions
diff --git a/inference-engine/thirdparty/clDNN/api/CPP/lstm.hpp b/inference-engine/thirdparty/clDNN/api/CPP/lstm.hpp
index 8b70355b3..dd9e99233 100644
--- a/inference-engine/thirdparty/clDNN/api/CPP/lstm.hpp
+++ b/inference-engine/thirdparty/clDNN/api/CPP/lstm.hpp
@@ -126,8 +126,6 @@ struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
cldnn_lstm_offset_order offset_order;
// NOT SUPPORTED YET
- // /// @brief Number of directions default = 1, bidirectional = 2.
- // uint32_t num_directions;
// /// @brief Optional tensor specifying lengths of the sequences in a batch.
// /// If not specified - assumed all sequences in the batch to have length `seq_length`. It has shape `[batch_size]`.
// tensor sequence_lens;
@@ -185,6 +183,7 @@ struct lstm_gemm : public primitive_base<lstm_gemm, CLDNN_PRIMITIVE_DESC(lstm_ge
/// @param input recurrent Primitive id containing recurrent data. It is required even for no hidden values.
/// @param input bias Primitive id containing bias data. Provide empty string if using lstm without bias.
/// @param input hidden Primitive id containing hidden data. Provide empty string if using lstm without hidden values.
+ /// @param direction default = 0, bidirectional = 1.
lstm_gemm(
const primitive_id& id,
const primitive_id& input,
@@ -192,6 +191,7 @@ struct lstm_gemm : public primitive_base<lstm_gemm, CLDNN_PRIMITIVE_DESC(lstm_ge
const primitive_id& recurrent,
const primitive_id& bias = "",
const primitive_id& hidden = "",
+ const uint32_t direction = 0,
const padding& output_padding = padding()
)
: primitive_base(id, {input}, output_padding)
@@ -199,6 +199,7 @@ struct lstm_gemm : public primitive_base<lstm_gemm, CLDNN_PRIMITIVE_DESC(lstm_ge
, recurrent(recurrent)
, bias(bias)
, hidden(hidden)
+ , direction(direction)
{
}
@@ -209,6 +210,7 @@ struct lstm_gemm : public primitive_base<lstm_gemm, CLDNN_PRIMITIVE_DESC(lstm_ge
, recurrent(dto->recurrent)
, bias(dto->bias)
, hidden(dto->hidden)
+ , direction(dto->direction)
{
}
@@ -220,8 +222,8 @@ struct lstm_gemm : public primitive_base<lstm_gemm, CLDNN_PRIMITIVE_DESC(lstm_ge
primitive_id bias;
/// @brief Primitive id containing the initial value of the hidden data.
primitive_id hidden;
- /// @brief Number of directions default = 1, bidirectional = 2.
- uint32_t num_directions;
+ /// @brief direction default = 0, bidirectional = 1.
+ uint32_t direction;
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
@@ -242,6 +244,7 @@ protected:
dto.recurrent = recurrent.c_str();
dto.bias = bias.c_str();
dto.hidden = hidden.c_str();
+ dto.direction = direction;
}
};
@@ -258,6 +261,7 @@ struct lstm_elt : public primitive_base<lstm_elt, CLDNN_PRIMITIVE_DESC(lstm_elt)
/// @param clip Clip threshold. Provide 0 if using lstm without activations clip threshold.
/// @param input_forget Provide 0 if using lstm without coupled input-forget gates.
/// @param offset_order. Order of the concatenated weights, recurrent, and bias. ONNX default is iofz [input, output, forget, block].
+ /// @param direction default = 0, bidirectional = 1.
lstm_elt(
const primitive_id& id,
const primitive_id& input,
@@ -303,10 +307,6 @@ struct lstm_elt : public primitive_base<lstm_elt, CLDNN_PRIMITIVE_DESC(lstm_elt)
std::vector<cldnn_activation_additional_params> activation_params;
/// @brief Weights, recurrent weights, and biases order. [iofz] : ONNX, [ifoz] : Caffe
cldnn_lstm_offset_order offset_order;
-
- // NOT SUPPORTED YET
- // /// @brief Number of directions default = 1, bidirectional = 2.
- // uint32_t num_directions;
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
{