summaryrefslogtreecommitdiff
path: root/source/opt
diff options
context:
space:
mode:
authorAlan Baker <alanbaker@google.com>2018-02-28 15:23:19 -0500
committerSteven Perron <stevenperron@google.com>2018-02-28 23:12:27 -0500
commitce5941a6425e2b0f8128d02e830a9609b3f18709 (patch)
tree08cea63b79aeed8efdde615cfbc3cf82a7dab6fd /source/opt
parentbdaf8d56fbe3fa22ee699e33306ffd5f77b7762f (diff)
downloadSPIRV-Tools-ce5941a6425e2b0f8128d02e830a9609b3f18709.tar.gz
SPIRV-Tools-ce5941a6425e2b0f8128d02e830a9609b3f18709.tar.bz2
SPIRV-Tools-ce5941a6425e2b0f8128d02e830a9609b3f18709.zip
Fixes #1357. Support null constants better in folding
* getFloatConstantKind() now handles OpConstantNull * PerformOperation() now handles OpConstantNull for vectors * Fixed some instances where we would attempt to merge a division by 0 * added tests
Diffstat (limited to 'source/opt')
-rw-r--r--source/opt/constants.h23
-rw-r--r--source/opt/folding_rules.cpp60
2 files changed, 62 insertions, 21 deletions
diff --git a/source/opt/constants.h b/source/opt/constants.h
index cd3134b1..999dc52c 100644
--- a/source/opt/constants.h
+++ b/source/opt/constants.h
@@ -126,6 +126,18 @@ class ScalarConstant : public Constant {
// Returns a const reference of the value of this constant in 32-bit words.
virtual const std::vector<uint32_t>& words() const { return words_; }
+ // Returns true if the value is zero.
+ bool IsZero() const {
+ bool is_zero = true;
+ for (uint32_t v : words()) {
+ if (v != 0) {
+ is_zero = false;
+ break;
+ }
+ }
+ return is_zero;
+ }
+
protected:
ScalarConstant(const Type* ty, const std::vector<uint32_t>& w)
: Constant(ty), words_(w) {}
@@ -175,17 +187,6 @@ class IntConstant : public ScalarConstant {
static_cast<uint64_t>(words()[0]);
}
- bool IsZero() const {
- bool is_zero = true;
- for (uint32_t v : words()) {
- if (v != 0) {
- is_zero = false;
- break;
- }
- }
- return is_zero;
- }
-
// Make a copy of this IntConstant instance.
std::unique_ptr<IntConstant> CopyIntConstant() const {
return MakeUnique<IntConstant>(type_->AsInteger(), words_);
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index f94ba7b0..7e4dddba 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -218,9 +218,12 @@ FoldingRule ReciprocalFDiv() {
const analysis::Constant* negated_const =
const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
- } else {
+ } else if (constants[1]->AsFloatConstant()) {
id = Reciprocal(const_mgr, constants[1]);
if (id == 0) return false;
+ } else {
+ // Don't fold a null constant.
+ return false;
}
inst->SetOpcode(SpvOpFMul);
inst->SetInOperands(
@@ -384,6 +387,22 @@ FoldingRule MergeNegateAddSubArithmetic() {
};
}
+// Returns true if |c| has a zero element.
+bool HasZero(const analysis::Constant* c) {
+ if (c->AsNullConstant()) {
+ return true;
+ }
+ if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
+ for (auto& comp : vec_const->GetComponents())
+ if (HasZero(comp)) return true;
+ } else {
+ assert(c->AsScalarConstant());
+ return c->AsScalarConstant()->IsZero();
+ }
+
+ return false;
+}
+
// Performs |input1| |opcode| |input2| and returns the merged constant result
// id. Returns 0 if the result is not a valid value. The input types must be
// Float.
@@ -415,6 +434,7 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
FOLD_OP(*);
break;
case SpvOpFDiv:
+ if (HasZero(input2)) return 0;
FOLD_OP(/);
break;
case SpvOpFAdd:
@@ -498,10 +518,25 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
const analysis::Type* ele_type = vector_type->element_type();
for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
uint32_t id = 0;
- const analysis::Constant* input1_comp =
- input1->AsVectorConstant()->GetComponents()[i];
- const analysis::Constant* input2_comp =
- input2->AsVectorConstant()->GetComponents()[i];
+
+ const analysis::Constant* input1_comp = nullptr;
+ if (const analysis::VectorConstant* input1_vector =
+ input1->AsVectorConstant()) {
+ input1_comp = input1_vector->GetComponents()[i];
+ } else {
+ assert(input1->AsNullConstant());
+ input1_comp = const_mgr->GetConstant(ele_type, {});
+ }
+
+ const analysis::Constant* input2_comp = nullptr;
+ if (const analysis::VectorConstant* input2_vector =
+ input2->AsVectorConstant()) {
+ input2_comp = input2_vector->GetComponents()[i];
+ } else {
+ assert(input2->AsNullConstant());
+ input2_comp = const_mgr->GetConstant(ele_type, {});
+ }
+
if (ele_type->AsFloat()) {
id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
input2_comp);
@@ -603,7 +638,7 @@ FoldingRule MergeMulDivArithmetic() {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
- if (!const_input2) return false;
+ if (!const_input2 || HasZero(const_input2)) return false;
bool other_first_is_variable = other_constants[0] == nullptr;
// If the variable value is the second operand of the divide, multiply
@@ -695,7 +730,7 @@ FoldingRule MergeDivDivArithmetic() {
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
- if (!const_input1) return false;
+ if (!const_input1 || HasZero(const_input1)) return false;
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
@@ -704,7 +739,7 @@ FoldingRule MergeDivDivArithmetic() {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
- if (!const_input2) return false;
+ if (!const_input2 || HasZero(const_input2)) return false;
bool other_first_is_variable = other_constants[0] == nullptr;
@@ -765,7 +800,7 @@ FoldingRule MergeDivMulArithmetic() {
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
- if (!const_input1) return false;
+ if (!const_input1 || HasZero(const_input1)) return false;
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
@@ -1543,7 +1578,12 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
return FloatConstantKind::Unknown;
}
- if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) {
+ assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
+
+ if (constant->AsNullConstant()) {
+ return FloatConstantKind::Zero;
+ } else if (const analysis::VectorConstant* vc =
+ constant->AsVectorConstant()) {
const std::vector<const analysis::Constant*>& components =
vc->GetComponents();
assert(!components.empty());