summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/ops/TensorArrayScatter.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/ops/TensorArrayScatter.py')
-rw-r--r--model-optimizer/extensions/ops/TensorArrayScatter.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/model-optimizer/extensions/ops/TensorArrayScatter.py b/model-optimizer/extensions/ops/TensorArrayScatter.py
index 349f78830..cb30e87ec 100644
--- a/model-optimizer/extensions/ops/TensorArrayScatter.py
+++ b/model-optimizer/extensions/ops/TensorArrayScatter.py
@@ -41,8 +41,12 @@ class TensorArrayScatter(Op):
flow_in = node.in_node(3)
ta_node = Node(node.graph, str(handle.value))
- if ta_node.has_valid('element_shape'):
- assert match_shapes(ta_node['element_shape'], value.shape[1:])
+ if ta_node.has_valid('element_shape') and len(ta_node.element_shape) > 0:
+ assert match_shapes(ta_node['element_shape'], value.shape[1:]), \
+ 'Shapes are not compatible: {} and {}'.format(ta_node['element_shape'], value.shape[1:])
+ else:
+ ta_node['element_shape'] = value.shape[1:]
+
# Assign element_shape anyway, because the original element_shape can contain -1
ta_node['element_shape'] = value.shape[1:]
#TODO: add smart check that indices and value.shape[0] is compatible