Skip to content

Commit

Permalink
cmd/compile: optimize single-precision floating point square root
Browse files Browse the repository at this point in the history
Add generic rule to rewrite the single-precision square root expression
with one single-precision instruction. The optimization will reduce two
times of precision converting between double-precision and single-precision.

On arm64 flatform.

previous:
  FCVTSD F0, F0
  FSQRTD F0, F0
  FCVTDS F0, F0

optimized:
  FSQRTS S0, S0

And this patch adds the test case to check the correctness.

This patch refers to CL 241877, contributed by Alice Xu
(dianhong.xu@arm.com)

Change-Id: I6de5d02281c693017ac4bd4c10963dd55989bd7e
Reviewed-on: https://go-review.googlesource.com/c/go/+/276873
Trust: fannie zhang <Fannie.Zhang@arm.com>
Run-TryBot: fannie zhang <Fannie.Zhang@arm.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Keith Randall <khr@golang.org>
  • Loading branch information
zhangfannie committed Mar 2, 2021
1 parent ebb92df commit 2b50ab2
Show file tree
Hide file tree
Showing 41 changed files with 255 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/cmd/compile/internal/amd64/ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -1053,15 +1053,15 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
p.To.Type = obj.TYPE_REG
p.To.Reg = v.Reg0()

case ssa.OpAMD64BSFQ, ssa.OpAMD64BSRQ, ssa.OpAMD64BSFL, ssa.OpAMD64BSRL, ssa.OpAMD64SQRTSD:
case ssa.OpAMD64BSFQ, ssa.OpAMD64BSRQ, ssa.OpAMD64BSFL, ssa.OpAMD64BSRL, ssa.OpAMD64SQRTSD, ssa.OpAMD64SQRTSS:
p := s.Prog(v.Op.Asm())
p.From.Type = obj.TYPE_REG
p.From.Reg = v.Args[0].Reg()
p.To.Type = obj.TYPE_REG
switch v.Op {
case ssa.OpAMD64BSFQ, ssa.OpAMD64BSRQ:
p.To.Reg = v.Reg0()
case ssa.OpAMD64BSFL, ssa.OpAMD64BSRL, ssa.OpAMD64SQRTSD:
case ssa.OpAMD64BSFL, ssa.OpAMD64BSRL, ssa.OpAMD64SQRTSD, ssa.OpAMD64SQRTSS:
p.To.Reg = v.Reg()
}
case ssa.OpAMD64ROUNDSD:
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/arm/ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
ssa.OpARMREV,
ssa.OpARMREV16,
ssa.OpARMRBIT,
ssa.OpARMSQRTF,
ssa.OpARMSQRTD,
ssa.OpARMNEGF,
ssa.OpARMNEGD,
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/arm64/ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
ssa.OpARM64FMOVSgpfp,
ssa.OpARM64FNEGS,
ssa.OpARM64FNEGD,
ssa.OpARM64FSQRTS,
ssa.OpARM64FSQRTD,
ssa.OpARM64FCVTZSSW,
ssa.OpARM64FCVTZSDW,
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/mips/ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
ssa.OpMIPSMOVDF,
ssa.OpMIPSNEGF,
ssa.OpMIPSNEGD,
ssa.OpMIPSSQRTF,
ssa.OpMIPSSQRTD,
ssa.OpMIPSCLZ:
p := s.Prog(v.Op.Asm())
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/mips64/ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
ssa.OpMIPS64MOVDF,
ssa.OpMIPS64NEGF,
ssa.OpMIPS64NEGD,
ssa.OpMIPS64SQRTF,
ssa.OpMIPS64SQRTD:
p := s.Prog(v.Op.Asm())
p.From.Type = obj.TYPE_REG
Expand Down
2 changes: 1 addition & 1 deletion src/cmd/compile/internal/s390x/ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
p.Reg = v.Args[1].Reg()
p.To.Type = obj.TYPE_REG
p.To.Reg = v.Reg()
case ssa.OpS390XFSQRT:
case ssa.OpS390XFSQRTS, ssa.OpS390XFSQRT:
p := s.Prog(v.Op.Asm())
p.From.Type = obj.TYPE_REG
p.From.Reg = v.Args[0].Reg()
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/386.rules
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
(Bswap32 ...) => (BSWAPL ...)

(Sqrt ...) => (SQRTSD ...)
(Sqrt32 ...) => (SQRTSS ...)

