Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall committed Jan 7, 2025
1 parent 242f3a7 commit 7c884c1
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 21 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/lint-test-build-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ jobs:
- name: Wait for container to be ready
run: |
for i in {1..10}; do
if curl -s \
-H "Authorization: Bearer $ACTIONS_ID_TOKEN" \
http://localhost:8080/test | grep "ok"; then
TOKEN=$(curl -s \
-H "Accept: application/json; api-version=2.0" \
-H "Content-Type: application/json" -d "{}" \
-H "Authorization: bearer $ACTIONS_ID_TOKEN_REQUEST_TOKEN" \
"$ACTIONS_ID_TOKEN_REQUEST_URL&audience=ghat" | jq -r '.value')
for i in {1..5}; do
if curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8080/test | grep "ok"; then
echo "Auth passed!"
exit 0
fi
Expand Down
52 changes: 37 additions & 15 deletions lib/handler/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type contextKey string

const claimsKey contextKey = "claims"

const githubKeysURL = "https://token.actions.githubusercontent.com/.well-known/openid-configuration"
const githubKeysURL = "https://token.actions.githubusercontent.com/.well-known/jwks"

func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -76,25 +76,30 @@ func JWTAuthMiddleware(next http.Handler) http.Handler {
func verifyJWT(tokenString string) (*GitHubClaims, error) {
keySet, err := fetchJWKS()
if err != nil {
return nil, err
return nil, fmt.Errorf("Unable to fetch JWKS: %v", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

token, err := jwt.Parse([]byte(tokenString), jwt.WithKeySet(keySet), jwt.WithContext(ctx))
token, err := jwt.Parse(
[]byte(tokenString),
jwt.WithKeySet(keySet),
jwt.WithContext(ctx),
jwt.WithVerify(true))
if err != nil {
return nil, err
return nil, fmt.Errorf("Unable to parse token: %v", err)
}

var claims GitHubClaims
if err := validateClaims(token); err != nil {
return nil, err
return nil, fmt.Errorf("Unable to validate claims: %v", err)
}
rawClaims, err := json.Marshal(token)
if err != nil {
return nil, fmt.Errorf("failed to marshal claims: %v", err)
}

var claims GitHubClaims
if err := json.Unmarshal(rawClaims, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %v", err)
}
Expand All @@ -115,16 +120,33 @@ func validateClaims(token jwt.Token) error {
return fmt.Errorf("invalid issuer: %v", iss)
}

if time.Now().After(token.Expiration()) {
return fmt.Errorf("token has expired")
}

ro, ok := token.Get("repository_owner")
if !ok {
return fmt.Errorf("repository_owner claim not found")
claims := map[string]string{
"aud": "https://github.com/libops",
"repository_owner": "libops",
}
if strings.ToLower(ro.(string)) != "libops" {
return fmt.Errorf("invalid repository_owner: %v", iss)
for c, expectedValue := range claims {
claim, ok := token.Get(c)
if !ok {
return fmt.Errorf("%s claim not found", c)
}
found := false
switch v := claim.(type) {
case string:
if strings.ToLower(v) == expectedValue {
found = true
}
case []string:
for _, claimValue := range v {
if strings.ToLower(claimValue) == expectedValue {
found = true
break
}
}
}
if !found {
return fmt.Errorf("invalid %s: %v", c, claim)
}
}

return nil
}
7 changes: 5 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ func main() {
r.Use(handler.JWTAuthMiddleware)
r.HandleFunc("/repo/admin", wh.RepoAdminToken).Methods("POST")
r.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`ok`))
_, err := w.Write([]byte(`ok`))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
slog.Error("Unable to write for healthcheck")
}
}).Methods("GET")

port := os.Getenv("PORT")
Expand Down

0 comments on commit 7c884c1

Please sign in to comment.