diff options
Diffstat (limited to 'model-optimizer/extensions/ops/TensorArrayScatter.py')
-rw-r--r-- | model-optimizer/extensions/ops/TensorArrayScatter.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/model-optimizer/extensions/ops/TensorArrayScatter.py b/model-optimizer/extensions/ops/TensorArrayScatter.py index 349f78830..cb30e87ec 100644 --- a/model-optimizer/extensions/ops/TensorArrayScatter.py +++ b/model-optimizer/extensions/ops/TensorArrayScatter.py @@ -41,8 +41,12 @@ class TensorArrayScatter(Op): flow_in = node.in_node(3) ta_node = Node(node.graph, str(handle.value)) - if ta_node.has_valid('element_shape'): - assert match_shapes(ta_node['element_shape'], value.shape[1:]) + if ta_node.has_valid('element_shape') and len(ta_node.element_shape) > 0: + assert match_shapes(ta_node['element_shape'], value.shape[1:]), \ + 'Shapes are not compatible: {} and {}'.format(ta_node['element_shape'], value.shape[1:]) + else: + ta_node['element_shape'] = value.shape[1:] + # Assign element_shape anyway, because the original element_shape can contain -1 ta_node['element_shape'] = value.shape[1:] #TODO: add smart check that indices and value.shape[0] is compatible |