diff --git a/clientv3/concurrency/example_stm_test.go b/clientv3/concurrency/example_stm_test.go index d49862c7db5..54c871d364c 100644 --- a/clientv3/concurrency/example_stm_test.go +++ b/clientv3/concurrency/example_stm_test.go @@ -58,7 +58,7 @@ func ExampleSTM_apply() { // transfer amount xfer := fromInt / 2 - fromInt, toInt = fromInt-xfer, toInt-xfer + fromInt, toInt = fromInt-xfer, toInt+xfer // writeback stm.Put(fromK, fmt.Sprintf("%d", fromInt)) diff --git a/clientv3/concurrency/stm.go b/clientv3/concurrency/stm.go index a9e69e8021a..9b6576f6839 100644 --- a/clientv3/concurrency/stm.go +++ b/clientv3/concurrency/stm.go @@ -193,11 +193,12 @@ func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) { } } +// first returns the store revision from the first fetch func (rs readSet) first() int64 { ret := int64(math.MaxInt64 - 1) for _, resp := range rs { - if len(resp.Kvs) > 0 && resp.Kvs[0].ModRevision < ret { - ret = resp.Kvs[0].ModRevision + if rev := resp.Header.Revision; rev < ret { + ret = rev } } return ret diff --git a/integration/v3_stm_test.go b/integration/v3_stm_test.go index 057bfb88abd..7965b3c2c9c 100644 --- a/integration/v3_stm_test.go +++ b/integration/v3_stm_test.go @@ -15,6 +15,7 @@ package integration import ( + "context" "fmt" "math/rand" "strconv" @@ -22,7 +23,7 @@ import ( v3 "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3/concurrency" - "golang.org/x/net/context" + "github.com/coreos/etcd/pkg/testutil" ) // TestSTMConflict tests that conflicts are retried. @@ -253,3 +254,36 @@ func TestSTMApplyOnConcurrentDeletion(t *testing.T) { t.Fatalf("bad value. got %+v, expected 'bar2' value", resp) } } + +func TestSTMSerializableSnapshotPut(t *testing.T) { + clus := NewClusterV3(t, &ClusterConfig{Size: 1}) + defer clus.Terminate(t) + + cli := clus.Client(0) + // key with lower create/mod revision than keys being updated + _, err := cli.Put(context.TODO(), "a", "0") + testutil.AssertNil(t, err) + + tries := 0 + applyf := func(stm concurrency.STM) error { + if tries > 2 { + return fmt.Errorf("too many retries") + } + tries++ + stm.Get("a") + stm.Put("b", "1") + return nil + } + + iso := concurrency.WithIsolation(concurrency.SerializableSnapshot) + _, err = concurrency.NewSTM(cli, applyf, iso) + testutil.AssertNil(t, err) + _, err = concurrency.NewSTM(cli, applyf, iso) + testutil.AssertNil(t, err) + + resp, err := cli.Get(context.TODO(), "b") + testutil.AssertNil(t, err) + if resp.Kvs[0].Version != 2 { + t.Fatalf("bad version. got %+v, expected version 2", resp) + } +}