From f312d212e412ff252f75c260a4b5e6540b8f7668 Mon Sep 17 00:00:00 2001 From: Denys Smirnov Date: Sat, 20 Oct 2018 22:09:59 +0300 Subject: [PATCH] schema: support loading objects with loops --- schema/loader.go | 104 ++++++++++++++++++++++++++++-------------- schema/loader_test.go | 77 +++++++++++++++++++++++++++++-- 2 files changed, 145 insertions(+), 36 deletions(-) diff --git a/schema/loader.go b/schema/loader.go index e4c76ac89..03b3d3e49 100644 --- a/schema/loader.go +++ b/schema/loader.go @@ -148,6 +148,8 @@ type loader struct { pathForType map[reflect.Type]*path.Path pathForTypeRoot map[reflect.Type]*path.Path + + seen map[quad.Value]reflect.Value } func (c *Config) newLoader(qs graph.QuadStore) *loader { @@ -157,6 +159,8 @@ func (c *Config) newLoader(qs graph.QuadStore) *loader { pathForType: make(map[reflect.Type]*path.Path), pathForTypeRoot: make(map[reflect.Type]*path.Path), + + seen: make(map[quad.Value]reflect.Value), } } @@ -288,7 +292,7 @@ func (l *loader) loadToValue(ctx context.Context, dst reflect.Value, depth int, for i := 0; i < rt.NumField(); i++ { select { case <-ctx.Done(): - return context.Canceled + return ctx.Err() default: } f := rt.Field(i) @@ -313,18 +317,32 @@ func (l *loader) loadToValue(ctx context.Context, dst reflect.Value, depth int, } ft := f.Type native := isNative(ft) + ptr := ft.Kind() == reflect.Ptr for ft.Kind() == reflect.Ptr || ft.Kind() == reflect.Slice { - native = native || isNative(ft) ft = ft.Elem() + native = native || isNative(ft) + switch ft.Kind() { + case reflect.Ptr: + ptr = true + case reflect.Slice: + ptr = false + } } recursive := !native && ft.Kind() == reflect.Struct for _, fv := range arr { var sv reflect.Value if recursive { + if ptr { + fv := l.qs.NameOf(fv) + var ok bool + sv, ok = l.seen[fv] + if ok && sv.Type().AssignableTo(f.Type) { + df.Set(sv) + continue + } + } sv = reflect.New(ft).Elem() - sit := iterator.NewFixed() - sit.Add(fv) - err := l.loadIteratorToDepth(ctx, sv, depth-1, sit) + err := l.loadIteratorToDepth(ctx, sv, depth-1, iterator.NewFixed(fv)) if err == errRequiredFieldIsMissing { continue } else if err != nil { @@ -353,6 +371,19 @@ func (l *loader) iteratorForType(root graph.Iterator, rt reflect.Type, rootOnly return l.iteratorFromPath(root, p) } +func mergeMap(dst map[string][]graph.Value, m map[string]graph.Value) { +loop: + for k, v := range m { + sl := dst[k] + for _, sv := range sl { + if keysEqual(sv, v) { + continue loop + } + } + dst[k] = append(sl, v) + } +} + func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, depth int, list graph.Iterator) error { if ctx == nil { ctx = context.TODO() @@ -374,11 +405,20 @@ func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, dep if err != nil { return err } - select { - case <-ctx.Done(): + + ctxDone := func() bool { + select { + case <-ctx.Done(): + return true + default: + } + return false + } + + if ctxDone() { return ctx.Err() - default: } + rootOnly := depth == 0 it, err := l.iteratorForType(list, et, rootOnly) if err != nil { @@ -388,10 +428,22 @@ func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, dep ctx = context.WithValue(ctx, fieldsCtxKey{}, fields) for it.Next(ctx) { - select { - case <-ctx.Done(): + if ctxDone() { return ctx.Err() - default: + } + id := l.qs.NameOf(it.Result()) + if id != nil { + if sv, ok := l.seen[id]; ok { + if slice { + dst.Set(reflect.Append(dst, sv.Elem())) + } else if chanl { + dst.Send(sv.Elem()) + } else { + dst.Set(sv) + return nil + } + continue + } } mp := make(map[string]graph.Value) it.TagResults(mp) @@ -407,10 +459,8 @@ func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, dep mo[k] = []graph.Value{v} } for it.NextPath(ctx) { - select { - case <-ctx.Done(): + if ctxDone() { return ctx.Err() - default: } mp = make(map[string]graph.Value) it.TagResults(mp) @@ -418,26 +468,14 @@ func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, dep continue } // TODO(dennwc): replace with something more efficient - for k, v := range mp { - if sl, ok := mo[k]; !ok { - mo[k] = []graph.Value{v} - } else if len(sl) == 1 { - if !keysEqual(sl[0], v) { - mo[k] = append(sl, v) - } - } else { - found := false - for _, sv := range sl { - if keysEqual(sv, v) { - found = true - break - } - } - if !found { - mo[k] = append(sl, v) - } - } + mergeMap(mo, mp) + } + if id != nil { + sv := cur + if sv.Kind() != reflect.Ptr && sv.CanAddr() { + sv = sv.Addr() } + l.seen[id] = sv } err := l.loadToValue(ctx, cur, depth, mo, "") if err == errRequiredFieldIsMissing { diff --git a/schema/loader_test.go b/schema/loader_test.go index 728f63d30..c89481a80 100644 --- a/schema/loader_test.go +++ b/schema/loader_test.go @@ -11,12 +11,78 @@ import ( "github.com/cayleygraph/cayley/schema" ) +func TestLoadLoop(t *testing.T) { + sch := schema.NewConfig() + + a := &NodeLoop{ID: iri("A"), Name: "Node A"} + a.Next = a + + qs := memstore.New([]quad.Quad{ + {a.ID, iri("name"), quad.String(a.Name), nil}, + {a.ID, iri("next"), a.ID, nil}, + }...) + + b := &NodeLoop{} + if err := sch.LoadIteratorTo(nil, qs, reflect.ValueOf(b), nil); err != nil { + t.Error(err) + return + } + if a.ID != b.ID || a.Name != b.Name { + t.Fatalf("%#v vs %#v", a, b) + } + if b != b.Next { + t.Fatalf("loop is broken: %p vs %p", b, b.Next) + } + + a = &NodeLoop{ID: iri("A"), Name: "Node A"} + b = &NodeLoop{ID: iri("B"), Name: "Node B"} + c := &NodeLoop{ID: iri("C"), Name: "Node C"} + a.Next = b + b.Next = c + c.Next = a + + qs = memstore.New([]quad.Quad{ + {a.ID, iri("name"), quad.String(a.Name), nil}, + {b.ID, iri("name"), quad.String(b.Name), nil}, + {c.ID, iri("name"), quad.String(c.Name), nil}, + {a.ID, iri("next"), b.ID, nil}, + {b.ID, iri("next"), c.ID, nil}, + {c.ID, iri("next"), a.ID, nil}, + }...) + + a1 := &NodeLoop{} + if err := sch.LoadIteratorTo(nil, qs, reflect.ValueOf(a1), nil); err != nil { + t.Error(err) + return + } + if a.ID != a1.ID || a.Name != a1.Name { + t.Fatalf("%#v vs %#v", a, b) + } + b1 := a1.Next + c1 := b1.Next + if b.ID != b1.ID || b.Name != b1.Name { + t.Fatalf("%#v vs %#v", a, b) + } + if c.ID != c1.ID || c.Name != c1.Name { + t.Fatalf("%#v vs %#v", a, b) + } + if a1 != c1.Next { + t.Fatalf("loop is broken: %p vs %p", a1, c1.Next) + } +} + func TestLoadIteratorTo(t *testing.T) { sch := schema.NewConfig() for i, c := range testFillValueCases { t.Run(c.name, func(t *testing.T) { qs := memstore.New(c.quads...) - out := reflect.New(reflect.TypeOf(c.expect)) + rt := reflect.TypeOf(c.expect) + var out reflect.Value + if rt.Kind() == reflect.Ptr { + out = reflect.New(rt.Elem()) + } else { + out = reflect.New(rt) + } var it graph.Iterator if c.from != nil { fixed := iterator.NewFixed() @@ -33,7 +99,12 @@ func TestLoadIteratorTo(t *testing.T) { t.Errorf("case %d failed: %v", i+1, err) return } - got := out.Elem().Interface() + var got interface{} + if rt.Kind() == reflect.Ptr { + got = out.Interface() + } else { + got = out.Elem().Interface() + } if s, ok := got.(interface { Sort() }); ok { @@ -46,7 +117,7 @@ func TestLoadIteratorTo(t *testing.T) { } if !reflect.DeepEqual(got, c.expect) { t.Errorf("case %d failed: objects are different\n%#v\n%#v", - i+1, out.Elem().Interface(), c.expect, + i+1, got, c.expect, ) } })