Skip to content

Commit

Permalink
Add int overflow checks
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv committed May 9, 2024
1 parent 1a5df77 commit 45c1ae7
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
2 changes: 2 additions & 0 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1039,9 +1039,11 @@ func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newType refl
case *ast.IntegerNode:
(*node).SetType(newType)
case *ast.UnaryNode:
(*node).SetType(newType)
unaryNode := (*node).(*ast.UnaryNode)
traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newType)
case *ast.BinaryNode:
// TODO: Binary node return type is dependent on the type of the operands. We can't just change the type of the node.
binaryNode := (*node).(*ast.BinaryNode)
switch binaryNode.Operator {
case "+", "-", "*":
Expand Down
28 changes: 28 additions & 0 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package compiler

import (
"fmt"
"math"
"reflect"
"regexp"

Expand Down Expand Up @@ -329,22 +330,49 @@ func (c *compiler) IntegerNode(node *ast.IntegerNode) {
case reflect.Int:
c.emitPush(node.Value)
case reflect.Int8:
if node.Value > math.MaxInt8 || node.Value < math.MinInt8 {
panic(fmt.Sprintf("constant %d overflows int8", node.Value))
}
c.emitPush(int8(node.Value))
case reflect.Int16:
if node.Value > math.MaxInt16 || node.Value < math.MinInt16 {
panic(fmt.Sprintf("constant %d overflows int16", node.Value))
}
c.emitPush(int16(node.Value))
case reflect.Int32:
if node.Value > math.MaxInt32 || node.Value < math.MinInt32 {
panic(fmt.Sprintf("constant %d overflows int32", node.Value))
}
c.emitPush(int32(node.Value))
case reflect.Int64:
if node.Value > math.MaxInt64 || node.Value < math.MinInt64 {
panic(fmt.Sprintf("constant %d overflows int64", node.Value))
}
c.emitPush(int64(node.Value))
case reflect.Uint:
if node.Value < 0 {
panic(fmt.Sprintf("constant %d overflows uint", node.Value))
}
c.emitPush(uint(node.Value))
case reflect.Uint8:
if node.Value > math.MaxUint8 || node.Value < 0 {
panic(fmt.Sprintf("constant %d overflows uint8", node.Value))
}
c.emitPush(uint8(node.Value))
case reflect.Uint16:
if node.Value > math.MaxUint16 || node.Value < 0 {
panic(fmt.Sprintf("constant %d overflows uint16", node.Value))
}
c.emitPush(uint16(node.Value))
case reflect.Uint32:
if node.Value > math.MaxUint32 || node.Value < 0 {
panic(fmt.Sprintf("constant %d overflows uint32", node.Value))
}
c.emitPush(uint32(node.Value))
case reflect.Uint64:
if node.Value < 0 {
panic(fmt.Sprintf("constant %d overflows uint64", node.Value))
}
c.emitPush(uint64(node.Value))
default:
c.emitPush(node.Value)
Expand Down
14 changes: 14 additions & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2645,3 +2645,17 @@ func TestIssue_570(t *testing.T) {
require.NoError(t, err)
require.IsType(t, nil, out)
}

func TestIssue_integer_truncated_by_compiler(t *testing.T) {
env := map[string]any{
"fn": func(x byte) byte {
return x
},
}

_, err := expr.Compile("fn(255)", expr.Env(env))
require.NoError(t, err)

_, err = expr.Compile("fn(256)", expr.Env(env))
require.Error(t, err)
}

0 comments on commit 45c1ae7

Please sign in to comment.