diff --git a/async_producer.go b/async_producer.go index 9b15cd192..f05f0018f 100644 --- a/async_producer.go +++ b/async_producer.go @@ -661,13 +661,28 @@ func (p *asyncProducer) newBrokerProducer(broker *Broker) *brokerProducer { go withRecover(func() { for set := range bridge { request := set.buildRequest() + // Capture the current set to use in the callback + sendResponse := func(set *produceSet) ProduceCallback { + return func(response *ProduceResponse, err error) { + responses <- &brokerProducerResponse{ + set: set, + err: err, + res: response, + } + } + }(set) - response, err := broker.Produce(request) - - responses <- &brokerProducerResponse{ - set: set, - err: err, - res: response, + // Use AsyncProduce to not block on waiting for the response so that we can + // pipeline mutliple produce requests and achieve higher throughput, see: + // https://kafka.apache.org/protocol#protocol_network + err := broker.AsyncProduce(request, sendResponse) + if err != nil { + // Request failed to be sent + sendResponse(nil, err) + } + // Callback is not called when using NoResponse + if p.conf.Producer.RequiredAcks == NoResponse { + sendResponse(nil, nil) } } close(responses) diff --git a/broker.go b/broker.go index 9c3e5a04a..f494fcb33 100644 --- a/broker.go +++ b/broker.go @@ -119,6 +119,21 @@ type responsePromise struct { correlationID int32 packets chan []byte errors chan error + handler func([]byte, error) +} + +func (p *responsePromise) handle(packets []byte, err error) { + // Privilegiate callback when provided + if p.handler != nil { + p.handler(packets, err) + return + } + // Over channels + if err != nil { + p.errors <- err + return + } + p.packets <- packets } // NewBroker creates and returns a Broker targeting the given host:port address. @@ -333,6 +348,47 @@ func (b *Broker) GetAvailableOffsets(request *OffsetRequest) (*OffsetResponse, e return response, nil } +// ProduceCallback function is called once the produce response has been parsed or +// if it failed. +type ProduceCallback func(*ProduceResponse, error) + +// AsyncProduce sends a produce request and eventually call the provided callback +// with a produce response or an error. +// Waiting for the response is not blocking on the contrary to Produce. +// If the maximum number of in flight request configured has been met then +// the request will be blocked till a previous response has been received. +// When configured with RequiredAcks == NoResponse, the callback will be skipped. +func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error { + needAcks := request.RequiredAcks != NoResponse + // Use a nil promise when no acks is required + var promise *responsePromise + + if needAcks { + promise = &responsePromise{ + // Convert packets to ProduceResponse in the responseReceiver goroutine + handler: func(packets []byte, err error) { + if err != nil { + // Failed request + cb(nil, err) + return + } + + response := new(ProduceResponse) + if err := versionedDecode(packets, response, request.version()); err != nil { + // Malformed response + cb(nil, err) + return + } + + // Wellformed response + cb(response, nil) + }, + } + } + + return b.sendWithPromise(request, promise) +} + //Produce returns a produce response or error func (b *Broker) Produce(request *ProduceRequest) (*ProduceResponse, error) { var ( @@ -660,49 +716,69 @@ func (b *Broker) DeleteGroups(request *DeleteGroupsRequest) (*DeleteGroupsRespon } func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise, error) { + var promise *responsePromise + if promiseResponse { + // Packets or error will be sent to the following channels + // once the response is received + promise = &responsePromise{ + packets: make(chan []byte), + errors: make(chan error), + } + } + + if err := b.sendWithPromise(rb, promise); err != nil { + return nil, err + } + + return promise, nil +} + +func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) error { b.lock.Lock() defer b.lock.Unlock() if b.conn == nil { if b.connErr != nil { - return nil, b.connErr + return b.connErr } - return nil, ErrNotConnected + return ErrNotConnected } if !b.conf.Version.IsAtLeast(rb.requiredVersion()) { - return nil, ErrUnsupportedVersion + return ErrUnsupportedVersion } req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb} buf, err := encode(req, b.conf.MetricRegistry) if err != nil { - return nil, err + return err } err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)) if err != nil { - return nil, err + return err } requestTime := time.Now() bytes, err := b.conn.Write(buf) b.updateOutgoingCommunicationMetrics(bytes) //TODO: should it be after error check if err != nil { - return nil, err + return err } b.correlationID++ - if !promiseResponse { + if promise == nil { // Record request latency without the response b.updateRequestLatencyMetrics(time.Since(requestTime)) - return nil, nil + return nil } - promise := responsePromise{requestTime, req.correlationID, make(chan []byte), make(chan error)} - b.responses <- promise + promise.requestTime = requestTime + promise.correlationID = req.correlationID + // TODO check if we need a pointer + b.responses <- *promise - return &promise, nil + return nil } func (b *Broker) sendAndReceive(req protocolBody, res versionedDecoder) error { @@ -790,14 +866,14 @@ func (b *Broker) responseReceiver() { for response := range b.responses { if dead != nil { - response.errors <- dead + response.handle(nil, dead) continue } err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout)) if err != nil { dead = err - response.errors <- err + response.handle(nil, err) continue } @@ -806,7 +882,7 @@ func (b *Broker) responseReceiver() { if err != nil { b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency) dead = err - response.errors <- err + response.handle(nil, err) continue } @@ -815,7 +891,7 @@ func (b *Broker) responseReceiver() { if err != nil { b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency) dead = err - response.errors <- err + response.handle(nil, err) continue } if decodedHeader.correlationID != response.correlationID { @@ -823,7 +899,7 @@ func (b *Broker) responseReceiver() { // TODO if decoded ID < cur ID, discard until we catch up // TODO if decoded ID > cur ID, save it so when cur ID catches up we have a response dead = PacketDecodingError{fmt.Sprintf("correlation ID didn't match, wanted %d, got %d", response.correlationID, decodedHeader.correlationID)} - response.errors <- dead + response.handle(nil, dead) continue } @@ -832,11 +908,11 @@ func (b *Broker) responseReceiver() { b.updateIncomingCommunicationMetrics(bytesReadHeader+bytesReadBody, requestLatency) if err != nil { dead = err - response.errors <- err + response.handle(nil, err) continue } - response.packets <- buf + response.handle(buf, nil) } close(b.done) }