Skip to content

Commit

Permalink
schema: support loading objects with loops
Browse files Browse the repository at this point in the history
  • Loading branch information
Denys Smirnov authored and dennwc committed Oct 20, 2018
1 parent 330365b commit f312d21
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 36 deletions.
104 changes: 71 additions & 33 deletions schema/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ type loader struct {

pathForType map[reflect.Type]*path.Path
pathForTypeRoot map[reflect.Type]*path.Path

seen map[quad.Value]reflect.Value
}

func (c *Config) newLoader(qs graph.QuadStore) *loader {
Expand All @@ -157,6 +159,8 @@ func (c *Config) newLoader(qs graph.QuadStore) *loader {

pathForType: make(map[reflect.Type]*path.Path),
pathForTypeRoot: make(map[reflect.Type]*path.Path),

seen: make(map[quad.Value]reflect.Value),
}
}

Expand Down Expand Up @@ -288,7 +292,7 @@ func (l *loader) loadToValue(ctx context.Context, dst reflect.Value, depth int,
for i := 0; i < rt.NumField(); i++ {
select {
case <-ctx.Done():
return context.Canceled
return ctx.Err()
default:
}
f := rt.Field(i)
Expand All @@ -313,18 +317,32 @@ func (l *loader) loadToValue(ctx context.Context, dst reflect.Value, depth int,
}
ft := f.Type
native := isNative(ft)
ptr := ft.Kind() == reflect.Ptr
for ft.Kind() == reflect.Ptr || ft.Kind() == reflect.Slice {
native = native || isNative(ft)
ft = ft.Elem()
native = native || isNative(ft)
switch ft.Kind() {
case reflect.Ptr:
ptr = true
case reflect.Slice:
ptr = false
}
}
recursive := !native && ft.Kind() == reflect.Struct
for _, fv := range arr {
var sv reflect.Value
if recursive {
if ptr {
fv := l.qs.NameOf(fv)
var ok bool
sv, ok = l.seen[fv]
if ok && sv.Type().AssignableTo(f.Type) {
df.Set(sv)
continue
}
}
sv = reflect.New(ft).Elem()
sit := iterator.NewFixed()
sit.Add(fv)
err := l.loadIteratorToDepth(ctx, sv, depth-1, sit)
err := l.loadIteratorToDepth(ctx, sv, depth-1, iterator.NewFixed(fv))
if err == errRequiredFieldIsMissing {
continue
} else if err != nil {
Expand Down Expand Up @@ -353,6 +371,19 @@ func (l *loader) iteratorForType(root graph.Iterator, rt reflect.Type, rootOnly
return l.iteratorFromPath(root, p)
}

func mergeMap(dst map[string][]graph.Value, m map[string]graph.Value) {
loop:
for k, v := range m {
sl := dst[k]
for _, sv := range sl {
if keysEqual(sv, v) {
continue loop
}
}
dst[k] = append(sl, v)
}
}

func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, depth int, list graph.Iterator) error {
if ctx == nil {
ctx = context.TODO()
Expand All @@ -374,11 +405,20 @@ func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, dep
if err != nil {
return err
}
select {
case <-ctx.Done():

ctxDone := func() bool {
select {
case <-ctx.Done():
return true
default:
}
return false
}

if ctxDone() {
return ctx.Err()
default:
}

rootOnly := depth == 0
it, err := l.iteratorForType(list, et, rootOnly)
if err != nil {
Expand All @@ -388,10 +428,22 @@ func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, dep

ctx = context.WithValue(ctx, fieldsCtxKey{}, fields)
for it.Next(ctx) {
select {
case <-ctx.Done():
if ctxDone() {
return ctx.Err()
default:
}
id := l.qs.NameOf(it.Result())
if id != nil {
if sv, ok := l.seen[id]; ok {
if slice {
dst.Set(reflect.Append(dst, sv.Elem()))
} else if chanl {
dst.Send(sv.Elem())
} else {
dst.Set(sv)
return nil
}
continue
}
}
mp := make(map[string]graph.Value)
it.TagResults(mp)
Expand All @@ -407,37 +459,23 @@ func (l *loader) loadIteratorToDepth(ctx context.Context, dst reflect.Value, dep
mo[k] = []graph.Value{v}
}
for it.NextPath(ctx) {
select {
case <-ctx.Done():
if ctxDone() {
return ctx.Err()
default:
}
mp = make(map[string]graph.Value)
it.TagResults(mp)
if len(mp) == 0 {
continue
}
// TODO(dennwc): replace with something more efficient
for k, v := range mp {
if sl, ok := mo[k]; !ok {
mo[k] = []graph.Value{v}
} else if len(sl) == 1 {
if !keysEqual(sl[0], v) {
mo[k] = append(sl, v)
}
} else {
found := false
for _, sv := range sl {
if keysEqual(sv, v) {
found = true
break
}
}
if !found {
mo[k] = append(sl, v)
}
}
mergeMap(mo, mp)
}
if id != nil {
sv := cur
if sv.Kind() != reflect.Ptr && sv.CanAddr() {
sv = sv.Addr()
}
l.seen[id] = sv
}
err := l.loadToValue(ctx, cur, depth, mo, "")
if err == errRequiredFieldIsMissing {
Expand Down
77 changes: 74 additions & 3 deletions schema/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,78 @@ import (
"github.com/cayleygraph/cayley/schema"
)

func TestLoadLoop(t *testing.T) {
sch := schema.NewConfig()

a := &NodeLoop{ID: iri("A"), Name: "Node A"}
a.Next = a

qs := memstore.New([]quad.Quad{
{a.ID, iri("name"), quad.String(a.Name), nil},
{a.ID, iri("next"), a.ID, nil},
}...)

b := &NodeLoop{}
if err := sch.LoadIteratorTo(nil, qs, reflect.ValueOf(b), nil); err != nil {
t.Error(err)
return
}
if a.ID != b.ID || a.Name != b.Name {
t.Fatalf("%#v vs %#v", a, b)
}
if b != b.Next {
t.Fatalf("loop is broken: %p vs %p", b, b.Next)
}

a = &NodeLoop{ID: iri("A"), Name: "Node A"}
b = &NodeLoop{ID: iri("B"), Name: "Node B"}
c := &NodeLoop{ID: iri("C"), Name: "Node C"}
a.Next = b
b.Next = c
c.Next = a

qs = memstore.New([]quad.Quad{
{a.ID, iri("name"), quad.String(a.Name), nil},
{b.ID, iri("name"), quad.String(b.Name), nil},
{c.ID, iri("name"), quad.String(c.Name), nil},
{a.ID, iri("next"), b.ID, nil},
{b.ID, iri("next"), c.ID, nil},
{c.ID, iri("next"), a.ID, nil},
}...)

a1 := &NodeLoop{}
if err := sch.LoadIteratorTo(nil, qs, reflect.ValueOf(a1), nil); err != nil {
t.Error(err)
return
}
if a.ID != a1.ID || a.Name != a1.Name {
t.Fatalf("%#v vs %#v", a, b)
}
b1 := a1.Next
c1 := b1.Next
if b.ID != b1.ID || b.Name != b1.Name {
t.Fatalf("%#v vs %#v", a, b)
}
if c.ID != c1.ID || c.Name != c1.Name {
t.Fatalf("%#v vs %#v", a, b)
}
if a1 != c1.Next {
t.Fatalf("loop is broken: %p vs %p", a1, c1.Next)
}
}

func TestLoadIteratorTo(t *testing.T) {
sch := schema.NewConfig()
for i, c := range testFillValueCases {
t.Run(c.name, func(t *testing.T) {
qs := memstore.New(c.quads...)
out := reflect.New(reflect.TypeOf(c.expect))
rt := reflect.TypeOf(c.expect)
var out reflect.Value
if rt.Kind() == reflect.Ptr {
out = reflect.New(rt.Elem())
} else {
out = reflect.New(rt)
}
var it graph.Iterator
if c.from != nil {
fixed := iterator.NewFixed()
Expand All @@ -33,7 +99,12 @@ func TestLoadIteratorTo(t *testing.T) {
t.Errorf("case %d failed: %v", i+1, err)
return
}
got := out.Elem().Interface()
var got interface{}
if rt.Kind() == reflect.Ptr {
got = out.Interface()
} else {
got = out.Elem().Interface()
}
if s, ok := got.(interface {
Sort()
}); ok {
Expand All @@ -46,7 +117,7 @@ func TestLoadIteratorTo(t *testing.T) {
}
if !reflect.DeepEqual(got, c.expect) {
t.Errorf("case %d failed: objects are different\n%#v\n%#v",
i+1, out.Elem().Interface(), c.expect,
i+1, got, c.expect,
)
}
})
Expand Down

0 comments on commit f312d21

Please sign in to comment.