Skip to content

Commit

Permalink
feat: adding stream interceptor for logging middleware (#3359)
Browse files Browse the repository at this point in the history
  • Loading branch information
akoserwal authored Sep 18, 2024
1 parent 908e625 commit e1f5dc4
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 28 deletions.
77 changes: 74 additions & 3 deletions transport/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
grpcinsecure "google.golang.org/grpc/credentials/insecure"
grpcmd "google.golang.org/grpc/metadata"

"github.com/go-kratos/kratos/v2/internal/matcher"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry"
Expand Down Expand Up @@ -132,6 +133,7 @@ type clientOptions struct {
timeout time.Duration
discovery registry.Discovery
middleware []middleware.Middleware
streamMiddleware []middleware.Middleware
ints []grpc.UnaryClientInterceptor
streamInts []grpc.StreamClientInterceptor
grpcOpts []grpc.DialOption
Expand Down Expand Up @@ -166,7 +168,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
unaryClientInterceptor(options.middleware, options.timeout, options.filters),
}
sints := []grpc.StreamClientInterceptor{
streamClientInterceptor(options.filters),
streamClientInterceptor(options.streamMiddleware, options.filters),
}

if len(options.ints) > 0 {
Expand Down Expand Up @@ -239,7 +241,54 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, f
}
}

func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInterceptor {
// wrappedClientStream wraps the grpc.ClientStream and applies middleware
type wrappedClientStream struct {
grpc.ClientStream
ctx context.Context
middleware matcher.Matcher
}

func (w *wrappedClientStream) Context() context.Context {
return w.ctx
}

func (w *wrappedClientStream) SendMsg(m interface{}) error {
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return req, w.ClientStream.SendMsg(m)
}

info, ok := transport.FromClientContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}

if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}

_, err := h(w.ctx, m)
return err
}

func (w *wrappedClientStream) RecvMsg(m interface{}) error {
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return req, w.ClientStream.RecvMsg(m)
}

info, ok := transport.FromClientContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}

if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}

_, err := h(w.ctx, m)
return err
}

func streamClientInterceptor(ms []middleware.Middleware, filters []selector.NodeFilter) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { // nolint
ctx = transport.NewClientContext(ctx, &Transport{
endpoint: cc.Target(),
Expand All @@ -249,6 +298,28 @@ func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInt
})
var p selector.Peer
ctx = selector.NewPeerContext(ctx, &p)
return streamer(ctx, desc, cc, method, opts...)

clientStream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, err
}

h := func(ctx context.Context, req interface{}) (interface{}, error) {
return streamer, nil
}

m := matcher.New()
if len(ms) > 0 {
m.Use(ms...)
middleware.Chain(ms...)(h)
}

wrappedStream := &wrappedClientStream{
ClientStream: clientStream,
ctx: ctx,
middleware: m,
}

return wrappedStream, nil
}
}
67 changes: 64 additions & 3 deletions transport/grpc/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package grpc

import (
"context"
"fmt"

"google.golang.org/grpc"
grpcmd "google.golang.org/grpc/metadata"

ic "github.com/go-kratos/kratos/v2/internal/context"
"github.com/go-kratos/kratos/v2/internal/matcher"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
)
Expand Down Expand Up @@ -48,13 +50,15 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
// wrappedStream is rewrite grpc stream's context
type wrappedStream struct {
grpc.ServerStream
ctx context.Context
ctx context.Context
middleware matcher.Matcher
}

func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream {
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream, m matcher.Matcher) grpc.ServerStream {
return &wrappedStream{
ServerStream: stream,
ctx: ctx,
middleware: m,
}
}

Expand All @@ -76,7 +80,19 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
replyHeader: headerCarrier(replyHeader),
})

ws := NewWrappedStream(ctx, ss)
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return handler(srv, ss), nil
}

if next := s.streamMiddleware.Match(info.FullMethod); len(next) > 0 {
middleware.Chain(next...)(h)
}

ctx = context.WithValue(ctx, stream{
ServerStream: ss,
streamMiddleware: s.streamMiddleware,
}, ss)
ws := NewWrappedStream(ctx, ss, s.streamMiddleware)

