diff --git a/dmsgget/dmsgget.go b/dmsgget/dmsgget.go index 8df7a9a6f..f37080761 100644 --- a/dmsgget/dmsgget.go +++ b/dmsgget/dmsgget.go @@ -110,7 +110,7 @@ func (dg *DmsgGet) Run(ctx context.Context, log *logging.Logger, skStr string, a } defer closeDmsg() - httpC := http.Client{Transport: dmsghttp.MakeHTTPTransport(dmsgC)} + httpC := http.Client{Transport: dmsghttp.MakeHTTPTransport(ctx, dmsgC)} for i := 0; i < dg.dlF.Tries; i++ { log.Infof("Download attempt %d/%d ...", i, dg.dlF.Tries) diff --git a/dmsgget/dmsgget_test.go b/dmsgget/dmsgget_test.go index 5139f778d..a841b282d 100644 --- a/dmsgget/dmsgget_test.go +++ b/dmsgget/dmsgget_test.go @@ -172,5 +172,8 @@ func newHTTPClient(t *testing.T, dc disc.APIClient) *http.Client { t.Cleanup(func() { assert.NoError(t, dmsgC.Close()) }) <-dmsgC.Ready() - return &http.Client{Transport: dmsghttp.MakeHTTPTransport(dmsgC)} + log := logging.MustGetLogger(fmt.Sprintf("http_client")) + ctx, cancel := cmdutil.SignalContext(context.Background(), log) + defer cancel() + return &http.Client{Transport: dmsghttp.MakeHTTPTransport(ctx, dmsgC)} } diff --git a/dmsghttp/examples_test.go b/dmsghttp/examples_test.go index aed628778..d8c47a9d2 100644 --- a/dmsghttp/examples_test.go +++ b/dmsghttp/examples_test.go @@ -12,8 +12,10 @@ import ( "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/dmsg/cmdutil" "github.com/skycoin/dmsg/disc" "github.com/skycoin/dmsg/dmsghttp" + "github.com/skycoin/skycoin/src/util/logging" ) func ExampleMakeHTTPTransport() { @@ -87,8 +89,11 @@ func ExampleMakeHTTPTransport() { go dmsgC2.Serve(context.Background()) <-dmsgC2.Ready() + log := logging.MustGetLogger(fmt.Sprintf("http_client")) + ctx, cancel := cmdutil.SignalContext(context.Background(), log) + defer cancel() // Run HTTP client. - httpC := http.Client{Transport: dmsghttp.MakeHTTPTransport(dmsgC2)} + httpC := http.Client{Transport: dmsghttp.MakeHTTPTransport(ctx, dmsgC2)} resp, err := httpC.Get(fmt.Sprintf("http://%s:%d/", c1PK.String(), dmsgHTTPPort)) if err != nil { panic(err) diff --git a/dmsghttp/http_transport.go b/dmsghttp/http_transport.go index 4ae90ac98..331a16282 100644 --- a/dmsghttp/http_transport.go +++ b/dmsghttp/http_transport.go @@ -2,8 +2,10 @@ package dmsghttp import ( "bufio" + "context" "fmt" "net/http" + "time" "github.com/skycoin/dmsg" ) @@ -13,12 +15,16 @@ const defaultHTTPPort = uint16(80) // HTTPTransport implements http.RoundTripper // Do not confuse this with a Skywire Transport implementation. type HTTPTransport struct { + ctx context.Context dmsgC *dmsg.Client } // MakeHTTPTransport makes an HTTPTransport. -func MakeHTTPTransport(dmsgC *dmsg.Client) HTTPTransport { - return HTTPTransport{dmsgC: dmsgC} +func MakeHTTPTransport(ctx context.Context, dmsgC *dmsg.Client) HTTPTransport { + return HTTPTransport{ + ctx: ctx, + dmsgC: dmsgC, + } } // RoundTrip implements golang's http package support for alternative HTTP transport protocols. @@ -39,10 +45,42 @@ func (t HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { return nil, err } - if err := req.Write(stream); err != nil { return nil, err } bufR := bufio.NewReader(stream) - return http.ReadResponse(bufR, req) + resp, err := http.ReadResponse(bufR, req) + if err != nil { + return nil, err + } + + defer func() { + go test(t.ctx, resp, stream) + }() + + return resp, nil +} + +func test(ctx context.Context, resp *http.Response, stream *dmsg.Stream) { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _, err := resp.Body.Read(nil) + log := stream.Logger() + log.Errorf("err %v", err) + if err == nil { + // can still read from body so it's not closed + + } else if err != nil && err.Error() == "http: invalid Read on closed Body" { + stream.Close() + return + } + } + } + } diff --git a/dmsghttp/http_transport_test.go b/dmsghttp/http_transport_test.go index 065736f82..e43254005 100644 --- a/dmsghttp/http_transport_test.go +++ b/dmsghttp/http_transport_test.go @@ -15,6 +15,7 @@ import ( "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/dmsg/cmdutil" "github.com/skycoin/dmsg/disc" ) @@ -63,10 +64,13 @@ func TestHTTPTransport_RoundTrip(t *testing.T) { startHTTPServer(t, server0Results, lis) addr := lis.Addr().String() + log := logging.MustGetLogger(fmt.Sprintf("http_client")) + ctx, cancel := cmdutil.SignalContext(context.Background(), log) + defer cancel() // Arrange: create http clients (in which each http client has an underlying dmsg client). - httpC1 := http.Client{Transport: MakeHTTPTransport(newDmsgClient(t, dc, minSessions, "client1"))} - httpC2 := http.Client{Transport: MakeHTTPTransport(newDmsgClient(t, dc, minSessions, "client2"))} - httpC3 := http.Client{Transport: MakeHTTPTransport(newDmsgClient(t, dc, minSessions, "client3"))} + httpC1 := http.Client{Transport: MakeHTTPTransport(ctx, newDmsgClient(t, dc, minSessions, "client1"))} + httpC2 := http.Client{Transport: MakeHTTPTransport(ctx, newDmsgClient(t, dc, minSessions, "client2"))} + httpC3 := http.Client{Transport: MakeHTTPTransport(ctx, newDmsgClient(t, dc, minSessions, "client3"))} httpC1.Timeout = timeout httpC2.Timeout = timeout httpC3.Timeout = timeout