diff --git a/go.mod b/go.mod index ab0ed14..c7fac60 100644 --- a/go.mod +++ b/go.mod @@ -16,4 +16,5 @@ require ( github.com/multiformats/go-multiaddr v0.2.2 github.com/multiformats/go-multihash v0.0.14 github.com/multiformats/go-varint v0.0.6 + github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index fe9af52..8913ef7 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,9 @@ github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtE github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs= github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495 h1:6IyqGr3fnd0tM3YxipK27TUskaOVUjU2nG45yzwcQKY= github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= @@ -58,10 +59,15 @@ github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/spacemonkeygo/spacelog v0.0.0-20180420211403-2296661a0572 h1:RC6RW7j+1+HkWaX/Yh71Ee5ZHaHYt7ZP4sQgUrm6cDU= github.com/spacemonkeygo/spacelog v0.0.0-20180420211403-2296661a0572/go.mod h1:w0SWMsp6j9O/dk4/ZpIhL+3CkG8ofA2vuv7k+ltqUMc= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU= @@ -82,3 +88,5 @@ gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMy gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/network/context.go b/network/context.go index 01f3177..7fabfb5 100644 --- a/network/context.go +++ b/network/context.go @@ -14,12 +14,13 @@ type noDialCtxKey struct{} type dialPeerTimeoutCtxKey struct{} type forceDirectDialCtxKey struct{} type useTransientCtxKey struct{} -type simConnectCtxKey struct{} +type simConnectCtxKey struct{ isClient bool } var noDial = noDialCtxKey{} var forceDirectDial = forceDirectDialCtxKey{} var useTransient = useTransientCtxKey{} -var simConnect = simConnectCtxKey{} +var simConnectIsServer = simConnectCtxKey{} +var simConnectIsClient = simConnectCtxKey{isClient: true} // EXPERIMENTAL // WithForceDirectDial constructs a new context with an option that instructs the network @@ -39,22 +40,26 @@ func GetForceDirectDial(ctx context.Context) (forceDirect bool, reason string) { return false, "" } -// EXPERIMENTAL // WithSimultaneousConnect constructs a new context with an option that instructs the transport // to apply hole punching logic where applicable. -func WithSimultaneousConnect(ctx context.Context, reason string) context.Context { - return context.WithValue(ctx, simConnect, reason) +// EXPERIMENTAL +func WithSimultaneousConnect(ctx context.Context, isClient bool, reason string) context.Context { + if isClient { + return context.WithValue(ctx, simConnectIsClient, reason) + } + return context.WithValue(ctx, simConnectIsServer, reason) } -// EXPERIMENTAL // GetSimultaneousConnect returns true if the simultaneous connect option is set in the context. -func GetSimultaneousConnect(ctx context.Context) (simconnect bool, reason string) { - v := ctx.Value(simConnect) - if v != nil { - return true, v.(string) +// EXPERIMENTAL +func GetSimultaneousConnect(ctx context.Context) (simconnect bool, isClient bool, reason string) { + if v := ctx.Value(simConnectIsClient); v != nil { + return true, true, v.(string) } - - return false, "" + if v := ctx.Value(simConnectIsServer); v != nil { + return true, false, v.(string) + } + return false, false, "" } // WithNoDial constructs a new context with an option that instructs the network diff --git a/network/context_test.go b/network/context_test.go index 0912551..b12def5 100644 --- a/network/context_test.go +++ b/network/context_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestDefaultTimeout(t *testing.T) { @@ -38,3 +40,20 @@ func TestSettingTimeout(t *testing.T) { t.Fatal("peer timeout doesn't match set timeout") } } + +func TestSimultaneousConnect(t *testing.T) { + t.Run("for the server", func(t *testing.T) { + serverCtx := WithSimultaneousConnect(context.Background(), false, "foobar") + ok, isClient, reason := GetSimultaneousConnect(serverCtx) + require.True(t, ok) + require.False(t, isClient) + require.Equal(t, reason, "foobar") + }) + t.Run("for the client", func(t *testing.T) { + serverCtx := WithSimultaneousConnect(context.Background(), true, "foo") + ok, isClient, reason := GetSimultaneousConnect(serverCtx) + require.True(t, ok) + require.True(t, isClient) + require.Equal(t, reason, "foo") + }) +}