Skip to content

Commit

Permalink
Enable setOccupancyLimit for MBSK (#1438)
Browse files Browse the repository at this point in the history
* Refactor MBSK related functions

* FIx vgpr occupancy not calculated correctly in unified mode

* Support setOccupancyLimit for MBSK
  • Loading branch information
KKyang authored Dec 13, 2024
1 parent e453aeb commit 88fef99
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 32 deletions.
33 changes: 14 additions & 19 deletions tensilelite/Tensile/Components/GlobalWriteBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,20 @@ class GlobalWriteBatchComponent(GlobalWriteComponents):
def __call__(self, kernel: Solution, tPA, tPB, activation: ActivationModule, ss: StoreState, \
batchIdx, applyAlpha, beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, addrScaleAVec, addrScaleBVec, addrScaleAlphaVec, isLocalBarrierInit: bool, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, \
packdata, parentWriter, factorDim) -> Module:
tmpVgpr, tmpVgprDynamic, cvtVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, \
codeMulAlpha, packdata, parentWriter, factorDim) -> Module:
return GlobalWriteBatchWriter(kernel, tPA, tPB, activation, ss, batchIdx, applyAlpha, \
beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, addrScaleAVec, addrScaleBVec, addrScaleAlphaVec, isLocalBarrierInit, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, \
packdata, parentWriter, factorDim).emit()
tmpVgpr, tmpVgprDynamic, cvtVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, \
codeAccVgprRead, codeMulAlpha, packdata, parentWriter, factorDim).emit()

class GlobalWriteBatchWriter:
def __init__(self, kernel: Solution, tPA, tPB, activation: ActivationModule, ss: StoreState, \
batchIdx, applyAlpha, beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, addrScaleAVec, addrScaleBVec, addrScaleAlphaVec, isLocalBarrierInit: bool, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, \
packdata, parentWriter, factorDim):
tmpVgpr, tmpVgprDynamic, cvtVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, \
codeMulAlpha, packdata, parentWriter, factorDim):
self.kernel = kernel
self.tPA = tPA
self.tPB = tPB
Expand All @@ -80,6 +80,9 @@ def __init__(self, kernel: Solution, tPA, tPB, activation: ActivationModule, ss:
self.activationTypeStr = activationTypeStr
self.tmpVgpr = tmpVgpr.idx
self.tmpVgprSize = tmpVgpr.size
if tmpVgprDynamic:
self.tmpVgprDynamic = tmpVgprDynamic.idx
self.tmpVgprDynamicSize = tmpVgprDynamic.size
self.cvtVgprStruct = cvtVgprStruct
self.batchElementSgprs = batchElementSgprs
self.tmpSgpr = tmpSgpr
Expand Down Expand Up @@ -339,11 +342,11 @@ def GSUSynccodegen(self, labelend, vgprstart, globalOffset, vgproffset):
numDim = len(indices)
with self.parentWriter.allocTmpSgpr(5) as tmpSgprInfo:
tmpSgpr = tmpSgprInfo.idx
module.addModuleAsFlatItems(self.parentWriter.s_mul_u64_u32(sgpr(tmpSgpr+0), sgpr(tmpSgpr+1), sgpr("SizesFree+0"), 1, "Free0"))
module.addModuleAsFlatItems(self.parentWriter.s_mul_u64_u32(sgpr(tmpSgpr+0), sgpr(tmpSgpr+1), sgpr("SizesFree+0"), 1, self.tmpVgpr, "Free0"))
for i in range(1, numDim):
module.add(SSubU32(dst=sgpr(tmpSgpr+4), src0=sgpr("SizesFree+%u"%i), src1=1, comment="Free%u" % i))
module.add(SMulI32(dst=sgpr(tmpSgpr+4), src0=sgpr(tmpSgpr+4), src1=1, comment="Free%u" % i))
module.addModuleAsFlatItems(self.parentWriter.s_mul_u64_u32(sgpr(tmpSgpr+2), sgpr(tmpSgpr+3), sgpr(tmpSgpr+4), sgpr("StrideC%s"%self.parentWriter.states.indexChars[i]), "Free%u" % i))
module.addModuleAsFlatItems(self.parentWriter.s_mul_u64_u32(sgpr(tmpSgpr+2), sgpr(tmpSgpr+3), sgpr(tmpSgpr+4), sgpr("StrideC%s"%self.parentWriter.states.indexChars[i]), self.tmpVgpr, "Free%u" % i))
module.add(SAddU32(dst=sgpr(tmpSgpr+0), src0=sgpr(tmpSgpr+0), src1=sgpr(tmpSgpr+2), comment="Free%u" % i))
module.add(SAddCU32(dst=sgpr(tmpSgpr+1), src0=sgpr(tmpSgpr+1), src1=sgpr(tmpSgpr+3), comment="Free%u" % i))

Expand Down Expand Up @@ -387,10 +390,7 @@ def GSUSynccodegen(self, labelend, vgprstart, globalOffset, vgproffset):

addr0 = vgpr(addrCalc.addrDVgpr)

GSUtotal = 16
if (self.kernel["MIWaveTile"][0]*self.kernel["MIWaveTile"][1])*(self.kernel["MIWaveGroup"][0]*self.kernel["MIWaveGroup"][1]) > 8:
GSUtotal = int(GSUtotal/int((self.kernel["MIWaveTile"][0]*self.kernel["MIWaveTile"][1])*(self.kernel["MIWaveGroup"][0]*self.kernel["MIWaveGroup"][1])/8))
GSUtotal = max(2,GSUtotal)
GSUtotal = self.parentWriter.getMBSKGSUTotal(self.kernel)
SynchronizerAddEndlabel = [""] * GSUtotal

for idx in range(0, GSUtotal):
Expand Down Expand Up @@ -432,17 +432,14 @@ def GSUSynccodegen(self, labelend, vgprstart, globalOffset, vgproffset):
comment="load GSU D 0 "+str(vgprstart)))
SyncloadedData += 1

