Skip to content

Commit

Permalink
Use archVGPR when accVGPR is not enough. (#1460)
Browse files Browse the repository at this point in the history
This PR is to support larger MT such as 256x320.
Generally, we only have 256 accVGPRs.
If MT is larger than 256x256, we need some extra archVGPR to store the acc
results.
  • Loading branch information
hcman2 authored Dec 23, 2024
1 parent 0c8494a commit 09ba034
Show file tree
Hide file tree
Showing 6 changed files with 460 additions and 44 deletions.
17 changes: 7 additions & 10 deletions tensilelite/Tensile/Components/LSU.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def writeReadReduction(self, writer, kernel):
numTotalAccVgprLdsReduction = numTotalAccVgprLdsReduction // kernel["LocalSplitU"]
self.accVgprLdsReduction = writer.vgprPool.checkOutAligned(numTotalAccVgprLdsReduction, 4, "LsuReduction")
module.add(RegSet("v", "vgprLsuReduction", self.accVgprLdsReduction))
module.addComment0("Size of vgprLsuReduction is %u"%(numTotalAccVgprLdsReduction))
writer.states.c.startVgprValu = self.accVgprLdsReduction

# Local Read VGPR idx
Expand Down Expand Up @@ -211,16 +212,12 @@ def writeReadReduction(self, writer, kernel):
destIdx = 0
for lsu in range(kernel["LocalSplitU"]):
for i in range(numVgprPerLSU):
srcIdx = neededAccVGPRIdx[lsu][i]
if not kernel["MIArchVgpr"]:
accStr = accvgpr(srcIdx)
module.add(VAccvgprReadB32(dst=vgpr(accVgprRes+destIdx),
src=accStr,
comment="copy acc[%u] to vreg[%u], LSU%u will process" % (srcIdx,destIdx,lsu)))
else:
module.add(VMovB32(dst=vgpr(accVgprRes+destIdx),
src=vgpr("ValuC+%u"%srcIdx),
comment="copy MI out reg to vreg[%u], LSU%u will process" % (destIdx,lsu)))
srcIdx = neededAccVGPRIdx[lsu][i]
readInst = writer.accVgprReadWriteFunction(kernel, srcIdx, True)
srcVgpr = writer.accVgprReadWriteIndex(kernel, srcIdx)
module.add(readInst(dst=vgpr(accVgprRes+destIdx),
src=srcVgpr,
comment="copy acc[%u] to vreg[%u], LSU%u will process" % (srcIdx,destIdx,lsu)))
destIdx += 1

dataPerWave = numAccVgpr * kernel["WavefrontSize"] * 4
Expand Down
16 changes: 8 additions & 8 deletions tensilelite/Tensile/Components/ShiftVectorComponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,13 @@ def ShiftVectorComponentsMFMAPartialThread(self, writer, kernel, tP):
for c in range(complexMultiplier):
for nr in range(regPerElem):
vgprOffsetForSCIU = 0
copyInst = VAccvgprReadB32 if not kernel["MIArchVgpr"] else VMovB32
for e in range(min(r, allContOutCoal)):
src = (e+(glvw-r)) % allContOutCoal
srcVgpr = (src + (vw * glvw) + allContOutCoal * mb) * regStrideCoal
srcVgpr = srcVgpr + ot * regStridePrep
srcVgpr = arch2acc[srcVgpr] * regPerElem + nr + c * accImOffset + vgprOffsetForSCIU
srcVal = accvgpr(srcVgpr) if not kernel["MIArchVgpr"] else vgpr(srcVgpr)
srcVal = writer.accVgprReadWriteIndex(kernel, srcVgpr)
copyInst = writer.accVgprReadWriteFunction(kernel, srcVgpr, True)
module.add(copyInst(dst=vgpr(tReg+e), src=srcVal, comment="glvw %u mb %u tt1 %u r %u" % (r, mb, ot, nr)))

if not kernel["MIArchVgpr"]:
Expand All @@ -313,12 +313,12 @@ def ShiftVectorComponentsMFMAPartialThread(self, writer, kernel, tP):
if needWait:
module.add(SWaitCnt(waitAll=True, comment="wait for swizzle operation"))

copyInst = VAccvgprWriteB32 if not kernel["MIArchVgpr"] else VMovB32
for e in range(min(r, allContOutCoal)):
dstVgpr = (e + (vw * glvw) + allContOutCoal * mb) * regStrideCoal
dstVgpr = dstVgpr + ot * regStridePrep
dstVgpr = arch2acc[dstVgpr] * regPerElem + nr + c * accImOffset + vgprOffsetForSCIU
dstStr = accvgpr(dstVgpr) if not kernel["MIArchVgpr"] else vgpr(dstVgpr)
dstStr = writer.accVgprReadWriteIndex(kernel, dstVgpr)
copyInst = writer.accVgprReadWriteFunction(kernel, dstVgpr, False)
module.add(copyInst(dst=dstStr, src=vgpr(tReg+e)))

# end shift reset mask and jump out
Expand Down Expand Up @@ -495,7 +495,6 @@ def ShiftVectorComponentsMFMAAllThread(self, writer, kernel, tP):
for dstMbblkId in range(glvw//(numContOutCoal*numThreadInCoal)):
for dstThreadId in range(numThreadInCoal):
skip = True
copyInst = VAccvgprReadB32 if not kernel["MIArchVgpr"] else VMovB32
for dstContId in range(numContOutCoal):
dst = dstContId + dstThreadId * numContOutCoal + dstMbblkId * numThreadInCoal * numContOutCoal
src = dst + (glvw - shift)
Expand All @@ -509,7 +508,8 @@ def ShiftVectorComponentsMFMAAllThread(self, writer, kernel, tP):
srcGpr = srcContId + srcMbblkId * numContOutCoal + glvwBlk * numRegInGlvwblkCoal + tt * numRegInMIBCoal
srcGpr = srcGpr * regStrideCoal + ot * regStridePrep
srcGpr = arch2acc[srcGpr]
srcGprStr = accvgpr(srcGpr) if not kernel["MIArchVgpr"] else vgpr(srcGpr)
srcGprStr = writer.accVgprReadWriteIndex(kernel, srcGpr)
copyInst = copyInst = writer.accVgprReadWriteFunction(kernel, srcGpr, True)
module.add(copyInst(dst=vgpr(movRegId), src=srcGprStr, comment=""))

if not skip:
Expand Down Expand Up @@ -540,7 +540,6 @@ def ShiftVectorComponentsMFMAAllThread(self, writer, kernel, tP):
module.add(VCmpXEqU32(dst=sgpr(tmpSgpr, writer.states.laneSGPRCount), src0=vgpr(threadIdInCoalReg), src1=sgpr(tmpSgpr), comment="is thread in edge glvw region"))
module.add(SNop(waitState=3, comment="wait for exec mask"))

copyInst = VAccvgprWriteB32 if not kernel["MIArchVgpr"] else VMovB32
for dstContId in range(numContOutCoal):
dst = dstContId + dstThreadId * numContOutCoal + dstMbblkId * numThreadInCoal * numContOutCoal
src = dst + (glvw - shift)
Expand All @@ -550,7 +549,8 @@ def ShiftVectorComponentsMFMAAllThread(self, writer, kernel, tP):
dstGpr = dstContId + dstMbblkId * numContOutCoal + glvwBlk * numRegInGlvwblkCoal + tt * numRegInMIBCoal
dstGpr = dstGpr * regStrideCoal + ot * regStridePrep
dstGpr = arch2acc[dstGpr]
dstGprStr = accvgpr(dstGpr) if not kernel["MIArchVgpr"] else vgpr(dstGpr)
dstGprStr = writer.accVgprReadWriteIndex(kernel, dstGpr)
copyInst = writer.accVgprReadWriteFunction(kernel, dstGpr, False)
module.add(copyInst(dst=dstGprStr, src=vgpr(movRegId), comment=""))

if not skip:
Expand Down
13 changes: 10 additions & 3 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class StateValues:
bias: MatrixInfo = field(default_factory=MatrixInfo)
m: ABMatrixInfo = field(default_factory=ABMatrixInfo) # For Sparse Metadata
totalAgprs: int = 0
maxLimitAgprs: int = 0
totalMixedAgprs: int = 0
totalVgprs: int = 0
totalSgprs: int = 0
lastValuAB: int = 0
Expand Down Expand Up @@ -3635,7 +3637,9 @@ def readWriteVectors(mat, vw, kernel):
# VGPR Assignment
####################################
vgprIdx = 0
self.states.totalAgprs = 0
self.states.totalAgprs = 0
self.states.totalMixedAgprs = 0
self.states.maxLimitAgprs = self.states.regCaps["PhysicalMaxVgpr"] - self.states.regCaps["MaxVgpr"]
self.states.c.startVgprValu = vgprIdx; vgprIdx += self.states.c.numVgprValu

if kernel["EnableMatrixInstruction"]:
Expand All @@ -3653,8 +3657,11 @@ def readWriteVectors(mat, vw, kernel):
########################################
if not kernel["MIArchVgpr"]:
self.states.totalAgprs = self.states.c.numVgprValu
vgprIdx = 0
self.states.c.numVgprValu = 0
if self.states.totalAgprs > self.states.maxLimitAgprs:
self.states.totalMixedAgprs = self.states.totalAgprs - self.states.maxLimitAgprs
self.states.totalAgprs = self.states.maxLimitAgprs
vgprIdx = self.states.totalMixedAgprs
self.states.c.numVgprValu = self.states.totalMixedAgprs

# TODO: alignment hack, figure out a better solution
vgprIdx = ((vgprIdx+1)//2)*2
Expand Down
47 changes: 33 additions & 14 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -3815,15 +3815,15 @@ def initC(self, kernel):
module.addComment1("initC: remove acc vgpr buffer [%u...%u) from pool"%(0, numAccvgprs))
self.vgprPool.remove(self.states.a.startVgprValu , self.states.lastValuAB - self.states.a.startVgprValu , "ValuAB")
module.addComment1("initC: remove ValuA/B vgpr buffer [%u...%u) from pool"%(self.states.a.startVgprValu , self.states.lastValuAB))
numCVgpr = max(self.states.c.numVgprValu, numAccvgprs)
numCVgpr = self.states.c.numVgprValu + numAccvgprs

if kernel["LdsInitCVgprs"]:
tmpAddr = self.vgprPool.checkOut(1,"tmp vgpr for lds init C registers")
module.add(VMovB32(dst=vgpr(tmpAddr), src=self.consts.ldsOOB, comment="set out-of-bound addr"))

for i in range(0, numCVgpr):
copyInst = VMovB32 if self.states.c.numVgprValu else VAccvgprWrite
regStr = vgpr("ValuC+%u"%i) if self.states.c.numVgprValu else accvgpr(i)
copyInst = VMovB32 if i >= numAccvgprs else VAccvgprWrite
regStr = vgpr("ValuC+%u"%(i-numAccvgprs)) if i >= numAccvgprs else accvgpr(i)
if not kernel["LdsInitCVgprs"]:
module.add(copyInst(dst=regStr, src=hex(0), comment="initC"))
else:
Expand Down Expand Up @@ -5134,9 +5134,9 @@ def fixPreloadOffset(offset, sgpxIdxVec, numStoreSgprToLoad):
#instCycles = kernel["MatrixInstM"] // 2 # 32x32 is 64 cycles, 16x16 is 32 cycles, 4x4 is 8 cycles
#module.add(SNop(waitState=instCycles))
module.addComment1("Mapping of Acc register -> C Vgpr register")
self.codes.accVgprRead = mapAcctoArchRegs(kernel, write=False)
self.codes.accVgprRead = mapAcctoArchRegs(kernel, self.states.maxLimitAgprs, write=False)
if kernel["StreamK"] > 0 and kernel["StreamKAtomic"] == 0:
self.codes.accVgprWrite = mapAcctoArchRegs(kernel, write=True)
self.codes.accVgprWrite = mapAcctoArchRegs(kernel, self.states.maxLimitAgprs, write=True)
if kernel["MIArchVgpr"]:
module.addComment1("Multiply MI out register with Alpha -> C Vgpr register")
self.codes.mulAlphaMultipleBuffer = moveMIoutToArch(kernel, self.states.startVgprAlphaTmp)
Expand Down Expand Up @@ -5214,6 +5214,26 @@ def macIter(self, kernel, tPA, tPB, bufferIdx, iuiCount, useMacro, isTail=False)
imod.addSpaceLine()
return imod

##############################################################################
# ACC Vgpr R/W Function
##############################################################################
def accVgprReadWriteFunction(self, kernel, idx, read=True):
if not kernel["MIArchVgpr"]:
if idx >= self.states.maxLimitAgprs:
return VMovB32
else:
return VAccvgprReadB32 if read else VAccvgprWriteB32
else:
return VMovB32
def accVgprReadWriteIndex(self, kernel, idx, sz=1):
if not kernel["MIArchVgpr"]:
if idx >= self.states.maxLimitAgprs:
return vgpr(idx - self.states.maxLimitAgprs, sz)
else:
return accvgpr(idx, sz)
else:
return vgpr(idx, sz)

##############################################################################
# MFMA Iteration
##############################################################################
Expand Down Expand Up @@ -5249,7 +5269,6 @@ def mfmaIter(self, kernel, tPA, tPB, u, innerUnroll, vregSetIdx, unrollLoopIdx =
vgprPerInput = max(vgprPerInputA,vgprPerInputB)
shiftPerElement = int(numRegistersIn * 32)
s_nop = 0
gprfunc = accvgpr if not kernel["MIArchVgpr"] else vgpr
accumRegType = "acc" if not kernel["MIArchVgpr"] else "v"
mfma_1k = True if kernel["MFMA_BF16_1K"] else False
accStoreCIdx = 0
Expand Down Expand Up @@ -5625,19 +5644,19 @@ def mfmaIter(self, kernel, tPA, tPB, u, innerUnroll, vregSetIdx, unrollLoopIdx =
imod.add(inst)
variant = [kernel["MatrixInstM"], kernel["MatrixInstN"], kernel["MatrixInstK"], kernel["MatrixInstB"]]
imod.add(MFMAInstruction(instType=miInInstType, accType=miOutInstType, variant=variant, mfma1k=False, \
acc=gprfunc(accStart, (accEnd-accStart+1)), a=src0, b=src1, acc2=gprfunc(accStart, (accEnd-accStart+1)), \
acc=self.accVgprReadWriteIndex(kernel, accStart, (accEnd-accStart+1)), a=src0, b=src1, acc2=self.accVgprReadWriteIndex(kernel, accStart, (accEnd-accStart+1)), \
comment="Cr += Ar*Br"))
(src0, src1) = (bi, (vgpr(ccVgprs[0] + offsetVgpr[0], numRegistersOut) if ccVgprs[0] else ai)) if kernel["SourceSwap"] else ((vgpr(ccVgprs[0] + offsetVgpr[0], numRegistersOut) if ccVgprs[0] else ai), bi)
imod.add(MFMAInstruction(instType=miInInstType, accType=miOutInstType, variant=variant, mfma1k=False, \
acc=gprfunc((accStart+accStoreCIdx), (accEnd-accStart+1)), a=src0, b=src1, acc2=gprfunc(accStart, (accEnd-accStart+1)), \
acc=self.accVgprReadWriteIndex(kernel, (accStart+accStoreCIdx), (accEnd-accStart+1)), a=src0, b=src1, acc2=self.accVgprReadWriteIndex(kernel, accStart, (accEnd-accStart+1)), \
comment="Cr += %sAi*Bi"%("-" if ccVgprs[0] else "")))
(src0, src1) = (br, (vgpr(ccVgprs[1] + offsetVgpr[1], numRegistersOut) if ccVgprs[1] else ai)) if kernel["SourceSwap"] else ((vgpr(ccVgprs[1] + offsetVgpr[1], numRegistersOut) if ccVgprs[1] else ai), br)
imod.add(MFMAInstruction(instType=miInInstType, accType=miOutInstType, variant=variant, mfma1k=False, \
acc=gprfunc((accStart+accImOffset), (accEnd-accStart+1)), a=src0, b=src1, acc2=gprfunc(accStartSrcImg, (accEndSrcImg-accStartSrcImg+1)), \
acc=self.accVgprReadWriteIndex(kernel, (accStart+accImOffset), (accEnd-accStart+1)), a=src0, b=src1, acc2=self.accVgprReadWriteIndex(kernel, accStartSrcImg, (accEndSrcImg-accStartSrcImg+1)), \
comment="Ci += %sAi*Br"%("-" if ccVgprs[1] else "")))
(src0, src1) = (bi, (vgpr(ccVgprs[2] + offsetVgpr[2], numRegistersOut) if ccVgprs[2] else ar)) if kernel["SourceSwap"] else ((vgpr(ccVgprs[2] + offsetVgpr[2], numRegistersOut) if ccVgprs[2] else ar), bi)
imod.add(MFMAInstruction(instType=miInInstType, accType=miOutInstType, variant=variant, mfma1k=False, \
acc=gprfunc((accStart+accImOffset+accStoreCIdx), (accEnd-accStart+1)), a=src0, b=src1, acc2=gprfunc(accStartSrcImg, (accEndSrcImg-accStartSrcImg+1)), \
acc=self.accVgprReadWriteIndex(kernel, (accStart+accImOffset+accStoreCIdx), (accEnd-accStart+1)), a=src0, b=src1, acc2=self.accVgprReadWriteIndex(kernel, accStartSrcImg, (accEndSrcImg-accStartSrcImg+1)), \
comment="Ci += %sAr*Bi"%("-" if ccVgprs[2] else "")))
for v in ccVgprs:
if v is not None: self.vgprPool.checkIn(v)
Expand All @@ -5661,18 +5680,18 @@ def mfmaIter(self, kernel, tPA, tPB, u, innerUnroll, vregSetIdx, unrollLoopIdx =
idx = idx1 if kernel["ProblemType"]["Sparse"] == 2 else idx0
accInStart = miWaveTile * kernel["LoopIters"] * unrollLoopIdx + idx * kernel["LoopIters"] + unrollIdx
imod.add(SMFMAInstruction(instType=miInInstType, accType=miOutInstType, variant=variant, mfma1k=mfma_1k, \
acc=gprfunc((accStart+accStoreCIdx), (accEnd-accStart+1)), \
acc=self.accVgprReadWriteIndex(kernel, (accStart+accStoreCIdx), (accEnd-accStart+1)), \
a=src0, b=src1, metadata=vgpr("ValuMetadata+%u"%(accInStart)), \
comment="left value = %s[%u+%u:%u+%u]" % (accumRegType, accStart, accStoreCIdx, accEnd, accStoreCIdx)))
else:
imod.add(SMFMAInstruction(instType=miInInstType, accType=miOutInstType, variant=variant, mfma1k=mfma_1k, \
acc=gprfunc((accStart+accStoreCIdx), (accEnd-accStart+1)), \
acc=self.accVgprReadWriteIndex(kernel, (accStart+accStoreCIdx), (accEnd-accStart+1)), \
a=src0, b=src1, metadata=mStr, \
comment="left value = %s[%u+%u:%u+%u]" % (accumRegType, accStart, accStoreCIdx, accEnd, accStoreCIdx)))
else:
imod.add(MFMAInstruction(instType=miInInstType, accType=miOutInstType, variant=variant, mfma1k=mfma_1k, \
acc=gprfunc((accStart+accStoreCIdx), (accEnd-accStart+1)), \
a=src0, b=src1, acc2=gprfunc(accStart, (accEnd-accStart+1)), neg=neg_flag,\
acc=self.accVgprReadWriteIndex(kernel, (accStart+accStoreCIdx), (accEnd-accStart+1)), \
a=src0, b=src1, acc2=self.accVgprReadWriteIndex(kernel, accStart, (accEnd-accStart+1)), neg=neg_flag,\
comment="left value = %s[%u+%u:%u+%u]" % (accumRegType, accStart, accStoreCIdx, accEnd, accStoreCIdx)))
prevAccIdx = accIdx

Expand Down
Loading

0 comments on commit 09ba034

Please sign in to comment.