Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] fix(commit): make txn context more robust #7659

Merged
merged 5 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions dgraph/cmd/alpha/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ func commitHandler(w http.ResponseWriter, r *http.Request) {
return
}

hash := r.URL.Query().Get("hash")
abort, err := parseBool(r, "abort")
if err != nil {
x.SetStatus(w, x.ErrorInvalidRequest, err.Error())
Expand All @@ -472,15 +473,15 @@ func commitHandler(w http.ResponseWriter, r *http.Request) {
ctx := x.AttachAccessJwt(context.Background(), r)
var response map[string]interface{}
if abort {
response, err = handleAbort(ctx, startTs)
response, err = handleAbort(ctx, startTs, hash)
} else {
// Keys are sent as an array in the body.
reqText := readRequest(w, r)
if reqText == nil {
return
}

response, err = handleCommit(ctx, startTs, reqText)
response, err = handleCommit(ctx, startTs, hash, reqText)
}
if err != nil {
x.SetStatus(w, x.ErrorInvalidRequest, err.Error())
Expand All @@ -496,10 +497,11 @@ func commitHandler(w http.ResponseWriter, r *http.Request) {
_, _ = x.WriteResponse(w, r, js)
}

func handleAbort(ctx context.Context, startTs uint64) (map[string]interface{}, error) {
func handleAbort(ctx context.Context, startTs uint64, hash string) (map[string]interface{}, error) {
tc := &api.TxnContext{
StartTs: startTs,
Aborted: true,
Hash: hash,
}

tctx, err := (&edgraph.Server{}).CommitOrAbort(ctx, tc)
Expand All @@ -516,10 +518,11 @@ func handleAbort(ctx context.Context, startTs uint64) (map[string]interface{}, e
}
}

func handleCommit(ctx context.Context, startTs uint64, reqText []byte) (map[string]interface{},
error) {
func handleCommit(ctx context.Context,
startTs uint64, hash string, reqText []byte) (map[string]interface{}, error) {
tc := &api.TxnContext{
StartTs: startTs,
Hash: hash,
}

var reqList []string
Expand Down
29 changes: 16 additions & 13 deletions dgraph/cmd/alpha/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,15 @@ type mutationResponse struct {
keys []string
preds []string
startTs uint64
hash string
data json.RawMessage
cost string
}

func mutationWithTs(m, t string, isJson bool, commitNow bool, ts uint64) (
mutationResponse, error) {

params := make([]string, 2)
params := make([]string, 0, 2)
if ts != 0 {
params = append(params, "startTs="+strconv.FormatUint(ts, 10))
}
Expand All @@ -219,7 +220,6 @@ func mutationWithTs(m, t string, isJson bool, commitNow bool, ts uint64) (
if commitNow {
params = append(params, "commitNow=true")
}

url := addr + "/mutate?" + strings.Join(params, "&")
_, body, resp, err := runWithRetriesForResp("POST", t, url, m)
if err != nil {
Expand All @@ -235,6 +235,7 @@ func mutationWithTs(m, t string, isJson bool, commitNow bool, ts uint64) (
mr.keys = r.Extensions.Txn.Keys
mr.preds = r.Extensions.Txn.Preds
mr.startTs = r.Extensions.Txn.StartTs
mr.hash = r.Extensions.Txn.Hash
sort.Strings(mr.preds)

var d map[string]interface{}
Expand Down Expand Up @@ -322,22 +323,23 @@ func runWithRetriesForResp(method, contentType, url string, body string) (
return qr, respBody, resp, err
}

func commitWithTs(keys, preds []string, ts uint64, abort bool) error {
func commitWithTs(mr mutationResponse, abort bool) error {
url := addr + "/commit"
if ts != 0 {
url += "?startTs=" + strconv.FormatUint(ts, 10)
if mr.startTs != 0 {
url += "?startTs=" + strconv.FormatUint(mr.startTs, 10)
url += "&hash=" + mr.hash
}
if abort {
if ts != 0 {
if mr.startTs != 0 {
url += "&abort=true"
} else {
url += "?abort=true"
}
}

m := make(map[string]interface{})
m["keys"] = keys
m["preds"] = preds
m["keys"] = mr.keys
m["preds"] = mr.preds
b, err := json.Marshal(m)
if err != nil {
return err
Expand All @@ -350,10 +352,11 @@ func commitWithTs(keys, preds []string, ts uint64, abort bool) error {
return err
}

func commitWithTsKeysOnly(keys []string, ts uint64) error {
func commitWithTsKeysOnly(keys []string, ts uint64, hash string) error {
url := addr + "/commit"
if ts != 0 {
url += "?startTs=" + strconv.FormatUint(ts, 10)
url += "&hash=" + hash
}

b, err := json.Marshal(keys)
Expand Down Expand Up @@ -418,7 +421,7 @@ func TestTransactionBasic(t *testing.T) {
require.Equal(t, `{"data":{"balances":[{"name":"Bob","balance":"110"}]}}`, data)

// Commit and query.
require.NoError(t, commitWithTs(mr.keys, mr.preds, ts, false))
require.NoError(t, commitWithTs(mr, false))
data, _, err = queryWithTs(q1, "application/dql", "", 0)
require.NoError(t, err)
require.Equal(t, `{"data":{"balances":[{"name":"Bob","balance":"110"}]}}`, data)
Expand Down Expand Up @@ -464,7 +467,7 @@ func TestTransactionBasicNoPreds(t *testing.T) {
require.Equal(t, `{"data":{"balances":[{"name":"Bob","balance":"110"}]}}`, data)

// Commit and query.
require.NoError(t, commitWithTs(mr.keys, nil, ts, false))
require.NoError(t, commitWithTs(mr, false))
data, _, err = queryWithTs(q1, "application/dql", "", 0)
require.NoError(t, err)
require.Equal(t, `{"data":{"balances":[{"name":"Bob","balance":"110"}]}}`, data)
Expand Down Expand Up @@ -550,13 +553,13 @@ func TestTransactionBasicOldCommitFormat(t *testing.T) {
require.Equal(t, `{"data":{"balances":[{"name":"Bob","balance":"110"}]}}`, data)

// Commit (using a list of keys instead of a map) and query.
require.NoError(t, commitWithTsKeysOnly(mr.keys, ts))
require.NoError(t, commitWithTsKeysOnly(mr.keys, ts, mr.hash))
data, _, err = queryWithTs(q1, "application/dql", "", 0)
require.NoError(t, err)
require.Equal(t, `{"data":{"balances":[{"name":"Bob","balance":"110"}]}}`, data)

// Aborting a transaction
url := fmt.Sprintf("%s/commit?startTs=%d&abort=true", addr, ts)
url := fmt.Sprintf("%s/commit?startTs=%d&abort=true&hash=%s", addr, ts, mr.hash)
req, err := http.NewRequest("POST", url, nil)
require.NoError(t, err)
_, _, _, err = runRequest(req)
Expand Down
16 changes: 8 additions & 8 deletions dgraph/cmd/alpha/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ func TestMetricTxnCommits(t *testing.T) {
// first normal commit
mr, err := mutationWithTs(mt, "application/rdf", false, false, 0)
require.NoError(t, err)
require.NoError(t, commitWithTs(mr.keys, mr.preds, mr.startTs, false))
require.NoError(t, commitWithTs(mr, false))

metrics := fetchMetrics(t, metricName)

// second normal commit
mr, err = mutationWithTs(mt, "application/rdf", false, false, 0)
require.NoError(t, err)
require.NoError(t, commitWithTs(mr.keys, mr.preds, mr.startTs, false))
require.NoError(t, commitWithTs(mr, false))

require.NoError(t, retryableFetchMetrics(t, map[string]int{
metricName: metrics[metricName] + 1,
Expand All @@ -68,14 +68,14 @@ func TestMetricTxnDiscards(t *testing.T) {
// first normal commit
mr, err := mutationWithTs(mt, "application/rdf", false, false, 0)
require.NoError(t, err)
require.NoError(t, commitWithTs(mr.keys, mr.preds, mr.startTs, false))
require.NoError(t, commitWithTs(mr, false))

metrics := fetchMetrics(t, metricName)

// second commit discarded
mr, err = mutationWithTs(mt, "application/rdf", false, false, 0)
require.NoError(t, err)
require.NoError(t, commitWithTs(mr.keys, mr.preds, mr.startTs, true))
require.NoError(t, commitWithTs(mr, true))

require.NoError(t, retryableFetchMetrics(t, map[string]int{
metricName: metrics[metricName] + 1,
Expand All @@ -96,17 +96,17 @@ func TestMetricTxnAborts(t *testing.T) {
require.NoError(t, err)
mr2, err := mutationWithTs(mt, "application/rdf", false, false, 0)
require.NoError(t, err)
require.NoError(t, commitWithTs(mr1.keys, mr1.preds, mr1.startTs, false))
require.Error(t, commitWithTs(mr2.keys, mr2.preds, mr2.startTs, false))
require.NoError(t, commitWithTs(mr1, false))
require.Error(t, commitWithTs(mr2, false))

metrics := fetchMetrics(t, metricName)

mr1, err = mutationWithTs(mt, "application/rdf", false, false, 0)
require.NoError(t, err)
mr2, err = mutationWithTs(mt, "application/rdf", false, false, 0)
require.NoError(t, err)
require.NoError(t, commitWithTs(mr1.keys, mr1.preds, mr1.startTs, false))
require.Error(t, commitWithTs(mr2.keys, mr2.preds, mr2.startTs, false))
require.NoError(t, commitWithTs(mr1, false))
require.Error(t, commitWithTs(mr2, false))

require.NoError(t, retryableFetchMetrics(t, map[string]int{
metricName: metrics[metricName] + 1,
Expand Down
52 changes: 27 additions & 25 deletions edgraph/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package edgraph
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"math"
Expand Down Expand Up @@ -1261,6 +1263,14 @@ func (s *Server) doQuery(ctx context.Context, req *Request) (
EncodingNs: uint64(l.Json.Nanoseconds()),
TotalNs: uint64((time.Since(l.Start)).Nanoseconds()),
}
if x.WorkerConfig.AclEnabled {
// attach the hash, user should send this hash when further operating on this startTs.
ns, err := x.ExtractNamespace(ctx)
if err != nil {
return nil, err
}
resp.Txn.Hash = getHash(ns, resp.Txn.StartTs)
}
md := metadata.Pairs(x.DgraphCostHeader, fmt.Sprint(resp.Metrics.NumUids["_total"]))
grpc.SendHeader(ctx, md)
return resp, gqlErrs
Expand Down Expand Up @@ -1471,30 +1481,24 @@ func authorizeRequest(ctx context.Context, qc *queryContext) error {
return nil
}

func validateNamespace(ctx context.Context, preds []string) error {
func getHash(ns, startTs uint64) string {
h := sha256.New()
h.Write([]byte(fmt.Sprintf("%#x%#x%s", ns, startTs, x.WorkerConfig.HmacSecret)))
return hex.EncodeToString(h.Sum(nil))
}

func validateNamespace(ctx context.Context, tc *api.TxnContext) error {
if !x.WorkerConfig.AclEnabled {
return nil
}

ns, err := x.ExtractJWTNamespace(ctx)
if err != nil {
return err
}

// Do a basic validation that all the predicates passed in transaction context matches the
// claimed namespace and user is not accidently commiting a transaction that it did not create.
for _, pred := range preds {
// Format for Preds in TxnContext is gid-<namespace><pred> (see fillPreds in posting pkg)
splits := strings.Split(pred, "-")
if len(splits) < 2 {
return errors.Errorf("Unable to find group id in %s", pred)
}
pred = strings.Join(splits[1:], "-")
if len(pred) < 8 {
return errors.Errorf("found invalid pred %s of length < 8 in transaction context", pred)
}
if parsedNs := x.ParseNamespace(pred); parsedNs != ns {
return errors.Errorf("Please login into correct namespace. "+
"Currently logged in namespace %#x", ns)
}
if tc.Hash != getHash(ns, tc.StartTs) {
return errors.Errorf("hash mismatch the claimed startTs|namespace")
}

return nil
}

Expand All @@ -1507,19 +1511,17 @@ func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.Tx
return &api.TxnContext{}, err
}

if x.WorkerConfig.AclEnabled {
if err := validateNamespace(ctx, tc.Preds); err != nil {
return &api.TxnContext{}, err
}
}

tctx := &api.TxnContext{}
if tc.StartTs == 0 {
return &api.TxnContext{}, errors.Errorf(
"StartTs cannot be zero while committing a transaction")
}
annotateStartTs(span, tc.StartTs)

if err := validateNamespace(ctx, tc); err != nil {
return &api.TxnContext{}, err
}

span.Annotatef(nil, "Txn Context received: %+v", tc)
commitTs, err := worker.CommitOverNetwork(ctx, tc)
if err == dgo.ErrAborted {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/blevesearch/bleve v1.0.13
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd
github.com/dgraph-io/badger/v3 v3.0.0-20210309075542-2245c18dfd1f
github.com/dgraph-io/dgo/v200 v200.0.0-20210212152539-e0a5bde40ba2
github.com/dgraph-io/dgo/v200 v200.0.0-20210331134112-3dd0035583a4
github.com/dgraph-io/gqlgen v0.13.2
github.com/dgraph-io/gqlparser/v2 v2.2.0
github.com/dgraph-io/graphql-transport-ws v0.0.0-20210223074046-e5b8b80bb4ed
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ github.com/dgraph-io/badger v1.6.0 h1:DshxFxZWXUcO0xX476VJC07Xsr6ZCBVRHKZ93Oh7Ev
github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4=
github.com/dgraph-io/badger/v3 v3.0.0-20210309075542-2245c18dfd1f h1:dZpGNLp9YUpq4h2DRcWAjW5dWj47SM3W3NK71z6FRa0=
github.com/dgraph-io/badger/v3 v3.0.0-20210309075542-2245c18dfd1f/go.mod h1:GHMCYxuDWyzbHkh4k3yyg4PM61tJPFfEGSMbE3Vd5QE=
github.com/dgraph-io/dgo/v200 v200.0.0-20210212152539-e0a5bde40ba2 h1:3STgJCaLdsBynA0FDLqocwk/OdlnrQ5MGbR74C4NvUs=
github.com/dgraph-io/dgo/v200 v200.0.0-20210212152539-e0a5bde40ba2/go.mod h1:zCfS4R3E/UC/PhETXJYq/Blia0eCH1EQqKrWDvvimxE=
github.com/dgraph-io/dgo/v200 v200.0.0-20210331134112-3dd0035583a4 h1:e4IfJ6Ut//KS9dN293rxnAQq3wvjcjFXhjlAQQ/SPFk=
github.com/dgraph-io/dgo/v200 v200.0.0-20210331134112-3dd0035583a4/go.mod h1:zCfS4R3E/UC/PhETXJYq/Blia0eCH1EQqKrWDvvimxE=
github.com/dgraph-io/gqlgen v0.13.2 h1:TNhndk+eHKj5qE7BenKKSYdSIdOGhLqxR1rCiMso9KM=
github.com/dgraph-io/gqlgen v0.13.2/go.mod h1:iCOrOv9lngN7KAo+jMgvUPVDlYHdf7qDwsTkQby2Sis=
github.com/dgraph-io/gqlparser/v2 v2.1.1/go.mod h1:MYS4jppjyx8b9tuUtjV7jU1UFZK6P9fvO8TsIsQtRKU=
Expand Down