From c573752b32b4d0dedf6c5e96dbc2bf61b06b278e Mon Sep 17 00:00:00 2001 From: Ehsan Noureddin Moosa Date: Sat, 12 Feb 2022 19:06:11 +0400 Subject: [PATCH] refactor Selector --- desc/desc.go | 28 ++------- desc/impl.go | 18 ++++-- envelope.go | 33 +++++++++- service.go | 1 + std/bundle/fasthttp/bundle.go | 104 ++++++++++++++++---------------- std/bundle/fasthttp/router.go | 18 +++--- std/bundle/fasthttp/selector.go | 63 +++++++------------ std/bundle/fastws/bundle.go | 18 ++---- std/bundle/fastws/selector.go | 32 +++------- utils/spinlock.go | 2 +- 10 files changed, 148 insertions(+), 169 deletions(-) diff --git a/desc/desc.go b/desc/desc.go index 6fbf320b..7b507fb9 100644 --- a/desc/desc.go +++ b/desc/desc.go @@ -1,24 +1,16 @@ package desc import ( - "reflect" - "github.com/ronaksoft/ronykit" ) -// Selector is the interface which should be provided by the Bundle developer. That is we -// make it as interface instead of concrete implementation. -type Selector interface { - Generate(f ronykit.MessageFactory) ronykit.RouteSelector -} - // Contract is the description of the ronykit.Contract you are going to create. type Contract struct { Name string Handlers []ronykit.Handler Input ronykit.Message Output ronykit.Message - Selectors []Selector + Selectors []ronykit.RouteSelector Modifiers []ronykit.Modifier } @@ -44,7 +36,7 @@ func (c *Contract) SetOutput(m ronykit.Message) *Contract { return c } -func (c *Contract) AddSelector(s Selector) *Contract { +func (c *Contract) AddSelector(s ronykit.RouteSelector) *Contract { c.Selectors = append(c.Selectors, s) return c @@ -71,25 +63,13 @@ func (c *Contract) SetHandler(h ...ronykit.Handler) *Contract { func (c *Contract) Generate() []ronykit.Contract { //nolint:prealloc var contracts []ronykit.Contract - makeFunc := func(m ronykit.Message) func(args []reflect.Value) (results []reflect.Value) { - return func(args []reflect.Value) (results []reflect.Value) { - return []reflect.Value{reflect.New(reflect.TypeOf(m).Elem())} - } - } for _, s := range c.Selectors { ci := &contractImpl{} ci.setHandler(c.Handlers...) ci.setModifier(c.Modifiers...) - - reflect.ValueOf(&ci.factoryFunc).Elem().Set( - reflect.MakeFunc( - reflect.TypeOf(ci.factoryFunc), - makeFunc(c.Input), - ), - ) - - ci.setSelector(s.Generate(ci.factoryFunc)) + ci.setInput(c.Input) + ci.setSelector(s) contracts = append(contracts, ci) } diff --git a/desc/impl.go b/desc/impl.go index ccf09265..82658556 100644 --- a/desc/impl.go +++ b/desc/impl.go @@ -46,10 +46,16 @@ func (s *serviceImpl) addContract(contracts ...ronykit.Contract) *serviceImpl { // contractImpl is simple implementation of ronykit.Contract interface. type contractImpl struct { - selector ronykit.RouteSelector - handlers []ronykit.Handler - modifiers []ronykit.Modifier - factoryFunc func() ronykit.Message + selector ronykit.RouteSelector + handlers []ronykit.Handler + modifiers []ronykit.Modifier + input ronykit.Message +} + +func (r *contractImpl) setInput(input ronykit.Message) *contractImpl { + r.input = input + + return r } func (r *contractImpl) setSelector(selector ronykit.RouteSelector) *contractImpl { @@ -81,3 +87,7 @@ func (r *contractImpl) Handlers() []ronykit.Handler { func (r *contractImpl) Modifiers() []ronykit.Modifier { return r.modifiers } + +func (r *contractImpl) Input() ronykit.Message { + return r.input +} diff --git a/envelope.go b/envelope.go index 2dd152aa..e94751b6 100644 --- a/envelope.go +++ b/envelope.go @@ -1,6 +1,7 @@ package ronykit import ( + "reflect" "sync" "github.com/ronaksoft/ronykit/utils" @@ -8,6 +9,10 @@ import ( var envelopePool = &sync.Pool{} +type Walker interface { + Walk(f func(k, v string) bool) +} + type Envelope struct { ctx *Context conn Conn @@ -59,6 +64,18 @@ func (e *Envelope) SetHdr(key, value string) *Envelope { return e } +func (e *Envelope) SetHdrWalker(walker Walker) *Envelope { + e.kvl.Lock() + walker.Walk(func(k, v string) bool { + e.kv[k] = v + + return true + }) + e.kvl.Unlock() + + return e +} + func (e *Envelope) SetHdrMap(kv map[string]string) *Envelope { e.kvl.Lock() for k, v := range kv { @@ -120,12 +137,26 @@ func (e *Envelope) Send() { e.release() } -type MessageFactory func() Message +type MessageFactoryFunc func() Message type Message interface { Marshal() ([]byte, error) } +func CreateMessageFactory(in Message) MessageFactoryFunc { + var ff MessageFactoryFunc + reflect.ValueOf(&ff).Elem().Set( + reflect.MakeFunc( + reflect.TypeOf(ff), + func(args []reflect.Value) (results []reflect.Value) { + return []reflect.Value{reflect.New(reflect.TypeOf(in).Elem())} + }, + ), + ) + + return ff +} + // RawMessage is a bytes slice which could be used as Message. This is helpful for // raw data messages. type RawMessage []byte diff --git a/service.go b/service.go index 71914eb4..81c7f88d 100644 --- a/service.go +++ b/service.go @@ -38,6 +38,7 @@ type Service interface { // order to be usable by Bundle 'b' otherwise it panics. type Contract interface { RouteSelector + Input() Message Handlers() []Handler Modifiers() []Modifier } diff --git a/std/bundle/fasthttp/bundle.go b/std/bundle/fasthttp/bundle.go index dcfba9bc..501b6416 100644 --- a/std/bundle/fasthttp/bundle.go +++ b/std/bundle/fasthttp/bundle.go @@ -12,7 +12,7 @@ import ( ) const ( - queryMethod = "fasthttp.method" + queryMethod = "fasthttp.Method" queryPath = "fasthttp.path" queryDecoder = "fasthttp.decoder" ) @@ -83,23 +83,23 @@ func (r *bundle) wsHandler(ctx *fasthttp.RequestCtx) { } func (r *bundle) Register(svc ronykit.Service) { - for _, rt := range svc.Contracts() { + for _, contract := range svc.Contracts() { var h []ronykit.Handler h = append(h, svc.PreHandlers()...) - h = append(h, rt.Handlers()...) + h = append(h, contract.Handlers()...) h = append(h, svc.PostHandlers()...) - method, ok := rt.Query(queryMethod).(string) + method, ok := contract.Query(queryMethod).(string) if !ok { continue } - path, ok := rt.Query(queryPath).(string) + path, ok := contract.Query(queryPath).(string) if !ok { continue } - decoder, ok := rt.Query(queryDecoder).(DecoderFunc) + decoder, ok := contract.Query(queryDecoder).(DecoderFunc) if !ok { - continue + decoder = reflectDecoder(ronykit.CreateMessageFactory(contract.Input())) } r.mux.Handle( @@ -110,57 +110,29 @@ func (r *bundle) Register(svc ronykit.Service) { ServiceName: svc.Name(), Decoder: decoder, Handlers: h, - Modifiers: rt.Modifiers(), + Modifiers: contract.Modifiers(), }, ) } } func (r *bundle) Dispatch(c ronykit.Conn, in []byte) (ronykit.DispatchFunc, error) { - rc, ok := c.(*httpConn) + conn, ok := c.(*httpConn) if !ok { panic("BUG!! incorrect connection") } - routeData, params, _ := r.mux.Lookup(rc.GetMethod(), rc.GetPath()) - if routeData == nil { - if r.cors != nil { - // ByPass cors (Cross Origin Resource Sharing) check - if r.cors.origins == "*" { - rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, rc.Get(headerOrigin)) - } else { - rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, r.cors.origins) - } - - if rc.ctx.IsOptions() { - reqHeaders := rc.ctx.Request.Header.Peek(headerAccessControlRequestHeaders) - if len(reqHeaders) > 0 { - rc.ctx.Response.Header.SetBytesV(headerAccessControlAllowHeaders, reqHeaders) - } else { - rc.ctx.Response.Header.Set(headerAccessControlAllowHeaders, r.cors.headers) - } - - rc.ctx.Response.Header.Set(headerAccessControlAllowMethods, r.cors.methods) - rc.ctx.SetStatusCode(fasthttp.StatusNoContent) - } else { - rc.ctx.SetStatusCode(fasthttp.StatusNotImplemented) - } - } + routeData, params, _ := r.mux.Lookup(conn.GetMethod(), conn.GetPath()) - return nil, errRouteNotFound - } + // check CORS rules + r.handleCORS(conn, routeData != nil) - if r.cors != nil { - // ByPass cors (Cross Origin Resource Sharing) check - if r.cors.origins == "*" { - rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, rc.Get(headerOrigin)) - } else { - rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, r.cors.origins) - } + if routeData == nil { + return nil, errRouteNotFound } // Walk over all the query params - rc.ctx.QueryArgs().VisitAll( + conn.ctx.QueryArgs().VisitAll( func(key, value []byte) { params = append( params, @@ -202,26 +174,52 @@ func (r *bundle) Dispatch(c ronykit.Conn, in []byte) (ronykit.DispatchFunc, erro } return func(ctx *ronykit.Context, execFunc ronykit.ExecuteFunc) error { - // Walk over all the connection headers - rc.Walk( - func(key string, val string) bool { - ctx.In().SetHdr(key, val) - - return true - }, - ) - // Set the route and service name ctx.Set(ronykit.CtxServiceName, routeData.ServiceName) ctx.Set(ronykit.CtxRoute, fmt.Sprintf("%s %s", routeData.Method, routeData.Path)) - ctx.In().SetMsg(routeData.Decoder(params, in)) + ctx.In(). + SetHdrWalker(conn). + SetMsg(routeData.Decoder(params, in)) + + // execute handler functions execFunc(writeFunc, routeData.Handlers...) return nil }, nil } +func (r *bundle) handleCORS(rc *httpConn, routeFound bool) { + if r.cors == nil { + return + } + + // ByPass cors (Cross Origin Resource Sharing) check + if r.cors.origins == "*" { + rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, rc.Get(headerOrigin)) + } else { + rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, r.cors.origins) + } + + if routeFound { + return + } + + if rc.ctx.IsOptions() { + reqHeaders := rc.ctx.Request.Header.Peek(headerAccessControlRequestHeaders) + if len(reqHeaders) > 0 { + rc.ctx.Response.Header.SetBytesV(headerAccessControlAllowHeaders, reqHeaders) + } else { + rc.ctx.Response.Header.Set(headerAccessControlAllowHeaders, r.cors.headers) + } + + rc.ctx.Response.Header.Set(headerAccessControlAllowMethods, r.cors.methods) + rc.ctx.SetStatusCode(fasthttp.StatusNoContent) + } else { + rc.ctx.SetStatusCode(fasthttp.StatusNotImplemented) + } +} + func (r *bundle) Start() { ln, err := net.Listen("tcp4", r.listen) if err != nil { diff --git a/std/bundle/fasthttp/router.go b/std/bundle/fasthttp/router.go index e6eb5a44..eb19c396 100644 --- a/std/bundle/fasthttp/router.go +++ b/std/bundle/fasthttp/router.go @@ -43,7 +43,7 @@ type mux struct { // RedirectTrailingSlash is independent of this option. RedirectFixedPath bool - // If enabled, the mux checks if another method is allowed for the + // If enabled, the mux checks if another Method is allowed for the // current route, if the current request can not be routed. // If this is the case, the request is answered with 'Method Not Allowed' // and HTTP status code 405. @@ -131,7 +131,7 @@ func (r *mux) DELETE(path string, handle *routeData) { r.Handle(http.MethodDelete, path, handle) } -// Handle registers a new request handle with the given path and method. +// Handle registers a new request handle with the given path and Method. // // For GET, POST, PUT, PATCH and DELETE requests the respective shortcut // functions can be used. @@ -143,7 +143,7 @@ func (r *mux) Handle(method, path string, handle *routeData) { varsCount := uint16(0) if method == "" { - panic("method must not be empty") + panic("Method must not be empty") } if len(path) < 1 || path[0] != '/' { panic("path must begin with '/' in path '" + path + "'") @@ -178,7 +178,7 @@ func (r *mux) Handle(method, path string, handle *routeData) { } } -// Lookup allows the manual lookup of a method + path combo. +// Lookup allows the manual lookup of a Method + path combo. // This is e.g. useful to build a framework around this mux. // If the path was found, it returns the handle function and the path parameter // values. Otherwise the third return value indicates whether a redirection to @@ -205,13 +205,13 @@ func (r *mux) allowed(path, reqMethod string) (allow string) { allowed := make([]string, 0, 9) if path == "*" { // server-wide - // empty method is used for internal calls to refresh the cache + // empty Method is used for internal calls to refresh the cache if reqMethod == "" { for method := range r.trees { if method == http.MethodOptions { continue } - // Add request method to list of allowed methods + // Add request Method to list of allowed methods allowed = append(allowed, method) } } else { @@ -219,21 +219,21 @@ func (r *mux) allowed(path, reqMethod string) (allow string) { } } else { // specific path for method := range r.trees { - // Skip the requested method - we already tried this one + // Skip the requested Method - we already tried this one if method == reqMethod || method == http.MethodOptions { continue } handle, _, _ := r.trees[method].getValue(path, nil) if handle != nil { - // Add request method to list of allowed methods + // Add request Method to list of allowed methods allowed = append(allowed, method) } } } if len(allowed) > 0 { - // Add request method to list of allowed methods + // Add request Method to list of allowed methods allowed = append(allowed, http.MethodOptions) // Sort allowed methods. diff --git a/std/bundle/fasthttp/selector.go b/std/bundle/fasthttp/selector.go index 5920660c..5c2d4802 100644 --- a/std/bundle/fasthttp/selector.go +++ b/std/bundle/fasthttp/selector.go @@ -1,11 +1,11 @@ package fasthttp import ( - "encoding/json" "reflect" "strings" "unsafe" + "github.com/goccy/go-json" "github.com/ronaksoft/ronykit" "github.com/ronaksoft/ronykit/utils" ) @@ -35,64 +35,43 @@ func (ps Params) ByName(name string) string { type ( DecoderFunc func(bag Params, data []byte) ronykit.Message - Selector struct { - Method string - Path string - Predicate string - CustomDecoder DecoderFunc - } ) -func (sd Selector) Generate(f ronykit.MessageFactory) ronykit.RouteSelector { - route := &routeSelector{ - method: sd.Method, - path: sd.Path, - predicate: sd.Predicate, - } - if sd.CustomDecoder != nil { - route.decoder = sd.CustomDecoder - } else { - route.decoder = reflectDecoder(f) - } - - return route -} - var ( - _ ronykit.RouteSelector = routeSelector{} - _ ronykit.RESTRouteSelector = routeSelector{} - _ ronykit.RPCRouteSelector = routeSelector{} + _ ronykit.RouteSelector = Selector{} + _ ronykit.RESTRouteSelector = Selector{} + _ ronykit.RPCRouteSelector = Selector{} ) -// routeSelector selector implements ronykit.RouteSelector and +// Selector implements ronykit.RouteSelector and // also ronykit.RPCRouteSelector and ronykit.RESTRouteSelector -type routeSelector struct { - method string - path string - predicate string - decoder DecoderFunc +type Selector struct { + Method string + Path string + Predicate string + Decoder DecoderFunc } -func (r routeSelector) GetMethod() string { - return r.method +func (r Selector) GetMethod() string { + return r.Method } -func (r routeSelector) GetPath() string { - return r.path +func (r Selector) GetPath() string { + return r.Path } -func (r routeSelector) GetPredicate() string { - return r.predicate +func (r Selector) GetPredicate() string { + return r.Predicate } -func (r routeSelector) Query(q string) interface{} { +func (r Selector) Query(q string) interface{} { switch q { case queryDecoder: - return r.decoder + return r.Decoder case queryMethod: - return r.method + return r.Method case queryPath: - return r.path + return r.Path } return nil @@ -118,7 +97,7 @@ type paramCaster struct { typ reflect.Type } -func reflectDecoder(factory ronykit.MessageFactory) DecoderFunc { +func reflectDecoder(factory ronykit.MessageFactoryFunc) DecoderFunc { rVal := reflect.ValueOf(factory()) rType := rVal.Type() if rType.Kind() != reflect.Ptr { diff --git a/std/bundle/fastws/bundle.go b/std/bundle/fastws/bundle.go index 6a63c701..91d2773c 100644 --- a/std/bundle/fastws/bundle.go +++ b/std/bundle/fastws/bundle.go @@ -12,8 +12,7 @@ import ( ) const ( - queryPredicate = "__rpc__predicate" - queryFactory = "__rpc__factory" + queryPredicate = "fastws.predicate" ) type bundle struct { @@ -56,18 +55,13 @@ func MustNew(opts ...Option) *bundle { } func (b *bundle) Register(svc ronykit.Service) { - for _, rt := range svc.Contracts() { + for _, contract := range svc.Contracts() { var h []ronykit.Handler h = append(h, svc.PreHandlers()...) - h = append(h, rt.Handlers()...) + h = append(h, contract.Handlers()...) h = append(h, svc.PostHandlers()...) - predicate, ok := rt.Query(queryPredicate).(string) - if !ok { - continue - } - - factory, ok := rt.Query(queryFactory).(ronykit.MessageFactory) + predicate, ok := contract.Query(queryPredicate).(string) if !ok { continue } @@ -76,8 +70,8 @@ func (b *bundle) Register(svc ronykit.Service) { ServiceName: svc.Name(), Predicate: predicate, Handlers: h, - Modifiers: rt.Modifiers(), - Factory: factory, + Modifiers: contract.Modifiers(), + Factory: ronykit.CreateMessageFactory(contract.Input()), } } } diff --git a/std/bundle/fastws/selector.go b/std/bundle/fastws/selector.go index 5267e22b..0fd582e3 100644 --- a/std/bundle/fastws/selector.go +++ b/std/bundle/fastws/selector.go @@ -4,37 +4,23 @@ import ( "github.com/ronaksoft/ronykit" ) -type Selector struct { - Predicate string -} - -func (s Selector) Generate(f ronykit.MessageFactory) ronykit.RouteSelector { - return &routeSelector{ - predicate: s.Predicate, - factory: f, - } -} - var ( - _ ronykit.RouteSelector = routeSelector{} - _ ronykit.RPCRouteSelector = routeSelector{} + _ ronykit.RouteSelector = Selector{} + _ ronykit.RPCRouteSelector = Selector{} ) -type routeSelector struct { - predicate string - factory ronykit.MessageFactory +type Selector struct { + Predicate string } -func (r routeSelector) GetPredicate() string { - return r.predicate +func (r Selector) GetPredicate() string { + return r.Predicate } -func (r routeSelector) Query(q string) interface{} { +func (r Selector) Query(q string) interface{} { switch q { case queryPredicate: - return r.predicate - case queryFactory: - return r.factory + return r.Predicate } return nil @@ -45,7 +31,7 @@ type routerData struct { Predicate string Handlers []ronykit.Handler Modifiers []ronykit.Modifier - Factory ronykit.MessageFactory + Factory ronykit.MessageFactoryFunc } type mux struct { diff --git a/utils/spinlock.go b/utils/spinlock.go index f4f3ddbf..d095d861 100644 --- a/utils/spinlock.go +++ b/utils/spinlock.go @@ -20,8 +20,8 @@ import ( // A SpinLock must not be copied after first use. // This SpinLock intended to be used to synchronize exceptionally short-lived operations. type SpinLock struct { - _ sync.Mutex // for copy protection compiler warning lock uintptr + _ sync.Mutex // for copy protection compiler warning } // Lock locks l.