From fd1c22a045ca7f8b2460c09d5f4de6f32f06a9e4 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 19 Sep 2024 16:36:54 -0700 Subject: [PATCH 01/17] cedar-go/internal: add a MapSet type, which is basically just some convenience functions around a map[T]struct{} Signed-off-by: Patrick Jakubowski --- internal/sets/mapset.go | 154 +++++++++++++++++++ internal/sets/mapset_test.go | 276 +++++++++++++++++++++++++++++++++++ 2 files changed, 430 insertions(+) create mode 100644 internal/sets/mapset.go create mode 100644 internal/sets/mapset_test.go diff --git a/internal/sets/mapset.go b/internal/sets/mapset.go new file mode 100644 index 0000000..5a0ee5a --- /dev/null +++ b/internal/sets/mapset.go @@ -0,0 +1,154 @@ +<<<<<<<< HEAD:internal/hashset.go +package internal +======== +package sets +>>>>>>>> 0f61d5f (fixup):internal/sets/mapset.go + +import ( + "encoding/json" + "fmt" + + "golang.org/x/exp/maps" +) + +// Similar to the concept of a [legal peppercorn](https://en.wikipedia.org/wiki/Peppercorn_(law)), this instance of +// nothingness is required in order to transact with Go's map[T]struct{} idiom. +var peppercorn = struct{}{} + +// MapSet is a struct that adds some convenience to the otherwise cumbersome map[T]struct{} idiom used in Go to +// implement sets of comparable types. +type MapSet[T comparable] struct { + m map[T]struct{} +} + +// NewMapSet returns a MapSet ready for use. Optionally, a desired size for the MapSet can be passed as an argument, +// as in the argument to make() for a map type. +func NewMapSet[T comparable](args ...int) MapSet[T] { + if len(args) > 1 { + panic(fmt.Sprintf("too many arguments passed to NewMapSet(). got: %v, expected 0 or 1", len(args))) + } + + var size int + if len(args) == 1 { + size = args[0] + } + + var m map[T]struct{} + if size > 0 { + m = make(map[T]struct{}, size) + } + + return MapSet[T]{m: m} +} + +// NewMapSetFromSlice creates a MapSet of size len(items) and calls AddSlice(items) on it. +func NewMapSetFromSlice[T comparable](items []T) MapSet[T] { + h := NewMapSet[T](len(items)) + h.AddSlice(items) + return h +} + +// Add an item to the set. Returns true if the item did not exist in the set. +func (h *MapSet[T]) Add(item T) bool { + if h.m == nil { + h.m = map[T]struct{}{} + } + + _, exists := h.m[item] + h.m[item] = peppercorn + return !exists +} + +// AddSlice adds a slice of items to the set, returning true if any new items were added to the set. +func (h *MapSet[T]) AddSlice(items []T) bool { + modified := false + for _, i := range items { + modified = h.Add(i) || modified + } + return modified +} + +// Remove an item from the Set. Returns true if the item existed in the set. +func (h *MapSet[T]) Remove(item T) bool { + _, exists := h.m[item] + delete(h.m, item) + return exists +} + +// RemoveSlice removes a slice of items from the set, returning true if any items existed in the set. +func (h *MapSet[T]) RemoveSlice(items []T) bool { + modified := false + for _, i := range items { + modified = h.Remove(i) || modified + } + return modified +} + +// Contains returns whether the item exists in the set +func (h MapSet[T]) Contains(item T) bool { + _, exists := h.m[item] + return exists +} + +// Intersection returns the items common to both h and o. +func (h MapSet[T]) Intersection(o MapSet[T]) MapSet[T] { + intersection := NewMapSet[T]() + for item := range h.m { + if o.Contains(item) { + intersection.Add(item) + } + } + return intersection +} + +// Iterate the items in the set, calling callback for each item. If the callback returns false, iteration is halted. +// Iteration order is undefined. +func (h MapSet[T]) Iterate(callback func(item T) bool) { + for item := range h.m { + if !callback(item) { + break + } + } +} + +func (h MapSet[T]) Slice() []T { + if h.m == nil { + return nil + } + return maps.Keys(h.m) +} + +// Len returns the size of the MapSet +func (h MapSet[T]) Len() int { + return len(h.m) +} + +// Equal returns whether the same items exist in both h and o +func (h MapSet[T]) Equal(o MapSet[T]) bool { + if len(h.m) != len(o.m) { + return false + } + + for item := range h.m { + if !o.Contains(item) { + return false + } + } + return true +} + +// MarshalJSON serializes a MapSet as a JSON array. Order is non-deterministic. +func (h MapSet[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(h.Slice()) +} + +// UnmarshalJSON deserializes a MapSet from a JSON array. +func (h *MapSet[T]) UnmarshalJSON(b []byte) error { + var s []T + if err := json.Unmarshal(b, &s); err != nil { + return err + } + + *h = NewMapSetFromSlice(s) + return nil +} diff --git a/internal/sets/mapset_test.go b/internal/sets/mapset_test.go new file mode 100644 index 0000000..c41b8c1 --- /dev/null +++ b/internal/sets/mapset_test.go @@ -0,0 +1,276 @@ +<<<<<<<< HEAD:internal/hashset_test.go +package internal +======== +package sets +>>>>>>>> 0f61d5f (fixup):internal/sets/mapset_test.go + +import ( + "encoding/json" + "slices" + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" +) + +func mustNotContain[T comparable](t *testing.T, s MapSet[T], item T) { + testutil.FatalIf(t, s.Contains(item), "set %v unexpectedly contained item %v", s, 1) +} + +func TestHashSet(t *testing.T) { + t.Run("empty set contains nothing", func(t *testing.T) { + s := MapSet[int]{} + mustNotContain(t, s, 1) + + s = NewMapSet[int]() + mustNotContain(t, s, 1) + + s = NewMapSet[int](10) + mustNotContain(t, s, 1) + }) + + t.Run("add => contains", func(t *testing.T) { + s := MapSet[int]{} + s.Add(1) + testutil.Equals(t, s.Contains(1), true) + }) + + t.Run("add twice", func(t *testing.T) { + s := MapSet[int]{} + testutil.Equals(t, s.Add(1), true) + testutil.Equals(t, s.Add(1), false) + }) + + t.Run("add slice", func(t *testing.T) { + s := MapSet[int]{} + s.AddSlice([]int{1, 2}) + testutil.Equals(t, s.Contains(1), true) + testutil.Equals(t, s.Contains(2), true) + mustNotContain(t, s, 3) + }) + + t.Run("add same slice", func(t *testing.T) { + s := MapSet[int]{} + testutil.Equals(t, s.AddSlice([]int{1, 2}), true) + testutil.Equals(t, s.AddSlice([]int{1, 2}), false) + }) + + t.Run("add disjoint slices", func(t *testing.T) { + s := MapSet[int]{} + testutil.Equals(t, s.AddSlice([]int{1, 2}), true) + testutil.Equals(t, s.AddSlice([]int{3, 4}), true) + testutil.Equals(t, s.AddSlice([]int{1, 2, 3, 4}), false) + }) + + t.Run("add overlapping slices", func(t *testing.T) { + s := MapSet[int]{} + testutil.Equals(t, s.AddSlice([]int{1, 2}), true) + testutil.Equals(t, s.AddSlice([]int{2, 3}), true) + testutil.Equals(t, s.AddSlice([]int{1, 3}), false) + }) + + t.Run("remove nonexistent", func(t *testing.T) { + s := MapSet[int]{} + testutil.Equals(t, s.Remove(1), false) + }) + + t.Run("remove existing", func(t *testing.T) { + s := MapSet[int]{} + s.Add(1) + testutil.Equals(t, s.Remove(1), true) + }) + + t.Run("remove => !contains", func(t *testing.T) { + s := MapSet[int]{} + s.Add(1) + s.Remove(1) + testutil.FatalIf(t, s.Contains(1), "set unexpectedly contained item") + }) + + t.Run("remove slice", func(t *testing.T) { + s := MapSet[int]{} + s.AddSlice([]int{1, 2, 3}) + s.RemoveSlice([]int{1, 2}) + mustNotContain(t, s, 1) + mustNotContain(t, s, 2) + testutil.Equals(t, s.Contains(3), true) + }) + + t.Run("remove non-existent slice", func(t *testing.T) { + s := MapSet[int]{} + testutil.Equals(t, s.RemoveSlice([]int{1, 2}), false) + }) + + t.Run("remove overlapping slice", func(t *testing.T) { + s := MapSet[int]{} + s.Add(1) + testutil.Equals(t, s.RemoveSlice([]int{1, 2}), true) + testutil.Equals(t, s.RemoveSlice([]int{1, 2}), false) + }) + + t.Run("new from slice", func(t *testing.T) { + s := NewMapSetFromSlice([]int{1, 2, 2, 3}) + testutil.Equals(t, s.Len(), 3) + testutil.Equals(t, s.Contains(1), true) + testutil.Equals(t, s.Contains(2), true) + testutil.Equals(t, s.Contains(3), true) + }) + + t.Run("slice", func(t *testing.T) { + s := MapSet[int]{} + testutil.Equals(t, s.Slice(), nil) + + s = NewMapSet[int]() + testutil.Equals(t, s.Slice(), nil) + + s = NewMapSet[int](10) + testutil.Equals(t, s.Slice(), []int{}) + + inSlice := []int{1, 2, 3} + s = NewMapSetFromSlice(inSlice) + outSlice := s.Slice() + slices.Sort(outSlice) + testutil.Equals(t, inSlice, outSlice) + }) + + t.Run("equal", func(t *testing.T) { + s1 := NewMapSetFromSlice([]int{1, 2, 3}) + testutil.Equals(t, s1.Equal(s1), true) + + s2 := NewMapSetFromSlice([]int{1, 2, 3}) + testutil.Equals(t, s1.Equal(s2), true) + + s2.Add(4) + testutil.Equals(t, s1.Equal(s2), false) + + s2.Remove(3) + testutil.Equals(t, s1.Equal(s2), false) + + s1.Add(4) + s1.Remove(3) + testutil.Equals(t, s1.Equal(s2), true) + }) + + t.Run("iterate", func(t *testing.T) { + s1 := NewMapSetFromSlice([]int{1, 2, 3}) + + var s2 MapSet[int] + s1.Iterate(func(item int) bool { + s2.Add(item) + return true + }) + + testutil.Equals(t, s1.Equal(s2), true) + }) + + t.Run("iterate break early", func(t *testing.T) { + s1 := NewMapSetFromSlice([]int{1, 2, 3}) + + i := 0 + var items []int + s1.Iterate(func(item int) bool { + if i == 2 { + return false + } + items = append(items, item) + i++ + return true + }) + + // Because iteration order is non-deterministic, all we can say is that the right number of items ended up in + // the set and that the items were in the original set. + testutil.Equals(t, len(items), 2) + testutil.Equals(t, s1.Contains(items[0]), true) + testutil.Equals(t, s1.Contains(items[1]), true) + }) + + t.Run("intersection with overlap", func(t *testing.T) { + s1 := NewMapSetFromSlice([]int{1, 2, 3}) + s2 := NewMapSetFromSlice([]int{2, 3, 4}) + + s3 := s1.Intersection(s2) + testutil.Equals(t, s3, NewMapSetFromSlice([]int{2, 3})) + + s4 := s1.Intersection(s2) + testutil.Equals(t, s4, NewMapSetFromSlice([]int{2, 3})) + }) + + t.Run("intersection disjoint", func(t *testing.T) { + s1 := NewMapSetFromSlice([]int{1, 2}) + s2 := NewMapSetFromSlice([]int{3, 4}) + + s3 := s1.Intersection(s2) + testutil.Equals(t, s3.Len(), 0) + + s4 := s1.Intersection(s2) + testutil.Equals(t, s4.Len(), 0) + }) + + t.Run("encode nil set", func(t *testing.T) { + s := NewMapSet[int]() + + out, err := json.Marshal(s) + + testutil.OK(t, err) + testutil.Equals(t, string(out), "[]") + }) + + t.Run("encode json", func(t *testing.T) { + s := NewMapSetFromSlice([]int{1, 2, 3}) + + out, err := json.Marshal(s) + + correctOutputs := []string{ + "[1,2,3]", + "[1,3,2]", + "[2,1,3]", + "[2,3,1]", + "[3,1,2]", + "[3,2,1]", + } + + testutil.OK(t, err) + testutil.FatalIf(t, !slices.Contains(correctOutputs, string(out)), "%v is not a valid output", string(out)) + }) + + t.Run("decode json", func(t *testing.T) { + var s1 MapSet[int] + err := s1.UnmarshalJSON([]byte("[2,3,1,2]")) + testutil.OK(t, err) + testutil.Equals(t, s1, NewMapSetFromSlice([]int{1, 2, 3})) + }) + + t.Run("decode json empty", func(t *testing.T) { + var s1 MapSet[int] + err := s1.UnmarshalJSON([]byte("[]")) + testutil.OK(t, err) + testutil.Equals(t, s1.Len(), 0) + }) + + t.Run("decode mixed types in array", func(t *testing.T) { + var s1 MapSet[int] + err := s1.UnmarshalJSON([]byte(`[2,3,1,"2"]`)) + testutil.Error(t, err) + testutil.Equals(t, err.Error(), "json: cannot unmarshal string into Go value of type int") + testutil.Equals(t, s1.Len(), 0) + }) + + t.Run("decode wrong type", func(t *testing.T) { + var s1 MapSet[int] + err := s1.UnmarshalJSON([]byte(`"1,2,3"`)) + testutil.Error(t, err) + testutil.Equals(t, err.Error(), "json: cannot unmarshal string into Go value of type []int") + testutil.Equals(t, s1.Len(), 0) + }) + + t.Run("panic if too many args", func(t *testing.T) { + t.Parallel() + + defer func() { + if r := recover(); r == nil { + t.Fatalf("code did not panic as expected") + } + }() + + NewMapSet[int](0, 1) + }) +} From 4fc3a7486ad13eca9e85e9ea4e942cbb7db58edb Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 19 Sep 2024 16:52:35 -0700 Subject: [PATCH 02/17] cedar-go/types: convert Entity.Parents to a MapSet[EntityUID] Signed-off-by: Patrick Jakubowski --- authorize_test.go | 462 +++++++++++++++++----------------- internal/eval/evalers.go | 59 ++--- internal/eval/evalers_test.go | 27 +- internal/eval/partial.go | 6 +- internal/sets/mapset.go | 7 +- internal/sets/mapset_test.go | 4 - types.go | 12 + types/entities.go | 27 -- types/entities_test.go | 49 ++-- types/entity.go | 31 +++ types/entity_test.go | 29 +++ types/entity_uid.go | 15 ++ types/entity_uid_test.go | 34 +++ 13 files changed, 419 insertions(+), 343 deletions(-) create mode 100644 types/entity.go create mode 100644 types/entity_test.go diff --git a/authorize_test.go b/authorize_test.go index 8db8582..7e43a62 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -1,68 +1,68 @@ -package cedar +package cedar_test import ( "testing" + "github.com/cedar-policy/cedar-go" "github.com/cedar-policy/cedar-go/internal/testutil" - "github.com/cedar-policy/cedar-go/types" ) //nolint:revive // due to table test function-length func TestIsAuthorized(t *testing.T) { t.Parallel() - cuzco := NewEntityUID("coder", "cuzco") - dropTable := NewEntityUID("table", "drop") + cuzco := cedar.NewEntityUID("coder", "cuzco") + dropTable := cedar.NewEntityUID("table", "drop") tests := []struct { Name string Policy string - Entities Entities - Principal, Action, Resource EntityUID - Context Record - Want Decision + Entities cedar.Entities + Principal, Action, Resource cedar.EntityUID + Context cedar.Record + Want cedar.Decision DiagErr int ParseErr bool }{ { Name: "simple-permit", Policy: `permit(principal,action,resource);`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "simple-forbid", Policy: `forbid(principal,action,resource);`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 0, }, { Name: "no-permit", Policy: `permit(principal,action,resource in asdf::"1234");`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 0, }, { Name: "error-in-policy", Policy: `permit(principal,action,resource) when { resource in "foo" };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, @@ -71,343 +71,343 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { resource in "foo" }; permit(principal,action,resource); `, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 1, }, { Name: "permit-requires-context-success", Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: NewRecord(RecordMap{"x": Long(42)}), + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(42)}), Want: true, DiagErr: 0, }, { Name: "permit-requires-context-fail", Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: NewRecord(RecordMap{"x": Long(43)}), + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(43)}), Want: false, DiagErr: 0, }, { Name: "permit-requires-entities-success", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: Entities{ - cuzco: &Entity{ + Entities: cedar.Entities{ + cuzco: &cedar.Entity{ UID: cuzco, - Attributes: NewRecord(RecordMap{"x": Long(42)}), + Attributes: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(42)}), }, }, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-requires-entities-fail", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: Entities{ - cuzco: &Entity{ + Entities: cedar.Entities{ + cuzco: &cedar.Entity{ UID: cuzco, - Attributes: NewRecord(types.RecordMap{"x": Long(43)}), + Attributes: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(43)}), }, }, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 0, }, { Name: "permit-requires-entities-parent-success", Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, - Entities: Entities{ - cuzco: &Entity{ + Entities: cedar.Entities{ + cuzco: &cedar.Entity{ UID: cuzco, - Parents: []EntityUID{types.NewEntityUID("parent", "bob")}, + Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("parent", "bob")}), }, }, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-principal-equals", Policy: `permit(principal == coder::"cuzco",action,resource);`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-principal-in", Policy: `permit(principal in team::"osiris",action,resource);`, - Entities: Entities{ - cuzco: &Entity{ + Entities: cedar.Entities{ + cuzco: &cedar.Entity{ UID: cuzco, - Parents: []EntityUID{types.NewEntityUID("team", "osiris")}, + Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("team", "osiris")}), }, }, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-action-equals", Policy: `permit(principal,action == table::"drop",resource);`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-action-in", Policy: `permit(principal,action in scary::"stuff",resource);`, - Entities: Entities{ - dropTable: &Entity{ + Entities: cedar.Entities{ + dropTable: &cedar.Entity{ UID: dropTable, - Parents: []EntityUID{types.NewEntityUID("scary", "stuff")}, + Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("scary", "stuff")}), }, }, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-action-in-set", Policy: `permit(principal,action in [scary::"stuff"],resource);`, - Entities: Entities{ - dropTable: &Entity{ + Entities: cedar.Entities{ + dropTable: &cedar.Entity{ UID: dropTable, - Parents: []EntityUID{types.NewEntityUID("scary", "stuff")}, + Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("scary", "stuff")}), }, }, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-resource-equals", Policy: `permit(principal,action,resource == table::"whatever");`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-unless", Policy: `permit(principal,action,resource) unless { false };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-if", Policy: `permit(principal,action,resource) when { (if true then true else true) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-or", Policy: `permit(principal,action,resource) when { (true || false) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-and", Policy: `permit(principal,action,resource) when { (true && true) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-relations", Policy: `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-relations-in", Policy: `permit(principal,action,resource) when { principal in principal };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-relations-has", Policy: `permit(principal,action,resource) when { principal has name };`, - Entities: Entities{ - cuzco: &Entity{ + Entities: cedar.Entities{ + cuzco: &cedar.Entity{ UID: cuzco, - Attributes: NewRecord(types.RecordMap{"name": String("bob")}), + Attributes: cedar.NewRecord(cedar.RecordMap{"name": cedar.String("bob")}), }, }, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-add-sub", Policy: `permit(principal,action,resource) when { 40+3-1==42 };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-mul", Policy: `permit(principal,action,resource) when { 6*7==42 };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-negate", Policy: `permit(principal,action,resource) when { -42==-42 };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-not", Policy: `permit(principal,action,resource) when { !(1+1==42) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-set", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-record", Policy: `permit(principal,action,resource) when { {name:"bob"} has name };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-action", Policy: `permit(principal,action,resource) when { action in action };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-set-contains-ok", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-set-contains-error", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 0, ParseErr: true, @@ -415,22 +415,22 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAll-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-set-containsAll-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 0, ParseErr: true, @@ -438,22 +438,22 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAny-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-set-containsAny-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 0, ParseErr: true, @@ -461,22 +461,22 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-record-attr", Policy: `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-unknown-method", Policy: `permit(principal,action,resource) when { [1,2,3].shuffle() };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 0, ParseErr: true, @@ -484,22 +484,22 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-like", Policy: `permit(principal,action,resource) when { "bananas" like "*nan*" };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-unknown-ext-fun", Policy: `permit(principal,action,resource) when { fooBar("10") };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 0, ParseErr: true, @@ -511,22 +511,22 @@ func TestIsAuthorized(t *testing.T) { decimal("10.0").lessThanOrEqual(decimal("11.0")) && decimal("10.0").greaterThan(decimal("9.0")) && decimal("10.0").greaterThanOrEqual(decimal("9.0")) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-decimal-fun-wrong-arity", Policy: `permit(principal,action,resource) when { decimal(1, 2) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, @@ -538,22 +538,22 @@ func TestIsAuthorized(t *testing.T) { datetime("1970-01-01T09:08:07Z") > (datetime("1970-01-01")) && datetime("1970-01-01T09:08:07Z") >= (datetime("1970-01-01")) && datetime("1970-01-01T09:08:07Z").toDate() == datetime("1970-01-01")};`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-datetime-fun-wrong-arity", Policy: `permit(principal,action,resource) when { datetime("1970-01-01", "UTC") };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, @@ -573,22 +573,22 @@ func TestIsAuthorized(t *testing.T) { datetime("1970-01-01").offset(duration("1ms")).toTime() == duration("1ms") && datetime("1970-01-01T00:00:00.001Z").durationSince(datetime("1970-01-01")) == duration("1ms")};`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-duration-fun-wrong-arity", Policy: `permit(principal,action,resource) when { duration("1h", "huh?") };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, @@ -600,183 +600,183 @@ func TestIsAuthorized(t *testing.T) { ip("::1").isLoopback() && ip("224.1.2.3").isMulticast() && ip("127.0.0.1").isInRange(ip("127.0.0.0/16"))};`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "permit-when-ip-fun-wrong-arity", Policy: `permit(principal,action,resource) when { ip() };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, { Name: "permit-when-isIpv4-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, { Name: "permit-when-isIpv6-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, { Name: "permit-when-isLoopback-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, { Name: "permit-when-isMulticast-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, { Name: "permit-when-isInRange-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, - Entities: Entities{}, + Entities: cedar.Entities{}, Principal: cuzco, Action: dropTable, - Resource: NewEntityUID("table", "whatever"), - Context: Record{}, + Resource: cedar.NewEntityUID("table", "whatever"), + Context: cedar.Record{}, Want: false, DiagErr: 1, }, { Name: "negative-unary-op", Policy: `permit(principal,action,resource) when { -context.value > 0 };`, - Entities: Entities{}, - Context: NewRecord(RecordMap{"value": Long(-42)}), + Entities: cedar.Entities{}, + Context: cedar.NewRecord(cedar.RecordMap{"value": cedar.Long(-42)}), Want: true, DiagErr: 0, }, { Name: "principal-is", Policy: `permit(principal is Actor,action,resource);`, - Entities: Entities{}, - Principal: NewEntityUID("Actor", "cuzco"), - Action: NewEntityUID("Action", "drop"), - Resource: NewEntityUID("Resource", "table"), - Context: Record{}, + Entities: cedar.Entities{}, + Principal: cedar.NewEntityUID("Actor", "cuzco"), + Action: cedar.NewEntityUID("Action", "drop"), + Resource: cedar.NewEntityUID("Resource", "table"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "principal-is-in", Policy: `permit(principal is Actor in Actor::"cuzco",action,resource);`, - Entities: Entities{}, - Principal: NewEntityUID("Actor", "cuzco"), - Action: NewEntityUID("Action", "drop"), - Resource: NewEntityUID("Resource", "table"), - Context: Record{}, + Entities: cedar.Entities{}, + Principal: cedar.NewEntityUID("Actor", "cuzco"), + Action: cedar.NewEntityUID("Action", "drop"), + Resource: cedar.NewEntityUID("Resource", "table"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "resource-is", Policy: `permit(principal,action,resource is Resource);`, - Entities: Entities{}, - Principal: NewEntityUID("Actor", "cuzco"), - Action: NewEntityUID("Action", "drop"), - Resource: NewEntityUID("Resource", "table"), - Context: Record{}, + Entities: cedar.Entities{}, + Principal: cedar.NewEntityUID("Actor", "cuzco"), + Action: cedar.NewEntityUID("Action", "drop"), + Resource: cedar.NewEntityUID("Resource", "table"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "resource-is-in", Policy: `permit(principal,action,resource is Resource in Resource::"table");`, - Entities: Entities{}, - Principal: NewEntityUID("Actor", "cuzco"), - Action: NewEntityUID("Action", "drop"), - Resource: NewEntityUID("Resource", "table"), - Context: Record{}, + Entities: cedar.Entities{}, + Principal: cedar.NewEntityUID("Actor", "cuzco"), + Action: cedar.NewEntityUID("Action", "drop"), + Resource: cedar.NewEntityUID("Resource", "table"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "when-is", Policy: `permit(principal,action,resource) when { resource is Resource };`, - Entities: Entities{}, - Principal: NewEntityUID("Actor", "cuzco"), - Action: NewEntityUID("Action", "drop"), - Resource: NewEntityUID("Resource", "table"), - Context: Record{}, + Entities: cedar.Entities{}, + Principal: cedar.NewEntityUID("Actor", "cuzco"), + Action: cedar.NewEntityUID("Action", "drop"), + Resource: cedar.NewEntityUID("Resource", "table"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Resource::"table" };`, - Entities: Entities{}, - Principal: NewEntityUID("Actor", "cuzco"), - Action: NewEntityUID("Action", "drop"), - Resource: NewEntityUID("Resource", "table"), - Context: Record{}, + Entities: cedar.Entities{}, + Principal: cedar.NewEntityUID("Actor", "cuzco"), + Action: cedar.NewEntityUID("Action", "drop"), + Resource: cedar.NewEntityUID("Resource", "table"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, - Entities: Entities{ - NewEntityUID("Resource", "table"): &Entity{ - UID: NewEntityUID("Resource", "table"), - Parents: []EntityUID{types.NewEntityUID("Parent", "id")}, + Entities: cedar.Entities{ + cedar.NewEntityUID("Resource", "table"): &cedar.Entity{ + UID: cedar.NewEntityUID("Resource", "table"), + Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("Parent", "id")}), }, }, - Principal: NewEntityUID("Actor", "cuzco"), - Action: NewEntityUID("Action", "drop"), - Resource: NewEntityUID("Resource", "table"), - Context: Record{}, + Principal: cedar.NewEntityUID("Actor", "cuzco"), + Action: cedar.NewEntityUID("Action", "drop"), + Resource: cedar.NewEntityUID("Resource", "table"), + Context: cedar.Record{}, Want: true, DiagErr: 0, }, { Name: "rfc-57", // https://github.com/cedar-policy/rfcs/blob/main/text/0057-general-multiplication.md Policy: `permit(principal, action, resource) when { context.foo * principal.bar >= 100 };`, - Entities: Entities{ - NewEntityUID("Principal", "1"): &Entity{ - UID: NewEntityUID("Principal", "1"), - Attributes: NewRecord(types.RecordMap{"bar": Long(42)}), + Entities: cedar.Entities{ + cedar.NewEntityUID("Principal", "1"): &cedar.Entity{ + UID: cedar.NewEntityUID("Principal", "1"), + Attributes: cedar.NewRecord(cedar.RecordMap{"bar": cedar.Long(42)}), }, }, - Principal: NewEntityUID("Principal", "1"), - Action: NewEntityUID("Action", "action"), - Resource: NewEntityUID("Resource", "resource"), - Context: NewRecord(RecordMap{"foo": Long(43)}), + Principal: cedar.NewEntityUID("Principal", "1"), + Action: cedar.NewEntityUID("Action", "action"), + Resource: cedar.NewEntityUID("Resource", "resource"), + Context: cedar.NewRecord(cedar.RecordMap{"foo": cedar.Long(43)}), Want: true, DiagErr: 0, }, @@ -785,9 +785,9 @@ func TestIsAuthorized(t *testing.T) { tt := tt t.Run(tt.Name, func(t *testing.T) { t.Parallel() - ps, err := NewPolicySetFromBytes("policy.cedar", []byte(tt.Policy)) + ps, err := cedar.NewPolicySetFromBytes("policy.cedar", []byte(tt.Policy)) testutil.Equals(t, err != nil, tt.ParseErr) - ok, diag := ps.IsAuthorized(tt.Entities, Request{ + ok, diag := ps.IsAuthorized(tt.Entities, cedar.Request{ Principal: tt.Principal, Action: tt.Action, Resource: tt.Resource, diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 46f938e..51ddae9 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -6,6 +6,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" + "github.com/cedar-policy/cedar-go/internal/sets" "github.com/cedar-policy/cedar-go/types" ) @@ -973,11 +974,6 @@ func newInEval(lhs, rhs Evaler) Evaler { return &inEval{lhs: lhs, rhs: rhs} } -func hasKnown(known map[types.EntityUID]struct{}, k types.EntityUID) bool { - _, ok := known[k] - return ok -} - func entityInOne(env *Env, entity types.EntityUID, parent types.EntityUID) bool { key := inKey{a: entity, b: parent} if cached, ok := env.inCache[key]; ok { @@ -987,31 +983,28 @@ func entityInOne(env *Env, entity types.EntityUID, parent types.EntityUID) bool env.inCache[key] = result return result } + func entityInOneWork(env *Env, entity types.EntityUID, parent types.EntityUID) bool { if entity == parent { return true } - var known map[types.EntityUID]struct{} + var known sets.MapSet[types.EntityUID] var todo []types.EntityUID var candidate = entity for { if fe, ok := env.Entities[candidate]; ok { - for _, k := range fe.Parents { - if k == parent { - return true - } + if fe.Parents.Contains(parent) { + return true } - for _, k := range fe.Parents { + fe.Parents.Iterate(func(k types.EntityUID) bool { p, ok := env.Entities[k] - if !ok || len(p.Parents) == 0 || k == entity || hasKnown(known, k) { - continue + if !ok || p.Parents.Len() == 0 || k == entity || known.Contains(k) { + return true } todo = append(todo, k) - if known == nil { - known = map[types.EntityUID]struct{}{} - } - known[k] = struct{}{} - } + known.Add(k) + return true + }) } if len(todo) == 0 { return false @@ -1020,31 +1013,27 @@ func entityInOneWork(env *Env, entity types.EntityUID, parent types.EntityUID) b } } -func entityInSet(env *Env, entity types.EntityUID, parents map[types.EntityUID]struct{}) bool { - if _, ok := parents[entity]; ok { +func entityInSet(env *Env, entity types.EntityUID, parents types.EntityUIDSet) bool { + if parents.Contains(entity) { return true } - var known map[types.EntityUID]struct{} + var known sets.MapSet[types.EntityUID] var todo []types.EntityUID var candidate = entity for { if fe, ok := env.Entities[candidate]; ok { - for _, k := range fe.Parents { - if _, ok := parents[k]; ok { - return true - } + if fe.Parents.Intersection(parents).Len() > 0 { + return true } - for _, k := range fe.Parents { + fe.Parents.Iterate(func(k types.EntityUID) bool { p, ok := env.Entities[k] - if !ok || len(p.Parents) == 0 || k == entity || hasKnown(known, k) { - continue + if !ok || p.Parents.Len() == 0 || k == entity || known.Contains(k) { + return true } todo = append(todo, k) - if known == nil { - known = map[types.EntityUID]struct{}{} - } - known[k] = struct{}{} - } + known.Add(k) + return true + }) } if len(todo) == 0 { return false @@ -1072,14 +1061,14 @@ func doInEval(env *Env, lhs types.EntityUID, rhs types.Value) (types.Value, erro case types.EntityUID: return types.Boolean(entityInOne(env, lhs, rhsv)), nil case types.Set: - query := make(map[types.EntityUID]struct{}, rhsv.Len()) + query := sets.NewMapSet[types.EntityUID](rhsv.Len()) var err error rhsv.Iterate(func(rhv types.Value) bool { var e types.EntityUID if e, err = ValueToEntity(rhv); err != nil { return false } - query[e] = struct{}{} + query.Add(e) return true }) if err != nil { diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index c7a200a..54a3a04 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -9,6 +9,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/parser" + "github.com/cedar-policy/cedar-go/internal/sets" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -1833,15 +1834,15 @@ func TestEntityIn(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - rhs := map[types.EntityUID]struct{}{} + var rhs sets.MapSet[types.EntityUID] for _, v := range tt.rhs { - rhs[strEnt(v)] = struct{}{} + rhs.Add(strEnt(v)) } entityMap := types.Entities{} for k, p := range tt.parents { - var ps []types.EntityUID + var ps sets.MapSet[types.EntityUID] for _, pp := range p { - ps = append(ps, strEnt(pp)) + ps.Add(strEnt(pp)) } uid := strEnt(k) entityMap[uid] = &types.Entity{ @@ -1860,10 +1861,10 @@ func TestEntityIn(t *testing.T) { entityMap := types.Entities{} for i := 0; i < 100; i++ { - p := []types.EntityUID{ + p := sets.NewMapSetFromSlice([]types.EntityUID{ types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "1"), types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "2"), - } + }) uid1 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "1") entityMap[uid1] = &types.Entity{ UID: uid1, @@ -1876,7 +1877,11 @@ func TestEntityIn(t *testing.T) { } } - res := entityInSet(&Env{Entities: entityMap}, types.NewEntityUID("0", "1"), map[types.EntityUID]struct{}{types.NewEntityUID("0", "3"): {}}) + res := entityInSet( + &Env{Entities: entityMap}, + types.NewEntityUID("0", "1"), + sets.NewMapSetFromSlice([]types.EntityUID{types.NewEntityUID("0", "3")}), + ) testutil.Equals(t, res, false) }) } @@ -2005,9 +2010,9 @@ func TestInNode(t *testing.T) { n := newInEval(tt.lhs, tt.rhs) entityMap := types.Entities{} for k, p := range tt.parents { - var ps []types.EntityUID + var ps sets.MapSet[types.EntityUID] for _, pp := range p { - ps = append(ps, strEnt(pp)) + ps.Add(strEnt(pp)) } uid := strEnt(k) entityMap[uid] = &types.Entity{ @@ -2145,9 +2150,9 @@ func TestIsInNode(t *testing.T) { n := newIsInEval(tt.lhs, tt.is, tt.rhs) entityMap := types.Entities{} for k, p := range tt.parents { - var ps []types.EntityUID + var ps sets.MapSet[types.EntityUID] for _, pp := range p { - ps = append(ps, strEnt(pp)) + ps.Add(strEnt(pp)) } uid := strEnt(k) entityMap[uid] = &types.Entity{ diff --git a/internal/eval/partial.go b/internal/eval/partial.go index a086a89..5f73fce 100644 --- a/internal/eval/partial.go +++ b/internal/eval/partial.go @@ -6,6 +6,7 @@ import ( "slices" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/sets" "github.com/cedar-policy/cedar-go/types" ) @@ -140,10 +141,7 @@ func partialScopeEval(env *Env, ent types.Value, in ast.IsScopeNode) (evaled boo case ast.ScopeTypeIn: return true, entityInOne(env, e, t.Entity) case ast.ScopeTypeInSet: - set := make(map[types.EntityUID]struct{}, len(t.Entities)) - for _, e := range t.Entities { - set[e] = struct{}{} - } + set := sets.NewMapSetFromSlice(t.Entities) return true, entityInSet(env, e, set) case ast.ScopeTypeIs: return true, e.Type == t.Type diff --git a/internal/sets/mapset.go b/internal/sets/mapset.go index 5a0ee5a..66bcc80 100644 --- a/internal/sets/mapset.go +++ b/internal/sets/mapset.go @@ -1,8 +1,4 @@ -<<<<<<<< HEAD:internal/hashset.go -package internal -======== package sets ->>>>>>>> 0f61d5f (fixup):internal/sets/mapset.go import ( "encoding/json" @@ -139,6 +135,9 @@ func (h MapSet[T]) Equal(o MapSet[T]) bool { // MarshalJSON serializes a MapSet as a JSON array. Order is non-deterministic. func (h MapSet[T]) MarshalJSON() ([]byte, error) { + if h.m == nil { + return []byte("[]"), nil + } return json.Marshal(h.Slice()) } diff --git a/internal/sets/mapset_test.go b/internal/sets/mapset_test.go index c41b8c1..fab97af 100644 --- a/internal/sets/mapset_test.go +++ b/internal/sets/mapset_test.go @@ -1,8 +1,4 @@ -<<<<<<<< HEAD:internal/hashset_test.go -package internal -======== package sets ->>>>>>>> 0f61d5f (fixup):internal/sets/mapset_test.go import ( "encoding/json" diff --git a/types.go b/types.go index 4a5b364..6378b55 100644 --- a/types.go +++ b/types.go @@ -32,6 +32,7 @@ type String = types.String type Entities = types.Entities type Entity = types.Entity type EntityType = types.EntityType +type EntityUIDSet = types.EntityUIDSet type Pattern = types.Pattern type Wildcard = types.Wildcard @@ -85,6 +86,17 @@ func NewEntityUID(typ EntityType, id String) EntityUID { return types.NewEntityUID(typ, id) } +// NewEntityUIDSet returns an EntityUIDSet ready for use. Optionally, a desired size for the set can be passed as an +// argument, as in the argument to make() for a map type. +func NewEntityUIDSet(args ...int) EntityUIDSet { + return types.NewEntityUIDSet(args...) +} + +// NewEntityUIDSetFromSlice creates an EntityUIDSet of size len(items) and calls AddSlice(items) on it. +func NewEntityUIDSetFromSlice(items []EntityUID) EntityUIDSet { + return types.NewEntityUIDSetFromSlice(items) +} + // NewPattern permits for the programmatic construction of a Pattern out of a slice of pattern components. // The pattern components may be one of string, cedar.String, or cedar.Wildcard. Any other types will // cause a panic. diff --git a/types/entities.go b/types/entities.go index 6cb8f7c..27fb854 100644 --- a/types/entities.go +++ b/types/entities.go @@ -13,33 +13,6 @@ import ( // the Entity (it must be the same as the UID within the Entity itself.) type Entities map[EntityUID]*Entity -// An Entity defines the parents and attributes for an EntityUID. -type Entity struct { - UID EntityUID `json:"uid"` - Parents []EntityUID `json:"parents"` - Attributes Record `json:"attrs"` -} - -// MarshalJSON serializes Entity as a JSON object, using the implicit form of EntityUID encoding to match the Rust -// SDK's behavior. -func (e Entity) MarshalJSON() ([]byte, error) { - parents := make([]ImplicitlyMarshaledEntityUID, len(e.Parents)) - for i, p := range e.Parents { - parents[i] = ImplicitlyMarshaledEntityUID(p) - } - - m := struct { - UID ImplicitlyMarshaledEntityUID `json:"uid"` - Parents []ImplicitlyMarshaledEntityUID `json:"parents"` - Attributes Record `json:"attrs"` - }{ - ImplicitlyMarshaledEntityUID(e.UID), - parents, - e.Attributes, - } - return json.Marshal(m) -} - func (e Entities) MarshalJSON() ([]byte, error) { s := maps.Values(e) slices.SortFunc(s, func(a, b *Entity) int { diff --git a/types/entities_test.go b/types/entities_test.go index bcff61e..6972b08 100644 --- a/types/entities_test.go +++ b/types/entities_test.go @@ -1,9 +1,11 @@ package types_test import ( + "bytes" "encoding/json" "testing" + "github.com/cedar-policy/cedar-go/internal/sets" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -24,6 +26,16 @@ func TestEntities(t *testing.T) { } +func assertJSONEquals(t *testing.T, e any, want string) { + b, err := json.MarshalIndent(e, "", "\t") + testutil.OK(t, err) + + var wantBuf bytes.Buffer + err = json.Indent(&wantBuf, []byte(want), "", "\t") + testutil.OK(t, err) + testutil.Equals(t, string(b), wantBuf.String()) +} + func TestEntitiesJSON(t *testing.T) { t.Parallel() t.Run("Marshal", func(t *testing.T) { @@ -31,19 +43,23 @@ func TestEntitiesJSON(t *testing.T) { e := types.Entities{} ent := &types.Entity{ UID: types.NewEntityUID("Type", "id"), - Parents: []types.EntityUID{}, + Parents: sets.MapSet[types.EntityUID]{}, Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), } ent2 := &types.Entity{ UID: types.NewEntityUID("Type", "id2"), - Parents: []types.EntityUID{ent.UID}, + Parents: sets.NewMapSetFromSlice([]types.EntityUID{ent.UID}), Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), } e[ent.UID] = ent e[ent2.UID] = ent2 - b, err := e.MarshalJSON() - testutil.OK(t, err) - testutil.Equals(t, string(b), `[{"uid":{"type":"Type","id":"id"},"parents":[],"attrs":{"key":42}},{"uid":{"type":"Type","id":"id2"},"parents":[{"type":"Type","id":"id"}],"attrs":{"key":42}}]`) + assertJSONEquals( + t, + e, + `[ + {"uid": {"type": "Type", "id": "id"}, "parents": [], "attrs": {"key": 42}}, + {"uid": {"type": "Type" ,"id" :"id2"}, "parents": [{"type":"Type","id":"id"}], "attrs": {"key": 42}} + ]`) }) t.Run("Unmarshal", func(t *testing.T) { @@ -55,7 +71,7 @@ func TestEntitiesJSON(t *testing.T) { want := types.Entities{} ent := &types.Entity{ UID: types.NewEntityUID("Type", "id"), - Parents: []types.EntityUID{}, + Parents: sets.MapSet[types.EntityUID]{}, Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), } want[ent.UID] = ent @@ -69,24 +85,3 @@ func TestEntitiesJSON(t *testing.T) { testutil.Error(t, err) }) } - -func TestEntityIsZero(t *testing.T) { - t.Parallel() - tests := []struct { - name string - uid types.EntityUID - want bool - }{ - {"empty", types.EntityUID{}, true}, - {"empty-type", types.NewEntityUID("one", ""), false}, - {"empty-id", types.NewEntityUID("", "one"), false}, - {"not-empty", types.NewEntityUID("one", "two"), false}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - testutil.Equals(t, tt.uid.IsZero(), tt.want) - }) - } -} diff --git a/types/entity.go b/types/entity.go new file mode 100644 index 0000000..76b58c5 --- /dev/null +++ b/types/entity.go @@ -0,0 +1,31 @@ +package types + +import "encoding/json" + +// An Entity defines the parents and attributes for an EntityUID. +type Entity struct { + UID EntityUID `json:"uid"` + Parents EntityUIDSet `json:"parents"` + Attributes Record `json:"attrs"` +} + +// MarshalJSON serializes Entity as a JSON object, using the implicit form of EntityUID encoding to match the Rust +// SDK's behavior. +func (e Entity) MarshalJSON() ([]byte, error) { + parents := make([]ImplicitlyMarshaledEntityUID, 0, e.Parents.Len()) + e.Parents.Iterate(func(p EntityUID) bool { + parents = append(parents, ImplicitlyMarshaledEntityUID(p)) + return true + }) + + m := struct { + UID ImplicitlyMarshaledEntityUID `json:"uid"` + Parents []ImplicitlyMarshaledEntityUID `json:"parents"` + Attributes Record `json:"attrs"` + }{ + ImplicitlyMarshaledEntityUID(e.UID), + parents, + e.Attributes, + } + return json.Marshal(m) +} diff --git a/types/entity_test.go b/types/entity_test.go new file mode 100644 index 0000000..b9f6e0d --- /dev/null +++ b/types/entity_test.go @@ -0,0 +1,29 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestEntityIsZero(t *testing.T) { + t.Parallel() + tests := []struct { + name string + uid types.EntityUID + want bool + }{ + {"empty", types.EntityUID{}, true}, + {"empty-type", types.NewEntityUID("one", ""), false}, + {"empty-id", types.NewEntityUID("", "one"), false}, + {"not-empty", types.NewEntityUID("one", "two"), false}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + testutil.Equals(t, tt.uid.IsZero(), tt.want) + }) + } +} diff --git a/types/entity_uid.go b/types/entity_uid.go index 0c06e57..19946b8 100644 --- a/types/entity_uid.go +++ b/types/entity_uid.go @@ -4,6 +4,8 @@ import ( "encoding/json" "hash/fnv" "strconv" + + "github.com/cedar-policy/cedar-go/internal/sets" ) // Path is a series of idents separated by :: @@ -91,3 +93,16 @@ func (i ImplicitlyMarshaledEntityUID) MarshalJSON() ([]byte, error) { }{i.Type, i.ID} return json.Marshal(s) } + +type EntityUIDSet = sets.MapSet[EntityUID] + +// NewEntityUIDSet returns an EntityUIDSet ready for use. Optionally, a desired size for the EntityUIDSet can be passed +// as an argument, as in the argument to make() for a map type. +func NewEntityUIDSet(args ...int) EntityUIDSet { + return sets.NewMapSet[EntityUID](args...) +} + +// NewEntityUIDSetFromSlice creates a EntityUIDSet of size len(items) and calls AddSlice(items) on it. +func NewEntityUIDSetFromSlice(items []EntityUID) EntityUIDSet { + return sets.NewMapSetFromSlice[EntityUID](items) +} diff --git a/types/entity_uid_test.go b/types/entity_uid_test.go index 259dd90..81ab55b 100644 --- a/types/entity_uid_test.go +++ b/types/entity_uid_test.go @@ -31,3 +31,37 @@ func TestEntity(t *testing.T) { testutil.Equals(t, string(types.EntityUID{"type", "id"}.MarshalCedar()), `type::"id"`) }) } + +func TestEntityUIDSet(t *testing.T) { + t.Parallel() + + t.Run("new empty set", func(t *testing.T) { + emptySets := []types.EntityUIDSet{ + types.NewEntityUIDSet(), + types.NewEntityUIDSet(0), + types.NewEntityUIDSet(1), + types.NewEntityUIDSetFromSlice(nil), + types.NewEntityUIDSetFromSlice([]types.EntityUID{}), + } + + for _, es := range emptySets { + testutil.Equals(t, es.Len(), 0) + testutil.Equals(t, emptySets[0].Equal(es), true) + testutil.Equals(t, es.Equal(emptySets[0]), true) + } + }) + + t.Run("new set from slice", func(t *testing.T) { + a := types.NewEntityUID("typeA", "1") + b := types.NewEntityUID("typeB", "2") + o := types.NewEntityUID("typeO", "2") + s1 := types.NewEntityUIDSet() + s1.Add(a) + s1.Add(b) + s1.Add(o) + + s2 := types.NewEntityUIDSetFromSlice([]types.EntityUID{o, b, a}) + + testutil.Equals(t, s1.Equal(s2), true) + }) +} From c2ef4304968f7f6caa18e2327461f687246df3df Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 20 Sep 2024 11:57:57 -0700 Subject: [PATCH 03/17] cedar-go/internal/parser: use a MapSet[string] instead of a map[string]struct{} for annotation names Signed-off-by: Patrick Jakubowski --- internal/parser/cedar_unmarshal.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index cce1c17..3745ff0 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -8,6 +8,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" + "github.com/cedar-policy/cedar-go/internal/sets" "github.com/cedar-policy/cedar-go/types" ) @@ -127,10 +128,10 @@ func (p *parser) errorf(s string, args ...interface{}) error { func (p *parser) annotations() (ast.Annotations, error) { var res ast.Annotations - known := map[string]struct{}{} + var known sets.MapSet[string] for p.peek().Text == "@" { p.advance() - err := p.annotation(&res, known) + err := p.annotation(&res, &known) if err != nil { return res, err } @@ -139,11 +140,11 @@ func (p *parser) annotations() (ast.Annotations, error) { } -func (p *parser) annotation(a *ast.Annotations, known map[string]struct{}) error { +func (p *parser) annotation(a *ast.Annotations, known *sets.MapSet[string]) error { var err error t := p.advance() - // As of 2024-09-13, the ability to use reserved keywords is not documented in the Cedar schema. The ability to use - // reserved keywords was added in this commit: + // As of 2024-09-13, the ability to use reserved keywords for annotation keys is not documented in the Cedar schema. + // This ability was added to the Rust implementation in this commit: // https://github.com/cedar-policy/cedar/commit/5f62c6df06b59abc5634d6668198a826839c6fb7 if !(t.isIdent() || t.isReservedKeyword()) { return p.errorf("expected ident or reserved keyword") @@ -152,10 +153,10 @@ func (p *parser) annotation(a *ast.Annotations, known map[string]struct{}) error if err = p.exact("("); err != nil { return err } - if _, ok := known[name]; ok { + if known.Contains(name) { return p.errorf("duplicate annotation: @%s", name) } - known[name] = struct{}{} + known.Add(name) t = p.advance() if !t.isString() { return p.errorf("expected string") From 58729be26da90ecf2e29de1de984676f20c109d0 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 20 Sep 2024 11:59:31 -0700 Subject: [PATCH 04/17] internal/parser: use a MapSet[string] instead of a map[string]struct{} for record keys Signed-off-by: Patrick Jakubowski --- internal/parser/cedar_unmarshal.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 3745ff0..da26258 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -816,7 +816,7 @@ func (p *parser) expressions(endOfListMarker string) ([]ast.Node, error) { func (p *parser) record() (ast.Node, error) { var res ast.Node var elements ast.Pairs - known := map[string]struct{}{} + var known sets.MapSet[string] for { t := p.peek() if t.Text == "}" { @@ -833,10 +833,10 @@ func (p *parser) record() (ast.Node, error) { return res, err } - if _, ok := known[k]; ok { + if known.Contains(k) { return res, p.errorf("duplicate key: %v", k) } - known[k] = struct{}{} + known.Add(k) elements = append(elements, ast.Pair{Key: types.String(k), Value: v}) } } From 32380ba8f3e3fedc2a21a6c8984a329f0fddb6a5 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 20 Sep 2024 12:04:05 -0700 Subject: [PATCH 05/17] x/exp/batch: use a MapSet[types.String] instead of map[types.String]{} for detecting unbound or unused variables Signed-off-by: Patrick Jakubowski --- x/exp/batch/batch.go | 27 +++++++++++++++++---------- x/exp/batch/batch_test.go | 19 ++++++++++--------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/x/exp/batch/batch.go b/x/exp/batch/batch.go index 5de52f1..3fd9edb 100644 --- a/x/exp/batch/batch.go +++ b/x/exp/batch/batch.go @@ -10,6 +10,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/eval" + "github.com/cedar-policy/cedar-go/internal/sets" "github.com/cedar-policy/cedar-go/types" ) @@ -103,18 +104,24 @@ var errInvalidPart = fmt.Errorf("invalid part") // The result passed to the callback must be used / cloned immediately and not modified. func Authorize(ctx context.Context, ps *cedar.PolicySet, entityMap types.Entities, request Request, cb Callback) error { be := &batchEvaler{} - found := map[types.String]struct{}{} - findVariables(found, request.Principal) - findVariables(found, request.Action) - findVariables(found, request.Resource) - findVariables(found, request.Context) - for key := range found { + var found sets.MapSet[types.String] + findVariables(&found, request.Principal) + findVariables(&found, request.Action) + findVariables(&found, request.Resource) + findVariables(&found, request.Context) + var err error + found.Iterate(func(key types.String) bool { if _, ok := request.Variables[key]; !ok { - return fmt.Errorf("%w: %v", errUnboundVariable, key) + err = fmt.Errorf("%w: %v", errUnboundVariable, key) + return false } + return true + }) + if err != nil { + return err } for k := range request.Variables { - if _, ok := found[k]; !ok { + if !found.Contains(k) { return fmt.Errorf("%w: %v", errUnusedVariable, k) } } @@ -375,11 +382,11 @@ func cloneSub(r types.Value, k types.String, v types.Value) (types.Value, bool) return r, false } -func findVariables(found map[types.String]struct{}, r types.Value) { +func findVariables(found *sets.MapSet[types.String], r types.Value) { switch t := r.(type) { case types.EntityUID: if key, ok := eval.ToVariable(t); ok { - found[key] = struct{}{} + found.Add(key) } case types.Record: t.Iterate(func(_ types.String, vv types.Value) bool { diff --git a/x/exp/batch/batch_test.go b/x/exp/batch/batch_test.go index d80bc07..1c4cf48 100644 --- a/x/exp/batch/batch_test.go +++ b/x/exp/batch/batch_test.go @@ -10,6 +10,7 @@ import ( "github.com/cedar-policy/cedar-go" publicast "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/sets" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -663,22 +664,22 @@ func TestFindVariables(t *testing.T) { tests := []struct { name string in types.Value - out map[types.String]struct{} + out []types.String }{ - {"record", types.NewRecord(types.RecordMap{"key": Variable("bananas")}), map[types.String]struct{}{"bananas": {}}}, - {"set", types.NewSet([]types.Value{Variable("bananas")}), map[types.String]struct{}{"bananas": {}}}, - {"dupes", types.NewSet([]types.Value{Variable("bananas"), Variable("bananas")}), map[types.String]struct{}{"bananas": {}}}, - {"none", types.String("test"), map[types.String]struct{}{}}, - {"multi", types.NewSet([]types.Value{Variable("bananas"), Variable("test")}), map[types.String]struct{}{"bananas": {}, "test": {}}}, + {"record", types.NewRecord(types.RecordMap{"key": Variable("bananas")}), []types.String{"bananas"}}, + {"set", types.NewSet([]types.Value{Variable("bananas")}), []types.String{"bananas"}}, + {"dupes", types.NewSet([]types.Value{Variable("bananas"), Variable("bananas")}), []types.String{"bananas"}}, + {"none", types.String("test"), nil}, + {"multi", types.NewSet([]types.Value{Variable("bananas"), Variable("test")}), []types.String{"bananas", "test"}}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - out := map[types.String]struct{}{} - findVariables(out, tt.in) - testutil.Equals(t, out, tt.out) + var out sets.MapSet[types.String] + findVariables(&out, tt.in) + testutil.Equals(t, out, sets.NewMapSetFromSlice(tt.out)) }) } From 2e445e2bbee91a1c520d817692358efe070d37c6 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 24 Sep 2024 14:39:42 -0700 Subject: [PATCH 06/17] internal/mapset: move mapset to its own package, rename functions appropriately, and return pointers to underscore risk of copy by value Signed-off-by: Patrick Jakubowski --- authorize_test.go | 10 +-- internal/eval/evalers.go | 10 +-- internal/eval/evalers_test.go | 26 +++--- internal/eval/partial.go | 4 +- internal/{sets => mapset}/mapset.go | 31 +++---- internal/{sets => mapset}/mapset_test.go | 101 +++++++++++++---------- internal/parser/cedar_unmarshal.go | 8 +- policy.go | 2 +- types.go | 4 +- types/entities_test.go | 8 +- types/entity_uid.go | 12 +-- types/entity_uid_test.go | 2 +- types/set.go | 2 +- x/exp/batch/batch.go | 6 +- x/exp/batch/batch_test.go | 8 +- 15 files changed, 123 insertions(+), 111 deletions(-) rename internal/{sets => mapset}/mapset.go (79%) rename internal/{sets => mapset}/mapset_test.go (76%) diff --git a/authorize_test.go b/authorize_test.go index 7e43a62..fa80ace 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -139,7 +139,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ cuzco: &cedar.Entity{ UID: cuzco, - Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("parent", "bob")}), + Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("parent", "bob")}), }, }, Principal: cuzco, @@ -166,7 +166,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ cuzco: &cedar.Entity{ UID: cuzco, - Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("team", "osiris")}), + Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("team", "osiris")}), }, }, Principal: cuzco, @@ -193,7 +193,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ dropTable: &cedar.Entity{ UID: dropTable, - Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("scary", "stuff")}), + Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("scary", "stuff")}), }, }, Principal: cuzco, @@ -209,7 +209,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ dropTable: &cedar.Entity{ UID: dropTable, - Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("scary", "stuff")}), + Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("scary", "stuff")}), }, }, Principal: cuzco, @@ -754,7 +754,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ cedar.NewEntityUID("Resource", "table"): &cedar.Entity{ UID: cedar.NewEntityUID("Resource", "table"), - Parents: cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("Parent", "id")}), + Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("Parent", "id")}), }, }, Principal: cedar.NewEntityUID("Actor", "cuzco"), diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 51ddae9..beddaa4 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -6,7 +6,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" - "github.com/cedar-policy/cedar-go/internal/sets" + "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/types" ) @@ -988,7 +988,7 @@ func entityInOneWork(env *Env, entity types.EntityUID, parent types.EntityUID) b if entity == parent { return true } - var known sets.MapSet[types.EntityUID] + var known mapset.MapSet[types.EntityUID] var todo []types.EntityUID var candidate = entity for { @@ -1013,11 +1013,11 @@ func entityInOneWork(env *Env, entity types.EntityUID, parent types.EntityUID) b } } -func entityInSet(env *Env, entity types.EntityUID, parents types.EntityUIDSet) bool { +func entityInSet(env *Env, entity types.EntityUID, parents *types.EntityUIDSet) bool { if parents.Contains(entity) { return true } - var known sets.MapSet[types.EntityUID] + var known mapset.MapSet[types.EntityUID] var todo []types.EntityUID var candidate = entity for { @@ -1061,7 +1061,7 @@ func doInEval(env *Env, lhs types.EntityUID, rhs types.Value) (types.Value, erro case types.EntityUID: return types.Boolean(entityInOne(env, lhs, rhsv)), nil case types.Set: - query := sets.NewMapSet[types.EntityUID](rhsv.Len()) + query := mapset.New[types.EntityUID](rhsv.Len()) var err error rhsv.Iterate(func(rhv types.Value) bool { var e types.EntityUID diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 54a3a04..079ed6a 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -8,8 +8,8 @@ import ( "time" "github.com/cedar-policy/cedar-go/internal/consts" + "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/internal/parser" - "github.com/cedar-policy/cedar-go/internal/sets" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -1486,7 +1486,7 @@ func TestContainsAnyNode(t *testing.T) { t.Run("not quadratic", func(t *testing.T) { t.Parallel() - // Make two totally disjoint sets to force a worst case search + // Make two totally disjoint mapset to force a worst case search setSize := 200000 set1 := make([]types.Value, setSize) set2 := make([]types.Value, setSize) @@ -1834,20 +1834,20 @@ func TestEntityIn(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - var rhs sets.MapSet[types.EntityUID] + rhs := types.NewEntityUIDSet(len(tt.rhs)) for _, v := range tt.rhs { rhs.Add(strEnt(v)) } entityMap := types.Entities{} for k, p := range tt.parents { - var ps sets.MapSet[types.EntityUID] + ps := types.NewEntityUIDSet(len(p)) for _, pp := range p { ps.Add(strEnt(pp)) } uid := strEnt(k) entityMap[uid] = &types.Entity{ UID: uid, - Parents: ps, + Parents: *ps, } } res := entityInSet(&Env{Entities: entityMap}, strEnt(tt.lhs), rhs) @@ -1861,26 +1861,26 @@ func TestEntityIn(t *testing.T) { entityMap := types.Entities{} for i := 0; i < 100; i++ { - p := sets.NewMapSetFromSlice([]types.EntityUID{ + p := mapset.FromSlice([]types.EntityUID{ types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "1"), types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "2"), }) uid1 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "1") entityMap[uid1] = &types.Entity{ UID: uid1, - Parents: p, + Parents: *p, } uid2 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "2") entityMap[uid2] = &types.Entity{ UID: uid2, - Parents: p, + Parents: *p, } } res := entityInSet( &Env{Entities: entityMap}, types.NewEntityUID("0", "1"), - sets.NewMapSetFromSlice([]types.EntityUID{types.NewEntityUID("0", "3")}), + mapset.FromSlice([]types.EntityUID{types.NewEntityUID("0", "3")}), ) testutil.Equals(t, res, false) }) @@ -2010,14 +2010,14 @@ func TestInNode(t *testing.T) { n := newInEval(tt.lhs, tt.rhs) entityMap := types.Entities{} for k, p := range tt.parents { - var ps sets.MapSet[types.EntityUID] + ps := types.NewEntityUIDSet(len(p)) for _, pp := range p { ps.Add(strEnt(pp)) } uid := strEnt(k) entityMap[uid] = &types.Entity{ UID: uid, - Parents: ps, + Parents: *ps, } } ec := InitEnv(&Env{Entities: entityMap}) @@ -2150,14 +2150,14 @@ func TestIsInNode(t *testing.T) { n := newIsInEval(tt.lhs, tt.is, tt.rhs) entityMap := types.Entities{} for k, p := range tt.parents { - var ps sets.MapSet[types.EntityUID] + ps := types.NewEntityUIDSet(len(p)) for _, pp := range p { ps.Add(strEnt(pp)) } uid := strEnt(k) entityMap[uid] = &types.Entity{ UID: uid, - Parents: ps, + Parents: *ps, } } ec := InitEnv(&Env{Entities: entityMap}) diff --git a/internal/eval/partial.go b/internal/eval/partial.go index 5f73fce..5b08c24 100644 --- a/internal/eval/partial.go +++ b/internal/eval/partial.go @@ -6,7 +6,7 @@ import ( "slices" "github.com/cedar-policy/cedar-go/internal/ast" - "github.com/cedar-policy/cedar-go/internal/sets" + "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/types" ) @@ -141,7 +141,7 @@ func partialScopeEval(env *Env, ent types.Value, in ast.IsScopeNode) (evaled boo case ast.ScopeTypeIn: return true, entityInOne(env, e, t.Entity) case ast.ScopeTypeInSet: - set := sets.NewMapSetFromSlice(t.Entities) + set := mapset.FromSlice(t.Entities) return true, entityInSet(env, e, set) case ast.ScopeTypeIs: return true, e.Type == t.Type diff --git a/internal/sets/mapset.go b/internal/mapset/mapset.go similarity index 79% rename from internal/sets/mapset.go rename to internal/mapset/mapset.go index 66bcc80..3de1e75 100644 --- a/internal/sets/mapset.go +++ b/internal/mapset/mapset.go @@ -1,4 +1,4 @@ -package sets +package mapset import ( "encoding/json" @@ -12,16 +12,16 @@ import ( var peppercorn = struct{}{} // MapSet is a struct that adds some convenience to the otherwise cumbersome map[T]struct{} idiom used in Go to -// implement sets of comparable types. +// implement mapset of comparable types. type MapSet[T comparable] struct { m map[T]struct{} } -// NewMapSet returns a MapSet ready for use. Optionally, a desired size for the MapSet can be passed as an argument, +// New returns a MapSet ready for use. Optionally, a desired size for the MapSet can be passed as an argument, // as in the argument to make() for a map type. -func NewMapSet[T comparable](args ...int) MapSet[T] { +func New[T comparable](args ...int) *MapSet[T] { if len(args) > 1 { - panic(fmt.Sprintf("too many arguments passed to NewMapSet(). got: %v, expected 0 or 1", len(args))) + panic(fmt.Sprintf("too many arguments passed to New(). got: %v, expected 0 or 1", len(args))) } var size int @@ -29,17 +29,12 @@ func NewMapSet[T comparable](args ...int) MapSet[T] { size = args[0] } - var m map[T]struct{} - if size > 0 { - m = make(map[T]struct{}, size) - } - - return MapSet[T]{m: m} + return &MapSet[T]{m: make(map[T]struct{}, size)} } -// NewMapSetFromSlice creates a MapSet of size len(items) and calls AddSlice(items) on it. -func NewMapSetFromSlice[T comparable](items []T) MapSet[T] { - h := NewMapSet[T](len(items)) +// FromSlice creates a MapSet of size len(items) and calls AddSlice(items) on it. +func FromSlice[T comparable](items []T) *MapSet[T] { + h := New[T](len(items)) h.AddSlice(items) return h } @@ -87,8 +82,8 @@ func (h MapSet[T]) Contains(item T) bool { } // Intersection returns the items common to both h and o. -func (h MapSet[T]) Intersection(o MapSet[T]) MapSet[T] { - intersection := NewMapSet[T]() +func (h MapSet[T]) Intersection(o *MapSet[T]) *MapSet[T] { + intersection := New[T]() for item := range h.m { if o.Contains(item) { intersection.Add(item) @@ -120,7 +115,7 @@ func (h MapSet[T]) Len() int { } // Equal returns whether the same items exist in both h and o -func (h MapSet[T]) Equal(o MapSet[T]) bool { +func (h MapSet[T]) Equal(o *MapSet[T]) bool { if len(h.m) != len(o.m) { return false } @@ -148,6 +143,6 @@ func (h *MapSet[T]) UnmarshalJSON(b []byte) error { return err } - *h = NewMapSetFromSlice(s) + *h = *FromSlice(s) return nil } diff --git a/internal/sets/mapset_test.go b/internal/mapset/mapset_test.go similarity index 76% rename from internal/sets/mapset_test.go rename to internal/mapset/mapset_test.go index fab97af..4088719 100644 --- a/internal/sets/mapset_test.go +++ b/internal/mapset/mapset_test.go @@ -1,4 +1,4 @@ -package sets +package mapset import ( "encoding/json" @@ -8,36 +8,33 @@ import ( "github.com/cedar-policy/cedar-go/internal/testutil" ) -func mustNotContain[T comparable](t *testing.T, s MapSet[T], item T) { +func mustNotContain[T comparable](t *testing.T, s *MapSet[T], item T) { testutil.FatalIf(t, s.Contains(item), "set %v unexpectedly contained item %v", s, 1) } func TestHashSet(t *testing.T) { t.Run("empty set contains nothing", func(t *testing.T) { - s := MapSet[int]{} - mustNotContain(t, s, 1) - - s = NewMapSet[int]() + s := New[int]() mustNotContain(t, s, 1) - s = NewMapSet[int](10) + s = New[int](10) mustNotContain(t, s, 1) }) t.Run("add => contains", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() s.Add(1) testutil.Equals(t, s.Contains(1), true) }) t.Run("add twice", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() testutil.Equals(t, s.Add(1), true) testutil.Equals(t, s.Add(1), false) }) t.Run("add slice", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() s.AddSlice([]int{1, 2}) testutil.Equals(t, s.Contains(1), true) testutil.Equals(t, s.Contains(2), true) @@ -45,45 +42,45 @@ func TestHashSet(t *testing.T) { }) t.Run("add same slice", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() testutil.Equals(t, s.AddSlice([]int{1, 2}), true) testutil.Equals(t, s.AddSlice([]int{1, 2}), false) }) t.Run("add disjoint slices", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() testutil.Equals(t, s.AddSlice([]int{1, 2}), true) testutil.Equals(t, s.AddSlice([]int{3, 4}), true) testutil.Equals(t, s.AddSlice([]int{1, 2, 3, 4}), false) }) t.Run("add overlapping slices", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() testutil.Equals(t, s.AddSlice([]int{1, 2}), true) testutil.Equals(t, s.AddSlice([]int{2, 3}), true) testutil.Equals(t, s.AddSlice([]int{1, 3}), false) }) t.Run("remove nonexistent", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() testutil.Equals(t, s.Remove(1), false) }) t.Run("remove existing", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() s.Add(1) testutil.Equals(t, s.Remove(1), true) }) t.Run("remove => !contains", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() s.Add(1) s.Remove(1) testutil.FatalIf(t, s.Contains(1), "set unexpectedly contained item") }) t.Run("remove slice", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() s.AddSlice([]int{1, 2, 3}) s.RemoveSlice([]int{1, 2}) mustNotContain(t, s, 1) @@ -92,19 +89,19 @@ func TestHashSet(t *testing.T) { }) t.Run("remove non-existent slice", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() testutil.Equals(t, s.RemoveSlice([]int{1, 2}), false) }) t.Run("remove overlapping slice", func(t *testing.T) { - s := MapSet[int]{} + s := New[int]() s.Add(1) testutil.Equals(t, s.RemoveSlice([]int{1, 2}), true) testutil.Equals(t, s.RemoveSlice([]int{1, 2}), false) }) t.Run("new from slice", func(t *testing.T) { - s := NewMapSetFromSlice([]int{1, 2, 2, 3}) + s := FromSlice([]int{1, 2, 2, 3}) testutil.Equals(t, s.Len(), 3) testutil.Equals(t, s.Contains(1), true) testutil.Equals(t, s.Contains(2), true) @@ -112,27 +109,24 @@ func TestHashSet(t *testing.T) { }) t.Run("slice", func(t *testing.T) { - s := MapSet[int]{} - testutil.Equals(t, s.Slice(), nil) - - s = NewMapSet[int]() - testutil.Equals(t, s.Slice(), nil) + s := New[int]() + testutil.Equals(t, s.Slice(), []int{}) - s = NewMapSet[int](10) + s = New[int](10) testutil.Equals(t, s.Slice(), []int{}) inSlice := []int{1, 2, 3} - s = NewMapSetFromSlice(inSlice) + s = FromSlice(inSlice) outSlice := s.Slice() slices.Sort(outSlice) testutil.Equals(t, inSlice, outSlice) }) t.Run("equal", func(t *testing.T) { - s1 := NewMapSetFromSlice([]int{1, 2, 3}) + s1 := FromSlice([]int{1, 2, 3}) testutil.Equals(t, s1.Equal(s1), true) - s2 := NewMapSetFromSlice([]int{1, 2, 3}) + s2 := FromSlice([]int{1, 2, 3}) testutil.Equals(t, s1.Equal(s2), true) s2.Add(4) @@ -147,9 +141,9 @@ func TestHashSet(t *testing.T) { }) t.Run("iterate", func(t *testing.T) { - s1 := NewMapSetFromSlice([]int{1, 2, 3}) + s1 := FromSlice([]int{1, 2, 3}) - var s2 MapSet[int] + s2 := New[int]() s1.Iterate(func(item int) bool { s2.Add(item) return true @@ -159,7 +153,7 @@ func TestHashSet(t *testing.T) { }) t.Run("iterate break early", func(t *testing.T) { - s1 := NewMapSetFromSlice([]int{1, 2, 3}) + s1 := FromSlice([]int{1, 2, 3}) i := 0 var items []int @@ -180,19 +174,19 @@ func TestHashSet(t *testing.T) { }) t.Run("intersection with overlap", func(t *testing.T) { - s1 := NewMapSetFromSlice([]int{1, 2, 3}) - s2 := NewMapSetFromSlice([]int{2, 3, 4}) + s1 := FromSlice([]int{1, 2, 3}) + s2 := FromSlice([]int{2, 3, 4}) s3 := s1.Intersection(s2) - testutil.Equals(t, s3, NewMapSetFromSlice([]int{2, 3})) + testutil.Equals(t, s3, FromSlice([]int{2, 3})) s4 := s1.Intersection(s2) - testutil.Equals(t, s4, NewMapSetFromSlice([]int{2, 3})) + testutil.Equals(t, s4, FromSlice([]int{2, 3})) }) t.Run("intersection disjoint", func(t *testing.T) { - s1 := NewMapSetFromSlice([]int{1, 2}) - s2 := NewMapSetFromSlice([]int{3, 4}) + s1 := FromSlice([]int{1, 2}) + s2 := FromSlice([]int{3, 4}) s3 := s1.Intersection(s2) testutil.Equals(t, s3.Len(), 0) @@ -202,7 +196,7 @@ func TestHashSet(t *testing.T) { }) t.Run("encode nil set", func(t *testing.T) { - s := NewMapSet[int]() + s := New[int]() out, err := json.Marshal(s) @@ -211,7 +205,7 @@ func TestHashSet(t *testing.T) { }) t.Run("encode json", func(t *testing.T) { - s := NewMapSetFromSlice([]int{1, 2, 3}) + s := FromSlice([]int{1, 2, 3}) out, err := json.Marshal(s) @@ -232,7 +226,7 @@ func TestHashSet(t *testing.T) { var s1 MapSet[int] err := s1.UnmarshalJSON([]byte("[2,3,1,2]")) testutil.OK(t, err) - testutil.Equals(t, s1, NewMapSetFromSlice([]int{1, 2, 3})) + testutil.Equals(t, &s1, FromSlice([]int{1, 2, 3})) }) t.Run("decode json empty", func(t *testing.T) { @@ -267,6 +261,29 @@ func TestHashSet(t *testing.T) { } }() - NewMapSet[int](0, 1) + New[int](0, 1) + }) + + // The zero value MapSet is usable, but care must be taken to ensure that it is not mutated when passed by value + // because those mutations may or may not be reflected in the caller's version of the MapSet. + t.Run("zero value", func(t *testing.T) { + s := MapSet[int]{} + mustNotContain(t, &s, 1) + testutil.Equals(t, s.Slice(), nil) + + addByValue := func(m MapSet[int], val int) { + m.Add(val) + } + + // Calling addByValue when s is still the zero value results in no mutation + addByValue(s, 1) + testutil.Equals(t, s.Len(), 0) + + // However, calling addByValue after the internal map in s has been initialized results in mutation + s.Add(0) + testutil.Equals(t, s.Len(), 1) + addByValue(s, 1) + testutil.Equals(t, s.Len(), 2) }) + } diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index da26258..1971b40 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -8,7 +8,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" - "github.com/cedar-policy/cedar-go/internal/sets" + "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/types" ) @@ -128,7 +128,7 @@ func (p *parser) errorf(s string, args ...interface{}) error { func (p *parser) annotations() (ast.Annotations, error) { var res ast.Annotations - var known sets.MapSet[string] + var known mapset.MapSet[string] for p.peek().Text == "@" { p.advance() err := p.annotation(&res, &known) @@ -140,7 +140,7 @@ func (p *parser) annotations() (ast.Annotations, error) { } -func (p *parser) annotation(a *ast.Annotations, known *sets.MapSet[string]) error { +func (p *parser) annotation(a *ast.Annotations, known *mapset.MapSet[string]) error { var err error t := p.advance() // As of 2024-09-13, the ability to use reserved keywords for annotation keys is not documented in the Cedar schema. @@ -816,7 +816,7 @@ func (p *parser) expressions(endOfListMarker string) ([]ast.Node, error) { func (p *parser) record() (ast.Node, error) { var res ast.Node var elements ast.Pairs - var known sets.MapSet[string] + var known mapset.MapSet[string] for { t := p.peek() if t.Text == "}" { diff --git a/policy.go b/policy.go index adcffad..9e21c1c 100644 --- a/policy.go +++ b/policy.go @@ -97,7 +97,7 @@ func (p *Policy) Position() Position { return Position(p.ast.Position) } -// SetFilename sets the filename of this policy. +// SetFilename mapset the filename of this policy. func (p *Policy) SetFilename(fileName string) { p.ast.Position.Filename = fileName } diff --git a/types.go b/types.go index 6378b55..be72d78 100644 --- a/types.go +++ b/types.go @@ -88,12 +88,12 @@ func NewEntityUID(typ EntityType, id String) EntityUID { // NewEntityUIDSet returns an EntityUIDSet ready for use. Optionally, a desired size for the set can be passed as an // argument, as in the argument to make() for a map type. -func NewEntityUIDSet(args ...int) EntityUIDSet { +func NewEntityUIDSet(args ...int) *EntityUIDSet { return types.NewEntityUIDSet(args...) } // NewEntityUIDSetFromSlice creates an EntityUIDSet of size len(items) and calls AddSlice(items) on it. -func NewEntityUIDSetFromSlice(items []EntityUID) EntityUIDSet { +func NewEntityUIDSetFromSlice(items []EntityUID) *EntityUIDSet { return types.NewEntityUIDSetFromSlice(items) } diff --git a/types/entities_test.go b/types/entities_test.go index 6972b08..848d81b 100644 --- a/types/entities_test.go +++ b/types/entities_test.go @@ -5,7 +5,7 @@ import ( "encoding/json" "testing" - "github.com/cedar-policy/cedar-go/internal/sets" + "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -43,12 +43,12 @@ func TestEntitiesJSON(t *testing.T) { e := types.Entities{} ent := &types.Entity{ UID: types.NewEntityUID("Type", "id"), - Parents: sets.MapSet[types.EntityUID]{}, + Parents: mapset.MapSet[types.EntityUID]{}, Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), } ent2 := &types.Entity{ UID: types.NewEntityUID("Type", "id2"), - Parents: sets.NewMapSetFromSlice([]types.EntityUID{ent.UID}), + Parents: *mapset.FromSlice([]types.EntityUID{ent.UID}), Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), } e[ent.UID] = ent @@ -71,7 +71,7 @@ func TestEntitiesJSON(t *testing.T) { want := types.Entities{} ent := &types.Entity{ UID: types.NewEntityUID("Type", "id"), - Parents: sets.MapSet[types.EntityUID]{}, + Parents: *types.NewEntityUIDSet(), Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), } want[ent.UID] = ent diff --git a/types/entity_uid.go b/types/entity_uid.go index 19946b8..16a344f 100644 --- a/types/entity_uid.go +++ b/types/entity_uid.go @@ -5,7 +5,7 @@ import ( "hash/fnv" "strconv" - "github.com/cedar-policy/cedar-go/internal/sets" + "github.com/cedar-policy/cedar-go/internal/mapset" ) // Path is a series of idents separated by :: @@ -94,15 +94,15 @@ func (i ImplicitlyMarshaledEntityUID) MarshalJSON() ([]byte, error) { return json.Marshal(s) } -type EntityUIDSet = sets.MapSet[EntityUID] +type EntityUIDSet = mapset.MapSet[EntityUID] // NewEntityUIDSet returns an EntityUIDSet ready for use. Optionally, a desired size for the EntityUIDSet can be passed // as an argument, as in the argument to make() for a map type. -func NewEntityUIDSet(args ...int) EntityUIDSet { - return sets.NewMapSet[EntityUID](args...) +func NewEntityUIDSet(args ...int) *EntityUIDSet { + return mapset.New[EntityUID](args...) } // NewEntityUIDSetFromSlice creates a EntityUIDSet of size len(items) and calls AddSlice(items) on it. -func NewEntityUIDSetFromSlice(items []EntityUID) EntityUIDSet { - return sets.NewMapSetFromSlice[EntityUID](items) +func NewEntityUIDSetFromSlice(items []EntityUID) *EntityUIDSet { + return mapset.FromSlice[EntityUID](items) } diff --git a/types/entity_uid_test.go b/types/entity_uid_test.go index 81ab55b..0e60229 100644 --- a/types/entity_uid_test.go +++ b/types/entity_uid_test.go @@ -36,7 +36,7 @@ func TestEntityUIDSet(t *testing.T) { t.Parallel() t.Run("new empty set", func(t *testing.T) { - emptySets := []types.EntityUIDSet{ + emptySets := []*types.EntityUIDSet{ types.NewEntityUIDSet(), types.NewEntityUIDSet(0), types.NewEntityUIDSet(1), diff --git a/types/set.go b/types/set.go index 07958b2..4e9e0e7 100644 --- a/types/set.go +++ b/types/set.go @@ -89,7 +89,7 @@ func (s Set) Slice() []Value { return maps.Values(s.s) } -// Equal returns true if the sets are Equal. +// Equal returns true if the mapset are Equal. func (as Set) Equal(bi Value) bool { bs, ok := bi.(Set) if !ok { diff --git a/x/exp/batch/batch.go b/x/exp/batch/batch.go index 3fd9edb..11884b6 100644 --- a/x/exp/batch/batch.go +++ b/x/exp/batch/batch.go @@ -10,7 +10,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/eval" - "github.com/cedar-policy/cedar-go/internal/sets" + "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/types" ) @@ -104,7 +104,7 @@ var errInvalidPart = fmt.Errorf("invalid part") // The result passed to the callback must be used / cloned immediately and not modified. func Authorize(ctx context.Context, ps *cedar.PolicySet, entityMap types.Entities, request Request, cb Callback) error { be := &batchEvaler{} - var found sets.MapSet[types.String] + var found mapset.MapSet[types.String] findVariables(&found, request.Principal) findVariables(&found, request.Action) findVariables(&found, request.Resource) @@ -382,7 +382,7 @@ func cloneSub(r types.Value, k types.String, v types.Value) (types.Value, bool) return r, false } -func findVariables(found *sets.MapSet[types.String], r types.Value) { +func findVariables(found *mapset.MapSet[types.String], r types.Value) { switch t := r.(type) { case types.EntityUID: if key, ok := eval.ToVariable(t); ok { diff --git a/x/exp/batch/batch_test.go b/x/exp/batch/batch_test.go index 1c4cf48..68116a2 100644 --- a/x/exp/batch/batch_test.go +++ b/x/exp/batch/batch_test.go @@ -10,7 +10,7 @@ import ( "github.com/cedar-policy/cedar-go" publicast "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/internal/ast" - "github.com/cedar-policy/cedar-go/internal/sets" + "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -677,9 +677,9 @@ func TestFindVariables(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - var out sets.MapSet[types.String] - findVariables(&out, tt.in) - testutil.Equals(t, out, sets.NewMapSetFromSlice(tt.out)) + out := mapset.New[types.String]() + findVariables(out, tt.in) + testutil.Equals(t, out, mapset.FromSlice(tt.out)) }) } From d479334719968ea4253d5fdd650f00ea6dcbb13e Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 24 Sep 2024 15:05:11 -0700 Subject: [PATCH 07/17] internal/mapset: create new immutable mapset type Signed-off-by: Patrick Jakubowski --- internal/mapset/immutable.go | 58 +++++++++++ internal/mapset/immutable_test.go | 168 ++++++++++++++++++++++++++++++ internal/mapset/mapset.go | 6 +- internal/mapset/mapset_test.go | 16 +-- 4 files changed, 240 insertions(+), 8 deletions(-) create mode 100644 internal/mapset/immutable.go create mode 100644 internal/mapset/immutable_test.go diff --git a/internal/mapset/immutable.go b/internal/mapset/immutable.go new file mode 100644 index 0000000..4b6db36 --- /dev/null +++ b/internal/mapset/immutable.go @@ -0,0 +1,58 @@ +package mapset + +import ( + "encoding/json" +) + +type ImmutableMapSet[T comparable] MapSet[T] + +func Immutable[T comparable](args ...T) ImmutableMapSet[T] { + return ImmutableMapSet[T](*FromSlice(args)) +} + +// Contains returns whether the item exists in the set +func (h ImmutableMapSet[T]) Contains(item T) bool { + return MapSet[T](h).Contains(item) +} + +// Intersection returns the items common to both h and o. +func (h ImmutableMapSet[T]) Intersection(o Container[T]) ImmutableMapSet[T] { + return ImmutableMapSet[T](*MapSet[T](h).Intersection(o)) +} + +// Iterate the items in the set, calling callback for each item. If the callback returns false, iteration is halted. +// Iteration order is undefined. +func (h ImmutableMapSet[T]) Iterate(callback func(item T) bool) { + MapSet[T](h).Iterate(callback) +} + +func (h ImmutableMapSet[T]) Slice() []T { + return MapSet[T](h).Slice() +} + +// Len returns the size of the set +func (h ImmutableMapSet[T]) Len() int { + return MapSet[T](h).Len() +} + +// Equal returns whether the same items exist in both h and o +func (h ImmutableMapSet[T]) Equal(o ImmutableMapSet[T]) bool { + om := MapSet[T](o) + return MapSet[T](h).Equal(&om) +} + +// MarshalJSON serializes an ImmutableMapSet as a JSON array. Order is non-deterministic. +func (h ImmutableMapSet[T]) MarshalJSON() ([]byte, error) { + return MapSet[T](h).MarshalJSON() +} + +// UnmarshalJSON deserializes an ImmutableMapSet from a JSON array. +func (h *ImmutableMapSet[T]) UnmarshalJSON(b []byte) error { + var s MapSet[T] + if err := json.Unmarshal(b, &s); err != nil { + return err + } + + *h = ImmutableMapSet[T](s) + return nil +} diff --git a/internal/mapset/immutable_test.go b/internal/mapset/immutable_test.go new file mode 100644 index 0000000..3baf661 --- /dev/null +++ b/internal/mapset/immutable_test.go @@ -0,0 +1,168 @@ +package mapset + +import ( + "encoding/json" + "slices" + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" +) + +func immutableHashSetMustNotContain[T comparable](t *testing.T, s ImmutableMapSet[T], item T) { + testutil.FatalIf(t, s.Contains(item), "set %v unexpectedly contained item %v", s, 1) +} + +func TestImmutableHashSet(t *testing.T) { + t.Run("empty set contains nothing", func(t *testing.T) { + s := Immutable[int]() + testutil.Equals(t, s.Len(), 0) + immutableHashSetMustNotContain(t, s, 1) + }) + + t.Run("one element", func(t *testing.T) { + s := Immutable[int](1) + testutil.Equals(t, s.Contains(1), true) + }) + + t.Run("two elements", func(t *testing.T) { + s := Immutable[int](1, 2) + testutil.Equals(t, s.Contains(1), true) + testutil.Equals(t, s.Contains(2), true) + testutil.Equals(t, s.Contains(3), false) + }) + + t.Run("deduplicate elements", func(t *testing.T) { + s := Immutable[int](1, 1) + testutil.Equals(t, s.Contains(1), true) + testutil.Equals(t, s.Contains(2), false) + }) + + t.Run("slice", func(t *testing.T) { + s := Immutable[int]() + testutil.Equals(t, s.Slice(), []int{}) + + inSlice := []int{1, 2, 3} + s = Immutable[int](inSlice...) + + outSlice := s.Slice() + slices.Sort(outSlice) + testutil.Equals(t, inSlice, outSlice) + }) + + t.Run("equal", func(t *testing.T) { + s1 := Immutable(1, 2, 3) + testutil.Equals(t, s1.Equal(s1), true) + + s2 := Immutable(1, 2, 3) + testutil.Equals(t, s1.Equal(s2), true) + + s3 := Immutable(1, 2, 3, 4) + testutil.Equals(t, s1.Equal(s3), false) + }) + + t.Run("iterate", func(t *testing.T) { + s1 := Immutable(1, 2, 3) + + var items []int + s1.Iterate(func(item int) bool { + items = append(items, item) + return true + }) + + testutil.Equals(t, s1.Equal(Immutable(items...)), true) + }) + + t.Run("iterate break early", func(t *testing.T) { + s1 := Immutable(1, 2, 3) + + i := 0 + var items []int + s1.Iterate(func(item int) bool { + if i == 2 { + return false + } + items = append(items, item) + i++ + return true + }) + + // Because iteration order is non-deterministic, all we can say is that the right number of items ended up in + // the set and that the items were in the original set. + testutil.Equals(t, len(items), 2) + testutil.Equals(t, s1.Contains(items[0]), true) + testutil.Equals(t, s1.Contains(items[1]), true) + }) + + t.Run("intersection with overlap", func(t *testing.T) { + s1 := Immutable(1, 2, 3) + s2 := Immutable(2, 3, 4) + + s3 := s1.Intersection(s2) + testutil.Equals(t, s3, Immutable(2, 3)) + }) + + t.Run("intersection disjoint", func(t *testing.T) { + s1 := Immutable(1, 2) + s2 := Immutable(3, 4) + + s3 := s1.Intersection(s2) + testutil.Equals(t, s3.Len(), 0) + }) + + t.Run("encode nil set", func(t *testing.T) { + s := ImmutableMapSet[int]{} + + out, err := json.Marshal(s) + + testutil.OK(t, err) + testutil.Equals(t, string(out), "[]") + }) + + t.Run("encode json", func(t *testing.T) { + s := Immutable(1, 2, 3) + + out, err := json.Marshal(s) + + correctOutputs := []string{ + "[1,2,3]", + "[1,3,2]", + "[2,1,3]", + "[2,3,1]", + "[3,1,2]", + "[3,2,1]", + } + + testutil.OK(t, err) + testutil.FatalIf(t, !slices.Contains(correctOutputs, string(out)), "%v is not a valid output", string(out)) + }) + + t.Run("decode json", func(t *testing.T) { + var s1 ImmutableMapSet[int] + err := s1.UnmarshalJSON([]byte("[2,3,1,2]")) + testutil.OK(t, err) + testutil.Equals(t, s1, Immutable(1, 2, 3)) + }) + + t.Run("decode json empty", func(t *testing.T) { + var s1 ImmutableMapSet[int] + err := s1.UnmarshalJSON([]byte("[]")) + testutil.OK(t, err) + testutil.Equals(t, s1.Len(), 0) + }) + + t.Run("decode mixed types in array", func(t *testing.T) { + var s1 ImmutableMapSet[int] + err := s1.UnmarshalJSON([]byte(`[2,3,1,"2"]`)) + testutil.Error(t, err) + testutil.Equals(t, err.Error(), "json: cannot unmarshal string into Go value of type int") + testutil.Equals(t, s1.Len(), 0) + }) + + t.Run("decode wrong type", func(t *testing.T) { + var s1 ImmutableMapSet[int] + err := s1.UnmarshalJSON([]byte(`"1,2,3"`)) + testutil.Error(t, err) + testutil.Equals(t, err.Error(), "json: cannot unmarshal string into Go value of type []int") + testutil.Equals(t, s1.Len(), 0) + }) +} diff --git a/internal/mapset/mapset.go b/internal/mapset/mapset.go index 3de1e75..e7c25c5 100644 --- a/internal/mapset/mapset.go +++ b/internal/mapset/mapset.go @@ -81,8 +81,12 @@ func (h MapSet[T]) Contains(item T) bool { return exists } +type Container[T comparable] interface { + Contains(T) bool +} + // Intersection returns the items common to both h and o. -func (h MapSet[T]) Intersection(o *MapSet[T]) *MapSet[T] { +func (h MapSet[T]) Intersection(o Container[T]) *MapSet[T] { intersection := New[T]() for item := range h.m { if o.Contains(item) { diff --git a/internal/mapset/mapset_test.go b/internal/mapset/mapset_test.go index 4088719..72ad558 100644 --- a/internal/mapset/mapset_test.go +++ b/internal/mapset/mapset_test.go @@ -8,17 +8,19 @@ import ( "github.com/cedar-policy/cedar-go/internal/testutil" ) -func mustNotContain[T comparable](t *testing.T, s *MapSet[T], item T) { +func hashSetMustNotContain[T comparable](t *testing.T, s *MapSet[T], item T) { testutil.FatalIf(t, s.Contains(item), "set %v unexpectedly contained item %v", s, 1) } func TestHashSet(t *testing.T) { t.Run("empty set contains nothing", func(t *testing.T) { s := New[int]() - mustNotContain(t, s, 1) + testutil.Equals(t, s.Len(), 0) + hashSetMustNotContain(t, s, 1) s = New[int](10) - mustNotContain(t, s, 1) + testutil.Equals(t, s.Len(), 0) + hashSetMustNotContain(t, s, 1) }) t.Run("add => contains", func(t *testing.T) { @@ -38,7 +40,7 @@ func TestHashSet(t *testing.T) { s.AddSlice([]int{1, 2}) testutil.Equals(t, s.Contains(1), true) testutil.Equals(t, s.Contains(2), true) - mustNotContain(t, s, 3) + hashSetMustNotContain(t, s, 3) }) t.Run("add same slice", func(t *testing.T) { @@ -83,8 +85,8 @@ func TestHashSet(t *testing.T) { s := New[int]() s.AddSlice([]int{1, 2, 3}) s.RemoveSlice([]int{1, 2}) - mustNotContain(t, s, 1) - mustNotContain(t, s, 2) + hashSetMustNotContain(t, s, 1) + hashSetMustNotContain(t, s, 2) testutil.Equals(t, s.Contains(3), true) }) @@ -268,7 +270,7 @@ func TestHashSet(t *testing.T) { // because those mutations may or may not be reflected in the caller's version of the MapSet. t.Run("zero value", func(t *testing.T) { s := MapSet[int]{} - mustNotContain(t, &s, 1) + hashSetMustNotContain(t, &s, 1) testutil.Equals(t, s.Slice(), nil) addByValue := func(m MapSet[int], val int) { From a81d2604b238fcd8f9f97f83a689897c5a83cf7b Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 24 Sep 2024 15:26:18 -0700 Subject: [PATCH 08/17] types: make EntityUIDSet an ImmutableHashSet Signed-off-by: Patrick Jakubowski --- authorize_test.go | 10 +++++----- internal/eval/evalers.go | 2 +- internal/eval/evalers_test.go | 35 +++++++++++++++++------------------ internal/eval/partial.go | 2 +- types.go | 13 ++++--------- types/entities_test.go | 7 +++---- types/entity_uid.go | 14 ++++---------- types/entity_uid_test.go | 21 ++++++++------------- 8 files changed, 43 insertions(+), 61 deletions(-) diff --git a/authorize_test.go b/authorize_test.go index fa80ace..d9cf776 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -139,7 +139,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ cuzco: &cedar.Entity{ UID: cuzco, - Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("parent", "bob")}), + Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("parent", "bob")), }, }, Principal: cuzco, @@ -166,7 +166,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ cuzco: &cedar.Entity{ UID: cuzco, - Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("team", "osiris")}), + Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("team", "osiris")), }, }, Principal: cuzco, @@ -193,7 +193,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ dropTable: &cedar.Entity{ UID: dropTable, - Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("scary", "stuff")}), + Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("scary", "stuff")), }, }, Principal: cuzco, @@ -209,7 +209,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ dropTable: &cedar.Entity{ UID: dropTable, - Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("scary", "stuff")}), + Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("scary", "stuff")), }, }, Principal: cuzco, @@ -754,7 +754,7 @@ func TestIsAuthorized(t *testing.T) { Entities: cedar.Entities{ cedar.NewEntityUID("Resource", "table"): &cedar.Entity{ UID: cedar.NewEntityUID("Resource", "table"), - Parents: *cedar.NewEntityUIDSetFromSlice([]cedar.EntityUID{cedar.NewEntityUID("Parent", "id")}), + Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("Parent", "id")), }, }, Principal: cedar.NewEntityUID("Actor", "cuzco"), diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index beddaa4..a556e6e 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -1013,7 +1013,7 @@ func entityInOneWork(env *Env, entity types.EntityUID, parent types.EntityUID) b } } -func entityInSet(env *Env, entity types.EntityUID, parents *types.EntityUIDSet) bool { +func entityInSet(env *Env, entity types.EntityUID, parents mapset.Container[types.EntityUID]) bool { if parents.Contains(entity) { return true } diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 079ed6a..86c0b9f 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/cedar-policy/cedar-go/internal/consts" - "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" @@ -1834,23 +1833,23 @@ func TestEntityIn(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - rhs := types.NewEntityUIDSet(len(tt.rhs)) + var rhs []types.EntityUID for _, v := range tt.rhs { - rhs.Add(strEnt(v)) + rhs = append(rhs, strEnt(v)) } entityMap := types.Entities{} for k, p := range tt.parents { - ps := types.NewEntityUIDSet(len(p)) + var ps []types.EntityUID for _, pp := range p { - ps.Add(strEnt(pp)) + ps = append(ps, strEnt(pp)) } uid := strEnt(k) entityMap[uid] = &types.Entity{ UID: uid, - Parents: *ps, + Parents: types.NewEntityUIDSet(ps...), } } - res := entityInSet(&Env{Entities: entityMap}, strEnt(tt.lhs), rhs) + res := entityInSet(&Env{Entities: entityMap}, strEnt(tt.lhs), types.NewEntityUIDSet(rhs...)) testutil.Equals(t, res, tt.result) }) } @@ -1861,26 +1860,26 @@ func TestEntityIn(t *testing.T) { entityMap := types.Entities{} for i := 0; i < 100; i++ { - p := mapset.FromSlice([]types.EntityUID{ + p := types.NewEntityUIDSet( types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "1"), types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "2"), - }) + ) uid1 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "1") entityMap[uid1] = &types.Entity{ UID: uid1, - Parents: *p, + Parents: p, } uid2 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "2") entityMap[uid2] = &types.Entity{ UID: uid2, - Parents: *p, + Parents: p, } } res := entityInSet( &Env{Entities: entityMap}, types.NewEntityUID("0", "1"), - mapset.FromSlice([]types.EntityUID{types.NewEntityUID("0", "3")}), + types.NewEntityUIDSet(types.NewEntityUID("0", "3")), ) testutil.Equals(t, res, false) }) @@ -2010,14 +2009,14 @@ func TestInNode(t *testing.T) { n := newInEval(tt.lhs, tt.rhs) entityMap := types.Entities{} for k, p := range tt.parents { - ps := types.NewEntityUIDSet(len(p)) + var ps []types.EntityUID for _, pp := range p { - ps.Add(strEnt(pp)) + ps = append(ps, strEnt(pp)) } uid := strEnt(k) entityMap[uid] = &types.Entity{ UID: uid, - Parents: *ps, + Parents: types.NewEntityUIDSet(ps...), } } ec := InitEnv(&Env{Entities: entityMap}) @@ -2150,14 +2149,14 @@ func TestIsInNode(t *testing.T) { n := newIsInEval(tt.lhs, tt.is, tt.rhs) entityMap := types.Entities{} for k, p := range tt.parents { - ps := types.NewEntityUIDSet(len(p)) + var ps []types.EntityUID for _, pp := range p { - ps.Add(strEnt(pp)) + ps = append(ps, strEnt(pp)) } uid := strEnt(k) entityMap[uid] = &types.Entity{ UID: uid, - Parents: *ps, + Parents: types.NewEntityUIDSet(ps...), } } ec := InitEnv(&Env{Entities: entityMap}) diff --git a/internal/eval/partial.go b/internal/eval/partial.go index 5b08c24..5b86a21 100644 --- a/internal/eval/partial.go +++ b/internal/eval/partial.go @@ -141,7 +141,7 @@ func partialScopeEval(env *Env, ent types.Value, in ast.IsScopeNode) (evaled boo case ast.ScopeTypeIn: return true, entityInOne(env, e, t.Entity) case ast.ScopeTypeInSet: - set := mapset.FromSlice(t.Entities) + set := mapset.Immutable(t.Entities...) return true, entityInSet(env, e, set) case ast.ScopeTypeIs: return true, e.Type == t.Type diff --git a/types.go b/types.go index be72d78..1681175 100644 --- a/types.go +++ b/types.go @@ -3,6 +3,7 @@ package cedar import ( "time" + "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/types" ) @@ -86,15 +87,9 @@ func NewEntityUID(typ EntityType, id String) EntityUID { return types.NewEntityUID(typ, id) } -// NewEntityUIDSet returns an EntityUIDSet ready for use. Optionally, a desired size for the set can be passed as an -// argument, as in the argument to make() for a map type. -func NewEntityUIDSet(args ...int) *EntityUIDSet { - return types.NewEntityUIDSet(args...) -} - -// NewEntityUIDSetFromSlice creates an EntityUIDSet of size len(items) and calls AddSlice(items) on it. -func NewEntityUIDSetFromSlice(items []EntityUID) *EntityUIDSet { - return types.NewEntityUIDSetFromSlice(items) +// NewEntityUIDSet returns an immutable EntityUIDSet ready for use. +func NewEntityUIDSet(args ...EntityUID) EntityUIDSet { + return mapset.Immutable[EntityUID](args...) } // NewPattern permits for the programmatic construction of a Pattern out of a slice of pattern components. diff --git a/types/entities_test.go b/types/entities_test.go index 848d81b..b8926e3 100644 --- a/types/entities_test.go +++ b/types/entities_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "testing" - "github.com/cedar-policy/cedar-go/internal/mapset" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -43,12 +42,12 @@ func TestEntitiesJSON(t *testing.T) { e := types.Entities{} ent := &types.Entity{ UID: types.NewEntityUID("Type", "id"), - Parents: mapset.MapSet[types.EntityUID]{}, + Parents: types.EntityUIDSet{}, Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), } ent2 := &types.Entity{ UID: types.NewEntityUID("Type", "id2"), - Parents: *mapset.FromSlice([]types.EntityUID{ent.UID}), + Parents: types.NewEntityUIDSet(ent.UID), Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), } e[ent.UID] = ent @@ -71,7 +70,7 @@ func TestEntitiesJSON(t *testing.T) { want := types.Entities{} ent := &types.Entity{ UID: types.NewEntityUID("Type", "id"), - Parents: *types.NewEntityUIDSet(), + Parents: types.NewEntityUIDSet(), Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), } want[ent.UID] = ent diff --git a/types/entity_uid.go b/types/entity_uid.go index 16a344f..3ed45db 100644 --- a/types/entity_uid.go +++ b/types/entity_uid.go @@ -94,15 +94,9 @@ func (i ImplicitlyMarshaledEntityUID) MarshalJSON() ([]byte, error) { return json.Marshal(s) } -type EntityUIDSet = mapset.MapSet[EntityUID] +type EntityUIDSet = mapset.ImmutableMapSet[EntityUID] -// NewEntityUIDSet returns an EntityUIDSet ready for use. Optionally, a desired size for the EntityUIDSet can be passed -// as an argument, as in the argument to make() for a map type. -func NewEntityUIDSet(args ...int) *EntityUIDSet { - return mapset.New[EntityUID](args...) -} - -// NewEntityUIDSetFromSlice creates a EntityUIDSet of size len(items) and calls AddSlice(items) on it. -func NewEntityUIDSetFromSlice(items []EntityUID) *EntityUIDSet { - return mapset.FromSlice[EntityUID](items) +// NewEntityUIDSet returns an immutable EntityUIDSet ready for use. +func NewEntityUIDSet(args ...EntityUID) EntityUIDSet { + return mapset.Immutable[EntityUID](args...) } diff --git a/types/entity_uid_test.go b/types/entity_uid_test.go index 0e60229..97b47ab 100644 --- a/types/entity_uid_test.go +++ b/types/entity_uid_test.go @@ -36,12 +36,9 @@ func TestEntityUIDSet(t *testing.T) { t.Parallel() t.Run("new empty set", func(t *testing.T) { - emptySets := []*types.EntityUIDSet{ + emptySets := []types.EntityUIDSet{ + types.EntityUIDSet{}, types.NewEntityUIDSet(), - types.NewEntityUIDSet(0), - types.NewEntityUIDSet(1), - types.NewEntityUIDSetFromSlice(nil), - types.NewEntityUIDSetFromSlice([]types.EntityUID{}), } for _, es := range emptySets { @@ -51,17 +48,15 @@ func TestEntityUIDSet(t *testing.T) { } }) - t.Run("new set from slice", func(t *testing.T) { + t.Run("new set", func(t *testing.T) { a := types.NewEntityUID("typeA", "1") b := types.NewEntityUID("typeB", "2") o := types.NewEntityUID("typeO", "2") - s1 := types.NewEntityUIDSet() - s1.Add(a) - s1.Add(b) - s1.Add(o) + s := types.NewEntityUIDSet(a, b, o) - s2 := types.NewEntityUIDSetFromSlice([]types.EntityUID{o, b, a}) - - testutil.Equals(t, s1.Equal(s2), true) + testutil.Equals(t, s.Len(), 3) + testutil.Equals(t, s.Contains(a), true) + testutil.Equals(t, s.Contains(b), true) + testutil.Equals(t, s.Contains(o), true) }) } From 68386de2f8e4bbeb74eca24df9dca8cc6ba81e50 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 24 Sep 2024 15:29:20 -0700 Subject: [PATCH 09/17] internal/mapset: rename mapset.New to mapset.Make Signed-off-by: Patrick Jakubowski --- internal/eval/evalers.go | 2 +- internal/mapset/mapset.go | 10 ++++----- internal/mapset/mapset_test.go | 38 +++++++++++++++++----------------- x/exp/batch/batch_test.go | 2 +- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index a556e6e..6783666 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -1061,7 +1061,7 @@ func doInEval(env *Env, lhs types.EntityUID, rhs types.Value) (types.Value, erro case types.EntityUID: return types.Boolean(entityInOne(env, lhs, rhsv)), nil case types.Set: - query := mapset.New[types.EntityUID](rhsv.Len()) + query := mapset.Make[types.EntityUID](rhsv.Len()) var err error rhsv.Iterate(func(rhv types.Value) bool { var e types.EntityUID diff --git a/internal/mapset/mapset.go b/internal/mapset/mapset.go index e7c25c5..879a1b3 100644 --- a/internal/mapset/mapset.go +++ b/internal/mapset/mapset.go @@ -17,11 +17,11 @@ type MapSet[T comparable] struct { m map[T]struct{} } -// New returns a MapSet ready for use. Optionally, a desired size for the MapSet can be passed as an argument, +// Make returns a MapSet ready for use. Optionally, a desired size for the MapSet can be passed as an argument, // as in the argument to make() for a map type. -func New[T comparable](args ...int) *MapSet[T] { +func Make[T comparable](args ...int) *MapSet[T] { if len(args) > 1 { - panic(fmt.Sprintf("too many arguments passed to New(). got: %v, expected 0 or 1", len(args))) + panic(fmt.Sprintf("too many arguments passed to Make(). got: %v, expected 0 or 1", len(args))) } var size int @@ -34,7 +34,7 @@ func New[T comparable](args ...int) *MapSet[T] { // FromSlice creates a MapSet of size len(items) and calls AddSlice(items) on it. func FromSlice[T comparable](items []T) *MapSet[T] { - h := New[T](len(items)) + h := Make[T](len(items)) h.AddSlice(items) return h } @@ -87,7 +87,7 @@ type Container[T comparable] interface { // Intersection returns the items common to both h and o. func (h MapSet[T]) Intersection(o Container[T]) *MapSet[T] { - intersection := New[T]() + intersection := Make[T]() for item := range h.m { if o.Contains(item) { intersection.Add(item) diff --git a/internal/mapset/mapset_test.go b/internal/mapset/mapset_test.go index 72ad558..acbe57d 100644 --- a/internal/mapset/mapset_test.go +++ b/internal/mapset/mapset_test.go @@ -14,29 +14,29 @@ func hashSetMustNotContain[T comparable](t *testing.T, s *MapSet[T], item T) { func TestHashSet(t *testing.T) { t.Run("empty set contains nothing", func(t *testing.T) { - s := New[int]() + s := Make[int]() testutil.Equals(t, s.Len(), 0) hashSetMustNotContain(t, s, 1) - s = New[int](10) + s = Make[int](10) testutil.Equals(t, s.Len(), 0) hashSetMustNotContain(t, s, 1) }) t.Run("add => contains", func(t *testing.T) { - s := New[int]() + s := Make[int]() s.Add(1) testutil.Equals(t, s.Contains(1), true) }) t.Run("add twice", func(t *testing.T) { - s := New[int]() + s := Make[int]() testutil.Equals(t, s.Add(1), true) testutil.Equals(t, s.Add(1), false) }) t.Run("add slice", func(t *testing.T) { - s := New[int]() + s := Make[int]() s.AddSlice([]int{1, 2}) testutil.Equals(t, s.Contains(1), true) testutil.Equals(t, s.Contains(2), true) @@ -44,45 +44,45 @@ func TestHashSet(t *testing.T) { }) t.Run("add same slice", func(t *testing.T) { - s := New[int]() + s := Make[int]() testutil.Equals(t, s.AddSlice([]int{1, 2}), true) testutil.Equals(t, s.AddSlice([]int{1, 2}), false) }) t.Run("add disjoint slices", func(t *testing.T) { - s := New[int]() + s := Make[int]() testutil.Equals(t, s.AddSlice([]int{1, 2}), true) testutil.Equals(t, s.AddSlice([]int{3, 4}), true) testutil.Equals(t, s.AddSlice([]int{1, 2, 3, 4}), false) }) t.Run("add overlapping slices", func(t *testing.T) { - s := New[int]() + s := Make[int]() testutil.Equals(t, s.AddSlice([]int{1, 2}), true) testutil.Equals(t, s.AddSlice([]int{2, 3}), true) testutil.Equals(t, s.AddSlice([]int{1, 3}), false) }) t.Run("remove nonexistent", func(t *testing.T) { - s := New[int]() + s := Make[int]() testutil.Equals(t, s.Remove(1), false) }) t.Run("remove existing", func(t *testing.T) { - s := New[int]() + s := Make[int]() s.Add(1) testutil.Equals(t, s.Remove(1), true) }) t.Run("remove => !contains", func(t *testing.T) { - s := New[int]() + s := Make[int]() s.Add(1) s.Remove(1) testutil.FatalIf(t, s.Contains(1), "set unexpectedly contained item") }) t.Run("remove slice", func(t *testing.T) { - s := New[int]() + s := Make[int]() s.AddSlice([]int{1, 2, 3}) s.RemoveSlice([]int{1, 2}) hashSetMustNotContain(t, s, 1) @@ -91,12 +91,12 @@ func TestHashSet(t *testing.T) { }) t.Run("remove non-existent slice", func(t *testing.T) { - s := New[int]() + s := Make[int]() testutil.Equals(t, s.RemoveSlice([]int{1, 2}), false) }) t.Run("remove overlapping slice", func(t *testing.T) { - s := New[int]() + s := Make[int]() s.Add(1) testutil.Equals(t, s.RemoveSlice([]int{1, 2}), true) testutil.Equals(t, s.RemoveSlice([]int{1, 2}), false) @@ -111,10 +111,10 @@ func TestHashSet(t *testing.T) { }) t.Run("slice", func(t *testing.T) { - s := New[int]() + s := Make[int]() testutil.Equals(t, s.Slice(), []int{}) - s = New[int](10) + s = Make[int](10) testutil.Equals(t, s.Slice(), []int{}) inSlice := []int{1, 2, 3} @@ -145,7 +145,7 @@ func TestHashSet(t *testing.T) { t.Run("iterate", func(t *testing.T) { s1 := FromSlice([]int{1, 2, 3}) - s2 := New[int]() + s2 := Make[int]() s1.Iterate(func(item int) bool { s2.Add(item) return true @@ -198,7 +198,7 @@ func TestHashSet(t *testing.T) { }) t.Run("encode nil set", func(t *testing.T) { - s := New[int]() + s := Make[int]() out, err := json.Marshal(s) @@ -263,7 +263,7 @@ func TestHashSet(t *testing.T) { } }() - New[int](0, 1) + Make[int](0, 1) }) // The zero value MapSet is usable, but care must be taken to ensure that it is not mutated when passed by value diff --git a/x/exp/batch/batch_test.go b/x/exp/batch/batch_test.go index 68116a2..2271a8b 100644 --- a/x/exp/batch/batch_test.go +++ b/x/exp/batch/batch_test.go @@ -677,7 +677,7 @@ func TestFindVariables(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - out := mapset.New[types.String]() + out := mapset.Make[types.String]() findVariables(out, tt.in) testutil.Equals(t, out, mapset.FromSlice(tt.out)) }) From 7b923a35d8b90544cfae1d0be2c0040db795582d Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 24 Sep 2024 15:33:52 -0700 Subject: [PATCH 10/17] internal/mapset: remove some unused methods on MapSet Signed-off-by: Patrick Jakubowski --- internal/mapset/mapset.go | 24 +++-------------- internal/mapset/mapset_test.go | 49 ---------------------------------- 2 files changed, 4 insertions(+), 69 deletions(-) diff --git a/internal/mapset/mapset.go b/internal/mapset/mapset.go index 879a1b3..31feea6 100644 --- a/internal/mapset/mapset.go +++ b/internal/mapset/mapset.go @@ -32,10 +32,12 @@ func Make[T comparable](args ...int) *MapSet[T] { return &MapSet[T]{m: make(map[T]struct{}, size)} } -// FromSlice creates a MapSet of size len(items) and calls AddSlice(items) on it. +// FromSlice creates a MapSet of size len(items) and calls Add for each of the items to it. func FromSlice[T comparable](items []T) *MapSet[T] { h := Make[T](len(items)) - h.AddSlice(items) + for _, i := range items { + h.Add(i) + } return h } @@ -50,15 +52,6 @@ func (h *MapSet[T]) Add(item T) bool { return !exists } -// AddSlice adds a slice of items to the set, returning true if any new items were added to the set. -func (h *MapSet[T]) AddSlice(items []T) bool { - modified := false - for _, i := range items { - modified = h.Add(i) || modified - } - return modified -} - // Remove an item from the Set. Returns true if the item existed in the set. func (h *MapSet[T]) Remove(item T) bool { _, exists := h.m[item] @@ -66,15 +59,6 @@ func (h *MapSet[T]) Remove(item T) bool { return exists } -// RemoveSlice removes a slice of items from the set, returning true if any items existed in the set. -func (h *MapSet[T]) RemoveSlice(items []T) bool { - modified := false - for _, i := range items { - modified = h.Remove(i) || modified - } - return modified -} - // Contains returns whether the item exists in the set func (h MapSet[T]) Contains(item T) bool { _, exists := h.m[item] diff --git a/internal/mapset/mapset_test.go b/internal/mapset/mapset_test.go index acbe57d..34088db 100644 --- a/internal/mapset/mapset_test.go +++ b/internal/mapset/mapset_test.go @@ -35,34 +35,6 @@ func TestHashSet(t *testing.T) { testutil.Equals(t, s.Add(1), false) }) - t.Run("add slice", func(t *testing.T) { - s := Make[int]() - s.AddSlice([]int{1, 2}) - testutil.Equals(t, s.Contains(1), true) - testutil.Equals(t, s.Contains(2), true) - hashSetMustNotContain(t, s, 3) - }) - - t.Run("add same slice", func(t *testing.T) { - s := Make[int]() - testutil.Equals(t, s.AddSlice([]int{1, 2}), true) - testutil.Equals(t, s.AddSlice([]int{1, 2}), false) - }) - - t.Run("add disjoint slices", func(t *testing.T) { - s := Make[int]() - testutil.Equals(t, s.AddSlice([]int{1, 2}), true) - testutil.Equals(t, s.AddSlice([]int{3, 4}), true) - testutil.Equals(t, s.AddSlice([]int{1, 2, 3, 4}), false) - }) - - t.Run("add overlapping slices", func(t *testing.T) { - s := Make[int]() - testutil.Equals(t, s.AddSlice([]int{1, 2}), true) - testutil.Equals(t, s.AddSlice([]int{2, 3}), true) - testutil.Equals(t, s.AddSlice([]int{1, 3}), false) - }) - t.Run("remove nonexistent", func(t *testing.T) { s := Make[int]() testutil.Equals(t, s.Remove(1), false) @@ -81,27 +53,6 @@ func TestHashSet(t *testing.T) { testutil.FatalIf(t, s.Contains(1), "set unexpectedly contained item") }) - t.Run("remove slice", func(t *testing.T) { - s := Make[int]() - s.AddSlice([]int{1, 2, 3}) - s.RemoveSlice([]int{1, 2}) - hashSetMustNotContain(t, s, 1) - hashSetMustNotContain(t, s, 2) - testutil.Equals(t, s.Contains(3), true) - }) - - t.Run("remove non-existent slice", func(t *testing.T) { - s := Make[int]() - testutil.Equals(t, s.RemoveSlice([]int{1, 2}), false) - }) - - t.Run("remove overlapping slice", func(t *testing.T) { - s := Make[int]() - s.Add(1) - testutil.Equals(t, s.RemoveSlice([]int{1, 2}), true) - testutil.Equals(t, s.RemoveSlice([]int{1, 2}), false) - }) - t.Run("new from slice", func(t *testing.T) { s := FromSlice([]int{1, 2, 2, 3}) testutil.Equals(t, s.Len(), 3) From 6b1e959477a5115dc61931a3f91679b609fcd651 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 24 Sep 2024 15:37:34 -0700 Subject: [PATCH 11/17] internal/mapset: rename FromSlice to FromItems and use variadic args Signed-off-by: Patrick Jakubowski --- internal/mapset/immutable.go | 2 +- internal/mapset/mapset.go | 6 +++--- internal/mapset/mapset_test.go | 28 ++++++++++++++-------------- x/exp/batch/batch_test.go | 2 +- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/internal/mapset/immutable.go b/internal/mapset/immutable.go index 4b6db36..3be21d9 100644 --- a/internal/mapset/immutable.go +++ b/internal/mapset/immutable.go @@ -7,7 +7,7 @@ import ( type ImmutableMapSet[T comparable] MapSet[T] func Immutable[T comparable](args ...T) ImmutableMapSet[T] { - return ImmutableMapSet[T](*FromSlice(args)) + return ImmutableMapSet[T](*FromItems(args...)) } // Contains returns whether the item exists in the set diff --git a/internal/mapset/mapset.go b/internal/mapset/mapset.go index 31feea6..09965d2 100644 --- a/internal/mapset/mapset.go +++ b/internal/mapset/mapset.go @@ -32,8 +32,8 @@ func Make[T comparable](args ...int) *MapSet[T] { return &MapSet[T]{m: make(map[T]struct{}, size)} } -// FromSlice creates a MapSet of size len(items) and calls Add for each of the items to it. -func FromSlice[T comparable](items []T) *MapSet[T] { +// FromItems creates a MapSet of size len(items) and calls Add for each of the items to it. +func FromItems[T comparable](items ...T) *MapSet[T] { h := Make[T](len(items)) for _, i := range items { h.Add(i) @@ -131,6 +131,6 @@ func (h *MapSet[T]) UnmarshalJSON(b []byte) error { return err } - *h = *FromSlice(s) + *h = *FromItems(s...) return nil } diff --git a/internal/mapset/mapset_test.go b/internal/mapset/mapset_test.go index 34088db..1b23ed6 100644 --- a/internal/mapset/mapset_test.go +++ b/internal/mapset/mapset_test.go @@ -54,7 +54,7 @@ func TestHashSet(t *testing.T) { }) t.Run("new from slice", func(t *testing.T) { - s := FromSlice([]int{1, 2, 2, 3}) + s := FromItems(1, 2, 2, 3) testutil.Equals(t, s.Len(), 3) testutil.Equals(t, s.Contains(1), true) testutil.Equals(t, s.Contains(2), true) @@ -69,17 +69,17 @@ func TestHashSet(t *testing.T) { testutil.Equals(t, s.Slice(), []int{}) inSlice := []int{1, 2, 3} - s = FromSlice(inSlice) + s = FromItems(inSlice...) outSlice := s.Slice() slices.Sort(outSlice) testutil.Equals(t, inSlice, outSlice) }) t.Run("equal", func(t *testing.T) { - s1 := FromSlice([]int{1, 2, 3}) + s1 := FromItems(1, 2, 3) testutil.Equals(t, s1.Equal(s1), true) - s2 := FromSlice([]int{1, 2, 3}) + s2 := FromItems(1, 2, 3) testutil.Equals(t, s1.Equal(s2), true) s2.Add(4) @@ -94,7 +94,7 @@ func TestHashSet(t *testing.T) { }) t.Run("iterate", func(t *testing.T) { - s1 := FromSlice([]int{1, 2, 3}) + s1 := FromItems(1, 2, 3) s2 := Make[int]() s1.Iterate(func(item int) bool { @@ -106,7 +106,7 @@ func TestHashSet(t *testing.T) { }) t.Run("iterate break early", func(t *testing.T) { - s1 := FromSlice([]int{1, 2, 3}) + s1 := FromItems(1, 2, 3) i := 0 var items []int @@ -127,19 +127,19 @@ func TestHashSet(t *testing.T) { }) t.Run("intersection with overlap", func(t *testing.T) { - s1 := FromSlice([]int{1, 2, 3}) - s2 := FromSlice([]int{2, 3, 4}) + s1 := FromItems(1, 2, 3) + s2 := FromItems(2, 3, 4) s3 := s1.Intersection(s2) - testutil.Equals(t, s3, FromSlice([]int{2, 3})) + testutil.Equals(t, s3, FromItems(2, 3)) s4 := s1.Intersection(s2) - testutil.Equals(t, s4, FromSlice([]int{2, 3})) + testutil.Equals(t, s4, FromItems(2, 3)) }) t.Run("intersection disjoint", func(t *testing.T) { - s1 := FromSlice([]int{1, 2}) - s2 := FromSlice([]int{3, 4}) + s1 := FromItems(1, 2) + s2 := FromItems(3, 4) s3 := s1.Intersection(s2) testutil.Equals(t, s3.Len(), 0) @@ -158,7 +158,7 @@ func TestHashSet(t *testing.T) { }) t.Run("encode json", func(t *testing.T) { - s := FromSlice([]int{1, 2, 3}) + s := FromItems(1, 2, 3) out, err := json.Marshal(s) @@ -179,7 +179,7 @@ func TestHashSet(t *testing.T) { var s1 MapSet[int] err := s1.UnmarshalJSON([]byte("[2,3,1,2]")) testutil.OK(t, err) - testutil.Equals(t, &s1, FromSlice([]int{1, 2, 3})) + testutil.Equals(t, &s1, FromItems(1, 2, 3)) }) t.Run("decode json empty", func(t *testing.T) { diff --git a/x/exp/batch/batch_test.go b/x/exp/batch/batch_test.go index 2271a8b..a282909 100644 --- a/x/exp/batch/batch_test.go +++ b/x/exp/batch/batch_test.go @@ -679,7 +679,7 @@ func TestFindVariables(t *testing.T) { t.Parallel() out := mapset.Make[types.String]() findVariables(out, tt.in) - testutil.Equals(t, out, mapset.FromSlice(tt.out)) + testutil.Equals(t, out, mapset.FromItems(tt.out...)) }) } From efe617a8aced05586bead9bebecd00766cc63b8a Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 24 Sep 2024 15:41:47 -0700 Subject: [PATCH 12/17] internal/mapset: reduce the functionality of Intersection to just a boolean Intersects Signed-off-by: Patrick Jakubowski --- internal/eval/evalers.go | 2 +- internal/mapset/immutable.go | 6 +++--- internal/mapset/immutable_test.go | 6 ++---- internal/mapset/mapset.go | 9 ++++----- internal/mapset/mapset_test.go | 12 ++---------- 5 files changed, 12 insertions(+), 23 deletions(-) diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 6783666..bb864aa 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -1022,7 +1022,7 @@ func entityInSet(env *Env, entity types.EntityUID, parents mapset.Container[type var candidate = entity for { if fe, ok := env.Entities[candidate]; ok { - if fe.Parents.Intersection(parents).Len() > 0 { + if fe.Parents.Intersects(parents) { return true } fe.Parents.Iterate(func(k types.EntityUID) bool { diff --git a/internal/mapset/immutable.go b/internal/mapset/immutable.go index 3be21d9..357c1d9 100644 --- a/internal/mapset/immutable.go +++ b/internal/mapset/immutable.go @@ -15,9 +15,9 @@ func (h ImmutableMapSet[T]) Contains(item T) bool { return MapSet[T](h).Contains(item) } -// Intersection returns the items common to both h and o. -func (h ImmutableMapSet[T]) Intersection(o Container[T]) ImmutableMapSet[T] { - return ImmutableMapSet[T](*MapSet[T](h).Intersection(o)) +// Intersects returns whether any items in this set exist in o +func (h ImmutableMapSet[T]) Intersects(o Container[T]) bool { + return MapSet[T](h).Intersects(o) } // Iterate the items in the set, calling callback for each item. If the callback returns false, iteration is halted. diff --git a/internal/mapset/immutable_test.go b/internal/mapset/immutable_test.go index 3baf661..fc3aaee 100644 --- a/internal/mapset/immutable_test.go +++ b/internal/mapset/immutable_test.go @@ -97,16 +97,14 @@ func TestImmutableHashSet(t *testing.T) { s1 := Immutable(1, 2, 3) s2 := Immutable(2, 3, 4) - s3 := s1.Intersection(s2) - testutil.Equals(t, s3, Immutable(2, 3)) + testutil.Equals(t, s1.Intersects(s2), true) }) t.Run("intersection disjoint", func(t *testing.T) { s1 := Immutable(1, 2) s2 := Immutable(3, 4) - s3 := s1.Intersection(s2) - testutil.Equals(t, s3.Len(), 0) + testutil.Equals(t, s1.Intersects(s2), false) }) t.Run("encode nil set", func(t *testing.T) { diff --git a/internal/mapset/mapset.go b/internal/mapset/mapset.go index 09965d2..91c4ea5 100644 --- a/internal/mapset/mapset.go +++ b/internal/mapset/mapset.go @@ -69,15 +69,14 @@ type Container[T comparable] interface { Contains(T) bool } -// Intersection returns the items common to both h and o. -func (h MapSet[T]) Intersection(o Container[T]) *MapSet[T] { - intersection := Make[T]() +// Intersects returns whether any items in this set exist in o +func (h MapSet[T]) Intersects(o Container[T]) bool { for item := range h.m { if o.Contains(item) { - intersection.Add(item) + return true } } - return intersection + return false } // Iterate the items in the set, calling callback for each item. If the callback returns false, iteration is halted. diff --git a/internal/mapset/mapset_test.go b/internal/mapset/mapset_test.go index 1b23ed6..21d43df 100644 --- a/internal/mapset/mapset_test.go +++ b/internal/mapset/mapset_test.go @@ -130,22 +130,14 @@ func TestHashSet(t *testing.T) { s1 := FromItems(1, 2, 3) s2 := FromItems(2, 3, 4) - s3 := s1.Intersection(s2) - testutil.Equals(t, s3, FromItems(2, 3)) - - s4 := s1.Intersection(s2) - testutil.Equals(t, s4, FromItems(2, 3)) + testutil.Equals(t, s1.Intersects(s2), true) }) t.Run("intersection disjoint", func(t *testing.T) { s1 := FromItems(1, 2) s2 := FromItems(3, 4) - s3 := s1.Intersection(s2) - testutil.Equals(t, s3.Len(), 0) - - s4 := s1.Intersection(s2) - testutil.Equals(t, s4.Len(), 0) + testutil.Equals(t, s1.Intersects(s2), false) }) t.Run("encode nil set", func(t *testing.T) { From 05c78eb745574f10d5da0db2792c9c280732fa97 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 24 Sep 2024 15:46:32 -0700 Subject: [PATCH 13/17] internal/mapset: add some documentation around the zero value of MapSet Signed-off-by: Patrick Jakubowski --- internal/mapset/mapset.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/internal/mapset/mapset.go b/internal/mapset/mapset.go index 91c4ea5..8a612fe 100644 --- a/internal/mapset/mapset.go +++ b/internal/mapset/mapset.go @@ -13,6 +13,12 @@ var peppercorn = struct{}{} // MapSet is a struct that adds some convenience to the otherwise cumbersome map[T]struct{} idiom used in Go to // implement mapset of comparable types. +// +// Note: the zero value of MapSet[T] (i.e. MapSet[T]{}) is fully usable and avoids unnecessary allocations in the case +// where nothing gets added to the MapSet. However, take care in using it, especially when passing it by value to other +// functions. If passed by value, mutating operations (e.g. Add(), Remove()) in the called function will persist in the +// calling function's version if the MapSet[T] has been changed from the zero value prior to the call. +// See the "zero value" test for an example. type MapSet[T comparable] struct { m map[T]struct{} } From c0c3981e30fbacd35ab4470121a578b7554790d6 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 24 Sep 2024 15:50:22 -0700 Subject: [PATCH 14/17] types: fix some tiny typos Signed-off-by: Patrick Jakubowski --- internal/eval/evalers_test.go | 2 +- types/entity_uid_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 86c0b9f..6262940 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1485,7 +1485,7 @@ func TestContainsAnyNode(t *testing.T) { t.Run("not quadratic", func(t *testing.T) { t.Parallel() - // Make two totally disjoint mapset to force a worst case search + // Make two totally disjoint sets to force a worst case search setSize := 200000 set1 := make([]types.Value, setSize) set2 := make([]types.Value, setSize) diff --git a/types/entity_uid_test.go b/types/entity_uid_test.go index 97b47ab..a8588e8 100644 --- a/types/entity_uid_test.go +++ b/types/entity_uid_test.go @@ -37,7 +37,7 @@ func TestEntityUIDSet(t *testing.T) { t.Run("new empty set", func(t *testing.T) { emptySets := []types.EntityUIDSet{ - types.EntityUIDSet{}, + {}, types.NewEntityUIDSet(), } From b7a52e1638136ccc721672351da6415dcc191363 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 26 Sep 2024 10:46:43 -0700 Subject: [PATCH 15/17] internal/mapset: give MarshalJSON a deterministic output Signed-off-by: Patrick Jakubowski --- internal/mapset/immutable.go | 2 +- internal/mapset/immutable_test.go | 11 +---------- internal/mapset/mapset.go | 17 +++++++++++++++-- internal/mapset/mapset_test.go | 31 ++++++++++++++++++++----------- 4 files changed, 37 insertions(+), 24 deletions(-) diff --git a/internal/mapset/immutable.go b/internal/mapset/immutable.go index 357c1d9..73b54a1 100644 --- a/internal/mapset/immutable.go +++ b/internal/mapset/immutable.go @@ -41,7 +41,7 @@ func (h ImmutableMapSet[T]) Equal(o ImmutableMapSet[T]) bool { return MapSet[T](h).Equal(&om) } -// MarshalJSON serializes an ImmutableMapSet as a JSON array. Order is non-deterministic. +// MarshalJSON serializes a MapSet as a JSON array. Elements are ordered lexicographically by their marshaled value. func (h ImmutableMapSet[T]) MarshalJSON() ([]byte, error) { return MapSet[T](h).MarshalJSON() } diff --git a/internal/mapset/immutable_test.go b/internal/mapset/immutable_test.go index fc3aaee..af5457a 100644 --- a/internal/mapset/immutable_test.go +++ b/internal/mapset/immutable_test.go @@ -121,17 +121,8 @@ func TestImmutableHashSet(t *testing.T) { out, err := json.Marshal(s) - correctOutputs := []string{ - "[1,2,3]", - "[1,3,2]", - "[2,1,3]", - "[2,3,1]", - "[3,1,2]", - "[3,2,1]", - } - testutil.OK(t, err) - testutil.FatalIf(t, !slices.Contains(correctOutputs, string(out)), "%v is not a valid output", string(out)) + testutil.Equals(t, string(out), "[1,2,3]") }) t.Run("decode json", func(t *testing.T) { diff --git a/internal/mapset/mapset.go b/internal/mapset/mapset.go index 8a612fe..3f6385f 100644 --- a/internal/mapset/mapset.go +++ b/internal/mapset/mapset.go @@ -1,8 +1,10 @@ package mapset import ( + "bytes" "encoding/json" "fmt" + "slices" "golang.org/x/exp/maps" ) @@ -121,12 +123,23 @@ func (h MapSet[T]) Equal(o *MapSet[T]) bool { return true } -// MarshalJSON serializes a MapSet as a JSON array. Order is non-deterministic. +// MarshalJSON serializes a MapSet as a JSON array. Elements are ordered lexicographically by their marshaled value. func (h MapSet[T]) MarshalJSON() ([]byte, error) { if h.m == nil { return []byte("[]"), nil } - return json.Marshal(h.Slice()) + + elems := h.Slice() + marshaledElems := make([][]byte, 0, len(elems)) + for _, elem := range elems { + b, err := json.Marshal(elem) + if err != nil { + return nil, err + } + marshaledElems = append(marshaledElems, b) + } + slices.SortFunc(marshaledElems, func(a, b []byte) int { return slices.Compare(a, b) }) + return slices.Concat([]byte{'['}, bytes.Join(marshaledElems, []byte{','}), []byte{']'}), nil } // UnmarshalJSON deserializes a MapSet from a JSON array. diff --git a/internal/mapset/mapset_test.go b/internal/mapset/mapset_test.go index 21d43df..ca126fe 100644 --- a/internal/mapset/mapset_test.go +++ b/internal/mapset/mapset_test.go @@ -149,22 +149,31 @@ func TestHashSet(t *testing.T) { testutil.Equals(t, string(out), "[]") }) - t.Run("encode json", func(t *testing.T) { - s := FromItems(1, 2, 3) + t.Run("encode json one int", func(t *testing.T) { + s := FromItems(1) out, err := json.Marshal(s) - correctOutputs := []string{ - "[1,2,3]", - "[1,3,2]", - "[2,1,3]", - "[2,3,1]", - "[3,1,2]", - "[3,2,1]", - } + testutil.OK(t, err) + testutil.Equals(t, string(out), "[1]") + }) + + t.Run("encode json multiple int", func(t *testing.T) { + s := FromItems(3, 2, 1) + + out, err := json.Marshal(s) + + testutil.OK(t, err) + testutil.Equals(t, string(out), "[1,2,3]") + }) + + t.Run("encode json multiple string", func(t *testing.T) { + s := FromItems("1", "2", "3") + + out, err := json.Marshal(s) testutil.OK(t, err) - testutil.FatalIf(t, !slices.Contains(correctOutputs, string(out)), "%v is not a valid output", string(out)) + testutil.Equals(t, string(out), `["1","2","3"]`) }) t.Run("decode json", func(t *testing.T) { From b71df7edfb48ecf5b040497ecbf40b3d0f835bd8 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 26 Sep 2024 10:52:11 -0700 Subject: [PATCH 16/17] Fix various nits Signed-off-by: Patrick Jakubowski --- internal/mapset/immutable_test.go | 2 +- internal/mapset/mapset.go | 6 ++++-- internal/mapset/mapset_test.go | 3 ++- policy.go | 2 +- types/set.go | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/internal/mapset/immutable_test.go b/internal/mapset/immutable_test.go index af5457a..663b8d6 100644 --- a/internal/mapset/immutable_test.go +++ b/internal/mapset/immutable_test.go @@ -34,7 +34,7 @@ func TestImmutableHashSet(t *testing.T) { t.Run("deduplicate elements", func(t *testing.T) { s := Immutable[int](1, 1) testutil.Equals(t, s.Contains(1), true) - testutil.Equals(t, s.Contains(2), false) + testutil.Equals(t, s.Len(), 1) }) t.Run("slice", func(t *testing.T) { diff --git a/internal/mapset/mapset.go b/internal/mapset/mapset.go index 3f6385f..8807064 100644 --- a/internal/mapset/mapset.go +++ b/internal/mapset/mapset.go @@ -55,9 +55,11 @@ func (h *MapSet[T]) Add(item T) bool { h.m = map[T]struct{}{} } - _, exists := h.m[item] + if _, exists := h.m[item]; exists { + return false + } h.m[item] = peppercorn - return !exists + return true } // Remove an item from the Set. Returns true if the item existed in the set. diff --git a/internal/mapset/mapset_test.go b/internal/mapset/mapset_test.go index ca126fe..5f5171d 100644 --- a/internal/mapset/mapset_test.go +++ b/internal/mapset/mapset_test.go @@ -91,6 +91,7 @@ func TestHashSet(t *testing.T) { s1.Add(4) s1.Remove(3) testutil.Equals(t, s1.Equal(s2), true) + testutil.Equals(t, s2.Equal(s1), true) }) t.Run("iterate", func(t *testing.T) { @@ -222,7 +223,7 @@ func TestHashSet(t *testing.T) { // because those mutations may or may not be reflected in the caller's version of the MapSet. t.Run("zero value", func(t *testing.T) { s := MapSet[int]{} - hashSetMustNotContain(t, &s, 1) + hashSetMustNotContain(t, &s, 0) testutil.Equals(t, s.Slice(), nil) addByValue := func(m MapSet[int], val int) { diff --git a/policy.go b/policy.go index 9e21c1c..adcffad 100644 --- a/policy.go +++ b/policy.go @@ -97,7 +97,7 @@ func (p *Policy) Position() Position { return Position(p.ast.Position) } -// SetFilename mapset the filename of this policy. +// SetFilename sets the filename of this policy. func (p *Policy) SetFilename(fileName string) { p.ast.Position.Filename = fileName } diff --git a/types/set.go b/types/set.go index 4e9e0e7..07958b2 100644 --- a/types/set.go +++ b/types/set.go @@ -89,7 +89,7 @@ func (s Set) Slice() []Value { return maps.Values(s.s) } -// Equal returns true if the mapset are Equal. +// Equal returns true if the sets are Equal. func (as Set) Equal(bi Value) bool { bs, ok := bi.(Set) if !ok { From e796ce27f1c09a201f0b5076a13309214e857fb8 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 26 Sep 2024 11:16:30 -0700 Subject: [PATCH 17/17] types: ensure Entity marshals to JSON with a consistent ordering Signed-off-by: Patrick Jakubowski --- internal/testutil/testutil.go | 14 ++++++++++++++ types/entities_test.go | 13 +------------ types/entity.go | 13 ++++++++++++- types/entity_test.go | 28 ++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 13 deletions(-) diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 148372d..f5d85f8 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -1,6 +1,8 @@ package testutil import ( + "bytes" + "encoding/json" "errors" "reflect" ) @@ -60,3 +62,15 @@ func Panic(t TB, f func()) { }() f() } + +// JSONMarshalsTo asserts that obj marshals as JSON to the given string, allowing for formatting differences and +// displaying an easy-to-read diff. +func JSONMarshalsTo[T any](t TB, obj T, want string) { + b, err := json.MarshalIndent(obj, "", "\t") + OK(t, err) + + var wantBuf bytes.Buffer + err = json.Indent(&wantBuf, []byte(want), "", "\t") + OK(t, err) + Equals(t, string(b), wantBuf.String()) +} diff --git a/types/entities_test.go b/types/entities_test.go index b8926e3..0be79f8 100644 --- a/types/entities_test.go +++ b/types/entities_test.go @@ -1,7 +1,6 @@ package types_test import ( - "bytes" "encoding/json" "testing" @@ -25,16 +24,6 @@ func TestEntities(t *testing.T) { } -func assertJSONEquals(t *testing.T, e any, want string) { - b, err := json.MarshalIndent(e, "", "\t") - testutil.OK(t, err) - - var wantBuf bytes.Buffer - err = json.Indent(&wantBuf, []byte(want), "", "\t") - testutil.OK(t, err) - testutil.Equals(t, string(b), wantBuf.String()) -} - func TestEntitiesJSON(t *testing.T) { t.Parallel() t.Run("Marshal", func(t *testing.T) { @@ -52,7 +41,7 @@ func TestEntitiesJSON(t *testing.T) { } e[ent.UID] = ent e[ent2.UID] = ent2 - assertJSONEquals( + testutil.JSONMarshalsTo( t, e, `[ diff --git a/types/entity.go b/types/entity.go index 76b58c5..a41fc1f 100644 --- a/types/entity.go +++ b/types/entity.go @@ -1,6 +1,10 @@ package types -import "encoding/json" +import ( + "encoding/json" + "slices" + "strings" +) // An Entity defines the parents and attributes for an EntityUID. type Entity struct { @@ -17,6 +21,13 @@ func (e Entity) MarshalJSON() ([]byte, error) { parents = append(parents, ImplicitlyMarshaledEntityUID(p)) return true }) + slices.SortFunc(parents, func(a, b ImplicitlyMarshaledEntityUID) int { + if cmp := strings.Compare(string(a.Type), string(b.Type)); cmp != 0 { + return cmp + } + + return strings.Compare(string(a.ID), string(b.ID)) + }) m := struct { UID ImplicitlyMarshaledEntityUID `json:"uid"` diff --git a/types/entity_test.go b/types/entity_test.go index b9f6e0d..28085ac 100644 --- a/types/entity_test.go +++ b/types/entity_test.go @@ -27,3 +27,31 @@ func TestEntityIsZero(t *testing.T) { }) } } + +func TestEntityMarshalJSON(t *testing.T) { + t.Parallel() + e := types.Entity{ + UID: types.NewEntityUID("FooType", "1"), + Parents: types.NewEntityUIDSet( + types.NewEntityUID("BazType", "1"), + types.NewEntityUID("BarType", "2"), + types.NewEntityUID("BarType", "1"), + types.NewEntityUID("QuuxType", "30"), + types.NewEntityUID("QuuxType", "3"), + ), + Attributes: types.Record{}, + } + + testutil.JSONMarshalsTo(t, e, + `{ + "uid": {"type":"FooType","id":"1"}, + "parents": [ + {"type":"BarType","id":"1"}, + {"type":"BarType","id":"2"}, + {"type":"BazType","id":"1"}, + {"type":"QuuxType","id":"3"}, + {"type":"QuuxType","id":"30"} + ], + "attrs":{} + }`) +}