Skip to content

Commit

Permalink
Revert back *8 scaled loads for AVX2 code and fix before conversion t…
Browse files Browse the repository at this point in the history
…o SVE
  • Loading branch information
fwessels committed Jun 17, 2024
1 parent 4ab53e5 commit 205abfe
Show file tree
Hide file tree
Showing 3 changed files with 803 additions and 794 deletions.
68 changes: 68 additions & 0 deletions _gen/gen-arm-sve.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,72 @@ func convertRoutine(asmBuf *bytes.Buffer, instructions []string) {
}
}

// convert (R..*1) memory accesses into (R..*8) offsets
func patchScaledLoads(code string, outputs int, isXor bool) (patched []string) {

scaledMemOps := strings.Count(code, "*1)")
if scaledMemOps == 0 {
// in case of no scaled loads, exit out early
return strings.Split(code, "\n")
}

sanityCheck := outputs
if isXor {
sanityCheck *= 2 // need to load all values as well as store them
}
if scaledMemOps != sanityCheck {
panic("Couldn't find expected number of scaled memory ops")
}

scaledReg := ""
re := regexp.MustCompile(`R(\d+)\*1`)
if match := re.FindStringSubmatch(code); len(match) > 1 {
scaledReg = fmt.Sprintf("R%s", match[1])
} else {
panic("Failed to find register used for scaled memory ops")
}

const inputs = 10

scaledRegUses := strings.Count(code, scaledReg)
sanityCheck += inputs // needed to add start offset to input
sanityCheck += 1 // needed to load offset from stack
sanityCheck += 1 // needed to increment offset

if scaledRegUses != sanityCheck {
panic("Did not find expected number of uses of scaled register")
}

// Adjust all scaled loads
code = strings.ReplaceAll(code, fmt.Sprintf("(%s*1)", scaledReg), fmt.Sprintf("(%s*8)", scaledReg))

// Adjust increment at end of loop
reAdd := regexp.MustCompile(`ADDQ\s*\$(0x[0-9a-f]+),\s*` + scaledReg)
if match := reAdd.FindStringSubmatch(code); len(match) > 1 && match[1][:2] == "0x" {
if increment, err := strconv.ParseInt(match[1][2:], 16, 64); err == nil {
code = strings.ReplaceAll(code, fmt.Sprintf("0x%x, %s", increment, scaledReg), fmt.Sprintf("0x%02x, %s", increment>>3, scaledReg))
} else {
panic(err)
}
} else {
panic("Failed to find increment of offset")
}

// Add shift instruction during initialization after inputs have been adjusted
reShift := regexp.MustCompilePOSIX(fmt.Sprintf(`^[[:blank:]]+ADDQ[[:blank:]]+%s.*$`, scaledReg))
if matches := reShift.FindAllStringIndex(code, -1); len(matches) == inputs {
lastInpIncr := code[matches[inputs-1][0]:matches[inputs-1][1]]
shiftCorrection := strings.ReplaceAll(strings.Split(lastInpIncr, scaledReg)[0], "ADDQ", "SHRQ")
shiftCorrection += "$0x03, " + scaledReg
code = strings.ReplaceAll(code, lastInpIncr, lastInpIncr+"\n"+shiftCorrection)
} else {
fmt.Println(matches)
panic("Did not find expected number start offset corrections")
}

return strings.Split(code, "\n")
}

func fromAvx2ToSve() {
asmOut, goOut := &bytes.Buffer{}, &bytes.Buffer{}

Expand All @@ -147,6 +213,7 @@ func fromAvx2ToSve() {
if err != nil {
log.Fatal(err)
}
lines = patchScaledLoads(strings.Join(lines, "\n"), output, strings.HasSuffix(templName, "Xor"))
lines = expandHashDefines(lines)

convertRoutine(asmOut, lines)
Expand All @@ -170,6 +237,7 @@ func fromAvx2ToSve() {
if err != nil {
log.Fatal(err)
}
lines = patchScaledLoads(strings.Join(lines, "\n"), output, strings.HasSuffix(templName, "Xor"))
lines = expandHashDefines(lines)

// add additional initialization for SVE
Expand Down
17 changes: 7 additions & 10 deletions _gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,6 @@ func genMulAvx2(name string, inputs int, outputs int, xor bool) {
ADDQ(offset, ptr)
}
// Offset no longer needed unless not regDst
if !regDst {
SHRQ(U8(3), offset) // divide by 8 since we'll be scaling it up when loading or storing
}

tmpMask := GP64()
MOVQ(U32(15), tmpMask)
Expand Down Expand Up @@ -486,9 +483,9 @@ func genMulAvx2(name string, inputs int, outputs int, xor bool) {
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 8}, dst[i])
VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 1}, dst[i])
if prefetchDst > 0 {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 8})
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
}
}
}
Expand Down Expand Up @@ -516,9 +513,9 @@ func genMulAvx2(name string, inputs int, outputs int, xor bool) {
} else {
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: j * 24}, ptr)
VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 8}, dst[j])
VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 1}, dst[j])
if prefetchDst > 0 {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 8})
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
}
}
}
Expand Down Expand Up @@ -551,14 +548,14 @@ func genMulAvx2(name string, inputs int, outputs int, xor bool) {
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
VMOVDQU(dst[i], Mem{Base: ptr, Index: offset, Scale: 8})
VMOVDQU(dst[i], Mem{Base: ptr, Index: offset, Scale: 1})
if prefetchDst > 0 && !xor {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 8})
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
}
}
Comment("Prepare for next loop")
if !regDst {
ADDQ(U8(perLoop>>3), offset)
ADDQ(U8(perLoop), offset)
}
DECQ(length)
JNZ(LabelRef(name + "_loop"))
Expand Down
Loading

0 comments on commit 205abfe

Please sign in to comment.