diff --git a/.env.sample b/.env.sample index 6b9b6d54..1f405f1b 100644 --- a/.env.sample +++ b/.env.sample @@ -3,7 +3,7 @@ SERVER_WEBSOCKET_CHECK_ORIGIN="true" SERVER_WEBSOCKET_MAX_CONN="30000" SERVER_WEBSOCKET_READ_BUFFER_SIZE="10240" SERVER_WEBSOCKET_WRITE_BUFFER_SIZE="10240" -SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER="x-user-id" +SERVER_WEBSOCKET_CONN_ID_HEADER="X-User-ID" SERVER_WEBSOCKET_PING_INTERVAL_MS=30000 SERVER_WEBSOCKET_PONG_WAIT_INTERVAL_MS=60000 SERVER_WEBSOCKET_WRITE_WAIT_INTERVAL_MS=5000 diff --git a/.env.test b/.env.test index b29e054c..12596ea2 100644 --- a/.env.test +++ b/.env.test @@ -3,10 +3,11 @@ SERVER_WEBSOCKET_CHECK_ORIGIN="true" SERVER_WEBSOCKET_MAX_CONN="30000" SERVER_WEBSOCKET_READ_BUFFER_SIZE="10240" SERVER_WEBSOCKET_WRITE_BUFFER_SIZE="10240" -SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER="x-user-id" -SERVER_WEBSOCKET_PING_INTERVAL_MS=30000 -SERVER_WEBSOCKET_PONG_WAIT_INTERVAL_MS=60000 -SERVER_WEBSOCKET_WRITE_WAIT_INTERVAL_MS=5000 +SERVER_WEBSOCKET_CONN_ID_HEADER="X-User-ID" +SERVER_WEBSOCKET_CONN_GROUP_HEADER="X-User-Group" +SERVER_WEBSOCKET_PING_INTERVAL_MS=10000 +SERVER_WEBSOCKET_PONG_WAIT_INTERVAL_MS=10000 +SERVER_WEBSOCKET_WRITE_WAIT_INTERVAL_MS=1000 SERVER_WEBSOCKET_PINGER_SIZE=1 WORKER_BUFFER_CHANNEL_SIZE=5 diff --git a/README.md b/README.md index c686243a..69e31f9b 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ $ docker pull odpf/raccoon # Run the following docker command with minimal config. $ docker run -p 8080:8080 \ -e SERVER_WEBSOCKET_PORT=8080 \ - -e SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER=x-user-id \ + -e SERVER_WEBSOCKET_CONN_ID_HEADER=X-User-ID \ -e PUBLISHER_KAFKA_CLIENT_BOOTSTRAP_SERVERS=host.docker.internal:9093 \ -e EVENT_DISTRIBUTION_PUBLISHER_PATTERN=clickstream-%s-log \ odpf/raccoon diff --git a/config/load_test.go b/config/load_test.go index 88a7917a..53de284b 100644 --- a/config/load_test.go +++ b/config/load_test.go @@ -27,7 +27,7 @@ func TestServerConfig(t *testing.T) { os.Setenv("SERVER_WEBSOCKET_PING_INTERVAL_MS", "1") os.Setenv("SERVER_WEBSOCKET_PONG_WAIT_INTERVAL_MS", "1") os.Setenv("SERVER_WEBSOCKET_SERVER_SHUTDOWN_GRACE_PERIOD_MS", "3") - os.Setenv("SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER", "x-user-id") + os.Setenv("SERVER_WEBSOCKET_CONN_ID_HEADER", "X-User-ID") serverWsConfigLoader() assert.Equal(t, "8080", ServerWs.AppPort) assert.Equal(t, time.Duration(1)*time.Millisecond, ServerWs.PingInterval) diff --git a/config/server.go b/config/server.go index c5328557..a3c79e22 100644 --- a/config/server.go +++ b/config/server.go @@ -10,16 +10,18 @@ import ( var ServerWs serverWs type serverWs struct { - AppPort string - ServerMaxConn int - ReadBufferSize int - WriteBufferSize int - CheckOrigin bool - PingInterval time.Duration - PongWaitInterval time.Duration - WriteWaitInterval time.Duration - PingerSize int - UniqConnIDHeader string + AppPort string + ServerMaxConn int + ReadBufferSize int + WriteBufferSize int + CheckOrigin bool + PingInterval time.Duration + PongWaitInterval time.Duration + WriteWaitInterval time.Duration + PingerSize int + ConnIDHeader string + ConnGroupHeader string + ConnGroupDefault string } func serverWsConfigLoader() { @@ -32,17 +34,21 @@ func serverWsConfigLoader() { viper.SetDefault("SERVER_WEBSOCKET_PONG_WAIT_INTERVAL_MS", "60000") //should be more than the ping period viper.SetDefault("SERVER_WEBSOCKET_WRITE_WAIT_INTERVAL_MS", "5000") viper.SetDefault("SERVER_WEBSOCKET_PINGER_SIZE", 1) + viper.SetDefault("SERVER_WEBSOCKET_CONN_GROUP_HEADER", "") + viper.SetDefault("SERVER_WEBSOCKET_CONN_GROUP_DEFAULT", "--default--") ServerWs = serverWs{ - AppPort: util.MustGetString("SERVER_WEBSOCKET_PORT"), - ServerMaxConn: util.MustGetInt("SERVER_WEBSOCKET_MAX_CONN"), - ReadBufferSize: util.MustGetInt("SERVER_WEBSOCKET_READ_BUFFER_SIZE"), - WriteBufferSize: util.MustGetInt("SERVER_WEBSOCKET_WRITE_BUFFER_SIZE"), - CheckOrigin: util.MustGetBool("SERVER_WEBSOCKET_CHECK_ORIGIN"), - PingInterval: util.MustGetDuration("SERVER_WEBSOCKET_PING_INTERVAL_MS", time.Millisecond), - PongWaitInterval: util.MustGetDuration("SERVER_WEBSOCKET_PONG_WAIT_INTERVAL_MS", time.Millisecond), - WriteWaitInterval: util.MustGetDuration("SERVER_WEBSOCKET_WRITE_WAIT_INTERVAL_MS", time.Microsecond), - PingerSize: util.MustGetInt("SERVER_WEBSOCKET_PINGER_SIZE"), - UniqConnIDHeader: util.MustGetString("SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER"), + AppPort: util.MustGetString("SERVER_WEBSOCKET_PORT"), + ServerMaxConn: util.MustGetInt("SERVER_WEBSOCKET_MAX_CONN"), + ReadBufferSize: util.MustGetInt("SERVER_WEBSOCKET_READ_BUFFER_SIZE"), + WriteBufferSize: util.MustGetInt("SERVER_WEBSOCKET_WRITE_BUFFER_SIZE"), + CheckOrigin: util.MustGetBool("SERVER_WEBSOCKET_CHECK_ORIGIN"), + PingInterval: util.MustGetDuration("SERVER_WEBSOCKET_PING_INTERVAL_MS", time.Millisecond), + PongWaitInterval: util.MustGetDuration("SERVER_WEBSOCKET_PONG_WAIT_INTERVAL_MS", time.Millisecond), + WriteWaitInterval: util.MustGetDuration("SERVER_WEBSOCKET_WRITE_WAIT_INTERVAL_MS", time.Millisecond), + PingerSize: util.MustGetInt("SERVER_WEBSOCKET_PINGER_SIZE"), + ConnIDHeader: util.MustGetString("SERVER_WEBSOCKET_CONN_ID_HEADER"), + ConnGroupHeader: util.MustGetString("SERVER_WEBSOCKET_CONN_GROUP_HEADER"), + ConnGroupDefault: util.MustGetString("SERVER_WEBSOCKET_CONN_GROUP_DEFAULT"), } } diff --git a/docs/concepts/architecture.md b/docs/concepts/architecture.md index 18b78dcf..e7076704 100644 --- a/docs/concepts/architecture.md +++ b/docs/concepts/architecture.md @@ -22,19 +22,19 @@ Note: The internals of each of the components like channel size, buffer sizes, p ### Connections -Raccoon has long running persistent connections with the client. Once a client makes a http request with a websocket upgrade header, raccoon upgrades the http request to a websocket connection end of which a persistent connection is established with the client. +Raccoon has long-running persistent connections with the client. Once a client makes an HTTP request with a WebSocket upgrade header, raccoon upgrades the HTTP request to a WebSocket connection end of which a persistent connection is established with the client. -The following sequence outlines the connection handling by Raccoon. +The following sequence outlines the connection handling by Raccoon: -* Fetch connection id details from the initial request header based on the configured header name in `SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER`. The header name uniquely identifies a client. A client in this case can be the user in the app. There can be multiple connections from the same client. The no., of connections allowed per client is determined by `SERVER_WEBSOCKET_MAX_CONN`. -* Once the connection id is fetched, verify if the user has connection limit reached based on the configured `SERVER_WEBSOCKET_MAX_CONN`. For each client an internal map stores the `SERVER_WEBSOCKET_MAX_CONN` along with the connection objects. On reaching the max connections for the client, the connection is disconnected with an appropriate error message as a response proto. -* Upgrade the connection -* Add this user-id -> connection mapping -* Add ping/pong handlers on this connection, readtimeout deadline. More about these handlers in the following sections -* Handle the message and send it to the events-channel -* Remove connection/user when the client closes the connection +* Construct connection identifier from the request header. The identifier is constructed from the value of `SERVER_WEBSOCKET_CONN_ID_HEADER` header. For example, Raccoon is configured with `SERVER_WEBSOCKET_CONN_ID_HEADER=X-User-ID`. Raccoon will check the value of X-User-ID header and make it an identifier. Raccoon then uses this identifier to check if there is already an existing connection with the same identifier. If the same connection already exists, Raccoon will disconnect the connection with an appropriate error message as a response proto. + * Optionally, you can also configure `SERVER_WEBSOCKET_CONN_GROUP_HEADER` to support multi-tenancy. For example, you want to use an instance of Raccoon with multiple mobile clients. You can configure raccoon with `SERVER_WEBSOCKET_CONN_GROUP_HEADER=X-Mobile-Client`. Then, Raccoon will use the value of X-Mobile-Client along with X-User-ID as identifier. The uniqueness becomes the combination of X-User-ID value with X-Mobile-Client value. This way, Raccoon can maintain the same X-User-ID within different X-Mobile-Client. +* Verify if the total connections have reached the configured limit based on `SERVER_WEBSOCKET_MAX_CONN` configuration. On reaching the max connections, Raccoon disconnects the connection with an appropriate error message as a response proto. +* Upgrade the connection and persist the identifier. +* Add ping/pong handlers on this connection, read timeout deadline. More about these handlers in the following sections +* At this point, the connection is completely upgraded and Raccoon is ready to accept EventRequest. The handler handles each EventRequest by sending it to the events-channel to be asynchronously published by the publisher. +* When the connection is closed. Raccoon clean up the connection along with the identifier. The same identifier then can be reused on the upcoming connection. -### Event Delivery gurantee \(at-least-once for most time\) +### Event Delivery Gurantee \(at-least-once for most time\) The server for the most times provide at-least-once event delivery gurantee. diff --git a/docs/example/main.go b/docs/example/main.go index 75476ef6..6bbe4cdf 100644 --- a/docs/example/main.go +++ b/docs/example/main.go @@ -14,7 +14,7 @@ import ( var ( url = "ws://localhost:8080/api/v1/events" header = http.Header{ - "x-user-id": []string{"1234"}, + "X-User-ID": []string{"1234"}, } pingInterval = 5 * time.Second ) diff --git a/docs/example/readme.md b/docs/example/readme.md index 66f85b25..b9c5cca9 100644 --- a/docs/example/readme.md +++ b/docs/example/readme.md @@ -14,7 +14,7 @@ You are free to use any websocket client as long as it supports passing header. var ( url = "ws://localhost:8080/api/v1/events" header = http.Header{ - "x-user-id": []string{"1234"}, + "X-User-ID": []string{"1234"}, } ) diff --git a/docs/guides/deployment.md b/docs/guides/deployment.md index 12c84a2d..46fa0dea 100644 --- a/docs/guides/deployment.md +++ b/docs/guides/deployment.md @@ -28,7 +28,7 @@ metadata: data: METRIC_STATSD_ADDRESS: "host.docker.internal:8125" PUBLISHER_KAFKA_CLIENT_BOOTSTRAP_SERVERS: "host.docker.internal:9093" - SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER: "x-user-id" + SERVER_WEBSOCKET_CONN_ID_HEADER: "X-User-ID" SERVER_WEBSOCKET_PORT: "8080" ``` @@ -193,7 +193,7 @@ Followings are main configurations closely related to deployment that you need t * [`EVENT_DISTRIBUTION_PUBLISHER_PATTERN`](https://odpf.gitbook.io/raccoon/reference/configurations#event_distribution_publisher_pattern) * [`PUBLISHER_KAFKA_CLIENT_BOOTSTRAP_SERVERS`](https://odpf.gitbook.io/raccoon/reference/configurations#publisher_kafka_client_bootstrap_servers) * [`METRIC_STATSD_ADDRESS`](https://odpf.gitbook.io/raccoon/reference/configurations#metric_statsd_address) -* [`SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER`](https://odpf.gitbook.io/raccoon/reference/configurations#server_websocket_conn_uniq_id_header) +* [`SERVER_WEBSOCKET_CONN_ID_HEADER`](https://odpf.gitbook.io/raccoon/reference/configurations#server_websocket_conn_id_header) **TLS/HTTPS** diff --git a/docs/guides/publishing.md b/docs/guides/publishing.md index 4279d9bc..8fcfc648 100644 --- a/docs/guides/publishing.md +++ b/docs/guides/publishing.md @@ -93,7 +93,7 @@ The above response model is self-explanatory. Clients can choose to retry for er ## Headers -Raccoon service accepts headers to identify a user connection uniquely. The header name is made configurable as it enables clients to specify a header name that works for them. For, e.g. for a mobile app having a request header as `X-User-id` which identifies the user \(client\) connecting to Raccoon, can configure Raccoon service with the config set as below `SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER=X-User-id` +Raccoon service accepts headers to identify a user connection uniquely. The header name is made configurable as it enables clients to specify a header name that works for them. For, e.g. for a mobile app having a request header as `X-User-ID` which identifies the user \(client\) connecting to Raccoon, can configure Raccoon service with the config set as below `SERVER_WEBSOCKET_CONN_ID_HEADER=X-User-ID`. Optionally, `SERVER_WEBSOCKET_CONN_GROUP_HEADER` can also be configured to [support multitenancy](https://odpf.gitbook.io/raccoon/concepts/architecture#connections) such as multiple apps connecting to a single Raccoon instance. Raccoon uses the config to fetch the header name and uses the value passed in the request header with this name, as the connection id. This header name uniquely identifies a client. A client, in this case, can be the user in the app. @@ -101,7 +101,7 @@ The following header is a sample providing a user id: 654785432. Once the client ```text { - "X-User-id": "654785432" + "X-User-ID": "654785432" } ``` diff --git a/docs/quickstart.md b/docs/quickstart.md index 865304ea..a1b1c902 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -8,7 +8,7 @@ Run the following command. Make sure to set `PUBLISHER_KAFKA_CLIENT_BOOTSTRAP_SE ```bash $ docker run -p 8080:8080 \ - -e SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER=x-user-id \ + -e SERVER_WEBSOCKET_CONN_ID_HEADER=X-User-ID \ -e PUBLISHER_KAFKA_CLIENT_BOOTSTRAP_SERVERS=host.docker.internal:9092 \ -e EVENT_DISTRIBUTION_PUBLISHER_PATTERN=clickstream-log \ odpf/raccoon:latest diff --git a/docs/reference/configurations.md b/docs/reference/configurations.md index 39f39eb8..fb4d6524 100644 --- a/docs/reference/configurations.md +++ b/docs/reference/configurations.md @@ -41,13 +41,27 @@ Specify I/O buffer sizes in bytes: [Refer gorilla websocket API](https://pkg.go. * Type: `Optional` * Default value: `10240` -### `SERVER_WEBSOCKET_CONN_UNIQ_ID_HEADER` +### `SERVER_WEBSOCKET_CONN_ID_HEADER` Unique identifier for the server to maintain the connection. A single uniq id can only connect once in a session. If, there is a subsequence connection with the same uniq id the connection will be rejected. -* Example value: `x-user-id` +* Example value: `X-User-ID` * Type: `Required` +### `SERVER_WEBSOCKET_CONN_GROUP_HEADER` + +Additional identifier for the server to maintain the connection. Value of the conn group header combined with user id will act as unique identifier instead of only user id. You can use this if you want to differentiate between user groups or clients e.g(mobile, web). The group names is used as conn_group tag in some of the metrics. + +* Example value: `X-User-Group` +* Type: `Optional` + +### `SERVER_WEBSOCKET_CONN_GROUP_DEFAULT` + +Default connection group name. The default is fallback when `SERVER_WEBSOCKET_CONN_GROUP_HEADER` is not set or when the value of group header is empty. In case the connection group default is clashing with your actual group name, override this config. + +* Default value: `--default--` +* Type: `Optional` + ### `SERVER_WEBSOCKET_PING_INTERVAL_MS` Interval of each ping to client. The interval is in seconds. diff --git a/docs/reference/metrics.md b/docs/reference/metrics.md index 27e3118d..50f2e56b 100644 --- a/docs/reference/metrics.md +++ b/docs/reference/metrics.md @@ -16,37 +16,42 @@ Raccoon uses Statsd protocol as way to report metrics. You can capture the metri Total ping that server fails to send * Type: `Counting` +* Tags: `conn_group=*` ### `server_pong_failure_total` Total pong that server fails to send * Type: `Counting` +* Tags: `conn_group=*` ### `connections_count_current` Number of alive connections * Type: `Gauge` +* Tags: `conn_group=*` ### `user_connection_success_total` Number of successful connections established to the server * Type: `Count` +* Tags: `conn_group=*` ### `user_connection_failure_total` Number of fail connections established to the server * Type: `Count` -* Tags: `reason=ugfailure` `reason=exists` `reason=serverlimit` +* Tags: `reason=ugfailure` `reason=exists` `reason=serverlimit` `conn_group=*` ### `user_session_duration_milliseconds` Duration of alive connection per session per connection * Type: `Timing` +* Tags: `conn_group=*` ## Kafka Publisher @@ -55,7 +60,7 @@ Duration of alive connection per session per connection Number of delivered events to Kafka * Type: `Count` -* Tags: `success=false` `success=true` +* Tags: `success=false` `success=true` `conn_group=*` ### `kafka_unknown_topic_failure_total` @@ -164,6 +169,7 @@ Following metrics are event delivery reports. Each metrics reported at a differe Total byte receieved in requests * Type: `Count` +* Tags: `conn_group=*` ### `events_rx_total` @@ -176,7 +182,7 @@ Number of events received in requests Request count * Type: `Count` -* Tags: `status=failed` `status=success` `reason=*` +* Tags: `status=failed` `status=success` `reason=*` `conn_group=*` ### `batch_idle_in_channel_milliseconds` @@ -190,12 +196,14 @@ Duration from when the request is received to when the request is processed. Hig Duration from the time request is sent to the time events are published. This metric is calculated per event by following formula `(PublishedTime - SentTime)/CountEvents` * Type: `Timing` +* Tags: `conn_group=*` ### `server_processing_latency_milliseconds` Duration from the time request is receieved to the time events are published. This metric is calculated per event by following formula`(PublishedTime - ReceievedTime)/CountEvents` * Type: `Timing` +* Tags: `conn_group=*` ### `worker_processing_duration_milliseconds` diff --git a/go.mod b/go.mod index 0cfcace6..b3b34608 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,14 @@ go 1.14 require ( github.com/confluentinc/confluent-kafka-go v1.4.2 // indirect - github.com/golang/protobuf v1.4.1 + github.com/golang/protobuf v1.5.0 github.com/gorilla/mux v1.7.4 github.com/gorilla/websocket v1.4.2 github.com/sirupsen/logrus v1.6.0 github.com/spf13/viper v1.7.0 github.com/stretchr/testify v1.6.0 golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb // indirect - google.golang.org/protobuf v1.22.0 + google.golang.org/protobuf v1.26.0 gopkg.in/alexcesaro/statsd.v2 v2.0.0 gopkg.in/confluentinc/confluent-kafka-go.v1 v1.4.2 ) diff --git a/go.sum b/go.sum index ef1fbb88..8c7d7632 100644 --- a/go.sum +++ b/go.sum @@ -57,20 +57,14 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1 h1:ZFgWrT+bLgsYPirOnRfKLYJLvssAegOj/hgyMFdJZe0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -307,13 +301,9 @@ google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvx google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0 h1:cJv5/xdbk1NnMPR1VP9+HU6gupuG9MLBoH1r6RHZ2MY= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alexcesaro/statsd.v2 v2.0.0 h1:FXkZSCZIH17vLCO5sO2UucTHsH9pc+17F6pl3JVCwMc= gopkg.in/alexcesaro/statsd.v2 v2.0.0/go.mod h1:i0ubccKGzBVNBpdGV5MocxyA/XlLUJzA7SLonnE4drU= diff --git a/integration/integration_test.go b/integration/integration_test.go index 908ffb60..29389754 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -8,12 +8,13 @@ import ( "testing" "time" + pb "raccoon/websocket/proto" + "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "gopkg.in/confluentinc/confluent-kafka-go.v1/kafka" - pb "raccoon/websocket/proto" ) var uuid string @@ -24,7 +25,7 @@ var bootstrapServers string func TestMain(m *testing.M) { uuid = fmt.Sprintf("%d-test", rand.Int()) - timeout = 120 * time.Second + timeout = 20 * time.Second topicFormat = os.Getenv("INTEGTEST_TOPIC_FORMAT") url = fmt.Sprintf("%v/api/v1/events", os.Getenv("INTEGTEST_HOST")) bootstrapServers = os.Getenv("INTEGTEST_BOOTSTRAP_SERVER") @@ -35,7 +36,7 @@ func TestIntegration(t *testing.T) { var err error assert.NoError(t, err) header := http.Header{ - "x-user-id": []string{"1234"}, + "X-User-ID": []string{"1234"}, } t.Run("Should response with BadRequest when sending invalid request", func(t *testing.T) { wss, _, err := websocket.DefaultDialer.Dial(url, header) @@ -126,7 +127,7 @@ func TestIntegration(t *testing.T) { t.Log("error", err) continue } - if (string(msg.Value) == "event_1") { + if string(msg.Value) == "event_1" { return } } @@ -160,7 +161,7 @@ func TestIntegration(t *testing.T) { t.Log("error", err) continue } - if (string(msg.Value) == "event_2") { + if string(msg.Value) == "event_2" { return } } @@ -230,4 +231,42 @@ func TestIntegration(t *testing.T) { } }) + t.Run("Should accept connections with same user id with different connection group", func(t *testing.T) { + done := make(chan int) + _, _, err := websocket.DefaultDialer.Dial(url, http.Header{ + "X-User-ID": []string{"1234"}, + "X-User-Group": []string{"viewer"}, + }) + + assert.NoError(t, err) + + secondWss, _, err := websocket.DefaultDialer.Dial(url, http.Header{ + "X-User-ID": []string{"1234"}, + "X-User-Group": []string{"editor"}, + }) + + assert.NoError(t, err) + + go func() { + for { + _, _, err := secondWss.ReadMessage() + assert.NoError(t, err) + if err != nil { + close(done) + break + } + } + }() + select { + case <-time.After(timeout): + assert.Fail(t, "Timeout. Expecting second connection to close") + break + case <-time.After(3 * time.Second): + // Second connection is established and there is no error + break + case <-done: + break + } + }) + } diff --git a/logger/logger.go b/logger/logger.go index 8633c9bb..4cf61273 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -31,6 +31,10 @@ func Debug(args ...interface{}) { logger.Debug(args...) } +func Debugf(format string, args ...interface{}) { + logger.Debugf(format, args...) +} + func Info(args ...interface{}) { logger.Info(args...) } diff --git a/websocket/connection/conn.go b/websocket/connection/conn.go new file mode 100644 index 00000000..a3b0ac3f --- /dev/null +++ b/websocket/connection/conn.go @@ -0,0 +1,41 @@ +package connection + +import ( + "fmt" + "raccoon/logger" + "raccoon/metrics" + "time" + + "github.com/gorilla/websocket" +) + +type Conn struct { + Identifier Identifier + conn *websocket.Conn + connectedAt time.Time + closeHook func(c Conn) +} + +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + return c.conn.ReadMessage() +} + +func (c *Conn) WriteMessage(messageType int, data []byte) error { + return c.conn.WriteMessage(messageType, data) +} + +func (c *Conn) Ping(writeWaitInterval time.Duration) error { + return c.conn.WriteControl(websocket.PingMessage, []byte("--ping--"), time.Now().Add(writeWaitInterval)) +} + +func (c *Conn) Close() { + c.conn.Close() + c.calculateSessionTime() + c.closeHook(*c) +} + +func (c *Conn) calculateSessionTime() { + connectionTime := time.Now().Sub(c.connectedAt) + logger.Debugf("[websocket.calculateSessionTime] %s, total time connected in minutes: %v", c.Identifier, connectionTime.Minutes()) + metrics.Timing("user_session_duration_milliseconds", connectionTime.Milliseconds(), fmt.Sprintf("conn_group=%s", c.Identifier.Group)) +} diff --git a/websocket/connection/identifier.go b/websocket/connection/identifier.go new file mode 100644 index 00000000..cff9a64f --- /dev/null +++ b/websocket/connection/identifier.go @@ -0,0 +1,14 @@ +package connection + +import ( + "fmt" +) + +type Identifier struct { + ID string + Group string +} + +func (i Identifier) String() string { + return fmt.Sprintf("connection [%s] %s", i.Group, i.ID) +} diff --git a/websocket/connection/table.go b/websocket/connection/table.go new file mode 100644 index 00000000..60644e89 --- /dev/null +++ b/websocket/connection/table.go @@ -0,0 +1,64 @@ +package connection + +import ( + "errors" + "sync" +) + +type Table struct { + m *sync.RWMutex + connMap map[Identifier]struct{} + counter map[string]int + maxUser int +} + +func NewTable(maxUser int) *Table { + return &Table{ + m: &sync.RWMutex{}, + connMap: make(map[Identifier]struct{}), + maxUser: maxUser, + counter: make(map[string]int), + } +} + +func (t *Table) Exists(c Identifier) bool { + t.m.Lock() + defer t.m.Unlock() + _, ok := t.connMap[c] + return ok +} + +func (t *Table) Store(c Identifier) error { + t.m.Lock() + defer t.m.Unlock() + if len(t.connMap) >= t.maxUser { + return errMaxConnectionReached + } + if _, ok := t.connMap[c]; ok == true { + return errConnDuplicated + } + t.connMap[c] = struct{}{} + t.counter[c.Group] = t.counter[c.Group] + 1 + return nil +} + +func (t *Table) Remove(c Identifier) { + t.m.Lock() + defer t.m.Unlock() + delete(t.connMap, c) + t.counter[c.Group] = t.counter[c.Group] - 1 +} + +func (t *Table) TotalConnection() int { + t.m.Lock() + defer t.m.Unlock() + return len(t.connMap) +} + +func (t *Table) TotalConnectionPerGroup() map[string]int { + return t.counter +} + +var errMaxConnectionReached = errors.New("max connection reached") + +var errConnDuplicated = errors.New("duplicated connection") diff --git a/websocket/connection/table_test.go b/websocket/connection/table_test.go new file mode 100644 index 00000000..2964894b --- /dev/null +++ b/websocket/connection/table_test.go @@ -0,0 +1,49 @@ +package connection + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConnectionPerGroup(t *testing.T) { + t.Run("Should return all the group on the table with the count", func(t *testing.T) { + table := NewTable(10) + table.Store(Identifier{ID: "user1", Group: "group1"}) + table.Store(Identifier{ID: "user2", Group: "group1"}) + table.Store(Identifier{ID: "user3", Group: "group1"}) + table.Store(Identifier{ID: "user1", Group: "group2"}) + table.Store(Identifier{ID: "user2", Group: "group2"}) + assert.Equal(t, map[string]int{"group1": 3, "group2": 2}, table.TotalConnectionPerGroup()) + }) +} + +func TestStore(t *testing.T) { + t.Run("Should store new connection", func(t *testing.T) { + table := NewTable(10) + err := table.Store(Identifier{ID: "user1", Group: ""}) + assert.NoError(t, err) + assert.True(t, table.Exists(Identifier{ID: "user1"})) + }) + + t.Run("Should return max connection reached error when connection is maxed", func(t *testing.T) { + table := NewTable(0) + err := table.Store(Identifier{ID: "user1", Group: ""}) + assert.Error(t, err, errMaxConnectionReached) + }) + + t.Run("Should return duplicated error when connection already exists", func(t *testing.T) { + table := NewTable(2) + err := table.Store(Identifier{ID: "user1", Group: ""}) + assert.NoError(t, err) + err = table.Store(Identifier{ID: "user1", Group: ""}) + assert.Error(t, err, errConnDuplicated) + }) + + t.Run("Should remove connection when identifier match", func(t *testing.T) { + table := NewTable(10) + table.Store(Identifier{ID: "user1", Group: ""}) + table.Remove(Identifier{ID: "user1", Group: ""}) + assert.False(t, table.Exists(Identifier{ID: "user1", Group: ""})) + }) +} diff --git a/websocket/connection/upgrader.go b/websocket/connection/upgrader.go new file mode 100644 index 00000000..e2db4f26 --- /dev/null +++ b/websocket/connection/upgrader.go @@ -0,0 +1,142 @@ +package connection + +import ( + "errors" + "fmt" + "net/http" + "raccoon/logger" + "raccoon/metrics" + pb "raccoon/websocket/proto" + "time" + + "github.com/golang/protobuf/proto" + "github.com/gorilla/websocket" +) + +type Upgrader struct { + gorillaUg websocket.Upgrader + Table *Table + pongWaitInterval time.Duration + writeWaitInterval time.Duration + connIDHeader string + connGroupHeader string + connGroupDefault string +} + +type UpgraderConfig struct { + ReadBufferSize int + WriteBufferSize int + CheckOrigin bool + MaxUser int + PongWaitInterval time.Duration + WriteWaitInterval time.Duration + ConnIDHeader string + ConnGroupHeader string + ConnGroupDefault string +} + +func NewUpgrader(conf UpgraderConfig) *Upgrader { + var checkOriginFunc func(r *http.Request) bool + if conf.CheckOrigin == false { + checkOriginFunc = func(r *http.Request) bool { + return true + } + } + return &Upgrader{ + gorillaUg: websocket.Upgrader{ + ReadBufferSize: conf.ReadBufferSize, + WriteBufferSize: conf.WriteBufferSize, + CheckOrigin: checkOriginFunc, + }, + Table: NewTable(conf.MaxUser), + pongWaitInterval: conf.PongWaitInterval, + writeWaitInterval: conf.WriteWaitInterval, + connIDHeader: conf.ConnIDHeader, + connGroupHeader: conf.ConnGroupHeader, + connGroupDefault: conf.ConnGroupDefault, + } +} + +func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request) (Conn, error) { + identifier := u.newIdentifier(r.Header) + logger.Debug(fmt.Sprintf("%s connected at %v", identifier, time.Now())) + + conn, err := u.gorillaUg.Upgrade(w, r, nil) + if err != nil { + metrics.Increment("user_connection_failure_total", fmt.Sprintf("reason=ugfailure,conn_group=%s", identifier.Group)) + return Conn{}, fmt.Errorf("failed to upgrade %s: %v", identifier, err) + } + err = u.Table.Store(identifier) + if errors.Is(err, errConnDuplicated) { + duplicateConnResp := createEmptyErrorResponse(pb.Code_MAX_USER_LIMIT_REACHED) + + conn.WriteMessage(websocket.BinaryMessage, duplicateConnResp) + conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(1008, "Duplicate connection")) + metrics.Increment("user_connection_failure_total", fmt.Sprintf("reason=exists,conn_group=%s", identifier.Group)) + conn.Close() + return Conn{}, fmt.Errorf("disconnecting %s: already connected", identifier) + } + if errors.Is(err, errMaxConnectionReached) { + logger.Errorf("[websocket.Handler] Disconnecting %v, max connection reached", identifier) + maxConnResp := createEmptyErrorResponse(pb.Code_MAX_CONNECTION_LIMIT_REACHED) + conn.WriteMessage(websocket.BinaryMessage, maxConnResp) + conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(1008, "Max connection reached")) + metrics.Increment("user_connection_failure_total", fmt.Sprintf("reason=serverlimit,conn_group=%s", identifier.Group)) + conn.Close() + return Conn{}, fmt.Errorf("max connection reached") + } + + u.setUpControlHandlers(conn, identifier) + metrics.Increment("user_connection_success_total", fmt.Sprintf("conn_group=%s", identifier.Group)) + + return Conn{ + Identifier: identifier, + conn: conn, + connectedAt: time.Now(), + closeHook: func(c Conn) { + u.Table.Remove(c.Identifier) + }}, nil +} + +func (u *Upgrader) setUpControlHandlers(conn *websocket.Conn, identifier Identifier) { + //expects the client to send a ping, mark this channel as idle timed out post the deadline + conn.SetReadDeadline(time.Now().Add(u.pongWaitInterval)) + conn.SetPongHandler(func(string) error { + // extends the read deadline since we have received this pong on this channel + conn.SetReadDeadline(time.Now().Add(u.pongWaitInterval)) + return nil + }) + + conn.SetPingHandler(func(s string) error { + logger.Debug(fmt.Sprintf("Client %s pinged", identifier)) + if err := conn.WriteControl(websocket.PongMessage, []byte(s), time.Now().Add(u.writeWaitInterval)); err != nil { + metrics.Increment("server_pong_failure_total", fmt.Sprintf("conn_group=%s", identifier.Group)) + logger.Debug(fmt.Sprintf("Failed to send pong event %s: %v", identifier, err)) + } + return nil + }) +} + +func (u *Upgrader) newIdentifier(h http.Header) Identifier { + // If connGroupHeader is empty string. By default, it will always return an empty string as Group. This means the group is fallback to default value. + var group = h.Get(u.connGroupHeader) + if group == "" { + group = u.connGroupDefault + } + return Identifier{ + ID: h.Get(u.connIDHeader), + Group: group, + } +} + +func createEmptyErrorResponse(errCode pb.Code) []byte { + resp := pb.EventResponse{ + Status: pb.Status_ERROR, + Code: errCode, + SentTime: time.Now().Unix(), + Reason: "", + Data: nil, + } + duplicateConnResp, _ := proto.Marshal(&resp) + return duplicateConnResp +} diff --git a/websocket/connection/upgrader_test.go b/websocket/connection/upgrader_test.go new file mode 100644 index 00000000..ce58b089 --- /dev/null +++ b/websocket/connection/upgrader_test.go @@ -0,0 +1,230 @@ +package connection + +import ( + "net/http" + "net/http/httptest" + "os" + "raccoon/logger" + "raccoon/metrics" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/mux" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" +) + +type void struct{} + +func (v void) Write(_ []byte) (int, error) { + return 0, nil +} + +func TestMain(t *testing.M) { + logger.SetOutput(void{}) + metrics.SetVoid() + os.Exit(t.Run()) +} + +var config = UpgraderConfig{ + ReadBufferSize: 10240, + WriteBufferSize: 10240, + CheckOrigin: false, + MaxUser: 2, + PongWaitInterval: time.Duration(60 * time.Second), + WriteWaitInterval: time.Duration(5 * time.Second), + ConnIDHeader: "X-User-ID", + ConnGroupHeader: "", + ConnGroupDefault: "--default--", +} + +func TestConnectionLifecycle(t *testing.T) { + t.Run("Should increment total connection when upgraded", func(t *testing.T) { + upgrader := NewUpgrader(config) + headers := []http.Header{{ + "X-User-ID": []string{"user1"}, + }} + upgradeConnectionTestHelper(t, upgrader, headers, assertUpgrade{ + callback: func(u upgradeRes) { + assert.NoError(t, u.err) + assert.Equal(t, 1, upgrader.Table.TotalConnection()) + }, + onIteration: 1, + }) + }) + + t.Run("Should decrement total connection when client close the conn", func(t *testing.T) { + upgrader := NewUpgrader(config) + headers := []http.Header{{ + "X-User-ID": []string{"user1"}, + }, { + "X-User-ID": []string{"user1"}, + }} + upgradeConnectionTestHelper(t, upgrader, headers, assertUpgrade{ + callback: func(u upgradeRes) { + if u.iteration == 1 { + assert.Equal(t, 1, upgrader.Table.TotalConnection()) + } + u.conn.Close() + }, + }) + }) +} + +func TestConnectionGroup(t *testing.T) { + t.Run("Should accept connections with same userid and different group", func(t *testing.T) { + config.ConnGroupHeader = "X-User-Group" + defer func() { config.ConnGroupHeader = "" }() + upgrader := NewUpgrader(config) + headers := []http.Header{{ + "X-User-ID": []string{"user1"}, + "X-User-Group": []string{"viewer"}, + }, { + "X-User-ID": []string{"user1"}, + "X-User-Group": []string{"editor"}, + }} + upgradeConnectionTestHelper(t, upgrader, headers, assertUpgrade{ + callback: func(u upgradeRes) { + assert.Equal(t, 2, upgrader.Table.TotalConnection()) + assert.NoError(t, u.err) + }, + onIteration: 2, + }) + }) + + t.Run("Should use default when ConnGroupHeader is not provided", func(t *testing.T) { + upgrader := NewUpgrader(config) + headers := []http.Header{{ + "X-User-ID": []string{"user1"}, + }, { + "X-User-ID": []string{"user1"}, + }, { + "X-User-ID": []string{"user1"}, + }} + upgradeConnectionTestHelper(t, upgrader, headers, assertUpgrade{ + callback: func(u upgradeRes) { + assert.EqualError(t, u.err, "disconnecting connection [--default--] user1: already connected") + }, + onIteration: 3, + }) + }) + + t.Run("Should reject connections with same userid and same group", func(t *testing.T) { + config.ConnGroupHeader = "X-User-Group" + defer func() { config.ConnGroupHeader = "" }() + upgrader := NewUpgrader(config) + headers := []http.Header{{ + "X-User-ID": []string{"user1"}, + "X-User-Group": []string{"viewer"}, + }, { + "X-User-ID": []string{"user1"}, + "X-User-Group": []string{"viewer"}, + }} + upgradeConnectionTestHelper(t, upgrader, headers, assertUpgrade{ + callback: func(u upgradeRes) { + assert.Equal(t, 1, upgrader.Table.TotalConnection()) + assert.EqualError(t, u.err, "disconnecting connection [viewer] user1: already connected") + }, + onIteration: 2, + }) + }) + + t.Run("Should be able to reconnect when connection is closed", func(t *testing.T) { + config.ConnGroupHeader = "X-User-Group" + defer func() { config.ConnGroupHeader = "" }() + upgrader := NewUpgrader(config) + headers := []http.Header{{ + "X-User-ID": []string{"user1"}, + "X-User-Group": []string{"viewer"}, + }, { + "X-User-ID": []string{"user1"}, + "X-User-Group": []string{"viewer"}, + }, { + "X-User-ID": []string{"user1"}, + "X-User-Group": []string{"viewer"}, + }} + upgradeConnectionTestHelper(t, upgrader, headers, assertUpgrade{ + callback: func(u upgradeRes) { + assert.Equal(t, 1, upgrader.Table.TotalConnection()) + assert.NoError(t, u.err) + u.conn.Close() + }, + }) + }) +} + +func TestConnectionRejection(t *testing.T) { + t.Run("Should close new connection when max is reached", func(t *testing.T) { + upgrader := NewUpgrader(config) + headers := make([]http.Header, 0) + for _, i := range []string{"1", "2", "3"} { + headers = append(headers, http.Header{ + "X-User-ID": []string{"user-" + i}, + }) + } + + upgradeConnectionTestHelper(t, upgrader, headers, assertUpgrade{ + callback: func(u upgradeRes) { + assert.EqualError(t, u.err, "max connection reached") + }, + onIteration: 3, + }) + }) +} + +// Prepare a websocket server with given upgrader and establish the connections with the given headers as many as given headers. +func upgradeConnectionTestHelper(t *testing.T, upgrader *Upgrader, headers []http.Header, f assertUpgrade) { + res := make(chan upgradeRes) + m := sync.Mutex{} + iteration := 0 + r := mux.NewRouter() + r.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { + m.Lock() + iteration++ + i := iteration + m.Unlock() + c, err := upgrader.Upgrade(rw, r) + if f.onIteration == 0 { + res <- upgradeRes{ + err: err, + iteration: i, + conn: c} + } + if i == f.onIteration { + res <- upgradeRes{ + err: err, + iteration: i, + conn: c} + } + }) + server := httptest.NewServer(r) + defer server.Close() + connect := func(h http.Header) { + websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(server.URL, "http"), h) + } + for _, header := range headers { + connect(header) + } + timeout := 5 * time.Second + select { + case <-time.After(timeout): + t.Fatal("timeout, no error return from upgrader") + case e := <-res: + f.callback(e) + } +} + +// Struct to prepare upgrade assertion. +// If onIteration is provided, the assertion only run on the specified iteration of the passed headers. If onIteration is not provided or 0, assertion is run every upgrade. +type assertUpgrade struct { + callback func(u upgradeRes) + onIteration int +} + +type upgradeRes struct { + err error + conn Conn + iteration int +} diff --git a/websocket/handler.go b/websocket/handler.go index 01759986..75f3550b 100644 --- a/websocket/handler.go +++ b/websocket/handler.go @@ -5,28 +5,25 @@ import ( "net/http" "raccoon/logger" "raccoon/metrics" + "raccoon/websocket/connection" "time" + pb "raccoon/websocket/proto" + "github.com/golang/protobuf/proto" "github.com/gorilla/websocket" - pb "raccoon/websocket/proto" ) type Handler struct { - websocketUpgrader websocket.Upgrader - bufferChannel chan EventsBatch - user *User - PongWaitInterval time.Duration - WriteWaitInterval time.Duration - PingChannel chan connection - UniqConnIDHeader string + upgrader *connection.Upgrader + bufferChannel chan EventsBatch + PingChannel chan connection.Conn } - type EventsBatch struct { - UniqConnID string - EventReq *pb.EventRequest - TimeConsumed time.Time - TimePushed time.Time + ConnIdentifier connection.Identifier + EventReq *pb.EventRequest + TimeConsumed time.Time + TimePushed time.Time } func PingHandler(w http.ResponseWriter, r *http.Request) { @@ -35,45 +32,14 @@ func PingHandler(w http.ResponseWriter, r *http.Request) { } //HandlerWSEvents handles the upgrade and the events sent by the peers -func (wsHandler *Handler) HandlerWSEvents(w http.ResponseWriter, r *http.Request) { - uniqConnID := r.Header.Get(wsHandler.UniqConnIDHeader) - connectedTime := time.Now() - logger.Debug(fmt.Sprintf("UniqConnID %s connected at %v", uniqConnID, connectedTime)) - conn, err := wsHandler.websocketUpgrader.Upgrade(w, r, nil) +func (h *Handler) HandlerWSEvents(w http.ResponseWriter, r *http.Request) { + conn, err := h.upgrader.Upgrade(w, r) if err != nil { - logger.Error(fmt.Sprintf("[websocket.Handler] Failed to upgrade connection UniqConnID: %s : %v", uniqConnID, err)) - metrics.Increment("user_connection_failure_total", "reason=ugfailure") + logger.Debugf("[websocket.Handler] %v", err) return } defer conn.Close() - - if wsHandler.user.Exists(uniqConnID) { - logger.Errorf("[websocket.Handler] Disconnecting %v, already connected", uniqConnID) - duplicateConnResp := createEmptyErrorResponse(pb.Code_MAX_USER_LIMIT_REACHED) - - conn.WriteMessage(websocket.BinaryMessage, duplicateConnResp) - conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(1008, "Duplicate connection")) - metrics.Increment("user_connection_failure_total", "reason=exists") - return - } - if wsHandler.user.HasReachedLimit() { - logger.Errorf("[websocket.Handler] Disconnecting %v, max connection reached", uniqConnID) - maxConnResp := createEmptyErrorResponse(pb.Code_MAX_CONNECTION_LIMIT_REACHED) - conn.WriteMessage(websocket.BinaryMessage, maxConnResp) - conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(1008, "Max connection reached")) - metrics.Increment("user_connection_failure_total", "reason=serverlimit") - return - } - wsHandler.user.Store(uniqConnID) - defer wsHandler.user.Remove(uniqConnID) - defer calculateSessionTime(uniqConnID, connectedTime) - - setUpControlHandlers(conn, uniqConnID, wsHandler.PongWaitInterval, wsHandler.WriteWaitInterval) - wsHandler.PingChannel <- connection{ - uniqConnID: uniqConnID, - conn: conn, - } - metrics.Increment("user_connection_success_total", "") + h.PingChannel <- conn for { _, message, err := conn.ReadMessage() @@ -82,34 +48,34 @@ func (wsHandler *Handler) HandlerWSEvents(w http.ResponseWriter, r *http.Request websocket.CloseNormalClosure, websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure) { - logger.Error(fmt.Sprintf("[websocket.Handler] UniqConnID %s Connection Closed Abruptly: %v", uniqConnID, err)) - metrics.Increment("batches_read_total", "status=failed,reason=closeerror") + logger.Error(fmt.Sprintf("[websocket.Handler] %s closed abruptly: %v", conn.Identifier, err)) + metrics.Increment("batches_read_total", fmt.Sprintf("status=failed,reason=closeerror,conn_group=%s", conn.Identifier.Group)) break } - metrics.Increment("batches_read_total", "status=failed,reason=unknown") - logger.Error(fmt.Sprintf("[websocket.Handler] Reading message failed. Unknown failure: %v User ID: %s ", err, uniqConnID)) //no connection issue here + metrics.Increment("batches_read_total", fmt.Sprintf("status=failed,reason=unknown,conn_group=%s", conn.Identifier.Group)) + logger.Error(fmt.Sprintf("[websocket.Handler] reading message failed. Unknown failure for %s: %v", conn.Identifier, err)) //no connection issue here break } timeConsumed := time.Now() - metrics.Count("events_rx_bytes_total", len(message), "") + metrics.Count("events_rx_bytes_total", len(message), fmt.Sprintf("conn_group=%s", conn.Identifier.Group)) payload := &pb.EventRequest{} err = proto.Unmarshal(message, payload) if err != nil { - logger.Error(fmt.Sprintf("[websocket.Handler] Reading message failed. %v UniqConnID: %s ", err, uniqConnID)) - metrics.Increment("batches_read_total", "status=failed,reason=serde") + logger.Error(fmt.Sprintf("[websocket.Handler] reading message failed for %s: %v", conn.Identifier, err)) + metrics.Increment("batches_read_total", fmt.Sprintf("status=failed,reason=serde,conn_group=%s", conn.Identifier.Group)) badrequest := createBadrequestResponse(err) conn.WriteMessage(websocket.BinaryMessage, badrequest) continue } - metrics.Increment("batches_read_total", "status=success") - metrics.Count("events_rx_total", len(payload.Events), "") - - wsHandler.bufferChannel <- EventsBatch{ - UniqConnID: uniqConnID, - EventReq: payload, - TimeConsumed: timeConsumed, - TimePushed: (time.Now()), + metrics.Increment("batches_read_total", fmt.Sprintf("status=success,conn_group=%s", conn.Identifier.Group)) + metrics.Count("events_rx_total", len(payload.Events), fmt.Sprintf("conn_group=%s", conn.Identifier.Group)) + + h.bufferChannel <- EventsBatch{ + ConnIdentifier: conn.Identifier, + EventReq: payload, + TimeConsumed: timeConsumed, + TimePushed: (time.Now()), } resp := createSuccessResponse(payload.ReqGuid) @@ -117,29 +83,3 @@ func (wsHandler *Handler) HandlerWSEvents(w http.ResponseWriter, r *http.Request conn.WriteMessage(websocket.BinaryMessage, success) } } - -func calculateSessionTime(uniqConnID string, connectedAt time.Time) { - connectionTime := time.Now().Sub(connectedAt) - logger.Debug(fmt.Sprintf("[websocket.calculateSessionTime] UniqConnID: %s, total time connected in minutes: %v", uniqConnID, connectionTime.Minutes())) - metrics.Timing("user_session_duration_milliseconds", connectionTime.Milliseconds(), "") -} - -func setUpControlHandlers(conn *websocket.Conn, uniqConnID string, - PongWaitInterval time.Duration, WriteWaitInterval time.Duration) { - //expects the client to send a ping, mark this channel as idle timed out post the deadline - conn.SetReadDeadline(time.Now().Add(PongWaitInterval)) - conn.SetPongHandler(func(string) error { - // extends the read deadline since we have received this pong on this channel - conn.SetReadDeadline(time.Now().Add(PongWaitInterval)) - return nil - }) - - conn.SetPingHandler(func(s string) error { - logger.Debug(fmt.Sprintf("Client connection with UniqConnID: %s Pinged", uniqConnID)) - if err := conn.WriteControl(websocket.PongMessage, []byte(s), time.Now().Add(WriteWaitInterval)); err != nil { - metrics.Increment("server_pong_failure_total", "") - logger.Debug(fmt.Sprintf("Failed to send pong event: %s UniqConnID: %s", err, uniqConnID)) - } - return nil - }) -} diff --git a/websocket/handler_test.go b/websocket/handler_test.go index e8dfcad1..289b3a5f 100644 --- a/websocket/handler_test.go +++ b/websocket/handler_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "raccoon/websocket/connection" pb "raccoon/websocket/proto" "github.com/golang/protobuf/proto" @@ -44,27 +45,27 @@ func TestPingHandler(t *testing.T) { func TestHandler_HandlerWSEvents(t *testing.T) { // ---- Setup ---- - hlr := &Handler{ - websocketUpgrader: websocket.Upgrader{ - ReadBufferSize: 10240, - WriteBufferSize: 10240, - CheckOrigin: func(r *http.Request) bool { - return true - }, - }, - user: NewUserStore(2), - bufferChannel: make(chan EventsBatch, 10), + upgrader := connection.NewUpgrader(connection.UpgraderConfig{ + ReadBufferSize: 10240, + WriteBufferSize: 10240, + CheckOrigin: false, + MaxUser: 2, PongWaitInterval: time.Duration(60 * time.Second), WriteWaitInterval: time.Duration(5 * time.Second), - PingChannel: make(chan connection, 100), - UniqConnIDHeader: "x-user-id", + ConnIDHeader: "X-User-ID", + ConnGroupHeader: "string", + }) + hlr := &Handler{ + upgrader: upgrader, + bufferChannel: make(chan EventsBatch, 10), + PingChannel: make(chan connection.Conn, 100), } ts := httptest.NewServer(Router(hlr)) defer ts.Close() url := "ws" + strings.TrimPrefix(ts.URL+"/api/v1/events", "http") header := http.Header{ - "x-user-id": []string{"test1-user1"}, + "X-User-ID": []string{"test1-user1"}, } t.Run("Should return success response after successfully push to channel", func(t *testing.T) { @@ -102,7 +103,7 @@ func TestHandler_HandlerWSEvents(t *testing.T) { defer ts.Close() wss, _, err := websocket.DefaultDialer.Dial(url, http.Header{ - "x-user-id": []string{"test2-user2"}, + "X-User-ID": []string{"test2-user2"}, }) defer wss.Close() require.NoError(t, err) @@ -120,71 +121,4 @@ func TestHandler_HandlerWSEvents(t *testing.T) { assert.Equal(t, pb.Code_BAD_REQUEST, resp.GetCode()) assert.Empty(t, resp.GetData()) }) - - t.Run("Should close subsequence connection of the same user", func(t *testing.T) { - ts := httptest.NewServer(Router(hlr)) - defer ts.Close() - - url := "ws" + strings.TrimPrefix(ts.URL+"/api/v1/events", "http") - header := http.Header{ - "x-user-id": []string{"test1-user1"}, - } - w1, _, err := websocket.DefaultDialer.Dial(url, header) - defer w1.Close() - require.NoError(t, err) - - w2, _, err := websocket.DefaultDialer.Dial(url, header) - defer w2.Close() - require.NoError(t, err) - _, message, err := w2.ReadMessage() - p := &pb.EventResponse{} - proto.Unmarshal(message, p) - assert.Equal(t, p.Code, pb.Code_MAX_USER_LIMIT_REACHED) - assert.Equal(t, p.Status, pb.Status_ERROR) - _, _, err = w2.ReadMessage() - assert.True(t, websocket.IsCloseError(err, websocket.ClosePolicyViolation)) - assert.Equal(t, "Duplicate connection", err.(*websocket.CloseError).Text) - }) - - t.Run("Should close new connection when reach max connection", func(t *testing.T) { - ts := httptest.NewServer(Router(hlr)) - defer ts.Close() - - url := "ws" + strings.TrimPrefix(ts.URL+"/api/v1/events", "http") - header := http.Header{ - "x-user-id": []string{"test1-user1"}, - } - w1, _, _ := websocket.DefaultDialer.Dial(url, http.Header{"x-user-id": []string{"test1-user2"}}) - defer w1.Close() - w2, _, _ := websocket.DefaultDialer.Dial(url, http.Header{"x-user-id": []string{"test1-user3"}}) - defer w2.Close() - - w3, _, err := websocket.DefaultDialer.Dial(url, header) - defer w3.Close() - require.NoError(t, err) - _, message, err := w3.ReadMessage() - p := &pb.EventResponse{} - proto.Unmarshal(message, p) - assert.Equal(t, p.Code, pb.Code_MAX_CONNECTION_LIMIT_REACHED) - assert.Equal(t, p.Status, pb.Status_ERROR) - _, _, err = w3.ReadMessage() - assert.True(t, websocket.IsCloseError(err, websocket.ClosePolicyViolation)) - assert.Equal(t, "Max connection reached", err.(*websocket.CloseError).Text) - }) - - t.Run("Should decrement total connection when client close the conn", func(t *testing.T) { - ts := httptest.NewServer(Router(hlr)) - defer ts.Close() - - url := "ws" + strings.TrimPrefix(ts.URL+"/api/v1/events", "http") - w1, _, _ := websocket.DefaultDialer.Dial(url, http.Header{"x-user-id": []string{"test1-user2"}}) - defer w1.Close() - w2, _, _ := websocket.DefaultDialer.Dial(url, http.Header{"x-user-id": []string{"test1-user3"}}) - defer w2.Close() - w3, _, err := websocket.DefaultDialer.Dial(url, http.Header{"x-user-id": []string{"test1-user1"}}) - defer w3.Close() - - assert.Equal(t, 2, hlr.user.TotalUsers()) - assert.Empty(t, err) - }) } diff --git a/websocket/pinger.go b/websocket/pinger.go index a35eb97c..fb259f17 100644 --- a/websocket/pinger.go +++ b/websocket/pinger.go @@ -4,33 +4,27 @@ import ( "fmt" "raccoon/logger" "raccoon/metrics" + "raccoon/websocket/connection" "time" - - "github.com/gorilla/websocket" ) -type connection struct { - uniqConnID string - conn *websocket.Conn -} - -//Pinger is a worker groroutine that pings the connected peers based on ping interval. -func Pinger(c chan connection, size int, PingInterval time.Duration, WriteWaitInterval time.Duration) { +//Pinger is worker that pings the connected peers based on ping interval. +func Pinger(c chan connection.Conn, size int, PingInterval time.Duration, WriteWaitInterval time.Duration) { for i := 0; i < size; i++ { go func() { - cSet := make(map[string]*websocket.Conn) - timer := time.NewTicker(PingInterval) + cSet := make(map[connection.Identifier]connection.Conn) + ticker := time.NewTicker(PingInterval) for { select { case conn := <-c: - cSet[conn.uniqConnID] = conn.conn - case <-timer.C: - for uniqConnID, conn := range cSet { - logger.Debug(fmt.Sprintf("Pinging UniqConnID: %s ", uniqConnID)) - if err := conn.WriteControl(websocket.PingMessage, []byte("--ping--"), time.Now().Add(WriteWaitInterval)); err != nil { - logger.Error(fmt.Sprintf("[websocket.pingPeer] - Failed to ping User: %s Error: %v", uniqConnID, err)) - metrics.Increment("server_ping_failure_total", "") - delete(cSet, uniqConnID) + cSet[conn.Identifier] = conn + case <-ticker.C: + for identifier, conn := range cSet { + logger.Debug(fmt.Sprintf("Pinging %s ", identifier)) + if err := conn.Ping(WriteWaitInterval); err != nil { + logger.Error(fmt.Sprintf("[websocket.pingPeer] - Failed to ping %s: %v", identifier, err)) + metrics.Increment("server_ping_failure_total", fmt.Sprintf("conn_group=%s", identifier.Group)) + delete(cSet, identifier) } } } diff --git a/websocket/responsefactory.go b/websocket/responsefactory.go index 060c9b04..0bcce572 100644 --- a/websocket/responsefactory.go +++ b/websocket/responsefactory.go @@ -32,15 +32,3 @@ func createBadrequestResponse(err error) []byte { badrequestResp, _ := proto.Marshal(&response) return badrequestResp } - -func createEmptyErrorResponse(errCode pb.Code) []byte { - resp := pb.EventResponse{ - Status: pb.Status_ERROR, - Code: errCode, - SentTime: time.Now().Unix(), - Reason: "", - Data: nil, - } - duplicateConnResp, _ := proto.Marshal(&resp) - return duplicateConnResp -} diff --git a/websocket/server.go b/websocket/server.go index 38acbbaf..76f3ecaf 100644 --- a/websocket/server.go +++ b/websocket/server.go @@ -2,16 +2,18 @@ package websocket import ( "context" + "fmt" "net/http" "raccoon/config" "raccoon/logger" + "raccoon/websocket/connection" "runtime" "time" "raccoon/metrics" "github.com/gorilla/mux" - "github.com/gorilla/websocket" + // https://golang.org/pkg/net/http/pprof/ _ "net/http/pprof" ) @@ -19,8 +21,8 @@ import ( type Server struct { HTTPServer *http.Server bufferChannel chan EventsBatch - user *User - pingChannel chan connection + table *connection.Table + pingChannel chan connection.Conn } func (s *Server) StartHTTPServer(ctx context.Context, cancel context.CancelFunc) { @@ -49,7 +51,9 @@ func (s *Server) ReportServerMetrics() { m := &runtime.MemStats{} for { <-t - metrics.Gauge("connections_count_current", s.user.TotalUsers(), "") + for k, v := range s.table.TotalConnectionPerGroup() { + metrics.Gauge("connections_count_current", v, fmt.Sprintf("conn_group=%s", k)) + } metrics.Gauge("server_go_routines_count_current", runtime.NumGoroutine(), "") runtime.ReadMemStats(m) @@ -68,24 +72,31 @@ func (s *Server) ReportServerMetrics() { func CreateServer() (*Server, chan EventsBatch) { //create the websocket handler that upgrades the http request bufferChannel := make(chan EventsBatch, config.Worker.ChannelSize) - pingChannel := make(chan connection, config.ServerWs.ServerMaxConn) - user := NewUserStore(config.ServerWs.ServerMaxConn) - wsHandler := &Handler{ - websocketUpgrader: newWebSocketUpgrader(config.ServerWs.ReadBufferSize, config.ServerWs.WriteBufferSize, config.ServerWs.CheckOrigin), - bufferChannel: bufferChannel, - user: user, + pingChannel := make(chan connection.Conn, config.ServerWs.ServerMaxConn) + ugConfig := connection.UpgraderConfig{ + ReadBufferSize: config.ServerWs.ReadBufferSize, + WriteBufferSize: config.ServerWs.WriteBufferSize, + CheckOrigin: config.ServerWs.CheckOrigin, + MaxUser: config.ServerWs.ServerMaxConn, PongWaitInterval: config.ServerWs.PongWaitInterval, WriteWaitInterval: config.ServerWs.WriteWaitInterval, - PingChannel: pingChannel, - UniqConnIDHeader: config.ServerWs.UniqConnIDHeader, + ConnIDHeader: config.ServerWs.ConnIDHeader, + ConnGroupHeader: config.ServerWs.ConnGroupHeader, + ConnGroupDefault: config.ServerWs.ConnGroupDefault, + } + upgrader := connection.NewUpgrader(ugConfig) + wsHandler := &Handler{ + upgrader: upgrader, + bufferChannel: bufferChannel, + PingChannel: pingChannel, } server := &Server{ HTTPServer: &http.Server{ Handler: Router(wsHandler), Addr: ":" + config.ServerWs.AppPort, }, + table: upgrader.Table, bufferChannel: bufferChannel, - user: user, pingChannel: pingChannel, } //Wrap the handler with a Server instance and return it @@ -100,18 +111,3 @@ func Router(h *Handler) http.Handler { subRouter.HandleFunc("/events", h.HandlerWSEvents).Methods(http.MethodGet).Name("events") return router } - -func newWebSocketUpgrader(readBufferSize int, writeBufferSize int, checkOrigin bool) websocket.Upgrader { - var checkOriginFunc func(r *http.Request) bool - if checkOrigin == false { - checkOriginFunc = func(r *http.Request) bool { - return true - } - } - ug := websocket.Upgrader{ - ReadBufferSize: readBufferSize, - WriteBufferSize: writeBufferSize, - CheckOrigin: checkOriginFunc, - } - return ug -} diff --git a/websocket/userstore.go b/websocket/userstore.go deleted file mode 100644 index 50a935ad..00000000 --- a/websocket/userstore.go +++ /dev/null @@ -1,46 +0,0 @@ -package websocket - -import "sync" - -type User struct { - m sync.Mutex - userMap map[string]string - maxUser int -} - -func NewUserStore(maxUser int) *User { - return &User{ - m: sync.Mutex{}, - userMap: make(map[string]string), - maxUser: maxUser, - } -} - -func (u *User) Exists(userID string) bool { - u.m.Lock() - defer u.m.Unlock() - _, ok := u.userMap[userID] - return ok -} - -func (u *User) Store(userID string) { - u.m.Lock() - defer u.m.Unlock() - u.userMap[userID] = userID -} - -func (u *User) Remove(userID string) { - u.m.Lock() - defer u.m.Unlock() - delete(u.userMap, userID) -} - -func (u *User) HasReachedLimit() bool { - return u.TotalUsers() >= u.maxUser -} - -func (u *User) TotalUsers() int { - u.m.Lock() - defer u.m.Unlock() - return len(u.userMap) -} diff --git a/worker/worker.go b/worker/worker.go index 72cbaa1e..2f0c9d17 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -59,14 +59,13 @@ func (w *Pool) StartWorkers() { logger.Debug(fmt.Sprintf("Success sending messages, %v", lenBatch-int64(totalErr))) if lenBatch > 0 { eventTimingMs := time.Since(time.Unix(request.EventReq.SentTime.Seconds, 0)).Milliseconds() / lenBatch - logger.Debug(fmt.Sprintf("Currenttime: %d, eventTimingMs: %d, UniqConnID: %s, ReqGUID: %s", request.EventReq.SentTime.Seconds, eventTimingMs, request.UniqConnID, request.EventReq.ReqGuid)) - metrics.Timing("event_processing_duration_milliseconds", eventTimingMs, "") + metrics.Timing("event_processing_duration_milliseconds", eventTimingMs, fmt.Sprintf("conn_group=%s", request.ConnIdentifier.Group)) now := time.Now() metrics.Timing("worker_processing_duration_milliseconds", (now.Sub(batchReadTime).Milliseconds())/lenBatch, "worker="+workerName) - metrics.Timing("server_processing_latency_milliseconds", (now.Sub(request.TimeConsumed)).Milliseconds()/lenBatch, "") + metrics.Timing("server_processing_latency_milliseconds", (now.Sub(request.TimeConsumed)).Milliseconds()/lenBatch, fmt.Sprintf("conn_group=%s", request.ConnIdentifier.Group)) } - metrics.Count("kafka_messages_delivered_total", totalErr, "success=false") - metrics.Count("kafka_messages_delivered_total", len(request.EventReq.GetEvents())-totalErr, "success=true") + metrics.Count("kafka_messages_delivered_total", totalErr, fmt.Sprintf("success=false,conn_group=%s", request.ConnIdentifier.Group)) + metrics.Count("kafka_messages_delivered_total", len(request.EventReq.GetEvents())-totalErr, fmt.Sprintf("success=true,conn_group=%s", request.ConnIdentifier.Group)) } w.wg.Done() }(fmt.Sprintf("worker-%d", i))