diff --git a/go.mod b/go.mod index b5dd621295..2098930d8f 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/x448/float16 v0.8.4 github.com/yuin/goldmark v1.4.13 go.uber.org/zap v1.23.0 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d golang.org/x/sync v0.4.0 golang.org/x/sys v0.13.0 golang.org/x/term v0.13.0 @@ -67,7 +68,6 @@ require ( go.opentelemetry.io/otel v0.16.0 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.8.0 // indirect - golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/mod v0.13.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/tools v0.14.0 // indirect diff --git a/runtime/sam/expr/agg/math.go b/runtime/sam/expr/agg/math.go index 18a879c4e3..5f779cb895 100644 --- a/runtime/sam/expr/agg/math.go +++ b/runtime/sam/expr/agg/math.go @@ -20,7 +20,6 @@ type mathReducer struct { function *anymath.Function hasval bool math consumer - pair coerce.Pair } var _ Function = (*mathReducer)(nil) @@ -47,10 +46,7 @@ func (m *mathReducer) consumeVal(val zed.Value) { var id int if m.math != nil { var err error - // XXX We're not using the value coercion parts of coerce.Pair here. - // Would be better if coerce had a function that just compared types - // and returned the type to coerce to. - id, err = m.pair.Coerce(zed.NewValue(m.math.typ(), nil), val) + id, err = coerce.Promote(zed.NewValue(m.math.typ(), nil), val) if err != nil { // Skip invalid values. return diff --git a/runtime/sam/expr/coerce/coerce.go b/runtime/sam/expr/coerce/coerce.go index 19c093cdf4..489fe64c1f 100644 --- a/runtime/sam/expr/coerce/coerce.go +++ b/runtime/sam/expr/coerce/coerce.go @@ -2,158 +2,59 @@ package coerce import ( "bytes" - "errors" - "math" "github.com/brimdata/zed" "github.com/brimdata/zed/pkg/byteconv" - "github.com/brimdata/zed/runtime/sam/expr/result" - "github.com/brimdata/zed/zcode" + "github.com/brimdata/zed/zson" + "golang.org/x/exp/constraints" ) -var Overflow = errors.New("integer overflow: uint64 value too large for int64") -var IncompatibleTypes = errors.New("incompatible types") - -// XXX Named types should probably be preserved according to the rank -// of the underlying number type. - -// Pair provides a buffer to decode values into while doing comparisons -// so the same buffers can be reused on each call without zcode.Bytes buffers -// escaping to GC. This method uses the zed.AppendInt(), zed.AppendUint(), -// etc to encode zcode.Bytes as an in-place slice instead of allocating -// new slice buffers for every value created. -type Pair struct { - // a and b point to inputs that can't change - A zcode.Bytes - B zcode.Bytes - // Buffer is a scratch buffer that stays around between calls and is the - // landing place for either the a or b value if one of them needs to - // be coerced (you never need to coerce both). Then we point a or b - // at buf and let go of the other input pointer. - result.Buffer - buf2 result.Buffer -} - -func (c *Pair) Equal() bool { - // bytes.Equal() returns true for nil compared to an empty-slice, - // which doesn't work for Zed null comparisons, so we explicitly check - // for the nil condition here. - if c.A == nil { - return c.B == nil - } - if c.B == nil { - return c.A == nil - } - return bytes.Equal(c.A, c.B) -} - -func (c *Pair) Coerce(a, b zed.Value) (int, error) { - a, b = a.Under(), b.Under() - c.A = a.Bytes() - c.B = b.Bytes() - aid := a.Type().ID() - bid := b.Type().ID() - if aid == bid { - return aid, nil - } - if aid == zed.IDNull { - return bid, nil - } - if bid == zed.IDNull { - return aid, nil - } - if zed.IsNumber(aid) { - if !zed.IsNumber(bid) { - return 0, IncompatibleTypes +func Equal(a, b zed.Value) bool { + if a.IsNull() { + return b.IsNull() + } else if b.IsNull() { + // We know a isn't null. + return false + } + switch aid, bid := a.Type().ID(), b.Type().ID(); { + case !zed.IsNumber(aid) || !zed.IsNumber(bid): + return aid == bid && bytes.Equal(a.Bytes(), b.Bytes()) + case zed.IsFloat(aid): + return a.Float() == ToNumeric[float64](b) + case zed.IsFloat(bid): + return b.Float() == ToNumeric[float64](a) + case zed.IsSigned(aid): + av := a.Int() + if zed.IsUnsigned(bid) { + return uint64(av) == b.Uint() && av >= 0 } - id, ok := c.coerceNumbers(aid, bid) - if !ok { - return 0, Overflow + return av == b.Int() + case zed.IsSigned(bid): + bv := b.Int() + if zed.IsUnsigned(aid) { + return uint64(bv) == a.Uint() && bv >= 0 } - return id, nil - } - return 0, IncompatibleTypes -} - -func intToFloat(id int, b zcode.Bytes) float64 { - if zed.IsSigned(id) { - return float64(zed.DecodeInt(b)) - } - return float64(zed.DecodeUint(b)) -} - -func (c *Pair) promoteToSigned(in zcode.Bytes) (zcode.Bytes, bool) { - v := zed.DecodeUint(in) - if v > math.MaxInt64 { - return nil, false + return bv == a.Int() + default: + return a.Uint() == b.Uint() } - return c.Int(int64(v)), true } -func (c *Pair) promoteToUnsigned(in zcode.Bytes) (zcode.Bytes, bool) { - v := zed.DecodeInt(in) - if v < 0 { - return nil, false - } - return c.Uint(uint64(v)), true -} - -func (c *Pair) coerceNumbers(aid, bid int) (int, bool) { - if zed.IsFloat(aid) { - if aid == zed.IDFloat16 { - c.A = c.buf2.Float64(float64(zed.DecodeFloat16(c.A))) - } else if aid == zed.IDFloat32 { - c.A = c.buf2.Float64(float64(zed.DecodeFloat32(c.A))) - } - c.B = c.Float64(intToFloat(bid, c.B)) - return aid, true - } - if zed.IsFloat(bid) { - if bid == zed.IDFloat16 { - c.B = c.buf2.Float64(float64(zed.DecodeFloat16(c.B))) - } else if bid == zed.IDFloat32 { - c.B = c.buf2.Float64(float64(zed.DecodeFloat32(c.B))) - } - c.A = c.Float64(intToFloat(aid, c.A)) - return bid, true - } - aIsSigned := zed.IsSigned(aid) - if aIsSigned == zed.IsSigned(bid) { - // They have the same signed-ness. Promote to the wider - // type by rank and leave the zcode.Bytes as is since - // the varint encoding is the same for all the widths. - // Width increasese with type ID. - id := aid - if bid > id { - id = bid - } - return id, true - } - id := promoteInt(aid, bid) - - // Otherwise, we'll promote mixed signed-ness to signed unless - // the unsigned value is greater than signed maxint, in which - // case, we report an overflow error. - var ok bool - if aIsSigned { - c.B, ok = c.promoteToSigned(c.B) - } else { - c.A, ok = c.promoteToSigned(c.A) - } - if !ok { - // We got overflow trying to turn the unsigned to signed, - // so try turning the signed into unsigned. - if aIsSigned { - c.A, ok = c.promoteToUnsigned(c.A) - } else { - c.B, ok = c.promoteToUnsigned(c.B) - } - id = zed.IDUint64 +func ToNumeric[T constraints.Integer | constraints.Float](val zed.Value) T { + val = val.Under() + switch id := val.Type().ID(); { + case zed.IsUnsigned(id): + return T(val.Uint()) + case zed.IsSigned(id): + return T(val.Int()) + case zed.IsFloat(id): + return T(val.Float()) } - return id, ok + panic(zson.FormatValue(val)) } func ToFloat(val zed.Value) (float64, bool) { + val = val.Under() switch id := val.Type().ID(); { case zed.IsUnsigned(id): return float64(val.Uint()), true @@ -169,6 +70,7 @@ func ToFloat(val zed.Value) (float64, bool) { } func ToUint(val zed.Value) (uint64, bool) { + val = val.Under() switch id := val.Type().ID(); { case zed.IsUnsigned(id): return val.Uint(), true @@ -188,6 +90,7 @@ func ToUint(val zed.Value) (uint64, bool) { } func ToInt(val zed.Value) (int64, bool) { + val = val.Under() switch id := val.Type().ID(); { case zed.IsUnsigned(id): return int64(val.Uint()), true @@ -204,6 +107,7 @@ func ToInt(val zed.Value) (int64, bool) { } func ToBool(val zed.Value) (bool, bool) { + val = val.Under() if val.IsString() { v, err := byteconv.ParseBool(val.Bytes()) return v, err == nil diff --git a/runtime/sam/expr/coerce/promote.go b/runtime/sam/expr/coerce/promote.go index 6799608556..31efbbda2e 100644 --- a/runtime/sam/expr/coerce/promote.go +++ b/runtime/sam/expr/coerce/promote.go @@ -1,10 +1,83 @@ package coerce import ( + "errors" + "math" + "github.com/brimdata/zed" ) -var promote = []int{ +var ErrIncompatibleTypes = errors.New("incompatible types") +var ErrOverflow = errors.New("integer overflow: uint64 value too large for int64") + +func Promote(a, b zed.Value) (int, error) { + a, b = a.Under(), b.Under() + aid, bid := a.Type().ID(), b.Type().ID() + switch { + case aid == bid: + return aid, nil + case aid == zed.IDNull: + return bid, nil + case bid == zed.IDNull: + return aid, nil + case !zed.IsNumber(aid) || !zed.IsNumber(bid): + return 0, ErrIncompatibleTypes + case zed.IsFloat(aid): + if !zed.IsFloat(bid) { + bid = promoteFloat[bid] + } + case zed.IsFloat(bid): + if !zed.IsFloat(aid) { + aid = promoteFloat[aid] + } + case zed.IsSigned(aid): + if zed.IsUnsigned(bid) { + if b.Uint() > math.MaxInt64 { + return 0, ErrOverflow + } + bid = promoteInt[bid] + } + case zed.IsSigned(bid): + if zed.IsUnsigned(aid) { + if a.Uint() > math.MaxInt64 { + return 0, ErrOverflow + } + aid = promoteInt[aid] + } + } + if aid > bid { + return aid, nil + } + return bid, nil +} + +var promoteFloat = []int{ + zed.IDFloat16, // IDUint8 = 0 + zed.IDFloat16, // IDUint16 = 1 + zed.IDFloat32, // IDUint32 = 2 + zed.IDFloat64, // IDUint64 = 3 + zed.IDFloat128, // IDUint128 = 4 + zed.IDFloat256, // IDUint256 = 5 + zed.IDFloat16, // IDInt8 = 6 + zed.IDFloat16, // IDInt16 = 7 + zed.IDFloat32, // IDInt32 = 8 + zed.IDFloat64, // IDInt64 = 9 + zed.IDFloat128, // IDInt128 = 10 + zed.IDFloat256, // IDInt256 = 11 + zed.IDFloat64, // IDDuration = 12 + zed.IDFloat64, // IDTime = 13 + zed.IDFloat16, // IDFloat16 = 14 + zed.IDFloat32, // IDFloat32 = 15 + zed.IDFloat64, // IDFloat64 = 16 + zed.IDFloat128, // IDFloat64 = 17 + zed.IDFloat256, // IDFloat64 = 18 + zed.IDFloat32, // IDDecimal32 = 19 + zed.IDFloat64, // IDDecimal64 = 20 + zed.IDFloat128, // IDDecimal128 = 21 + zed.IDFloat256, // IDDecimal256 = 22 +} + +var promoteInt = []int{ zed.IDInt8, // IDUint8 = 0 zed.IDInt16, // IDUint16 = 1 zed.IDInt32, // IDUint32 = 2 @@ -29,13 +102,3 @@ var promote = []int{ zed.IDDecimal128, // IDDecimal128 = 21 zed.IDDecimal256, // IDDecimal256 = 22 } - -// promoteInt promotes type to the largest signed type where the IDs must both -// satisfy zed.IsNumber. -func promoteInt(aid, bid int) int { - id := promote[aid] - if bid := promote[bid]; bid > id { - id = bid - } - return id -} diff --git a/runtime/sam/expr/eval.go b/runtime/sam/expr/eval.go index 73a71e6b7c..d9492e589d 100644 --- a/runtime/sam/expr/eval.go +++ b/runtime/sam/expr/eval.go @@ -2,6 +2,7 @@ package expr import ( "bytes" + "cmp" "errors" "fmt" "math" @@ -117,7 +118,6 @@ type In struct { zctx *zed.Context elem Evaluator container Evaluator - vals coerce.Pair } func NewIn(zctx *zed.Context, elem, container Evaluator) *In { @@ -138,11 +138,8 @@ func (i *In) Eval(ectx Context, this zed.Value) zed.Value { return container } err := container.Walk(func(typ zed.Type, body zcode.Bytes) error { - if _, err := i.vals.Coerce(elem, zed.NewValue(typ, body)); err != nil { - if err != coerce.IncompatibleTypes { - return err - } - } else if i.vals.Equal() { + tmpVal := zed.NewValue(typ, body) + if coerce.Equal(elem, tmpVal) { return errMatch } return nil @@ -175,22 +172,11 @@ func NewCompareEquality(zctx *zed.Context, lhs, rhs Evaluator, operator string) } func (e *Equal) Eval(ectx Context, this zed.Value) zed.Value { - _, zerr, err := e.numeric.eval(ectx, this) - if zerr != nil { - return *zerr - } - if err != nil { - if errors.Is(err, coerce.IncompatibleTypes) || errors.Is(err, coerce.Overflow) { - // If the types are incompatible or there was overflow, - // then, then we know the values can't be equal. - if e.equality { - return zed.False - } - return zed.True - } - return e.zctx.NewError(err) + lhsVal, rhsVal, errVal := e.numeric.eval(ectx, this) + if errVal != nil { + return *errVal } - result := e.vals.Equal() + result := coerce.Equal(lhsVal, rhsVal) if !e.equality { result = !result } @@ -221,7 +207,6 @@ type numeric struct { zctx *zed.Context lhs Evaluator rhs Evaluator - vals coerce.Pair } func newNumeric(zctx *zed.Context, lhs, rhs Evaluator) numeric { @@ -232,39 +217,40 @@ func newNumeric(zctx *zed.Context, lhs, rhs Evaluator) numeric { } } -func enumify(ectx Context, val zed.Value) zed.Value { - // automatically convert an enum to its index value when coercing - if _, ok := val.Type().(*zed.TypeEnum); ok { - return zed.NewValue(zed.TypeUint64, val.Bytes()) +func (n *numeric) evalAndPromote(ectx Context, zctx *zed.Context, this zed.Value) (zed.Value, zed.Value, zed.Type, *zed.Value) { + lhsVal, rhsVal, errVal := n.eval(ectx, this) + if errVal != nil { + return zed.Null, zed.Null, nil, errVal } - return val + id, err := coerce.Promote(lhsVal, rhsVal) + if err != nil { + return zed.Null, zed.Null, nil, n.zctx.NewError(err).Ptr() + } + typ, err := zctx.LookupType(id) + if err != nil { + return zed.Null, zed.Null, nil, n.zctx.NewError(err).Ptr() + } + return lhsVal, rhsVal, typ, nil } -func (n *numeric) eval(ectx Context, this zed.Value) (int, *zed.Value, error) { +func (n *numeric) eval(ectx Context, this zed.Value) (zed.Value, zed.Value, *zed.Value) { lhs := n.lhs.Eval(ectx, this) if lhs.IsError() { - return 0, &lhs, nil + return zed.Null, zed.Null, &lhs } - lhs = enumify(ectx, lhs) rhs := n.rhs.Eval(ectx, this) if rhs.IsError() { - return 0, &rhs, nil + return zed.Null, zed.Null, &rhs } - rhs = enumify(ectx, rhs) - id, err := n.vals.Coerce(lhs, rhs) - return id, nil, err -} - -func (n *numeric) floats() (float64, float64) { - return zed.DecodeFloat(n.vals.A), zed.DecodeFloat(n.vals.B) + return enumToIndex(ectx, lhs), enumToIndex(ectx, rhs), nil } -func (n *numeric) ints() (int64, int64) { - return zed.DecodeInt(n.vals.A), zed.DecodeInt(n.vals.B) -} - -func (n *numeric) uints() (uint64, uint64) { - return zed.DecodeUint(n.vals.A), zed.DecodeUint(n.vals.B) +// enumToIndex converts an enum to its index value. +func enumToIndex(ectx Context, val zed.Value) zed.Value { + if _, ok := val.Type().(*zed.TypeEnum); ok { + return zed.NewValue(zed.TypeUint64, val.Bytes()) + } + return val } type Compare struct { @@ -291,7 +277,10 @@ func NewCompareRelative(zctx *zed.Context, lhs, rhs Evaluator, operator string) } func (c *Compare) result(result int) zed.Value { - return zed.NewBool(c.convert(result)) + if c.convert(result) { + return zed.True + } + return zed.False } func (c *Compare) Eval(ectx Context, this zed.Value) zed.Value { @@ -303,65 +292,72 @@ func (c *Compare) Eval(ectx Context, this zed.Value) zed.Value { if rhs.IsError() { return rhs } - id, err := c.vals.Coerce(lhs, rhs) - if err != nil { - // If coercion fails due to overflow, then we know there is a - // mixed signed and unsigned situation and either the unsigned - // value couldn't be converted to an int64 because it was too big, - // or the signed value couldn't be converted to a uint64 because - // it was negative. In either case, the unsigned value is bigger - // than the signed value. - if err == coerce.Overflow { - result := 1 - if zed.IsSigned(lhs.Type().ID()) { - result = -1 - } - return c.result(result) + + if lhs.IsNull() { + if rhs.IsNull() { + return c.result(0) } return zed.False + } else if rhs.IsNull() { + // We know lhs isn't null. + return zed.False } - var result int - if !c.vals.Equal() { - switch { - case c.vals.A == nil || c.vals.B == nil: - return zed.False - case zed.IsFloat(id): - v1, v2 := c.floats() - if v1 < v2 { - result = -1 - } else { - result = 1 - } - case zed.IsSigned(id): - v1, v2 := c.ints() - if v1 < v2 { - result = -1 - } else { - result = 1 + + switch lid, rid := lhs.Type().ID(), rhs.Type().ID(); { + case zed.IsNumber(lid) && zed.IsNumber(rid): + return c.result(compareNumbers(lhs, rhs, lid, rid)) + case lid != rid: + return zed.False + case lid == zed.IDBool: + if lhs.Bool() { + if rhs.Bool() { + return c.result(0) } - case zed.IsNumber(id): - v1, v2 := c.uints() - if v1 < v2 { - result = -1 - } else { - result = 1 + + } + case lid == zed.IDBytes: + return c.result(bytes.Compare(zed.DecodeBytes(lhs.Bytes()), zed.DecodeBytes(rhs.Bytes()))) + case lid == zed.IDString: + return c.result(cmp.Compare(zed.DecodeString(lhs.Bytes()), zed.DecodeString(lhs.Bytes()))) + default: + if bytes.Equal(lhs.Bytes(), rhs.Bytes()) { + return c.result(0) + } + } + return zed.False +} + +func compareNumbers(a, b zed.Value, aid, bid int) int { + switch { + case zed.IsFloat(aid): + return cmp.Compare(a.Float(), toFloat(b)) + case zed.IsFloat(bid): + return cmp.Compare(toFloat(a), b.Float()) + case zed.IsSigned(aid): + av := a.Int() + if zed.IsUnsigned(bid) { + if av < 0 { + return -1 } - case id == zed.IDString: - if zed.DecodeString(c.vals.A) < zed.DecodeString(c.vals.B) { - result = -1 - } else { - result = 1 + return cmp.Compare(uint64(av), b.Uint()) + } + return cmp.Compare(av, b.Int()) + case zed.IsSigned(bid): + bv := b.Int() + if zed.IsUnsigned(aid) { + if bv < 0 { + return 1 } - default: - return c.zctx.NewErrorf("bad comparison type ID: %d", id) + return cmp.Compare(a.Uint(), uint64(bv)) } + return cmp.Compare(a.Int(), bv) } - if c.convert(result) { - return zed.True - } - return zed.False + return cmp.Compare(a.Uint(), b.Uint()) } +func toFloat(val zed.Value) float64 { return coerce.ToNumeric[float64](val) } +func toInt(val zed.Value) int64 { return coerce.ToNumeric[int64](val) } + type Add struct { zctx *zed.Context operands numeric @@ -409,153 +405,108 @@ func NewArithmetic(zctx *zed.Context, lhs, rhs Evaluator, op string) (Evaluator, } func (a *Add) Eval(ectx Context, this zed.Value) zed.Value { - id, zerr, err := a.operands.eval(ectx, this) - if err != nil { - return a.zctx.NewError(err) + lhsVal, rhsVal, typ, errVal := a.operands.evalAndPromote(ectx, a.zctx, this) + if errVal != nil { + return *errVal } - if zerr != nil { - return *zerr - } - typ, err := a.zctx.LookupType(id) - if err != nil { - return a.zctx.NewError(err) - } - switch { - case zed.IsFloat(id): - v1, v2 := a.operands.floats() - return zed.NewFloat(typ, v1+v2) + switch id := typ.ID(); { + case zed.IsUnsigned(id): + return zed.NewUint(typ, lhsVal.Uint()+rhsVal.Uint()) case zed.IsSigned(id): - v1, v2 := a.operands.ints() - return zed.NewInt(typ, v1+v2) - case zed.IsNumber(id): - v1, v2 := a.operands.uints() - return zed.NewUint(typ, v1+v2) + return zed.NewInt(typ, toInt(lhsVal)+toInt(rhsVal)) + case zed.IsFloat(id): + return zed.NewFloat(typ, toFloat(lhsVal)+toFloat(rhsVal)) case id == zed.IDString: - v1, v2 := zed.DecodeString(a.operands.vals.A), zed.DecodeString(a.operands.vals.B) - // XXX GC + v1, v2 := zed.DecodeString(lhsVal.Bytes()), zed.DecodeString(rhsVal.Bytes()) return zed.NewValue(typ, zed.EncodeString(v1+v2)) } return a.zctx.NewErrorf("type %s incompatible with '+' operator", zson.FormatType(typ)) } func (s *Subtract) Eval(ectx Context, this zed.Value) zed.Value { - id, zerr, err := s.operands.eval(ectx, this) - if err != nil { - return s.zctx.NewError(err) - } - if zerr != nil { - return *zerr - } - typ, err := s.zctx.LookupType(id) - if err != nil { - return s.zctx.NewError(err) + lhsVal, rhsVal, typ, errVal := s.operands.evalAndPromote(ectx, s.zctx, this) + if errVal != nil { + return *errVal } - switch { - case zed.IsFloat(id): - v1, v2 := s.operands.floats() - return zed.NewFloat(typ, v1-v2) + switch id := typ.ID(); { + case zed.IsUnsigned(id): + return zed.NewUint(typ, lhsVal.Uint()-rhsVal.Uint()) case zed.IsSigned(id): - v1, v2 := s.operands.ints() if id == zed.IDTime { // Return the difference of two times as a duration. typ = zed.TypeDuration } - return zed.NewInt(typ, v1-v2) - case zed.IsNumber(id): - v1, v2 := s.operands.uints() - return zed.NewUint(typ, v1-v2) + return zed.NewInt(typ, toInt(lhsVal)-toInt(rhsVal)) + case zed.IsFloat(id): + return zed.NewFloat(typ, toFloat(lhsVal)-toFloat(rhsVal)) } return s.zctx.NewErrorf("type %s incompatible with '-' operator", zson.FormatType(typ)) } func (m *Multiply) Eval(ectx Context, this zed.Value) zed.Value { - id, zerr, err := m.operands.eval(ectx, this) - if err != nil { - return m.zctx.NewError(err) - } - if zerr != nil { - return *zerr + lhsVal, rhsVal, typ, errVal := m.operands.evalAndPromote(ectx, m.zctx, this) + if errVal != nil { + return *errVal } - typ, err := m.zctx.LookupType(id) - if err != nil { - return m.zctx.NewError(err) - } - switch { - case zed.IsFloat(id): - v1, v2 := m.operands.floats() - return zed.NewFloat(typ, v1*v2) + switch id := typ.ID(); { + case zed.IsUnsigned(id): + return zed.NewUint(typ, lhsVal.Uint()*rhsVal.Uint()) case zed.IsSigned(id): - v1, v2 := m.operands.ints() - return zed.NewInt(typ, v1*v2) - case zed.IsNumber(id): - v1, v2 := m.operands.uints() - return zed.NewUint(typ, v1*v2) + return zed.NewInt(typ, toInt(lhsVal)*toInt(rhsVal)) + case zed.IsFloat(id): + return zed.NewFloat(typ, toFloat(lhsVal)*toFloat(rhsVal)) } return m.zctx.NewErrorf("type %s incompatible with '*' operator", zson.FormatType(typ)) } func (d *Divide) Eval(ectx Context, this zed.Value) zed.Value { - id, zerr, err := d.operands.eval(ectx, this) - if err != nil { - return d.zctx.NewError(err) - } - if zerr != nil { - return *zerr - } - typ, err := d.zctx.LookupType(id) - if err != nil { - return d.zctx.NewError(err) - } - switch { - case zed.IsFloat(id): - v1, v2 := d.operands.floats() - if v2 == 0 { + lhsVal, rhsVal, typ, errVal := d.operands.evalAndPromote(ectx, d.zctx, this) + if errVal != nil { + return *errVal + } + switch id := typ.ID(); { + case zed.IsUnsigned(id): + v := rhsVal.Uint() + if v == 0 { return d.zctx.NewError(DivideByZero) } - return zed.NewFloat(typ, v1/v2) + return zed.NewUint(typ, lhsVal.Uint()/v) case zed.IsSigned(id): - v1, v2 := d.operands.ints() - if v2 == 0 { + v := toInt(rhsVal) + if v == 0 { return d.zctx.NewError(DivideByZero) } - return zed.NewInt(typ, v1/v2) - case zed.IsNumber(id): - v1, v2 := d.operands.uints() - if v2 == 0 { + return zed.NewInt(typ, toInt(lhsVal)/v) + case zed.IsFloat(id): + v := toFloat(rhsVal) + if v == 0 { return d.zctx.NewError(DivideByZero) } - return zed.NewUint(typ, v1/v2) + return zed.NewFloat(typ, toFloat(lhsVal)/v) } return d.zctx.NewErrorf("type %s incompatible with '/' operator", zson.FormatType(typ)) } func (m *Modulo) Eval(ectx Context, this zed.Value) zed.Value { - id, zerr, err := m.operands.eval(ectx, this) - if err != nil { - return m.zctx.NewError(err) - } - if zerr != nil { - return *zerr - } - typ, err := m.zctx.LookupType(id) - if err != nil { - return m.zctx.NewError(err) - } - if zed.IsFloat(id) || !zed.IsNumber(id) { - return m.zctx.NewErrorf("type %s incompatible with '%%' operator", zson.FormatType(typ)) - } - if zed.IsSigned(id) { - x, y := m.operands.ints() - if y == 0 { + lhsVal, rhsVal, typ, errVal := m.operands.evalAndPromote(ectx, m.zctx, this) + if errVal != nil { + return *errVal + } + switch id := typ.ID(); { + case zed.IsUnsigned(id): + v := rhsVal.Uint() + if v == 0 { return m.zctx.NewError(DivideByZero) } - return zed.NewInt(typ, x%y) - } - x, y := m.operands.uints() - if y == 0 { - return m.zctx.NewError(DivideByZero) + return zed.NewUint(typ, lhsVal.Uint()%v) + case zed.IsSigned(id): + v := toInt(rhsVal) + if v == 0 { + return m.zctx.NewError(DivideByZero) + } + return zed.NewInt(typ, toInt(lhsVal)%v) } - return zed.NewUint(typ, x%y) + return m.zctx.NewErrorf("type %s incompatible with '%%' operator", zson.FormatType(typ)) } type UnaryMinus struct { @@ -799,6 +750,34 @@ func (c *Call) Eval(ectx Context, this zed.Value) zed.Value { return c.fn.Call(ectx, c.args) } +func NewCast(zctx *zed.Context, expr Evaluator, typ zed.Type) (Evaluator, error) { + // XXX should handle named type casts. need type context. + // compile is going to need a local type context to create literals + // of complex types? + c := LookupPrimitiveCaster(zctx, typ) + if c == nil { + // XXX See issue #1572. To implement named cast here. + return nil, fmt.Errorf("cast to %q not implemented", zson.FormatType(typ)) + } + return &evalCast{expr, c, typ}, nil +} + +type evalCast struct { + expr Evaluator + caster Evaluator + typ zed.Type +} + +func (c *evalCast) Eval(ectx Context, this zed.Value) zed.Value { + val := c.expr.Eval(ectx, this) + if val.IsNull() || val.Type() == c.typ { + // If value is null or the type won't change, just return a + // copy of the value. + return zed.NewValue(c.typ, val.Bytes()) + } + return c.caster.Eval(ectx, val) +} + type Assignment struct { LHS *Lval RHS Evaluator diff --git a/runtime/sam/expr/sort.go b/runtime/sam/expr/sort.go index 0135c82885..8df082be32 100644 --- a/runtime/sam/expr/sort.go +++ b/runtime/sam/expr/sort.go @@ -10,7 +10,6 @@ import ( "github.com/brimdata/zed" "github.com/brimdata/zed/order" - "github.com/brimdata/zed/runtime/sam/expr/coerce" "github.com/brimdata/zed/zcode" "github.com/brimdata/zed/zio" ) @@ -72,7 +71,7 @@ func (c *Comparator) sortStableIndices(vals []zed.Value) []uint32 { ival = expr.Eval(ectx, vals[iidx]) jval = expr.Eval(ectx, vals[jidx]) } - if v := compareValues(ival, jval, c.comparefns, &c.pair, c.nullsMax); v != 0 { + if v := compareValues(ival, jval, c.nullsMax); v != 0 { return v < 0 } } @@ -98,13 +97,10 @@ func NewValueCompareFn(o order.Which, nullsMax bool) CompareFn { } type Comparator struct { + ectx Context exprs []Evaluator nullsMax bool reverse bool - - comparefns map[zed.Type]comparefn - ectx Context - pair coerce.Pair } type comparefn func(a, b zcode.Bytes) int @@ -116,11 +112,10 @@ type comparefn func(a, b zcode.Bytes) int // reverse reverses the sense of comparisons. func NewComparator(nullsMax, reverse bool, exprs ...Evaluator) *Comparator { return &Comparator{ - exprs: slices.Clone(exprs), - nullsMax: nullsMax, - reverse: reverse, - comparefns: make(map[zed.Type]comparefn), - ectx: NewContext(), + ectx: NewContext(), + exprs: slices.Clone(exprs), + nullsMax: nullsMax, + reverse: reverse, } } @@ -152,14 +147,14 @@ func (c *Comparator) Compare(a, b zed.Value) int { for _, k := range c.exprs { aval := k.Eval(c.ectx, a) bval := k.Eval(c.ectx, b) - if v := compareValues(aval, bval, c.comparefns, &c.pair, c.nullsMax); v != 0 { + if v := compareValues(aval, bval, c.nullsMax); v != 0 { return v } } return 0 } -func compareValues(a, b zed.Value, comparefns map[zed.Type]comparefn, pair *coerce.Pair, nullsMax bool) int { +func compareValues(a, b zed.Value, nullsMax bool) int { // Handle nulls according to nullsMax nullA := a.IsNull() nullB := b.IsNull() @@ -180,27 +175,54 @@ func compareValues(a, b zed.Value, comparefns map[zed.Type]comparefn, pair *coer return 1 } } - - typ := a.Type() - abytes, bbytes := a.Bytes(), b.Bytes() - if a.Type().ID() != b.Type().ID() { - id, err := pair.Coerce(a, b) - if err == nil { - typ, err = zed.LookupPrimitiveByID(id) - } - if err != nil { - return zed.CompareTypes(a.Type(), b.Type()) + switch aid, bid := a.Type().ID(), b.Type().ID(); { + case zed.IsNumber(aid) && zed.IsNumber(bid): + return compareNumbers(a, b, aid, bid) + case aid != bid: + return zed.CompareTypes(a.Type(), b.Type()) + case aid == zed.IDBool: + if av, bv := a.Bool(), b.Bool(); av == bv { + return 0 + } else if av { + return 1 } - abytes, bbytes = pair.A, pair.B + return -1 + case aid == zed.IDBytes: + return bytes.Compare(zed.DecodeBytes(a.Bytes()), zed.DecodeBytes(b.Bytes())) + case aid == zed.IDString: + return cmp.Compare(zed.DecodeString(a.Bytes()), zed.DecodeString(b.Bytes())) + case aid == zed.IDIP: + return zed.DecodeIP(a.Bytes()).Compare(zed.DecodeIP(b.Bytes())) + case aid == zed.IDType: + zctx := zed.NewContext() // XXX This is expensive. + // XXX This isn't cheap eventually we should add + // zed.CompareTypeValues(a, b zcode.Bytes). + av, _ := zctx.DecodeTypeValue(a.Bytes()) + bv, _ := zctx.DecodeTypeValue(b.Bytes()) + return zed.CompareTypes(av, bv) } - - cfn, ok := comparefns[typ] - if !ok { - cfn = LookupCompare(typ) - comparefns[typ] = cfn + // XXX record support easy to add here if we moved the creation of the + // field resolvers into this package. + if innerType := zed.InnerType(a.Type()); innerType != nil { + ait, bit := a.Iter(), b.Iter() + for { + if ait.Done() { + if bit.Done() { + return 0 + } + return -1 + } + if bit.Done() { + return 1 + } + aa := zed.NewValue(innerType, ait.Next()) + bb := zed.NewValue(innerType, bit.Next()) + if v := compareValues(aa, bb, nullsMax); v != 0 { + return v + } + } } - - return cfn(abytes, bbytes) + return bytes.Compare(a.Bytes(), b.Bytes()) } // SortStable sorts vals according to c, with equal values in their original diff --git a/runtime/vam/expr/coerce.go b/runtime/vam/expr/coerce.go index 68e00b4a36..82e7ab8c6d 100644 --- a/runtime/vam/expr/coerce.go +++ b/runtime/vam/expr/coerce.go @@ -31,7 +31,7 @@ func coerceVals(zctx *zed.Context, a, b vector.Any) (vector.Any, vector.Any, vec return a, b, nil //XXX } if !zed.IsNumber(aid) || !zed.IsNumber(bid) { - return nil, nil, vector.NewStringError(zctx, coerce.IncompatibleTypes.Error(), a.Len()) + return nil, nil, vector.NewStringError(zctx, coerce.ErrIncompatibleTypes.Error(), a.Len()) } // Both a and b are numbers. We need to promote to a common // type based on Zed's coercion rules. diff --git a/runtime/vam/expr/compare.go b/runtime/vam/expr/compare.go index 703f216eaa..1be8fd4d22 100644 --- a/runtime/vam/expr/compare.go +++ b/runtime/vam/expr/compare.go @@ -46,7 +46,7 @@ func (c *Compare) Eval(val vector.Any) vector.Any { return compareUints(op, lhs, rhs) default: //XXX incompatible types - return vector.NewStringError(c.zctx, coerce.IncompatibleTypes.Error(), lhs.Len()) + return vector.NewStringError(c.zctx, coerce.ErrIncompatibleTypes.Error(), lhs.Len()) } }