err := handler(srv, ws)
if len(replyHeader) > 0 {
Expand All @@ -85,3 +101,48 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
return err
}
}

type stream struct {
grpc.ServerStream
streamMiddleware matcher.Matcher
}

func GetStream(ctx context.Context) grpc.ServerStream {
return ctx.Value(stream{}).(grpc.ServerStream)
}

func (w *wrappedStream) SendMsg(m interface{}) error {
h := func(_ context.Context, req interface{}) (interface{}, error) {
return req, w.ServerStream.SendMsg(m)
}

info, ok := transport.FromServerContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}

if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}

_, err := h(w.ctx, m)
return err
}

func (w *wrappedStream) RecvMsg(m interface{}) error {
h := func(_ context.Context, req interface{}) (interface{}, error) {
return req, w.ServerStream.RecvMsg(m)
}

info, ok := transport.FromServerContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}

if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}

_, err := h(w.ctx, m)
return err
}
52 changes: 30 additions & 22 deletions transport/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ func Middleware(m ...middleware.Middleware) ServerOption {
}
}

func StreamMiddleware(m ...middleware.Middleware) ServerOption {
return func(s *Server) {
s.streamMiddleware.Use(m...)
}
}

// CustomHealth Checks server.
func CustomHealth() ServerOption {
return func(s *Server) {
Expand Down Expand Up @@ -117,33 +123,35 @@ func Options(opts ...grpc.ServerOption) ServerOption {
// Server is a gRPC server wrapper.
type Server struct {
*grpc.Server
baseCtx context.Context
tlsConf *tls.Config
lis net.Listener
err error
network string
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
health *health.Server
customHealth bool
metadata *apimd.Server
adminClean func()
baseCtx context.Context
tlsConf *tls.Config
lis net.Listener
err error
network string
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
streamMiddleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
health *health.Server
customHealth bool
metadata *apimd.Server
adminClean func()
}

// NewServer creates a gRPC server by options.
func NewServer(opts ...ServerOption) *Server {
srv := &Server{
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
streamMiddleware: matcher.New(),
}
for _, o := range opts {
o(srv)
Expand Down
77 changes: 77 additions & 0 deletions transport/grpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/internal/matcher"
Expand Down Expand Up @@ -280,6 +281,82 @@ func TestServer_unaryServerInterceptor(t *testing.T) {
}
}

type mockServerStream struct {
ctx context.Context
sentMsg interface{}
recvMsg interface{}
metadata metadata.MD
grpc.ServerStream
}

func (m *mockServerStream) SetHeader(md metadata.MD) error {
m.metadata = md
return nil
}

func (m *mockServerStream) SendHeader(md metadata.MD) error {
m.metadata = md
return nil
}

func (m *mockServerStream) SetTrailer(md metadata.MD) {
m.metadata = md
}

func (m *mockServerStream) Context() context.Context {
return m.ctx
}

func (m *mockServerStream) SendMsg(msg interface{}) error {
m.sentMsg = msg
return nil
}

func (m *mockServerStream) RecvMsg(msg interface{}) error {
m.recvMsg = msg
return nil
}

func TestServer_streamServerInterceptor(t *testing.T) {
u, err := url.Parse("grpc://hello/world")
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
srv := &Server{
baseCtx: context.Background(),
endpoint: u,
timeout: time.Duration(10),
middleware: matcher.New(),
streamMiddleware: matcher.New(),
}

srv.streamMiddleware.Use(EmptyMiddleware())

mockStream := &mockServerStream{
ctx: srv.baseCtx,
}

handler := func(_ interface{}, stream grpc.ServerStream) error {
resp := &testResp{Data: "stream hi"}
return stream.SendMsg(resp)
}

info := &grpc.StreamServerInfo{
FullMethod: "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
}

err = srv.streamServerInterceptor()(nil, mockStream, info, handler)
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}

// Check response
resp := mockStream.sentMsg.(*testResp)
if !reflect.DeepEqual("stream hi", resp.Data) {
t.Errorf("expect %s, got %s", "stream hi", resp.Data)
}
}

func TestListener(t *testing.T) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
Expand Down

0 comments on commit e1f5dc4

Please sign in to comment.