Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(Query): fix cascade with pagination #7440

Merged
merged 18 commits into from
Feb 19, 2021
Merged
77 changes: 68 additions & 9 deletions query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ type params struct {
// Cascade is the list of predicates to apply @cascade to.
// __all__ is special to mean @cascade i.e. all the children of this subgraph are mandatory
// and should have values otherwise the node will be excluded.
Cascade []string
Cascade *CascadeArgs
// IgnoreReflex is true if the @ignorereflex directive is specified.
IgnoreReflex bool

Expand Down Expand Up @@ -206,6 +206,14 @@ type params struct {
AllowedPreds []string
}

// CascadeArgs stores the arguments needed to process @cascade directive.
// It is introduced to ensure correct behaviour for cascade with pagination.
type CascadeArgs struct {
Fields []string
First int
Offset int
}

type pathMetadata struct {
weight float64 // Total weight of the path.
}
Expand Down Expand Up @@ -559,15 +567,22 @@ func treeCopy(gq *gql.GraphQuery, sg *SubGraph) error {
GroupbyAttrs: gchild.GroupbyAttrs,
IsGroupBy: gchild.IsGroupby,
IsInternal: gchild.IsInternal,
Cascade: &CascadeArgs{},
}

// Inherit from the parent.
if len(sg.Params.Cascade) > 0 {
args.Cascade = append(args.Cascade, sg.Params.Cascade...)
if len(sg.Params.Cascade.Fields) > 0 {
args.Cascade.Fields = append(args.Cascade.Fields, sg.Params.Cascade.Fields...)
}
// Allow over-riding at this level.
if len(gchild.Cascade) > 0 {
args.Cascade = gchild.Cascade
args.Cascade.Fields = gchild.Cascade
}

// Remove pagination arguments from the query if @cascade is mentioned since
// pagination will be applied post processing the data.
if len(args.Cascade.Fields) > 0 {
args.addCascadePaginationArguments(gchild)
}

if gchild.IsCount {
Expand Down Expand Up @@ -646,6 +661,13 @@ func treeCopy(gq *gql.GraphQuery, sg *SubGraph) error {
return nil
}

func (args *params) addCascadePaginationArguments(gq *gql.GraphQuery) {
args.Cascade.First, _ = strconv.Atoi(gq.Args["first"])
delete(gq.Args, "first")
args.Cascade.Offset, _ = strconv.Atoi(gq.Args["offset"])
delete(gq.Args, "offset")
}

func (args *params) fill(gq *gql.GraphQuery) error {
if v, ok := gq.Args["offset"]; ok {
offset, err := strconv.ParseInt(v, 0, 32)
Expand Down Expand Up @@ -777,7 +799,7 @@ func newGraph(ctx context.Context, gq *gql.GraphQuery) (*SubGraph, error) {
// The attr at root (if present) would stand for the source functions attr.
args := params{
Alias: gq.Alias,
Cascade: gq.Cascade,
Cascade: &CascadeArgs{Fields: gq.Cascade},
GetUid: isDebug(ctx),
IgnoreReflex: gq.IgnoreReflex,
IsEmpty: gq.IsEmpty,
Expand All @@ -795,6 +817,12 @@ func newGraph(ctx context.Context, gq *gql.GraphQuery) (*SubGraph, error) {
AllowedPreds: gq.AllowedPreds,
}

// Remove pagination arguments from the query if @cascade is mentioned since
// pagination will be applied post processing the data.
if len(args.Cascade.Fields) > 0 {
args.addCascadePaginationArguments(gq)
}

for argk := range gq.Args {
if !isValidArg(argk) {
return nil, errors.Errorf("Invalid argument: %s", argk)
Expand Down Expand Up @@ -1320,7 +1348,7 @@ func (sg *SubGraph) populateVarMap(doneVars map[string]varValue, sgPath []*SubGr
}

cascadeArgMap := make(map[string]bool)
for _, pred := range sg.Params.Cascade {
for _, pred := range sg.Params.Cascade.Fields {
cascadeArgMap[pred] = true
}
cascadeAllPreds := cascadeArgMap["__all__"]
Expand All @@ -1340,16 +1368,24 @@ func (sg *SubGraph) populateVarMap(doneVars map[string]varValue, sgPath []*SubGr
return err
}
sgPath = sgPath[:len(sgPath)-1] // Backtrack
if len(child.Params.Cascade) == 0 {
if len(child.Params.Cascade.Fields) == 0 {
continue
}

// Intersect the UidMatrix with the DestUids as some UIDs might have been removed
// by other operations. So we need to apply it on the UidMatrix.
child.updateUidMatrix()

// Apply pagination after the @cascade.
if len(child.Params.Cascade.Fields) > 0 && child.Params.Cascade.First != 0 && child.Params.Cascade.Offset != 0 {
for i := 0; i < len(child.uidMatrix); i++ {
start, end := x.PageRange(child.Params.Cascade.First, child.Params.Cascade.Offset, len(child.uidMatrix[i].Uids))
child.uidMatrix[i].Uids = child.uidMatrix[i].Uids[start:end]
}
}
}

if len(sg.Params.Cascade) == 0 {
if len(sg.Params.Cascade.Fields) == 0 {
goto AssignStep
}

Expand Down Expand Up @@ -2132,7 +2168,19 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) {

if parent == nil {
// I'm root. We reach here if root had a function.
sg.uidMatrix = []*pb.List{sg.DestUIDs}

if len(sg.Params.Cascade.Fields) >= 0 {
// DesitUIDs for this level becomes the sourceUIDs for the next level. In updateUidMatrix with cascade,
// we end up modifying the first list from the uidMatrix which ends up modifying the srcUids of the next level.
// So to avoid that we make a copy.
newDestUIDList := &pb.List{Uids: make([]uint64, 0, len(sg.DestUIDs.Uids))}
for _, uid := range sg.DestUIDs.GetUids() {
newDestUIDList.Uids = append(newDestUIDList.Uids, uid)
}
sg.uidMatrix = []*pb.List{newDestUIDList}
} else {
sg.uidMatrix = []*pb.List{sg.DestUIDs}
}
}
}
}
Expand Down Expand Up @@ -2803,6 +2851,17 @@ func (req *Request) ProcessQuery(ctx context.Context) (err error) {
if err := sg.populateVarMap(req.Vars, sgPath); err != nil {
return err
}
// first time at the root here.

// Apply pagination at the root after @cascade.
if len(sg.Params.Cascade.Fields) > 0 && sg.Params.Cascade.First != 0 && sg.Params.Cascade.Offset != 0 {
sg.updateUidMatrix()
for i := 0; i < len(sg.uidMatrix); i++ {
start, end := x.PageRange(sg.Params.Cascade.First, sg.Params.Cascade.Offset, len(sg.uidMatrix[i].Uids))
sg.uidMatrix[i].Uids = sg.uidMatrix[i].Uids[start:end]
}
}

if err := sg.populatePostAggregation(req.Vars, []*SubGraph{}, nil); err != nil {
return err
}
Expand Down
33 changes: 33 additions & 0 deletions query/query0_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,39 @@ func TestCascadeDirective(t *testing.T) {
js)
}

func TestCascadeWithPaginationDeep(t *testing.T) {
query := `
{
me(func: type("Person")) @cascade{
name
friend {
name
friend(first: 2, offset: 1) {
name
alive
}
}
}
}
`

js := processQueryNoErr(t, query)
require.JSONEq(t, `{"data":{"me":[{"name":"Rick Grimes","friend":[{"name": "Michonne","friend":[{"name":"Daryl Dixon","alive":false},{"name": "Andrea","alive": false}]}]}]}}`, js)
}

func TestCascadeWithPaginationAtRoot(t *testing.T) {
query := `
{
me(func: type(Person), first: 2, offset: 2) @cascade{
name
alive
}
}
`
js := processQueryNoErr(t, query)
require.JSONEq(t, `{"data":{"me":[{"name":"Andrea","alive":false}]}}`, js)
}

func TestLevelBasedFacetVarAggSum(t *testing.T) {
query := `
{
Expand Down