diff --git a/set.go b/set.go new file mode 100644 index 0000000..f56a275 --- /dev/null +++ b/set.go @@ -0,0 +1,66 @@ +package multihash + +// Set is a set of Multihashes, holding one copy per Multihash. +type Set struct { + set map[string]struct{} +} + +// NewSet creates a new set correctly initialized. +func NewSet() *Set { + return &Set{ + set: make(map[string]struct{}), + } +} + +// Add adds a new multihash to the set. +func (s *Set) Add(m Multihash) { + s.set[string(m)] = struct{}{} +} + +// Len returns the number of elements in the set. +func (s *Set) Len() int { + return len(s.set) +} + +// Has returns true if the element is in the set. +func (s *Set) Has(m Multihash) bool { + _, ok := s.set[string(m)] + return ok +} + +// Visit adds a multihash only if it is not in the set already. Returns true +// if the multihash was added (was not in the set before). +func (s *Set) Visit(m Multihash) bool { + _, ok := s.set[string(m)] + if !ok { + s.set[string(m)] = struct{}{} + return true + } + return false +} + +// ForEach runs f(m) with each multihash in the set. If returns immediately if +// f(m) returns an error. +func (s *Set) ForEach(f func(m Multihash) error) error { + for elem := range s.set { + mh := Multihash(elem) + if err := f(mh); err != nil { + return err + } + } + return nil +} + +// Remove removes an element from the set. +func (s *Set) Remove(m Multihash) { + delete(s.set, string(m)) +} + +// All returns a slice with all the elements in the set. +func (s *Set) All() []Multihash { + out := make([]Multihash, 0, len(s.set)) + for m := range s.set { + out = append(out, Multihash(m)) + } + return out +} diff --git a/set_test.go b/set_test.go new file mode 100644 index 0000000..2d3f607 --- /dev/null +++ b/set_test.go @@ -0,0 +1,86 @@ +package multihash + +import ( + "crypto/rand" + "errors" + "testing" +) + +func makeRandomMultihash(t *testing.T) Multihash { + t.Helper() + + p := make([]byte, 256) + _, err := rand.Read(p) + if err != nil { + t.Fatal(err) + } + + m, err := Sum(p, SHA3, 4) + if err != nil { + t.Fatal(err) + } + return m +} + +func TestSet(t *testing.T) { + mhSet := NewSet() + + total := 10 + for i := 0; i < total; i++ { + mhSet.Add(makeRandomMultihash(t)) + } + + m0 := makeRandomMultihash(t) + + if mhSet.Len() != total { + t.Error("bad length") + } + + if mhSet.Has(m0) { + t.Error("m0 should not be in set") + } + + mhSet.Add(m0) + + if !mhSet.Has(m0) { + t.Error("m0 should be in set") + } + + i := 0 + f := func(m Multihash) error { + i++ + if i == 3 { + return errors.New("3") + } + return nil + } + + mhSet.ForEach(f) + if i != 3 { + t.Error("forEach should have run 3 times") + } + + mhSet.Remove(m0) + + if mhSet.Len() != total { + t.Error("an element should have been removed") + } + + if mhSet.Has(m0) { + t.Error("m0 should not be in set") + } + + if !mhSet.Visit(m0) { + t.Error("Visit() should return true when new element added") + } + + all := mhSet.All() + if len(all) != mhSet.Len() { + t.Error("All() should return all") + } + for _, mh := range all { + if !mhSet.Has(mh) { + t.Error("element in All() not in set") + } + } +}