GSUMvgpr = self.parentWriter.vgprPool.checkOut(1, "GSUMvgpr")
module.add(SAndB32(dst=sgpr("GSUSync"), src0=sgpr("GSU"), src1=hex(0x3FFF), comment="Restore GSU"))

SynchronizerlabelString = "Synchronizer_read_add"
SynchronizerComment = "Synchronizer read add"
Synchronizerlabel = Label(self.parentWriter.labels.getNameInc(SynchronizerlabelString), SynchronizerComment)

if(self.kernel["ProblemType"]["DestDataType"].numRegisters() > 1):
tmpVAdd = self.parentWriter.vgprPool.checkOutAligned((GSUtotal-1)*self.gwvw*self.kernel["ProblemType"]["DestDataType"].numRegisters(), 4)
else:
tmpVAdd = self.parentWriter.vgprPool.checkOutAligned((GSUtotal-1)*self.gwvw, 4)
tmpVAdd = self.tmpVgprDynamic
GSUMvgpr = self.tmpVgpr

GSUP1 = GSUtotal-1

Expand Down Expand Up @@ -560,10 +557,8 @@ def GSUSynccodegen(self, labelend, vgprstart, globalOffset, vgproffset):

module.add(SynchronizerAddSkiplabel)

self.parentWriter.vgprPool.checkIn(GSUMvgpr)
module.addComment("buffer add end2\n")

self.parentWriter.vgprPool.checkIn(tmpVAdd)
self.parentWriter.sgprPool.checkIn(tmpS06)
self.parentWriter.sgprPool.checkIn(tmpS05)
self.parentWriter.sgprPool.checkIn(tmpS04)
Expand Down
67 changes: 54 additions & 13 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def getOccupancy(self, numThreads, vgprs, sgprs, ldsSize, accvgprs=0, doubleVgpr
vgprLimitedOccupancy = self.getVgprOccupancy(numThreads, vgprs, doubleVgpr)
accvgprLimitedOccupancy = self.getVgprOccupancy(numThreads, accvgprs, doubleVgpr)
else:
vgprLimitedOccupancy = self.getVgprOccupancy(numThreads, vgprs+accvgprs, doubleVgpr)
vgprLimitedOccupancy = self.getVgprOccupancy(numThreads, ceil(vgprs//8)*8+accvgprs, doubleVgpr)
accvgprLimitedOccupancy = vgprLimitedOccupancy
sgprLimitedOccupancy = self.getSgprOccupancy(sgprs)

Expand All @@ -112,9 +112,11 @@ def getMaxRegsForOccupancy(self, numThreads, vgprs, sgprs, ldsSize, accvgprs=0,
initOccupancy = self.getOccupancy(numThreads, vgprs, sgprs, ldsSize, accvgprs, doubleVgpr)
if initOccupancy == 0: return lastVgprs, 1

while (vgprs + considerAccVgprs) < totalVgprs and vgprs < self.states.regCaps["MaxVgpr"]:
def getVgpr(vgpr, doubleVgpr):
return vgpr if not doubleVgpr else ceil(vgpr/8)*8
while (getVgpr(vgprs, doubleVgpr) + considerAccVgprs) < totalVgprs and vgprs < self.states.regCaps["MaxVgpr"]:
vgprs += 1
if self.getVgprOccupancy(numThreads, vgprs + considerAccVgprs, doubleVgpr) >= initOccupancy:
if self.getVgprOccupancy(numThreads, getVgpr(vgprs, doubleVgpr) + considerAccVgprs, doubleVgpr) >= initOccupancy:
lastVgprs = vgprs
next
else:
Expand Down Expand Up @@ -9724,6 +9726,13 @@ def findInstCount(module, targetItem, count):
self.states.bpeCexternal = bpeCexternalBackup
return module

def getMBSKGSUTotal(self, kernel):
GSUtotal = 16
if (kernel["MIWaveTile"][0] * kernel["MIWaveTile"][1]) * (kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1]) > 8:
GSUtotal = int(GSUtotal/int((kernel["MIWaveTile"][0] * kernel["MIWaveTile"][1]) * (kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1])/8))
GSUtotal = max(2,GSUtotal)
return GSUtotal

##############################################################################
# globalWriteElementBatch :
##############################################################################
Expand Down Expand Up @@ -9753,6 +9762,33 @@ def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
########################################
# Calculate Vgprs for Write Batching
########################################
self.vgprPool.resetOccupancyLimit()
self.sgprPool.resetOccupancyLimit()

# Temporarily grow pool for sgpr
sgprList = []
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
for s in sgprList:
self.sgprPool.checkIn(s)
if actPCMaxTempSgpr > 0:
self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))

