Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix integer power #215

Merged
merged 3 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions engine/integer.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ func (i Integer) Unify(t Term, occursCheck bool, env *Env) (*Env, bool) {

// Unparse emits tokens that represent the integer.
func (i Integer) Unparse(emit func(token Token), _ *Env, _ ...WriteOption) {
if i < 0 {
s := strconv.FormatInt(int64(i), 10)

if s[0] == '-' {
emit(Token{Kind: TokenGraphic, Val: "-"})
i *= -1
emit(Token{Kind: TokenInteger, Val: s[1:]})
return
}
s := strconv.FormatInt(int64(i), 10)

emit(Token{Kind: TokenInteger, Val: s})
}

Expand Down
12 changes: 12 additions & 0 deletions engine/integer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package engine

import (
"math"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -77,6 +78,17 @@ func TestInteger_Unparse(t *testing.T) {
{Kind: TokenGraphic, Val: "-"},
{Kind: TokenInteger, Val: "33"},
}, tokens)

t.Run("math.MinInt64", func(t *testing.T) {
var tokens []Token
Integer(math.MinInt64).Unparse(func(token Token) {
tokens = append(tokens, token)
}, nil)
assert.Equal(t, []Token{
{Kind: TokenGraphic, Val: "-"},
{Kind: TokenInteger, Val: "9223372036854775808"},
}, tokens)
})
})
}

Expand Down
66 changes: 53 additions & 13 deletions engine/number.go
Original file line number Diff line number Diff line change
Expand Up @@ -919,20 +919,60 @@ func Min(x, y Number) (Number, error) {

// IntegerPower returns x raised to the power of y.
func IntegerPower(x, y Number) (Number, error) {
if x, ok := x.(Integer); ok {
if y, ok := y.(Integer); ok {
if x != 1 && y < -1 {
return nil, TypeErrorFloat(x, nil)
}
vx, ok := x.(Integer)
if !ok {
return Power(x, y)
}

r, err := Power(x, y)
vy, ok := y.(Integer)
if !ok {
return Power(x, y)
}

if vy < 0 {
switch vx {
case 0:
return nil, ErrUndefined
case 1, -1:
vy, err := negI(vy) // y can be math.MinInt64
if err != nil {
return nil, err
}
return truncateFtoI(r.(Float))
r, _ := intPow(vx, vy) // Since x is either 1 or -1, no errors occur.
return intDivI(1, r)
default:
return nil, TypeErrorFloat(vx, nil)
}
}

return intPow(vx, vy)
}

// Loosely based on https://www.programminglogic.com/fast-exponentiation-algorithms/
func intPow(a, b Integer) (Integer, error) {
var (
r = Integer(1)
err error
)
for {
if b&1 != 0 {
r, err = mulI(r, a)
if err != nil {
return 0, err
}
}

b >>= 1
if b == 0 {
break
}

a, err = mulI(a, a)
if err != nil {
return 0, err
}
}
return Power(x, y)
return r, nil
}

// Asin returns the arc sine of x.
Expand Down Expand Up @@ -1210,12 +1250,12 @@ func mulI(x, y Integer) (Integer, error) {
return 0, ErrIntOverflow
case y == 0:
return 0, nil
case x > math.MaxInt64/y:
return 0, ErrIntOverflow
case x < math.MinInt64/y:
return 0, ErrIntOverflow
default:
return x * y, nil
r := x * y
if r/y != x {
return 0, ErrIntOverflow
}
return r, nil
}
}

Expand Down
52 changes: 46 additions & 6 deletions engine/number_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1546,13 +1546,41 @@ func TestMin(t *testing.T) {
func TestIntegerPower(t *testing.T) {
t.Run("integer", func(t *testing.T) {
t.Run("integer", func(t *testing.T) {
r, err := IntegerPower(Integer(1), Integer(1))
assert.NoError(t, err)
assert.Equal(t, Integer(1), r)
t.Run("y is positive", func(t *testing.T) {
r, err := IntegerPower(Integer(1), Integer(1))
assert.NoError(t, err)
assert.Equal(t, Integer(1), r)
})

t.Run("x is not equal to 1 and y is less than -1", func(t *testing.T) {
_, err := IntegerPower(Integer(2), Integer(-2))
assert.Equal(t, TypeErrorFloat(Integer(2), nil), err)
t.Run("y is negative", func(t *testing.T) {
t.Run("x is 1", func(t *testing.T) {
r, err := IntegerPower(Integer(1), Integer(-1))
assert.NoError(t, err)
assert.Equal(t, Integer(1), r)
})

t.Run("x is 0", func(t *testing.T) {
_, err := IntegerPower(Integer(0), Integer(-1))
assert.Equal(t, ErrUndefined, err)
})

t.Run("x is -1", func(t *testing.T) {
t.Run("ok", func(t *testing.T) {
r, err := IntegerPower(Integer(-1), Integer(-1))
assert.NoError(t, err)
assert.Equal(t, Integer(-1), r)
})

t.Run("y is math.MinInt64", func(t *testing.T) {
_, err := IntegerPower(Integer(-1), Integer(math.MinInt64))
assert.Equal(t, ErrIntOverflow, err)
})
})

t.Run("x is neither 1, 0, nor -1", func(t *testing.T) {
_, err := IntegerPower(Integer(2), Integer(-2))
assert.Equal(t, TypeErrorFloat(Integer(2), nil), err)
})
})
})

Expand All @@ -1566,6 +1594,18 @@ func TestIntegerPower(t *testing.T) {
_, err := IntegerPower(Integer(1), &mockNumber{})
assert.Equal(t, ErrUndefined, err)
})

t.Run("overflow", func(t *testing.T) {
t.Run("x is too large", func(t *testing.T) {
_, err := IntegerPower(Integer(math.MaxInt64), Integer(2))
assert.Equal(t, ErrIntOverflow, err)
})

t.Run("y is too large", func(t *testing.T) {
_, err := IntegerPower(Integer(2), Integer(63))
assert.Equal(t, ErrIntOverflow, err)
})
})
})

t.Run("float", func(t *testing.T) {
Expand Down