Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(query): Support for between func with count at root #6556

Merged
merged 14 commits into from
Sep 24, 2020
49 changes: 49 additions & 0 deletions query/query0_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3234,6 +3234,55 @@ func TestBetweenInt(t *testing.T) {
}
}

func TestBetweenCount(t *testing.T) {
tests := []struct {
name string
query string
result string
}{
{
`Test between on valid bounds`,
`
{
me(func: between(count(friend), 1, 3)) {
name
}
}
`,
`{"data":{"me":[{"name":"Rick Grimes"},{"name":"Andrea"}]}}`,
},
{
`Test between on count equal bounds`,
`
{
me(func: between(count(friend), 5, 5)) {
name
}
}
`,
`{"data":{"me":[{"name":"Michonne"}]}}`,
},
{
`Test between on count invalid bounds`,
`
{
me(func: between(count(friend), 3, 1)) {
name
}
}
`,
`{"data":{"me":[]}}`,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
js := processQueryNoErr(t, tc.query)
require.JSONEq(t, js, tc.result)
})
}
}

var client *dgo.Dgraph

func TestMain(m *testing.M) {
Expand Down
69 changes: 47 additions & 22 deletions worker/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ func (qs *queryState) handleUidPostings(
return posting.ErrTsTooOld
}
count := int64(len)
if evalCompare(srcFn.fname, count, srcFn.threshold) {
if evalCompare(srcFn.fname, count, srcFn.threshold[0]) {
tlist := &pb.List{Uids: []uint64{q.UidList.Uids[i]}}
out.UidMatrix = append(out.UidMatrix, tlist)
}
Expand Down Expand Up @@ -1061,10 +1061,10 @@ func (qs *queryState) handleCompareScalarFunction(ctx context.Context, arg funcA
return errors.Errorf("Need @count directive in schema for attr: %s for fn: %s at root",
attr, arg.srcFn.fname)
}
count := arg.srcFn.threshold
counts := arg.srcFn.threshold
cp := countParams{
fn: arg.srcFn.fname,
count: count,
counts: counts,
attr: attr,
gid: arg.gid,
readTs: arg.q.ReadTs,
Expand Down Expand Up @@ -1411,7 +1411,7 @@ func (qs *queryState) handleMatchFunction(ctx context.Context, arg funcArgs) err
return err
}

max := int(arg.srcFn.threshold)
max := int(arg.srcFn.threshold[0])
for _, val := range vals {
// convert data from binary to appropriate format
strVal, err := types.Convert(val, types.StringID)
Expand Down Expand Up @@ -1621,7 +1621,7 @@ type functionContext struct {
eqTokens []types.Val
ineqValueToken []string
n int
threshold int64
threshold []int64
uidsPresent []uint64
fname string
fnType FuncType
Expand Down Expand Up @@ -1769,13 +1769,23 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) {
fc.n = len(fc.tokens)
}
case compareScalarFn:
if err = ensureArgsCount(q.SrcFunc, 1); err != nil {
argCount := 1
if q.SrcFunc.Name == between {
argCount = 2
}
if err = ensureArgsCount(q.SrcFunc, argCount); err != nil {
return nil, err
}
if fc.threshold, err = strconv.ParseInt(q.SrcFunc.Args[0], 0, 64); err != nil {
return nil, errors.Wrapf(err, "Compare %v(%v) require digits, but got invalid num",
q.SrcFunc.Name, q.SrcFunc.Args[0])
var thresholds []int64
for _, arg := range q.SrcFunc.Args {
threshold, err := strconv.ParseInt(arg, 0, 64)
if err != nil {
return nil, errors.Wrapf(err, "Compare %v(%v) require digits, but got invalid num",
q.SrcFunc.Name, q.SrcFunc.Args[0])
}
thresholds = append(thresholds, threshold)
}
fc.threshold = thresholds
checkRoot(q, fc)
case geoFn:
// For geo functions, we get extra information used for filtering.
Expand Down Expand Up @@ -1824,7 +1834,7 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) {
if max < 0 {
return nil, errors.Errorf("Levenshtein distance value must be greater than 0, got %v", s)
}
fc.threshold = int64(max)
fc.threshold = []int64{int64(max)}
fc.tokens = q.SrcFunc.Args
fc.n = len(fc.tokens)
case customIndexFn:
Expand Down Expand Up @@ -2166,27 +2176,33 @@ func preprocessFilter(tree *pb.FilterTree) (*facetsTree, error) {

type countParams struct {
readTs uint64
count int64
counts []int64
attr string
gid uint32
reverse bool // If query is asking for ~pred
fn string // function name
}

func (qs *queryState) evaluate(cp countParams, out *pb.Result) error {
count := cp.count
countl := cp.counts[0]
var counth int64
if cp.fn == between {
counth = cp.counts[1]
}
var illegal bool
switch cp.fn {
case "eq":
illegal = count <= 0
illegal = countl <= 0
case "lt":
illegal = count <= 1
illegal = countl <= 1
case "le":
illegal = count <= 0
illegal = countl <= 0
case "gt":
illegal = count < 0
illegal = countl < 0
case "ge":
illegal = count <= 0
illegal = countl <= 0
case "between":
illegal = countl <= 0 || counth <= 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another check could be countl > counth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not returning any error in case of improper bounds. Returned response would be empty.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just check that case here as we are already checking other illegal cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. But the point is not to return error from here, to have between behaviour consistent.

default:
x.AssertTruef(false, "unhandled count comparison fn: %v", cp.fn)
}
Expand All @@ -2195,7 +2211,7 @@ func (qs *queryState) evaluate(cp countParams, out *pb.Result) error {
"negative counts (nonsensical) or zero counts (not tracked).")
}

countKey := x.CountKey(cp.attr, uint32(count), cp.reverse)
countKey := x.CountKey(cp.attr, uint32(countl), cp.reverse)
if cp.fn == "eq" {
pl, err := qs.cache.Get(countKey)
if err != nil {
Expand All @@ -2211,13 +2227,13 @@ func (qs *queryState) evaluate(cp countParams, out *pb.Result) error {

switch cp.fn {
case "lt":
count--
countl--
case "gt":
count++
countl++
}

x.AssertTrue(count >= 1)
countKey = x.CountKey(cp.attr, uint32(count), cp.reverse)
x.AssertTrue(countl >= 1)
countKey = x.CountKey(cp.attr, uint32(countl), cp.reverse)

txn := pstore.NewTransactionAt(cp.readTs, false)
defer txn.Discard()
Expand All @@ -2234,6 +2250,15 @@ func (qs *queryState) evaluate(cp countParams, out *pb.Result) error {
for itr.Seek(countKey); itr.Valid(); itr.Next() {
item := itr.Item()
var key []byte
key = item.KeyCopy(key)
k, err := x.Parse(key)
if err != nil {
return err
}
if cp.fn == between && int64(k.Count) > counth {
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we will continue here to the next item and not break, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to break here as soon as we cross the upper bound in case of between.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because all future keys count will be higher. (i.e keys are sorted?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

}

pl, err := qs.cache.Get(item.KeyCopy(key))
if err != nil {
return err
Expand Down