Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: nested txns #8102

Merged
merged 8 commits into from
Jun 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Documentation/dev-guide/api_reference_v3.md
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ Empty field.
| request_range | | RangeRequest |
| request_put | | PutRequest |
| request_delete_range | | DeleteRangeRequest |
| request_txn | | TxnRequest |



Expand All @@ -691,6 +692,7 @@ Empty field.
| response_range | | RangeResponse |
| response_put | | PutResponse |
| response_delete_range | | DeleteRangeResponse |
| response_txn | | TxnResponse |



Expand Down
6 changes: 6 additions & 0 deletions Documentation/dev-guide/apispec/swagger/rpc.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1954,6 +1954,9 @@
},
"request_range": {
"$ref": "#/definitions/etcdserverpbRangeRequest"
},
"request_txn": {
"$ref": "#/definitions/etcdserverpbTxnRequest"
}
}
},
Expand Down Expand Up @@ -1993,6 +1996,9 @@
},
"response_range": {
"$ref": "#/definitions/etcdserverpbRangeResponse"
},
"response_txn": {
"$ref": "#/definitions/etcdserverpbTxnResponse"
}
}
},
Expand Down
84 changes: 65 additions & 19 deletions clientv3/integration/txn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,35 @@ func TestTxnReadRetry(t *testing.T) {
defer clus.Terminate(t)

kv := clus.Client(0)
clus.Members[0].Stop(t)
<-clus.Members[0].StopNotify()

donec := make(chan struct{})
go func() {
ctx := context.TODO()
_, err := kv.Txn(ctx).Then(clientv3.OpGet("foo")).Commit()
if err != nil {
t.Fatalf("expected response, got error %v", err)
thenOps := [][]clientv3.Op{
{clientv3.OpGet("foo")},
{clientv3.OpTxn(nil, []clientv3.Op{clientv3.OpGet("foo")}, nil)},
{clientv3.OpTxn(nil, nil, nil)},
{},
}
for i := range thenOps {
clus.Members[0].Stop(t)
<-clus.Members[0].StopNotify()

donec := make(chan struct{})
go func() {
_, err := kv.Txn(context.TODO()).Then(thenOps[i]...).Commit()
if err != nil {
t.Fatalf("expected response, got error %v", err)
}
donec <- struct{}{}
}()
// wait for txn to fail on disconnect
time.Sleep(100 * time.Millisecond)

// restart node; client should resume
clus.Members[0].Restart(t)
select {
case <-donec:
case <-time.After(2 * clus.Members[1].ServerConfig.ReqTimeout()):
t.Fatalf("waited too long")
}
donec <- struct{}{}
}()
// wait for txn to fail on disconnect
time.Sleep(100 * time.Millisecond)

// restart node; client should resume
clus.Members[0].Restart(t)
select {
case <-donec:
case <-time.After(2 * clus.Members[1].ServerConfig.ReqTimeout()):
t.Fatalf("waited too long")
}
}

Expand Down Expand Up @@ -179,3 +187,41 @@ func TestTxnCompareRange(t *testing.T) {
t.Fatal("expected prefix compare to false, got compares as true")
}
}

func TestTxnNested(t *testing.T) {
defer testutil.AfterTest(t)

clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
defer clus.Terminate(t)

kv := clus.Client(0)

tresp, err := kv.Txn(context.TODO()).
If(clientv3.Compare(clientv3.Version("foo"), "=", 0)).
Then(
clientv3.OpPut("foo", "bar"),
clientv3.OpTxn(nil, []clientv3.Op{clientv3.OpPut("abc", "123")}, nil)).
Else(clientv3.OpPut("foo", "baz")).Commit()
if err != nil {
t.Fatal(err)
}
if len(tresp.Responses) != 2 {
t.Errorf("expected 2 top-level txn responses, got %+v", tresp.Responses)
}

// check txn writes were applied
resp, err := kv.Get(context.TODO(), "foo")
if err != nil {
t.Fatal(err)
}
if len(resp.Kvs) != 1 || string(resp.Kvs[0].Value) != "bar" {
t.Errorf("unexpected Get response %+v", resp)
}
resp, err = kv.Get(context.TODO(), "abc")
if err != nil {
t.Fatal(err)
}
if len(resp.Kvs) != 1 || string(resp.Kvs[0].Value) != "123" {
t.Errorf("unexpected Get response %+v", resp)
}
}
9 changes: 8 additions & 1 deletion clientv3/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ type OpResponse struct {
put *PutResponse
get *GetResponse
del *DeleteResponse
txn *TxnResponse
}

func (op OpResponse) Put() *PutResponse { return op.put }
func (op OpResponse) Get() *GetResponse { return op.get }
func (op OpResponse) Del() *DeleteResponse { return op.del }
func (op OpResponse) Txn() *TxnResponse { return op.txn }

type kv struct {
remote pb.KVClient
Expand Down Expand Up @@ -134,7 +136,6 @@ func (kv *kv) Do(ctx context.Context, op Op) (OpResponse, error) {
func (kv *kv) do(ctx context.Context, op Op) (OpResponse, error) {
var err error
switch op.t {
// TODO: handle other ops
case tRange:
var resp *pb.RangeResponse
resp, err = kv.remote.Range(ctx, op.toRangeRequest(), grpc.FailFast(false))
Expand All @@ -155,6 +156,12 @@ func (kv *kv) do(ctx context.Context, op Op) (OpResponse, error) {
if err == nil {
return OpResponse{del: (*DeleteResponse)(resp)}, nil
}
case tTxn:
var resp *pb.TxnResponse
resp, err = kv.remote.Txn(ctx, op.toTxnRequest())
if err == nil {
return OpResponse{txn: (*TxnResponse)(resp)}, nil
}
default:
panic("Unknown op")
}
Expand Down
64 changes: 39 additions & 25 deletions clientv3/namespace/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (kv *kvPrefix) Delete(ctx context.Context, key string, opts ...clientv3.OpO
}

func (kv *kvPrefix) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) {
if len(op.KeyBytes()) == 0 {
if len(op.KeyBytes()) == 0 && !op.IsTxn() {
return clientv3.OpResponse{}, rpctypes.ErrEmptyKey
}
r, err := kv.KV.Do(ctx, kv.prefixOp(op))
Expand All @@ -88,6 +88,8 @@ func (kv *kvPrefix) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse
kv.unprefixPutResponse(r.Put())
case r.Del() != nil:
kv.unprefixDeleteResponse(r.Del())
case r.Txn() != nil:
kv.unprefixTxnResponse(r.Txn())
}
return r, nil
}
Expand All @@ -102,34 +104,17 @@ func (kv *kvPrefix) Txn(ctx context.Context) clientv3.Txn {
}

func (txn *txnPrefix) If(cs ...clientv3.Cmp) clientv3.Txn {
newCmps := make([]clientv3.Cmp, len(cs))
for i := range cs {
newCmps[i] = cs[i]
pfxKey, endKey := txn.kv.prefixInterval(cs[i].KeyBytes(), cs[i].RangeEnd)
newCmps[i].WithKeyBytes(pfxKey)
if len(cs[i].RangeEnd) != 0 {
newCmps[i].RangeEnd = endKey
}
}
txn.Txn = txn.Txn.If(newCmps...)
txn.Txn = txn.Txn.If(txn.kv.prefixCmps(cs)...)
return txn
}

func (txn *txnPrefix) Then(ops ...clientv3.Op) clientv3.Txn {
newOps := make([]clientv3.Op, len(ops))
for i := range ops {
newOps[i] = txn.kv.prefixOp(ops[i])
}
txn.Txn = txn.Txn.Then(newOps...)
txn.Txn = txn.Txn.Then(txn.kv.prefixOps(ops)...)
return txn
}

func (txn *txnPrefix) Else(ops ...clientv3.Op) clientv3.Txn {
newOps := make([]clientv3.Op, len(ops))
for i := range ops {
newOps[i] = txn.kv.prefixOp(ops[i])
}
txn.Txn = txn.Txn.Else(newOps...)
txn.Txn = txn.Txn.Else(txn.kv.prefixOps(ops)...)
return txn
}

Expand All @@ -143,10 +128,14 @@ func (txn *txnPrefix) Commit() (*clientv3.TxnResponse, error) {
}

func (kv *kvPrefix) prefixOp(op clientv3.Op) clientv3.Op {
begin, end := kv.prefixInterval(op.KeyBytes(), op.RangeBytes())
op.WithKeyBytes(begin)
op.WithRangeBytes(end)
return op
if !op.IsTxn() {
begin, end := kv.prefixInterval(op.KeyBytes(), op.RangeBytes())
op.WithKeyBytes(begin)
op.WithRangeBytes(end)
return op
}
cmps, thenOps, elseOps := op.Txn()
return clientv3.OpTxn(kv.prefixCmps(cmps), kv.prefixOps(thenOps), kv.prefixOps(elseOps))
}

func (kv *kvPrefix) unprefixGetResponse(resp *clientv3.GetResponse) {
Expand Down Expand Up @@ -182,6 +171,10 @@ func (kv *kvPrefix) unprefixTxnResponse(resp *clientv3.TxnResponse) {
if tv.ResponseDeleteRange != nil {
kv.unprefixDeleteResponse((*clientv3.DeleteResponse)(tv.ResponseDeleteRange))
}
case *pb.ResponseOp_ResponseTxn:
if tv.ResponseTxn != nil {
kv.unprefixTxnResponse((*clientv3.TxnResponse)(tv.ResponseTxn))
}
default:
}
}
Expand All @@ -190,3 +183,24 @@ func (kv *kvPrefix) unprefixTxnResponse(resp *clientv3.TxnResponse) {
func (p *kvPrefix) prefixInterval(key, end []byte) (pfxKey []byte, pfxEnd []byte) {
return prefixInterval(p.pfx, key, end)
}

func (kv *kvPrefix) prefixCmps(cs []clientv3.Cmp) []clientv3.Cmp {
newCmps := make([]clientv3.Cmp, len(cs))
for i := range cs {
newCmps[i] = cs[i]
pfxKey, endKey := kv.prefixInterval(cs[i].KeyBytes(), cs[i].RangeEnd)
newCmps[i].WithKeyBytes(pfxKey)
if len(cs[i].RangeEnd) != 0 {
newCmps[i].RangeEnd = endKey
}
}
return newCmps
}

func (kv *kvPrefix) prefixOps(ops []clientv3.Op) []clientv3.Op {
newOps := make([]clientv3.Op, len(ops))
for i := range ops {
newOps[i] = kv.prefixOp(ops[i])
}
return newOps
}
44 changes: 44 additions & 0 deletions clientv3/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
tRange opType = iota + 1
tPut
tDeleteRange
tTxn
)

var (
Expand Down Expand Up @@ -67,10 +68,18 @@ type Op struct {
// for put
val []byte
leaseID LeaseID

// txn
cmps []Cmp
thenOps []Op
elseOps []Op
}

// accesors / mutators

func (op Op) IsTxn() bool { return op.t == tTxn }
func (op Op) Txn() ([]Cmp, []Op, []Op) { return op.cmps, op.thenOps, op.elseOps }

// KeyBytes returns the byte slice holding the Op's key.
func (op Op) KeyBytes() []byte { return op.key }

Expand Down Expand Up @@ -113,6 +122,22 @@ func (op Op) toRangeRequest() *pb.RangeRequest {
return r
}

func (op Op) toTxnRequest() *pb.TxnRequest {
thenOps := make([]*pb.RequestOp, len(op.thenOps))
for i, tOp := range op.thenOps {
thenOps[i] = tOp.toRequestOp()
}
elseOps := make([]*pb.RequestOp, len(op.elseOps))
for i, eOp := range op.elseOps {
elseOps[i] = eOp.toRequestOp()
}
cmps := make([]*pb.Compare, len(op.cmps))
for i := range op.cmps {
cmps[i] = (*pb.Compare)(&op.cmps[i])
}
return &pb.TxnRequest{Compare: cmps, Success: thenOps, Failure: elseOps}
}

func (op Op) toRequestOp() *pb.RequestOp {
switch op.t {
case tRange:
Expand All @@ -123,12 +148,27 @@ func (op Op) toRequestOp() *pb.RequestOp {
case tDeleteRange:
r := &pb.DeleteRangeRequest{Key: op.key, RangeEnd: op.end, PrevKv: op.prevKV}
return &pb.RequestOp{Request: &pb.RequestOp_RequestDeleteRange{RequestDeleteRange: r}}
case tTxn:
return &pb.RequestOp{Request: &pb.RequestOp_RequestTxn{RequestTxn: op.toTxnRequest()}}
default:
panic("Unknown Op")
}
}

func (op Op) isWrite() bool {
if op.t == tTxn {
for _, tOp := range op.thenOps {
if tOp.isWrite() {
return true
}
}
for _, tOp := range op.elseOps {
if tOp.isWrite() {
return true
}
}
return false
}
return op.t != tRange
}

Expand Down Expand Up @@ -194,6 +234,10 @@ func OpPut(key, val string, opts ...OpOption) Op {
return ret
}

func OpTxn(cmps []Cmp, thenOps []Op, elseOps []Op) Op {
return Op{t: tTxn, cmps: cmps, thenOps: thenOps, elseOps: elseOps}
}

func opWatch(key string, opts ...OpOption) Op {
ret := Op{t: tRange, key: []byte(key)}
ret.applyOpts(opts)
Expand Down
Loading