From db58b3aae279e9086f366ccecaf3d77a7f249e8d Mon Sep 17 00:00:00 2001 From: Kirill Parasotchenko Date: Mon, 5 Mar 2018 18:33:58 +0300 Subject: [PATCH] Add NATS transport NATS is an open-source, cloud-native messaging system (https://www.nats.io). The functional provides API that lets one works with NATS in a similar way as HTTP. - nats.MsgHandler could be used in queue or simple subscriber - Sync publisher --- .../addsvc/thrift/gen-go/addsvc/addsvc.go | 14 +- examples/stringsvc4/main.go | 206 ++++++++ transport/nats/doc.go | 2 + transport/nats/encode_decode.go | 32 ++ transport/nats/publisher.go | 110 ++++ transport/nats/publisher_test.go | 252 +++++++++ transport/nats/request_response_funcs.go | 22 + transport/nats/subscriber.go | 167 ++++++ transport/nats/subscriber_test.go | 477 ++++++++++++++++++ 9 files changed, 1275 insertions(+), 7 deletions(-) create mode 100644 examples/stringsvc4/main.go create mode 100644 transport/nats/doc.go create mode 100644 transport/nats/encode_decode.go create mode 100644 transport/nats/publisher.go create mode 100644 transport/nats/publisher_test.go create mode 100644 transport/nats/request_response_funcs.go create mode 100644 transport/nats/subscriber.go create mode 100644 transport/nats/subscriber_test.go diff --git a/examples/addsvc/thrift/gen-go/addsvc/addsvc.go b/examples/addsvc/thrift/gen-go/addsvc/addsvc.go index 729ad6226..00353d009 100644 --- a/examples/addsvc/thrift/gen-go/addsvc/addsvc.go +++ b/examples/addsvc/thrift/gen-go/addsvc/addsvc.go @@ -373,7 +373,7 @@ func (p *AddServiceProcessor) Process(ctx context.Context, iprot, oprot thrift.T oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) x5.Write(oprot) oprot.WriteMessageEnd() - oprot.Flush() + oprot.Flush(ctx) return false, x5 } @@ -390,7 +390,7 @@ func (p *addServiceProcessorSum) Process(ctx context.Context, seqId int32, iprot oprot.WriteMessageBegin("Sum", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() - oprot.Flush() + oprot.Flush(ctx) return false, err } @@ -403,7 +403,7 @@ var retval *SumReply oprot.WriteMessageBegin("Sum", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() - oprot.Flush() + oprot.Flush(ctx) return true, err2 } else { result.Success = retval @@ -417,7 +417,7 @@ var retval *SumReply if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } - if err2 = oprot.Flush(); err == nil && err2 != nil { + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { err = err2 } if err != nil { @@ -438,7 +438,7 @@ func (p *addServiceProcessorConcat) Process(ctx context.Context, seqId int32, ip oprot.WriteMessageBegin("Concat", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() - oprot.Flush() + oprot.Flush(ctx) return false, err } @@ -451,7 +451,7 @@ var retval *ConcatReply oprot.WriteMessageBegin("Concat", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() - oprot.Flush() + oprot.Flush(ctx) return true, err2 } else { result.Success = retval @@ -465,7 +465,7 @@ var retval *ConcatReply if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } - if err2 = oprot.Flush(); err == nil && err2 != nil { + if err2 = oprot.Flush(ctx); err == nil && err2 != nil { err = err2 } if err != nil { diff --git a/examples/stringsvc4/main.go b/examples/stringsvc4/main.go new file mode 100644 index 000000000..c6447079e --- /dev/null +++ b/examples/stringsvc4/main.go @@ -0,0 +1,206 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "log" + "strings" + "flag" + "net/http" + + "github.com/go-kit/kit/endpoint" + natstransport "github.com/go-kit/kit/transport/nats" + httptransport "github.com/go-kit/kit/transport/http" + + "github.com/nats-io/go-nats" +) + +// StringService provides operations on strings. +type StringService interface { + Uppercase(context.Context, string) (string, error) + Count(context.Context, string) int +} + +// stringService is a concrete implementation of StringService +type stringService struct{} + +func (stringService) Uppercase(_ context.Context, s string) (string, error) { + if s == "" { + return "", ErrEmpty + } + return strings.ToUpper(s), nil +} + +func (stringService) Count(_ context.Context, s string) int { + return len(s) +} + +// ErrEmpty is returned when an input string is empty. +var ErrEmpty = errors.New("empty string") + +// For each method, we define request and response structs +type uppercaseRequest struct { + S string `json:"s"` +} + +type uppercaseResponse struct { + V string `json:"v"` + Err string `json:"err,omitempty"` // errors don't define JSON marshaling +} + +type countRequest struct { + S string `json:"s"` +} + +type countResponse struct { + V int `json:"v"` +} + +// Endpoints are a primary abstraction in go-kit. An endpoint represents a single RPC (method in our service interface) +func makeUppercaseHTTPEndpoint(nc *nats.Conn) endpoint.Endpoint { + return natstransport.NewPublisher( + nc, + "stringsvc.uppercase", + natstransport.EncodeJSONRequest, + decodeUppercaseResponse, + ).Endpoint() +} + +func makeCountHTTPEndpoint(nc *nats.Conn) endpoint.Endpoint { + return natstransport.NewPublisher( + nc, + "stringsvc.count", + natstransport.EncodeJSONRequest, + decodeCountResponse, + ).Endpoint() +} + +func makeUppercaseEndpoint(svc StringService) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(uppercaseRequest) + v, err := svc.Uppercase(ctx, req.S) + if err != nil { + return uppercaseResponse{v, err.Error()}, nil + } + return uppercaseResponse{v, ""}, nil + } +} + +func makeCountEndpoint(svc StringService) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(countRequest) + v := svc.Count(ctx, req.S) + return countResponse{v}, nil + } +} + +// Transports expose the service to the network. In this fourth example we utilize JSON over NATS and HTTP. +func main() { + svc := stringService{} + + natsURL := flag.String("nats-url", nats.DefaultURL, "URL for connection to NATS") + flag.Parse() + + nc, err := nats.Connect(*natsURL) + if err != nil { + log.Fatal(err) + } + defer nc.Close() + + uppercaseHTTPHandler := httptransport.NewServer( + makeUppercaseHTTPEndpoint(nc), + decodeUppercaseHTTPRequest, + httptransport.EncodeJSONResponse, + ) + + countHTTPHandler := httptransport.NewServer( + makeCountHTTPEndpoint(nc), + decodeCountHTTPRequest, + httptransport.EncodeJSONResponse, + ) + + uppercaseHandler := natstransport.NewSubscriber( + makeUppercaseEndpoint(svc), + decodeUppercaseRequest, + natstransport.EncodeJSONResponse, + ) + + countHandler := natstransport.NewSubscriber( + makeCountEndpoint(svc), + decodeCountRequest, + natstransport.EncodeJSONResponse, + ) + + uSub, err := nc.QueueSubscribe("stringsvc.uppercase", "stringsvc", uppercaseHandler.ServeMsg(nc)) + if err != nil { + log.Fatal(err) + } + defer uSub.Unsubscribe() + + cSub, err := nc.QueueSubscribe("stringsvc.count", "stringsvc", countHandler.ServeMsg(nc)) + if err != nil { + log.Fatal(err) + } + defer cSub.Unsubscribe() + + http.Handle("/uppercase", uppercaseHTTPHandler) + http.Handle("/count", countHTTPHandler) + log.Fatal(http.ListenAndServe(":8080", nil)) + +} + +func decodeUppercaseHTTPRequest(_ context.Context, r *http.Request) (interface{}, error) { + var request uppercaseRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + return nil, err + } + return request, nil +} + +func decodeCountHTTPRequest(_ context.Context, r *http.Request) (interface{}, error) { + var request countRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + return nil, err + } + return request, nil +} + +func decodeUppercaseResponse(_ context.Context, msg *nats.Msg) (interface{}, error) { + var response uppercaseResponse + + if err := json.Unmarshal(msg.Data, &response); err != nil { + return nil, err + } + + return response, nil +} + +func decodeCountResponse(_ context.Context, msg *nats.Msg) (interface{}, error) { + var response countResponse + + if err := json.Unmarshal(msg.Data, &response); err != nil { + return nil, err + } + + return response, nil +} + +func decodeUppercaseRequest(_ context.Context, msg *nats.Msg) (interface{}, error) { + var request uppercaseRequest + + if err := json.Unmarshal(msg.Data, &request); err != nil { + return nil, err + } + return request, nil +} + +func decodeCountRequest(_ context.Context, msg *nats.Msg) (interface{}, error) { + var request countRequest + + if err := json.Unmarshal(msg.Data, &request); err != nil { + return nil, err + } + return request, nil +} + diff --git a/transport/nats/doc.go b/transport/nats/doc.go new file mode 100644 index 000000000..e34a06870 --- /dev/null +++ b/transport/nats/doc.go @@ -0,0 +1,2 @@ +// Package nats provides a NATS transport. +package nats diff --git a/transport/nats/encode_decode.go b/transport/nats/encode_decode.go new file mode 100644 index 000000000..ec4accf50 --- /dev/null +++ b/transport/nats/encode_decode.go @@ -0,0 +1,32 @@ +package nats + +import ( + "context" + + "github.com/nats-io/go-nats" +) + +// DecodeRequestFunc extracts a user-domain request object from a publisher +// request object. It's designed to be used in NATS subscribers, for subscriber-side +// endpoints. One straightforward DecodeRequestFunc could be something that +// JSON decodes from the request body to the concrete response type. +type DecodeRequestFunc func(context.Context, *nats.Msg) (request interface{}, err error) + +// EncodeRequestFunc encodes the passed request object into the NATS request +// object. It's designed to be used in NATS publishers, for publisher-side +// endpoints. One straightforward EncodeRequestFunc could something that JSON +// encodes the object directly to the request payload. +type EncodeRequestFunc func(context.Context, *nats.Msg, interface{}) error + +// EncodeResponseFunc encodes the passed response object to the subscriber reply. +// It's designed to be used in NATS subscribers, for subscriber-side +// endpoints. One straightforward EncodeResponseFunc could be something that +// JSON encodes the object directly to the response body. +type EncodeResponseFunc func(context.Context, string, *nats.Conn, interface{}) error + +// DecodeResponseFunc extracts a user-domain response object from an NATS +// response object. It's designed to be used in NATS publisher, for publisher-side +// endpoints. One straightforward DecodeResponseFunc could be something that +// JSON decodes from the response payload to the concrete response type. +type DecodeResponseFunc func(context.Context, *nats.Msg) (response interface{}, err error) + diff --git a/transport/nats/publisher.go b/transport/nats/publisher.go new file mode 100644 index 000000000..7ba40fc52 --- /dev/null +++ b/transport/nats/publisher.go @@ -0,0 +1,110 @@ +package nats + +import ( + "context" + "encoding/json" + "github.com/go-kit/kit/endpoint" + "github.com/nats-io/go-nats" + "time" +) + +// Publisher wraps a URL and provides a method that implements endpoint.Endpoint. +type Publisher struct { + publisher *nats.Conn + subject string + enc EncodeRequestFunc + dec DecodeResponseFunc + before []RequestFunc + after []PublisherResponseFunc + timeout time.Duration +} + +// NewClient constructs a usable Publisher for a single remote method. +func NewPublisher( + publisher *nats.Conn, + subject string, + enc EncodeRequestFunc, + dec DecodeResponseFunc, + options ...PublisherOption, +) *Publisher { + p := &Publisher{ + publisher: publisher, + subject: subject, + enc: enc, + dec: dec, + timeout: 10 * time.Second, + } + for _, option := range options { + option(p) + } + return p +} + +// PublisherOption sets an optional parameter for clients. +type PublisherOption func(*Publisher) + +// PublisherBefore sets the RequestFuncs that are applied to the outgoing NATS +// request before it's invoked. +func PublisherBefore(before ...RequestFunc) PublisherOption { + return func(p *Publisher) { p.before = append(p.before, before...) } +} + +// PublisherAfter sets the ClientResponseFuncs applied to the incoming NATS +// request prior to it being decoded. This is useful for obtaining anything off +// of the response and adding onto the context prior to decoding. +func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { + return func(p *Publisher) { p.after = append(p.after, after...) } +} + +// PublisherTimeout sets the available timeout for NATS request. +func PublisherTimeout(timeout time.Duration) PublisherOption { + return func(p *Publisher) { p.timeout = timeout } +} + +// Endpoint returns a usable endpoint that invokes the remote endpoint. +func (p Publisher) Endpoint() endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + defer cancel() + + msg := nats.Msg{Subject: p.subject} + + if err := p.enc(ctx, &msg, request); err != nil { + return nil, err + } + + for _, f := range p.before { + ctx = f(ctx, &msg) + } + + resp, err := p.publisher.RequestWithContext(ctx, msg.Subject, msg.Data) + if err != nil { + return nil, err + } + + for _, f := range p.after { + ctx = f(ctx, resp) + } + + response, err := p.dec(ctx, resp) + if err != nil { + return nil, err + } + + return response, nil + } +} + +// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a +// JSON object to the Data of the Msg. Many JSON-over-NATS services can use it as +// a sensible default. +func EncodeJSONRequest(_ context.Context, msg *nats.Msg, request interface{}) error { + b, err := json.Marshal(request) + if err != nil { + return err + } + + msg.Data = b + + return nil +} diff --git a/transport/nats/publisher_test.go b/transport/nats/publisher_test.go new file mode 100644 index 000000000..8468f1b2f --- /dev/null +++ b/transport/nats/publisher_test.go @@ -0,0 +1,252 @@ +package nats_test + +import ( + "testing" + "context" + "time" + "strings" + + "github.com/nats-io/go-nats" + natstransport "github.com/go-kit/kit/transport/nats" +) + +func TestPublisher(t *testing.T) { + var ( + testdata = "testdata" + encode = func(context.Context, *nats.Msg, interface{}) error { return nil } + decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) { + return TestResponse{string(msg.Data), ""}, nil + } + ) + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { + if err := nc.Publish(msg.Reply, []byte(testdata)); err != nil { + t.Fatal(err) + } + }) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + publisher := natstransport.NewPublisher( + nc, + "natstransport.test", + encode, + decode, + ) + + res, err := publisher.Endpoint()(context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + + response, ok := res.(TestResponse) + if !ok { + t.Fatal("response should be TestResponse") + } + if want, have := testdata, response.String; want != have { + t.Errorf("want %q, have %q", want, have) + } + +} + +func TestPublisherBefore(t *testing.T) { + var ( + testdata = "testdata" + encode = func(context.Context, *nats.Msg, interface{}) error { return nil } + decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) { + return TestResponse{string(msg.Data), ""}, nil + } + ) + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { + if err := nc.Publish(msg.Reply, msg.Data); err != nil { + t.Fatal(err) + } + }) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + publisher := natstransport.NewPublisher( + nc, + "natstransport.test", + encode, + decode, + natstransport.PublisherBefore(func(ctx context.Context, msg *nats.Msg) context.Context { + msg.Data = []byte(strings.ToUpper(string(testdata))) + return ctx + }), + ) + + res, err := publisher.Endpoint()(context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + + response, ok := res.(TestResponse) + if !ok { + t.Fatal("response should be TestResponse") + } + if want, have := strings.ToUpper(testdata), response.String; want != have { + t.Errorf("want %q, have %q", want, have) + } + +} + +func TestPublisherAfter(t *testing.T) { + var ( + testdata = "testdata" + encode = func(context.Context, *nats.Msg, interface{}) error { return nil } + decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) { + return TestResponse{string(msg.Data), ""}, nil + } + ) + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { + if err := nc.Publish(msg.Reply, []byte(testdata)); err != nil { + t.Fatal(err) + } + }) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + publisher := natstransport.NewPublisher( + nc, + "natstransport.test", + encode, + decode, + natstransport.PublisherAfter(func(ctx context.Context, msg *nats.Msg) context.Context { + msg.Data = []byte(strings.ToUpper(string(msg.Data))) + return ctx + }), + ) + + res, err := publisher.Endpoint()(context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + + response, ok := res.(TestResponse) + if !ok { + t.Fatal("response should be TestResponse") + } + if want, have := strings.ToUpper(testdata), response.String; want != have { + t.Errorf("want %q, have %q", want, have) + } + +} + +func TestPublisherTimeout(t *testing.T) { + var ( + encode = func(context.Context, *nats.Msg, interface{}) error { return nil } + decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) { + return TestResponse{string(msg.Data), ""}, nil + } + ) + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + ch := make(chan struct{}) + defer close(ch) + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { + <-ch + }) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + publisher := natstransport.NewPublisher( + nc, + "natstransport.test", + encode, + decode, + natstransport.PublisherTimeout(time.Second), + ) + + _, err = publisher.Endpoint()(context.Background(), struct{}{}) + if err != context.DeadlineExceeded { + t.Errorf("want %s, have %s", context.DeadlineExceeded, err) + + } + +} + +func TestEncodeJSONRequest(t *testing.T) { + var data string + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { + data = string(msg.Data) + + if err := nc.Publish(msg.Reply, []byte("")); err != nil { + t.Fatal(err) + } + }) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + publisher := natstransport.NewPublisher( + nc, + "natstransport.test", + natstransport.EncodeJSONRequest, + func(context.Context, *nats.Msg) (interface{}, error) { return nil, nil }, + ).Endpoint() + + for _, test := range []struct { + value interface{} + body string + }{ + {nil, "null"}, + {12, "12"}, + {1.2, "1.2"}, + {true, "true"}, + {"test", "\"test\""}, + {struct{ Foo string `json:"foo"` }{"foo"}, "{\"foo\":\"foo\"}"}, + } { + if _, err := publisher(context.Background(), test.value); err != nil { + t.Fatal(err) + continue + } + + if data != test.body { + t.Errorf("%v: actual %#v, expected %#v", test.value, data, test.body) + } + } + +} diff --git a/transport/nats/request_response_funcs.go b/transport/nats/request_response_funcs.go new file mode 100644 index 000000000..32cec57de --- /dev/null +++ b/transport/nats/request_response_funcs.go @@ -0,0 +1,22 @@ +package nats + +import ( + "context" + + "github.com/nats-io/go-nats" +) + +// RequestFunc may take information from a publisher request and put it into a +// request context. In Subscribers, RequestFuncs are executed prior to invoking the +// endpoint. +type RequestFunc func(context.Context, *nats.Msg) context.Context + +// SubscriberResponseFunc may take information from a request context and use it to +// manipulate a Publisher. SubscriberResponseFuncs are only executed in +// subscribers, after invoking the endpoint but prior to publishing a reply. +type SubscriberResponseFunc func(context.Context, *nats.Conn) context.Context + +// PublisherResponseFunc may take information from an NATS request and make the +// response available for consumption. ClientResponseFuncs are only executed in +// clients, after a request has been made, but prior to it being decoded. +type PublisherResponseFunc func(context.Context, *nats.Msg) context.Context diff --git a/transport/nats/subscriber.go b/transport/nats/subscriber.go new file mode 100644 index 000000000..c562c2a52 --- /dev/null +++ b/transport/nats/subscriber.go @@ -0,0 +1,167 @@ +package nats + +import ( + "context" + "encoding/json" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + + "github.com/nats-io/go-nats" +) + +// Subscriber wraps an endpoint and provides nats.MsgHandler. +type Subscriber struct { + e endpoint.Endpoint + dec DecodeRequestFunc + enc EncodeResponseFunc + before []RequestFunc + after []SubscriberResponseFunc + errorEncoder ErrorEncoder + logger log.Logger +} + +// NewSubscriber constructs a new subscriber, which provides nats.MsgHandler and wraps +// the provided endpoint. +func NewSubscriber( + e endpoint.Endpoint, + dec DecodeRequestFunc, + enc EncodeResponseFunc, + options ...SubscriberOption, +) *Subscriber { + s := &Subscriber{ + e: e, + dec: dec, + enc: enc, + errorEncoder: DefaultErrorEncoder, + logger: log.NewNopLogger(), + } + for _, option := range options { + option(s) + } + return s +} + +// SubscriberOption sets an optional parameter for subscribers. +type SubscriberOption func(*Subscriber) + +// SubscriberBefore functions are executed on the publisher request object before the +// request is decoded. +func SubscriberBefore(before ...RequestFunc) SubscriberOption { + return func(s *Subscriber) { s.before = append(s.before, before...) } +} + +// SubscriberAfter functions are executed on the subscriber reply after the +// endpoint is invoked, but before anything is published to the reply. +func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption { + return func(s *Subscriber) { s.after = append(s.after, after...) } +} + +// SubscriberErrorEncoder is used to encode errors to the subscriber reply +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting. By default, +// errors will be published with the DefaultErrorEncoder. +func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption { + return func(s *Subscriber) { s.errorEncoder = ee } +} + +// SubscriberErrorLogger is used to log non-terminal errors. By default, no errors +// are logged. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom SubscriberErrorEncoder which has access to the context. +func SubscriberErrorLogger(logger log.Logger) SubscriberOption { + return func(s *Subscriber) { s.logger = logger } +} + +// ServeMsg provides nats.MsgHandler. +func (s Subscriber) ServeMsg(nc *nats.Conn) func(msg *nats.Msg) { + return func(msg *nats.Msg) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for _, f := range s.before { + ctx = f(ctx, msg) + } + + request, err := s.dec(ctx, msg) + if err != nil { + s.logger.Log("err", err) + if msg.Reply == "" { + return + } + s.errorEncoder(ctx, err, msg.Reply, nc) + return + } + + response, err := s.e(ctx, request) + if err != nil { + s.logger.Log("err", err) + if msg.Reply == "" { + return + } + s.errorEncoder(ctx, err, msg.Reply, nc) + return + } + + for _, f := range s.after { + ctx = f(ctx, nc) + } + + if msg.Reply == "" { + return + } + + if err := s.enc(ctx, msg.Reply, nc, response); err != nil { + s.logger.Log("err", err) + s.errorEncoder(ctx, err, msg.Reply, nc) + return + } + } +} + +// ErrorEncoder is responsible for encoding an error to the subscriber reply. +// Users are encouraged to use custom ErrorEncoders to encode errors to +// their replies, and will likely want to pass and check for their own error +// types. +type ErrorEncoder func(ctx context.Context, err error, reply string, nc *nats.Conn) + +// NopRequestDecoder is a DecodeRequestFunc that can be used for requests that do not +// need to be decoded, and simply returns nil, nil. +func NopRequestDecoder(_ context.Context, _ *nats.Msg) (interface{}, error) { + return nil, nil +} + +// EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a +// JSON object to the subscriber reply. Many JSON-over services can use it as +// a sensible default. +func EncodeJSONResponse(_ context.Context, reply string, nc *nats.Conn, response interface{}) error { + b, err := json.Marshal(response) + if err != nil { + return err + } + + return nc.Publish(reply, b) +} + +// DefaultErrorEncoder writes the error to the subscriber reply. +func DefaultErrorEncoder(_ context.Context, err error, reply string, nc *nats.Conn) { + logger := log.NewNopLogger() + + type Response struct { + Error string `json:"err"` + } + + var response Response + + response.Error = err.Error() + + b, err := json.Marshal(response) + if err != nil { + logger.Log("err", err) + return + } + + if err := nc.Publish(reply, b); err != nil { + logger.Log("err", err) + } +} diff --git a/transport/nats/subscriber_test.go b/transport/nats/subscriber_test.go new file mode 100644 index 000000000..a2ba160f3 --- /dev/null +++ b/transport/nats/subscriber_test.go @@ -0,0 +1,477 @@ +package nats_test + +import ( + "testing" + "context" + "errors" + "time" + "sync" + "strings" + "encoding/json" + + "github.com/nats-io/go-nats" + "github.com/nats-io/gnatsd/server" + + natstransport "github.com/go-kit/kit/transport/nats" + "github.com/go-kit/kit/endpoint" +) + +type TestResponse struct { + String string `json:"str"` + Error string `json:"err"` +} + +func init() { + opts := server.Options{Host: "localhost", Port: 4222} + natsServer := server.New(&opts) + + go func() { + natsServer.Start() + }() + + if ok := natsServer.ReadyForConnections(2 * time.Second); !ok { + panic("Failed start of NATS") + } +} + +func TestSubscriberBadDecode(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + handler := natstransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, errors.New("dang") }, + func(context.Context, string, *nats.Conn, interface{}) error { return nil }, + ) + + resp := testRequest(t, nc, handler) + + if want, have := "dang", resp.Error; want != have { + t.Errorf("want %s, have %s", want, have) + } + +} + +func TestSubscriberBadEndpoint(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + handler := natstransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") }, + func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, string, *nats.Conn, interface{}) error { return nil }, + ) + + resp := testRequest(t, nc, handler) + + if want, have := "dang", resp.Error; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSubscriberBadEncode(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + handler := natstransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, string, *nats.Conn, interface{}) error { return errors.New("dang") }, + ) + + resp := testRequest(t, nc, handler) + + if want, have := "dang", resp.Error; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSubscriberErrorEncoder(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + errTeapot := errors.New("teapot") + code := func(err error) error { + if err == errTeapot { + return err + } + return errors.New("dang") + } + handler := natstransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot }, + func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, string, *nats.Conn, interface{}) error { return nil }, + natstransport.SubscriberErrorEncoder(func(_ context.Context, err error, reply string, nc *nats.Conn) { + var r TestResponse + r.Error = code(err).Error() + + b, err := json.Marshal(r) + if err != nil { + t.Fatal(err) + } + + if err := nc.Publish(reply, b); err != nil { + t.Fatal(err) + } + }), + ) + + resp := testRequest(t, nc, handler) + + if want, have := errTeapot.Error(), resp.Error; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSubscriberHappySubject(t *testing.T) { + step, response := testSubscriber(t) + step() + r := <-response + + var resp TestResponse + err := json.Unmarshal(r.Data, &resp) + if err != nil { + t.Fatal(err) + } + + if want, have := "", resp.Error; want != have { + t.Errorf("want %s, have %s (%s)", want, have, r.Data) + } +} + +func TestMultipleSubscriberBefore(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + var ( + response = struct{ Body string }{"go eat a fly ugly\n"} + wg sync.WaitGroup + done = make(chan struct{}) + ) + handler := natstransport.NewSubscriber( + endpoint.Nop, + func(context.Context, *nats.Msg) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, reply string, nc *nats.Conn, _ interface{}) error { + b, err := json.Marshal(response) + if err != nil { + return err + } + + return nc.Publish(reply, b) + }, + natstransport.SubscriberBefore(func(ctx context.Context, _ *nats.Msg) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + natstransport.SubscriberBefore(func(ctx context.Context, _ *nats.Msg) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerBefores are used") + } + + close(done) + return ctx + }), + ) + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + wg.Add(1) + go func() { + defer wg.Done() + _, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) + if err != nil { + t.Fatal(err) + } + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } + + wg.Wait() +} + +func TestMultipleSubscriberAfter(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + var ( + response = struct{ Body string }{"go eat a fly ugly\n"} + wg sync.WaitGroup + done = make(chan struct{}) + ) + handler := natstransport.NewSubscriber( + endpoint.Nop, + func(context.Context, *nats.Msg) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, reply string, nc *nats.Conn, _ interface{}) error { + b, err := json.Marshal(response) + if err != nil { + return err + } + + return nc.Publish(reply, b) + }, + natstransport.SubscriberAfter(func(ctx context.Context, nc *nats.Conn) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + natstransport.SubscriberAfter(func(ctx context.Context, nc *nats.Conn) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerAfters are used") + } + + close(done) + return ctx + }), + ) + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + wg.Add(1) + go func() { + defer wg.Done() + _, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) + if err != nil { + t.Fatal(err) + } + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } + + wg.Wait() +} + +func TestEncodeJSONResponse(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + handler := natstransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{ Foo string `json:"foo"` }{"bar"}, nil }, + func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, + natstransport.EncodeJSONResponse, + ) + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) + if err != nil { + t.Fatal(err) + } + + if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(r.Data)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + +type responseError struct { + msg string +} + +func (m responseError) Error() string { + return m.msg +} + +func TestErrorEncoder(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + errResp := struct{ Error string `json:"err"` }{"oh no"} + handler := natstransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { + return nil, responseError{msg: errResp.Error} + }, + func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, + natstransport.EncodeJSONResponse, + ) + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) + if err != nil { + t.Fatal(err) + } + + b, err := json.Marshal(errResp) + if err != nil { + t.Fatal(err) + } + if string(b) != string(r.Data) { + t.Errorf("ErrorEncoder: got: %q, expected: %q", r.Data, b) + } +} + +type noContentResponse struct{} + +func TestEncodeNoContent(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + handler := natstransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil }, + func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, + natstransport.EncodeJSONResponse, + ) + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) + if err != nil { + t.Fatal(err) + } + + if want, have := `{}`, strings.TrimSpace(string(r.Data)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + +func TestNoOpRequestDecoder(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + handler := natstransport.NewSubscriber( + func(ctx context.Context, request interface{}) (interface{}, error) { + if request != nil { + t.Error("Expected nil request in endpoint when using NopRequestDecoder") + } + return nil, nil + }, + natstransport.NopRequestDecoder, + natstransport.EncodeJSONResponse, + ) + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) + if err != nil { + t.Fatal(err) + } + + if want, have := `null`, strings.TrimSpace(string(r.Data)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + +func testSubscriber(t *testing.T) (step func(), resp <-chan *nats.Msg) { + var ( + stepch = make(chan bool) + endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil } + response = make(chan *nats.Msg) + handler = natstransport.NewSubscriber( + endpoint, + func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, + natstransport.EncodeJSONResponse, + natstransport.SubscriberBefore(func(ctx context.Context, msg *nats.Msg) context.Context { return ctx }), + natstransport.SubscriberAfter(func(ctx context.Context, nc *nats.Conn) context.Context { return ctx }), + ) + ) + + go func() { + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) + if err != nil { + t.Fatal(err) + } + + response <- r + }() + + return func() { stepch <- true }, response +} + +func testRequest(t *testing.T, nc *nats.Conn, handler *natstransport.Subscriber) TestResponse { + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) + if err != nil { + t.Fatal(err) + } + + var resp TestResponse + err = json.Unmarshal(r.Data, &resp) + if err != nil { + t.Fatal(err) + } + + return resp +}