From 201842d027fef2220bf7291ae3ab425eb599f599 Mon Sep 17 00:00:00 2001 From: xrstf Date: Tue, 28 Nov 2023 23:24:24 +0100 Subject: [PATCH] use generic algorithm to base equality checks solely on the coalescer --- aliases.go | 12 - pkg/coalescing/coalescer.go | 4 +- pkg/coalescing/humane.go | 132 +++++++-- pkg/coalescing/humane_test.go | 49 ++++ pkg/coalescing/strict.go | 38 +-- pkg/coalescing/util_test.go | 166 +++++++++++ pkg/deepcopy/deepcopy_test.go | 4 +- pkg/equality/coalesced.go | 411 +++++++++++++++++++++++++++ pkg/equality/coalesced_test.go | 373 ++++++++++++++++++++++++ pkg/equality/loose.go | 282 ------------------ pkg/equality/strict.go | 154 ---------- pkg/eval/builtin/comparisons.go | 45 +-- pkg/eval/builtin/comparisons_test.go | 15 +- pkg/eval/builtin/lists.go | 13 +- pkg/eval/builtin/types_test.go | 20 ++ pkg/testutil/php/php.go | 62 ++++ pkg/testutil/testcase.go | 2 +- program.go | 5 +- 18 files changed, 1231 insertions(+), 556 deletions(-) create mode 100644 pkg/coalescing/humane_test.go create mode 100644 pkg/coalescing/util_test.go create mode 100644 pkg/equality/coalesced.go create mode 100644 pkg/equality/coalesced_test.go delete mode 100644 pkg/equality/loose.go delete mode 100644 pkg/equality/strict.go create mode 100644 pkg/testutil/php/php.go diff --git a/aliases.go b/aliases.go index 7d94527..7bc8fa3 100644 --- a/aliases.go +++ b/aliases.go @@ -59,15 +59,3 @@ func NewVariables() Variables { func NewDocument(data any) (Document, error) { return types.NewDocument(data) } - -// Unwrap returns the native Go value for either native Go values or an -// Rudi AST node (like turning an ast.Number into an int64). -func Unwrap(val any) (any, error) { - return types.UnwrapType(val) -} - -// WrapNative returns the Rudi node equivalent of a native Go value, like turning -// a string into ast.String. -func WrapNative(val any) (any, error) { - return types.WrapNative(val) -} diff --git a/pkg/coalescing/coalescer.go b/pkg/coalescing/coalescer.go index 2ec00c1..a7dba81 100644 --- a/pkg/coalescing/coalescer.go +++ b/pkg/coalescing/coalescer.go @@ -10,14 +10,14 @@ import ( ) type Coalescer interface { + ToNull(val any) (bool, error) ToBool(val any) (bool, error) - ToFloat64(val any) (float64, error) ToInt64(val any) (int64, error) + ToFloat64(val any) (float64, error) ToNumber(val any) (ast.Number, error) ToString(val any) (string, error) ToVector(val any) ([]any, error) ToObject(val any) (map[string]any, error) - ToNull(val any) (bool, error) } func deliteral(val any) any { diff --git a/pkg/coalescing/humane.go b/pkg/coalescing/humane.go index 387da01..549b7f0 100644 --- a/pkg/coalescing/humane.go +++ b/pkg/coalescing/humane.go @@ -4,6 +4,7 @@ package coalescing import ( + "errors" "fmt" "strconv" "strings" @@ -24,7 +25,50 @@ func (humane) ToNull(val any) (bool, error) { case nil: return true, nil case bool: - return !v, nil + if v { + return false, fmt.Errorf("cannot coalesce true into null") + } + return true, nil + case int: + if v != 0 { + return false, fmt.Errorf("cannot coalesce %v (%T) into null", v, v) + } + return true, nil + case int32: + if v != 0 { + return false, fmt.Errorf("cannot coalesce %v (%T) into null", v, v) + } + return true, nil + case int64: + if v != 0 { + return false, fmt.Errorf("cannot coalesce %v (%T) into null", v, v) + } + return true, nil + case float32: + if v != 0 { + return false, fmt.Errorf("cannot coalesce %v (%T) into null", v, v) + } + return true, nil + case float64: + if v != 0 { + return false, fmt.Errorf("cannot coalesce %v (%T) into null", v, v) + } + return true, nil + case string: + if len(v) != 0 { + return false, fmt.Errorf("cannot coalesce %q (%T) into null", v, v) + } + return true, nil + case []any: + if len(v) != 0 { + return false, errors.New("cannot coalesce non-empty vector into null") + } + return true, nil + case map[string]any: + if len(v) != 0 { + return false, errors.New("cannot coalesce non-empty object into null") + } + return true, nil default: return false, fmt.Errorf("cannot coalesce %T into null", v) } @@ -32,6 +76,8 @@ func (humane) ToNull(val any) (bool, error) { func (humane) ToBool(val any) (bool, error) { switch v := deliteral(val).(type) { + case nil: + return false, nil case bool: return v, nil case int: @@ -54,8 +100,6 @@ func (humane) ToBool(val any) (bool, error) { return len(v) > 0, nil case map[string]any: return len(v) > 0, nil - case nil: - return false, nil default: return false, fmt.Errorf("cannot coalesce %T into bool", v) } @@ -63,6 +107,14 @@ func (humane) ToBool(val any) (bool, error) { func (humane) ToFloat64(val any) (float64, error) { switch v := deliteral(val).(type) { + case nil: + return 0, nil + case bool: + if v { + return 1, nil + } else { + return 0, nil + } case int: return float64(v), nil case int32: @@ -74,19 +126,16 @@ func (humane) ToFloat64(val any) (float64, error) { case float64: return v, nil case string: + v = strings.TrimSpace(v) + if v == "" { + return 0, nil + } + parsed, err := strconv.ParseFloat(v, 64) if err != nil { - return 0, fmt.Errorf("cannot convert %q losslessly to float64", v) + return 0, fmt.Errorf("cannot coalesce %T into float64", v) } return parsed, nil - case bool: - if v { - return 1, nil - } else { - return 0, nil - } - case nil: - return 0, nil default: return 0, fmt.Errorf("cannot coalesce %T into float64", v) } @@ -94,26 +143,47 @@ func (humane) ToFloat64(val any) (float64, error) { func (humane) ToInt64(val any) (int64, error) { switch v := deliteral(val).(type) { + case nil: + return 0, nil + case bool: + if v { + return 1, nil + } else { + return 0, nil + } case int: return int64(v), nil case int32: return int64(v), nil case int64: return v, nil + case float32: + if v == float32(int32(v)) { + return int64(v), nil + } + return 0, fmt.Errorf("cannot convert %s losslessly to int64", formatFloat(float64(v))) + case float64: + if v == float64(int64(v)) { + return int64(v), nil + } + return 0, fmt.Errorf("cannot convert %s losslessly to int64", formatFloat(v)) case string: + v = strings.TrimSpace(v) + if v == "" { + return 0, nil + } + parsed, err := strconv.ParseInt(v, 10, 64) if err != nil { - return 0, fmt.Errorf("cannot convert %q losslessly to int64", v) + // allows "2.0" to turn into int64(2) + parsed, err := strconv.ParseFloat(v, 64) + if err == nil && parsed == float64(int64(parsed)) { + return int64(parsed), nil + } + + return 0, fmt.Errorf("cannot coalesce %T into int64", v) } return parsed, nil - case bool: - if v { - return 1, nil - } else { - return 0, nil - } - case nil: - return 0, nil default: return 0, fmt.Errorf("cannot coalesce %T into int64", v) } @@ -125,8 +195,8 @@ func (h humane) ToNumber(val any) (ast.Number, error) { func (humane) ToString(val any) (string, error) { switch v := deliteral(val).(type) { - case string: - return v, nil + case nil: + return "", nil case bool: return strconv.FormatBool(v), nil case int: @@ -137,8 +207,8 @@ func (humane) ToString(val any) (string, error) { return strconv.FormatInt(v, 10), nil case float64: return formatFloat(v), nil - case nil: - return "", nil + case string: + return v, nil default: return "", fmt.Errorf("cannot coalesce %T into string", v) } @@ -159,6 +229,12 @@ func (humane) ToVector(val any) ([]any, error) { return []any{}, nil case []any: return v, nil + case map[string]any: + if len(v) == 0 { + return []any{}, nil + } else { + return nil, fmt.Errorf("cannot coalesce %T into vector", v) + } default: return nil, fmt.Errorf("cannot coalesce %T into vector", v) } @@ -168,6 +244,12 @@ func (humane) ToObject(val any) (map[string]any, error) { switch v := deliteral(val).(type) { case nil: return map[string]any{}, nil + case []any: + if len(v) == 0 { + return map[string]any{}, nil + } else { + return nil, fmt.Errorf("cannot coalesce %T into object", v) + } case map[string]any: return v, nil default: diff --git a/pkg/coalescing/humane_test.go b/pkg/coalescing/humane_test.go new file mode 100644 index 0000000..5b4a28b --- /dev/null +++ b/pkg/coalescing/humane_test.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2023 Christoph Mewes +// SPDX-License-Identifier: MIT + +package coalescing + +import ( + "testing" +) + +func TestHumaneCoalescer(t *testing.T) { + testCoalescer(t, NewHumane(), getHumaneTestcases()) +} + +func getHumaneTestcases() []testcase { + return []testcase{ + // (source, canBeNull, toBool, toInt, toFloat, toNumber, toString, toVector, toObject) + // nil source value + newTestcase(nil, true, false, int64(0), 0.0, newNum(int64(0)), "", []any{}, map[string]any{}), + // boolean source values + newTestcase(true, invalid, true, int64(1), 1.0, newNum(int64(1)), "true", invalid, invalid), + newTestcase(false, true, false, int64(0), 0.0, newNum(int64(0)), "false", invalid, invalid), + // numeric source values + newTestcase(0, true, false, int64(0), 0.0, newNum(int64(0)), "0", invalid, invalid), + newTestcase(0.0, true, false, int64(0), 0.0, newNum(int64(0)), "0", invalid, invalid), + newTestcase(0.1, invalid, true, invalid, 0.1, newNum(0.1), "0.1", invalid, invalid), + newTestcase(1, invalid, true, int64(1), 1.0, newNum(int64(1)), "1", invalid, invalid), + newTestcase(1.0, invalid, true, int64(1), 1.0, newNum(int64(1)), "1", invalid, invalid), + newTestcase(-3.14, invalid, true, invalid, -3.14, newNum(-3.14), "-3.14", invalid, invalid), + // string source values + newTestcase("", true, false, int64(0), 0.0, newNum(int64(0)), "", invalid, invalid), + newTestcase(" ", invalid, true, int64(0), 0.0, newNum(int64(0)), " ", invalid, invalid), + newTestcase("\n", invalid, true, int64(0), 0.0, newNum(int64(0)), "\n", invalid, invalid), + newTestcase("0", invalid, false, int64(0), 0.0, newNum(int64(0)), "0", invalid, invalid), + newTestcase("000", invalid, true, int64(0), 0.0, newNum(int64(0)), "000", invalid, invalid), + newTestcase(" 0 ", invalid, true, int64(0), 0.0, newNum(int64(0)), " 0 ", invalid, invalid), + newTestcase(" 000 ", invalid, true, int64(0), 0.0, newNum(int64(0)), " 000 ", invalid, invalid), + newTestcase("1", invalid, true, int64(1), 1.0, newNum(int64(1)), "1", invalid, invalid), + newTestcase("001", invalid, true, int64(1), 1.0, newNum(int64(1)), "001", invalid, invalid), + newTestcase("1.0", invalid, true, int64(1), 1.0, newNum(int64(1)), "1.0", invalid, invalid), + newTestcase(" 1.1 ", invalid, true, invalid, 1.1, newNum(1.1), " 1.1 ", invalid, invalid), + // vector source values + newTestcase([]any{}, true, false, invalid, invalid, invalid, invalid, []any{}, map[string]any{}), + newTestcase([]any{""}, invalid, true, invalid, invalid, invalid, invalid, []any{""}, invalid), + newTestcase([]any{"foo"}, invalid, true, invalid, invalid, invalid, invalid, []any{"foo"}, invalid), + // object source values + newTestcase(map[string]any{}, true, false, invalid, invalid, invalid, invalid, []any{}, map[string]any{}), + newTestcase(map[string]any{"": ""}, invalid, true, invalid, invalid, invalid, invalid, invalid, map[string]any{"": ""}), + } +} diff --git a/pkg/coalescing/strict.go b/pkg/coalescing/strict.go index 29796f3..d043933 100644 --- a/pkg/coalescing/strict.go +++ b/pkg/coalescing/strict.go @@ -28,10 +28,10 @@ func (strict) ToNull(val any) (bool, error) { func (strict) ToBool(val any) (bool, error) { switch v := deliteral(val).(type) { - case bool: - return v, nil case nil: return false, nil + case bool: + return v, nil default: return false, fmt.Errorf("cannot coalesce %T into bool", v) } @@ -39,6 +39,8 @@ func (strict) ToBool(val any) (bool, error) { func (strict) ToFloat64(val any) (float64, error) { switch v := deliteral(val).(type) { + case nil: + return 0, nil case int: return float64(v), nil case int32: @@ -49,8 +51,6 @@ func (strict) ToFloat64(val any) (float64, error) { return float64(v), nil case float64: return v, nil - case nil: - return 0, nil default: return 0, fmt.Errorf("cannot coalesce %T into float64", v) } @@ -58,20 +58,24 @@ func (strict) ToFloat64(val any) (float64, error) { func (strict) ToInt64(val any) (int64, error) { switch v := deliteral(val).(type) { + case nil: + return 0, nil case int: return int64(v), nil case int32: return int64(v), nil case int64: return v, nil - case ast.Number: - intVal, ok := v.ToInteger() - if !ok { - return 0, fmt.Errorf("cannot convert %f losslessly to int64", val) + case float32: + if v == float32(int32(v)) { + return int64(v), nil } - return intVal, nil - case nil: - return 0, nil + return 0, fmt.Errorf("cannot convert %s losslessly to int64", formatFloat(float64(v))) + case float64: + if v == float64(int64(v)) { + return int64(v), nil + } + return 0, fmt.Errorf("cannot convert %s losslessly to int64", formatFloat(v)) default: return 0, fmt.Errorf("cannot coalesce %T into int64", v) } @@ -83,10 +87,10 @@ func (s strict) ToNumber(val any) (ast.Number, error) { func (strict) ToString(val any) (string, error) { switch v := deliteral(val).(type) { - case string: - return v, nil case nil: return "", nil + case string: + return v, nil default: return "", fmt.Errorf("cannot coalesce %T into string", v) } @@ -94,10 +98,10 @@ func (strict) ToString(val any) (string, error) { func (strict) ToVector(val any) ([]any, error) { switch v := deliteral(val).(type) { - case []any: - return v, nil case nil: return []any{}, nil + case []any: + return v, nil default: return nil, fmt.Errorf("cannot coalesce %T into vector", v) } @@ -105,10 +109,10 @@ func (strict) ToVector(val any) ([]any, error) { func (strict) ToObject(val any) (map[string]any, error) { switch v := deliteral(val).(type) { - case map[string]any: - return v, nil case nil: return map[string]any{}, nil + case map[string]any: + return v, nil default: return nil, fmt.Errorf("cannot coalesce %T into object", v) } diff --git a/pkg/coalescing/util_test.go b/pkg/coalescing/util_test.go new file mode 100644 index 0000000..a5136a0 --- /dev/null +++ b/pkg/coalescing/util_test.go @@ -0,0 +1,166 @@ +// SPDX-FileCopyrightText: 2023 Christoph Mewes +// SPDX-License-Identifier: MIT + +package coalescing + +import ( + "fmt" + "strings" + "testing" + + "go.xrstf.de/rudi/pkg/lang/ast" + + "github.com/google/go-cmp/cmp" +) + +type invalidConversion int + +const invalid invalidConversion = iota + +type testcase struct { + value any + toNull any + toBool any + toInt any + toFloat any + toNumber any + toString any + toVector any + toObject any +} + +func newTestcase(value, toNull, toBool, toInt, toFloat, toNumber, toString, toVector, toObject any) testcase { + return testcase{ + value: value, + toNull: toNull, + toBool: toBool, + toInt: toInt, + toFloat: toFloat, + toNumber: toNumber, + toString: toString, + toVector: toVector, + toObject: toObject, + } +} + +func newNum(value any) ast.Number { + return ast.Number{Value: value} +} + +func testCoalescer(t *testing.T, coalescer Coalescer, testcases []testcase) { + t.Helper() + + testConversion( + t, + coalescer, + testcases, + "null", + func(val any) (any, error) { return coalescer.ToNull(val) }, + func(tc testcase) any { return tc.toNull }, + ) + + testConversion( + t, + coalescer, + testcases, + "bool", + func(val any) (any, error) { return coalescer.ToBool(val) }, + func(tc testcase) any { return tc.toBool }, + ) + + testConversion( + t, + coalescer, + testcases, + "int", + func(val any) (any, error) { return coalescer.ToInt64(val) }, + func(tc testcase) any { return tc.toInt }, + ) + + testConversion( + t, + coalescer, + testcases, + "float", + func(val any) (any, error) { return coalescer.ToFloat64(val) }, + func(tc testcase) any { return tc.toFloat }, + ) + + testConversion( + t, + coalescer, + testcases, + "number", + func(val any) (any, error) { return coalescer.ToNumber(val) }, + func(tc testcase) any { return tc.toNumber }, + ) + + testConversion( + t, + coalescer, + testcases, + "string", + func(val any) (any, error) { return coalescer.ToString(val) }, + func(tc testcase) any { return tc.toString }, + ) + + testConversion( + t, + coalescer, + testcases, + "vector", + func(val any) (any, error) { return coalescer.ToVector(val) }, + func(tc testcase) any { return tc.toVector }, + ) + + testConversion( + t, + coalescer, + testcases, + "object", + func(val any) (any, error) { return coalescer.ToObject(val) }, + func(tc testcase) any { return tc.toObject }, + ) +} + +func testConversion( + t *testing.T, + coalescer Coalescer, + testcases []testcase, + name string, + convert func(any) (any, error), + getExpected func(testcase) any, +) { + t.Helper() + + for _, tc := range testcases { + t.Run(fmt.Sprintf("%v to %s", tc.value, name), func(t *testing.T) { + expected := getExpected(tc) + _, expectErr := expected.(invalidConversion) + + result, err := convert(tc.value) + if err != nil { + if !expectErr { + t.Fatalf("Failed to do %s(%s): %v", name, printValue(tc.value), err) + } + + return + } + + if expectErr { + t.Fatalf("Should not have been able to do %s(%s), but got: %v", name, printValue(tc.value), result) + } + + if !cmp.Equal(result, expected) { + t.Fatalf("expected %s(%s) => %s, but got %s", name, printValue(tc.value), printValue(expected), printValue(result)) + } + }) + } +} + +func printValue(val any) string { + t := fmt.Sprintf("%T(%#v)", val, val) + t = strings.ReplaceAll(t, "interface {}", "any") + + return t +} diff --git a/pkg/deepcopy/deepcopy_test.go b/pkg/deepcopy/deepcopy_test.go index 5e84b10..a9effa6 100644 --- a/pkg/deepcopy/deepcopy_test.go +++ b/pkg/deepcopy/deepcopy_test.go @@ -450,7 +450,7 @@ func TestCloneObjectPointer(t *testing.T) { } input.Data["new"] = "new-value" - if _, ok := (*cloned).Data["new"]; ok { + if _, ok := cloned.Data["new"]; ok { t.Fatal("Both input and output data point to the same memory address, no actual cloning happened.") } } @@ -464,7 +464,7 @@ func TestCloneVectorPointer(t *testing.T) { } input.Data[1] = "new" - if (*cloned).Data[1] == "new" { + if cloned.Data[1] == "new" { t.Fatal("Both input and output data point to the same memory address, no actual cloning happened.") } } diff --git a/pkg/equality/coalesced.go b/pkg/equality/coalesced.go new file mode 100644 index 0000000..c471fcb --- /dev/null +++ b/pkg/equality/coalesced.go @@ -0,0 +1,411 @@ +// SPDX-FileCopyrightText: 2023 Christoph Mewes +// SPDX-License-Identifier: MIT + +package equality + +import ( + "errors" + "fmt" + + "go.xrstf.de/rudi/pkg/coalescing" + "go.xrstf.de/rudi/pkg/lang/ast" +) + +var ErrIncompatibleTypes = errors.New("types are incompatible") + +func deliteral(val any) any { + lit, ok := val.(ast.Literal) + if ok { + return lit.LiteralValue() + } + + return val +} + +// Re-use a pedantic coalescer to not repeat the logic to turn int/int32/int64 into int64. +var typeChecker = coalescing.NewPedantic() + +func EqualCoalesced(c coalescing.Coalescer, left, right any) (bool, error) { + if c == nil { + c = coalescing.NewStrict() + } + + left = deliteral(left) + right = deliteral(right) + + // if either of the sides is a null, convert the other to null + matched, equal, err := nullishEqualCoalesced(c, left, right) + if err != nil { + return false, err + } + if matched { + return equal, nil + } + + // if either of the sides is a bool, convert the other to bool + matched, equal, err = boolishEqualCoalesced(c, left, right) + if err != nil { + return false, err + } + if matched { + return equal, nil + } + + // if either of the sides is a int, convert the other to a int + matched, equal, err = intEqualCoalesced(c, left, right) + if err != nil { + return false, err + } + if matched { + return equal, nil + } + + // if either of the sides is a float, convert the other to a float + matched, equal, err = floatEqualCoalesced(c, left, right) + if err != nil { + return false, err + } + if matched { + return equal, nil + } + + // if either of the sides is a string, convert the other to a string + matched, equal, err = stringishEqualCoalesced(c, left, right) + if err != nil { + return false, err + } + if matched { + return equal, nil + } + + // if either of the sides is a vector, convert the other to a vector + matched, equal, err = vectorishEqualCoalesced(c, left, right) + if err != nil { + return false, err + } + if matched { + return equal, nil + } + + // now only objects are left + matched, equal, err = objectishEqualCoalesced(c, left, right) + if err != nil { + return false, err + } + if matched { + return equal, nil + } + + return false, fmt.Errorf("cannot compare with %T with %T", left, right) +} + +func nullishEqualCoalesced(c coalescing.Coalescer, left any, right any) (matched bool, equal bool, err error) { + leftOk := left == nil + rightOk := right == nil + + if !leftOk && !rightOk { + return false, false, nil + } + + matched = true + + if leftOk && rightOk { + return matched, true, nil + } + + var other any + + if leftOk { + other = right + } else { + other = left + } + + isNullish, err := c.ToNull(other) + if err != nil { + return matched, false, err + } + + return matched, isNullish, nil +} + +func boolishEqualCoalesced(c coalescing.Coalescer, left any, right any) (matched bool, equal bool, err error) { + leftBool, leftOk := left.(bool) + rightBool, rightOk := right.(bool) + + if !leftOk && !rightOk { + return false, false, nil + } + + matched = true + + if leftOk && rightOk { + return matched, leftBool == rightBool, nil + } + + var ( + a bool + b any + ) + + if leftOk { + a = leftBool + b = right + } else { + a = rightBool + b = left + } + + bValue, err := c.ToBool(b) + if err != nil { + return matched, false, err + } + + return matched, a == bValue, nil +} + +func intEqualCoalesced(c coalescing.Coalescer, left any, right any) (matched bool, equal bool, err error) { + leftInt, leftErr := typeChecker.ToInt64(left) + rightInt, rightErr := typeChecker.ToInt64(right) + + if leftErr != nil && rightErr != nil { + return false, false, nil + } + + matched = true + + if leftErr == nil && rightErr == nil { + return matched, leftInt == rightInt, nil + } + + var ( + a int64 + b any + ) + + if leftErr == nil { + a = leftInt + b = right + } else { + a = rightInt + b = left + } + + bValue, err := c.ToInt64(b) + if err != nil { + return matched, false, err + } + + return matched, a == bValue, nil +} + +func floatEqualCoalesced(c coalescing.Coalescer, left any, right any) (matched bool, equal bool, err error) { + leftFloat, leftErr := typeChecker.ToFloat64(left) + rightFloat, rightErr := typeChecker.ToFloat64(right) + + if leftErr != nil && rightErr != nil { + return false, false, nil + } + + matched = true + + if leftErr == nil && rightErr == nil { + return matched, leftFloat == rightFloat, nil + } + + var ( + a float64 + b any + ) + + if leftErr == nil { + a = leftFloat + b = right + } else { + a = rightFloat + b = left + } + + bValue, err := c.ToFloat64(b) + if err != nil { + return matched, false, err + } + + return matched, a == bValue, nil +} + +func stringishEqualCoalesced(c coalescing.Coalescer, left any, right any) (matched bool, equal bool, err error) { + leftString, leftOk := left.(string) + rightString, rightOk := right.(string) + + if !leftOk && !rightOk { + return false, false, nil + } + + matched = true + + if leftOk && rightOk { + return matched, leftString == rightString, nil + } + + var ( + a string + b any + ) + + if leftOk { + a = leftString + b = right + } else { + a = rightString + b = left + } + + bValue, err := c.ToString(b) + if err != nil { + return matched, false, err + } + + return matched, a == bValue, nil +} + +func vectorishEqualCoalesced(c coalescing.Coalescer, left any, right any) (matched bool, equal bool, err error) { + leftVector, leftErr := typeChecker.ToVector(left) + rightVector, rightErr := typeChecker.ToVector(right) + + if leftErr != nil && rightErr != nil { + return false, false, nil + } + + matched = true + + if leftErr == nil && rightErr == nil { + equal, err := vectorEqualCoalesced(c, leftVector, rightVector) + + return matched, equal, err + } + + var ( + a []any + b any + ) + + if leftErr == nil { + a = leftVector + b = right + } else { + a = rightVector + b = left + } + + // vector conversion is only allowed if the a vector is empty, so that [] == {} depending on the coalescer + if len(a) > 0 { + return matched, false, ErrIncompatibleTypes + } + + bVector, err := c.ToVector(b) + if err != nil { + return matched, false, ErrIncompatibleTypes + } + + equal, err = vectorEqualCoalesced(c, a, bVector) + + return matched, equal, err +} + +func vectorEqualCoalesced(c coalescing.Coalescer, left, right []any) (bool, error) { + if len(left) != len(right) { + return false, nil + } + + for i, leftItem := range left { + rightItem := right[i] + + // wrapping always returns literals, so type assertions are safe here + equal, err := EqualCoalesced(c, leftItem, rightItem) + if err != nil { + return false, err + } + if !equal { + return false, nil + } + } + + return true, nil +} + +func objectishEqualCoalesced(c coalescing.Coalescer, left any, right any) (matched bool, equal bool, err error) { + leftObject, leftErr := typeChecker.ToObject(left) + rightObject, rightErr := typeChecker.ToObject(right) + + if leftErr != nil && rightErr != nil { + return false, false, nil + } + + matched = true + + if leftErr == nil && rightErr == nil { + equal, err := objectEqualCoalesced(c, leftObject, rightObject) + + return matched, equal, err + } + + var ( + a map[string]any + b any + ) + + if leftErr == nil { + a = leftObject + b = right + } else { + a = rightObject + b = left + } + + // vector conversion is only allowed if the a vector is empty, so that [] == {} depending on the coalescer + if len(a) > 0 { + return matched, false, ErrIncompatibleTypes + } + + bObject, err := c.ToObject(b) + if err != nil { + return matched, false, ErrIncompatibleTypes + } + + equal, err = objectEqualCoalesced(c, a, bObject) + + return matched, equal, err +} + +func objectEqualCoalesced(c coalescing.Coalescer, left, right map[string]any) (bool, error) { + if len(left) != len(right) { + return false, nil + } + + keysSeen := map[string]struct{}{} + + for key, leftItem := range left { + rightItem, exists := right[key] + if !exists { + return false, nil + } + + keysSeen[key] = struct{}{} + + // wrapping always returns literals, so type assertions are safe here + equal, err := EqualCoalesced(c, leftItem, rightItem) + if err != nil { + return false, err + } + if !equal { + return false, nil + } + } + + for key := range right { + delete(keysSeen, key) + } + + return len(keysSeen) == 0, nil +} diff --git a/pkg/equality/coalesced_test.go b/pkg/equality/coalesced_test.go new file mode 100644 index 0000000..5897d44 --- /dev/null +++ b/pkg/equality/coalesced_test.go @@ -0,0 +1,373 @@ +// SPDX-FileCopyrightText: 2023 Christoph Mewes +// SPDX-License-Identifier: MIT + +package equality + +import ( + "fmt" + "testing" + + "go.xrstf.de/rudi/pkg/coalescing" + "go.xrstf.de/rudi/pkg/lang/ast" +) + +type invalidConversion int + +const invalid invalidConversion = iota + +type coalescedTestcase struct { + left any + right any + pedantic any + strict any + humane any +} + +func newCoalescedTest(left, right any, pedantic, strict, humane any) coalescedTestcase { + return coalescedTestcase{ + left: left, + right: right, + pedantic: pedantic, + strict: strict, + humane: humane, + } +} + +// type checklist: +// null, bool, int64, float64, string, vector, object +// for brevity's sake, we know that int==int32==int64 internally, likewise for floats + +func getEqualCoalescedTestcases() []coalescedTestcase { + return []coalescedTestcase{ + /////////////////////////////////////////////////////////// + // test nil against all other types + + newCoalescedTest(nil, nil, true, true, true), + newCoalescedTest(nil, ast.Null{}, true, true, true), + + newCoalescedTest(nil, true, invalid, invalid, invalid), + newCoalescedTest(nil, false, invalid, invalid, true), + newCoalescedTest(nil, ast.Bool(true), invalid, invalid, invalid), + newCoalescedTest(nil, ast.Bool(false), invalid, invalid, true), + + newCoalescedTest(nil, int64(0), invalid, invalid, true), + newCoalescedTest(nil, float64(0), invalid, invalid, true), + newCoalescedTest(nil, float64(0.0), invalid, invalid, true), + newCoalescedTest(nil, float64(0.1), invalid, invalid, invalid), + newCoalescedTest(nil, int64(1), invalid, invalid, invalid), + newCoalescedTest(nil, float64(1), invalid, invalid, invalid), + newCoalescedTest(nil, int64(-1), invalid, invalid, invalid), + newCoalescedTest(nil, float64(-1), invalid, invalid, invalid), + newCoalescedTest(nil, ast.Number{Value: int64(0)}, invalid, invalid, true), + newCoalescedTest(nil, ast.Number{Value: float64(0)}, invalid, invalid, true), + newCoalescedTest(nil, ast.Number{Value: float64(0.0)}, invalid, invalid, true), + newCoalescedTest(nil, ast.Number{Value: float64(0.1)}, invalid, invalid, invalid), + newCoalescedTest(nil, ast.Number{Value: int64(1)}, invalid, invalid, invalid), + newCoalescedTest(nil, ast.Number{Value: float64(1)}, invalid, invalid, invalid), + newCoalescedTest(nil, ast.Number{Value: int64(-1)}, invalid, invalid, invalid), + newCoalescedTest(nil, ast.Number{Value: float64(-1)}, invalid, invalid, invalid), + + newCoalescedTest(nil, "", invalid, invalid, true), + newCoalescedTest(nil, " ", invalid, invalid, invalid), + newCoalescedTest(nil, "test", invalid, invalid, invalid), + newCoalescedTest(nil, ast.String(""), invalid, invalid, true), + newCoalescedTest(nil, ast.String(" "), invalid, invalid, invalid), + newCoalescedTest(nil, ast.String("test"), invalid, invalid, invalid), + + newCoalescedTest(nil, []any{}, invalid, invalid, true), + newCoalescedTest(nil, []any{0}, invalid, invalid, invalid), + newCoalescedTest(nil, []any{1}, invalid, invalid, invalid), + newCoalescedTest(nil, []any{""}, invalid, invalid, invalid), + newCoalescedTest(nil, ast.Vector{Data: nil}, invalid, invalid, true), + newCoalescedTest(nil, ast.Vector{Data: []any{}}, invalid, invalid, true), + newCoalescedTest(nil, ast.Vector{Data: []any{0}}, invalid, invalid, invalid), + newCoalescedTest(nil, ast.Vector{Data: []any{1}}, invalid, invalid, invalid), + newCoalescedTest(nil, ast.Vector{Data: []any{""}}, invalid, invalid, invalid), + + newCoalescedTest(nil, map[string]any{}, invalid, invalid, true), + newCoalescedTest(nil, map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest(nil, ast.Object{Data: nil}, invalid, invalid, true), + newCoalescedTest(nil, ast.Object{Data: map[string]any{}}, invalid, invalid, true), + newCoalescedTest(nil, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + + /////////////////////////////////////////////////////////// + // test bool against all other types, except nils + + newCoalescedTest(true, true, true, true, true), + newCoalescedTest(false, false, true, true, true), + newCoalescedTest(true, false, false, false, false), + newCoalescedTest(true, ast.Bool(true), true, true, true), + newCoalescedTest(false, ast.Bool(false), true, true, true), + newCoalescedTest(true, ast.Bool(false), false, false, false), + + newCoalescedTest(true, int64(0), invalid, invalid, false), + newCoalescedTest(true, float64(0), invalid, invalid, false), + newCoalescedTest(true, int64(1), invalid, invalid, true), + newCoalescedTest(true, float64(1), invalid, invalid, true), + newCoalescedTest(true, float64(1.0), invalid, invalid, true), + newCoalescedTest(true, int64(-1), invalid, invalid, true), + newCoalescedTest(true, float64(-1), invalid, invalid, true), + newCoalescedTest(false, int64(0), invalid, invalid, true), + newCoalescedTest(false, float64(0), invalid, invalid, true), + newCoalescedTest(false, int64(1), invalid, invalid, false), + newCoalescedTest(false, float64(1), invalid, invalid, false), + newCoalescedTest(false, float64(1.0), invalid, invalid, false), + newCoalescedTest(false, int64(-1), invalid, invalid, false), + newCoalescedTest(false, float64(-1), invalid, invalid, false), + newCoalescedTest(true, ast.Number{Value: int64(0)}, invalid, invalid, false), + newCoalescedTest(true, ast.Number{Value: float64(0)}, invalid, invalid, false), + newCoalescedTest(true, ast.Number{Value: int64(1)}, invalid, invalid, true), + newCoalescedTest(true, ast.Number{Value: float64(1)}, invalid, invalid, true), + newCoalescedTest(true, ast.Number{Value: float64(1.0)}, invalid, invalid, true), + newCoalescedTest(true, ast.Number{Value: int64(-1)}, invalid, invalid, true), + newCoalescedTest(true, ast.Number{Value: float64(-1)}, invalid, invalid, true), + newCoalescedTest(false, ast.Number{Value: int64(0)}, invalid, invalid, true), + newCoalescedTest(false, ast.Number{Value: float64(0)}, invalid, invalid, true), + newCoalescedTest(false, ast.Number{Value: int64(1)}, invalid, invalid, false), + newCoalescedTest(false, ast.Number{Value: float64(1)}, invalid, invalid, false), + newCoalescedTest(false, ast.Number{Value: float64(1.0)}, invalid, invalid, false), + newCoalescedTest(false, ast.Number{Value: int64(-1)}, invalid, invalid, false), + newCoalescedTest(false, ast.Number{Value: float64(-1)}, invalid, invalid, false), + + newCoalescedTest(true, "", invalid, invalid, false), + newCoalescedTest(true, " ", invalid, invalid, true), + newCoalescedTest(true, "test", invalid, invalid, true), + newCoalescedTest(false, "", invalid, invalid, true), + newCoalescedTest(false, " ", invalid, invalid, false), + newCoalescedTest(false, "test", invalid, invalid, false), + newCoalescedTest(true, ast.String(""), invalid, invalid, false), + newCoalescedTest(true, ast.String(" "), invalid, invalid, true), + newCoalescedTest(true, ast.String("test"), invalid, invalid, true), + newCoalescedTest(false, ast.String(""), invalid, invalid, true), + newCoalescedTest(false, ast.String(" "), invalid, invalid, false), + newCoalescedTest(false, ast.String("test"), invalid, invalid, false), + + newCoalescedTest(true, []any{}, invalid, invalid, false), + newCoalescedTest(true, []any{0}, invalid, invalid, true), + newCoalescedTest(true, []any{1}, invalid, invalid, true), + newCoalescedTest(true, []any{""}, invalid, invalid, true), + newCoalescedTest(false, []any{}, invalid, invalid, true), + newCoalescedTest(false, []any{0}, invalid, invalid, false), + newCoalescedTest(false, []any{1}, invalid, invalid, false), + newCoalescedTest(false, []any{""}, invalid, invalid, false), + newCoalescedTest(true, ast.Vector{Data: []any{}}, invalid, invalid, false), + newCoalescedTest(true, ast.Vector{Data: []any{0}}, invalid, invalid, true), + newCoalescedTest(true, ast.Vector{Data: []any{1}}, invalid, invalid, true), + newCoalescedTest(true, ast.Vector{Data: []any{""}}, invalid, invalid, true), + newCoalescedTest(false, ast.Vector{Data: []any{}}, invalid, invalid, true), + newCoalescedTest(false, ast.Vector{Data: []any{0}}, invalid, invalid, false), + newCoalescedTest(false, ast.Vector{Data: []any{1}}, invalid, invalid, false), + newCoalescedTest(false, ast.Vector{Data: []any{""}}, invalid, invalid, false), + + newCoalescedTest(true, map[string]any{}, invalid, invalid, false), + newCoalescedTest(true, map[string]any{"": ""}, invalid, invalid, true), + newCoalescedTest(false, map[string]any{}, invalid, invalid, true), + newCoalescedTest(false, map[string]any{"": ""}, invalid, invalid, false), + newCoalescedTest(true, ast.Object{Data: nil}, invalid, invalid, false), + newCoalescedTest(true, ast.Object{Data: map[string]any{}}, invalid, invalid, false), + newCoalescedTest(true, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, true), + newCoalescedTest(false, ast.Object{Data: nil}, invalid, invalid, true), + newCoalescedTest(false, ast.Object{Data: map[string]any{}}, invalid, invalid, true), + newCoalescedTest(false, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, false), + + /////////////////////////////////////////////////////////// + // test numbers against all other types, except nils and bools + + newCoalescedTest(0, 0, true, true, true), + newCoalescedTest(0, 1, false, false, false), + newCoalescedTest(0, -1, false, false, false), + newCoalescedTest(0, ast.Number{Value: -1}, false, false, false), + newCoalescedTest(2, 2, true, true, true), + newCoalescedTest(0, 0.0, invalid, true, true), + newCoalescedTest(0, 1.0, invalid, false, false), + newCoalescedTest(2, 2.0, invalid, true, true), + newCoalescedTest(2, ast.Number{Value: 2.0}, invalid, true, true), + newCoalescedTest(-3.14, -3.14, true, true, true), + newCoalescedTest(-3.14, ast.Number{Value: -3.14}, true, true, true), + newCoalescedTest(ast.Number{Value: -3.14}, ast.Number{Value: -3.14}, true, true, true), + + newCoalescedTest(0, "", invalid, invalid, true), + newCoalescedTest(0.0, "", invalid, invalid, true), + newCoalescedTest(0, " ", invalid, invalid, true), + newCoalescedTest(0.0, " ", invalid, invalid, true), + newCoalescedTest(0, "0", invalid, invalid, true), + newCoalescedTest(0.0, "0", invalid, invalid, true), + newCoalescedTest(0, "0000", invalid, invalid, true), + newCoalescedTest(0, "1", invalid, invalid, false), + newCoalescedTest(1, "1", invalid, invalid, true), + newCoalescedTest(1, " 1 ", invalid, invalid, true), + newCoalescedTest(3, "3", invalid, invalid, true), + newCoalescedTest(3.1, "3.1", invalid, invalid, true), + newCoalescedTest(3.1, ast.String("3.1"), invalid, invalid, true), + + newCoalescedTest(0, []any{}, invalid, invalid, invalid), + newCoalescedTest(0, []any{0}, invalid, invalid, invalid), + newCoalescedTest(0.0, []any{}, invalid, invalid, invalid), + newCoalescedTest(0.0, []any{0}, invalid, invalid, invalid), + newCoalescedTest(1, []any{}, invalid, invalid, invalid), + newCoalescedTest(1, []any{0}, invalid, invalid, invalid), + newCoalescedTest(3.14, []any{}, invalid, invalid, invalid), + newCoalescedTest(3.14, []any{0}, invalid, invalid, invalid), + newCoalescedTest(0, ast.Vector{Data: []any{}}, invalid, invalid, invalid), + newCoalescedTest(0, ast.Vector{Data: []any{0}}, invalid, invalid, invalid), + newCoalescedTest(0.0, ast.Vector{Data: []any{}}, invalid, invalid, invalid), + newCoalescedTest(0.0, ast.Vector{Data: []any{0}}, invalid, invalid, invalid), + newCoalescedTest(1, ast.Vector{Data: []any{}}, invalid, invalid, invalid), + newCoalescedTest(1, ast.Vector{Data: []any{0}}, invalid, invalid, invalid), + newCoalescedTest(3.14, ast.Vector{Data: []any{}}, invalid, invalid, invalid), + newCoalescedTest(3.14, ast.Vector{Data: []any{0}}, invalid, invalid, invalid), + + newCoalescedTest(0, map[string]any{}, invalid, invalid, invalid), + newCoalescedTest(0, map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest(0.0, map[string]any{}, invalid, invalid, invalid), + newCoalescedTest(0.0, map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest(1, map[string]any{}, invalid, invalid, invalid), + newCoalescedTest(1, map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest(3.14, map[string]any{}, invalid, invalid, invalid), + newCoalescedTest(3.14, map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest(0, ast.Object{Data: map[string]any{}}, invalid, invalid, invalid), + newCoalescedTest(0, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + newCoalescedTest(0.0, ast.Object{Data: map[string]any{}}, invalid, invalid, invalid), + newCoalescedTest(0.0, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + newCoalescedTest(1, ast.Object{Data: map[string]any{}}, invalid, invalid, invalid), + newCoalescedTest(1, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + newCoalescedTest(3.14, ast.Object{Data: map[string]any{}}, invalid, invalid, invalid), + newCoalescedTest(3.14, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + + /////////////////////////////////////////////////////////// + // test strings against all other types, except nils, bools and numbers + + newCoalescedTest("", "", true, true, true), + newCoalescedTest("", " ", false, false, false), + newCoalescedTest("", "0", false, false, false), + newCoalescedTest("", "a", false, false, false), + newCoalescedTest("a", "a", true, true, true), + newCoalescedTest("a", "A", false, false, false), + newCoalescedTest("a", " a ", false, false, false), + + newCoalescedTest("", []any{}, invalid, invalid, invalid), + newCoalescedTest("", []any{0}, invalid, invalid, invalid), + newCoalescedTest("0", []any{}, invalid, invalid, invalid), + newCoalescedTest("0", []any{0}, invalid, invalid, invalid), + newCoalescedTest("", ast.Vector{Data: []any{}}, invalid, invalid, invalid), + newCoalescedTest("", ast.Vector{Data: []any{0}}, invalid, invalid, invalid), + newCoalescedTest("0", ast.Vector{Data: []any{}}, invalid, invalid, invalid), + newCoalescedTest("0", ast.Vector{Data: []any{0}}, invalid, invalid, invalid), + + newCoalescedTest("", map[string]any{}, invalid, invalid, invalid), + newCoalescedTest("", map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest("0", map[string]any{}, invalid, invalid, invalid), + newCoalescedTest("0", map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest("", ast.Object{Data: map[string]any{}}, invalid, invalid, invalid), + newCoalescedTest("", ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + newCoalescedTest("0", ast.Object{Data: map[string]any{}}, invalid, invalid, invalid), + newCoalescedTest("0", ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + + /////////////////////////////////////////////////////////// + // test vectors against all other types, except nils, bools, numbers and strings + + newCoalescedTest([]any{}, []any{}, true, true, true), + newCoalescedTest([]any{}, []any{0}, false, false, false), + newCoalescedTest([]any{0}, []any{0}, true, true, true), + newCoalescedTest([]any{0}, []any{"0"}, invalid, invalid, true), + newCoalescedTest([]any{false}, []any{0}, invalid, invalid, true), + newCoalescedTest([]any{}, ast.Vector{Data: []any{}}, true, true, true), + newCoalescedTest([]any{}, ast.Vector{Data: []any{0}}, false, false, false), + newCoalescedTest([]any{0}, ast.Vector{Data: []any{0}}, true, true, true), + newCoalescedTest([]any{0}, ast.Vector{Data: []any{"0"}}, invalid, invalid, true), + newCoalescedTest([]any{false}, ast.Vector{Data: []any{0}}, invalid, invalid, true), + + newCoalescedTest([]any{}, map[string]any{}, invalid, invalid, true), + newCoalescedTest([]any{}, map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest([]any{0}, map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest([]any{1}, map[string]any{}, invalid, invalid, invalid), + newCoalescedTest([]any{nil}, map[string]any{"": ""}, invalid, invalid, invalid), + newCoalescedTest([]any{}, ast.Object{Data: map[string]any{}}, invalid, invalid, true), + newCoalescedTest([]any{}, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + newCoalescedTest([]any{0}, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + newCoalescedTest([]any{nil}, ast.Object{Data: map[string]any{"": ""}}, invalid, invalid, invalid), + + /////////////////////////////////////////////////////////// + // test objects + + newCoalescedTest(map[string]any{}, map[string]any{}, true, true, true), + newCoalescedTest(map[string]any{}, map[string]any{"foo": "bar"}, false, false, false), + newCoalescedTest(map[string]any{"foo": false}, map[string]any{"foo": ""}, invalid, invalid, true), + newCoalescedTest(map[string]any{"foo": "bar"}, map[string]any{"foo": "bar"}, true, true, true), + newCoalescedTest(map[string]any{"foo": "bar"}, map[string]any{"foo": "X"}, false, false, false), + newCoalescedTest(map[string]any{}, ast.Object{Data: map[string]any{}}, true, true, true), + newCoalescedTest(map[string]any{}, ast.Object{Data: map[string]any{"foo": "bar"}}, false, false, false), + newCoalescedTest(map[string]any{"foo": false}, ast.Object{Data: map[string]any{"foo": ""}}, invalid, invalid, true), + newCoalescedTest(map[string]any{"foo": "bar"}, ast.Object{Data: map[string]any{"foo": "bar"}}, true, true, true), + newCoalescedTest(map[string]any{"foo": "bar"}, ast.Object{Data: map[string]any{"foo": "X"}}, false, false, false), + } +} + +func TestEqualCoalesced(t *testing.T) { + pedanticCoalescer := coalescing.NewPedantic() + strictCoalescer := coalescing.NewStrict() + humaneCoalescer := coalescing.NewHumane() + + type subtest struct { + left any + right any + coal coalescing.Coalescer + expected any + } + + for _, testcase := range getEqualCoalescedTestcases() { + t.Run(fmt.Sprintf("%v %v", testcase.left, testcase.right), func(t *testing.T) { + subtests := []subtest{ + { + left: testcase.left, + right: testcase.right, + coal: pedanticCoalescer, + expected: testcase.pedantic, + }, + { + left: testcase.left, + right: testcase.right, + coal: strictCoalescer, + expected: testcase.strict, + }, + { + left: testcase.left, + right: testcase.right, + coal: humaneCoalescer, + expected: testcase.humane, + }, + } + + for _, subtest := range subtests { + _, expectErr := subtest.expected.(invalidConversion) + + equal, err := EqualCoalesced(subtest.coal, subtest.left, subtest.right) + if err != nil { + if !expectErr { + t.Errorf("%T unexpectedly failed: %v (%T) == %v (%T): %v", subtest.coal, subtest.left, subtest.left, subtest.right, subtest.right, err) + } + } else { + if expectErr { + t.Errorf("Expected %T to fail on %v (%T) == %v (%T), but got %v", subtest.coal, subtest.left, subtest.left, subtest.right, subtest.right, equal) + } else if equal != subtest.expected { + t.Errorf("Expected %T to return %v (%T) == %v (%T) => %v", subtest.coal, subtest.left, subtest.left, subtest.right, subtest.right, subtest.expected) + } + } + + // comparisons must be associated (a == b means b == a) + flippedEqual, err := EqualCoalesced(subtest.coal, subtest.right, subtest.left) + if err != nil { + if !expectErr { + t.Errorf("%T unexpectedly failed on reverse test: %v (%T) == %v (%T): %v", subtest.coal, subtest.right, subtest.right, subtest.left, subtest.left, err) + } + } else { + if expectErr { + t.Errorf("Expected %T to fail on %v (%T) == %v (%T), but got %v", subtest.coal, subtest.left, subtest.left, subtest.right, subtest.right, flippedEqual) + } else if equal != flippedEqual { + t.Errorf("Expected %T to be associative, but is not", subtest.coal) + } + } + } + }) + } +} diff --git a/pkg/equality/loose.go b/pkg/equality/loose.go deleted file mode 100644 index ce560e3..0000000 --- a/pkg/equality/loose.go +++ /dev/null @@ -1,282 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Christoph Mewes -// SPDX-License-Identifier: MIT - -package equality - -import ( - "fmt" - - "go.xrstf.de/rudi/pkg/coalescing" - "go.xrstf.de/rudi/pkg/eval/types" - "go.xrstf.de/rudi/pkg/lang/ast" -) - -// equality, but with using coalescing, so 1 == "1". -func EqualEnough(left, right ast.Literal) (bool, error) { - // if either of the sides is a null, convert the other to null - matched, equal, err := nullishEqualEnough(left, right) - if matched { - return equal, err - } - - // if either of the sides is a bool, convert the other to bool - matched, equal, err = boolishEqualEnough(left, right) - if matched { - return equal, err - } - - // if either of the sides is a number, convert the other to a number - matched, equal, err = numberishEqualEnough(left, right) - if matched { - return equal, err - } - - // if either of the sides is a string, convert the other to a string - matched, equal, err = stringishEqualEnough(left, right) - if matched { - return equal, err - } - - // now both sides can basically just be vectors or objects - - switch leftAsserted := left.(type) { - case ast.Vector: - return vectorishEqualEnough(leftAsserted, right) - case ast.Object: - return objectishEqualEnough(leftAsserted, right) - default: - return false, fmt.Errorf("cannot compare with %T with %T", left, right) - } -} - -func nullishEqualEnough(left ast.Literal, right ast.Literal) (matched bool, equal bool, err error) { - _, leftOk := left.(ast.Null) - _, rightOk := right.(ast.Null) - - if !leftOk && !rightOk { - return false, false, nil - } - - matched = true - - if leftOk && rightOk { - return matched, true, nil - } - - var b ast.Literal - - if leftOk { - b = right - } else { - b = left - } - - bValue, err := coalescing.NewHumane().ToBool(b) - if err != nil { - return matched, false, ErrIncompatibleTypes - } - - return matched, !bValue, nil -} - -func boolishEqualEnough(left ast.Literal, right ast.Literal) (matched bool, equal bool, err error) { - leftBool, leftOk := left.(ast.Bool) - rightBool, rightOk := right.(ast.Bool) - - if !leftOk && !rightOk { - return false, false, nil - } - - matched = true - - if leftOk && rightOk { - return matched, leftBool.Equal(rightBool), nil - } - - var ( - a bool - b ast.Literal - ) - - if leftOk { - a = bool(leftBool) - b = right - } else { - a = bool(rightBool) - b = left - } - - bValue, err := coalescing.NewHumane().ToBool(b) - if err != nil { - return matched, false, ErrIncompatibleTypes - } - - return matched, a == bValue, nil -} - -func numberishEqualEnough(left ast.Literal, right ast.Literal) (matched bool, equal bool, err error) { - leftNumber, leftOk := left.(ast.Number) - rightNumber, rightOk := right.(ast.Number) - - if !leftOk && !rightOk { - return false, false, nil - } - - matched = true - - if leftOk && rightOk { - return matched, leftNumber.ToFloat() == rightNumber.ToFloat(), nil - } - - var ( - a ast.Number - b ast.Literal - ) - - if leftOk { - a = leftNumber - b = right - } else { - a = rightNumber - b = left - } - - bValue, err := coalescing.NewHumane().ToFloat64(b) - if err != nil { - return matched, false, ErrIncompatibleTypes - } - - return matched, a.ToFloat() == bValue, nil -} - -func stringishEqualEnough(left ast.Literal, right ast.Literal) (matched bool, equal bool, err error) { - leftString, leftOk := left.(ast.String) - rightString, rightOk := right.(ast.String) - - if !leftOk && !rightOk { - return false, false, nil - } - - matched = true - - if leftOk && rightOk { - return matched, leftString.Equal(rightString), nil - } - - var ( - a string - b ast.Literal - ) - - if leftOk { - a = string(leftString) - b = right - } else { - a = string(rightString) - b = left - } - - bValue, err := coalescing.NewHumane().ToString(b) - if err != nil { - return matched, false, ErrIncompatibleTypes - } - - return matched, a == bValue, nil -} - -func vectorishEqualEnough(left ast.Vector, right any) (bool, error) { - // extra: [] == {} - rightObject, ok := right.(ast.Object) - if ok { - return len(left.Data) == 0 && len(rightObject.Data) == 0, nil - } - - rightValue, ok := right.(ast.Vector) - if !ok { - return false, ErrIncompatibleTypes - } - - if len(left.Data) != len(rightValue.Data) { - return false, nil - } - - for i, leftItem := range left.Data { - rightItem := rightValue.Data[i] - - leftWrapped, err := types.WrapNative(leftItem) - if err != nil { - return false, ErrIncompatibleTypes - } - - rightWrapped, err := types.WrapNative(rightItem) - if err != nil { - return false, ErrIncompatibleTypes - } - - // wrapping always returns literals, so type assertions are safe here - equal, err := EqualEnough(leftWrapped, rightWrapped) - if err != nil { - return false, err - } - - if !equal { - return false, nil - } - } - - return true, nil -} - -func objectishEqualEnough(left ast.Object, right any) (bool, error) { - // extra: [] == {} - rightVector, ok := right.(ast.Vector) - if ok { - return len(left.Data) == 0 && len(rightVector.Data) == 0, nil - } - - rightValue, ok := right.(ast.Object) - if !ok { - return false, ErrIncompatibleTypes - } - - if len(left.Data) != len(rightValue.Data) { - return false, nil - } - - keysSeen := map[string]struct{}{} - - for key, leftItem := range left.Data { - rightItem, exists := rightValue.Data[key] - if !exists { - return false, nil - } - - keysSeen[key] = struct{}{} - - leftWrapped, err := types.WrapNative(leftItem) - if err != nil { - return false, ErrIncompatibleTypes - } - - rightWrapped, err := types.WrapNative(rightItem) - if err != nil { - return false, ErrIncompatibleTypes - } - - // wrapping always returns literals, so type assertions are safe here - equal, err := EqualEnough(leftWrapped, rightWrapped) - if err != nil { - return false, err - } - - if !equal { - return false, nil - } - } - - for key := range rightValue.Data { - delete(keysSeen, key) - } - - return len(keysSeen) == 0, nil -} diff --git a/pkg/equality/strict.go b/pkg/equality/strict.go deleted file mode 100644 index efdd73d..0000000 --- a/pkg/equality/strict.go +++ /dev/null @@ -1,154 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Christoph Mewes -// SPDX-License-Identifier: MIT - -package equality - -import ( - "errors" - "fmt" - - "go.xrstf.de/rudi/pkg/eval/types" - "go.xrstf.de/rudi/pkg/lang/ast" -) - -var ErrIncompatibleTypes = errors.New("types are incompatible") - -func StrictEqual(left, right ast.Literal) (bool, error) { - switch leftAsserted := left.(type) { - case ast.Null: - return nullStrictEquals(leftAsserted, right) - case ast.Bool: - return boolStrictEquals(leftAsserted, right) - case ast.String: - return stringStrictEquals(leftAsserted, right) - case ast.Number: - return numberStrictEquals(leftAsserted, right) - case ast.Vector: - return vectorStrictEquals(leftAsserted, right) - case ast.Object: - return objectStrictEquals(leftAsserted, right) - default: - return false, fmt.Errorf("cannot compare with %T with %T", left, right) - } -} - -func boolStrictEquals(left ast.Bool, right ast.Literal) (bool, error) { - rightValue, ok := right.(ast.Bool) - if !ok { - return false, ErrIncompatibleTypes - } - - return left.Equal(rightValue), nil -} - -func nullStrictEquals(left ast.Null, right any) (bool, error) { - rightValue, ok := right.(ast.Null) - if !ok { - return false, ErrIncompatibleTypes - } - - return left.Equal(rightValue), nil -} - -func stringStrictEquals(left ast.String, right any) (bool, error) { - rightValue, ok := right.(ast.String) - if !ok { - return false, ErrIncompatibleTypes - } - - return left.Equal(rightValue), nil -} - -func numberStrictEquals(left ast.Number, right any) (bool, error) { - rightValue, ok := right.(ast.Number) - if !ok { - return false, ErrIncompatibleTypes - } - - return left.Equal(rightValue), nil -} - -func vectorStrictEquals(left ast.Vector, right any) (bool, error) { - rightValue, ok := right.(ast.Vector) - if !ok { - return false, ErrIncompatibleTypes - } - - if len(left.Data) != len(rightValue.Data) { - return false, nil - } - - for i, leftItem := range left.Data { - rightItem := rightValue.Data[i] - - leftWrapped, err := types.WrapNative(leftItem) - if err != nil { - return false, ErrIncompatibleTypes - } - - rightWrapped, err := types.WrapNative(rightItem) - if err != nil { - return false, ErrIncompatibleTypes - } - - // wrapping always returns literals, so type assertions are safe here - equal, err := StrictEqual(leftWrapped, rightWrapped) - if err != nil { - return false, err - } - - if !equal { - return false, nil - } - } - - return true, nil -} - -func objectStrictEquals(left ast.Object, right any) (bool, error) { - rightValue, ok := right.(ast.Object) - if !ok { - return false, ErrIncompatibleTypes - } - - if len(left.Data) != len(rightValue.Data) { - return false, nil - } - - keysSeen := map[string]struct{}{} - - for key, leftItem := range left.Data { - rightItem, exists := rightValue.Data[key] - if !exists { - return false, nil - } - - keysSeen[key] = struct{}{} - - leftWrapped, err := types.WrapNative(leftItem) - if err != nil { - return false, ErrIncompatibleTypes - } - - rightWrapped, err := types.WrapNative(rightItem) - if err != nil { - return false, ErrIncompatibleTypes - } - - // wrapping always returns literals, so type assertions are safe here - equal, err := StrictEqual(leftWrapped, rightWrapped) - if err != nil { - return false, err - } - - if !equal { - return false, nil - } - } - - for key := range rightValue.Data { - delete(keysSeen, key) - } - - return len(keysSeen) == 0, nil -} diff --git a/pkg/eval/builtin/comparisons.go b/pkg/eval/builtin/comparisons.go index 708f706..443b447 100644 --- a/pkg/eval/builtin/comparisons.go +++ b/pkg/eval/builtin/comparisons.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" + "go.xrstf.de/rudi/pkg/coalescing" "go.xrstf.de/rudi/pkg/equality" "go.xrstf.de/rudi/pkg/eval" "go.xrstf.de/rudi/pkg/eval/types" @@ -23,32 +24,12 @@ func eqFunction(ctx types.Context, args []ast.Expression) (any, error) { return nil, fmt.Errorf("argument #0: %w", err) } - leftData, err = types.WrapNative(leftData) - if err != nil { - return nil, fmt.Errorf("argument #0: %w", err) - } - - leftValue, ok := leftData.(ast.Literal) - if !ok { - return nil, fmt.Errorf("argument #0 is not a literal, but %T", leftData) - } - _, rightData, err := eval.EvalExpression(ctx, args[1]) if err != nil { return nil, fmt.Errorf("argument #1: %w", err) } - rightData, err = types.WrapNative(rightData) - if err != nil { - return nil, fmt.Errorf("argument #1: %w", err) - } - - rightValue, ok := rightData.(ast.Literal) - if !ok { - return nil, fmt.Errorf("argument #1 is not a literal, but %T", rightData) - } - - equal, err := equality.StrictEqual(leftValue, rightValue) + equal, err := equality.EqualCoalesced(ctx.Coalesce(), leftData, rightData) if err != nil { return nil, err } @@ -66,32 +47,12 @@ func likeFunction(ctx types.Context, args []ast.Expression) (any, error) { return nil, fmt.Errorf("argument #0: %w", err) } - leftData, err = types.WrapNative(leftData) - if err != nil { - return nil, fmt.Errorf("argument #0: %w", err) - } - - leftValue, ok := leftData.(ast.Literal) - if !ok { - return nil, fmt.Errorf("argument #0 is not a literal, but %T", leftData) - } - _, rightData, err := eval.EvalExpression(ctx, args[1]) if err != nil { return nil, fmt.Errorf("argument #1: %w", err) } - rightData, err = types.WrapNative(rightData) - if err != nil { - return nil, fmt.Errorf("argument #1: %w", err) - } - - rightValue, ok := rightData.(ast.Literal) - if !ok { - return nil, fmt.Errorf("argument #1 is not a literal, but %T", rightData) - } - - equal, err := equality.EqualEnough(leftValue, rightValue) + equal, err := equality.EqualCoalesced(coalescing.NewHumane(), leftData, rightData) if err != nil { return nil, err } diff --git a/pkg/eval/builtin/comparisons_test.go b/pkg/eval/builtin/comparisons_test.go index 17ab7c2..33a503e 100644 --- a/pkg/eval/builtin/comparisons_test.go +++ b/pkg/eval/builtin/comparisons_test.go @@ -114,9 +114,10 @@ func TestEqFunction(t *testing.T) { document: testDoc, }, { + // strict coalescing allows lossless float->int conversion left: `1`, right: `1.0`, - expected: false, + expected: true, }, { left: `1`, @@ -343,14 +344,14 @@ func TestLikeFunction(t *testing.T) { expected: true, }, { - left: `{}`, - right: `[1]`, - expected: false, + left: `{}`, + right: `[1]`, + invalid: true, }, { - left: `{foo "bar"}`, - right: `[]`, - expected: false, + left: `{foo "bar"}`, + right: `[]`, + invalid: true, }, }) diff --git a/pkg/eval/builtin/lists.go b/pkg/eval/builtin/lists.go index 926b00f..a18e7b3 100644 --- a/pkg/eval/builtin/lists.go +++ b/pkg/eval/builtin/lists.go @@ -637,19 +637,12 @@ func containsFunction(ctx types.Context, args []ast.Expression) (any, error) { } if vec, err := ctx.Coalesce().ToVector(haystack); err == nil { - needleLiteral, err := types.WrapNative(needle) - if err != nil { - return nil, fmt.Errorf("cannot compare with %T: %w", needle, err) - } - for _, val := range vec { - valLiteral, err := types.WrapNative(val) + equal, err := equality.EqualCoalesced(ctx.Coalesce(), val, needle) if err != nil { - return nil, fmt.Errorf("cannot compare with %T: %w", val, err) + return false, err } - - equal, err := equality.StrictEqual(valLiteral, needleLiteral) - if err == nil && equal { + if equal { return true, nil } } diff --git a/pkg/eval/builtin/types_test.go b/pkg/eval/builtin/types_test.go index 156254b..ad5d032 100644 --- a/pkg/eval/builtin/types_test.go +++ b/pkg/eval/builtin/types_test.go @@ -43,6 +43,10 @@ func TestToStringFunction(t *testing.T) { Expression: `(to-string true)`, Expected: "true", }, + { + Expression: `(to-string false)`, + Expected: "false", + }, { Expression: `(to-string null)`, Expected: "", @@ -109,10 +113,18 @@ func TestToIntFunction(t *testing.T) { Expression: `(to-int [])`, Invalid: true, }, + { + Expression: `(to-int [0])`, + Invalid: true, + }, { Expression: `(to-int {})`, Invalid: true, }, + { + Expression: `(to-int {"" ""})`, + Invalid: true, + }, } for _, testcase := range testcases { @@ -167,10 +179,18 @@ func TestToFloatFunction(t *testing.T) { Expression: `(to-float [])`, Invalid: true, }, + { + Expression: `(to-float [""])`, + Invalid: true, + }, { Expression: `(to-float {})`, Invalid: true, }, + { + Expression: `(to-float {"" ""})`, + Invalid: true, + }, } for _, testcase := range testcases { diff --git a/pkg/testutil/php/php.go b/pkg/testutil/php/php.go new file mode 100644 index 0000000..97512c5 --- /dev/null +++ b/pkg/testutil/php/php.go @@ -0,0 +1,62 @@ +//go:build integration + +// SPDX-FileCopyrightText: 2023 Christoph Mewes +// SPDX-License-Identifier: MIT + +package php + +import ( + "fmt" + "strings" +) + +// Convert converts a Go value to a rough PHP representation. +func Convert(val any) string { + switch asserted := val.(type) { + case nil: + return `null` + case bool: + return fmt.Sprintf("(bool) %v", asserted) + case int: + return fmt.Sprintf("(int) %d", asserted) + case int32: + return fmt.Sprintf("(int) %d", asserted) + case int64: + return fmt.Sprintf("(int) %d", asserted) + case float32: + return fmt.Sprintf("(double) %f", asserted) + case float64: + return fmt.Sprintf("(double) %f", asserted) + case string: + return fmt.Sprintf("%q", asserted) + case []any: + items := make([]string, len(asserted)) + for i, item := range asserted { + phpItem := Convert(item) + if phpItem == "" { + return "" + } + + items[i] = phpItem + } + + return fmt.Sprintf(`[%s]`, strings.Join(items, ",")) + + case map[string]any: + items := make([]string, len(asserted)) + i := 0 + for key, value := range asserted { + phpValue := Convert(value) + if phpValue == "" { + return "" + } + + items[i] = fmt.Sprintf(`%q => %s`, key, phpValue) + i++ + } + + return fmt.Sprintf(`[%s]`, strings.Join(items, ",")) + default: + return "" + } +} diff --git a/pkg/testutil/testcase.go b/pkg/testutil/testcase.go index 11258cb..2f42099 100644 --- a/pkg/testutil/testcase.go +++ b/pkg/testutil/testcase.go @@ -129,7 +129,7 @@ func assertResultValue(t *testing.T, expected any, actual any) { if !ok { t.Errorf("Result has invalid type:\n%s", renderDiff(expected, actual)) } else { - equal, err := equality.StrictEqual(expectedNode, resultNode) + equal, err := equality.EqualCoalesced(nil, expectedNode, resultNode) if err != nil { t.Errorf("Could not compare result with expectation: %v", err) } else if !equal { diff --git a/program.go b/program.go index afd3a91..2537b2a 100644 --- a/program.go +++ b/program.go @@ -9,6 +9,7 @@ import ( "go.xrstf.de/rudi/pkg/debug" "go.xrstf.de/rudi/pkg/eval" + "go.xrstf.de/rudi/pkg/eval/types" "go.xrstf.de/rudi/pkg/lang/ast" "go.xrstf.de/rudi/pkg/lang/parser" ) @@ -87,7 +88,7 @@ func (p *rudiProgram) Run(data any, variables Variables, funcs Functions, coales // get current state of the document docData := finalCtx.GetDocument().Data() - unwrappedDocData, err := Unwrap(docData) + unwrappedDocData, err := types.UnwrapType(docData) if err != nil { // this should never happen return nil, nil, fmt.Errorf("failed to unwrap final document data: %w", err) @@ -105,7 +106,7 @@ func (p *rudiProgram) RunContext(ctx Context) (finalCtx Context, result any, err return ctx, nil, err } - unwrappedResult, err := Unwrap(result) + unwrappedResult, err := types.UnwrapType(result) if err != nil { // this should never happen return ctx, nil, fmt.Errorf("failed to unwrap result: %w", err)