Skip to content

Commit

Permalink
refactor Selector
Browse files Browse the repository at this point in the history
  • Loading branch information
ehsannm committed Feb 12, 2022
1 parent fb88f1f commit c573752
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 169 deletions.
28 changes: 4 additions & 24 deletions desc/desc.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
package desc

import (
"reflect"

"github.com/ronaksoft/ronykit"
)

// Selector is the interface which should be provided by the Bundle developer. That is we
// make it as interface instead of concrete implementation.
type Selector interface {
Generate(f ronykit.MessageFactory) ronykit.RouteSelector
}

// Contract is the description of the ronykit.Contract you are going to create.
type Contract struct {
Name string
Handlers []ronykit.Handler
Input ronykit.Message
Output ronykit.Message
Selectors []Selector
Selectors []ronykit.RouteSelector
Modifiers []ronykit.Modifier
}

Expand All @@ -44,7 +36,7 @@ func (c *Contract) SetOutput(m ronykit.Message) *Contract {
return c
}

func (c *Contract) AddSelector(s Selector) *Contract {
func (c *Contract) AddSelector(s ronykit.RouteSelector) *Contract {
c.Selectors = append(c.Selectors, s)

return c
Expand All @@ -71,25 +63,13 @@ func (c *Contract) SetHandler(h ...ronykit.Handler) *Contract {
func (c *Contract) Generate() []ronykit.Contract {
//nolint:prealloc
var contracts []ronykit.Contract
makeFunc := func(m ronykit.Message) func(args []reflect.Value) (results []reflect.Value) {
return func(args []reflect.Value) (results []reflect.Value) {
return []reflect.Value{reflect.New(reflect.TypeOf(m).Elem())}
}
}

for _, s := range c.Selectors {
ci := &contractImpl{}
ci.setHandler(c.Handlers...)
ci.setModifier(c.Modifiers...)

reflect.ValueOf(&ci.factoryFunc).Elem().Set(
reflect.MakeFunc(
reflect.TypeOf(ci.factoryFunc),
makeFunc(c.Input),
),
)

ci.setSelector(s.Generate(ci.factoryFunc))
ci.setInput(c.Input)
ci.setSelector(s)

contracts = append(contracts, ci)
}
Expand Down
18 changes: 14 additions & 4 deletions desc/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,16 @@ func (s *serviceImpl) addContract(contracts ...ronykit.Contract) *serviceImpl {

// contractImpl is simple implementation of ronykit.Contract interface.
type contractImpl struct {
selector ronykit.RouteSelector
handlers []ronykit.Handler
modifiers []ronykit.Modifier
factoryFunc func() ronykit.Message
selector ronykit.RouteSelector
handlers []ronykit.Handler
modifiers []ronykit.Modifier
input ronykit.Message
}

func (r *contractImpl) setInput(input ronykit.Message) *contractImpl {
r.input = input

return r
}

func (r *contractImpl) setSelector(selector ronykit.RouteSelector) *contractImpl {
Expand Down Expand Up @@ -81,3 +87,7 @@ func (r *contractImpl) Handlers() []ronykit.Handler {
func (r *contractImpl) Modifiers() []ronykit.Modifier {
return r.modifiers
}

func (r *contractImpl) Input() ronykit.Message {
return r.input
}
33 changes: 32 additions & 1 deletion envelope.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package ronykit

import (
"reflect"
"sync"

"github.com/ronaksoft/ronykit/utils"
)

var envelopePool = &sync.Pool{}

type Walker interface {
Walk(f func(k, v string) bool)
}

type Envelope struct {
ctx *Context
conn Conn
Expand Down Expand Up @@ -59,6 +64,18 @@ func (e *Envelope) SetHdr(key, value string) *Envelope {
return e
}

func (e *Envelope) SetHdrWalker(walker Walker) *Envelope {
e.kvl.Lock()
walker.Walk(func(k, v string) bool {
e.kv[k] = v

return true
})
e.kvl.Unlock()

return e
}

func (e *Envelope) SetHdrMap(kv map[string]string) *Envelope {
e.kvl.Lock()
for k, v := range kv {
Expand Down Expand Up @@ -120,12 +137,26 @@ func (e *Envelope) Send() {
e.release()
}

type MessageFactory func() Message
type MessageFactoryFunc func() Message

type Message interface {
Marshal() ([]byte, error)
}

func CreateMessageFactory(in Message) MessageFactoryFunc {
var ff MessageFactoryFunc
reflect.ValueOf(&ff).Elem().Set(
reflect.MakeFunc(
reflect.TypeOf(ff),
func(args []reflect.Value) (results []reflect.Value) {
return []reflect.Value{reflect.New(reflect.TypeOf(in).Elem())}
},
),
)

return ff
}

// RawMessage is a bytes slice which could be used as Message. This is helpful for
// raw data messages.
type RawMessage []byte
Expand Down
1 change: 1 addition & 0 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type Service interface {
// order to be usable by Bundle 'b' otherwise it panics.
type Contract interface {
RouteSelector
Input() Message
Handlers() []Handler
Modifiers() []Modifier
}
Expand Down
104 changes: 51 additions & 53 deletions std/bundle/fasthttp/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

const (
queryMethod = "fasthttp.method"
queryMethod = "fasthttp.Method"
queryPath = "fasthttp.path"
queryDecoder = "fasthttp.decoder"
)
Expand Down Expand Up @@ -83,23 +83,23 @@ func (r *bundle) wsHandler(ctx *fasthttp.RequestCtx) {
}

func (r *bundle) Register(svc ronykit.Service) {
for _, rt := range svc.Contracts() {
for _, contract := range svc.Contracts() {
var h []ronykit.Handler
h = append(h, svc.PreHandlers()...)
h = append(h, rt.Handlers()...)
h = append(h, contract.Handlers()...)
h = append(h, svc.PostHandlers()...)

method, ok := rt.Query(queryMethod).(string)
method, ok := contract.Query(queryMethod).(string)
if !ok {
continue
}
path, ok := rt.Query(queryPath).(string)
path, ok := contract.Query(queryPath).(string)
if !ok {
continue
}
decoder, ok := rt.Query(queryDecoder).(DecoderFunc)
decoder, ok := contract.Query(queryDecoder).(DecoderFunc)
if !ok {
continue
decoder = reflectDecoder(ronykit.CreateMessageFactory(contract.Input()))
}

r.mux.Handle(
Expand All @@ -110,57 +110,29 @@ func (r *bundle) Register(svc ronykit.Service) {
ServiceName: svc.Name(),
Decoder: decoder,
Handlers: h,
Modifiers: rt.Modifiers(),
Modifiers: contract.Modifiers(),
},
)
}
}

func (r *bundle) Dispatch(c ronykit.Conn, in []byte) (ronykit.DispatchFunc, error) {
rc, ok := c.(*httpConn)
conn, ok := c.(*httpConn)
if !ok {
panic("BUG!! incorrect connection")
}

routeData, params, _ := r.mux.Lookup(rc.GetMethod(), rc.GetPath())
if routeData == nil {
if r.cors != nil {
// ByPass cors (Cross Origin Resource Sharing) check
if r.cors.origins == "*" {
rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, rc.Get(headerOrigin))
} else {
rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, r.cors.origins)
}

if rc.ctx.IsOptions() {
reqHeaders := rc.ctx.Request.Header.Peek(headerAccessControlRequestHeaders)
if len(reqHeaders) > 0 {
rc.ctx.Response.Header.SetBytesV(headerAccessControlAllowHeaders, reqHeaders)
} else {
rc.ctx.Response.Header.Set(headerAccessControlAllowHeaders, r.cors.headers)
}

rc.ctx.Response.Header.Set(headerAccessControlAllowMethods, r.cors.methods)
rc.ctx.SetStatusCode(fasthttp.StatusNoContent)
} else {
rc.ctx.SetStatusCode(fasthttp.StatusNotImplemented)
}
}
routeData, params, _ := r.mux.Lookup(conn.GetMethod(), conn.GetPath())

return nil, errRouteNotFound
}
// check CORS rules
r.handleCORS(conn, routeData != nil)

if r.cors != nil {
// ByPass cors (Cross Origin Resource Sharing) check
if r.cors.origins == "*" {
rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, rc.Get(headerOrigin))
} else {
rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, r.cors.origins)
}
if routeData == nil {
return nil, errRouteNotFound
}

// Walk over all the query params
rc.ctx.QueryArgs().VisitAll(
conn.ctx.QueryArgs().VisitAll(
func(key, value []byte) {
params = append(
params,
Expand Down Expand Up @@ -202,26 +174,52 @@ func (r *bundle) Dispatch(c ronykit.Conn, in []byte) (ronykit.DispatchFunc, erro
}

return func(ctx *ronykit.Context, execFunc ronykit.ExecuteFunc) error {
// Walk over all the connection headers
rc.Walk(
func(key string, val string) bool {
ctx.In().SetHdr(key, val)

return true
},
)

// Set the route and service name
ctx.Set(ronykit.CtxServiceName, routeData.ServiceName)
ctx.Set(ronykit.CtxRoute, fmt.Sprintf("%s %s", routeData.Method, routeData.Path))

ctx.In().SetMsg(routeData.Decoder(params, in))
ctx.In().
SetHdrWalker(conn).
SetMsg(routeData.Decoder(params, in))

// execute handler functions
execFunc(writeFunc, routeData.Handlers...)

return nil
}, nil
}

func (r *bundle) handleCORS(rc *httpConn, routeFound bool) {
if r.cors == nil {
return
}

// ByPass cors (Cross Origin Resource Sharing) check
if r.cors.origins == "*" {
rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, rc.Get(headerOrigin))
} else {
rc.ctx.Response.Header.Set(headerAccessControlAllowOrigin, r.cors.origins)
}

if routeFound {
return
}

if rc.ctx.IsOptions() {
reqHeaders := rc.ctx.Request.Header.Peek(headerAccessControlRequestHeaders)
if len(reqHeaders) > 0 {
rc.ctx.Response.Header.SetBytesV(headerAccessControlAllowHeaders, reqHeaders)
} else {
rc.ctx.Response.Header.Set(headerAccessControlAllowHeaders, r.cors.headers)
}

rc.ctx.Response.Header.Set(headerAccessControlAllowMethods, r.cors.methods)
rc.ctx.SetStatusCode(fasthttp.StatusNoContent)
} else {
rc.ctx.SetStatusCode(fasthttp.StatusNotImplemented)
}
}

func (r *bundle) Start() {
ln, err := net.Listen("tcp4", r.listen)
if err != nil {
Expand Down
Loading

0 comments on commit c573752

Please sign in to comment.