Skip to content

Commit

Permalink
Merge pull request aws#250 from shinichy/develop
Browse files Browse the repository at this point in the history
Support API Gateway binary payloads
  • Loading branch information
PaulMaddox authored Feb 23, 2018
2 parents 2b3813b + 04d200e commit 2c7cf78
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 70 deletions.
21 changes: 15 additions & 6 deletions router/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

const apiGatewayIntegrationExtension = "x-amazon-apigateway-integration"
const apiGatewayAnyMethodExtension = "x-amazon-apigateway-any-method"
const apiGatewayBinaryMediaTypesExtension = "x-amazon-apigateway-binary-media-types"

// temporary object. This is just used to marshal and unmarshal the any method
// API Gateway swagger extension
Expand Down Expand Up @@ -49,6 +50,11 @@ func (api *AWSServerlessApi) Mounts() ([]*ServerlessRouterMount, error) {

mounts := []*ServerlessRouterMount{}

binaryMediaTypes, ok := swagger.VendorExtensible.Extensions.GetStringSlice(apiGatewayBinaryMediaTypesExtension)
if !ok {
binaryMediaTypes = []string{}
}

for path, pathItem := range swagger.Paths.Paths {
// temporary tracking of mounted methods for the current path. Used to
// mount all non-existing methods for the any extension. This is because
Expand All @@ -75,7 +81,8 @@ func (api *AWSServerlessApi) Mounts() ([]*ServerlessRouterMount, error) {
mounts = append(mounts, api.createMount(
path,
strings.ToLower(method),
api.parseIntegrationSettings(integration)))
api.parseIntegrationSettings(integration),
binaryMediaTypes))
mappedMethods[method] = true
}
}
Expand All @@ -100,7 +107,8 @@ func (api *AWSServerlessApi) Mounts() ([]*ServerlessRouterMount, error) {
mounts = append(mounts, api.createMount(
path,
strings.ToLower(method),
api.parseIntegrationSettings(anyMethodObject.IntegrationSettings)))
api.parseIntegrationSettings(anyMethodObject.IntegrationSettings),
binaryMediaTypes))
}
}
}
Expand Down Expand Up @@ -129,11 +137,12 @@ func (api *AWSServerlessApi) parseIntegrationSettings(integrationData interface{
return &integration
}

func (api *AWSServerlessApi) createMount(path string, verb string, integration *ApiGatewayIntegration) *(ServerlessRouterMount) {
func (api *AWSServerlessApi) createMount(path string, verb string, integration *ApiGatewayIntegration, binaryMediaTypes []string) *(ServerlessRouterMount) {
newMount := &ServerlessRouterMount{
Name: path,
Path: path,
Method: verb,
Name: path,
Path: path,
Method: verb,
BinaryMediaTypes: binaryMediaTypes,
}

if integration == nil {
Expand Down
20 changes: 13 additions & 7 deletions event.go → router/event.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package router

import (
"encoding/json"
Expand All @@ -19,6 +19,7 @@ type Event struct {
PathParameters map[string]string `json:"pathParameters"`
StageVariables map[string]string `json:"stageVariables"`
Path string `json:"path"`
IsBase64Encoded bool `json:"isBase64Encoded"`
}

// RequestContext represents the context object that gets passed to an AWS Lambda function
Expand Down Expand Up @@ -49,12 +50,16 @@ type ContextIdentity struct {
}

// NewEvent initalises and populates a new ApiEvent with
// event details from a http.Request
func NewEvent(req *http.Request) (*Event, error) {

body, err := ioutil.ReadAll(req.Body)
if err != nil {
return nil, err
// event details from a http.Request and isBase64Encoded value
func NewEvent(req *http.Request, isBase64Encoded bool) (*Event, error) {

var body []byte
if req.Body != nil {
var err error
body, err = ioutil.ReadAll(req.Body)
if err != nil {
return nil, err
}
}

headers := map[string]string{}
Expand Down Expand Up @@ -84,6 +89,7 @@ func NewEvent(req *http.Request) (*Event, error) {
Path: req.URL.Path,
Resource: req.URL.Path,
PathParameters: pathParams,
IsBase64Encoded: isBase64Encoded,
}

event.RequestContext.Identity.SourceIP = req.RemoteAddr
Expand Down
21 changes: 8 additions & 13 deletions event_test.go → router/event_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package main
package router

import (
"bytes"
"net/http"
"net/http/httptest"

"github.com/awslabs/aws-sam-local/router"
"github.com/awslabs/goformation/cloudformation"

. "github.com/onsi/ginkgo"
Expand All @@ -14,9 +13,9 @@ import (

var _ = Describe("Event", func() {
Describe("PathParameters", func() {
var r *router.ServerlessRouter
var r *ServerlessRouter
BeforeEach(func() {
r = router.NewServerlessRouter(false)
r = NewServerlessRouter(false)
})

Context("with path parameters on the route", func() {
Expand Down Expand Up @@ -48,22 +47,20 @@ var _ = Describe("Event", func() {
req, _ := http.NewRequest("GET", "/get/1", new(bytes.Buffer))

It("returns the parameters on the event", func() {
r.AddFunction(function, func(w http.ResponseWriter, r *http.Request) {
e, _ := NewEvent(r)
r.AddFunction(function, func(w http.ResponseWriter, e *Event) {
Expect(e.PathParameters).To(HaveKeyWithValue("parameter", "1"))
})

rec := httptest.NewRecorder()
r.Router().ServeHTTP(rec, req)
})
})

Context("and path parameters on the request", func() {
req, _ := http.NewRequest("GET", "/get/1", new(bytes.Buffer))

It("returns stage property with value \"prod\"", func() {
r.AddFunction(function, func(w http.ResponseWriter, r *http.Request) {
e, _ := NewEvent(r)
r.AddFunction(function, func(w http.ResponseWriter, e *Event) {
Expect(e.RequestContext.Stage).To(BeIdenticalTo("prod"))
})

Expand All @@ -76,8 +73,7 @@ var _ = Describe("Event", func() {
req, _ := http.NewRequest("GET", "/get", new(bytes.Buffer))

It("returns nil for PathParameters on the event", func() {
r.AddFunction(function, func(w http.ResponseWriter, r *http.Request) {
e, _ := NewEvent(r)
r.AddFunction(function, func(w http.ResponseWriter, e *Event) {
Expect(e.PathParameters).To(BeNil())
})

Expand Down Expand Up @@ -106,8 +102,7 @@ var _ = Describe("Event", func() {
req, _ := http.NewRequest("GET", "/get", new(bytes.Buffer))

It("returns nil for PathParameters on the event", func() {
r.AddFunction(function, func(w http.ResponseWriter, r *http.Request) {
e, _ := NewEvent(r)
r.AddFunction(function, func(w http.ResponseWriter, e *Event) {
Expect(e.PathParameters).To(BeNil())
})

Expand Down
4 changes: 1 addition & 3 deletions router/function.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package router

import (
"net/http"

"github.com/awslabs/goformation/cloudformation"
)

Expand All @@ -11,7 +9,7 @@ import (
// from the event sources.
type AWSServerlessFunction struct {
*cloudformation.AWSServerlessFunction
handler http.HandlerFunc
handler EventHandlerFunc
}

// Mounts fetches an array of the ServerlessRouterMount's for this API.
Expand Down
6 changes: 3 additions & 3 deletions router/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ var _ = Describe("Function", func() {
},
}

err := r.AddFunction(function, func(w http.ResponseWriter, r *http.Request) {
err := r.AddFunction(function, func(w http.ResponseWriter, e *router.Event) {
w.WriteHeader(200)
w.Write([]byte("ok"))
})
Expand Down Expand Up @@ -190,7 +190,7 @@ var _ = Describe("Function", func() {
},
}

err := r.AddFunction(function, func(w http.ResponseWriter, r *http.Request) {
err := r.AddFunction(function, func(w http.ResponseWriter, e *router.Event) {
w.WriteHeader(200)
w.Write([]byte("ok"))
})
Expand Down Expand Up @@ -246,7 +246,7 @@ var _ = Describe("Function", func() {
Runtime: "nodejs6.10",
}

err := r.AddFunction(function, func(w http.ResponseWriter, r *http.Request) {
err := r.AddFunction(function, func(w http.ResponseWriter, e *router.Event) {
w.WriteHeader(200)
w.Write([]byte("ok"))
})
Expand Down
58 changes: 53 additions & 5 deletions router/mount.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,76 @@
package router

import (
"encoding/base64"
"io/ioutil"
"mime"
"net/http"
"strings"
"unicode/utf8"
"log"
"fmt"
)

const MuxPathRegex = ".+"
var HttpMethods = []string{"OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}

// EventHandlerFunc is similar to Go http.Handler but it receives an event from API Gateway
// instead of http.Request
type EventHandlerFunc func(http.ResponseWriter, *Event)

// ServerlessRouterMount represents a single mount point on the API
// Such as '/path', the HTTP method, and the function to resolve it
type ServerlessRouterMount struct {
Name string
Function *AWSServerlessFunction
Handler http.HandlerFunc
Path string
Method string
Name string
Function *AWSServerlessFunction
Handler EventHandlerFunc
Path string
Method string
BinaryMediaTypes []string

// authorization settings
AuthType string
AuthFunction *AWSServerlessFunction
IntegrationArn *LambdaFunctionArn
}

// Returns the wrapped handler to encode the body as base64 when binary
// media types contains Content-Type
func (m *ServerlessRouterMount) WrappedHandler() http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType)
binaryContent := false

if err == nil {
for _, value := range m.BinaryMediaTypes {
if value != "" && value == mediaType {
binaryContent = true
break
}
}
}

if binaryContent {
if body, err := ioutil.ReadAll(req.Body); err == nil && !utf8.Valid(body) {
req.Body = ioutil.NopCloser(strings.NewReader(base64.StdEncoding.EncodeToString(body)))
} else {
req.Body = ioutil.NopCloser(strings.NewReader(string(body)))
}
}

event, err := NewEvent(req, binaryContent)
if err != nil {
msg := fmt.Sprintf("Error creating a new event: %s", err)
log.Println(msg)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{ "message": "Internal server error" }`))
} else {
m.Handler(w, event)
}
})
}

// Methods gets an array of HTTP methods from a AWS::Serverless::Function
// API event source method declaration (which could include 'any')
func (m *ServerlessRouterMount) Methods() []string {
Expand Down
8 changes: 4 additions & 4 deletions router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func NewServerlessRouter(usePrefix bool) *ServerlessRouter {

// AddFunction adds a AWS::Serverless::Function to the router and mounts all of it's
// event sources that have type 'Api'
func (r *ServerlessRouter) AddFunction(f *cloudformation.AWSServerlessFunction, handler http.HandlerFunc) error {
func (r *ServerlessRouter) AddFunction(f *cloudformation.AWSServerlessFunction, handler EventHandlerFunc) error {

// Wrap GoFormation's AWS::Serverless::Function definition in our own, which provides
// convenience methods for extracting the ServerlessRouterMount(s) from it.
Expand Down Expand Up @@ -115,7 +115,7 @@ func (r *ServerlessRouter) Router() http.Handler {

// Mount all of the things!
for _, mount := range r.Mounts() {
r.mux.Handle(mount.GetMuxPath(), mount.Handler).Methods(mount.Methods()...)
r.mux.Handle(mount.GetMuxPath(), mount.WrappedHandler()).Methods(mount.Methods()...)
}

return r.mux
Expand All @@ -127,8 +127,8 @@ func (r *ServerlessRouter) Mounts() []*ServerlessRouterMount {
return r.mounts
}

func (r *ServerlessRouter) missingFunctionHandler() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, req *http.Request) {
func (r *ServerlessRouter) missingFunctionHandler() func(http.ResponseWriter, *Event) {
return func(w http.ResponseWriter, event *Event) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadGateway)
w.Write([]byte(`{ "message": "No function defined for resource method" }`))
Expand Down
Loading

0 comments on commit 2c7cf78

Please sign in to comment.