diff options
author | Alan Baker <alanbaker@google.com> | 2018-02-28 15:23:19 -0500 |
---|---|---|
committer | Steven Perron <stevenperron@google.com> | 2018-02-28 23:12:27 -0500 |
commit | ce5941a6425e2b0f8128d02e830a9609b3f18709 (patch) | |
tree | 08cea63b79aeed8efdde615cfbc3cf82a7dab6fd /source/opt | |
parent | bdaf8d56fbe3fa22ee699e33306ffd5f77b7762f (diff) | |
download | SPIRV-Tools-ce5941a6425e2b0f8128d02e830a9609b3f18709.tar.gz SPIRV-Tools-ce5941a6425e2b0f8128d02e830a9609b3f18709.tar.bz2 SPIRV-Tools-ce5941a6425e2b0f8128d02e830a9609b3f18709.zip |
Fixes #1357. Support null constants better in folding
* getFloatConstantKind() now handles OpConstantNull
* PerformOperation() now handles OpConstantNull for vectors
* Fixed some instances where we would attempt to merge a division by 0
* added tests
Diffstat (limited to 'source/opt')
-rw-r--r-- | source/opt/constants.h | 23 | ||||
-rw-r--r-- | source/opt/folding_rules.cpp | 60 |
2 files changed, 62 insertions, 21 deletions
diff --git a/source/opt/constants.h b/source/opt/constants.h index cd3134b1..999dc52c 100644 --- a/source/opt/constants.h +++ b/source/opt/constants.h @@ -126,6 +126,18 @@ class ScalarConstant : public Constant { // Returns a const reference of the value of this constant in 32-bit words. virtual const std::vector<uint32_t>& words() const { return words_; } + // Returns true if the value is zero. + bool IsZero() const { + bool is_zero = true; + for (uint32_t v : words()) { + if (v != 0) { + is_zero = false; + break; + } + } + return is_zero; + } + protected: ScalarConstant(const Type* ty, const std::vector<uint32_t>& w) : Constant(ty), words_(w) {} @@ -175,17 +187,6 @@ class IntConstant : public ScalarConstant { static_cast<uint64_t>(words()[0]); } - bool IsZero() const { - bool is_zero = true; - for (uint32_t v : words()) { - if (v != 0) { - is_zero = false; - break; - } - } - return is_zero; - } - // Make a copy of this IntConstant instance. std::unique_ptr<IntConstant> CopyIntConstant() const { return MakeUnique<IntConstant>(type_->AsInteger(), words_); diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index f94ba7b0..7e4dddba 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -218,9 +218,12 @@ FoldingRule ReciprocalFDiv() { const analysis::Constant* negated_const = const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids)); id = const_mgr->GetDefiningInstruction(negated_const)->result_id(); - } else { + } else if (constants[1]->AsFloatConstant()) { id = Reciprocal(const_mgr, constants[1]); if (id == 0) return false; + } else { + // Don't fold a null constant. + return false; } inst->SetOpcode(SpvOpFMul); inst->SetInOperands( @@ -384,6 +387,22 @@ FoldingRule MergeNegateAddSubArithmetic() { }; } +// Returns true if |c| has a zero element. +bool HasZero(const analysis::Constant* c) { + if (c->AsNullConstant()) { + return true; + } + if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) { + for (auto& comp : vec_const->GetComponents()) + if (HasZero(comp)) return true; + } else { + assert(c->AsScalarConstant()); + return c->AsScalarConstant()->IsZero(); + } + + return false; +} + // Performs |input1| |opcode| |input2| and returns the merged constant result // id. Returns 0 if the result is not a valid value. The input types must be // Float. @@ -415,6 +434,7 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr, FOLD_OP(*); break; case SpvOpFDiv: + if (HasZero(input2)) return 0; FOLD_OP(/); break; case SpvOpFAdd: @@ -498,10 +518,25 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode, const analysis::Type* ele_type = vector_type->element_type(); for (uint32_t i = 0; i != vector_type->element_count(); ++i) { uint32_t id = 0; - const analysis::Constant* input1_comp = - input1->AsVectorConstant()->GetComponents()[i]; - const analysis::Constant* input2_comp = - input2->AsVectorConstant()->GetComponents()[i]; + + const analysis::Constant* input1_comp = nullptr; + if (const analysis::VectorConstant* input1_vector = + input1->AsVectorConstant()) { + input1_comp = input1_vector->GetComponents()[i]; + } else { + assert(input1->AsNullConstant()); + input1_comp = const_mgr->GetConstant(ele_type, {}); + } + + const analysis::Constant* input2_comp = nullptr; + if (const analysis::VectorConstant* input2_vector = + input2->AsVectorConstant()) { + input2_comp = input2_vector->GetComponents()[i]; + } else { + assert(input2->AsNullConstant()); + input2_comp = const_mgr->GetConstant(ele_type, {}); + } + if (ele_type->AsFloat()) { id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp, input2_comp); @@ -603,7 +638,7 @@ FoldingRule MergeMulDivArithmetic() { std::vector<const analysis::Constant*> other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); - if (!const_input2) return false; + if (!const_input2 || HasZero(const_input2)) return false; bool other_first_is_variable = other_constants[0] == nullptr; // If the variable value is the second operand of the divide, multiply @@ -695,7 +730,7 @@ FoldingRule MergeDivDivArithmetic() { if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); - if (!const_input1) return false; + if (!const_input1 || HasZero(const_input1)) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -704,7 +739,7 @@ FoldingRule MergeDivDivArithmetic() { std::vector<const analysis::Constant*> other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); - if (!const_input2) return false; + if (!const_input2 || HasZero(const_input2)) return false; bool other_first_is_variable = other_constants[0] == nullptr; @@ -765,7 +800,7 @@ FoldingRule MergeDivMulArithmetic() { if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); - if (!const_input1) return false; + if (!const_input1 || HasZero(const_input1)) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -1543,7 +1578,12 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { return FloatConstantKind::Unknown; } - if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) { + assert(HasFloatingPoint(constant->type()) && "Unexpected constant type"); + + if (constant->AsNullConstant()) { + return FloatConstantKind::Zero; + } else if (const analysis::VectorConstant* vc = + constant->AsVectorConstant()) { const std::vector<const analysis::Constant*>& components = vc->GetComponents(); assert(!components.empty()); |