Skip to content

Commit 1c2fe0e

Browse files
authored
[hipblaslt] Fix fails with dtl.yaml and xfp32.yaml on gfx950_mx_rebase (#4906)
## Motivation Fix fails with dtl.yaml and xfp32.yaml on gfx950_mx_rebase branch ## Technical Details - Fixed merge issue with if kernel["ScheduleIterAlg"] == 3 - Added int cast for float value const in asm - Fixed incorrect parameter for calcLdsBlockSizePerPad() - Fixed incorrect local read calculation due to incorrectly applying MX logic to TF32 emulation - Fixed incorrect ShiftK code vreg due to missing if condition for TF32 emulation ## Test Plan ran dtl.yaml and xfp32.yaml on gfx950_mx_rebase branch ## Test Result All passed except for a known issue (should be fixed with latest develop branch) ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent a2ce1ab commit 1c2fe0e

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):
252252
# fp64 TLU=1 reading 0.5element/lane/read..
253253
# for TLU=0 case, blockWidth and LRVW should match
254254
miInputPerGroup = kernel["MIInputPerThread%s"%tc]
255-
if writer.states.asmCaps["HasMFMA_f8f6f4"] and ((tP["bpeDS"] * miInputPerGroup) > 24):
255+
if writer.states.asmCaps["HasMFMA_f8f6f4"] and ((tP["bpeDS"] * miInputPerGroup) > 24) and not kernel["UseF32XEmulation"]:
256256
miInputPerGroup = int(16 / tP["bpeDS"])
257257
miInputGroup = kernel["MIInputPerThread%s"%tc] // miInputPerGroup
258258
numReadsPerUnroll = ceil(tP["bpeDS"] * miInputPerGroup / int(unrollBlockWidth * bpr))

projects/hipblaslt/tensilelite/Tensile/KernelWriter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,7 +1872,7 @@ def calculateRangeAndUpdateCounter(itemCounter, writeCounters, length):
18721872
localReads += (localReadsA + localReadsB + localReadsMXSA + localReadsMXSB)
18731873

18741874
# some of localReads is interleaved after waitcnt in SIA3
1875-
if kernel["ScheduleIterAlg"] == 3 and self.states.numItersPLR and\
1875+
if scheduleIterAlg == 3 and self.states.numItersPLR and\
18761876
(iteration < maxNumberReadIter or numPrefetchIter):
18771877
if ((iteration < numReadsIterA and not dataAtIterA < maxDataAtIter) or numPrefetchIter) and (not kernel["DirectToVgprA"]):
18781878
localReads -= self.states.numReadsPerIterA * readFactorA
@@ -4024,9 +4024,9 @@ def _initKernel(self, kernel, tensorParametersA, tensorParametersB):
40244024
unitA = 1
40254025
unitB = 1
40264026
if ((not tluA) and (bpeGRA * asem < 4) and grvwa > 1):
4027-
unitA = 4 // (bpeGRA * asem)
4027+
unitA = int(4 / (bpeGRA * asem))
40284028
if ((not tluB) and (bpeGRB * asem < 4) and grvwb > 1):
4029-
unitB = 4 // (bpeGRB * asem)
4029+
unitB = int(4 / (bpeGRB * asem))
40304030
self.states.tailloopInNllmaxUnit = max(unitA, unitB)
40314031

40324032
# Only assembly supports scheduling

projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5282,7 +5282,7 @@ def generateFindTheLastElementLocation(tc):
52825282
comment="Calculate the remaining dimension along I/J direction."))
52835283
imod.add(SSubU32(dst=sgpr(sTmp0), src0=sgpr(strSize), src1=sgpr(sTmp0), \
52845284
comment="Calculate the remaining dimension along I/J direction."))
5285-
imod.add(SMulI32(dst=sgpr(sTmp0), src0=sgpr(sTmp0), src1=tP["bpeGR"], \
5285+
imod.add(SMulI32(dst=sgpr(sTmp0), src0=sgpr(sTmp0), src1=int(tP["bpeGR"]), \
52865286
comment="In bytes"))
52875287
imod.add(SAndB32(dst=sgpr(sTmp1), src0=sgpr("SizeL"), src1=(kernel["DepthU"] - 1), \
52885288
comment="Calculate the remaining dimension along L direction."))
@@ -6611,6 +6611,8 @@ def generateSrcStrForMFMA(self, kernel, tP, innerUnroll, vregSetIdx, vgprPerInpu
66116611
iui_new_offset = iui%numReadsIterCoalesced*vgprPerInput
66126612
ab_new = idxAB*vgprPerInput*numReadsIterCoalesced
66136613
abStr = "Valu%s_X%u_I%u+%u+%u+%u" % (tc, vgprBuffer_new, iui_new, ab_new, vgprBuffer_new_offset, iui_new_offset)
6614+
if kernel["UseDirect32XEmulation"] and bk != None and (int(bk) % 8) < 4:
6615+
abStr = "Valu%c_T%u_I%u+%u+%u+%u" % (tc, vgprBuffer_new, iui_new, ab_new // 2, vgprBuffer_new_offset, iui_new_offset)
66146616
if kernel["DirectToVgpr%s"%tc] and not (packDTV or convDTV):
66156617
# overwrite aStr/bStr for DirectToVgpr (except for pack DTV case)
66166618
numVgprPerBlock = statesTc.numVgprG2LAllocated

projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2947,7 +2947,7 @@ def calSwizzlePackK(state, tc):
29472947
auto_LdsBlockSizePerPadB_for_mix = 0
29482948
if state["LdsBlockSizePerPadB"] == -1:
29492949
auto_LdsBlockSizePerPadB_for_mix = 1
2950-
state["LdsBlockSizePerPadA"], state["LdsBlockSizePerPadB"] = calcLdsBlockSizePerPad(-1) # for MX datatypes, the lrvw argument is ignored
2950+
state["LdsBlockSizePerPadA"], state["LdsBlockSizePerPadB"] = calcLdsBlockSizePerPad(state["LocalReadVectorWidth"])
29512951

29522952
if state["LdsBlockSizePerPadMetadata"] == -1:
29532953
state["LdsBlockSizePerPadMetadata"] = state["LdsBlockSizePerPadA"]

0 commit comments

Comments
 (0)