diff options
author | Dejan Mircevski <deki@google.com> | 2016-01-15 11:25:11 -0500 |
---|---|---|
committer | David Neto <dneto@google.com> | 2016-01-20 17:00:58 -0500 |
commit | 961f5dc54408b1516b9e4d85ee1fb4891b903f61 (patch) | |
tree | fc7417697b032295a346e7dcbd6c92ac9ca1b941 | |
parent | 383c83729e608c3ce815d7d19f44fc5b729092c5 (diff) | |
download | SPIRV-Tools-961f5dc54408b1516b9e4d85ee1fb4891b903f61.tar.gz SPIRV-Tools-961f5dc54408b1516b9e4d85ee1fb4891b903f61.tar.bz2 SPIRV-Tools-961f5dc54408b1516b9e4d85ee1fb4891b903f61.zip |
Track uses and defs during parsing.
Replace two other, imperfect mechanisms for use-def tracking.
Use ValidationState_t::entry_points to track entry points.
Concentrate undefined-ID diagnostics in a single place.
Move validate_types.h content into validate.h due to increased
inter-dependency.
Track uses of all IDs: TYPE_ID, SCOPE_ID, ...
Also update some blurbs.
Fix entry-point accumulation and move it outside ProcessIds().
Remove validate_types.h from CMakeLists.txt.
Blurb for spvIsIdType.
Remove redundant diagnostics for undefined IDs.
Join "can not" and reformat.
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | source/operand.cpp | 14 | ||||
-rw-r--r-- | source/operand.h | 3 | ||||
-rw-r--r-- | source/validate.cpp | 116 | ||||
-rw-r--r-- | source/validate.h | 279 | ||||
-rw-r--r-- | source/validate_cfg.cpp | 1 | ||||
-rw-r--r-- | source/validate_id.cpp | 1258 | ||||
-rw-r--r-- | source/validate_instruction.cpp | 1 | ||||
-rw-r--r-- | source/validate_layout.cpp | 1 | ||||
-rw-r--r-- | source/validate_passes.h | 2 | ||||
-rw-r--r-- | source/validate_ssa.cpp | 2 | ||||
-rw-r--r-- | source/validate_types.cpp | 31 | ||||
-rw-r--r-- | source/validate_types.h | 241 | ||||
-rw-r--r-- | test/ValidateID.cpp | 49 |
14 files changed, 879 insertions, 1120 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index fa778c31..035987ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -121,7 +121,6 @@ set(SPIRV_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/source/text.h ${CMAKE_CURRENT_SOURCE_DIR}/source/text_handler.h ${CMAKE_CURRENT_SOURCE_DIR}/source/validate.h - ${CMAKE_CURRENT_SOURCE_DIR}/source/validate_types.h ${CMAKE_CURRENT_SOURCE_DIR}/source/assembly_grammar.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/disassemble.cpp diff --git a/source/operand.cpp b/source/operand.cpp index 864a1d65..7eae14e8 100644 --- a/source/operand.cpp +++ b/source/operand.cpp @@ -1367,3 +1367,17 @@ spv_operand_pattern_t spvAlternatePatternFollowingImmediate( // No result-id found, so just expect CIVs. return {SPV_OPERAND_TYPE_OPTIONAL_CIV}; } + +bool spvIsIdType(spv_operand_type_t type) { + switch (type) { + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_TYPE_ID: + case SPV_OPERAND_TYPE_RESULT_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + return true; + default: + return false; + } + return false; +} diff --git a/source/operand.h b/source/operand.h index 96c0bc98..3e27c971 100644 --- a/source/operand.h +++ b/source/operand.h @@ -125,4 +125,7 @@ spv_operand_type_t spvTakeFirstMatchableOperand(spv_operand_pattern_t* pattern); spv_operand_pattern_t spvAlternatePatternFollowingImmediate( const spv_operand_pattern_t& pattern); +// Is the operand an ID? +bool spvIsIdType(spv_operand_type_t type); + #endif // LIBSPIRV_OPERAND_H_ diff --git a/source/validate.cpp b/source/validate.cpp index f06ca2c5..a09eb12c 100644 --- a/source/validate.cpp +++ b/source/validate.cpp @@ -25,7 +25,6 @@ // MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. #include "validate.h" -#include "validate_types.h" #include "validate_passes.h" #include "binary.h" @@ -206,81 +205,20 @@ spv_result_t spvValidateBasic(const spv_instruction_t* pInsts, } #endif -spv_result_t spvValidateIDs(const spv_instruction_t* pInsts, - const uint64_t count, const uint32_t bound, - const spv_opcode_table opcodeTable, - const spv_operand_table operandTable, - const spv_ext_inst_table extInstTable, - spv_position position, - spv_diagnostic* pDiagnostic) { - std::vector<spv_id_info_t> idUses; - std::vector<spv_id_info_t> idDefs; - - for (uint64_t instIndex = 0; instIndex < count; ++instIndex) { - const uint32_t* words = pInsts[instIndex].words.data(); - SpvOp opcode; - spvOpcodeSplit(words[0], nullptr, &opcode); - - spv_opcode_desc opcodeEntry = nullptr; - if (spvOpcodeTableValueLookup(opcodeTable, opcode, &opcodeEntry)) { - DIAGNOSTIC << "Invalid Opcode '" << opcode << "'."; - return SPV_ERROR_INVALID_BINARY; - } - - spv_operand_desc operandEntry = nullptr; - position->index++; // NOTE: Account for Opcode word - for (uint16_t index = 1; index < pInsts[instIndex].words.size(); - ++index, position->index++) { - const uint32_t word = words[index]; - - spv_operand_type_t type = spvBinaryOperandInfo( - word, index, opcodeEntry, operandTable, &operandEntry); - - if (SPV_OPERAND_TYPE_RESULT_ID == type || SPV_OPERAND_TYPE_ID == type) { - if (0 == word) { - DIAGNOSTIC << "Invalid ID of '0' is not allowed."; - return SPV_ERROR_INVALID_ID; - } - if (bound < word) { - DIAGNOSTIC << "Invalid ID '" << word << "' exceeds the bound '" - << bound << "'."; - return SPV_ERROR_INVALID_ID; - } - } - - if (SPV_OPERAND_TYPE_RESULT_ID == type) { - idDefs.push_back( - {word, opcodeEntry->opcode, &pInsts[instIndex], *position}); - } - - if (SPV_OPERAND_TYPE_ID == type) { - idUses.push_back({word, opcodeEntry->opcode, nullptr, *position}); - } - } - } - - // NOTE: Error on redefined ID - for (size_t outerIndex = 0; outerIndex < idDefs.size(); ++outerIndex) { - for (size_t innerIndex = 0; innerIndex < idDefs.size(); ++innerIndex) { - if (outerIndex == innerIndex) { - continue; - } - if (idDefs[outerIndex].id == idDefs[innerIndex].id) { - DIAGNOSTIC << "Multiply defined ID '" << idDefs[outerIndex].id << "'."; - return SPV_ERROR_INVALID_ID; - } - } +spv_result_t spvValidateIDs( + const spv_instruction_t* pInsts, const uint64_t count, + const spv_opcode_table opcodeTable, const spv_operand_table operandTable, + const spv_ext_inst_table extInstTable, const ValidationState_t& state, + spv_position position, spv_diagnostic* pDiagnostic) { + auto undefd = state.usedefs().FindUsesWithoutDefs(); + for (auto id : undefd) { + DIAGNOSTIC << "Undefined ID: " << id; } - - // NOTE: Validate ID usage, including use of undefined ID's position->index = SPV_INDEX_INSTRUCTION; - if (spvValidateInstructionIDs(pInsts, count, idUses.data(), idUses.size(), - idDefs.data(), idDefs.size(), opcodeTable, - operandTable, extInstTable, position, - pDiagnostic)) - return SPV_ERROR_INVALID_ID; - - return SPV_SUCCESS; + spvCheckReturn(spvValidateInstructionIDs(pInsts, count, opcodeTable, + operandTable, extInstTable, state, + position, pDiagnostic)); + return undefd.empty() ? SPV_SUCCESS : SPV_ERROR_INVALID_ID; } namespace { @@ -332,14 +270,28 @@ void DebugInstructionPass(ValidationState_t& _, } } +// Collects use-def info about an instruction's IDs. +void ProcessIds(ValidationState_t& _, const spv_parsed_instruction_t& inst) { + if (inst.result_id) { + _.usedefs().AddDef( + {inst.result_id, inst.opcode, + std::vector<uint32_t>(inst.words, inst.words + inst.num_words)}); + } + for (auto op = inst.operands; op != inst.operands + inst.num_operands; ++op) { + if (spvIsIdType(op->type)) + _.usedefs().AddUse(inst.words[op->offset]); + } +} + spv_result_t ProcessInstruction(void* user_data, - const spv_parsed_instruction_t* inst) { + const spv_parsed_instruction_t* inst) { ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data)); _.incrementInstructionCount(); + if (inst->opcode == SpvOpEntryPoint) _.entry_points().push_back(inst->words[2]); DebugInstructionPass(_, inst); - // TODO(umar): Perform data rules pass + ProcessIds(_, *inst); spvCheckReturn(ModuleLayoutPass(_, inst)); spvCheckReturn(CfgPass(_, inst)); spvCheckReturn(SsaPass(_, inst)); @@ -376,8 +328,8 @@ spv_result_t spvValidate(const spv_const_context context, ProcessInstruction, pDiagnostic)); // TODO(umar): Add validation checks which require the parsing of the entire - // module. Use the information from the processInstructions pass to make - // the checks. + // module. Use the information from the ProcessInstruction pass to make the + // checks. if (vstate.unresolvedForwardIdCount() > 0) { stringstream ss; @@ -408,10 +360,10 @@ spv_result_t spvValidate(const spv_const_context context, if (spvIsInBitfield(SPV_VALIDATE_ID_BIT, options)) { position.index = SPV_INDEX_INSTRUCTION; - spvCheckReturn( - spvValidateIDs(instructions.data(), instructions.size(), header.bound, - context->opcode_table, context->operand_table, - context->ext_inst_table, &position, pDiagnostic)); + spvCheckReturn(spvValidateIDs(instructions.data(), instructions.size(), + context->opcode_table, context->operand_table, + context->ext_inst_table, vstate, &position, + pDiagnostic)); } return SPV_SUCCESS; diff --git a/source/validate.h b/source/validate.h index 86483d0a..f2b5c64a 100644 --- a/source/validate.h +++ b/source/validate.h @@ -27,42 +27,290 @@ #ifndef LIBSPIRV_VALIDATE_H_ #define LIBSPIRV_VALIDATE_H_ -#include "instruction.h" +#include <algorithm> +#include <map> +#include <string> +#include <unordered_set> +#include <utility> +#include <vector> + #include "libspirv/libspirv.h" + +#include "binary.h" +#include "diagnostic.h" +#include "instruction.h" #include "table.h" // Structures +// Info about a result ID. typedef struct spv_id_info_t { + // Id value. uint32_t id; + // Opcode of the instruction defining the id. SpvOp opcode; - const spv_instruction_t* inst; - spv_position_t position; + // Binary words of the instruction defining the id. + std::vector<uint32_t> words; } spv_id_info_t; +namespace libspirv { + +// This enum represents the sections of a SPIRV module. See section 2.4 +// of the SPIRV spec for additional details of the order. The enumerant values +// are in the same order as the vector returned by GetModuleOrder +enum ModuleLayoutSection { + kLayoutCapabilities, // < Section 2.4 #1 + kLayoutExtensions, // < Section 2.4 #2 + kLayoutExtInstImport, // < Section 2.4 #3 + kLayoutMemoryModel, // < Section 2.4 #4 + kLayoutEntryPoint, // < Section 2.4 #5 + kLayoutExecutionMode, // < Section 2.4 #6 + kLayoutDebug1, // < Section 2.4 #7 > 1 + kLayoutDebug2, // < Section 2.4 #7 > 2 + kLayoutAnnotations, // < Section 2.4 #8 + kLayoutTypes, // < Section 2.4 #9 + kLayoutFunctionDeclarations, // < Section 2.4 #10 + kLayoutFunctionDefinitions // < Section 2.4 #11 +}; + +enum class FunctionDecl { + kFunctionDeclUnknown, // < Unknown function declaration + kFunctionDeclDeclaration, // < Function declaration + kFunctionDeclDefinition // < Function definition +}; + +class ValidationState_t; + +// This class manages all function declaration and definitions in a module. It +// handles the state and id information while parsing a function in the SPIR-V +// binary. +// +// NOTE: This class is designed to be a Structure of Arrays. Therefore each +// member variable is a vector whose elements represent the values for the +// corresponding function in a SPIR-V module. Variables that are not vector +// types are used to manage the state while parsing the function. +class Functions { + public: + Functions(ValidationState_t& module); + + // Registers the function in the module. Subsequent instructions will be + // called against this function + spv_result_t RegisterFunction(uint32_t id, uint32_t ret_type_id, + uint32_t function_control, + uint32_t function_type_id); + + // Registers a function parameter in the current function + spv_result_t RegisterFunctionParameter(uint32_t id, uint32_t type_id); + + // Register a function end instruction + spv_result_t RegisterFunctionEnd(); + + // Sets the declaration type of the current function + spv_result_t RegisterSetFunctionDeclType(FunctionDecl type); + + // Registers a block in the current function. Subsequent block instructions + // will target this block + // @param id The ID of the label of the block + spv_result_t RegisterBlock(uint32_t id); + + // Registers a variable in the current block + spv_result_t RegisterBlockVariable(uint32_t type_id, uint32_t id, + SpvStorageClass storage, uint32_t init_id); + + spv_result_t RegisterBlockLoopMerge(uint32_t merge_id, uint32_t continue_id, + SpvLoopControlMask control); + + spv_result_t RegisterBlockSelectionMerge(uint32_t merge_id, + SpvSelectionControlMask control); + + // Registers the end of the block + spv_result_t RegisterBlockEnd(); + + // Returns the number of blocks in the current function being parsed + size_t get_block_count() const; + + // Retuns true if called after a function instruction but before the + // function end instruction + bool in_function_body() const; + + // Returns true if called after a label instruction but before a branch + // instruction + bool in_block() const; + + libspirv::DiagnosticStream diag(spv_result_t error_code) const; + + private: + // Parent module + ValidationState_t& module_; + + // Funciton IDs in a module + std::vector<uint32_t> id_; + + // OpTypeFunction IDs of each of the id_ functions + std::vector<uint32_t> type_id_; + + // The type of declaration of each function + std::vector<FunctionDecl> declaration_type_; + + // TODO(umar): Probably needs better abstractions + // The beginning of the block of functions + std::vector<std::vector<uint32_t>> block_ids_; + + // The variable IDs of the functions + std::vector<std::vector<uint32_t>> variable_ids_; + + // The function parameter ids of the functions + std::vector<std::vector<uint32_t>> parameter_ids_; + + // NOTE: See correspoding getter functions + bool in_function_; + bool in_block_; +}; + +class ValidationState_t { + public: + ValidationState_t(spv_diagnostic* diagnostic, uint32_t options); + + // Forward declares the id in the module + spv_result_t forwardDeclareId(uint32_t id); + + // Removes a forward declared ID if it has been defined + spv_result_t removeIfForwardDeclared(uint32_t id); + + // Assigns a name to an ID + void assignNameToId(uint32_t id, std::string name); + + // Returns a string representation of the ID in the format <id>[Name] where + // the <id> is the numeric valid of the id and the Name is a name assigned by + // the OpName instruction + std::string getIdName(uint32_t id) const; + + // Returns the number of ID which have been forward referenced but not defined + size_t unresolvedForwardIdCount() const; + + // Returns a list of unresolved forward ids. + std::vector<uint32_t> unresolvedForwardIds() const; + + // Returns true if the id has been defined + bool isDefinedId(uint32_t id) const; + + // Returns true if an spv_validate_options_t option is enabled in the + // validation instruction + bool is_enabled(spv_validate_options_t flag) const; + + // Increments the instruction count. Used for diagnostic + int incrementInstructionCount(); + + // Returns the current layout section which is being processed + ModuleLayoutSection getLayoutSection() const; + + // Increments the module_layout_order_section_ + void progressToNextLayoutSectionOrder(); + + // Determines if the op instruction is part of the current section + bool isOpcodeInCurrentLayoutSection(SpvOp op); + + libspirv::DiagnosticStream diag(spv_result_t error_code) const; + + // Returns the function states + Functions& get_functions(); + + // Retuns true if the called after a function instruction but before the + // function end instruction + bool in_function_body() const; + + // Returns true if called after a label instruction but before a branch + // instruction + bool in_block() const; + + // Keeps track of ID definitions and uses. + class UseDefTracker { + public: + void AddDef(const spv_id_info_t& def) { defs_.push_back(def); } + + void AddUse(uint32_t id) { uses_.insert(id); } + + // Finds id's def, if it exists. If found, returns <true, def>. Otherwise, + // returns <false, something>. + std::pair<bool, spv_id_info_t> FindDef(uint32_t id) const { + auto found = + std::find_if(defs_.cbegin(), defs_.cend(), + [id](const spv_id_info_t& e) { return e.id == id; }); + if (found == defs_.cend()) { + return std::make_pair(false, spv_id_info_t{}); + } else { + return std::make_pair(true, *found); + } + } + + // Returns uses of IDs lacking defs. + std::unordered_set<uint32_t> FindUsesWithoutDefs() const { + auto diff = uses_; + for (const auto d : defs_) diff.erase(d.id); + return diff; + } + + private: + std::unordered_set<uint32_t> uses_; + std::vector<spv_id_info_t> defs_; + }; + + UseDefTracker& usedefs() { return usedefs_; } + const UseDefTracker& usedefs() const { return usedefs_; } + + std::vector<uint32_t>& entry_points() { return entry_points_; } + const std::vector<uint32_t>& entry_points() const { return entry_points_; } + + private: + spv_diagnostic* diagnostic_; + // Tracks the number of instructions evaluated by the validator + int instruction_counter_; + + // IDs which have been forward declared but have not been defined + std::unordered_set<uint32_t> unresolved_forward_ids_; + + // Validation options to determine the passes to execute + uint32_t validation_flags_; + + std::map<uint32_t, std::string> operand_names_; + + // The section of the code being processed + ModuleLayoutSection current_layout_section_; + + Functions module_functions_; + + std::vector<SpvCapability> module_capabilities_; + + // Definitions and uses of all the IDs in the module. + UseDefTracker usedefs_; + + // IDs that are entry points, ie, arguments to OpEntryPoint. + std::vector<uint32_t> entry_points_; +}; + +} // namespace libspirv + // Functions /// @brief Validate the ID usage of the instruction stream /// /// @param[in] pInsts stream of instructions /// @param[in] instCount number of instructions -/// @param[in] pIdUses stream of ID uses -/// @param[in] idUsesCount number of ID uses -/// @param[in] pIdDefs stream of ID uses -/// @param[in] idDefsCount number of ID uses /// @param[in] opcodeTable table of specified Opcodes /// @param[in] operandTable table of specified operands +/// @param[in] usedefs use-def info from module parsing /// @param[in,out] position current position in the stream /// @param[out] pDiag contains diagnostic on failure /// /// @return result code -spv_result_t spvValidateInstructionIDs( - const spv_instruction_t* pInsts, const uint64_t instCount, - const spv_id_info_t* pIdUses, const uint64_t idUsesCount, - const spv_id_info_t* pIdDefs, const uint64_t idDefsCount, - const spv_opcode_table opcodeTable, const spv_operand_table operandTable, - const spv_ext_inst_table extInstTable, spv_position position, - spv_diagnostic* pDiag); +spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, + const uint64_t instCount, + const spv_opcode_table opcodeTable, + const spv_operand_table operandTable, + const spv_ext_inst_table extInstTable, + const libspirv::ValidationState_t& state, + spv_position position, + spv_diagnostic* pDiag); /// @brief Validate the ID's within a SPIR-V binary /// @@ -82,4 +330,7 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions, const spv_ext_inst_table extInstTable, spv_position position, spv_diagnostic* pDiagnostic); +#define spvCheckReturn(expression) \ + if (spv_result_t error = (expression)) return error; + #endif // LIBSPIRV_VALIDATE_H_ diff --git a/source/validate_cfg.cpp b/source/validate_cfg.cpp index 6fa349e0..50dc2b61 100644 --- a/source/validate_cfg.cpp +++ b/source/validate_cfg.cpp @@ -25,7 +25,6 @@ // MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. #include "validate_passes.h" -#include "validate_types.h" namespace libspirv { diff --git a/source/validate_id.cpp b/source/validate_id.cpp index 31dd7a84..253ba83a 100644 --- a/source/validate_id.cpp +++ b/source/validate_id.cpp @@ -24,8 +24,7 @@ // TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE // MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. -#include <assert.h> - +#include <cassert> #include <iostream> #include <unordered_map> #include <vector> @@ -41,15 +40,18 @@ action; \ } +using UseDefTracker = libspirv::ValidationState_t::UseDefTracker; + namespace { + class idUsage { public: idUsage(const spv_opcode_table opcodeTableArg, const spv_operand_table operandTableArg, - const spv_ext_inst_table extInstTableArg, const spv_id_info_t* pIdUses, - const uint64_t idUsesCount, const spv_id_info_t* pIdDefs, - const uint64_t idDefsCount, const spv_instruction_t* pInsts, - const uint64_t instCountArg, spv_position positionArg, + const spv_ext_inst_table extInstTableArg, + const spv_instruction_t* pInsts, const uint64_t instCountArg, + const UseDefTracker& usedefs, + const std::vector<uint32_t>& entry_points, spv_position positionArg, spv_diagnostic* pDiagnosticArg) : opcodeTable(opcodeTableArg), operandTable(operandTableArg), @@ -57,54 +59,15 @@ class idUsage { firstInst(pInsts), instCount(instCountArg), position(positionArg), - pDiagnostic(pDiagnosticArg) { - for (uint64_t idUsesIndex = 0; idUsesIndex < idUsesCount; ++idUsesIndex) { - idUses[pIdUses[idUsesIndex].id].push_back(pIdUses[idUsesIndex]); - } - for (uint64_t idDefsIndex = 0; idDefsIndex < idDefsCount; ++idDefsIndex) { - idDefs[pIdDefs[idDefsIndex].id] = pIdDefs[idDefsIndex]; - } - } + pDiagnostic(pDiagnosticArg), + usedefs_(usedefs), + entry_points_(entry_points) {} bool isValid(const spv_instruction_t* inst); template <SpvOp> bool isValid(const spv_instruction_t* inst, const spv_opcode_desc); - std::unordered_map<uint32_t, spv_id_info_t>::iterator find( - const uint32_t& id) { - return idDefs.find(id); - } - std::unordered_map<uint32_t, spv_id_info_t>::const_iterator find( - const uint32_t& id) const { - return idDefs.find(id); - } - - bool found(std::unordered_map<uint32_t, spv_id_info_t>::iterator item) { - return idDefs.end() != item; - } - bool found(std::unordered_map<uint32_t, spv_id_info_t>::const_iterator item) { - return idDefs.end() != item; - } - - std::unordered_map<uint32_t, std::vector<spv_id_info_t>>::iterator findUses( - const uint32_t& id) { - return idUses.find(id); - } - std::unordered_map<uint32_t, std::vector<spv_id_info_t>>::const_iterator - findUses(const uint32_t& id) const { - return idUses.find(id); - } - - bool foundUses( - std::unordered_map<uint32_t, std::vector<spv_id_info_t>>::iterator item) { - return idUses.end() != item; - } - bool foundUses(std::unordered_map< - uint32_t, std::vector<spv_id_info_t>>::const_iterator item) { - return idUses.end() != item; - } - private: const spv_opcode_table opcodeTable; const spv_operand_table operandTable; @@ -113,8 +76,8 @@ class idUsage { const uint64_t instCount; spv_position position; spv_diagnostic* pDiagnostic; - std::unordered_map<uint32_t, std::vector<spv_id_info_t>> idUses; - std::unordered_map<uint32_t, spv_id_info_t> idDefs; + UseDefTracker usedefs_; + std::vector<uint32_t> entry_points_; }; #define DIAG(INDEX) \ @@ -128,42 +91,26 @@ bool idUsage::isValid<SpvOpUndef>(const spv_instruction_t *inst, assert(0 && "Unimplemented!"); return false; } -#endif - -template <> -bool idUsage::isValid<SpvOpName>(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto targetIndex = 1; - auto target = find(inst->words[targetIndex]); - spvCheck(!found(target), DIAG(targetIndex) << "OpName Target <id> '" - << inst->words[targetIndex] - << "' is not defined."; - return false); - return true; -} +#endif // 0 template <> bool idUsage::isValid<SpvOpMemberName>(const spv_instruction_t* inst, const spv_opcode_desc) { auto typeIndex = 1; - auto type = find(inst->words[typeIndex]); - spvCheck(!found(type), DIAG(typeIndex) << "OpMemberName Type <id> '" - << inst->words[typeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypeStruct != type->second.opcode, - DIAG(typeIndex) << "OpMemberName Type <id> '" - << inst->words[typeIndex] - << "' is not a struct type."; - return false); + auto type = usedefs_.FindDef(inst->words[typeIndex]); + if (!type.first || SpvOpTypeStruct != type.second.opcode) { + DIAG(typeIndex) << "OpMemberName Type <id> '" << inst->words[typeIndex] + << "' is not a struct type."; + return false; + } auto memberIndex = 2; auto member = inst->words[memberIndex]; - auto memberCount = (uint32_t)(type->second.inst->words.size() - 2); + auto memberCount = (uint32_t)(type.second.words.size() - 2); spvCheck(memberCount <= member, DIAG(memberIndex) << "OpMemberName Member <id> '" << inst->words[memberIndex] << "' index is larger than Type <id> '" - << type->second.id << "'s member count."; + << type.second.id << "'s member count."; return false); return true; } @@ -172,27 +119,12 @@ template <> bool idUsage::isValid<SpvOpLine>(const spv_instruction_t* inst, const spv_opcode_desc) { auto fileIndex = 1; - auto file = find(inst->words[fileIndex]); - spvCheck(!found(file), DIAG(fileIndex) << "OpLine Target <id> '" - << inst->words[fileIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpString != file->second.opcode, - DIAG(fileIndex) << "OpLine Target <id> '" << inst->words[fileIndex] - << "' is not an OpString."; - return false); - return true; -} - -template <> -bool idUsage::isValid<SpvOpDecorate>(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto targetIndex = 1; - auto target = find(inst->words[targetIndex]); - spvCheck(!found(target), DIAG(targetIndex) << "OpDecorate Target <id> '" - << inst->words[targetIndex] - << "' is not defined."; - return false); + auto file = usedefs_.FindDef(inst->words[fileIndex]); + if (!file.first || SpvOpString != file.second.opcode) { + DIAG(fileIndex) << "OpLine Target <id> '" << inst->words[fileIndex] + << "' is not an OpString."; + return false; + } return true; } @@ -200,20 +132,16 @@ template <> bool idUsage::isValid<SpvOpMemberDecorate>(const spv_instruction_t* inst, const spv_opcode_desc) { auto structTypeIndex = 1; - auto structType = find(inst->words[structTypeIndex]); - spvCheck(!found(structType), DIAG(structTypeIndex) - << "OpMemberDecorate Structure type <id> '" - << inst->words[structTypeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypeStruct != structType->second.inst->opcode, - DIAG(structTypeIndex) << "OpMemberDecorate Structure type <id> '" - << inst->words[structTypeIndex] - << "' is not a struct type."; - return false); + auto structType = usedefs_.FindDef(inst->words[structTypeIndex]); + if (!structType.first || SpvOpTypeStruct != structType.second.opcode) { + DIAG(structTypeIndex) << "OpMemberDecorate Structure type <id> '" + << inst->words[structTypeIndex] + << "' is not a struct type."; + return false; + } auto memberIndex = 2; auto member = inst->words[memberIndex]; - auto memberCount = (uint32_t)(structType->second.inst->words.size() - 2); + auto memberCount = static_cast<uint32_t>(structType.second.words.size() - 2); spvCheck(memberCount < member, DIAG(memberIndex) << "OpMemberDecorate Structure type <id> '" << inst->words[memberIndex] @@ -226,26 +154,13 @@ template <> bool idUsage::isValid<SpvOpGroupDecorate>(const spv_instruction_t* inst, const spv_opcode_desc) { auto decorationGroupIndex = 1; - auto decorationGroup = find(inst->words[decorationGroupIndex]); - spvCheck(!found(decorationGroup), - DIAG(decorationGroupIndex) - << "OpGroupDecorate Decoration group <id> '" - << inst->words[decorationGroupIndex] << "' is not defined."; - return false); - spvCheck(SpvOpDecorationGroup != decorationGroup->second.opcode, - DIAG(decorationGroupIndex) - << "OpGroupDecorate Decoration group <id> '" - << inst->words[decorationGroupIndex] - << "' is not a decoration group."; - return false); - for (size_t targetIndex = 2; targetIndex < inst->words.size(); - ++targetIndex) { - auto target = find(inst->words[targetIndex]); - spvCheck(!found(target), DIAG(targetIndex) - << "OpGroupDecorate Target <id> '" - << inst->words[targetIndex] - << "' is not defined."; - return false); + auto decorationGroup = usedefs_.FindDef(inst->words[decorationGroupIndex]); + if (!decorationGroup.first || + SpvOpDecorationGroup != decorationGroup.second.opcode) { + DIAG(decorationGroupIndex) << "OpGroupDecorate Decoration group <id> '" + << inst->words[decorationGroupIndex] + << "' is not a decoration group."; + return false; } return true; } @@ -254,45 +169,41 @@ bool idUsage::isValid<SpvOpGroupDecorate>(const spv_instruction_t* inst, template <> bool idUsage::isValid<SpvOpGroupMemberDecorate>( const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif +#endif // 0 #if 0 template <> bool idUsage::isValid<SpvOpExtInst>(const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif +#endif // 0 template <> bool idUsage::isValid<SpvOpEntryPoint>(const spv_instruction_t* inst, const spv_opcode_desc) { auto entryPointIndex = 2; - auto entryPoint = find(inst->words[entryPointIndex]); - spvCheck(!found(entryPoint), DIAG(entryPointIndex) - << "OpEntryPoint Entry Point <id> '" - << inst->words[entryPointIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpFunction != entryPoint->second.opcode, - DIAG(entryPointIndex) << "OpEntryPoint Entry Point <id> '" - << inst->words[entryPointIndex] - << "' is not a function."; - return false); + auto entryPoint = usedefs_.FindDef(inst->words[entryPointIndex]); + if (!entryPoint.first || SpvOpFunction != entryPoint.second.opcode) { + DIAG(entryPointIndex) << "OpEntryPoint Entry Point <id> '" + << inst->words[entryPointIndex] + << "' is not a function."; + return false; + } // TODO: Check the entry point signature is void main(void), may be subject // to change - auto entryPointType = find(entryPoint->second.inst->words[4]); - spvCheck(!found(entryPointType), assert(0 && "Unreachable!")); - spvCheck(3 != entryPointType->second.inst->words.size(), - DIAG(entryPointIndex) << "OpEntryPoint Entry Point <id> '" - << inst->words[entryPointIndex] - << "'s function parameter count is not zero."; - return false); - auto returnType = find(entryPoint->second.inst->words[1]); - spvCheck(!found(returnType), assert(0 && "Unreachable!")); - spvCheck(SpvOpTypeVoid != returnType->second.opcode, - DIAG(entryPointIndex) << "OpEntryPoint Entry Point <id> '" - << inst->words[entryPointIndex] - << "'s function return type is not void."; - return false); + auto entryPointType = usedefs_.FindDef(entryPoint.second.words[4]); + if (!entryPointType.first || 3 != entryPointType.second.words.size()) { + DIAG(entryPointIndex) << "OpEntryPoint Entry Point <id> '" + << inst->words[entryPointIndex] + << "'s function parameter count is not zero."; + return false; + } + auto returnType = usedefs_.FindDef(entryPoint.second.words[1]); + if (!returnType.first || SpvOpTypeVoid != returnType.second.opcode) { + DIAG(entryPointIndex) << "OpEntryPoint Entry Point <id> '" + << inst->words[entryPointIndex] + << "'s function return type is not void."; + return false; + } return true; } @@ -300,26 +211,16 @@ template <> bool idUsage::isValid<SpvOpExecutionMode>(const spv_instruction_t* inst, const spv_opcode_desc) { auto entryPointIndex = 1; - auto entryPoint = find(inst->words[entryPointIndex]); - spvCheck(!found(entryPoint), DIAG(entryPointIndex) - << "OpExecutionMode Entry Point <id> '" - << inst->words[entryPointIndex] - << "' is not defined."; - return false); - auto entryPointUses = findUses(inst->words[entryPointIndex]); - spvCheck(!foundUses(entryPointUses), assert(0 && "Unreachable!")); - bool foundEntryPointUse = false; - for (auto use : entryPointUses->second) { - if (SpvOpEntryPoint == use.opcode) { - foundEntryPointUse = true; - } + auto entryPointID = inst->words[entryPointIndex]; + auto found = + std::find(entry_points_.cbegin(), entry_points_.cend(), entryPointID); + if (found == entry_points_.cend()) { + DIAG(entryPointIndex) << "OpExecutionMode Entry Point <id> '" + << inst->words[entryPointIndex] + << "' is not the Entry Point " + "operand of an OpEntryPoint."; + return false; } - spvCheck(!foundEntryPointUse, DIAG(entryPointIndex) - << "OpExecutionMode Entry Point <id> '" - << inst->words[entryPointIndex] - << "' is not the Entry Point " - "operand of an OpEntryPoint."; - return false); return true; } @@ -327,17 +228,14 @@ template <> bool idUsage::isValid<SpvOpTypeVector>(const spv_instruction_t* inst, const spv_opcode_desc) { auto componentIndex = 2; - auto componentType = find(inst->words[componentIndex]); - spvCheck(!found(componentType), DIAG(componentIndex) - << "OpTypeVector Component Type <id> '" - << inst->words[componentIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsScalarType(componentType->second.opcode), - DIAG(componentIndex) << "OpTypeVector Component Type <id> '" - << inst->words[componentIndex] - << "' is not a scalar type."; - return false); + auto componentType = usedefs_.FindDef(inst->words[componentIndex]); + if (!componentType.first || + !spvOpcodeIsScalarType(componentType.second.opcode)) { + DIAG(componentIndex) << "OpTypeVector Component Type <id> '" + << inst->words[componentIndex] + << "' is not a scalar type."; + return false; + } return true; } @@ -345,17 +243,13 @@ template <> bool idUsage::isValid<SpvOpTypeMatrix>(const spv_instruction_t* inst, const spv_opcode_desc) { auto columnTypeIndex = 2; - auto columnType = find(inst->words[columnTypeIndex]); - spvCheck(!found(columnType), DIAG(columnTypeIndex) - << "OpTypeMatrix Column Type <id> '" - << inst->words[columnTypeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypeVector != columnType->second.opcode, - DIAG(columnTypeIndex) << "OpTypeMatrix Column Type <id> '" - << inst->words[columnTypeIndex] - << "' is not a vector."; - return false); + auto columnType = usedefs_.FindDef(inst->words[columnTypeIndex]); + if (!columnType.first || SpvOpTypeVector != columnType.second.opcode) { + DIAG(columnTypeIndex) << "OpTypeMatrix Column Type <id> '" + << inst->words[columnTypeIndex] + << "' is not a vector."; + return false; + } return true; } @@ -370,54 +264,41 @@ template <> bool idUsage::isValid<SpvOpTypeArray>(const spv_instruction_t* inst, const spv_opcode_desc) { auto elementTypeIndex = 2; - auto elementType = find(inst->words[elementTypeIndex]); - spvCheck(!found(elementType), DIAG(elementTypeIndex) - << "OpTypeArray Element Type <id> '" - << inst->words[elementTypeIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeGeneratesType(elementType->second.opcode), - DIAG(elementTypeIndex) << "OpTypeArray Element Type <id> '" - << inst->words[elementTypeIndex] - << "' is not a type."; - return false); + auto elementType = usedefs_.FindDef(inst->words[elementTypeIndex]); + if (!elementType.first || + !spvOpcodeGeneratesType(elementType.second.opcode)) { + DIAG(elementTypeIndex) << "OpTypeArray Element Type <id> '" + << inst->words[elementTypeIndex] + << "' is not a type."; + return false; + } auto lengthIndex = 3; - auto length = find(inst->words[lengthIndex]); - spvCheck(!found(length), DIAG(lengthIndex) << "OpTypeArray Length <id> '" - << inst->words[lengthIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpConstant != length->second.opcode && - SpvOpSpecConstant != length->second.opcode, - DIAG(lengthIndex) << "OpTypeArray Length <id> '" - << inst->words[lengthIndex] - << "' is not a scalar constant type."; - return false); + auto length = usedefs_.FindDef(inst->words[lengthIndex]); + if (!length.first || (SpvOpConstant != length.second.opcode && + SpvOpSpecConstant != length.second.opcode)) { + DIAG(lengthIndex) << "OpTypeArray Length <id> '" << inst->words[lengthIndex] + << "' is not a scalar constant type."; + return false; + } // NOTE: Check the initialiser value of the constant - auto constInst = length->second.inst; + auto constInst = length.second.words; auto constResultTypeIndex = 1; - auto constResultType = find(constInst->words[constResultTypeIndex]); - spvCheck(!found(constResultType), DIAG(lengthIndex) - << "OpTypeArray Length <id> '" - << inst->words[constResultTypeIndex] - << "' result type is not defined."; - return false); - spvCheck(SpvOpTypeInt != constResultType->second.opcode, - DIAG(lengthIndex) << "OpTypeArray Length <id> '" - << inst->words[lengthIndex] - << "' is not a constant integer type."; - return false); - if (4 == constInst->words.size()) { - spvCheck(1 > constInst->words[3], DIAG(lengthIndex) - << "OpTypeArray Length <id> '" - << inst->words[lengthIndex] - << "' value must be at least 1."; + auto constResultType = usedefs_.FindDef(constInst[constResultTypeIndex]); + if (!constResultType.first || SpvOpTypeInt != constResultType.second.opcode) { + DIAG(lengthIndex) << "OpTypeArray Length <id> '" << inst->words[lengthIndex] + << "' is not a constant integer type."; + return false; + } + if (4 == constInst.size()) { + spvCheck(1 > constInst[3], DIAG(lengthIndex) + << "OpTypeArray Length <id> '" + << inst->words[lengthIndex] + << "' value must be at least 1."; return false); - } else if (5 == constInst->words.size()) { - uint64_t value = - constInst->words[3] | ((uint64_t)constInst->words[4]) << 32; - bool signedness = constResultType->second.inst->words[3] != 0; + } else if (5 == constInst.size()) { + uint64_t value = constInst[3] | ((uint64_t)constInst[4]) << 32; + bool signedness = constResultType.second.words[3] != 0; if (signedness) { spvCheck(1 > (int64_t)value, DIAG(lengthIndex) << "OpTypeArray Length <id> '" @@ -438,17 +319,14 @@ template <> bool idUsage::isValid<SpvOpTypeRuntimeArray>(const spv_instruction_t* inst, const spv_opcode_desc) { auto elementTypeIndex = 2; - auto elementType = find(inst->words[elementTypeIndex]); - spvCheck(!found(elementType), DIAG(elementTypeIndex) - << "OpTypeRuntimeArray Element Type <id> '" - << inst->words[elementTypeIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeGeneratesType(elementType->second.opcode), - DIAG(elementTypeIndex) << "OpTypeRuntimeArray Element Type <id> '" - << inst->words[elementTypeIndex] - << "' is not a type."; - return false); + auto elementType = usedefs_.FindDef(inst->words[elementTypeIndex]); + if (!elementType.first || + !spvOpcodeGeneratesType(elementType.second.opcode)) { + DIAG(elementTypeIndex) << "OpTypeRuntimeArray Element Type <id> '" + << inst->words[elementTypeIndex] + << "' is not a type."; + return false; + } return true; } @@ -457,17 +335,14 @@ bool idUsage::isValid<SpvOpTypeStruct>(const spv_instruction_t* inst, const spv_opcode_desc) { for (size_t memberTypeIndex = 2; memberTypeIndex < inst->words.size(); ++memberTypeIndex) { - auto memberType = find(inst->words[memberTypeIndex]); - spvCheck(!found(memberType), DIAG(memberTypeIndex) - << "OpTypeStruct Member Type <id> '" - << inst->words[memberTypeIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeGeneratesType(memberType->second.opcode), - DIAG(memberTypeIndex) << "OpTypeStruct Member Type <id> '" - << inst->words[memberTypeIndex] - << "' is not a type."; - return false); + auto memberType = usedefs_.FindDef(inst->words[memberTypeIndex]); + if (!memberType.first || + !spvOpcodeGeneratesType(memberType.second.opcode)) { + DIAG(memberTypeIndex) << "OpTypeStruct Member Type <id> '" + << inst->words[memberTypeIndex] + << "' is not a type."; + return false; + } } return true; } @@ -476,15 +351,12 @@ template <> bool idUsage::isValid<SpvOpTypePointer>(const spv_instruction_t* inst, const spv_opcode_desc) { auto typeIndex = 3; - auto type = find(inst->words[typeIndex]); - spvCheck(!found(type), DIAG(typeIndex) << "OpTypePointer Type <id> '" - << inst->words[typeIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeGeneratesType(type->second.opcode), - DIAG(typeIndex) << "OpTypePointer Type <id> '" - << inst->words[typeIndex] << "' is not a type."; - return false); + auto type = usedefs_.FindDef(inst->words[typeIndex]); + if (!type.first || !spvOpcodeGeneratesType(type.second.opcode)) { + DIAG(typeIndex) << "OpTypePointer Type <id> '" << inst->words[typeIndex] + << "' is not a type."; + return false; + } return true; } @@ -492,30 +364,20 @@ template <> bool idUsage::isValid<SpvOpTypeFunction>(const spv_instruction_t* inst, const spv_opcode_desc) { auto returnTypeIndex = 2; - auto returnType = find(inst->words[returnTypeIndex]); - spvCheck(!found(returnType), DIAG(returnTypeIndex) - << "OpTypeFunction Return Type <id> '" - << inst->words[returnTypeIndex] - << "' is not defined"; - return false); - spvCheck(!spvOpcodeGeneratesType(returnType->second.opcode), - DIAG(returnTypeIndex) << "OpTypeFunction Return Type <id> '" - << inst->words[returnTypeIndex] - << "' is not a type."; - return false); + auto returnType = usedefs_.FindDef(inst->words[returnTypeIndex]); + if (!returnType.first || !spvOpcodeGeneratesType(returnType.second.opcode)) { + DIAG(returnTypeIndex) << "OpTypeFunction Return Type <id> '" + << inst->words[returnTypeIndex] << "' is not a type."; + return false; + } for (size_t paramTypeIndex = 3; paramTypeIndex < inst->words.size(); ++paramTypeIndex) { - auto paramType = find(inst->words[paramTypeIndex]); - spvCheck(!found(paramType), DIAG(paramTypeIndex) - << "OpTypeFunction Parameter Type <id> '" - << inst->words[paramTypeIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeGeneratesType(paramType->second.opcode), - DIAG(paramTypeIndex) << "OpTypeFunction Parameter Type <id> '" - << inst->words[paramTypeIndex] - << "' is not a type."; - return false); + auto paramType = usedefs_.FindDef(inst->words[paramTypeIndex]); + if (!paramType.first || !spvOpcodeGeneratesType(paramType.second.opcode)) { + DIAG(paramTypeIndex) << "OpTypeFunction Parameter Type <id> '" + << inst->words[paramTypeIndex] << "' is not a type."; + return false; + } } return true; } @@ -531,17 +393,13 @@ template <> bool idUsage::isValid<SpvOpConstantTrue>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpConstantTrue Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypeBool != resultType->second.opcode, - DIAG(resultTypeIndex) << "OpConstantTrue Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not a boolean type."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first || SpvOpTypeBool != resultType.second.opcode) { + DIAG(resultTypeIndex) << "OpConstantTrue Result Type <id> '" + << inst->words[resultTypeIndex] + << "' is not a boolean type."; + return false; + } return true; } @@ -549,17 +407,13 @@ template <> bool idUsage::isValid<SpvOpConstantFalse>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpConstantFalse Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypeBool != resultType->second.opcode, - DIAG(resultTypeIndex) << "OpConstantFalse Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not a boolean type."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first || SpvOpTypeBool != resultType.second.opcode) { + DIAG(resultTypeIndex) << "OpConstantFalse Result Type <id> '" + << inst->words[resultTypeIndex] + << "' is not a boolean type."; + return false; + } return true; } @@ -567,18 +421,13 @@ template <> bool idUsage::isValid<SpvOpConstant>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpConstant Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsScalarType(resultType->second.opcode), - DIAG(resultTypeIndex) - << "OpConstant Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not a scalar integer or floating point type."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first || !spvOpcodeIsScalarType(resultType.second.opcode)) { + DIAG(resultTypeIndex) + << "OpConstant Result Type <id> '" << inst->words[resultTypeIndex] + << "' is not a scalar integer or floating point type."; + return false; + } return true; } @@ -586,185 +435,172 @@ template <> bool idUsage::isValid<SpvOpConstantComposite>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpConstantComposite Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsComposite(resultType->second.opcode), - DIAG(resultTypeIndex) << "OpConstantComposite Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not a composite type."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first || !spvOpcodeIsComposite(resultType.second.opcode)) { + DIAG(resultTypeIndex) << "OpConstantComposite Result Type <id> '" + << inst->words[resultTypeIndex] + << "' is not a composite type."; + return false; + } auto constituentCount = inst->words.size() - 3; - switch (resultType->second.opcode) { + switch (resultType.second.opcode) { case SpvOpTypeVector: { - auto componentCount = resultType->second.inst->words[3]; + auto componentCount = resultType.second.words[3]; spvCheck( componentCount != constituentCount, // TODO: Output ID's on diagnostic DIAG(inst->words.size() - 1) << "OpConstantComposite Constituent <id> count does not match " "Result Type <id> '" - << resultType->second.id << "'s vector component count."; + << resultType.second.id << "'s vector component count."; return false); - auto componentType = find(resultType->second.inst->words[2]); - spvCheck(!found(componentType), assert(0 && "Unreachable!")); + auto componentType = usedefs_.FindDef(resultType.second.words[2]); + assert(componentType.first); for (size_t constituentIndex = 3; constituentIndex < inst->words.size(); constituentIndex++) { - auto constituent = find(inst->words[constituentIndex]); - spvCheck(!found(constituent), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeIsConstant(constituent->second.opcode), - DIAG(constituentIndex) - << "OpConstantComposite Constituent <id> '" - << inst->words[constituentIndex] << "' is not a constant."; - return false); - auto constituentResultType = find(constituent->second.inst->words[1]); - spvCheck(!found(constituentResultType), assert(0 && "Unreachable!")); - spvCheck(componentType->second.opcode != - constituentResultType->second.opcode, - DIAG(constituentIndex) - << "OpConstantComposite Constituent <id> '" - << inst->words[constituentIndex] - << "'s type does not match Result Type <id> '" - << resultType->second.id << "'s vector element type."; - return false); + auto constituent = usedefs_.FindDef(inst->words[constituentIndex]); + if (!constituent.first || + !spvOpcodeIsConstant(constituent.second.opcode)) { + DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" + << inst->words[constituentIndex] + << "' is not a constant."; + return false; + } + auto constituentResultType = + usedefs_.FindDef(constituent.second.words[1]); + if (!constituentResultType.first || + componentType.second.opcode != + constituentResultType.second.opcode) { + DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" + << inst->words[constituentIndex] + << "'s type does not match Result Type <id> '" + << resultType.second.id + << "'s vector element type."; + return false; + } } } break; case SpvOpTypeMatrix: { - auto columnCount = resultType->second.inst->words[3]; + auto columnCount = resultType.second.words[3]; spvCheck( columnCount != constituentCount, // TODO: Output ID's on diagnostic DIAG(inst->words.size() - 1) << "OpConstantComposite Constituent <id> count does not match " "Result Type <id> '" - << resultType->second.id << "'s matrix column count."; + << resultType.second.id << "'s matrix column count."; return false); - auto columnType = find(resultType->second.inst->words[2]); - spvCheck(!found(columnType), assert(0 && "Unreachable!")); - auto componentCount = columnType->second.inst->words[3]; - auto componentType = find(columnType->second.inst->words[2]); - spvCheck(!found(componentType), assert(0 && "Unreachable!")); + auto columnType = usedefs_.FindDef(resultType.second.words[2]); + assert(columnType.first); + auto componentCount = columnType.second.words[3]; + auto componentType = usedefs_.FindDef(columnType.second.words[2]); + assert(componentType.first); for (size_t constituentIndex = 3; constituentIndex < inst->words.size(); constituentIndex++) { - auto constituent = find(inst->words[constituentIndex]); - spvCheck(!found(constituent), - DIAG(constituentIndex) - << "OpConstantComposite Constituent <id> '" - << inst->words[constituentIndex] << "' is not defined."; - return false); - spvCheck(SpvOpConstantComposite != constituent->second.opcode, - DIAG(constituentIndex) - << "OpConstantComposite Constituent <id> '" - << inst->words[constituentIndex] - << "' is not a constant composite."; - return false); - auto vector = find(constituent->second.inst->words[1]); - spvCheck(!found(vector), assert(0 && "Unreachable!")); - spvCheck(columnType->second.opcode != vector->second.opcode, + auto constituent = usedefs_.FindDef(inst->words[constituentIndex]); + if (!constituent.first || + SpvOpConstantComposite != constituent.second.opcode) { + DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" + << inst->words[constituentIndex] + << "' is not a constant composite."; + return false; + } + auto vector = usedefs_.FindDef(constituent.second.words[1]); + assert(vector.first); + spvCheck(columnType.second.opcode != vector.second.opcode, DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" << inst->words[constituentIndex] << "' type does not match Result Type <id> '" - << resultType->second.id << "'s matrix column type."; + << resultType.second.id << "'s matrix column type."; return false); - auto vectorComponentType = find(vector->second.inst->words[2]); - spvCheck(!found(vectorComponentType), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeAreTypesEqual(componentType->second.inst, - vectorComponentType->second.inst), + auto vectorComponentType = usedefs_.FindDef(vector.second.words[2]); + assert(vectorComponentType.first); + spvCheck(componentType.second.id != vectorComponentType.second.id, DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" << inst->words[constituentIndex] << "' component type does not match Result Type <id> '" - << resultType->second.id + << resultType.second.id << "'s matrix column component type."; return false); spvCheck( - componentCount != vector->second.inst->words[3], + componentCount != vector.second.words[3], DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" << inst->words[constituentIndex] << "' vector component count does not match Result Type <id> '" - << resultType->second.id << "'s vector component count."; + << resultType.second.id << "'s vector component count."; return false); } } break; case SpvOpTypeArray: { - auto elementType = find(resultType->second.inst->words[2]); - spvCheck(!found(elementType), assert(0 && "Unreachable!")); - auto length = find(resultType->second.inst->words[3]); - spvCheck(!found(length), assert(0 && "Unreachable!")); - spvCheck(length->second.inst->words[3] != constituentCount, + auto elementType = usedefs_.FindDef(resultType.second.words[2]); + assert(elementType.first); + auto length = usedefs_.FindDef(resultType.second.words[3]); + assert(length.first); + spvCheck(length.second.words[3] != constituentCount, DIAG(inst->words.size() - 1) << "OpConstantComposite Constituent count does not match " "Result Type <id> '" - << resultType->second.id << "'s array length."; + << resultType.second.id << "'s array length."; return false); for (size_t constituentIndex = 3; constituentIndex < inst->words.size(); constituentIndex++) { - auto constituent = find(inst->words[constituentIndex]); - spvCheck(!found(constituent), - DIAG(constituentIndex) - << "OpConstantComposite Constituent <id> '" - << inst->words[constituentIndex] << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsConstant(constituent->second.opcode), - DIAG(constituentIndex) - << "OpConstantComposite Constituent <id> '" - << inst->words[constituentIndex] << "' is not a constant."; - return false); - auto constituentType = find(constituent->second.inst->words[1]); - spvCheck(!found(constituentType), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeAreTypesEqual(elementType->second.inst, - constituentType->second.inst), + auto constituent = usedefs_.FindDef(inst->words[constituentIndex]); + if (!constituent.first || + !spvOpcodeIsConstant(constituent.second.opcode)) { + DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" + << inst->words[constituentIndex] + << "' is not a constant."; + return false; + } + auto constituentType = usedefs_.FindDef(constituent.second.words[1]); + assert(constituentType.first); + spvCheck(elementType.second.id != constituentType.second.id, DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" << inst->words[constituentIndex] << "'s type does not match Result Type <id> '" - << resultType->second.id << "'s array element type."; + << resultType.second.id << "'s array element type."; return false); } } break; case SpvOpTypeStruct: { - auto memberCount = resultType->second.inst->words.size() - 2; + auto memberCount = resultType.second.words.size() - 2; spvCheck(memberCount != constituentCount, DIAG(resultTypeIndex) << "OpConstantComposite Constituent <id> '" << inst->words[resultTypeIndex] << "' count does not match Result Type <id> '" - << resultType->second.id << "'s struct member count."; + << resultType.second.id << "'s struct member count."; return false); for (uint32_t constituentIndex = 3, memberIndex = 2; constituentIndex < inst->words.size(); constituentIndex++, memberIndex++) { - auto constituent = find(inst->words[constituentIndex]); - spvCheck(!found(constituent), - DIAG(constituentIndex) - << "OpConstantComposite Constituent <id> '" - << inst->words[constituentIndex] << "' is not define."; - return false); - spvCheck(!spvOpcodeIsConstant(constituent->second.opcode), - DIAG(constituentIndex) - << "OpConstantComposite Constituent <id> '" - << inst->words[constituentIndex] << "' is not a constant."; - return false); - auto constituentType = find(constituent->second.inst->words[1]); - spvCheck(!found(constituentType), assert(0 && "Unreachable!")); - - auto memberType = find(resultType->second.inst->words[memberIndex]); - spvCheck(!found(memberType), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeAreTypesEqual(memberType->second.inst, - constituentType->second.inst), + auto constituent = usedefs_.FindDef(inst->words[constituentIndex]); + if (!constituent.first || + !spvOpcodeIsConstant(constituent.second.opcode)) { + DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" + << inst->words[constituentIndex] + << "' is not a constant."; + return false; + } + auto constituentType = usedefs_.FindDef(constituent.second.words[1]); + assert(constituentType.first); + + auto memberType = + usedefs_.FindDef(resultType.second.words[memberIndex]); + assert(memberType.first); + spvCheck(memberType.second.id != constituentType.second.id, DIAG(constituentIndex) << "OpConstantComposite Constituent <id> '" << inst->words[constituentIndex] << "' type does not match the Result Type <id> '" - << resultType->second.id << "'s member type."; + << resultType.second.id << "'s member type."; return false); } } break; @@ -777,17 +613,13 @@ template <> bool idUsage::isValid<SpvOpConstantSampler>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpConstantSampler Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypeSampler != resultType->second.opcode, - DIAG(resultTypeIndex) << "OpConstantSampler Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not a sampler type."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first || SpvOpTypeSampler != resultType.second.opcode) { + DIAG(resultTypeIndex) << "OpConstantSampler Result Type <id> '" + << inst->words[resultTypeIndex] + << "' is not a sampler type."; + return false; + } return true; } @@ -795,46 +627,41 @@ template <> bool idUsage::isValid<SpvOpConstantNull>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpConstantNull Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - switch (resultType->second.inst->opcode) { + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first) return false; + switch (resultType.second.opcode) { default: { - spvCheck(!spvOpcodeIsBasicTypeNullable(resultType->second.inst->opcode), + spvCheck(!spvOpcodeIsBasicTypeNullable(resultType.second.opcode), DIAG(resultTypeIndex) << "OpConstantNull Result Type <id> '" << inst->words[resultTypeIndex] - << "' can not be null."; + << "' cannot be null."; return false); } break; case SpvOpTypeVector: { - auto type = find(resultType->second.inst->words[2]); - spvCheck(!found(type), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeIsBasicTypeNullable(type->second.inst->opcode), + auto type = usedefs_.FindDef(resultType.second.words[2]); + assert(type.first); + spvCheck(!spvOpcodeIsBasicTypeNullable(type.second.opcode), DIAG(resultTypeIndex) << "OpConstantNull Result Type <id> '" << inst->words[resultTypeIndex] - << "'s vector component type can not be null."; + << "'s vector component type cannot be null."; return false); } break; case SpvOpTypeArray: { - auto type = find(resultType->second.inst->words[2]); - spvCheck(!found(type), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeIsBasicTypeNullable(type->second.inst->opcode), - DIAG(resultTypeIndex) - << "OpConstantNull Result Type <id> '" - << inst->words[resultTypeIndex] - << "'s array element type can not be null."; + auto type = usedefs_.FindDef(resultType.second.words[2]); + assert(type.first); + spvCheck(!spvOpcodeIsBasicTypeNullable(type.second.opcode), + DIAG(resultTypeIndex) << "OpConstantNull Result Type <id> '" + << inst->words[resultTypeIndex] + << "'s array element type cannot be null."; return false); } break; case SpvOpTypeMatrix: { - auto columnType = find(resultType->second.inst->words[2]); - spvCheck(!found(columnType), assert(0 && "Unreachable!")); - auto type = find(columnType->second.inst->words[2]); - spvCheck(!found(type), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeIsBasicTypeNullable(type->second.inst->opcode), + auto columnType = usedefs_.FindDef(resultType.second.words[2]); + assert(columnType.first); + auto type = usedefs_.FindDef(columnType.second.words[2]); + assert(type.first); + spvCheck(!spvOpcodeIsBasicTypeNullable(type.second.opcode), DIAG(resultTypeIndex) << "OpConstantNull Result Type <id> '" << inst->words[resultTypeIndex] @@ -843,15 +670,14 @@ bool idUsage::isValid<SpvOpConstantNull>(const spv_instruction_t* inst, } break; case SpvOpTypeStruct: { for (size_t elementIndex = 2; - elementIndex < resultType->second.inst->words.size(); - ++elementIndex) { - auto element = find(resultType->second.inst->words[elementIndex]); - spvCheck(!found(element), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeIsBasicTypeNullable(element->second.inst->opcode), + elementIndex < resultType.second.words.size(); ++elementIndex) { + auto element = usedefs_.FindDef(resultType.second.words[elementIndex]); + assert(element.first); + spvCheck(!spvOpcodeIsBasicTypeNullable(element.second.opcode), DIAG(resultTypeIndex) << "OpConstantNull Result Type <id> '" << inst->words[resultTypeIndex] - << "'s struct element type can not be null."; + << "'s struct element type cannot be null."; return false); } } break; @@ -863,17 +689,13 @@ template <> bool idUsage::isValid<SpvOpSpecConstantTrue>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpSpecConstantTrue Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypeBool != resultType->second.opcode, - DIAG(resultTypeIndex) << "OpSpecConstantTrue Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not a boolean type."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first || SpvOpTypeBool != resultType.second.opcode) { + DIAG(resultTypeIndex) << "OpSpecConstantTrue Result Type <id> '" + << inst->words[resultTypeIndex] + << "' is not a boolean type."; + return false; + } return true; } @@ -881,17 +703,13 @@ template <> bool idUsage::isValid<SpvOpSpecConstantFalse>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpSpecConstantFalse Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypeBool != resultType->second.opcode, - DIAG(resultTypeIndex) << "OpSpecConstantFalse Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not a boolean type."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first || SpvOpTypeBool != resultType.second.opcode) { + DIAG(resultTypeIndex) << "OpSpecConstantFalse Result Type <id> '" + << inst->words[resultTypeIndex] + << "' is not a boolean type."; + return false; + } return true; } @@ -899,17 +717,13 @@ template <> bool idUsage::isValid<SpvOpSpecConstant>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpSpecConstant Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsScalarType(resultType->second.opcode), - DIAG(resultTypeIndex) << "OpSpecConstant Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not a scalar type."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first || !spvOpcodeIsScalarType(resultType.second.opcode)) { + DIAG(resultTypeIndex) << "OpSpecConstant Result Type <id> '" + << inst->words[resultTypeIndex] + << "' is not a scalar type."; + return false; + } return true; } @@ -928,30 +742,22 @@ template <> bool idUsage::isValid<SpvOpVariable>(const spv_instruction_t* inst, const spv_opcode_desc opcodeEntry) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpVariable Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypePointer != resultType->second.opcode, - DIAG(resultTypeIndex) << "OpVariable Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not a pointer type."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first || SpvOpTypePointer != resultType.second.opcode) { + DIAG(resultTypeIndex) << "OpVariable Result Type <id> '" + << inst->words[resultTypeIndex] + << "' is not a pointer type."; + return false; + } if (opcodeEntry->numTypes < inst->words.size()) { auto initialiserIndex = 4; - auto initialiser = find(inst->words[initialiserIndex]); - spvCheck(!found(initialiser), DIAG(initialiserIndex) - << "OpVariable Initializer <id> '" - << inst->words[initialiserIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsConstant(initialiser->second.opcode), - DIAG(initialiserIndex) << "OpVariable Initializer <id> '" - << inst->words[initialiserIndex] - << "' is not a constant."; - return false); + auto initialiser = usedefs_.FindDef(inst->words[initialiserIndex]); + if (!initialiser.first || !spvOpcodeIsConstant(initialiser.second.opcode)) { + DIAG(initialiserIndex) << "OpVariable Initializer <id> '" + << inst->words[initialiserIndex] + << "' is not a constant."; + return false; + } } return true; } @@ -960,30 +766,26 @@ template <> bool idUsage::isValid<SpvOpLoad>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpLoad Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defind."; + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + spvCheck(!resultType.first, DIAG(resultTypeIndex) + << "OpLoad Result Type <id> '" + << inst->words[resultTypeIndex] + << "' is not defind."; return false); auto pointerIndex = 3; - auto pointer = find(inst->words[pointerIndex]); - spvCheck(!found(pointer), DIAG(pointerIndex) << "OpLoad Pointer <id> '" - << inst->words[pointerIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsPointer(pointer->second.opcode), - DIAG(pointerIndex) << "OpLoad Pointer <id> '" - << inst->words[pointerIndex] - << "' is not a pointer."; - return false); - auto type = find(pointer->second.inst->words[1]); - spvCheck(!found(type), assert(0 && "Unreachable!")); - spvCheck(resultType != type, DIAG(resultTypeIndex) - << "OpLoad Result Type <id> '" - << inst->words[resultTypeIndex] - << " does not match Pointer <id> '" - << pointer->second.id << "'s type."; + auto pointer = usedefs_.FindDef(inst->words[pointerIndex]); + if (!pointer.first || !spvOpcodeIsPointer(pointer.second.opcode)) { + DIAG(pointerIndex) << "OpLoad Pointer <id> '" << inst->words[pointerIndex] + << "' is not a pointer."; + return false; + } + auto type = usedefs_.FindDef(pointer.second.words[1]); + assert(type.first); + spvCheck(resultType.second.id != type.second.id, + DIAG(resultTypeIndex) + << "OpLoad Result Type <id> '" << inst->words[resultTypeIndex] + << " does not match Pointer <id> '" << pointer.second.id + << "'s type."; return false); return true; } @@ -992,49 +794,41 @@ template <> bool idUsage::isValid<SpvOpStore>(const spv_instruction_t* inst, const spv_opcode_desc) { auto pointerIndex = 1; - auto pointer = find(inst->words[pointerIndex]); - spvCheck(!found(pointer), DIAG(pointerIndex) << "OpStore Pointer <id> '" - << inst->words[pointerIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsPointer(pointer->second.opcode), - DIAG(pointerIndex) << "OpStore Pointer <id> '" - << inst->words[pointerIndex] - << "' is not a pointer."; - return false); - auto pointerType = find(pointer->second.inst->words[1]); - spvCheck(!found(pointerType), assert(0 && "Unreachable!")); - auto type = find(pointerType->second.inst->words[3]); - spvCheck(!found(type), assert(0 && "Unreachable!")); - spvCheck(SpvOpTypeVoid == type->second.opcode, - DIAG(pointerIndex) << "OpStore Pointer <id> '" - << inst->words[pointerIndex] - << "'s type is void."; + auto pointer = usedefs_.FindDef(inst->words[pointerIndex]); + if (!pointer.first || !spvOpcodeIsPointer(pointer.second.opcode)) { + DIAG(pointerIndex) << "OpStore Pointer <id> '" << inst->words[pointerIndex] + << "' is not a pointer."; + return false; + } + auto pointerType = usedefs_.FindDef(pointer.second.words[1]); + assert(pointerType.first); + auto type = usedefs_.FindDef(pointerType.second.words[3]); + assert(type.first); + spvCheck(SpvOpTypeVoid == type.second.opcode, DIAG(pointerIndex) + << "OpStore Pointer <id> '" + << inst->words[pointerIndex] + << "'s type is void."; return false); auto objectIndex = 2; - auto object = find(inst->words[objectIndex]); - spvCheck(!found(object), DIAG(objectIndex) << "OpStore Object <id> '" - << inst->words[objectIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsObject(object->second.opcode), - DIAG(objectIndex) << "OpStore Object <id> '" - << inst->words[objectIndex] - << "' in not an object."; - return false); - auto objectType = find(object->second.inst->words[1]); - spvCheck(!found(objectType), assert(0 && "Unreachable!")); - spvCheck(SpvOpTypeVoid == objectType->second.opcode, + auto object = usedefs_.FindDef(inst->words[objectIndex]); + if (!object.first || !spvOpcodeIsObject(object.second.opcode)) { + DIAG(objectIndex) << "OpStore Object <id> '" << inst->words[objectIndex] + << "' in not an object."; + return false; + } + auto objectType = usedefs_.FindDef(object.second.words[1]); + assert(objectType.first); + spvCheck(SpvOpTypeVoid == objectType.second.opcode, DIAG(objectIndex) << "OpStore Object <id> '" << inst->words[objectIndex] << "'s type is void."; return false); - spvCheck(!spvOpcodeAreTypesEqual(type->second.inst, objectType->second.inst), - DIAG(pointerIndex) << "OpStore Pointer <id> '" - << inst->words[pointerIndex] - << "'s type does not match Object <id> '" - << objectType->second.id << "'s type."; + spvCheck(type.second.id != objectType.second.id, + DIAG(pointerIndex) + << "OpStore Pointer <id> '" << inst->words[pointerIndex] + << "'s type does not match Object <id> '" << objectType.second.id + << "'s type."; return false); return true; } @@ -1043,32 +837,25 @@ template <> bool idUsage::isValid<SpvOpCopyMemory>(const spv_instruction_t* inst, const spv_opcode_desc) { auto targetIndex = 1; - auto target = find(inst->words[targetIndex]); - spvCheck(!found(target), DIAG(targetIndex) << "OpCopyMemory Target <id> '" - << inst->words[targetIndex] - << "' is not defined."; - return false); + auto target = usedefs_.FindDef(inst->words[targetIndex]); + if (!target.first) return false; auto sourceIndex = 2; - auto source = find(inst->words[sourceIndex]); - spvCheck(!found(source), DIAG(targetIndex) << "OpCopyMemory Source <id> '" - << inst->words[targetIndex] - << "' is not defined."; + auto source = usedefs_.FindDef(inst->words[sourceIndex]); + if (!source.first) return false; + auto targetPointerType = usedefs_.FindDef(target.second.words[1]); + assert(targetPointerType.first); + auto targetType = usedefs_.FindDef(targetPointerType.second.words[3]); + assert(targetType.first); + auto sourcePointerType = usedefs_.FindDef(source.second.words[1]); + assert(sourcePointerType.first); + auto sourceType = usedefs_.FindDef(sourcePointerType.second.words[3]); + assert(sourceType.first); + spvCheck(targetType.second.id != sourceType.second.id, + DIAG(sourceIndex) + << "OpCopyMemory Target <id> '" << inst->words[sourceIndex] + << "'s type does not match Source <id> '" << sourceType.second.id + << "'s type."; return false); - auto targetPointerType = find(target->second.inst->words[1]); - spvCheck(!found(targetPointerType), assert(0 && "Unreachable!")); - auto targetType = find(targetPointerType->second.inst->words[3]); - spvCheck(!found(targetType), assert(0 && "Unreachable!")); - auto sourcePointerType = find(source->second.inst->words[1]); - spvCheck(!found(sourcePointerType), assert(0 && "Unreachable!")); - auto sourceType = find(sourcePointerType->second.inst->words[3]); - spvCheck(!found(sourceType), assert(0 && "Unreachable!")); - spvCheck( - !spvOpcodeAreTypesEqual(targetType->second.inst, sourceType->second.inst), - DIAG(sourceIndex) << "OpCopyMemory Target <id> '" - << inst->words[sourceIndex] - << "'s type does not match Source <id> '" - << sourceType->second.id << "'s type."; - return false); return true; } @@ -1076,57 +863,48 @@ template <> bool idUsage::isValid<SpvOpCopyMemorySized>(const spv_instruction_t* inst, const spv_opcode_desc) { auto targetIndex = 1; - auto target = find(inst->words[targetIndex]); - spvCheck(!found(target), - DIAG(targetIndex) << "OpCopyMemorySized Target <id> '" - << inst->words[targetIndex] << "' is not defined."; - return false); + auto target = usedefs_.FindDef(inst->words[targetIndex]); + if (!target.first) return false; auto sourceIndex = 2; - auto source = find(inst->words[sourceIndex]); - spvCheck(!found(source), - DIAG(sourceIndex) << "OpCopyMemorySized Source <id> '" - << inst->words[sourceIndex] << "' is not defined."; - return false); + auto source = usedefs_.FindDef(inst->words[sourceIndex]); + if (!source.first) return false; auto sizeIndex = 3; - auto size = find(inst->words[sizeIndex]); - spvCheck(!found(size), DIAG(sizeIndex) << "OpCopyMemorySized, Size <id> '" - << inst->words[sizeIndex] - << "' is not defined."; - return false); - auto targetPointerType = find(target->second.inst->words[1]); - spvCheck(!found(targetPointerType), assert(0 && "Unreachable!")); - spvCheck(SpvOpTypePointer != targetPointerType->second.opcode, + auto size = usedefs_.FindDef(inst->words[sizeIndex]); + if (!size.first) return false; + auto targetPointerType = usedefs_.FindDef(target.second.words[1]); + assert(targetPointerType.first); + spvCheck(SpvOpTypePointer != targetPointerType.second.opcode, DIAG(targetIndex) << "OpCopyMemorySized Target <id> '" << inst->words[targetIndex] << "' is not a pointer."; return false); - auto sourcePointerType = find(source->second.inst->words[1]); - spvCheck(!found(sourcePointerType), assert(0 && "Unreachable!")); - spvCheck(SpvOpTypePointer != sourcePointerType->second.opcode, + auto sourcePointerType = usedefs_.FindDef(source.second.words[1]); + assert(sourcePointerType.first); + spvCheck(SpvOpTypePointer != sourcePointerType.second.opcode, DIAG(sourceIndex) << "OpCopyMemorySized Source <id> '" << inst->words[sourceIndex] << "' is not a pointer."; return false); - switch (size->second.opcode) { + switch (size.second.opcode) { // TODO: The following opcode's are assumed to be valid, refer to the // following bug https://cvs.khronos.org/bugzilla/show_bug.cgi?id=13871 for // clarification case SpvOpConstant: case SpvOpSpecConstant: { - auto sizeType = find(size->second.inst->words[1]); - spvCheck(!found(sizeType), assert(0 && "Unreachable!")); - spvCheck(SpvOpTypeInt != sizeType->second.opcode, + auto sizeType = usedefs_.FindDef(size.second.words[1]); + assert(sizeType.first); + spvCheck(SpvOpTypeInt != sizeType.second.opcode, DIAG(sizeIndex) << "OpCopyMemorySized Size <id> '" << inst->words[sizeIndex] << "'s type is not an integer type."; return false); } break; case SpvOpVariable: { - auto pointerType = find(size->second.inst->words[1]); - spvCheck(!found(pointerType), assert(0 && "Unreachable!")); - auto sizeType = find(pointerType->second.inst->words[1]); - spvCheck(!found(sizeType), assert(0 && "Unreachable!")); - spvCheck(SpvOpTypeInt != sizeType->second.opcode, + auto pointerType = usedefs_.FindDef(size.second.words[1]); + assert(pointerType.first); + auto sizeType = usedefs_.FindDef(pointerType.second.words[1]); + assert(sizeType.first); + spvCheck(SpvOpTypeInt != sizeType.second.opcode, DIAG(sizeIndex) << "OpCopyMemorySized Size <id> '" << inst->words[sizeIndex] << "'s variable type is not an integer type."; @@ -1177,31 +955,23 @@ template <> bool idUsage::isValid<SpvOpFunction>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpFunction Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first) return false; auto functionTypeIndex = 4; - auto functionType = find(inst->words[functionTypeIndex]); - spvCheck(!found(functionType), DIAG(functionTypeIndex) - << "OpFunction Function Type <id> '" - << inst->words[functionTypeIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpTypeFunction != functionType->second.opcode, - DIAG(functionTypeIndex) << "OpFunction Function Type <id> '" - << inst->words[functionTypeIndex] - << "' is not a function type."; - return false); - auto returnType = find(functionType->second.inst->words[2]); - spvCheck(!found(returnType), assert(0 && "Unreachable!")); - spvCheck(returnType != resultType, + auto functionType = usedefs_.FindDef(inst->words[functionTypeIndex]); + if (!functionType.first || SpvOpTypeFunction != functionType.second.opcode) { + DIAG(functionTypeIndex) << "OpFunction Function Type <id> '" + << inst->words[functionTypeIndex] + << "' is not a function type."; + return false; + } + auto returnType = usedefs_.FindDef(functionType.second.words[2]); + assert(returnType.first); + spvCheck(returnType.second.id != resultType.second.id, DIAG(resultTypeIndex) << "OpFunction Result Type <id> '" << inst->words[resultTypeIndex] << "' does not match the Function Type <id> '" - << resultType->second.id << "'s return type."; + << resultType.second.id << "'s return type."; return false); return true; } @@ -1210,37 +980,34 @@ template <> bool idUsage::isValid<SpvOpFunctionParameter>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpFunctionParameter Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first) return false; // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. size_t paramIndex = 0; assert(firstInst < inst && "Invalid instruction pointer"); while (firstInst != --inst) { - spvCheck(SpvOpFunction != inst->opcode && SpvOpFunctionParameter != inst->opcode, - DIAG(0) << "OpFunctionParameter is not preceded by OpFunction or " - "OpFunctionParameter sequence."; - return false); + spvCheck( + SpvOpFunction != inst->opcode && SpvOpFunctionParameter != inst->opcode, + DIAG(0) << "OpFunctionParameter is not preceded by OpFunction or " + "OpFunctionParameter sequence."; + return false); if (SpvOpFunction == inst->opcode) { break; } else { paramIndex++; } } - auto functionType = find(inst->words[4]); - spvCheck(!found(functionType), assert(0 && "Unreachable!")); - auto paramType = find(functionType->second.inst->words[paramIndex + 3]); - spvCheck(!found(paramType), assert(0 && "Unreachable!")); - spvCheck( - !spvOpcodeAreTypesEqual(resultType->second.inst, paramType->second.inst), - DIAG(resultTypeIndex) << "OpFunctionParameter Result Type <id> '" - << inst->words[resultTypeIndex] - << "' does not match the OpTypeFunction parameter " - "type of the same index."; - return false); + auto functionType = usedefs_.FindDef(inst->words[4]); + assert(functionType.first); + auto paramType = usedefs_.FindDef(functionType.second.words[paramIndex + 3]); + assert(paramType.first); + spvCheck(resultType.second.id != paramType.second.id, + DIAG(resultTypeIndex) + << "OpFunctionParameter Result Type <id> '" + << inst->words[resultTypeIndex] + << "' does not match the OpTypeFunction parameter " + "type of the same index."; + return false); return true; } @@ -1248,37 +1015,27 @@ template <> bool idUsage::isValid<SpvOpFunctionCall>(const spv_instruction_t* inst, const spv_opcode_desc) { auto resultTypeIndex = 1; - auto resultType = find(inst->words[resultTypeIndex]); - spvCheck(!found(resultType), DIAG(resultTypeIndex) - << "OpFunctionCall Result Type <id> '" - << inst->words[resultTypeIndex] - << "' is not defined."; - return false); + auto resultType = usedefs_.FindDef(inst->words[resultTypeIndex]); + if (!resultType.first) return false; auto functionIndex = 3; - auto function = find(inst->words[functionIndex]); - spvCheck(!found(function), DIAG(functionIndex) - << "OpFunctionCall Function <id> '" - << inst->words[functionIndex] - << "' is not defined."; - return false); - spvCheck(SpvOpFunction != function->second.opcode, - DIAG(functionIndex) << "OpFunctionCall Function <id> '" - << inst->words[functionIndex] - << "' is not a function."; + auto function = usedefs_.FindDef(inst->words[functionIndex]); + if (!function.first || SpvOpFunction != function.second.opcode) { + DIAG(functionIndex) << "OpFunctionCall Function <id> '" + << inst->words[functionIndex] << "' is not a function."; + return false; + } + auto returnType = usedefs_.FindDef(function.second.words[1]); + assert(returnType.first); + spvCheck(returnType.second.id != resultType.second.id, + DIAG(resultTypeIndex) << "OpFunctionCall Result Type <id> '" + << inst->words[resultTypeIndex] + << "'s type does not match Function <id> '" + << returnType.second.id << "'s return type."; return false); - auto returnType = find(function->second.inst->words[1]); - spvCheck(!found(returnType), assert(0 && "Unreachable!")); - spvCheck( - !spvOpcodeAreTypesEqual(returnType->second.inst, resultType->second.inst), - DIAG(resultTypeIndex) - << "OpFunctionCall Result Type <id> '" << inst->words[resultTypeIndex] - << "'s type does not match Function <id> '" << returnType->second.id - << "'s return type."; - return false); - auto functionType = find(function->second.inst->words[4]); - spvCheck(!found(functionType), assert(0 && "Unreachable!")); + auto functionType = usedefs_.FindDef(function.second.words[4]); + assert(functionType.first); auto functionCallArgCount = inst->words.size() - 4; - auto functionParamCount = functionType->second.inst->words.size() - 3; + auto functionParamCount = functionType.second.words.size() - 3; spvCheck( functionParamCount != functionCallArgCount, DIAG(inst->words.size() - 1) @@ -1287,22 +1044,18 @@ bool idUsage::isValid<SpvOpFunctionCall>(const spv_instruction_t* inst, return false); for (size_t argumentIndex = 4, paramIndex = 3; argumentIndex < inst->words.size(); argumentIndex++, paramIndex++) { - auto argument = find(inst->words[argumentIndex]); - spvCheck(!found(argument), DIAG(argumentIndex) - << "OpFunctionCall Argument <id> '" - << inst->words[argumentIndex] - << "' is not defined."; - return false); - auto argumentType = find(argument->second.inst->words[1]); - spvCheck(!found(argumentType), assert(0 && "Unreachable!")); - auto parameterType = find(functionType->second.inst->words[paramIndex]); - spvCheck(!found(parameterType), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeAreTypesEqual(argumentType->second.inst, - parameterType->second.inst), + auto argument = usedefs_.FindDef(inst->words[argumentIndex]); + if (!argument.first) return false; + auto argumentType = usedefs_.FindDef(argument.second.words[1]); + assert(argumentType.first); + auto parameterType = + usedefs_.FindDef(functionType.second.words[paramIndex]); + assert(parameterType.first); + spvCheck(argumentType.second.id != parameterType.second.id, DIAG(argumentIndex) << "OpFunctionCall Argument <id> '" << inst->words[argumentIndex] << "'s type does not match Function <id> '" - << parameterType->second.id + << parameterType.second.id << "'s parameter type."; return false); } @@ -1939,18 +1692,14 @@ template <> bool idUsage::isValid<SpvOpReturnValue>(const spv_instruction_t* inst, const spv_opcode_desc) { auto valueIndex = 1; - auto value = find(inst->words[valueIndex]); - spvCheck(!found(value), DIAG(valueIndex) << "OpReturnValue Value <id> '" - << inst->words[valueIndex] - << "' is not defined."; - return false); - spvCheck(!spvOpcodeIsValue(value->second.opcode), - DIAG(valueIndex) << "OpReturnValue Value <id> '" - << inst->words[valueIndex] - << "' does not represent a value."; - return false); - auto valueType = find(value->second.inst->words[1]); - spvCheck(!found(valueType), assert(0 && "Unreachable!")); + auto value = usedefs_.FindDef(inst->words[valueIndex]); + if (!value.first || !spvOpcodeIsValue(value.second.opcode)) { + DIAG(valueIndex) << "OpReturnValue Value <id> '" << inst->words[valueIndex] + << "' does not represent a value."; + return false; + } + auto valueType = usedefs_.FindDef(value.second.words[1]); + assert(valueType.first); // NOTE: Find OpFunction const spv_instruction_t* function = inst - 1; while (firstInst != function) { @@ -1960,20 +1709,18 @@ bool idUsage::isValid<SpvOpReturnValue>(const spv_instruction_t* inst, spvCheck(SpvOpFunction != function->opcode, DIAG(valueIndex) << "OpReturnValue is not in a basic block."; return false); - auto returnType = find(function->words[1]); - spvCheck(!found(returnType), assert(0 && "Unreachable!")); - if (SpvOpTypePointer == valueType->second.opcode) { - auto pointerValueType = find(valueType->second.inst->words[3]); - spvCheck(!found(pointerValueType), assert(0 && "Unreachable!")); - spvCheck(!spvOpcodeAreTypesEqual(returnType->second.inst, - pointerValueType->second.inst), + auto returnType = usedefs_.FindDef(function->words[1]); + assert(returnType.first); + if (SpvOpTypePointer == valueType.second.opcode) { + auto pointerValueType = usedefs_.FindDef(valueType.second.words[3]); + assert(pointerValueType.first); + spvCheck(returnType.second.id != pointerValueType.second.id, DIAG(valueIndex) << "OpReturnValue Value <id> '" << inst->words[valueIndex] << "'s pointer type does not match OpFunction's return type."; return false); } else { - spvCheck(!spvOpcodeAreTypesEqual(returnType->second.inst, - valueType->second.inst), + spvCheck(returnType.second.id != valueType.second.id, DIAG(valueIndex) << "OpReturnValue Value <id> '" << inst->words[valueIndex] << "'s type does not match OpFunction's return type."; @@ -2378,10 +2125,8 @@ bool idUsage::isValid(const spv_instruction_t* inst) { return false; switch (inst->opcode) { FAIL(OpUndef) - CASE(OpName) CASE(OpMemberName) CASE(OpLine) - CASE(OpDecorate) CASE(OpMemberDecorate) CASE(OpGroupDecorate) FAIL(OpGroupMemberDecorate) @@ -2592,15 +2337,16 @@ bool idUsage::isValid(const spv_instruction_t* inst) { } } // anonymous namespace -spv_result_t spvValidateInstructionIDs( - const spv_instruction_t* pInsts, const uint64_t instCount, - const spv_id_info_t* pIdUses, const uint64_t idUsesCount, - const spv_id_info_t* pIdDefs, const uint64_t idDefsCount, - const spv_opcode_table opcodeTable, const spv_operand_table operandTable, - const spv_ext_inst_table extInstTable, spv_position position, - spv_diagnostic* pDiag) { - idUsage idUsage(opcodeTable, operandTable, extInstTable, pIdUses, idUsesCount, - pIdDefs, idDefsCount, pInsts, instCount, position, pDiag); +spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, + const uint64_t instCount, + const spv_opcode_table opcodeTable, + const spv_operand_table operandTable, + const spv_ext_inst_table extInstTable, + const libspirv::ValidationState_t& state, + spv_position position, + spv_diagnostic* pDiag) { + idUsage idUsage(opcodeTable, operandTable, extInstTable, pInsts, instCount, + state.usedefs(), state.entry_points(), position, pDiag); for (uint64_t instIndex = 0; instIndex < instCount; ++instIndex) { spvCheck(!idUsage.isValid(&pInsts[instIndex]), return SPV_ERROR_INVALID_ID); position->index += pInsts[instIndex].words.size(); diff --git a/source/validate_instruction.cpp b/source/validate_instruction.cpp index a5d8d0c4..cd35a4f2 100644 --- a/source/validate_instruction.cpp +++ b/source/validate_instruction.cpp @@ -27,7 +27,6 @@ // Performs validation on instructions that appear inside of a SPIR-V block. #include "validate_passes.h" -#include "validate_types.h" namespace libspirv { diff --git a/source/validate_layout.cpp b/source/validate_layout.cpp index b9fb5441..dd30c2b3 100644 --- a/source/validate_layout.cpp +++ b/source/validate_layout.cpp @@ -26,7 +26,6 @@ // Source code for logical layout validation as described in section 2.4 -#include "validate_types.h" #include "validate_passes.h" #include "libspirv/libspirv.h" diff --git a/source/validate_passes.h b/source/validate_passes.h index 5b78e541..de1d44a7 100644 --- a/source/validate_passes.h +++ b/source/validate_passes.h @@ -28,7 +28,7 @@ #define LIBSPIRV_VALIDATE_PASSES_H_ #include "binary.h" -#include "validate_types.h" +#include "validate.h" namespace libspirv { diff --git a/source/validate_ssa.cpp b/source/validate_ssa.cpp index 7e148e5e..2ce78efa 100644 --- a/source/validate_ssa.cpp +++ b/source/validate_ssa.cpp @@ -112,7 +112,7 @@ spv_result_t SsaPass(ValidationState_t& _, switch (type) { case SPV_OPERAND_TYPE_RESULT_ID: _.removeIfForwardDeclared(*operand_ptr); - ret = _.defineId(*operand_ptr); + ret = SPV_SUCCESS; break; case SPV_OPERAND_TYPE_ID: case SPV_OPERAND_TYPE_TYPE_ID: diff --git a/source/validate_types.cpp b/source/validate_types.cpp index 764cb34b..4130450c 100644 --- a/source/validate_types.cpp +++ b/source/validate_types.cpp @@ -24,9 +24,6 @@ // TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE // MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. -#include "headers/spirv.h" -#include "validate_types.h" - #include <algorithm> #include <cassert> #include <map> @@ -34,6 +31,10 @@ #include <unordered_set> #include <vector> +#include "headers/spirv.h" + +#include "validate.h" + using std::find; using std::string; using std::unordered_set; @@ -210,22 +211,12 @@ ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic, uint32_t options) : diagnostic_(diagnostic), instruction_counter_(0), - defined_ids_{}, unresolved_forward_ids_{}, validation_flags_(options), operand_names_{}, - current_layout_stage_(kLayoutCapabilities), + current_layout_section_(kLayoutCapabilities), module_functions_(*this) {} -spv_result_t ValidationState_t::defineId(uint32_t id) { - if (defined_ids_.find(id) == end(defined_ids_)) { - defined_ids_.insert(id); - } else { - return diag(SPV_ERROR_INVALID_ID) << "ID cannot be assigned multiple times"; - } - return SPV_SUCCESS; -} - spv_result_t ValidationState_t::forwardDeclareId(uint32_t id) { unresolved_forward_ids_.insert(id); return SPV_SUCCESS; @@ -260,7 +251,7 @@ vector<uint32_t> ValidationState_t::unresolvedForwardIds() const { } bool ValidationState_t::isDefinedId(uint32_t id) const { - return defined_ids_.find(id) != end(defined_ids_); + return usedefs_.FindDef(id).first; } bool ValidationState_t::is_enabled(spv_validate_options_t flag) const { @@ -273,19 +264,19 @@ int ValidationState_t::incrementInstructionCount() { } ModuleLayoutSection ValidationState_t::getLayoutSection() const { - return current_layout_stage_; + return current_layout_section_; } void ValidationState_t::progressToNextLayoutSectionOrder() { // Guard against going past the last element(kLayoutFunctionDefinitions) - if (current_layout_stage_ <= kLayoutFunctionDefinitions) { - current_layout_stage_ = - static_cast<ModuleLayoutSection>(current_layout_stage_ + 1); + if (current_layout_section_ <= kLayoutFunctionDefinitions) { + current_layout_section_ = + static_cast<ModuleLayoutSection>(current_layout_section_ + 1); } } bool ValidationState_t::isOpcodeInCurrentLayoutSection(SpvOp op) { - return IsInstructionInLayoutSection(current_layout_stage_, op); + return IsInstructionInLayoutSection(current_layout_section_, op); } DiagnosticStream ValidationState_t::diag(spv_result_t error_code) const { diff --git a/source/validate_types.h b/source/validate_types.h deleted file mode 100644 index 7725c896..00000000 --- a/source/validate_types.h +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright (c) 2015-2016 The Khronos Group Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and/or associated documentation files (the -// "Materials"), to deal in the Materials without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Materials, and to -// permit persons to whom the Materials are furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be included -// in all copies or substantial portions of the Materials. -// -// MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS -// KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS -// SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT -// https://www.khronos.org/registry/ -// -// THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -// MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - -#ifndef LIBSPIRV_VALIDATE_TYPES_H_ -#define LIBSPIRV_VALIDATE_TYPES_H_ - -#include "binary.h" -#include "diagnostic.h" -#include "libspirv/libspirv.h" - -#include <map> -#include <string> -#include <unordered_set> -#include <vector> - -namespace libspirv { - -// This enum represents the sections of a SPIRV module. See section 2.4 -// of the SPIRV spec for additional details of the order. The enumerant values -// are in the same order as the vector returned by GetModuleOrder -enum ModuleLayoutSection { - kLayoutCapabilities, // < Section 2.4 #1 - kLayoutExtensions, // < Section 2.4 #2 - kLayoutExtInstImport, // < Section 2.4 #3 - kLayoutMemoryModel, // < Section 2.4 #4 - kLayoutEntryPoint, // < Section 2.4 #5 - kLayoutExecutionMode, // < Section 2.4 #6 - kLayoutDebug1, // < Section 2.4 #7 > 1 - kLayoutDebug2, // < Section 2.4 #7 > 2 - kLayoutAnnotations, // < Section 2.4 #8 - kLayoutTypes, // < Section 2.4 #9 - kLayoutFunctionDeclarations, // < Section 2.4 #10 - kLayoutFunctionDefinitions // < Section 2.4 #11 -}; - -enum class FunctionDecl { - kFunctionDeclUnknown, // < Unknown function declaration - kFunctionDeclDeclaration, // < Function declaration - kFunctionDeclDefinition // < Function definition -}; - -class ValidationState_t; - -// This class manages all function declaration and definitions in a module. It -// handles the state and id information while parsing a function in the SPIR-V -// binary. -// -// NOTE: This class is designed to be a Structure of Arrays. Therefore each -// member variable is a vector whose elements represent the values for the -// corresponding function in a SPIR-V module. Variables that are not vector -// types are used to manage the state while parsing the function. -class Functions { - public: - Functions(ValidationState_t& module); - - // Registers the function in the module. Subsequent instructions will be - // called against this function - spv_result_t RegisterFunction(uint32_t id, uint32_t ret_type_id, - uint32_t function_control, - uint32_t function_type_id); - - // Registers a function parameter in the current function - spv_result_t RegisterFunctionParameter(uint32_t id, uint32_t type_id); - - // Register a function end instruction - spv_result_t RegisterFunctionEnd(); - - // Sets the declaration type of the current function - spv_result_t RegisterSetFunctionDeclType(FunctionDecl type); - - // Registers a block in the current function. Subsequent block instructions - // will target this block - // @param id The ID of the label of the block - spv_result_t RegisterBlock(uint32_t id); - - // Registers a variable in the current block - spv_result_t RegisterBlockVariable(uint32_t type_id, uint32_t id, - SpvStorageClass storage, uint32_t init_id); - - spv_result_t RegisterBlockLoopMerge(uint32_t merge_id, uint32_t continue_id, - SpvLoopControlMask control); - - spv_result_t RegisterBlockSelectionMerge(uint32_t merge_id, - SpvSelectionControlMask control); - - // Registers the end of the block - spv_result_t RegisterBlockEnd(); - - // Returns the number of blocks in the current function being parsed - size_t get_block_count() const; - - // Retuns true if called after a function instruction but before the - // function end instruction - bool in_function_body() const; - - // Returns true if called after a label instruction but before a branch - // instruction - bool in_block() const; - - libspirv::DiagnosticStream diag(spv_result_t error_code) const; - - private: - // Parent module - ValidationState_t& module_; - - // Funciton IDs in a module - std::vector<uint32_t> id_; - - // OpTypeFunction IDs of each of the id_ functions - std::vector<uint32_t> type_id_; - - // The type of declaration of each function - std::vector<FunctionDecl> declaration_type_; - - // TODO(umar): Probably needs better abstractions - // The beginning of the block of functions - std::vector<std::vector<uint32_t>> block_ids_; - - // The variable IDs of the functions - std::vector<std::vector<uint32_t>> variable_ids_; - - // The function parameter ids of the functions - std::vector<std::vector<uint32_t>> parameter_ids_; - - // NOTE: See correspoding getter functions - bool in_function_; - bool in_block_; -}; - -class ValidationState_t { - public: - ValidationState_t(spv_diagnostic* diagnostic, uint32_t options); - - // Defines the \p id for the module - spv_result_t defineId(uint32_t id); - - // Forward declares the id in the module - spv_result_t forwardDeclareId(uint32_t id); - - // Removes a forward declared ID if it has been defined - spv_result_t removeIfForwardDeclared(uint32_t id); - - // Assigns a name to an ID - void assignNameToId(uint32_t id, std::string name); - - // Returns a string representation of the ID in the format <id>[Name] where - // the <id> is the numeric valid of the id and the Name is a name assigned by - // the OpName instruction - std::string getIdName(uint32_t id) const; - - // Returns the number of ID which have been forward referenced but not defined - size_t unresolvedForwardIdCount() const; - - // Returns a list of unresolved forward ids. - std::vector<uint32_t> unresolvedForwardIds() const; - - // Returns true if the id has been defined - bool isDefinedId(uint32_t id) const; - - // Returns true if an spv_validate_options_t option is enabled in the - // validation instruction - bool is_enabled(spv_validate_options_t flag) const; - - // Increments the instruction count. Used for diagnostic - int incrementInstructionCount(); - - // Returns the current layout section which is being processed - ModuleLayoutSection getLayoutSection() const; - - // Increments the module_layout_order_stage_ - void progressToNextLayoutSectionOrder(); - - // Determines if the op instruction is part of the current stage - bool isOpcodeInCurrentLayoutSection(SpvOp op); - - libspirv::DiagnosticStream diag(spv_result_t error_code) const; - - // Returns the function states - Functions& get_functions(); - - // Retuns true if the called after a function instruction but before the - // function end instruction - bool in_function_body() const; - - // Returns true if called after a label instruction but before a branch - // instruction - bool in_block() const; - - private: - spv_diagnostic* diagnostic_; - // Tracks the number of instructions evaluated by the validator - int instruction_counter_; - - // All IDs which have been defined - std::unordered_set<uint32_t> defined_ids_; - - // IDs which have been forward declared but have not been defined - std::unordered_set<uint32_t> unresolved_forward_ids_; - - // Validation options to determine the passes to execute - uint32_t validation_flags_; - - std::map<uint32_t, std::string> operand_names_; - - // The section of the code being processed - ModuleLayoutSection current_layout_stage_; - - Functions module_functions_; - - std::vector<SpvCapability> module_capabilities_; -}; -} - -#define spvCheckReturn(expression) \ - if (spv_result_t error = (expression)) return error; - - -#endif diff --git a/test/ValidateID.cpp b/test/ValidateID.cpp index f8e3184c..64930fbe 100644 --- a/test/ValidateID.cpp +++ b/test/ValidateID.cpp @@ -229,7 +229,8 @@ TEST_F(ValidateID, OpExecutionModeGood) { OpFunctionEnd)"; CHECK(spirv, SPV_SUCCESS); } -TEST_F(ValidateID, OpExecutionModeEntryPointBad) { + +TEST_F(ValidateID, OpExecutionModeEntryPointMissing) { const char* spirv = R"( OpExecutionMode %3 LocalSize 1 1 1 %1 = OpTypeVoid @@ -241,6 +242,21 @@ TEST_F(ValidateID, OpExecutionModeEntryPointBad) { CHECK(spirv, SPV_ERROR_INVALID_ID); } +TEST_F(ValidateID, OpExecutionModeEntryPointBad) { + const char* spirv = R"( + OpEntryPoint GLCompute %3 "" %a + OpExecutionMode %a LocalSize 1 1 1 +%void = OpTypeVoid +%ptr = OpTypePointer Input %void +%a = OpVariable %ptr Input +%2 = OpTypeFunction %void +%3 = OpFunction %void None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd)"; + CHECK(spirv, SPV_ERROR_INVALID_ID); +} + TEST_F(ValidateID, OpTypeVectorGood) { const char* spirv = R"( %1 = OpTypeFloat 32 @@ -1267,6 +1283,37 @@ TEST_F(ValidateID, OpReturnValueBad) { CHECK(spirv, SPV_ERROR_INVALID_ID); } +TEST_F(ValidateID, UndefinedTypeId) { + const char* spirv = R"( +%f32 = OpTypeFloat 32 +%atype = OpTypeRuntimeArray %f32 +%stype = OpTypeStruct %atype +%ptrtype = OpTypePointer Input %stype +%svar = OpVariable %ptrtype Input +%s = OpLoad %stype %svar +%len = OpArrayLength %undef %s 0 +)"; + CHECK(spirv, SPV_ERROR_INVALID_ID); +} + +TEST_F(ValidateID, UndefinedIdScope) { + const char* spirv = R"( +%u32 = OpTypeInt 32 0 +%memsem = OpConstant %u32 0 +OpMemoryBarrier %undef %memsem +)"; + CHECK(spirv, SPV_ERROR_INVALID_ID); +} + +TEST_F(ValidateID, UndefinedIdMemSem) { + const char* spirv = R"( +%u32 = OpTypeInt 32 0 +%scope = OpConstant %u32 0 +OpMemoryBarrier %scope %undef +)"; + CHECK(spirv, SPV_ERROR_INVALID_ID); +} + // TODO: OpLifetimeStart // TODO: OpLifetimeStop // TODO: OpAtomicInit |