diff --git a/frontend/src/components/ui/RunPipelineButton.jsx b/frontend/src/components/ui/RunPipelineButton.jsx index 4264062c..0c2d3da0 100644 --- a/frontend/src/components/ui/RunPipelineButton.jsx +++ b/frontend/src/components/ui/RunPipelineButton.jsx @@ -40,7 +40,6 @@ export default function RunPipelineButton({ children, action }) { if (!(await validateAnvilOnline())) return; if (!validateSchema()) return; - console.log(pipelineSpecs); const newExecution = await execute(pipelineSpecs, executionId); if (!newExecution) { return; diff --git a/frontend/src/hooks/useStableWebsocket.js b/frontend/src/hooks/useStableWebsocket.js index c9020cc0..15e47f40 100644 --- a/frontend/src/hooks/useStableWebsocket.js +++ b/frontend/src/hooks/useStableWebsocket.js @@ -17,7 +17,9 @@ export const useStableWebSocket = (url) => { const [wsError, setWsError] = useState(null); const reconnectCount = useRef(0); const [configuration] = useAtom(activeConfigurationAtom); - const protocols = url?.startsWith("wss") ? [configuration.anvil.token] : null; + const protocols = url?.startsWith("wss") + ? ["Bearer", configuration.anvil.token] + : null; const { lastMessage, readyState, sendMessage } = useWebSocket(url, { shouldReconnect: (closeEvent) => { @@ -36,7 +38,6 @@ export const useStableWebSocket = (url) => { protocols: protocols, reconnectInterval: (attemptNumber) => Math.min(1000 * 2, 30000), onOpen: (event) => { - console.log("Open: ", event); console.log("WebSocket connection established."); setWsError(null); reconnectCount.current = 0; diff --git a/main.go b/main.go index 401fb1a8..745eda8e 100644 --- a/main.go +++ b/main.go @@ -355,25 +355,18 @@ func main() { return } - // Check if it's a WebSocket upgrade request + headerKey := "Authorization" if isWebSocketRequest(ctx.Request) { - token := ctx.GetHeader("Sec-WebSocket-Protocol") - code, prefix := validateSocketToken(token, certsPath) - if code != http.StatusOK { - ctx.AbortWithStatus(code) - return - } - ctx.Set("prefix", prefix) - } else { - // Existing token validation for non-WebSocket requests - code, prefix := validateToken(ctx, certsPath) - if code != http.StatusOK { - ctx.AbortWithStatus(code) - return - } - ctx.Set("prefix", prefix) + headerKey = "Sec-WebSocket-Protocol" } + token := ctx.GetHeader(headerKey) + code, prefix := validateToken(token, certsPath) + if code != http.StatusOK { + ctx.AbortWithStatus(code) + return + } + ctx.Set("prefix", prefix) ctx.Next() }) } @@ -464,6 +457,7 @@ func main() { }, ReadBufferSize: 1024, WriteBufferSize: 1024, + Subprotocols: []string{"Bearer"}, } if !websocket.IsWebSocketUpgrade(ctx.Request) { diff --git a/token.go b/token.go index 23d886bc..3fcc9412 100644 --- a/token.go +++ b/token.go @@ -9,7 +9,6 @@ import ( "path/filepath" "strings" - "github.com/gin-gonic/gin" jwt "github.com/golang-jwt/jwt/v5" ) @@ -39,13 +38,12 @@ func loadCertificates(folder string) (map[string]crypto.PublicKey, error) { return certificates, nil } -func validateToken(ctx *gin.Context, folder string) (int, string) { +func validateToken(bearer string, folder string) (int, string) { certs, err := loadCertificates(folder) if err != nil { log.Printf("Could not load certificates") return http.StatusUnauthorized, "" } - bearer := ctx.Request.Header.Get("Authorization") if len(strings.Fields(bearer)) != 2 { log.Printf("Invalid authorization header") return http.StatusUnauthorized, "" @@ -76,41 +74,3 @@ func validateToken(ctx *gin.Context, folder string) (int, string) { sub, _ := token.Claims.GetSubject() return http.StatusOK, sub } - -func validateSocketToken(token string, folder string) (int, string) { - certs, err := loadCertificates(folder) - if err != nil { - log.Printf("Could not load certificates") - return http.StatusUnauthorized, "" - } - - parsedToken, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { - sub, err := t.Claims.GetSubject() - if err != nil { - return "", err - } - key, ok := certs[sub] - if !ok { - return "", errors.New("Subject certificate missing") - } - return key, nil - }, jwt.WithValidMethods([]string{"EdDSA"})) - - if err != nil { - log.Printf("Error parsing WebSocket token: %v", err) - return http.StatusUnauthorized, "" - } - - if !parsedToken.Valid { - log.Printf("Invalid WebSocket token") - return http.StatusUnauthorized, "" - } - - sub, err := parsedToken.Claims.GetSubject() - if err != nil { - log.Printf("Error getting subject from WebSocket token: %v", err) - return http.StatusUnauthorized, "" - } - - return http.StatusOK, sub -}