diff --git a/transport/nats/subscriber.go b/transport/nats/subscriber.go index c562c2a52..7cb8eb8ec 100644 --- a/transport/nats/subscriber.go +++ b/transport/nats/subscriber.go @@ -18,6 +18,7 @@ type Subscriber struct { before []RequestFunc after []SubscriberResponseFunc errorEncoder ErrorEncoder + finalizer []SubscriberFinalizerFunc logger log.Logger } @@ -73,12 +74,26 @@ func SubscriberErrorLogger(logger log.Logger) SubscriberOption { return func(s *Subscriber) { s.logger = logger } } +// SubscriberFinalizer is executed at the end of every request from a publisher through NATS. +// By default, no finalizer is registered. +func SubscriberFinalizer(f ...SubscriberFinalizerFunc) SubscriberOption { + return func(s *Subscriber) { s.finalizer = f } +} + // ServeMsg provides nats.MsgHandler. func (s Subscriber) ServeMsg(nc *nats.Conn) func(msg *nats.Msg) { return func(msg *nats.Msg) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + if len(s.finalizer) > 0 { + defer func() { + for _, f := range s.finalizer { + f(ctx, msg) + } + }() + } + for _, f := range s.before { ctx = f(ctx, msg) } @@ -125,6 +140,11 @@ func (s Subscriber) ServeMsg(nc *nats.Conn) func(msg *nats.Msg) { // types. type ErrorEncoder func(ctx context.Context, err error, reply string, nc *nats.Conn) +// ServerFinalizerFunc can be used to perform work at the end of an request +// from a publisher, after the response has been written to the publisher. The principal +// intended use is for request logging. +type SubscriberFinalizerFunc func(ctx context.Context, msg *nats.Msg) + // NopRequestDecoder is a DecodeRequestFunc that can be used for requests that do not // need to be decoded, and simply returns nil, nil. func NopRequestDecoder(_ context.Context, _ *nats.Msg) (interface{}, error) { diff --git a/transport/nats/subscriber_test.go b/transport/nats/subscriber_test.go index 4d44dbbd0..b23544354 100644 --- a/transport/nats/subscriber_test.go +++ b/transport/nats/subscriber_test.go @@ -289,6 +289,57 @@ func TestMultipleSubscriberAfter(t *testing.T) { wg.Wait() } +func TestSubscriberFinalizerFunc(t *testing.T) { + nc := newNatsConn(t) + defer nc.Close() + + var ( + response = struct{ Body string }{"go eat a fly ugly\n"} + wg sync.WaitGroup + done = make(chan struct{}) + ) + handler := natstransport.NewSubscriber( + endpoint.Nop, + func(context.Context, *nats.Msg) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, reply string, nc *nats.Conn, _ interface{}) error { + b, err := json.Marshal(response) + if err != nil { + return err + } + + return nc.Publish(reply, b) + }, + natstransport.SubscriberFinalizer(func(ctx context.Context, _ *nats.Msg) { + close(done) + }), + ) + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + wg.Add(1) + go func() { + defer wg.Done() + _, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) + if err != nil { + t.Fatal(err) + } + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } + + wg.Wait() +} + func TestEncodeJSONResponse(t *testing.T) { nc := newNatsConn(t) defer nc.Close()