Skip to content

Commit

Permalink
feat: get subscription context
Browse files Browse the repository at this point in the history
  • Loading branch information
zengchen221 committed Oct 15, 2019
1 parent b1e6221 commit 5009595
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 40 deletions.
35 changes: 12 additions & 23 deletions endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ package rpc

import (
"net"
"net/http"
"net/url"
)

// StartHTTPEndpoint starts the HTTP RPC endpoint, configured with cors/vhosts/modules
func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []string, vhosts []string, timeouts HTTPTimeouts) (net.Listener, *Server, *http.Server, error) {
func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []string, vhosts []string, timeouts HTTPTimeouts) (net.Listener, *Server, error) {
// Generate the whitelist based on the allowed modules
whitelist := make(map[string]bool)
for _, module := range modules {
Expand All @@ -34,7 +33,7 @@ func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []str
for _, api := range apis {
if whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
if err := handler.RegisterName(api.Namespace, api.Service); err != nil {
return nil, nil, nil, err
return nil, nil, err
}
logger.Debug("HTTP registered ", "namespace ", api.Namespace)
}
Expand All @@ -46,23 +45,18 @@ func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []str
)
network, address, err := scheme(endpoint)
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}
if listener, err = net.Listen(network, address); err != nil {
return nil, nil, nil, err
return nil, nil, err
}

hServer := new(http.Server)
go func(hServer *http.Server) {
hServer = NewHTTPServer(cors, vhosts, timeouts, handler)
hServer.Serve(listener)
}(hServer)

return listener, handler, hServer, err
go NewHTTPServer(cors, vhosts, timeouts, handler).Serve(listener)
return listener, handler, err
}

// StartWSEndpoint starts a websocket endpoint
func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []string, exposeAll bool) (net.Listener, *Server, *http.Server, error) {
func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []string, exposeAll bool) (net.Listener, *Server, error) {

// Generate the whitelist based on the allowed modules
whitelist := make(map[string]bool)
Expand All @@ -74,7 +68,7 @@ func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []
for _, api := range apis {
if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
if err := handler.RegisterName(api.Namespace, api.Service); err != nil {
return nil, nil, nil, err
return nil, nil, err
}
logger.Debug("WebSocket registered ", " service ", api.Service, " namespace ", api.Namespace)
}
Expand All @@ -86,20 +80,15 @@ func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []
)
network, address, err := scheme(endpoint)
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}

if listener, err = net.Listen(network, address); err != nil {
return nil, nil, nil, err
return nil, nil, err
}

hServer := new(http.Server)
go func(hServer *http.Server) {
hServer = NewWSServer(wsOrigins, handler)
hServer.Serve(listener)
}(hServer)

return listener, handler, hServer, err
go NewWSServer(wsOrigins, handler).Serve(listener)
return listener, handler, err

}

Expand Down
28 changes: 28 additions & 0 deletions subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"encoding/json"
"errors"
"math/rand"
"net/http"
"reflect"
"strings"
"sync"
Expand Down Expand Up @@ -325,3 +326,30 @@ func (sub *ClientSubscription) requestUnsubscribe() error {
var result interface{}
return sub.client.Call(&result, sub.namespace+unsubscribeMethodSuffix, sub.subid)
}

func SubscriptionContext() context.Context {
ctx := context.Background()
handler := new(handler)
handler.idgen = sequentialIDGenerator()
r := new(http.Request)
w := new(http.ResponseWriter)
handler.conn = newHTTPServerConn(r, *w)
return context.WithValue(ctx, notifierKey{}, &Notifier{
h: handler,
})
}

func sequentialIDGenerator() func() ID {
var (
mu sync.Mutex
counter uint64
)
return func() ID {
mu.Lock()
defer mu.Unlock()
counter++
id := make([]byte, 8)
binary.BigEndian.PutUint64(id, counter)
return encodeID(id)
}
}
17 changes: 0 additions & 17 deletions testservice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ package rpc

import (
"context"
"encoding/binary"
"errors"
"sync"
"time"
)

Expand All @@ -36,21 +34,6 @@ func newTestServer() *Server {
return server
}

func sequentialIDGenerator() func() ID {
var (
mu sync.Mutex
counter uint64
)
return func() ID {
mu.Lock()
defer mu.Unlock()
counter++
id := make([]byte, 8)
binary.BigEndian.PutUint64(id, counter)
return encodeID(id)
}
}

type testService struct{}

type Args struct {
Expand Down

0 comments on commit 5009595

Please sign in to comment.