tmpVgprDynamic = None
tmpVgprDynamicSize = 0
tmpVgprDynamicAlign = 0
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
GSUTotal = self.getMBSKGSUTotal(kernel)
vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
tmpVgprDynamicSize = vgprMbsk
tmpVgprDynamicAlign = 4
if tmpVgprDynamicSize > 0:
tmpVgprDynamic = RegisterPoolResource(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)

ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)

Expand All @@ -9761,9 +9797,8 @@ def setOccupancy():
maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
# Set occupancy limit for register pools
# TODO: Support GSUMBSK
# TODO: Support gfx120X
if (kernel["_GlobalAccumulation"] != 'MultipleBufferSingleKernel') and (kernel["ISA"][0] != 12):
# TODO: Support gfx12
if kernel["ISA"][0] != 12:
self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
return maxVgprs, occupancy
Expand All @@ -9772,7 +9807,7 @@ def setOccupancy():
# Get estimated numVgprAvailable
# print("Max vgprs =", maxVgprs, self.vgprPool.size(), self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align))
numVgprAvailable = self.vgprPool.availableBlockMaxVgpr(maxVgprs, ss.numVgprsPerElement, ss.align)

# Grow the register pool if needed - we need enough regs for at least one element
# Unfortunate since this means the write logic is setting the VGPR requirement
# for the entire kernel but at least we have a functional kernel.
Expand Down Expand Up @@ -9940,14 +9975,19 @@ def setOccupancy():
applyAlpha, beta, edge, atomic, gwvw, atomicW, \
elementsThisBatch, self.vgprs.addrE, self.vgprs.addrD, self.vgprs.addrC, self.vgprs.addrBias, \
self.vgprs.addrScaleAVec, self.vgprs.addrScaleBVec, self.vgprs.addrScaleAlphaVec, \
biasLocalBarrierInit, tmpVgpr, cvtVgprStruct, activationSetPCStruct, \
biasLocalBarrierInit, tmpVgpr, tmpVgprDynamic, cvtVgprStruct, activationSetPCStruct, \
activationTypeStr, elementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, factorDim))
biasLocalBarrierInit = True

ss.resetState()
actLoopModuleList.append(actLoopModule)
actLoopModuleCodeLength.append(actLoopModule.countType(Instruction))

#################
# Free after final vgpr vcalculation
if tmpVgprDynamic:
self.vgprPool.checkIn(tmpVgprDynamic.idx)

if len(actLoopLabelModules) > 1:
actInstCounter = 0
# Add activation branch
Expand Down Expand Up @@ -10492,15 +10532,15 @@ def globalWriteBatch(self, kernel, tPA, tPB, activation, ss: StoreState, batchId
applyAlpha, beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, \
addrScaleAVec, addrScaleBVec, addrScaleAlphaVec, biasLocalBarrierInit: bool, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, \
tmpVgpr, tmpVgprDynamic, cvtVgprStruct, activationSetPCStruct, activationTypeStr, \
batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, factorDim) -> Module:
packdata = Component.PackData.find(self)
gwriter = Component.GlobalWriteComponents.find(self)
return gwriter(kernel, tPA, tPB, activation, ss, \
batchIdx, applyAlpha, beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, \
addrScaleAVec, addrScaleBVec, addrScaleAlphaVec, biasLocalBarrierInit, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, \
tmpVgpr, tmpVgprDynamic, cvtVgprStruct, activationSetPCStruct, activationTypeStr, \
batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, packdata, self, factorDim)

##############################################################################
Expand Down Expand Up @@ -11696,11 +11736,12 @@ def sMagicDivWrapper(self, dest, dividend, magicNumber, magicShift):
self.vgprPool.checkIn(tmpVgpr)
return module

def s_mul_u64_u32 (self, dst0, dst1, src0, src1, comment):
vtmp0 = self.vgprPool.checkOut(2)
def s_mul_u64_u32 (self, dst0, dst1, src0, src1, tmpVgpr=None, comment=""):
vtmp0 = self.vgprPool.checkOut(2) if tmpVgpr == None else tmpVgpr
module = SMulInt64to32(self.states.asmCaps["HasSMulHi"], \
dst0, dst1, src0, src1, False, vtmp0, comment)
self.vgprPool.checkIn(vtmp0)
if tmpVgpr == None:
self.vgprPool.checkIn(vtmp0)
return module

def s_mul_i64_i32 (self, dst0, dst1, src0, src1, comment):
Expand Down

0 comments on commit 88fef99

Please sign in to comment.