Skip to content

Commit

Permalink
feat: Allow custom origin header for Websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
hamann committed Mar 11, 2024
1 parent 1e431c7 commit 2d302a3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
5 changes: 2 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,11 @@ func Ping(address string, config *Config) (bool, time.Duration) {
}

// QueryWebSocket opens a websocket connection, write `body` and return a message from the server
func QueryWebSocket(address, body string, config *Config) (bool, []byte, error) {
func QueryWebSocket(address, body string, origin string, config *Config) (bool, []byte, error) {
const (
Origin = "http://localhost/"
MaximumMessageSize = 1024 // in bytes
)
wsConfig, err := websocket.NewConfig(address, Origin)
wsConfig, err := websocket.NewConfig(address, origin)
if err != nil {
return false, nil, fmt.Errorf("error configuring websocket connection: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
}

func TestQueryWebSocket(t *testing.T) {
_, _, err := QueryWebSocket("", "body", &Config{Timeout: 2 * time.Second})
_, _, err := QueryWebSocket("", "body", &Config{Timeout: 2 * time.Second}, "")
if err == nil {
t.Error("expected an error due to the address being invalid")
}
_, _, err = QueryWebSocket("ws://example.org", "body", &Config{Timeout: 2 * time.Second})
_, _, err = QueryWebSocket("ws://example.org", "body", "", &Config{Timeout: 2 * time.Second})
if err == nil {
t.Error("expected an error due to the target not being websocket-friendly")
}
Expand Down
13 changes: 12 additions & 1 deletion core/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ const (
// GatusUserAgent is the default user agent that Gatus uses to send requests.
GatusUserAgent = "Gatus/1.0"

Origin = "http://localhost"
OriginHeader = "Origin"

EndpointTypeDNS EndpointType = "DNS"
EndpointTypeTCP EndpointType = "TCP"
EndpointTypeSCTP EndpointType = "SCTP"
Expand Down Expand Up @@ -221,6 +224,14 @@ func (endpoint *Endpoint) ValidateAndSetDefaults() error {
if _, contentTypeHeaderExists := endpoint.Headers[ContentTypeHeader]; !contentTypeHeaderExists && endpoint.GraphQL {
endpoint.Headers[ContentTypeHeader] = "application/json"
}

// Automatically add Origin header for websocket endpoints if there isn't one specified
if endpoint.Type() == EndpointTypeWS {
if _, originHeaderExists := endpoint.Headers[OriginHeader]; !originHeaderExists {
endpoint.Headers[OriginHeader] = Origin
}
}

for _, endpointAlert := range endpoint.Alerts {
if err := endpointAlert.ValidateAndSetDefaults(); err != nil {
return err
Expand Down Expand Up @@ -376,7 +387,7 @@ func (endpoint *Endpoint) call(result *Result) {
} else if endpointType == EndpointTypeICMP {
result.Connected, result.Duration = client.Ping(strings.TrimPrefix(endpoint.URL, "icmp://"), endpoint.ClientConfig)
} else if endpointType == EndpointTypeWS {
result.Connected, result.Body, err = client.QueryWebSocket(endpoint.URL, endpoint.Body, endpoint.ClientConfig)
result.Connected, result.Body, err = client.QueryWebSocket(endpoint.URL, endpoint.Body, endpoint.Headers[OriginHeader], endpoint.ClientConfig)
if err != nil {
result.AddError(err.Error())
return
Expand Down

0 comments on commit 2d302a3

Please sign in to comment.