From 51d2dfe02ae39aa820edad9ce676441de83c42d1 Mon Sep 17 00:00:00 2001 From: Chip Davis Date: Tue, 8 Nov 2022 16:00:06 -0800 Subject: MSL: Add missing casts to `Op?MulExtended`. It is possible to pass unsigned integers to `OpSMulExtended`. In that case, we want to do a signed multiply with sign extension, so make sure the operands are forced to be interpreted as signed. This was an oversight on my part when I added these instructions. Fixes the CTS test `dEQP-VK.spirv_assembly.instruction.compute.signed_op.uint_smulextended`. --- .../asm/comp/uint_smulextended.asm.comp | 25 +++++++++ .../asm/comp/uint_smulextended.asm.comp | 25 +++++++++ shaders-msl/asm/comp/uint_smulextended.asm.comp | 61 ++++++++++++++++++++++ spirv_msl.cpp | 29 ++++++++-- 4 files changed, 136 insertions(+), 4 deletions(-) create mode 100644 reference/opt/shaders-msl/asm/comp/uint_smulextended.asm.comp create mode 100644 reference/shaders-msl/asm/comp/uint_smulextended.asm.comp create mode 100644 shaders-msl/asm/comp/uint_smulextended.asm.comp diff --git a/reference/opt/shaders-msl/asm/comp/uint_smulextended.asm.comp b/reference/opt/shaders-msl/asm/comp/uint_smulextended.asm.comp new file mode 100644 index 00000000..6996f7fd --- /dev/null +++ b/reference/opt/shaders-msl/asm/comp/uint_smulextended.asm.comp @@ -0,0 +1,25 @@ +#include +#include + +using namespace metal; + +struct _4 +{ + uint _m0[1]; +}; + +struct _20 +{ + uint _m0; + uint _m1; +}; + +kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +{ + _20 _28; + _28._m0 = uint(int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x])); + _28._m1 = uint(mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x]))); + _7._m0[gl_GlobalInvocationID.x] = _28._m0; + _8._m0[gl_GlobalInvocationID.x] = _28._m1; +} + diff --git a/reference/shaders-msl/asm/comp/uint_smulextended.asm.comp b/reference/shaders-msl/asm/comp/uint_smulextended.asm.comp new file mode 100644 index 00000000..6996f7fd --- /dev/null +++ b/reference/shaders-msl/asm/comp/uint_smulextended.asm.comp @@ -0,0 +1,25 @@ +#include +#include + +using namespace metal; + +struct _4 +{ + uint _m0[1]; +}; + +struct _20 +{ + uint _m0; + uint _m1; +}; + +kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +{ + _20 _28; + _28._m0 = uint(int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x])); + _28._m1 = uint(mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x]))); + _7._m0[gl_GlobalInvocationID.x] = _28._m0; + _8._m0[gl_GlobalInvocationID.x] = _28._m1; +} + diff --git a/shaders-msl/asm/comp/uint_smulextended.asm.comp b/shaders-msl/asm/comp/uint_smulextended.asm.comp new file mode 100644 index 00000000..32d48363 --- /dev/null +++ b/shaders-msl/asm/comp/uint_smulextended.asm.comp @@ -0,0 +1,61 @@ + OpCapability Shader + + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationId + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %gl_GlobalInvocationId BuiltIn GlobalInvocationId + OpDecorate %ra_uint ArrayStride 4 + OpDecorate %struct_uint4 BufferBlock + OpMemberDecorate %struct_uint4 0 Offset 0 + OpDecorate %input0 DescriptorSet 0 + OpDecorate %input0 Binding 0 + OpDecorate %input1 DescriptorSet 0 + OpDecorate %input1 Binding 1 + OpDecorate %output0 DescriptorSet 0 + OpDecorate %output0 Binding 2 + OpDecorate %output1 DescriptorSet 0 + OpDecorate %output1 Binding 3 + + %uint = OpTypeInt 32 0 + %ptr_uint = OpTypePointer Uniform %uint + %ptr_input_uint = OpTypePointer Input %uint + %uint3 = OpTypeVector %uint 3 + %ptr_input_uint3 = OpTypePointer Input %uint3 + %void = OpTypeVoid + %voidFn = OpTypeFunction %void + + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %ra_uint = OpTypeRuntimeArray %uint + %uint4 = OpTypeVector %uint 4 + %struct_uint4 = OpTypeStruct %ra_uint + %ptr_struct_uint4 = OpTypePointer Uniform %struct_uint4 + %resulttype = OpTypeStruct %uint %uint +%gl_GlobalInvocationId = OpVariable %ptr_input_uint3 Input + %input0 = OpVariable %ptr_struct_uint4 Uniform + %input1 = OpVariable %ptr_struct_uint4 Uniform + + %output0 = OpVariable %ptr_struct_uint4 Uniform + %output1 = OpVariable %ptr_struct_uint4 Uniform + + %main = OpFunction %void None %voidFn + %mainStart = OpLabel + %index_ptr = OpAccessChain %ptr_input_uint %gl_GlobalInvocationId %uint_0 + %index = OpLoad %uint %index_ptr + %in_ptr0 = OpAccessChain %ptr_uint %input0 %uint_0 %index + %invalue0 = OpLoad %uint %in_ptr0 + %in_ptr1 = OpAccessChain %ptr_uint %input1 %uint_0 %index + %invalue1 = OpLoad %uint %in_ptr1 + + %outvalue = OpSMulExtended %resulttype %invalue0 %invalue1 + %outvalue0 = OpCompositeExtract %uint %outvalue 0 + %out_ptr0 = OpAccessChain %ptr_uint %output0 %uint_0 %index + OpStore %out_ptr0 %outvalue0 + %outvalue1 = OpCompositeExtract %uint %outvalue 1 + %out_ptr1 = OpAccessChain %ptr_uint %output1 %uint_0 %index + OpStore %out_ptr1 %outvalue1 + + + OpReturn + OpFunctionEnd diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 58090ebb..f6d22ced 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -8937,12 +8937,33 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) uint32_t op0 = ops[2]; uint32_t op1 = ops[3]; auto &type = get(result_type); + auto input_type = opcode == OpSMulExtended ? int_type : uint_type; + auto &output_type = get_type(result_type); + string cast_op0, cast_op1; + + auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, false); + emit_uninitialized_temporary_expression(result_type, result_id); - statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", - to_enclosed_unpacked_expression(op0), " * ", to_enclosed_unpacked_expression(op1), ";"); - statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", - to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), ");"); + string mullo_expr, mulhi_expr; + mullo_expr = join(cast_op0, " * ", cast_op1); + mulhi_expr = join("mulhi(", cast_op0, ", ", cast_op1, ")"); + + auto &low_type = get_type(output_type.member_types[0]); + auto &high_type = get_type(output_type.member_types[1]); + if (low_type.basetype != input_type) + { + expected_type.basetype = input_type; + mullo_expr = join(bitcast_glsl_op(low_type, expected_type), "(", mullo_expr, ")"); + } + if (high_type.basetype != input_type) + { + expected_type.basetype = input_type; + mulhi_expr = join(bitcast_glsl_op(high_type, expected_type), "(", mulhi_expr, ")"); + } + + statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", mullo_expr, ";"); + statement(to_expression(result_id), ".", to_member_name(type, 1), " = ", mulhi_expr, ";"); break; } -- cgit v1.2.3