Skip to content

Commit

Permalink
fix ws token auth
Browse files Browse the repository at this point in the history
  • Loading branch information
jmagoon committed Jul 25, 2024
1 parent 4a275e0 commit 6ba4369
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 60 deletions.
1 change: 0 additions & 1 deletion frontend/src/components/ui/RunPipelineButton.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ export default function RunPipelineButton({ children, action }) {

if (!(await validateAnvilOnline())) return;
if (!validateSchema()) return;
console.log(pipelineSpecs);
const newExecution = await execute(pipelineSpecs, executionId);
if (!newExecution) {
return;
Expand Down
5 changes: 3 additions & 2 deletions frontend/src/hooks/useStableWebsocket.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ export const useStableWebSocket = (url) => {
const [wsError, setWsError] = useState(null);
const reconnectCount = useRef(0);
const [configuration] = useAtom(activeConfigurationAtom);
const protocols = url?.startsWith("wss") ? [configuration.anvil.token] : null;
const protocols = url?.startsWith("wss")
? ["Bearer", configuration.anvil.token]
: null;

const { lastMessage, readyState, sendMessage } = useWebSocket(url, {
shouldReconnect: (closeEvent) => {
Expand All @@ -36,7 +38,6 @@ export const useStableWebSocket = (url) => {
protocols: protocols,
reconnectInterval: (attemptNumber) => Math.min(1000 * 2, 30000),

Check failure on line 39 in frontend/src/hooks/useStableWebsocket.js

View workflow job for this annotation

GitHub Actions / lint (20.x)

'attemptNumber' is defined but never used
onOpen: (event) => {

Check failure on line 40 in frontend/src/hooks/useStableWebsocket.js

View workflow job for this annotation

GitHub Actions / lint (20.x)

'event' is defined but never used
console.log("Open: ", event);
console.log("WebSocket connection established.");
setWsError(null);
reconnectCount.current = 0;
Expand Down
26 changes: 10 additions & 16 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,25 +355,18 @@ func main() {
return
}

// Check if it's a WebSocket upgrade request
headerKey := "Authorization"
if isWebSocketRequest(ctx.Request) {
token := ctx.GetHeader("Sec-WebSocket-Protocol")
code, prefix := validateSocketToken(token, certsPath)
if code != http.StatusOK {
ctx.AbortWithStatus(code)
return
}
ctx.Set("prefix", prefix)
} else {
// Existing token validation for non-WebSocket requests
code, prefix := validateToken(ctx, certsPath)
if code != http.StatusOK {
ctx.AbortWithStatus(code)
return
}
ctx.Set("prefix", prefix)
headerKey = "Sec-WebSocket-Protocol"
}

token := ctx.GetHeader(headerKey)
code, prefix := validateToken(token, certsPath)
if code != http.StatusOK {
ctx.AbortWithStatus(code)
return
}
ctx.Set("prefix", prefix)
ctx.Next()
})
}
Expand Down Expand Up @@ -464,6 +457,7 @@ func main() {
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
Subprotocols: []string{"Bearer"},
}

if !websocket.IsWebSocketUpgrade(ctx.Request) {
Expand Down
42 changes: 1 addition & 41 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"path/filepath"
"strings"

"github.com/gin-gonic/gin"
jwt "github.com/golang-jwt/jwt/v5"
)

Expand Down Expand Up @@ -39,13 +38,12 @@ func loadCertificates(folder string) (map[string]crypto.PublicKey, error) {
return certificates, nil
}

func validateToken(ctx *gin.Context, folder string) (int, string) {
func validateToken(bearer string, folder string) (int, string) {
certs, err := loadCertificates(folder)
if err != nil {
log.Printf("Could not load certificates")
return http.StatusUnauthorized, ""
}
bearer := ctx.Request.Header.Get("Authorization")
if len(strings.Fields(bearer)) != 2 {
log.Printf("Invalid authorization header")
return http.StatusUnauthorized, ""
Expand Down Expand Up @@ -76,41 +74,3 @@ func validateToken(ctx *gin.Context, folder string) (int, string) {
sub, _ := token.Claims.GetSubject()
return http.StatusOK, sub
}

func validateSocketToken(token string, folder string) (int, string) {
certs, err := loadCertificates(folder)
if err != nil {
log.Printf("Could not load certificates")
return http.StatusUnauthorized, ""
}

parsedToken, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
sub, err := t.Claims.GetSubject()
if err != nil {
return "", err
}
key, ok := certs[sub]
if !ok {
return "", errors.New("Subject certificate missing")
}
return key, nil
}, jwt.WithValidMethods([]string{"EdDSA"}))

if err != nil {
log.Printf("Error parsing WebSocket token: %v", err)
return http.StatusUnauthorized, ""
}

if !parsedToken.Valid {
log.Printf("Invalid WebSocket token")
return http.StatusUnauthorized, ""
}

sub, err := parsedToken.Claims.GetSubject()
if err != nil {
log.Printf("Error getting subject from WebSocket token: %v", err)
return http.StatusUnauthorized, ""
}

return http.StatusOK, sub
}

0 comments on commit 6ba4369

Please sign in to comment.