Skip to content

Commit

Permalink
Merge pull request #17 from nkryuchkov/feature/change-http-servemux-t…
Browse files Browse the repository at this point in the history
…o-chi

Change http.ServeMux to chi
  • Loading branch information
nkryuchkov authored Oct 16, 2020
2 parents c65f9c8 + 3ba8834 commit 0423423
Show file tree
Hide file tree
Showing 96 changed files with 11,134 additions and 178 deletions.
211 changes: 97 additions & 114 deletions cmd/dmsg-discovery/internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import (
"encoding/json"
"net"
"net/http"
"net/url"
"strings"
"time"

"github.com/gorilla/handlers"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/sirupsen/logrus"
"github.com/skycoin/skycoin/src/util/logging"

Expand All @@ -24,76 +23,69 @@ const maxGetAvailableServersResult = 512

// API represents the api of the dmsg-discovery service`
type API struct {
log logrus.FieldLogger
http.Handler
db store.Storer
testMode bool
mux *http.ServeMux
}

// New returns a new API object, which can be started as a server
func New(log logrus.FieldLogger, db store.Storer, testMode bool) *API {
if log != nil {
log = logging.MustGetLogger("dmsg_disc")
}

if db == nil {
panic("cannot create new api without a store.Storer")
}

mux := http.NewServeMux()
r := chi.NewRouter()
api := &API{
log: log,
Handler: r,
db: db,
testMode: testMode,
mux: mux,
}
mux.HandleFunc("/dmsg-discovery/entry/", api.muxEntry())
mux.HandleFunc("/dmsg-discovery/available_servers", api.getAvailableServers())
mux.HandleFunc("/dmsg-discovery/health", api.health())
return api
}

// ServeHTTP implements http.Handler.
func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log := a.log.WithField("_module", "dmsgdisc_api")
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Use(httputil.SetLoggerMiddleware(log))

w.Header().Set("Content-Type", "application/json")
handlers.CustomLoggingHandler(log.Writer(), a.mux, httputil.WriteLog).
ServeHTTP(w, r)
r.Get("/dmsg-discovery/entry/{pk}", api.getEntry())
r.Post("/dmsg-discovery/entry/", api.setEntry())
r.Post("/dmsg-discovery/entry/{pk}", api.setEntry())
r.Get("/dmsg-discovery/available_servers", api.getAvailableServers())
r.Get("/dmsg-discovery/health", api.health())

return api
}

// muxEntry calls either getEntry or setEntry depending on the
// http method used on the endpoint /dmsg-discovery/entry/:pk
func (a *API) muxEntry() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
a.setEntry(w, r)
default:
a.getEntry(w, r)
}
}
func (a *API) log(r *http.Request) logrus.FieldLogger {
return httputil.GetLogger(r)
}

// getEntry returns the entry associated with the given public key
// URI: /dmsg-discovery/entry/:pk
// Method: GET
func (a *API) getEntry(w http.ResponseWriter, r *http.Request) {
staticPK, err := retrievePkFromURL(r.URL)
if err != nil {
a.handleError(w, disc.ErrBadInput)
return
}
func (a *API) getEntry() func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
staticPK := cipher.PubKey{}
if err := staticPK.UnmarshalText([]byte(chi.URLParam(r, "pk"))); err != nil {
a.handleError(w, r, disc.ErrBadInput)
return
}

entry, err := a.db.Entry(r.Context(), staticPK)
entry, err := a.db.Entry(r.Context(), staticPK)

// If we make sure that every error is handled then we can
// remove the if and make the entry return the switch default
if err != nil {
a.handleError(w, err)
return
}
// If we make sure that every error is handled then we can
// remove the if and make the entry return the switch default
if err != nil {
a.handleError(w, r, err)
return
}

a.writeJSON(w, http.StatusOK, entry)
a.writeJSON(w, r, http.StatusOK, entry)
}
}

// setEntry adds a new entry associated with the given public key
Expand All @@ -103,75 +95,77 @@ func (a *API) getEntry(w http.ResponseWriter, r *http.Request) {
// Method: POST
// Args:
// json serialized entry object
func (a *API) setEntry(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := r.Body.Close(); err != nil {
log.WithError(err).Warn("Failed to decode HTTP response body")
}
}()
func (a *API) setEntry() func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := r.Body.Close(); err != nil {
log.WithError(err).Warn("Failed to decode HTTP response body")
}
}()

