diff --git a/transport/tchannel/channel_transport.go b/transport/tchannel/channel_transport.go index 2ff8867b3..1dfe05b89 100644 --- a/transport/tchannel/channel_transport.go +++ b/transport/tchannel/channel_transport.go @@ -82,12 +82,13 @@ func (options transportOptions) newChannelTransport() *ChannelTransport { logger = zap.NewNop() } return &ChannelTransport{ - once: lifecycle.NewOnce(), - ch: options.ch, - addr: options.addr, - tracer: options.tracer, - logger: logger.Named("tchannel"), - originalHeaders: options.originalHeaders, + once: lifecycle.NewOnce(), + ch: options.ch, + addr: options.addr, + tracer: options.tracer, + logger: logger.Named("tchannel"), + originalHeaders: options.originalHeaders, + newResponseWriter: newHandlerWriter, } } @@ -96,15 +97,15 @@ func (options transportOptions) newChannelTransport() *ChannelTransport { // If you have a YARPC peer.Chooser, use the unqualified tchannel.Transport // instead. type ChannelTransport struct { - ch Channel - name string - addr string - tracer opentracing.Tracer - logger *zap.Logger - router transport.Router - originalHeaders bool - - once *lifecycle.Once + once *lifecycle.Once + ch Channel + name string + addr string + tracer opentracing.Tracer + logger *zap.Logger + router transport.Router + originalHeaders bool + newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter } // Channel returns the underlying TChannel "Channel" instance. @@ -139,7 +140,7 @@ func (t *ChannelTransport) start() error { for s := range services { sc := t.ch.GetSubChannel(s) existing := sc.GetHandlers() - sc.SetHandler(handler{existing: existing, router: t.router, tracer: t.tracer, logger: t.logger, newResponseWriter: newHandlerWriter}) + sc.SetHandler(handler{existing: existing, router: t.router, tracer: t.tracer, logger: t.logger, newResponseWriter: t.newResponseWriter}) } } diff --git a/transport/tchannel/transport.go b/transport/tchannel/transport.go index 07526f7e8..f97f37170 100644 --- a/transport/tchannel/transport.go +++ b/transport/tchannel/transport.go @@ -52,13 +52,14 @@ type Transport struct { lock sync.Mutex once *lifecycle.Once - ch *tchannel.Channel - router transport.Router - tracer opentracing.Tracer - logger *zap.Logger - name string - addr string - listener net.Listener + ch *tchannel.Channel + router transport.Router + tracer opentracing.Tracer + logger *zap.Logger + name string + addr string + listener net.Listener + newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter connTimeout time.Duration initialConnRetryDelay time.Duration @@ -111,6 +112,7 @@ func (o transportOptions) newTransport() *Transport { tracer: o.tracer, logger: logger, headerCase: headerCase, + newResponseWriter: newHandlerWriter, } } @@ -202,7 +204,7 @@ func (t *Transport) start() error { tracer: t.tracer, headerCase: t.headerCase, logger: t.logger, - newResponseWriter: newHandlerWriter, + newResponseWriter: t.newResponseWriter, }, OnPeerStatusChanged: t.onPeerStatusChanged, }