diff --git a/tools/clang/lib/Headers/hlsl/dx/linalg.h b/tools/clang/lib/Headers/hlsl/dx/linalg.h index 47743fa0f8..3c2b60a023 100644 --- a/tools/clang/lib/Headers/hlsl/dx/linalg.h +++ b/tools/clang/lib/Headers/hlsl/dx/linalg.h @@ -186,8 +186,11 @@ __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U64, uint64_t) __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F64, double) template struct DstN { + // Make sure to round up in case SrcN isn't an even multiple of the number of + // elements per scalar static const int Value = - (SrcN * ComponentTypeTraits::ElementsPerScalar) / + (SrcN * ComponentTypeTraits::ElementsPerScalar + + ComponentTypeTraits::ElementsPerScalar - 1) / ComponentTypeTraits::ElementsPerScalar; }; diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl index 268e814e7e..14e2f04e85 100644 --- a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl @@ -82,4 +82,9 @@ void main(uint ID : SV_GroupID) { half16 srcF16 = BAB.Load(128); InterpretedVector convertedPacked = Convert(srcF16); + // CHECK: call <1 x i32> @dx.op.linAlgConvert.v1i32.v3f16(i32 -2147483618, <3 x half> %25, i32 8, i32 21) + // CHECK-SAME: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) + half3 ThreeF16 = BAB.Load(256); + InterpretedVector convertedPacked2 = + Convert(ThreeF16); }