entryTimeout := time.Duration(0) // no timeout
entryTimeout := time.Duration(0) // no timeout

if timeout := r.URL.Query().Get("timeout"); timeout == "true" {
entryTimeout = store.DefaultTimeout
}
if timeout := r.URL.Query().Get("timeout"); timeout == "true" {
entryTimeout = store.DefaultTimeout
}

entry := new(disc.Entry)
if err := json.NewDecoder(r.Body).Decode(entry); err != nil {
a.handleError(w, disc.ErrUnexpected)
return
}
entry := new(disc.Entry)
if err := json.NewDecoder(r.Body).Decode(entry); err != nil {
a.handleError(w, r, disc.ErrUnexpected)
return
}

if entry.Server != nil && !a.testMode {
if ok, err := isLoopbackAddr(entry.Server.Address); ok {
if err != nil {
a.log(r).Warningf("failed to parse hostname and port: %s", err)
}

if entry.Server != nil && !a.testMode {
if ok, err := isLoopbackAddr(entry.Server.Address); ok {
if err != nil && a.log != nil {
a.log.Warningf("failed to parse hostname and port: %s", err)
a.handleError(w, r, disc.ErrValidationServerAddress)
return
}
}

a.handleError(w, disc.ErrValidationServerAddress)
if err := entry.Validate(); err != nil {
a.handleError(w, r, err)
return
}
}

if err := entry.Validate(); err != nil {
a.handleError(w, err)
return
}
if err := entry.VerifySignature(); err != nil {
a.handleError(w, r, disc.ErrUnauthorized)
return
}

if err := entry.VerifySignature(); err != nil {
a.handleError(w, disc.ErrUnauthorized)
return
}
// Recover previous entry. If key not found we insert with sequence 0
// If there was a previous entry we check the new one is a valid iteration
oldEntry, err := a.db.Entry(r.Context(), entry.Static)
if err == disc.ErrKeyNotFound {
setErr := a.db.SetEntry(r.Context(), entry, entryTimeout)
if setErr != nil {
a.handleError(w, r, setErr)
return
}

// Recover previous entry. If key not found we insert with sequence 0
// If there was a previous entry we check the new one is a valid iteration
oldEntry, err := a.db.Entry(r.Context(), entry.Static)
if err == disc.ErrKeyNotFound {
setErr := a.db.SetEntry(r.Context(), entry, entryTimeout)
if setErr != nil {
a.handleError(w, setErr)
a.writeJSON(w, r, http.StatusOK, disc.MsgEntrySet)

return
} else if err != nil {
a.handleError(w, r, err)
return
}

a.writeJSON(w, http.StatusOK, disc.MsgEntrySet)

return
} else if err != nil {
a.handleError(w, err)
return
}
if err := oldEntry.ValidateIteration(entry); err != nil {
a.handleError(w, r, err)
return
}

if err := oldEntry.ValidateIteration(entry); err != nil {
a.handleError(w, err)
return
}
if err := a.db.SetEntry(r.Context(), entry, entryTimeout); err != nil {
a.handleError(w, r, err)
return
}

if err := a.db.SetEntry(r.Context(), entry, entryTimeout); err != nil {
a.handleError(w, err)
return
a.writeJSON(w, r, http.StatusOK, disc.MsgEntryUpdated)
}

a.writeJSON(w, http.StatusOK, disc.MsgEntryUpdated)
}

// getAvailableServers returns all available server entries as an array of json codified entry objects
Expand All @@ -181,20 +175,20 @@ func (a *API) getAvailableServers() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
entries, err := a.db.AvailableServers(r.Context(), maxGetAvailableServersResult)
if err != nil {
a.handleError(w, err)
a.handleError(w, r, err)
return
}

if len(entries) == 0 {
a.writeJSON(w, http.StatusNotFound, disc.HTTPMessage{
a.writeJSON(w, r, http.StatusNotFound, disc.HTTPMessage{
Code: http.StatusNotFound,
Message: disc.ErrNoAvailableServers.Error(),
})

return
}

a.writeJSON(w, http.StatusOK, entries)
a.writeJSON(w, r, http.StatusOK, entries)
}
}

Expand All @@ -203,7 +197,7 @@ func (a *API) getAvailableServers() http.HandlerFunc {
// Method: GET
func (a *API) health() http.HandlerFunc {
const expBase = "health"
return httputil.MakeHealthHandler(a.log, expBase, nil)
return httputil.MakeHealthHandler(expBase, nil)
}

// isLoopbackAddr checks if string is loopback interface
Expand All @@ -220,29 +214,18 @@ func isLoopbackAddr(addr string) (bool, error) {
return net.ParseIP(host).IsLoopback(), nil
}

// retrievePkFromURL returns the id used on endpoints of the form path/:pk
// it doesn't checks if the endpoint has this form and can fail with other
// endpoint forms
func retrievePkFromURL(url *url.URL) (cipher.PubKey, error) {
splitPath := strings.Split(url.EscapedPath(), "/")
v := splitPath[len(splitPath)-1]
pk := cipher.PubKey{}
err := pk.UnmarshalText([]byte(v))
return pk, err
}

// writeJSON writes a json object on a http.ResponseWriter with the given code.
func (a *API) writeJSON(w http.ResponseWriter, code int, object interface{}) {
func (a *API) writeJSON(w http.ResponseWriter, r *http.Request, code int, object interface{}) {
jsonObject, err := json.Marshal(object)
if err != nil {
a.log.Warnf("Failed to encode json response: %s", err)
a.log(r).Warnf("Failed to encode json response: %s", err)
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)

_, err = w.Write(jsonObject)
if err != nil {
a.log.Warnf("Failed to write response: %s", err)
a.log(r).Warnf("Failed to write response: %s", err)
}
}
6 changes: 3 additions & 3 deletions cmd/dmsg-discovery/internal/api/entries_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ func TestEntriesEndpoint(t *testing.T) {
req.Header.Set("Content-Type", contentType)

rr := httptest.NewRecorder()
api.mux.ServeHTTP(rr, req)
api.Handler.ServeHTTP(rr, req)

status := rr.Code
require.Equal(t, tc.status, status, "case: %s, handler returned wrong status code: got `%v` want `%v`",
tc.name, status, tc.status)
require.Equal(t, tc.status, status, "case: %s, handler for %s %s returned wrong status code: got `%v` want `%v`",
tc.name, tc.method, tc.endpoint, status, tc.status)

if tc.responseIsEntry {
var resEntry disc.Entry
Expand Down
8 changes: 4 additions & 4 deletions cmd/dmsg-discovery/internal/api/error_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var apiErrors = map[error]func() (int, string){
},
}

func (a *API) handleError(w http.ResponseWriter, e error) {
func (a *API) handleError(w http.ResponseWriter, r *http.Request, e error) {
var (
code int
msg string
Expand All @@ -42,9 +42,9 @@ func (a *API) handleError(w http.ResponseWriter, e error) {
code, msg = f()
}

if a.log != nil && code != http.StatusNotFound {
a.log.Warnf("%d: %s", code, e)
if code != http.StatusNotFound {
a.log(r).Warnf("%d: %s", code, e)
}

a.writeJSON(w, code, disc.HTTPMessage{Code: code, Message: msg})
a.writeJSON(w, r, code, disc.HTTPMessage{Code: code, Message: msg})
}
2 changes: 1 addition & 1 deletion cmd/dmsg-discovery/internal/api/error_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestErrorHandler(t *testing.T) {
t.Run(tc.err.Error(), func(t *testing.T) {
w := httptest.NewRecorder()
api := New(nil, store.NewMock(), true)
api.handleError(w, tc.err)
api.handleError(w, &http.Request{}, tc.err)

msg := new(disc.HTTPMessage)
err := json.NewDecoder(w.Body).Decode(&msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestGetAvailableServers(t *testing.T) {
require.NoError(t, err)

rr := httptest.NewRecorder()
api.mux.ServeHTTP(rr, req)
api.Handler.ServeHTTP(rr, req)

status := rr.Code
require.Equal(t, tc.status, status, "case: %s, handler returned wrong status code: got `%v` want `%v`",
Expand Down
Loading

0 comments on commit 0423423

Please sign in to comment.