Skip to content

Commit

Permalink
Detect if a function needs authentication
Browse files Browse the repository at this point in the history
Signed-off-by: Han Verstraete (OpenFaaS Ltd) <han@openfaas.com>
  • Loading branch information
welteki authored and alexellis committed Jun 19, 2024
1 parent 9210b4e commit 06ceb3c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
54 changes: 51 additions & 3 deletions commands/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ var (
sigHeader string
key string
functionInvokeNamespace string
authenticate bool
)

const functionInvokeRealm = "IAM function invoke"

func init() {
// Setup flags that are used by multiple commands (variables defined in faas.go)
invokeCmd.Flags().StringVar(&functionName, "name", "", "Name of the deployed function")
Expand All @@ -40,6 +43,7 @@ func init() {
invokeCmd.Flags().StringVar(&contentType, "content-type", "text/plain", "The content-type HTTP header such as application/json")
invokeCmd.Flags().StringArrayVar(&query, "query", []string{}, "pass query-string options")
invokeCmd.Flags().StringArrayVarP(&headers, "header", "H", []string{}, "pass HTTP request header")
invokeCmd.Flags().BoolVar(&authenticate, "auth", false, "Authenticate with an OpenFaaS token when invoking the function")
invokeCmd.Flags().BoolVarP(&invokeAsync, "async", "a", false, "Invoke the function asynchronously")
invokeCmd.Flags().StringVarP(&httpMethod, "method", "m", "POST", "pass HTTP request method")
invokeCmd.Flags().BoolVar(&tlsInsecure, "tls-no-verify", false, "Disable TLS validation")
Expand Down Expand Up @@ -124,16 +128,39 @@ func runInvoke(cmd *cobra.Command, args []string) error {
}
req.Header = httpHeader

authenticate := false
res, err := client.InvokeFunction(functionName, functionInvokeNamespace, invokeAsync, authenticate, req)
if err != nil {
return fmt.Errorf("cannot connect to OpenFaaS on URL: %s", client.GatewayURL)
return fmt.Errorf("failed to invoke function: %s", err)
}

if res.Body != nil {
defer res.Body.Close()
}

if !authenticate && res.StatusCode == http.StatusUnauthorized {
authenticateHeader := res.Header.Get("WWW-Authenticate")
realm := getRealm(authenticateHeader)

// Retry the request and authenticate with an OpenFaaS function access token if the realm directive in the
// WWW-Authenticate header is the function invoke realm.
if realm == functionInvokeRealm {
authenticate := true
body := bytes.NewReader(functionInput)
req, err := http.NewRequest(httpMethod, u.String(), body)
if err != nil {
return err
}
req.Header = httpHeader

res, err = client.InvokeFunction(functionName, functionInvokeNamespace, invokeAsync, authenticate, req)
if err != nil {
return fmt.Errorf("failed to invoke function: %s", err)
}
if res.Body != nil {
defer res.Body.Close()
}
}
}

if code := res.StatusCode; code < 200 || code > 299 {
resBody, err := io.ReadAll(res.Body)
if err != nil {
Expand Down Expand Up @@ -228,3 +255,24 @@ func validateHTTPMethod(httpMethod string) error {
}
return nil
}

// NOTE: This is far from a fully compliant parser per RFC 7235.
// It is only intended to correctly capture the realm directive in the
// known format as returned by the OpenFaaS watchdogs.
func getRealm(headerVal string) string {
parts := strings.SplitN(headerVal, " ", 2)

realm := ""
if len(parts) > 1 {
directives := strings.Split(parts[1], ", ")

for _, part := range directives {
if strings.HasPrefix(part, "realm=") {
realm = strings.Trim(part[6:], `"`)
break
}
}
}

return realm
}
32 changes: 32 additions & 0 deletions commands/invoke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,35 @@ func Test_parseQueryValues_invalid(t *testing.T) {
})
}
}

func Test_getRealm(t *testing.T) {
tests := []struct {
header string
want string
}{
{
header: "Bearer",
want: "",
},
{
header: `Bearer realm="OpenFaaS API"`,
want: "OpenFaaS API",
},
{
header: `Bearer realm="OpenFaaS API", charset="UTF-8"`,
want: "OpenFaaS API",
},
{
header: "",
want: "",
},
}

for _, test := range tests {
got := getRealm(test.header)

if test.want != got {
t.Errorf("want: %s, got: %s", test.want, got)
}
}
}

0 comments on commit 06ceb3c

Please sign in to comment.