Skip to content

Commit

Permalink
Add PipelineClient name (valyala#994)
Browse files Browse the repository at this point in the history
* Improve documentation about DelClientCookie which related with valyala#951.

* Add pipeline name
  • Loading branch information
kiyonlin authored Mar 15, 2021
1 parent 1a7995b commit 02e0722
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
37 changes: 37 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2120,6 +2120,13 @@ type PipelineClient struct {
// Address of the host to connect to.
Addr string

// PipelineClient name. Used in User-Agent request header.
Name string

// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool

// The maximum number of concurrent connections to the Addr.
//
// A single connection is used by default.
Expand Down Expand Up @@ -2225,6 +2232,8 @@ type pipelineConnClient struct {
noCopy noCopy //nolint:unused,structcheck

Addr string
Name string
NoDefaultUserAgentHeader bool
MaxPendingRequests int
MaxBatchDelay time.Duration
Dial DialFunc
Expand All @@ -2248,6 +2257,7 @@ type pipelineConnClient struct {

tlsConfigLock sync.Mutex
tlsConfig *tls.Config
clientName atomic.Value
}

type pipelineWork struct {
Expand Down Expand Up @@ -2316,6 +2326,11 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t
req.URI().DisablePathNormalizing = true
}

userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
}

w := acquirePipelineWork(&c.workPool, timeout)
w.respCopy.Header.disableNormalizing = c.DisableHeaderNamesNormalizing
w.req = &w.reqCopy
Expand Down Expand Up @@ -2380,6 +2395,11 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
req.URI().DisablePathNormalizing = true
}

userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
}

w := acquirePipelineWork(&c.workPool, 0)
w.req = req
if resp != nil {
Expand Down Expand Up @@ -2459,6 +2479,8 @@ func (c *PipelineClient) getConnClientUnlocked() *pipelineConnClient {
func (c *PipelineClient) newConnClient() *pipelineConnClient {
cc := &pipelineConnClient{
Addr: c.Addr,
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
MaxPendingRequests: c.MaxPendingRequests,
MaxBatchDelay: c.MaxBatchDelay,
Dial: c.Dial,
Expand Down Expand Up @@ -2770,6 +2792,21 @@ func (c *pipelineConnClient) PendingRequests() int {
return n
}

func (c *pipelineConnClient) getClientName() []byte {
v := c.clientName.Load()
var clientName []byte
if v == nil {
clientName = []byte(c.Name)
if len(clientName) == 0 && !c.NoDefaultUserAgentHeader {
clientName = defaultUserAgent
}
c.clientName.Store(clientName)
} else {
clientName = v.([]byte)
}
return clientName
}

var errPipelineConnStopped = errors.New("pipeline connection has been stopped")

func acquirePipelineWork(pool *sync.Pool, timeout time.Duration) *pipelineWork {
Expand Down
50 changes: 50 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,56 @@ func TestCloseIdleConnections(t *testing.T) {
}
}

func TestPipelineClientSetUserAgent(t *testing.T) {
t.Parallel()

testPipelineClientSetUserAgent(t, 0)
}

func TestPipelineClientSetUserAgentTimeout(t *testing.T) {
t.Parallel()

testPipelineClientSetUserAgent(t, time.Second)
}

func testPipelineClientSetUserAgent(t *testing.T, timeout time.Duration) {
ln := fasthttputil.NewInmemoryListener()

userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck

userAgent := "I'm not fasthttp"
c := &HostClient{
Name: userAgent,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()

req.SetRequestURI("http://example.com")

var err error
if timeout <= 0 {
err = c.Do(req, res)
} else {
err = c.DoTimeout(req, res, timeout)
}

if err != nil {
t.Fatal(err)
}
if userAgentSeen != userAgent {
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
}
}

func TestPipelineClientIssue832(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 02e0722

Please sign in to comment.