diff --git a/go/arrow/flight/cookie_middleware.go b/go/arrow/flight/cookie_middleware.go index 27754a13b829a..39c86d8303434 100644 --- a/go/arrow/flight/cookie_middleware.go +++ b/go/arrow/flight/cookie_middleware.go @@ -23,6 +23,7 @@ import ( "sync" "time" + "golang.org/x/exp/maps" "google.golang.org/grpc/metadata" ) @@ -40,11 +41,34 @@ func NewClientCookieMiddleware() ClientMiddleware { return CreateClientMiddleware(&clientCookieMiddleware{jar: make(map[string]http.Cookie)}) } +func NewCookieMiddleware() CookieMiddleware { + return &clientCookieMiddleware{jar: make(map[string]http.Cookie)} +} + +// CookieMiddleware is a go-routine safe middleware for flight clients +// which properly handles Set-Cookie headers for storing cookies. +// This can be passed into `CreateClientMiddleware` to create a new +// middleware object. You can also clone it to create middleware for a +// new client which starts with the same cookies. +type CookieMiddleware interface { + CustomClientMiddleware + // Clone creates a new CookieMiddleware that starts out with the same + // cookies that this one already has. This is useful when creating a + // new client connection for the same server. + Clone() CookieMiddleware +} + type clientCookieMiddleware struct { jar map[string]http.Cookie mx sync.Mutex } +func (cc *clientCookieMiddleware) Clone() CookieMiddleware { + cc.mx.Lock() + defer cc.mx.Unlock() + return &clientCookieMiddleware{jar: maps.Clone(cc.jar)} +} + func (cc *clientCookieMiddleware) StartCall(ctx context.Context) context.Context { cc.mx.Lock() defer cc.mx.Unlock() diff --git a/go/arrow/flight/cookie_middleware_test.go b/go/arrow/flight/cookie_middleware_test.go index 0adf4927652d4..4007d056b2c99 100644 --- a/go/arrow/flight/cookie_middleware_test.go +++ b/go/arrow/flight/cookie_middleware_test.go @@ -239,3 +239,63 @@ func TestCookieExpiration(t *testing.T) { cookieMiddleware.expectedCookies = map[string]string{} makeReq(client, t) } + +func TestCookiesClone(t *testing.T) { + cookieMiddleware := &serverAddCookieMiddleware{} + + s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{ + flight.CreateServerMiddleware(cookieMiddleware), + }) + s.Init("localhost:0") + f := &flightServer{} + s.RegisterFlightService(f) + + go s.Serve() + defer s.Shutdown() + + makeReq := func(c flight.Client, t *testing.T) { + flightStream, err := c.ListFlights(context.Background(), &flight.Criteria{}) + assert.NoError(t, err) + + for { + _, err := flightStream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + assert.NoError(t, err) + } + } + } + + credsOpt := grpc.WithTransportCredentials(insecure.NewCredentials()) + cookies := flight.NewCookieMiddleware() + client1, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, + []flight.ClientMiddleware{flight.CreateClientMiddleware(cookies)}, credsOpt) + require.NoError(t, err) + defer client1.Close() + + // set cookies + cookieMiddleware.cookies = []*http.Cookie{ + {Name: "foo", Value: "bar"}, + {Name: "foo2", Value: "bar2", MaxAge: 1}, + } + makeReq(client1, t) + + // validate set + cookieMiddleware.expectedCookies = map[string]string{ + "foo": "bar", "foo2": "bar2", + } + makeReq(client1, t) + + client2, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, + []flight.ClientMiddleware{flight.CreateClientMiddleware(cookies.Clone())}, credsOpt) + require.NoError(t, err) + defer client2.Close() + + // validate clone worked + cookieMiddleware.expectedCookies = map[string]string{ + "foo": "bar", "foo2": "bar2", + } + makeReq(client2, t) +}