Skip to content

Commit

Permalink
better auth error handling (#1032)
Browse files Browse the repository at this point in the history
* better auth error handling

* fix unit test
  • Loading branch information
missing1984 authored and neel-astro committed Jan 20, 2023
1 parent aa58254 commit c8af954
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 11 deletions.
24 changes: 13 additions & 11 deletions cloud/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ var (

var (
err error
callbackChannel = make(chan string, 1)
callbackChannel = make(chan CallbackMessage, 1)
callbackTimeout = time.Second * 300
redirectURI = "http://localhost:12345/callback"
callbackServer = "localhost:12345"
userEmail = ""
)

Expand Down Expand Up @@ -139,17 +140,16 @@ func requestToken(authConfig astro.AuthConfig, verifier, code string) (Result, e

func authorizeCallbackHandler() (string, error) {
m := http.NewServeMux()
s := http.Server{Addr: "localhost:12345", Handler: m, ReadHeaderTimeout: 0}
s := http.Server{Addr: callbackServer, Handler: m, ReadHeaderTimeout: 0}
m.HandleFunc("/callback", func(w http.ResponseWriter, req *http.Request) {
defer req.Body.Close()
if errorCode, ok := req.URL.Query()["error"]; ok {
log.Fatalf(
"Could not authorize your device. %s: %s",
errorCode, req.URL.Query()["error_description"],
)
callbackChannel <- CallbackMessage{errorMessage: fmt.Sprintf("Could not authorize your device. %s: %s",
errorCode, req.URL.Query()["error_description"])}
resp := &http.Request{}
http.Redirect(w, resp, "https://auth.astronomer.io/device/denied", http.StatusFound)
} else {
authorizationCode := req.URL.Query().Get("code")
callbackChannel <- authorizationCode
callbackChannel <- CallbackMessage{authorizationCode: req.URL.Query().Get("code")}
resp := &http.Request{}
http.Redirect(w, resp, "https://auth.astronomer.io/device/success", http.StatusFound)
}
Expand All @@ -164,10 +164,12 @@ func authorizeCallbackHandler() (string, error) {
authorizationCode := ""
for authorizationCode == "" {
select {
case code := <-callbackChannel:
authorizationCode = code
case callbackMessage := <-callbackChannel:
if callbackMessage.errorMessage != "" {
return "", errors.New(callbackMessage.errorMessage)
}
authorizationCode = callbackMessage.authorizationCode
case <-time.After(callbackTimeout):

err := s.Shutdown(http_context.Background())
if err != nil {
fmt.Printf("error: %s", err)
Expand Down
17 changes: 17 additions & 0 deletions cloud/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ func TestRequestToken(t *testing.T) {
func TestAuthorizeCallbackHandler(t *testing.T) {
httpClient = httputil.NewHTTPClient()
t.Run("success", func(t *testing.T) {
callbackServer = "localhost:12345"
go func() {
time.Sleep(2 * time.Second) // time to spinup the server in authorizeCallbackHandler

Expand All @@ -269,7 +270,23 @@ func TestAuthorizeCallbackHandler(t *testing.T) {
assert.NoError(t, err)
})

t.Run("error", func(t *testing.T) {
callbackServer = "localhost:12346"
go func() {
time.Sleep(2 * time.Second) // time to spinup the server in authorizeCallbackHandler
opts := &httputil.DoOptions{
Method: http.MethodGet,
Path: "http://localhost:12346/callback?error=error&error_description=fatal_error",
}
_, err = httpClient.Do(opts) //nolint
assert.NoError(t, err)
}()
_, err := authorizeCallbackHandler()
assert.Contains(t, err.Error(), "fatal_error")
})

t.Run("timeout", func(t *testing.T) {
callbackServer = "localhost:12347"
callbackTimeout = 5 * time.Millisecond
_, err := authorizeCallbackHandler()
assert.Contains(t, err.Error(), "the operation has timed out")
Expand Down
5 changes: 5 additions & 0 deletions cloud/auth/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ type Result struct {
UserEmail string
}

type CallbackMessage struct {
authorizationCode string
errorMessage string
}

func (res Result) writeToContext(c *config.Context) error {
err = c.SetContextKey("token", "Bearer "+res.AccessToken)
if err != nil {
Expand Down

0 comments on commit c8af954

Please sign in to comment.