summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChip Davis <chip@holochip.com>2022-11-08 16:00:06 -0800
committerChip Davis <chip@holochip.com>2022-11-20 00:20:49 -0800
commit51d2dfe02ae39aa820edad9ce676441de83c42d1 (patch)
treed2ef42bd7f45d76ea3adb6bb711b81802ca687fc
parentedd66a2fc9e932ad0d3dce78f2627eeae91c2660 (diff)
downloadSPIRV-Cross-51d2dfe02ae39aa820edad9ce676441de83c42d1.tar.gz
SPIRV-Cross-51d2dfe02ae39aa820edad9ce676441de83c42d1.tar.bz2
SPIRV-Cross-51d2dfe02ae39aa820edad9ce676441de83c42d1.zip
MSL: Add missing casts to `Op?MulExtended`.
It is possible to pass unsigned integers to `OpSMulExtended`. In that case, we want to do a signed multiply with sign extension, so make sure the operands are forced to be interpreted as signed. This was an oversight on my part when I added these instructions. Fixes the CTS test `dEQP-VK.spirv_assembly.instruction.compute.signed_op.uint_smulextended`.
-rw-r--r--reference/opt/shaders-msl/asm/comp/uint_smulextended.asm.comp25
-rw-r--r--reference/shaders-msl/asm/comp/uint_smulextended.asm.comp25
-rw-r--r--shaders-msl/asm/comp/uint_smulextended.asm.comp61
-rw-r--r--spirv_msl.cpp29
4 files changed, 136 insertions, 4 deletions
diff --git a/reference/opt/shaders-msl/asm/comp/uint_smulextended.asm.comp b/reference/opt/shaders-msl/asm/comp/uint_smulextended.asm.comp
new file mode 100644
index 00000000..6996f7fd
--- /dev/null
+++ b/reference/opt/shaders-msl/asm/comp/uint_smulextended.asm.comp
@@ -0,0 +1,25 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct _4
+{
+ uint _m0[1];
+};
+
+struct _20
+{
+ uint _m0;
+ uint _m1;
+};
+
+kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
+{
+ _20 _28;
+ _28._m0 = uint(int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x]));
+ _28._m1 = uint(mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x])));
+ _7._m0[gl_GlobalInvocationID.x] = _28._m0;
+ _8._m0[gl_GlobalInvocationID.x] = _28._m1;
+}
+
diff --git a/reference/shaders-msl/asm/comp/uint_smulextended.asm.comp b/reference/shaders-msl/asm/comp/uint_smulextended.asm.comp
new file mode 100644
index 00000000..6996f7fd
--- /dev/null
+++ b/reference/shaders-msl/asm/comp/uint_smulextended.asm.comp
@@ -0,0 +1,25 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct _4
+{
+ uint _m0[1];
+};
+
+struct _20
+{
+ uint _m0;
+ uint _m1;
+};
+
+kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
+{
+ _20 _28;
+ _28._m0 = uint(int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x]));
+ _28._m1 = uint(mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x])));
+ _7._m0[gl_GlobalInvocationID.x] = _28._m0;
+ _8._m0[gl_GlobalInvocationID.x] = _28._m1;
+}
+
diff --git a/shaders-msl/asm/comp/uint_smulextended.asm.comp b/shaders-msl/asm/comp/uint_smulextended.asm.comp
new file mode 100644
index 00000000..32d48363
--- /dev/null
+++ b/shaders-msl/asm/comp/uint_smulextended.asm.comp
@@ -0,0 +1,61 @@
+ OpCapability Shader
+
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationId
+ OpExecutionMode %main LocalSize 1 1 1
+
+ OpDecorate %gl_GlobalInvocationId BuiltIn GlobalInvocationId
+ OpDecorate %ra_uint ArrayStride 4
+ OpDecorate %struct_uint4 BufferBlock
+ OpMemberDecorate %struct_uint4 0 Offset 0
+ OpDecorate %input0 DescriptorSet 0
+ OpDecorate %input0 Binding 0
+ OpDecorate %input1 DescriptorSet 0
+ OpDecorate %input1 Binding 1
+ OpDecorate %output0 DescriptorSet 0
+ OpDecorate %output0 Binding 2
+ OpDecorate %output1 DescriptorSet 0
+ OpDecorate %output1 Binding 3
+
+ %uint = OpTypeInt 32 0
+ %ptr_uint = OpTypePointer Uniform %uint
+ %ptr_input_uint = OpTypePointer Input %uint
+ %uint3 = OpTypeVector %uint 3
+ %ptr_input_uint3 = OpTypePointer Input %uint3
+ %void = OpTypeVoid
+ %voidFn = OpTypeFunction %void
+
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %ra_uint = OpTypeRuntimeArray %uint
+ %uint4 = OpTypeVector %uint 4
+ %struct_uint4 = OpTypeStruct %ra_uint
+ %ptr_struct_uint4 = OpTypePointer Uniform %struct_uint4
+ %resulttype = OpTypeStruct %uint %uint
+%gl_GlobalInvocationId = OpVariable %ptr_input_uint3 Input
+ %input0 = OpVariable %ptr_struct_uint4 Uniform
+ %input1 = OpVariable %ptr_struct_uint4 Uniform
+
+ %output0 = OpVariable %ptr_struct_uint4 Uniform
+ %output1 = OpVariable %ptr_struct_uint4 Uniform
+
+ %main = OpFunction %void None %voidFn
+ %mainStart = OpLabel
+ %index_ptr = OpAccessChain %ptr_input_uint %gl_GlobalInvocationId %uint_0
+ %index = OpLoad %uint %index_ptr
+ %in_ptr0 = OpAccessChain %ptr_uint %input0 %uint_0 %index
+ %invalue0 = OpLoad %uint %in_ptr0
+ %in_ptr1 = OpAccessChain %ptr_uint %input1 %uint_0 %index
+ %invalue1 = OpLoad %uint %in_ptr1
+
+ %outvalue = OpSMulExtended %resulttype %invalue0 %invalue1
+ %outvalue0 = OpCompositeExtract %uint %outvalue 0
+ %out_ptr0 = OpAccessChain %ptr_uint %output0 %uint_0 %index
+ OpStore %out_ptr0 %outvalue0
+ %outvalue1 = OpCompositeExtract %uint %outvalue 1
+ %out_ptr1 = OpAccessChain %ptr_uint %output1 %uint_0 %index
+ OpStore %out_ptr1 %outvalue1
+
+
+ OpReturn
+ OpFunctionEnd
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index 58090ebb..f6d22ced 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -8937,12 +8937,33 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
uint32_t op0 = ops[2];
uint32_t op1 = ops[3];
auto &type = get<SPIRType>(result_type);
+ auto input_type = opcode == OpSMulExtended ? int_type : uint_type;
+ auto &output_type = get_type(result_type);
+ string cast_op0, cast_op1;
+
+ auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, false);
+
emit_uninitialized_temporary_expression(result_type, result_id);
- statement(to_expression(result_id), ".", to_member_name(type, 0), " = ",
- to_enclosed_unpacked_expression(op0), " * ", to_enclosed_unpacked_expression(op1), ";");
- statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(",
- to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), ");");
+ string mullo_expr, mulhi_expr;
+ mullo_expr = join(cast_op0, " * ", cast_op1);
+ mulhi_expr = join("mulhi(", cast_op0, ", ", cast_op1, ")");
+
+ auto &low_type = get_type(output_type.member_types[0]);
+ auto &high_type = get_type(output_type.member_types[1]);
+ if (low_type.basetype != input_type)
+ {
+ expected_type.basetype = input_type;
+ mullo_expr = join(bitcast_glsl_op(low_type, expected_type), "(", mullo_expr, ")");
+ }
+ if (high_type.basetype != input_type)
+ {
+ expected_type.basetype = input_type;
+ mulhi_expr = join(bitcast_glsl_op(high_type, expected_type), "(", mulhi_expr, ")");
+ }
+
+ statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", mullo_expr, ";");
+ statement(to_expression(result_id), ".", to_member_name(type, 1), " = ", mulhi_expr, ";");
break;
}