Skip to content

Commit

Permalink
cmd/compile: try to rewrite loops to count down
Browse files Browse the repository at this point in the history
Fixes #61629

This reduce the pressure on regalloc because then the loop only keep alive
one value (the iterator) instead of the iterator and the upper bound since
the comparison now acts against an immediate, often zero which can be skipped.

This optimize things like:
  for i := 0; i < n; i++ {
Or a range over a slice where the index is not used:
  for _, v := range someSlice {
Or the new range over int from #61405:
  for range n {

It is hit in 975 unique places while doing ./make.bash.

Change-Id: I5facff8b267a0b60ea3c1b9a58c4d74cdb38f03f
Reviewed-on: https://go-review.googlesource.com/c/go/+/512935
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Jorropo <jorropo.pgm@gmail.com>
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
Auto-Submit: Keith Randall <khr@golang.org>
  • Loading branch information
Jorropo authored and gopherbot committed Jul 31, 2023
1 parent 8613ef8 commit bac4e2f
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 21 deletions.
8 changes: 6 additions & 2 deletions src/cmd/compile/internal/ssa/loopbce.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import (
type indVarFlags uint8

const (
indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
indVarMaxInc // maximum value is inclusive (default: exclusive)
indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
indVarMaxInc // maximum value is inclusive (default: exclusive)
indVarCountDown // if set the iteration starts at max and count towards min (default: min towards max)
)

type indVar struct {
ind *Value // induction variable
nxt *Value // the incremented variable
min *Value // minimum value, inclusive/exclusive depends on flags
max *Value // maximum value, inclusive/exclusive depends on flags
entry *Block // entry block in the loop.
Expand Down Expand Up @@ -277,6 +279,7 @@ func findIndVar(f *Func) []indVar {
if !inclusive {
flags |= indVarMinExc
}
flags |= indVarCountDown
step = -step
}
if f.pass.debug >= 1 {
Expand All @@ -285,6 +288,7 @@ func findIndVar(f *Func) []indVar {

iv = append(iv, indVar{
ind: ind,
nxt: nxt,
min: min,
max: max,
entry: b.Succs[0].b,
Expand Down
169 changes: 160 additions & 9 deletions src/cmd/compile/internal/ssa/prove.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,166 @@ func (ft *factsTable) cleanup(f *Func) {
// its negation. If either leads to a contradiction, it can trim that
// successor.
func prove(f *Func) {
// Find induction variables. Currently, findIndVars
// is limited to one induction variable per block.
var indVars map[*Block]indVar
for _, v := range findIndVar(f) {
ind := v.ind
if len(ind.Args) != 2 {
// the rewrite code assumes there is only ever two parents to loops
panic("unexpected induction with too many parents")
}

nxt := v.nxt
if !(ind.Uses == 2 && // 2 used by comparison and next
nxt.Uses == 1) { // 1 used by induction
// ind or nxt is used inside the loop, add it for the facts table
if indVars == nil {
indVars = make(map[*Block]indVar)
}
indVars[v.entry] = v
continue
} else {
// Since this induction variable is not used for anything but counting the iterations,
// no point in putting it into the facts table.
}

// try to rewrite to a downward counting loop checking against start if the
// loop body does not depends on ind or nxt and end is known before the loop.
// This reduce pressure on the register allocator because this do not need
// to use end on each iteration anymore. We compare against the start constant instead.
// That means this code:
//
// loop:
// ind = (Phi (Const [x]) nxt),
// if ind < end
// then goto enter_loop
// else goto exit_loop
//
// enter_loop:
// do something without using ind nor nxt
// nxt = inc + ind
// goto loop
//
// exit_loop:
//
// is rewritten to:
//
// loop:
// ind = (Phi end nxt)
// if (Const [x]) < ind
// then goto enter_loop
// else goto exit_loop
//
// enter_loop:
// do something without using ind nor nxt
// nxt = ind - inc
// goto loop
//
// exit_loop:
//
// this is better because it only require to keep ind then nxt alive while looping,
// while the original form keeps ind then nxt and end alive
start, end := v.min, v.max
if v.flags&indVarCountDown != 0 {
start, end = end, start
}

if !(start.Op == OpConst8 || start.Op == OpConst16 || start.Op == OpConst32 || start.Op == OpConst64) {
// if start is not a constant we would be winning nothing from inverting the loop
continue
}
if end.Op == OpConst8 || end.Op == OpConst16 || end.Op == OpConst32 || end.Op == OpConst64 {
// TODO: if both start and end are constants we should rewrite such that the comparison
// is against zero and nxt is ++ or -- operation
// That means:
// for i := 2; i < 11; i += 2 {
// should be rewritten to:
// for i := 5; 0 < i; i-- {
continue
}

header := ind.Block
check := header.Controls[0]
if check == nil {
// we don't know how to rewrite a loop that not simple comparison
continue
}
switch check.Op {
case OpLeq64, OpLeq32, OpLeq16, OpLeq8,
OpLess64, OpLess32, OpLess16, OpLess8:
default:
// we don't know how to rewrite a loop that not simple comparison
continue
}
if !((check.Args[0] == ind && check.Args[1] == end) ||
(check.Args[1] == ind && check.Args[0] == end)) {
// we don't know how to rewrite a loop that not simple comparison
continue
}
if end.Block == ind.Block {
// we can't rewrite loops where the condition depends on the loop body
// this simple check is forced to work because if this is true a Phi in ind.Block must exists
continue
}

// invert the check
check.Args[0], check.Args[1] = check.Args[1], check.Args[0]

// invert start and end in the loop
for i, v := range check.Args {
if v != end {
continue
}

check.SetArg(i, start)
goto replacedEnd
}
panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end))
replacedEnd:

for i, v := range ind.Args {
if v != start {
continue
}

ind.SetArg(i, end)
goto replacedStart
}
panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end))
replacedStart:

if nxt.Args[0] != ind {
// unlike additions subtractions are not commutative so be sure we get it right
nxt.Args[0], nxt.Args[1] = nxt.Args[1], nxt.Args[0]
}

switch nxt.Op {
case OpAdd8:
nxt.Op = OpSub8
case OpAdd16:
nxt.Op = OpSub16
case OpAdd32:
nxt.Op = OpSub32
case OpAdd64:
nxt.Op = OpSub64
case OpSub8:
nxt.Op = OpAdd8
case OpSub16:
nxt.Op = OpAdd16
case OpSub32:
nxt.Op = OpAdd32
case OpSub64:
nxt.Op = OpAdd64
default:
panic("unreachable")
}

if f.pass.debug > 0 {
f.Warnl(ind.Pos, "Inverted loop iteration")
}
}

ft := newFactsTable(f)
ft.checkpoint()

Expand Down Expand Up @@ -933,15 +1093,6 @@ func prove(f *Func) {
}
}
}
// Find induction variables. Currently, findIndVars
// is limited to one induction variable per block.
var indVars map[*Block]indVar
for _, v := range findIndVar(f) {
if indVars == nil {
indVars = make(map[*Block]indVar)
}
indVars[v.entry] = v
}

// current node state
type walkState int
Expand Down
22 changes: 12 additions & 10 deletions test/codegen/compare_and_branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,25 @@ func si64(x, y chan int64) {
}

// Signed 64-bit compare-and-branch with 8-bit immediate.
func si64x8() {
func si64x8(doNotOptimize int64) {
// take in doNotOptimize as an argument to avoid the loops being rewritten to count down
// s390x:"CGIJ\t[$]12, R[0-9]+, [$]127, "
for i := int64(0); i < 128; i++ {
for i := doNotOptimize; i < 128; i++ {
dummy()
}

// s390x:"CGIJ\t[$]10, R[0-9]+, [$]-128, "
for i := int64(0); i > -129; i-- {
for i := doNotOptimize; i > -129; i-- {
dummy()
}

// s390x:"CGIJ\t[$]2, R[0-9]+, [$]127, "
for i := int64(0); i >= 128; i++ {
for i := doNotOptimize; i >= 128; i++ {
dummy()
}

// s390x:"CGIJ\t[$]4, R[0-9]+, [$]-128, "
for i := int64(0); i <= -129; i-- {
for i := doNotOptimize; i <= -129; i-- {
dummy()
}
}
Expand Down Expand Up @@ -95,24 +96,25 @@ func si32(x, y chan int32) {
}

// Signed 32-bit compare-and-branch with 8-bit immediate.
func si32x8() {
func si32x8(doNotOptimize int32) {
// take in doNotOptimize as an argument to avoid the loops being rewritten to count down
// s390x:"CIJ\t[$]12, R[0-9]+, [$]127, "
for i := int32(0); i < 128; i++ {
for i := doNotOptimize; i < 128; i++ {
dummy()
}

// s390x:"CIJ\t[$]10, R[0-9]+, [$]-128, "
for i := int32(0); i > -129; i-- {
for i := doNotOptimize; i > -129; i-- {
dummy()
}

// s390x:"CIJ\t[$]2, R[0-9]+, [$]127, "
for i := int32(0); i >= 128; i++ {
for i := doNotOptimize; i >= 128; i++ {
dummy()
}

// s390x:"CIJ\t[$]4, R[0-9]+, [$]-128, "
for i := int32(0); i <= -129; i-- {
for i := doNotOptimize; i <= -129; i-- {
dummy()
}
}
Expand Down
10 changes: 10 additions & 0 deletions test/prove_invert_loop_with_unused_iterators.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// +build amd64
// errorcheck -0 -d=ssa/prove/debug=1

package main

func invert(b func(), n int) {
for i := 0; i < n; i++ { // ERROR "(Inverted loop iteration|Induction variable: limits \[0,\?\), increment 1)"
b()
}
}

0 comments on commit bac4e2f

Please sign in to comment.