diff options
Diffstat (limited to 'source')
-rw-r--r-- | source/opt/CMakeLists.txt | 2 | ||||
-rw-r--r-- | source/opt/basic_block.cpp | 9 | ||||
-rw-r--r-- | source/opt/basic_block.h | 4 | ||||
-rw-r--r-- | source/opt/cfg.h | 17 | ||||
-rw-r--r-- | source/opt/dominator_tree.cpp | 4 | ||||
-rw-r--r-- | source/opt/dominator_tree.h | 10 | ||||
-rw-r--r-- | source/opt/function.h | 22 | ||||
-rw-r--r-- | source/opt/ir_builder.h | 46 | ||||
-rw-r--r-- | source/opt/iterator.h | 8 | ||||
-rw-r--r-- | source/opt/loop_descriptor.cpp | 113 | ||||
-rw-r--r-- | source/opt/loop_descriptor.h | 50 | ||||
-rw-r--r-- | source/opt/loop_unswitch_pass.cpp | 908 | ||||
-rw-r--r-- | source/opt/loop_unswitch_pass.h | 43 | ||||
-rw-r--r-- | source/opt/loop_utils.cpp | 110 | ||||
-rw-r--r-- | source/opt/loop_utils.h | 43 | ||||
-rw-r--r-- | source/opt/mem_pass.cpp | 6 | ||||
-rw-r--r-- | source/opt/optimizer.cpp | 5 | ||||
-rw-r--r-- | source/opt/passes.h | 1 |
18 files changed, 1379 insertions, 22 deletions
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 01948515..854c9509 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -60,6 +60,7 @@ add_library(SPIRV-Tools-opt loop_descriptor.h loop_unroller.h loop_utils.h + loop_unswitch_pass.h make_unique.h mem_pass.h merge_return_pass.h @@ -132,6 +133,7 @@ add_library(SPIRV-Tools-opt loop_descriptor.cpp loop_utils.cpp loop_unroller.cpp + loop_unswitch_pass.cpp mem_pass.cpp merge_return_pass.cpp module.cpp diff --git a/source/opt/basic_block.cpp b/source/opt/basic_block.cpp index b07696b1..65030eaa 100644 --- a/source/opt/basic_block.cpp +++ b/source/opt/basic_block.cpp @@ -14,6 +14,7 @@ #include "basic_block.h" #include "function.h" +#include "ir_context.h" #include "module.h" #include "reflect.h" @@ -89,6 +90,14 @@ Instruction* BasicBlock::GetLoopMergeInst() { return nullptr; } +void BasicBlock::KillAllInsts(bool killLabel) { + ForEachInst([killLabel](ir::Instruction* ip) { + if (killLabel || ip->opcode() != SpvOpLabel) { + ip->context()->KillInst(ip); + } + }); +} + void BasicBlock::ForEachSuccessorLabel( const std::function<void(const uint32_t)>& f) const { const auto br = &insts_.back(); diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h index d0186e6e..5f3d393e 100644 --- a/source/opt/basic_block.h +++ b/source/opt/basic_block.h @@ -171,6 +171,10 @@ class BasicBlock { // Returns true if this basic block exits this function or aborts execution. bool IsReturnOrAbort() const { return ctail()->IsReturnOrAbort(); } + // Kill all instructions in this block. Whether or not to kill the label is + // indicated by |killLabel|. + void KillAllInsts(bool killLabel); + private: // The enclosing function. Function* function_; diff --git a/source/opt/cfg.h b/source/opt/cfg.h index 53dddd23..b21a273d 100644 --- a/source/opt/cfg.h +++ b/source/opt/cfg.h @@ -17,6 +17,7 @@ #include "basic_block.h" +#include <algorithm> #include <list> #include <unordered_map> #include <unordered_set> @@ -83,6 +84,22 @@ class CFG { AddEdges(blk); } + // Removes from the CFG any mapping for the basic block id |blk_id|. + void ForgetBlock(const ir::BasicBlock* blk) { + id2block_.erase(blk->id()); + label2preds_.erase(blk->id()); + blk->ForEachSuccessorLabel( + [blk, this](uint32_t succ_id) { RemoveEdge(blk->id(), succ_id); }); + } + + void RemoveEdge(uint32_t pred_blk_id, uint32_t succ_blk_id) { + auto pred_it = label2preds_.find(succ_blk_id); + if (pred_it == label2preds_.end()) return; + auto& preds_list = pred_it->second; + auto it = std::find(preds_list.begin(), preds_list.end(), pred_blk_id); + if (it != preds_list.end()) preds_list.erase(it); + } + // Registers |blk| to all of its successors. void AddEdges(ir::BasicBlock* blk); diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp index c22d7438..776adf4b 100644 --- a/source/opt/dominator_tree.cpp +++ b/source/opt/dominator_tree.cpp @@ -358,6 +358,10 @@ void DominatorTree::InitializeTree(const ir::Function* f, const ir::CFG& cfg) { second->children_.push_back(first); } + ResetDFNumbering(); +} + +void DominatorTree::ResetDFNumbering() { int index = 0; auto preFunc = [&index](const DominatorTreeNode* node) { const_cast<DominatorTreeNode*>(node)->dfs_num_pre_ = ++index; diff --git a/source/opt/dominator_tree.h b/source/opt/dominator_tree.h index 5221eea1..39d5e029 100644 --- a/source/opt/dominator_tree.h +++ b/source/opt/dominator_tree.h @@ -15,6 +15,7 @@ #ifndef LIBSPIRV_OPT_DOMINATOR_ANALYSIS_TREE_H_ #define LIBSPIRV_OPT_DOMINATOR_ANALYSIS_TREE_H_ +#include <algorithm> #include <cstdint> #include <map> #include <utility> @@ -195,7 +196,9 @@ class DominatorTree { } // Returns true if the basic block id |a| is reachable by this tree. - bool ReachableFromRoots(uint32_t a) const; + bool ReachableFromRoots(uint32_t a) const { + return GetTreeNode(a) != nullptr; + } // Returns true if this tree is a post dominator tree. bool IsPostDominator() const { return postdominator_; } @@ -267,11 +270,14 @@ class DominatorTree { return &node_iter->second; } - private: // Adds the basic block |bb| to the tree structure if it doesn't already // exist. DominatorTreeNode* GetOrInsertNode(ir::BasicBlock* bb); + // Recomputes the DF numbering of the tree. + void ResetDFNumbering(); + + private: // Wrapper function which gets the list of pairs of each BasicBlocks to its // immediately dominating BasicBlock and stores the result in the the edges // parameter. diff --git a/source/opt/function.h b/source/opt/function.h index 0da62a8f..17e06376 100644 --- a/source/opt/function.h +++ b/source/opt/function.h @@ -59,6 +59,10 @@ class Function { inline void AddParameter(std::unique_ptr<Instruction> p); // Appends a basic block to this function. inline void AddBasicBlock(std::unique_ptr<BasicBlock> b); + // Appends a basic block to this function at the position |ip|. + inline void AddBasicBlock(std::unique_ptr<BasicBlock> b, iterator ip); + template <typename T> + inline void AddBasicBlocks(T begin, T end, iterator ip); // Saves the given function end instruction. inline void SetFunctionEnd(std::unique_ptr<Instruction> end_inst); @@ -73,6 +77,11 @@ class Function { // Returns function's return type id inline uint32_t type_id() const { return def_inst_->type_id(); } + // Returns the basic block container for this function. + const std::vector<std::unique_ptr<BasicBlock>>* GetBlocks() const { + return &blocks_; + } + // Returns the entry basic block for this function. const std::unique_ptr<BasicBlock>& entry() const { return blocks_.front(); } @@ -123,7 +132,18 @@ inline void Function::AddParameter(std::unique_ptr<Instruction> p) { } inline void Function::AddBasicBlock(std::unique_ptr<BasicBlock> b) { - blocks_.emplace_back(std::move(b)); + AddBasicBlock(std::move(b), end()); +} + +inline void Function::AddBasicBlock(std::unique_ptr<BasicBlock> b, + iterator ip) { + ip.InsertBefore(std::move(b)); +} + +template <typename T> +inline void Function::AddBasicBlocks(T src_begin, T src_end, iterator ip) { + blocks_.insert(ip.Get(), std::make_move_iterator(src_begin), + std::make_move_iterator(src_end)); } inline void Function::SetFunctionEnd(std::unique_ptr<Instruction> end_inst) { diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h index a1a1d1e0..aa722cbd 100644 --- a/source/opt/ir_builder.h +++ b/source/opt/ir_builder.h @@ -105,6 +105,44 @@ class InstructionBuilder { return AddInstruction(std::move(new_branch)); } + // Creates a new switch instruction and the associated selection merge + // instruction if requested. + // The id |selector_id| is the id of the selector instruction, must be of + // type int. + // The id |default_id| is the id of the default basic block to branch to. + // The vector |targets| is the pair of literal/branch id. + // The id |merge_id| is the id of the merge basic block for the selection + // merge instruction. If |merge_id| equals kInvalidId then no selection merge + // instruction will be created. + // The value |selection_control| is the selection control flag for the + // selection merge instruction. + // Note that the user must make sure the final basic block is + // well formed. + ir::Instruction* AddSwitch( + uint32_t selector_id, uint32_t default_id, + const std::vector<std::pair<std::vector<uint32_t>, uint32_t>>& targets, + uint32_t merge_id = kInvalidId, + uint32_t selection_control = SpvSelectionControlMaskNone) { + if (merge_id != kInvalidId) { + AddSelectionMerge(merge_id, selection_control); + } + std::vector<ir::Operand> operands; + operands.emplace_back( + ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {selector_id}}); + operands.emplace_back( + ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {default_id}}); + for (auto& target : targets) { + operands.emplace_back( + ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, + target.first}); + operands.emplace_back(ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, + {target.second}}); + } + std::unique_ptr<ir::Instruction> new_switch( + new ir::Instruction(GetContext(), SpvOpSwitch, 0, 0, operands)); + return AddInstruction(std::move(new_switch)); + } + // Creates a phi instruction. // The id |type| must be the id of the phi instruction's type. // The vector |incomings| must be a sequence of pairs of <definition id, @@ -215,6 +253,14 @@ class InstructionBuilder { return AddInstruction(std::move(new_inst)); } + // Creates an unreachable instruction. + ir::Instruction* AddUnreachable() { + std::unique_ptr<ir::Instruction> select( + new ir::Instruction(GetContext(), SpvOpUnreachable, 0, 0, + std::initializer_list<ir::Operand>{})); + return AddInstruction(std::move(select)); + } + // Inserts the new instruction before the insertion point. ir::Instruction* AddInstruction(std::unique_ptr<ir::Instruction>&& insn) { ir::Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn)); diff --git a/source/opt/iterator.h b/source/opt/iterator.h index 52a8d864..d43dfbef 100644 --- a/source/opt/iterator.h +++ b/source/opt/iterator.h @@ -99,6 +99,14 @@ class UptrVectorIterator inline typename std::enable_if<!IsConstForMethod, UptrVectorIterator>::type Erase(); + // Returns the underlying iterator. + UnderlyingIterator Get() const { return iterator_; } + + // Returns a valid end iterator for the underlying container. + UptrVectorIterator End() const { + return UptrVectorIterator(container_, container_->end()); + } + private: UptrVector* container_; // The container we are manipulating. UnderlyingIterator iterator_; // The raw iterator from the container. diff --git a/source/opt/loop_descriptor.cpp b/source/opt/loop_descriptor.cpp index 131363c5..60d94687 100644 --- a/source/opt/loop_descriptor.cpp +++ b/source/opt/loop_descriptor.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "opt/loop_descriptor.h" +#include <algorithm> #include <iostream> #include <type_traits> #include <utility> @@ -245,11 +246,10 @@ bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) { assert(bb->GetParent() && "The basic block does not belong to a function"); opt::DominatorAnalysis* dom_analysis = context_->GetDominatorAnalysis(bb->GetParent(), *context_->cfg()); - if (!dom_analysis->Dominates(GetHeaderBlock(), bb)) return false; + if (dom_analysis->IsReachable(bb) && + !dom_analysis->Dominates(GetHeaderBlock(), bb)) + return false; - opt::PostDominatorAnalysis* postdom_analysis = - context_->GetPostDominatorAnalysis(bb->GetParent(), *context_->cfg()); - if (!postdom_analysis->Dominates(GetMergeBlock(), bb)) return false; return true; } @@ -378,6 +378,17 @@ void Loop::SetMergeBlock(BasicBlock* merge) { } } +void Loop::SetPreHeaderBlock(BasicBlock* preheader) { + assert(!IsInsideLoop(preheader) && "The preheader block is in the loop"); + assert(preheader->tail()->opcode() == SpvOpBranch && + "The preheader block does not unconditionally branch to the header " + "block"); + assert(preheader->tail()->GetSingleWordOperand(0) == GetHeaderBlock()->id() && + "The preheader block does not unconditionally branch to the header " + "block"); + loop_preheader_ = preheader; +} + void Loop::GetExitBlocks(std::unordered_set<uint32_t>* exit_blocks) const { ir::CFG* cfg = context_->cfg(); exit_blocks->clear(); @@ -412,6 +423,43 @@ void Loop::GetMergingBlocks( } } +namespace { + +static inline bool IsBasicBlockSafeToClone(IRContext* context, BasicBlock* bb) { + for (ir::Instruction& inst : *bb) { + if (!inst.IsBranch() && !context->IsCombinatorInstruction(&inst)) + return false; + } + + return true; +} + +} // namespace + +bool Loop::IsSafeToClone() const { + ir::CFG& cfg = *context_->cfg(); + + for (uint32_t bb_id : GetBlocks()) { + BasicBlock* bb = cfg.block(bb_id); + assert(bb); + if (!IsBasicBlockSafeToClone(context_, bb)) return false; + } + + // Look at the merge construct. + if (GetHeaderBlock()->GetLoopMergeInst()) { + std::unordered_set<uint32_t> blocks; + GetMergingBlocks(&blocks); + blocks.erase(GetMergeBlock()->id()); + for (uint32_t bb_id : blocks) { + BasicBlock* bb = cfg.block(bb_id); + assert(bb); + if (!IsBasicBlockSafeToClone(context_, bb)) return false; + } + } + + return true; +} + bool Loop::IsLCSSA() const { ir::CFG* cfg = context_->cfg(); opt::analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); @@ -482,7 +530,8 @@ void Loop::ComputeLoopStructuredOrder( ordered_loop_blocks->push_back(loop_merge_); } -LoopDescriptor::LoopDescriptor(const Function* f) : loops_() { +LoopDescriptor::LoopDescriptor(const Function* f) + : loops_(), dummy_top_loop_(nullptr) { PopulateList(f); } @@ -503,6 +552,17 @@ void LoopDescriptor::PopulateList(const Function* f) { ir::make_range(dom_tree.post_begin(), dom_tree.post_end())) { Instruction* merge_inst = node.bb_->GetLoopMergeInst(); if (merge_inst) { + bool all_backedge_unreachable = true; + for (uint32_t pid : context->cfg()->preds(node.bb_->id())) { + if (dom_analysis->IsReachable(pid) && + dom_analysis->Dominates(node.bb_->id(), pid)) { + all_backedge_unreachable = false; + break; + } + } + if (all_backedge_unreachable) + continue; // ignore this one, we actually never branch back. + // The id of the merge basic block of this loop. uint32_t merge_bb_id = merge_inst->GetSingleWordOperand(0); @@ -888,5 +948,48 @@ void LoopDescriptor::ClearLoops() { } loops_.clear(); } + +// Adds a new loop nest to the descriptor set. +ir::Loop* LoopDescriptor::AddLoopNest(std::unique_ptr<ir::Loop> new_loop) { + ir::Loop* loop = new_loop.release(); + if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop); + // Iterate from inner to outer most loop, adding basic block to loop mapping + // as we go. + for (ir::Loop& current_loop : + make_range(iterator::begin(loop), iterator::end(nullptr))) { + loops_.push_back(¤t_loop); + for (uint32_t bb_id : current_loop.GetBlocks()) + basic_block_to_loop_.insert(std::make_pair(bb_id, ¤t_loop)); + } + + return loop; +} + +void LoopDescriptor::RemoveLoop(ir::Loop* loop) { + ir::Loop* parent = loop->GetParent() ? loop->GetParent() : &dummy_top_loop_; + parent->nested_loops_.erase(std::find(parent->nested_loops_.begin(), + parent->nested_loops_.end(), loop)); + std::for_each( + loop->nested_loops_.begin(), loop->nested_loops_.end(), + [loop](ir::Loop* sub_loop) { sub_loop->SetParent(loop->GetParent()); }); + parent->nested_loops_.insert(parent->nested_loops_.end(), + loop->nested_loops_.begin(), + loop->nested_loops_.end()); + for (uint32_t bb_id : loop->GetBlocks()) { + ir::Loop* l = FindLoopForBasicBlock(bb_id); + if (l == loop) { + SetBasicBlockToLoop(bb_id, l->GetParent()); + } else { + ForgetBasicBlock(bb_id); + } + } + + LoopContainerType::iterator it = + std::find(loops_.begin(), loops_.end(), loop); + assert(it != loops_.end()); + delete loop; + loops_.erase(it); +} + } // namespace ir } // namespace spvtools diff --git a/source/opt/loop_descriptor.h b/source/opt/loop_descriptor.h index d0421d64..05acce20 100644 --- a/source/opt/loop_descriptor.h +++ b/source/opt/loop_descriptor.h @@ -47,8 +47,8 @@ class Loop { using const_iterator = ChildrenList::const_iterator; using BasicBlockListTy = std::unordered_set<uint32_t>; - Loop() - : context_(nullptr), + explicit Loop(IRContext* context) + : context_(context), loop_header_(nullptr), loop_continue_(nullptr), loop_merge_(nullptr), @@ -59,6 +59,8 @@ class Loop { Loop(IRContext* context, opt::DominatorAnalysis* analysis, BasicBlock* header, BasicBlock* continue_target, BasicBlock* merge_target); + ~Loop() {} + // Iterators over the immediate sub-loops. inline iterator begin() { return nested_loops_.begin(); } inline iterator end() { return nested_loops_.end(); } @@ -115,6 +117,11 @@ class Loop { // Returns the loop pre-header. inline const BasicBlock* GetPreHeaderBlock() const { return loop_preheader_; } + // Sets |preheader| as the loop preheader block. A preheader block must have + // the following properties: + // - |merge| must not be in the loop; + // - have an unconditional branch to the loop header. + void SetPreHeaderBlock(BasicBlock* preheader); // Returns the loop pre-header, if there is no suitable preheader it will be // created. @@ -190,7 +197,16 @@ class Loop { // Adds the Basic Block with |id| to this loop and its parents. void AddBasicBlock(uint32_t id) { for (Loop* loop = this; loop != nullptr; loop = loop->parent_) { - loop_basic_blocks_.insert(id); + loop->loop_basic_blocks_.insert(id); + } + } + + // Removes the Basic Block id |bb_id| from this loop and its parents. + // It the user responsibility to make sure the removed block is not a merge, + // header or continue block. + void RemoveBasicBlock(uint32_t bb_id) { + for (Loop* loop = this; loop != nullptr; loop = loop->parent_) { + loop->loop_basic_blocks_.erase(bb_id); } } @@ -264,6 +280,10 @@ class Loop { return true; } + // Checks if the loop contains any instruction that will prevent it from being + // cloned. If the loop is structured, the merge construct is also considered. + bool IsSafeToClone() const; + // Sets the parent loop of this loop, that is, a loop which contains this loop // as a nested child loop. inline void SetParent(Loop* parent) { parent_ = parent; } @@ -384,7 +404,7 @@ class LoopDescriptor { // Disable copy constructor, to avoid double-free on destruction. LoopDescriptor(const LoopDescriptor&) = delete; // Move constructor. - LoopDescriptor(LoopDescriptor&& other) { + LoopDescriptor(LoopDescriptor&& other) : dummy_top_loop_(nullptr) { // We need to take ownership of the Loop objects in the other // LoopDescriptor, to avoid double-free. loops_ = std::move(other.loops_); @@ -446,6 +466,28 @@ class LoopDescriptor { // for addition with AddLoop or MarkLoopForRemoval. void PostModificationCleanup(); + // Removes the basic block id |bb_id| from the block to loop mapping. + inline void ForgetBasicBlock(uint32_t bb_id) { + basic_block_to_loop_.erase(bb_id); + } + + // Adds the loop |new_loop| and all its nested loops to the descriptor set. + // The object takes ownership of all the loops. + ir::Loop* AddLoopNest(std::unique_ptr<ir::Loop> new_loop); + + // Remove the loop |loop|. + void RemoveLoop(ir::Loop* loop); + + void SetAsTopLoop(ir::Loop* loop) { + assert(std::find(dummy_top_loop_.begin(), dummy_top_loop_.end(), loop) == + dummy_top_loop_.end() && + "already registered"); + dummy_top_loop_.nested_loops_.push_back(loop); + } + + Loop* GetDummyRootLoop() { return &dummy_top_loop_; } + const Loop* GetDummyRootLoop() const { return &dummy_top_loop_; } + private: // TODO(dneto): This should be a vector of unique_ptr. But VisualStudio 2013 // is unable to compile it. diff --git a/source/opt/loop_unswitch_pass.cpp b/source/opt/loop_unswitch_pass.cpp new file mode 100644 index 00000000..53f62995 --- /dev/null +++ b/source/opt/loop_unswitch_pass.cpp @@ -0,0 +1,908 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "loop_unswitch_pass.h" + +#include <functional> +#include <list> +#include <memory> +#include <type_traits> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "basic_block.h" +#include "dominator_tree.h" +#include "fold.h" +#include "function.h" +#include "instruction.h" +#include "ir_builder.h" +#include "ir_context.h" +#include "loop_descriptor.h" + +#include "loop_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +static const uint32_t kTypePointerStorageClassInIdx = 0; +static const uint32_t kBranchCondTrueLabIdInIdx = 1; +static const uint32_t kBranchCondFalseLabIdInIdx = 2; + +} // anonymous namespace + +namespace { + +// This class handle the unswitch procedure for a given loop. +// The unswitch will not happen if: +// - The loop has any instruction that will prevent it; +// - The loop invariant condition is not uniform. +class LoopUnswitch { + public: + LoopUnswitch(ir::IRContext* context, ir::Function* function, ir::Loop* loop, + ir::LoopDescriptor* loop_desc) + : function_(function), + loop_(loop), + loop_desc_(*loop_desc), + context_(context), + switch_block_(nullptr) {} + + // Returns true if the loop can be unswitched. + // Can be unswitch if: + // - The loop has no instructions that prevents it (such as barrier); + // - The loop has one conditional branch or switch that do not depends on the + // loop; + // - The loop invariant condition is uniform; + bool CanUnswitchLoop() { + if (switch_block_) return true; + if (loop_->IsSafeToClone()) return false; + + ir::CFG& cfg = *context_->cfg(); + + for (uint32_t bb_id : loop_->GetBlocks()) { + ir::BasicBlock* bb = cfg.block(bb_id); + if (bb->terminator()->IsBranch() && + bb->terminator()->opcode() != SpvOpBranch) { + if (IsConditionLoopInvariant(bb->terminator())) { + switch_block_ = bb; + break; + } + } + } + + return switch_block_; + } + + // Return the iterator to the basic block |bb|. + ir::Function::iterator FindBasicBlockPosition(ir::BasicBlock* bb_to_find) { + ir::Function::iterator it = std::find_if( + function_->begin(), function_->end(), + [bb_to_find](const ir::BasicBlock& bb) { return bb_to_find == &bb; }); + assert(it != function_->end() && "Basic Block not found"); + return it; + } + + // Creates a new basic block and insert it into the function |fn| at the + // position |ip|. This function preserves the def/use and instr to block + // managers. + ir::BasicBlock* CreateBasicBlock(ir::Function::iterator ip) { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + ir::BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<ir::BasicBlock>( + new ir::BasicBlock(std::unique_ptr<ir::Instruction>(new ir::Instruction( + context_, SpvOpLabel, 0, context_->TakeNextId(), {}))))); + bb->SetParent(function_); + def_use_mgr->AnalyzeInstDef(bb->GetLabelInst()); + context_->set_instr_block(bb->GetLabelInst(), bb); + + return bb; + } + + // Unswitches |loop_|. + void PerformUnswitch() { + assert(CanUnswitchLoop() && + "Cannot unswitch if there is not constant condition"); + assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block"); + assert(loop_->IsLCSSA() && "This loop is not in LCSSA form"); + + ir::CFG& cfg = *context_->cfg(); + DominatorTree* dom_tree = + &context_->GetDominatorAnalysis(function_, *context_->cfg()) + ->GetDomTree(); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + LoopUtils loop_utils(context_, loop_); + + ////////////////////////////////////////////////////////////////////////////// + // Step 1: Create the if merge block for structured modules. + // To do so, the |loop_| merge block will become the if's one and we + // create a merge for the loop. This will limit the amount of duplicated + // code the structured control flow imposes. + // For non structured program, the new loop will be connected to + // the old loop's exit blocks. + ////////////////////////////////////////////////////////////////////////////// + + // Get the merge block if it exists. + ir::BasicBlock* if_merge_block = loop_->GetMergeBlock(); + // The merge block is only created if the loop has a unique exit block. We + // have this guarantee for structured loops, for compute loop it will + // trivially help maintain both a structured-like form and LCSAA. + ir::BasicBlock* loop_merge_block = + if_merge_block + ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block)) + : nullptr; + if (loop_merge_block) { + // Add the instruction and update managers. + opt::InstructionBuilder builder( + context_, loop_merge_block, + ir::IRContext::kAnalysisDefUse | + ir::IRContext::kAnalysisInstrToBlockMapping); + builder.AddBranch(if_merge_block->id()); + builder.SetInsertPoint(&*loop_merge_block->begin()); + cfg.RegisterBlock(loop_merge_block); + def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst()); + // Update CFG. + if_merge_block->ForEachPhiInst( + [loop_merge_block, &builder, this](ir::Instruction* phi) { + ir::Instruction* cloned = phi->Clone(context_); + builder.AddInstruction(std::unique_ptr<ir::Instruction>(cloned)); + phi->SetInOperand(0, {cloned->result_id()}); + phi->SetInOperand(1, {loop_merge_block->id()}); + for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--) + phi->RemoveInOperand(j); + }); + // Copy the predecessor list (will get invalidated otherwise). + std::vector<uint32_t> preds = cfg.preds(if_merge_block->id()); + for (uint32_t pid : preds) { + if (pid == loop_merge_block->id()) continue; + ir::BasicBlock* p_bb = cfg.block(pid); + p_bb->ForEachSuccessorLabel( + [if_merge_block, loop_merge_block](uint32_t* id) { + if (*id == if_merge_block->id()) *id = loop_merge_block->id(); + }); + cfg.AddEdge(pid, loop_merge_block->id()); + } + cfg.RemoveNonExistingEdges(if_merge_block->id()); + // Update loop descriptor. + if (ir::Loop* ploop = loop_->GetParent()) { + ploop->AddBasicBlock(loop_merge_block); + loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop); + } + + // Update the dominator tree. + DominatorTreeNode* loop_merge_dtn = + dom_tree->GetOrInsertNode(loop_merge_block); + DominatorTreeNode* if_merge_block_dtn = + dom_tree->GetOrInsertNode(if_merge_block); + loop_merge_dtn->parent_ = if_merge_block_dtn->parent_; + loop_merge_dtn->children_.push_back(if_merge_block_dtn); + loop_merge_dtn->parent_->children_.push_back(loop_merge_dtn); + if_merge_block_dtn->parent_->children_.erase(std::find( + if_merge_block_dtn->parent_->children_.begin(), + if_merge_block_dtn->parent_->children_.end(), if_merge_block_dtn)); + + loop_->SetMergeBlock(loop_merge_block); + } + + //////////////////////////////////////////////////////////////////////////// + // Step 2: Build a new preheader for |loop_|, use the old one + // for the constant branch. + //////////////////////////////////////////////////////////////////////////// + + ir::BasicBlock* if_block = loop_->GetPreHeaderBlock(); + // If this preheader is the parent loop header, + // we need to create a dedicated block for the if. + ir::BasicBlock* loop_pre_header = + CreateBasicBlock(++FindBasicBlockPosition(if_block)); + opt::InstructionBuilder(context_, loop_pre_header, + ir::IRContext::kAnalysisDefUse | + ir::IRContext::kAnalysisInstrToBlockMapping) + .AddBranch(loop_->GetHeaderBlock()->id()); + + if_block->tail()->SetInOperand(0, {loop_pre_header->id()}); + + // Update loop descriptor. + if (ir::Loop* ploop = loop_desc_[if_block]) { + ploop->AddBasicBlock(loop_pre_header); + loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop); + } + + // Update the CFG. + cfg.RegisterBlock(loop_pre_header); + def_use_mgr->AnalyzeInstDef(loop_pre_header->GetLabelInst()); + cfg.AddEdge(if_block->id(), loop_pre_header->id()); + cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id()); + + loop_->GetHeaderBlock()->ForEachPhiInst( + [loop_pre_header, if_block](ir::Instruction* phi) { + phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) { + if (*id == if_block->id()) { + *id = loop_pre_header->id(); + } + }); + }); + loop_->SetPreHeaderBlock(loop_pre_header); + + // Update the dominator tree. + DominatorTreeNode* loop_pre_header_dtn = + dom_tree->GetOrInsertNode(loop_pre_header); + DominatorTreeNode* if_block_dtn = dom_tree->GetTreeNode(if_block); + loop_pre_header_dtn->parent_ = if_block_dtn; + assert( + if_block_dtn->children_.size() == 1 && + "A loop preheader should only have the header block as a child in the " + "dominator tree"); + loop_pre_header_dtn->children_.push_back(if_block_dtn->children_[0]); + if_block_dtn->children_.clear(); + if_block_dtn->children_.push_back(loop_pre_header_dtn); + + // Make domination queries valid. + dom_tree->ResetDFNumbering(); + + // Compute an ordered list of basic block to clone: loop blocks + pre-header + // + merge block. + loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks_, true, true); + + ///////////////////////////// + // Do the actual unswitch: // + // - Clone the loop // + // - Connect exits // + // - Specialize the loop // + ///////////////////////////// + + ir::Instruction* iv_condition = &*switch_block_->tail(); + SpvOp iv_opcode = iv_condition->opcode(); + ir::Instruction* condition = + def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]); + + analysis::ConstantManager* cst_mgr = context_->get_constant_mgr(); + const analysis::Type* cond_type = + context_->get_type_mgr()->GetType(condition->type_id()); + + // Build the list of value for which we need to clone and specialize the + // loop. + std::vector<std::pair<ir::Instruction*, ir::BasicBlock*>> constant_branch; + // Special case for the original loop + ir::Instruction* original_loop_constant_value; + ir::BasicBlock* original_loop_target; + if (iv_opcode == SpvOpBranchConditional) { + constant_branch.emplace_back( + cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})), + nullptr); + original_loop_constant_value = + cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {1})); + } else { + // We are looking to take the default branch, so we can't provide a + // specific value. + original_loop_constant_value = nullptr; + for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) { + constant_branch.emplace_back( + cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant( + cond_type, iv_condition->GetInOperand(i).words)), + nullptr); + } + } + + // Get the loop landing pads. + std::unordered_set<uint32_t> if_merging_blocks; + std::function<bool(uint32_t)> is_from_original_loop; + if (loop_->GetHeaderBlock()->GetLoopMergeInst()) { + if_merging_blocks.insert(if_merge_block->id()); + is_from_original_loop = [this](uint32_t id) { + return loop_->IsInsideLoop(id) || loop_->GetMergeBlock()->id() == id; + }; + } else { + loop_->GetExitBlocks(&if_merging_blocks); + is_from_original_loop = [this](uint32_t id) { + return loop_->IsInsideLoop(id); + }; + } + + for (auto& specialisation_pair : constant_branch) { + ir::Instruction* specialisation_value = specialisation_pair.first; + ////////////////////////////////////////////////////////// + // Step 3: Duplicate |loop_|. + ////////////////////////////////////////////////////////// + LoopUtils::LoopCloningResult clone_result; + + ir::Loop* cloned_loop = + loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_); + specialisation_pair.second = cloned_loop->GetPreHeaderBlock(); + + //////////////////////////////////// + // Step 4: Specialize the loop. // + //////////////////////////////////// + + { + std::unordered_set<uint32_t> dead_blocks; + std::unordered_set<uint32_t> unreachable_merges; + SimplifyLoop( + ir::make_range( + ir::UptrVectorIterator<ir::BasicBlock>( + &clone_result.cloned_bb_, clone_result.cloned_bb_.begin()), + ir::UptrVectorIterator<ir::BasicBlock>( + &clone_result.cloned_bb_, clone_result.cloned_bb_.end())), + cloned_loop, condition, specialisation_value, &dead_blocks); + + // We tagged dead blocks, create the loop before we invalidate any basic + // block. + cloned_loop = + CleanLoopNest(cloned_loop, dead_blocks, &unreachable_merges); + CleanUpCFG( + ir::UptrVectorIterator<ir::BasicBlock>( + &clone_result.cloned_bb_, clone_result.cloned_bb_.begin()), + dead_blocks, unreachable_merges); + + /////////////////////////////////////////////////////////// + // Step 5: Connect convergent edges to the landing pads. // + /////////////////////////////////////////////////////////// + + for (uint32_t merge_bb_id : if_merging_blocks) { + ir::BasicBlock* merge = context_->cfg()->block(merge_bb_id); + // We are in LCSSA so we only care about phi instructions. + merge->ForEachPhiInst([is_from_original_loop, &dead_blocks, + &clone_result](ir::Instruction* phi) { + uint32_t num_in_operands = phi->NumInOperands(); + for (uint32_t i = 0; i < num_in_operands; i += 2) { + uint32_t pred = phi->GetSingleWordInOperand(i + 1); + if (is_from_original_loop(pred)) { + pred = clone_result.value_map_.at(pred); + if (!dead_blocks.count(pred)) { + uint32_t incoming_value_id = phi->GetSingleWordInOperand(i); + // Not all the incoming value are coming from the loop. + ValueMapTy::iterator new_value = + clone_result.value_map_.find(incoming_value_id); + if (new_value != clone_result.value_map_.end()) { + incoming_value_id = new_value->second; + } + phi->AddOperand({SPV_OPERAND_TYPE_ID, {incoming_value_id}}); + phi->AddOperand({SPV_OPERAND_TYPE_ID, {pred}}); + } + } + } + }); + } + } + function_->AddBasicBlocks(clone_result.cloned_bb_.begin(), + clone_result.cloned_bb_.end(), + ++FindBasicBlockPosition(if_block)); + } + + // Same as above but specialize the existing loop + { + std::unordered_set<uint32_t> dead_blocks; + std::unordered_set<uint32_t> unreachable_merges; + SimplifyLoop(ir::make_range(function_->begin(), function_->end()), loop_, + condition, original_loop_constant_value, &dead_blocks); + + for (uint32_t merge_bb_id : if_merging_blocks) { + ir::BasicBlock* merge = context_->cfg()->block(merge_bb_id); + // LCSSA, so we only care about phi instructions. + // If we the phi is reduced to a single incoming branch, do not + // propagate it to preserve LCSSA. + PatchPhis(merge, dead_blocks, true); + } + if (if_merge_block) { + bool has_live_pred = false; + for (uint32_t pid : cfg.preds(if_merge_block->id())) { + if (!dead_blocks.count(pid)) { + has_live_pred = true; + break; + } + } + if (!has_live_pred) unreachable_merges.insert(if_merge_block->id()); + } + original_loop_target = loop_->GetPreHeaderBlock(); + // We tagged dead blocks, prune the loop descriptor from any dead loops. + // After this call, |loop_| can be nullptr (i.e. the unswitch killed this + // loop). + loop_ = CleanLoopNest(loop_, dead_blocks, &unreachable_merges); + + CleanUpCFG(function_->begin(), dead_blocks, unreachable_merges); + } + + ///////////////////////////////////// + // Finally: connect the new loops. // + ///////////////////////////////////// + + // Delete the old jump + context_->KillInst(&*if_block->tail()); + opt::InstructionBuilder builder(context_, if_block); + if (iv_opcode == SpvOpBranchConditional) { + assert(constant_branch.size() == 1); + builder.AddConditionalBranch( + condition->result_id(), original_loop_target->id(), + constant_branch[0].second->id(), + if_merge_block ? if_merge_block->id() : kInvalidId); + } else { + std::vector<std::pair<std::vector<uint32_t>, uint32_t>> targets; + for (auto& t : constant_branch) { + targets.emplace_back(t.first->GetInOperand(0).words, t.second->id()); + } + + builder.AddSwitch(condition->result_id(), original_loop_target->id(), + targets, + if_merge_block ? if_merge_block->id() : kInvalidId); + } + + switch_block_ = nullptr; + ordered_loop_blocks_.clear(); + + context_->InvalidateAnalysesExceptFor( + ir::IRContext::Analysis::kAnalysisLoopAnalysis); + } + + // Returns true if the unswitch killed the original |loop_|. + bool WasLoopKilled() const { return loop_ == nullptr; } + + private: + using ValueMapTy = std::unordered_map<uint32_t, uint32_t>; + using BlockMapTy = std::unordered_map<uint32_t, ir::BasicBlock*>; + + ir::Function* function_; + ir::Loop* loop_; + ir::LoopDescriptor& loop_desc_; + ir::IRContext* context_; + + ir::BasicBlock* switch_block_; + // Map between instructions and if they are dynamically uniform. + std::unordered_map<uint32_t, bool> dynamically_uniform_; + // The loop basic blocks in structured order. + std::vector<ir::BasicBlock*> ordered_loop_blocks_; + + // Returns the next usable id for the context. + uint32_t TakeNextId() { return context_->TakeNextId(); } + + // Patches |bb|'s phi instruction by removing incoming value from unexisting + // or tagged as dead branches. + void PatchPhis(ir::BasicBlock* bb, + const std::unordered_set<uint32_t>& dead_blocks, + bool preserve_phi) { + ir::CFG& cfg = *context_->cfg(); + + std::vector<ir::Instruction*> phi_to_kill; + const std::vector<uint32_t>& bb_preds = cfg.preds(bb->id()); + auto is_branch_dead = [&bb_preds, &dead_blocks](uint32_t id) { + return dead_blocks.count(id) || + std::find(bb_preds.begin(), bb_preds.end(), id) == bb_preds.end(); + }; + bb->ForEachPhiInst([&phi_to_kill, &is_branch_dead, preserve_phi, + this](ir::Instruction* insn) { + uint32_t i = 0; + while (i < insn->NumInOperands()) { + uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1); + if (is_branch_dead(incoming_id)) { + // Remove the incoming block id operand. + insn->RemoveInOperand(i + 1); + // Remove the definition id operand. + insn->RemoveInOperand(i); + continue; + } + i += 2; + } + // If there is only 1 remaining edge, propagate the value and + // kill the instruction. + if (insn->NumInOperands() == 2 && !preserve_phi) { + phi_to_kill.push_back(insn); + context_->ReplaceAllUsesWith(insn->result_id(), + insn->GetSingleWordInOperand(0)); + } + }); + for (ir::Instruction* insn : phi_to_kill) { + context_->KillInst(insn); + } + } + + // Removes any block that is tagged as dead, if the block is in + // |unreachable_merges| then all block's instructions are replaced by a + // OpUnreachable. + void CleanUpCFG(ir::UptrVectorIterator<ir::BasicBlock> bb_it, + const std::unordered_set<uint32_t>& dead_blocks, + const std::unordered_set<uint32_t>& unreachable_merges) { + ir::CFG& cfg = *context_->cfg(); + + while (bb_it != bb_it.End()) { + ir::BasicBlock& bb = *bb_it; + + if (unreachable_merges.count(bb.id())) { + if (bb.begin() != bb.tail() || + bb.terminator()->opcode() != SpvOpUnreachable) { + // Make unreachable, but leave the label. + bb.KillAllInsts(false); + opt::InstructionBuilder(context_, &bb).AddUnreachable(); + cfg.RemoveNonExistingEdges(bb.id()); + } + ++bb_it; + } else if (dead_blocks.count(bb.id())) { + cfg.ForgetBlock(&bb); + // Kill this block. + bb.KillAllInsts(true); + bb_it = bb_it.Erase(); + } else { + cfg.RemoveNonExistingEdges(bb.id()); + ++bb_it; + } + } + } + + // Return true if |c_inst| is a Boolean constant and set |cond_val| with the + // value that |c_inst| + bool GetConstCondition(const ir::Instruction* c_inst, bool* cond_val) { + bool cond_is_const; + switch (c_inst->opcode()) { + case SpvOpConstantFalse: { + *cond_val = false; + cond_is_const = true; + } break; + case SpvOpConstantTrue: { + *cond_val = true; + cond_is_const = true; + } break; + default: { cond_is_const = false; } break; + } + return cond_is_const; + } + + // Simplifies |loop| assuming the instruction |to_version_insn| takes the + // value |cst_value|. |block_range| is an iterator range returning the loop + // basic blocks in a structured order (dominator first). + // The function will ignore basic blocks returned by |block_range| if they + // does not belong to the loop. + // The set |dead_blocks| will contain all the dead basic blocks. + // + // Requirements: + // - |loop| must be in the LCSSA form; + // - |cst_value| must be constant or null (to represent the default target + // of an OpSwitch). + void SimplifyLoop( + ir::IteratorRange<ir::UptrVectorIterator<ir::BasicBlock>> block_range, + ir::Loop* loop, ir::Instruction* to_version_insn, + ir::Instruction* cst_value, std::unordered_set<uint32_t>* dead_blocks) { + ir::CFG& cfg = *context_->cfg(); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + std::function<bool(uint32_t)> ignore_node; + ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); }; + + std::vector<std::pair<ir::Instruction*, uint32_t>> use_list; + def_use_mgr->ForEachUse( + to_version_insn, [&use_list, &ignore_node, this]( + ir::Instruction* inst, uint32_t operand_index) { + ir::BasicBlock* bb = context_->get_instr_block(inst); + + if (!bb || ignore_node(bb->id())) { + // Out of the loop, the specialization does not apply any more. + return; + } + use_list.emplace_back(inst, operand_index); + }); + + // First pass: inject the specialized value into the loop (and only the + // loop). + for (auto use : use_list) { + ir::Instruction* inst = use.first; + uint32_t operand_index = use.second; + ir::BasicBlock* bb = context_->get_instr_block(inst); + + // If it is not a branch, simply inject the value. + if (!inst->IsBranch()) { + // To also handle switch, cst_value can be nullptr: this case + // means that we are looking to branch to the default target of + // the switch. We don't actually know its value so we don't touch + // it if it not a switch. + if (cst_value) { + inst->SetOperand(operand_index, {cst_value->result_id()}); + def_use_mgr->AnalyzeInstUse(inst); + } + } + + // The user is a branch, kill dead branches. + uint32_t live_target = 0; + std::unordered_set<uint32_t> dead_branches; + switch (inst->opcode()) { + case SpvOpBranchConditional: { + assert(cst_value && "No constant value to specialize !"); + bool branch_cond = false; + if (GetConstCondition(cst_value, &branch_cond)) { + uint32_t true_label = + inst->GetSingleWordInOperand(kBranchCondTrueLabIdInIdx); + uint32_t false_label = + inst->GetSingleWordInOperand(kBranchCondFalseLabIdInIdx); + live_target = branch_cond ? true_label : false_label; + uint32_t dead_target = !branch_cond ? true_label : false_label; + cfg.RemoveEdge(bb->id(), dead_target); + } + break; + } + case SpvOpSwitch: { + live_target = inst->GetSingleWordInOperand(1); + if (cst_value) { + if (!cst_value->IsConstant()) break; + const ir::Operand& cst = cst_value->GetInOperand(0); + for (uint32_t i = 2; i < inst->NumInOperands(); i += 2) { + const ir::Operand& literal = inst->GetInOperand(i); + if (literal == cst) { + live_target = inst->GetSingleWordInOperand(i + 1); + break; + } + } + } + for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) { + uint32_t id = inst->GetSingleWordInOperand(i); + if (id != live_target) { + cfg.RemoveEdge(bb->id(), id); + } + } + } + default: + break; + } + if (live_target != 0) { + // Check for the presence of the merge block. + if (ir::Instruction* merge = bb->GetMergeInst()) + context_->KillInst(merge); + context_->KillInst(&*bb->tail()); + opt::InstructionBuilder builder( + context_, bb, + ir::IRContext::kAnalysisDefUse | + ir::IRContext::kAnalysisInstrToBlockMapping); + builder.AddBranch(live_target); + } + } + + // Go through the loop basic block and tag all blocks that are obviously + // dead. + std::unordered_set<uint32_t> visited; + for (ir::BasicBlock& bb : block_range) { + if (ignore_node(bb.id())) continue; + visited.insert(bb.id()); + + // Check if this block is dead, if so tag it as dead otherwise patch phi + // instructions. + bool has_live_pred = false; + for (uint32_t pid : cfg.preds(bb.id())) { + if (!dead_blocks->count(pid)) { + has_live_pred = true; + break; + } + } + if (!has_live_pred) { + dead_blocks->insert(bb.id()); + const ir::BasicBlock& cbb = bb; + // Patch the phis for any back-edge. + cbb.ForEachSuccessorLabel( + [dead_blocks, &visited, &cfg, this](uint32_t id) { + if (!visited.count(id) || dead_blocks->count(id)) return; + ir::BasicBlock* succ = cfg.block(id); + PatchPhis(succ, *dead_blocks, false); + }); + continue; + } + // Update the phi instructions, some incoming branch have/will disappear. + PatchPhis(&bb, *dead_blocks, /* preserve_phi = */ false); + } + } + + // Returns true if the header is not reachable or tagged as dead or if we + // never loop back. + bool IsLoopDead(ir::BasicBlock* header, ir::BasicBlock* latch, + const std::unordered_set<uint32_t>& dead_blocks) { + if (!header || dead_blocks.count(header->id())) return true; + if (!latch || dead_blocks.count(latch->id())) return true; + for (uint32_t pid : context_->cfg()->preds(header->id())) { + if (!dead_blocks.count(pid)) { + // Seems reachable. + return false; + } + } + return true; + } + + // Cleans the loop nest under |loop| and reflect changes to the loop + // descriptor. This will kill all descriptors that represent dead loops. + // If |loop_| is killed, it will be set to nullptr. + // Any merge blocks that become unreachable will be added to + // |unreachable_merges|. + // The function returns the pointer to |loop| or nullptr if the loop was + // killed. + ir::Loop* CleanLoopNest(ir::Loop* loop, + const std::unordered_set<uint32_t>& dead_blocks, + std::unordered_set<uint32_t>* unreachable_merges) { + // This represent the pair of dead loop and nearest alive parent (nullptr if + // no parent). + std::unordered_map<ir::Loop*, ir::Loop*> dead_loops; + auto get_parent = [&dead_loops](ir::Loop* l) -> ir::Loop* { + std::unordered_map<ir::Loop*, ir::Loop*>::iterator it = + dead_loops.find(l); + if (it != dead_loops.end()) return it->second; + return nullptr; + }; + + bool is_main_loop_dead = + IsLoopDead(loop->GetHeaderBlock(), loop->GetLatchBlock(), dead_blocks); + if (is_main_loop_dead) { + if (ir::Instruction* merge = loop->GetHeaderBlock()->GetLoopMergeInst()) { + context_->KillInst(merge); + } + dead_loops[loop] = loop->GetParent(); + } else + dead_loops[loop] = loop; + // For each loop, check if we killed it. If we did, find a suitable parent + // for its children. + for (ir::Loop& sub_loop : + ir::make_range(++opt::TreeDFIterator<ir::Loop>(loop), + opt::TreeDFIterator<ir::Loop>())) { + if (IsLoopDead(sub_loop.GetHeaderBlock(), sub_loop.GetLatchBlock(), + dead_blocks)) { + if (ir::Instruction* merge = + sub_loop.GetHeaderBlock()->GetLoopMergeInst()) { + context_->KillInst(merge); + } + dead_loops[&sub_loop] = get_parent(&sub_loop); + } else { + // The loop is alive, check if its merge block is dead, if it is, tag it + // as required. + if (sub_loop.GetMergeBlock()) { + uint32_t merge_id = sub_loop.GetMergeBlock()->id(); + if (dead_blocks.count(merge_id)) { + unreachable_merges->insert(sub_loop.GetMergeBlock()->id()); + } + } + } + } + if (!is_main_loop_dead) dead_loops.erase(loop); + + // Remove dead blocks from live loops. + for (uint32_t bb_id : dead_blocks) { + ir::Loop* l = loop_desc_[bb_id]; + if (l) { + l->RemoveBasicBlock(bb_id); + loop_desc_.ForgetBasicBlock(bb_id); + } + } + + std::for_each( + dead_loops.begin(), dead_loops.end(), + [&loop, this]( + std::unordered_map<ir::Loop*, ir::Loop*>::iterator::reference it) { + if (it.first == loop) loop = nullptr; + loop_desc_.RemoveLoop(it.first); + }); + + return loop; + } + + // Returns true if |var| is dynamically uniform. + // Note: this is currently approximated as uniform. + bool IsDynamicallyUniform(ir::Instruction* var, const ir::BasicBlock* entry, + const DominatorTree& post_dom_tree) { + assert(post_dom_tree.IsPostDominator()); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + auto it = dynamically_uniform_.find(var->result_id()); + + if (it != dynamically_uniform_.end()) return it->second; + + analysis::DecorationManager* dec_mgr = context_->get_decoration_mgr(); + + bool& is_uniform = dynamically_uniform_[var->result_id()]; + is_uniform = false; + + dec_mgr->WhileEachDecoration(var->result_id(), SpvDecorationUniform, + [&is_uniform](const ir::Instruction&) { + is_uniform = true; + return false; + }); + if (is_uniform) { + return is_uniform; + } + + ir::BasicBlock* parent = context_->get_instr_block(var); + if (!parent) { + return is_uniform = true; + } + + if (!post_dom_tree.Dominates(parent->id(), entry->id())) { + return is_uniform = false; + } + if (var->opcode() == SpvOpLoad) { + const uint32_t PtrTypeId = + def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id(); + const ir::Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId); + uint32_t storage_class = + PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx); + if (storage_class != SpvStorageClassUniform && + storage_class != SpvStorageClassUniformConstant) { + return is_uniform = false; + } + } else { + if (!context_->IsCombinatorInstruction(var)) { + return is_uniform = false; + } + } + + return is_uniform = var->WhileEachInId([entry, &post_dom_tree, + this](const uint32_t* id) { + return IsDynamicallyUniform(context_->get_def_use_mgr()->GetDef(*id), + entry, post_dom_tree); + }); + } + + // Returns true if |insn| is constant and dynamically uniform within the loop. + bool IsConditionLoopInvariant(ir::Instruction* insn) { + assert(insn->IsBranch()); + assert(insn->opcode() != SpvOpBranch); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + ir::Instruction* condition = + def_use_mgr->GetDef(insn->GetOperand(0).words[0]); + return !loop_->IsInsideLoop(condition) && + IsDynamicallyUniform( + condition, function_->entry().get(), + context_->GetPostDominatorAnalysis(function_, *context_->cfg()) + ->GetDomTree()); + } +}; + +} // namespace + +Pass::Status LoopUnswitchPass::Process(ir::IRContext* c) { + InitializeProcessing(c); + + bool modified = false; + ir::Module* module = c->module(); + + // Process each function in the module + for (ir::Function& f : *module) { + modified |= ProcessFunction(&f); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool LoopUnswitchPass::ProcessFunction(ir::Function* f) { + bool modified = false; + std::unordered_set<ir::Loop*> processed_loop; + + ir::LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f); + + bool loop_changed = true; + while (loop_changed) { + loop_changed = false; + for (ir::Loop& loop : + ir::make_range(++opt::TreeDFIterator<ir::Loop>( + loop_descriptor.GetDummyRootLoop()), + opt::TreeDFIterator<ir::Loop>())) { + if (processed_loop.count(&loop)) continue; + processed_loop.insert(&loop); + + LoopUnswitch unswitcher(context(), f, &loop, &loop_descriptor); + while (!unswitcher.WasLoopKilled() && unswitcher.CanUnswitchLoop()) { + if (!loop.IsLCSSA()) { + LoopUtils(context(), &loop).MakeLoopClosedSSA(); + } + modified = true; + loop_changed = true; + unswitcher.PerformUnswitch(); + } + if (loop_changed) break; + } + } + + return modified; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_unswitch_pass.h b/source/opt/loop_unswitch_pass.h new file mode 100644 index 00000000..dbe58147 --- /dev/null +++ b/source/opt/loop_unswitch_pass.h @@ -0,0 +1,43 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_OPT_LOOP_UNSWITCH_PASS_H_ +#define LIBSPIRV_OPT_LOOP_UNSWITCH_PASS_H_ + +#include "opt/loop_descriptor.h" +#include "opt/pass.h" + +namespace spvtools { +namespace opt { + +// Implements the loop unswitch optimization. +// The loop unswitch hoists invariant "if" statements if the conditions are +// constant within the loop and clones the loop for each branch. +class LoopUnswitchPass : public Pass { + public: + const char* name() const override { return "loop-unswitch"; } + + // Processes the given |module|. Returns Status::Failure if errors occur when + // processing. Returns the corresponding Status::Success if processing is + // succesful to indicate whether changes have been made to the modue. + Pass::Status Process(ir::IRContext* context) override; + + private: + bool ProcessFunction(ir::Function* f); +}; + +} // namespace opt +} // namespace spvtools + +#endif // !LIBSPIRV_OPT_LOOP_UNSWITCH_PASS_H_ diff --git a/source/opt/loop_utils.cpp b/source/opt/loop_utils.cpp index 6c2a15f9..85326790 100644 --- a/source/opt/loop_utils.cpp +++ b/source/opt/loop_utils.cpp @@ -18,6 +18,7 @@ #include <unordered_set> #include <vector> +#include "cfa.h" #include "opt/cfg.h" #include "opt/ir_builder.h" #include "opt/ir_context.h" @@ -481,5 +482,114 @@ void LoopUtils::MakeLoopClosedSSA() { ir::IRContext::Analysis::kAnalysisLoopAnalysis); } +ir::Loop* LoopUtils::CloneLoop( + LoopCloningResult* cloning_result, + const std::vector<ir::BasicBlock*>& ordered_loop_blocks) const { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + std::unique_ptr<ir::Loop> new_loop = MakeUnique<ir::Loop>(context_); + if (loop_->HasParent()) new_loop->SetParent(loop_->GetParent()); + + ir::CFG& cfg = *context_->cfg(); + + // Clone and place blocks in a SPIR-V compliant order (dominators first). + for (ir::BasicBlock* old_bb : ordered_loop_blocks) { + // For each basic block in the loop, we clone it and register the mapping + // between old and new ids. + ir::BasicBlock* new_bb = old_bb->Clone(context_); + new_bb->SetParent(&function_); + new_bb->GetLabelInst()->SetResultId(context_->TakeNextId()); + def_use_mgr->AnalyzeInstDef(new_bb->GetLabelInst()); + context_->set_instr_block(new_bb->GetLabelInst(), new_bb); + cloning_result->cloned_bb_.emplace_back(new_bb); + + cloning_result->old_to_new_bb_[old_bb->id()] = new_bb; + cloning_result->new_to_old_bb_[new_bb->id()] = old_bb; + cloning_result->value_map_[old_bb->id()] = new_bb->id(); + + if (loop_->IsInsideLoop(old_bb)) new_loop->AddBasicBlock(new_bb); + + for (auto& inst : *new_bb) { + if (inst.HasResultId()) { + uint32_t old_result_id = inst.result_id(); + inst.SetResultId(context_->TakeNextId()); + cloning_result->value_map_[old_result_id] = inst.result_id(); + + // Only look at the defs for now, uses are not updated yet. + def_use_mgr->AnalyzeInstDef(&inst); + } + } + } + + // All instructions (including all labels) have been cloned, + // remap instruction operands id with the new ones. + for (std::unique_ptr<ir::BasicBlock>& bb_ref : cloning_result->cloned_bb_) { + ir::BasicBlock* bb = bb_ref.get(); + + for (ir::Instruction& insn : *bb) { + insn.ForEachInId([cloning_result](uint32_t* old_id) { + // If the operand is defined in the loop, remap the id. + auto id_it = cloning_result->value_map_.find(*old_id); + if (id_it != cloning_result->value_map_.end()) { + *old_id = id_it->second; + } + }); + // Only look at what the instruction uses. All defs are register, so all + // should be fine now. + def_use_mgr->AnalyzeInstUse(&insn); + context_->set_instr_block(&insn, bb); + } + cfg.RegisterBlock(bb); + } + + PopulateLoopNest(new_loop.get(), *cloning_result); + + return new_loop.release(); +} + +void LoopUtils::PopulateLoopNest( + ir::Loop* new_loop, const LoopCloningResult& cloning_result) const { + std::unordered_map<ir::Loop*, ir::Loop*> loop_mapping; + loop_mapping[loop_] = new_loop; + + if (loop_->HasParent()) loop_->GetParent()->AddNestedLoop(new_loop); + PopulateLoopDesc(new_loop, loop_, cloning_result); + + for (ir::Loop& sub_loop : + ir::make_range(++opt::TreeDFIterator<ir::Loop>(loop_), + opt::TreeDFIterator<ir::Loop>())) { + ir::Loop* cloned = new ir::Loop(context_); + if (ir::Loop* parent = loop_mapping[sub_loop.GetParent()]) + parent->AddNestedLoop(cloned); + loop_mapping[&sub_loop] = cloned; + PopulateLoopDesc(cloned, &sub_loop, cloning_result); + } + + loop_desc_->AddLoopNest(std::unique_ptr<ir::Loop>(new_loop)); +} + +// Populates |new_loop| descriptor according to |old_loop|'s one. +void LoopUtils::PopulateLoopDesc( + ir::Loop* new_loop, ir::Loop* old_loop, + const LoopCloningResult& cloning_result) const { + for (uint32_t bb_id : old_loop->GetBlocks()) { + ir::BasicBlock* bb = cloning_result.old_to_new_bb_.at(bb_id); + new_loop->AddBasicBlock(bb); + } + new_loop->SetHeaderBlock( + cloning_result.old_to_new_bb_.at(old_loop->GetHeaderBlock()->id())); + if (old_loop->GetLatchBlock()) + new_loop->SetLatchBlock( + cloning_result.old_to_new_bb_.at(old_loop->GetLatchBlock()->id())); + if (old_loop->GetMergeBlock()) { + ir::BasicBlock* bb = + cloning_result.old_to_new_bb_.at(old_loop->GetMergeBlock()->id()); + new_loop->SetMergeBlock(bb); + } + if (old_loop->GetPreHeaderBlock()) + new_loop->SetPreHeaderBlock( + cloning_result.old_to_new_bb_.at(old_loop->GetPreHeaderBlock()->id())); +} + } // namespace opt } // namespace spvtools diff --git a/source/opt/loop_utils.h b/source/opt/loop_utils.h index 89e69367..0e77bb6b 100644 --- a/source/opt/loop_utils.h +++ b/source/opt/loop_utils.h @@ -17,15 +17,11 @@ #include <list> #include <memory> #include <vector> +#include "opt/ir_context.h" #include "opt/loop_descriptor.h" namespace spvtools { -namespace ir { -class Loop; -class IRContext; -} // namespace ir - namespace opt { // LoopUtils is used to encapsulte loop optimizations and from the passes which @@ -33,8 +29,25 @@ namespace opt { // or through a pass which is using this. class LoopUtils { public: + // Holds a auxiliary results of the loop cloning procedure. + struct LoopCloningResult { + using ValueMapTy = std::unordered_map<uint32_t, uint32_t>; + using BlockMapTy = std::unordered_map<uint32_t, ir::BasicBlock*>; + + // Mapping between the original loop ids and the new one. + ValueMapTy value_map_; + // Mapping between original loop blocks to the cloned one. + BlockMapTy old_to_new_bb_; + // Mapping between the cloned loop blocks to original one. + BlockMapTy new_to_old_bb_; + // List of cloned basic block. + std::vector<std::unique_ptr<ir::BasicBlock>> cloned_bb_; + }; + LoopUtils(ir::IRContext* context, ir::Loop* loop) : context_(context), + loop_desc_( + context->GetLoopDescriptor(loop->GetHeaderBlock()->GetParent())), loop_(loop), function_(*loop_->GetHeaderBlock()->GetParent()) {} @@ -72,6 +85,17 @@ class LoopUtils { // Preserves: CFG, def/use and instruction to block mapping. void CreateLoopDedicatedExits(); + // Clone |loop_| and remap its instructions. Newly created blocks + // will be added to the |cloning_result.cloned_bb_| list, correctly ordered to + // be inserted into a function. If the loop is structured, the merge construct + // will also be cloned. The function preserves the def/use, cfg and instr to + // block analyses. + // The cloned loop nest will be added to the loop descriptor and will have + // owner ship. + ir::Loop* CloneLoop( + LoopCloningResult* cloning_result, + const std::vector<ir::BasicBlock*>& ordered_loop_blocks) const; + // Perfom a partial unroll of |loop| by given |factor|. This will copy the // body of the loop |factor| times. So a |factor| of one would give a new loop // with the original body plus one unrolled copy body. @@ -103,8 +127,17 @@ class LoopUtils { private: ir::IRContext* context_; + ir::LoopDescriptor* loop_desc_; ir::Loop* loop_; ir::Function& function_; + + // Populates the loop nest of |new_loop| according to |loop_| nest. + void PopulateLoopNest(ir::Loop* new_loop, + const LoopCloningResult& cloning_result) const; + + // Populates |new_loop| descriptor according to |old_loop|'s one. + void PopulateLoopDesc(ir::Loop* new_loop, ir::Loop* old_loop, + const LoopCloningResult& cloning_result) const; }; } // namespace opt diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index 39964777..86b84a0a 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp @@ -137,11 +137,7 @@ bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const { } void MemPass::KillAllInsts(ir::BasicBlock* bp, bool killLabel) { - bp->ForEachInst([this, killLabel](ir::Instruction* ip) { - if (killLabel || ip->opcode() != SpvOpLabel) { - context()->KillInst(ip); - } - }); + bp->KillAllInsts(killLabel); } bool MemPass::HasLoads(uint32_t varId) const { diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index dced5db5..c52e6435 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -356,6 +356,11 @@ Optimizer::PassToken CreateLoopInvariantCodeMotionPass() { return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::LICMPass>()); } +Optimizer::PassToken CreateLoopUnswitchPass() { + return MakeUnique<Optimizer::PassToken::Impl>( + MakeUnique<opt::LoopUnswitchPass>()); +} + Optimizer::PassToken CreateRedundancyEliminationPass() { return MakeUnique<Optimizer::PassToken::Impl>( MakeUnique<opt::RedundancyEliminationPass>()); diff --git a/source/opt/passes.h b/source/opt/passes.h index 9fb98aaf..f0fb289a 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -42,6 +42,7 @@ #include "local_single_store_elim_pass.h" #include "local_ssa_elim_pass.h" #include "loop_unroller.h" +#include "loop_unswitch_pass.h" #include "merge_return_pass.h" #include "null_pass.h" #include "private_to_local_pass.h" |