(Ctz16 x) => (BSFL (ORLconst <typ.UInt32> [0x10000] x))
(Ctz16NonZero ...) => (BSFL ...)
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/386Ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ func init() {
{name: "BSWAPL", argLength: 1, reg: gp11, asm: "BSWAPL", resultInArg0: true, clobberFlags: true}, // arg0 swap bytes

{name: "SQRTSD", argLength: 1, reg: fp11, asm: "SQRTSD"}, // sqrt(arg0)
{name: "SQRTSS", argLength: 1, reg: fp11, asm: "SQRTSS"}, // sqrt(arg0), float32

{name: "SBBLcarrymask", argLength: 1, reg: flagsgp, asm: "SBBL"}, // (int32)(-1) if carry is set, 0 if carry is clear.
// Note: SBBW and SBBB are subsumed by SBBL
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/AMD64.rules
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
(PopCount8 x) => (POPCNTL (MOVBQZX <typ.UInt32> x))

(Sqrt ...) => (SQRTSD ...)
(Sqrt32 ...) => (SQRTSS ...)

(RoundToEven x) => (ROUNDSD [0] x)
(Floor x) => (ROUNDSD [1] x)
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/AMD64Ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ func init() {
{name: "POPCNTL", argLength: 1, reg: gp11, asm: "POPCNTL", clobberFlags: true}, // count number of set bits in arg0

{name: "SQRTSD", argLength: 1, reg: fp11, asm: "SQRTSD"}, // sqrt(arg0)
{name: "SQRTSS", argLength: 1, reg: fp11, asm: "SQRTSS"}, // sqrt(arg0), float32

// ROUNDSD instruction isn't guaranteed to be on the target platform (it is SSE4.1)
// Any use must be preceded by a successful check of runtime.x86HasSSE41.
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/ARM.rules
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
(Com(32|16|8) ...) => (MVN ...)

(Sqrt ...) => (SQRTD ...)
(Sqrt32 ...) => (SQRTF ...)
(Abs ...) => (ABSD ...)

// TODO: optimize this for ARMv5 and ARMv6
Expand Down
2 changes: 2 additions & 0 deletions src/cmd/compile/internal/ssa/gen/ARM64.rules
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
(Trunc ...) => (FRINTZD ...)
(FMA x y z) => (FMADDD z x y)

(Sqrt32 ...) => (FSQRTS ...)

// lowering rotates
(RotateLeft8 <t> x (MOVDconst [c])) => (Or8 (Lsh8x64 <t> x (MOVDconst [c&7])) (Rsh8Ux64 <t> x (MOVDconst [-c&7])))
(RotateLeft16 <t> x (MOVDconst [c])) => (Or16 (Lsh16x64 <t> x (MOVDconst [c&15])) (Rsh16Ux64 <t> x (MOVDconst [-c&15])))
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/ARM64Ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ func init() {
{name: "FNEGS", argLength: 1, reg: fp11, asm: "FNEGS"}, // -arg0, float32
{name: "FNEGD", argLength: 1, reg: fp11, asm: "FNEGD"}, // -arg0, float64
{name: "FSQRTD", argLength: 1, reg: fp11, asm: "FSQRTD"}, // sqrt(arg0), float64
{name: "FSQRTS", argLength: 1, reg: fp11, asm: "FSQRTS"}, // sqrt(arg0), float32
{name: "REV", argLength: 1, reg: gp11, asm: "REV"}, // byte reverse, 64-bit
{name: "REVW", argLength: 1, reg: gp11, asm: "REVW"}, // byte reverse, 32-bit
{name: "REV16W", argLength: 1, reg: gp11, asm: "REV16W"}, // byte reverse in each 16-bit halfword, 32-bit
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/ARMOps.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ func init() {
{name: "NEGF", argLength: 1, reg: fp11, asm: "NEGF"}, // -arg0, float32
{name: "NEGD", argLength: 1, reg: fp11, asm: "NEGD"}, // -arg0, float64
{name: "SQRTD", argLength: 1, reg: fp11, asm: "SQRTD"}, // sqrt(arg0), float64
{name: "SQRTF", argLength: 1, reg: fp11, asm: "SQRTF"}, // sqrt(arg0), float32
{name: "ABSD", argLength: 1, reg: fp11, asm: "ABSD"}, // abs(arg0), float64

{name: "CLZ", argLength: 1, reg: gp11, asm: "CLZ"}, // count leading zero
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/MIPS.rules
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
(Com(32|16|8) x) => (NORconst [0] x)

(Sqrt ...) => (SQRTD ...)
(Sqrt32 ...) => (SQRTF ...)

// TODO: optimize this case?
(Ctz32NonZero ...) => (Ctz32 ...)
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/MIPS64.rules
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
(Com(64|32|16|8) x) => (NOR (MOVVconst [0]) x)

(Sqrt ...) => (SQRTD ...)
(Sqrt32 ...) => (SQRTF ...)

// boolean ops -- booleans are represented with 0=false, 1=true
(AndB ...) => (AND ...)
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/MIPS64Ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ func init() {
{name: "NEGF", argLength: 1, reg: fp11, asm: "NEGF"}, // -arg0, float32
{name: "NEGD", argLength: 1, reg: fp11, asm: "NEGD"}, // -arg0, float64
{name: "SQRTD", argLength: 1, reg: fp11, asm: "SQRTD"}, // sqrt(arg0), float64
{name: "SQRTF", argLength: 1, reg: fp11, asm: "SQRTF"}, // sqrt(arg0), float32

// shifts
{name: "SLLV", argLength: 2, reg: gp21, asm: "SLLV"}, // arg0 << arg1, shift amount is mod 64
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/MIPSOps.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ func init() {
{name: "NEGF", argLength: 1, reg: fp11, asm: "NEGF"}, // -arg0, float32
{name: "NEGD", argLength: 1, reg: fp11, asm: "NEGD"}, // -arg0, float64
{name: "SQRTD", argLength: 1, reg: fp11, asm: "SQRTD"}, // sqrt(arg0), float64
{name: "SQRTF", argLength: 1, reg: fp11, asm: "SQRTF"}, // sqrt(arg0), float32

// shifts
{name: "SLL", argLength: 2, reg: gp21, asm: "SLL"}, // arg0 << arg1, shift amount is mod 32
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/PPC64.rules
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
(Round(32|64)F ...) => (LoweredRound(32|64)F ...)

(Sqrt ...) => (FSQRT ...)
(Sqrt32 ...) => (FSQRTS ...)
(Floor ...) => (FFLOOR ...)
(Ceil ...) => (FCEIL ...)
(Trunc ...) => (FTRUNC ...)
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/RISCV64.rules
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
(Com8 ...) => (NOT ...)

(Sqrt ...) => (FSQRTD ...)
(Sqrt32 ...) => (FSQRTS ...)

// Sign and zero extension.

Expand Down
2 changes: 2 additions & 0 deletions src/cmd/compile/internal/ssa/gen/S390X.rules
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@
(Round x) => (FIDBR [1] x)
(FMA x y z) => (FMADD z x y)

(Sqrt32 ...) => (FSQRTS ...)

// Atomic loads and stores.
// The SYNC instruction (fast-BCR-serialization) prevents store-load
// reordering. Other sequences of memory operations (load-load,
Expand Down
1 change: 1 addition & 0 deletions src/cmd/compile/internal/ssa/gen/S390XOps.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ func init() {
{name: "NOTW", argLength: 1, reg: gp11, resultInArg0: true, clobberFlags: true}, // ^arg0

{name: "FSQRT", argLength: 1, reg: fp11, asm: "FSQRT"}, // sqrt(arg0)
{name: "FSQRTS", argLength: 1, reg: fp11, asm: "FSQRTS"}, // sqrt(arg0), float32

// Conditional register-register moves.
// The aux for these values is an s390x.CCMask value representing the condition code mask.
Expand Down
2 changes: 2 additions & 0 deletions src/cmd/compile/internal/ssa/gen/Wasm.rules
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@
(Abs ...) => (F64Abs ...)
(Copysign ...) => (F64Copysign ...)

(Sqrt32 ...) => (F32Sqrt ...)

(Ctz64 ...) => (I64Ctz ...)
(Ctz32 x) => (I64Ctz (I64Or x (I64Const [0x100000000])))
(Ctz16 x) => (I64Ctz (I64Or x (I64Const [0x10000])))
Expand Down
14 changes: 7 additions & 7 deletions src/cmd/compile/internal/ssa/gen/WasmOps.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,13 @@ func init() {
{name: "I64Extend16S", asm: "I64Extend16S", argLength: 1, reg: gp11, typ: "Int64"}, // sign-extend arg0 from 16 to 64 bit
{name: "I64Extend32S", asm: "I64Extend32S", argLength: 1, reg: gp11, typ: "Int64"}, // sign-extend arg0 from 32 to 64 bit

{name: "F32Sqrt", asm: "F32Sqrt", argLength: 1, reg: fp64_11, typ: "Float32"}, // sqrt(arg0)
{name: "F32Trunc", asm: "F32Trunc", argLength: 1, reg: fp64_11, typ: "Float32"}, // trunc(arg0)
{name: "F32Ceil", asm: "F32Ceil", argLength: 1, reg: fp64_11, typ: "Float32"}, // ceil(arg0)
{name: "F32Floor", asm: "F32Floor", argLength: 1, reg: fp64_11, typ: "Float32"}, // floor(arg0)
{name: "F32Nearest", asm: "F32Nearest", argLength: 1, reg: fp64_11, typ: "Float32"}, // round(arg0)
{name: "F32Abs", asm: "F32Abs", argLength: 1, reg: fp64_11, typ: "Float32"}, // abs(arg0)
{name: "F32Copysign", asm: "F32Copysign", argLength: 2, reg: fp64_21, typ: "Float32"}, // copysign(arg0, arg1)
{name: "F32Sqrt", asm: "F32Sqrt", argLength: 1, reg: fp32_11, typ: "Float32"}, // sqrt(arg0)
{name: "F32Trunc", asm: "F32Trunc", argLength: 1, reg: fp32_11, typ: "Float32"}, // trunc(arg0)
{name: "F32Ceil", asm: "F32Ceil", argLength: 1, reg: fp32_11, typ: "Float32"}, // ceil(arg0)
{name: "F32Floor", asm: "F32Floor", argLength: 1, reg: fp32_11, typ: "Float32"}, // floor(arg0)
{name: "F32Nearest", asm: "F32Nearest", argLength: 1, reg: fp32_11, typ: "Float32"}, // round(arg0)
{name: "F32Abs", asm: "F32Abs", argLength: 1, reg: fp32_11, typ: "Float32"}, // abs(arg0)
{name: "F32Copysign", asm: "F32Copysign", argLength: 2, reg: fp32_21, typ: "Float32"}, // copysign(arg0, arg1)

{name: "F64Sqrt", asm: "F64Sqrt", argLength: 1, reg: fp64_11, typ: "Float64"}, // sqrt(arg0)
{name: "F64Trunc", asm: "F64Trunc", argLength: 1, reg: fp64_11, typ: "Float64"}, // trunc(arg0)
Expand Down
3 changes: 3 additions & 0 deletions src/cmd/compile/internal/ssa/gen/generic.rules
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,9 @@
(Div32F x (Const32F <t> [c])) && reciprocalExact32(c) => (Mul32F x (Const32F <t> [1/c]))
(Div64F x (Const64F <t> [c])) && reciprocalExact64(c) => (Mul64F x (Const64F <t> [1/c]))

// rewrite single-precision sqrt expression "float32(math.Sqrt(float64(x)))"
(Cvt64Fto32F sqrt0:(Sqrt (Cvt32Fto64F x))) && sqrt0.Uses==1 => (Sqrt32 x)

(Sqrt (Const64F [c])) && !math.IsNaN(math.Sqrt(c)) => (Const64F [math.Sqrt(c)])

// for rewriting results of some late-expanded rewrites (below)
Expand Down
5 changes: 3 additions & 2 deletions src/cmd/compile/internal/ssa/gen/genericOps.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,14 @@ var genericOps = []opData{
{name: "RotateLeft32", argLength: 2}, // Rotate bits in arg[0] left by arg[1]
{name: "RotateLeft64", argLength: 2}, // Rotate bits in arg[0] left by arg[1]

// Square root, float64 only.
// Square root.
// Special cases:
// +∞ → +∞
// ±0 → ±0 (sign preserved)
// x<0 → NaN
// NaN → NaN
{name: "Sqrt", argLength: 1}, // √arg0
{name: "Sqrt", argLength: 1}, // √arg0 (floating point, double precision)
{name: "Sqrt32", argLength: 1}, // √arg0 (floating point, single precision)

// Round to integer, float64 only.
// Special cases:
Expand Down
Loading

0 comments on commit 2b50ab2

Please sign in to comment.