Skip to content

Commit

Permalink
Better logging, allow exclusion of certain IPs from max_throttled_con…
Browse files Browse the repository at this point in the history
…nections_per_ip
  • Loading branch information
ayyghost committed Mar 9, 2024
1 parent dc81ad3 commit 2a41b5e
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 46 deletions.
10 changes: 6 additions & 4 deletions config.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
{
"listen_address": ":8008",
"websocket_endpoint": "/ws",
"upstream_websocket_url": "ws://cryptodog-server:8009/ws",
"listen_address": ":8111",
"upstream_websocket_url": "ws://ejabberd:5280/websocket",
"use_x_forwarded_for": true,
"num_proxies": 2,
"max_throttled_connections_per_ip": 3,
"max_throttled_connections_exclude_cidrs": ["127.0.0.1/32"],
"max_message_size": 250000,
"rate_limit": 250000,
"rate_measure_period": 5,
"ping_period": 30
"ping_period": 30,
"log_websocket_messages": false
}
15 changes: 15 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module github.com/cryptodog/updog

go 1.22.0

require (
github.com/gorilla/websocket v1.5.1
github.com/rs/zerolog v1.32.0
)

require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.17.0 // indirect
)
21 changes: 21 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
187 changes: 145 additions & 42 deletions proxy.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,41 @@
package main

import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"unicode"

"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)

type Config struct {
ListenAddress string `json:"listen_address"`
WebsocketEndpoint string `json:"websocket_endpoint"`
ListenAddress string `json:"listen_address"`

// WebSocket URL of XMPP server.
UpstreamWebsocketURL string `json:"upstream_websocket_url"`

// For when we're behind another proxy, such as nginx.
UseXForwardedFor bool `json:"use_x_forwarded_for"`

// The number of proxies in front of the service. Used to retrieve the correct client IP.
NumProxies int `json:"num_proxies"`

// How many connections per IP are allowed to max out the rate limit at any given time?
MaxThrottledConnectionsPerIP int `json:"max_throttled_connections_per_ip"`

// Network ranges to exclude from the MaxThrottledConnectionsPerIP count. Other rate limits will still apply.
MaxThrottledConnectionsExcludeCIDRs []string `json:"max_throttled_connections_exclude_cidrs"`

// The maxiumum allowed size, in bytes, for incoming WebSocket messages from clients.
// Offending clients will have their connection terminated.
MaxMessageSize int64 `json:"max_message_size"`
Expand All @@ -41,9 +51,10 @@ type Config struct {

// How often, in seconds, to send pings to downstream client.
PingPeriod int `json:"ping_period"`
}

const configFile = "config.json"
// Whether to log WebSocket messages from the client.
LogWebSocketMessages bool `json:"log_websocket_messages"`
}

var config Config

Expand All @@ -56,29 +67,84 @@ var upgrader = websocket.Upgrader{

var throttledConnectionsPerIP = make(map[string]int)
var throttledConnectionsPerIPLock = sync.Mutex{}
var maxThrottledConnectionsExcludeCIDRs []*net.IPNet

func isControlMessage(messageType int) bool {
return (messageType != websocket.TextMessage) && (messageType != websocket.BinaryMessage)
}

func proxy(w http.ResponseWriter, r *http.Request) {
var ip string
func newID() string {
id := make([]byte, 12)
_, err := rand.Read(id)
if err != nil {
panic(err)
}
return hex.EncodeToString(id)
}

// TODO: verify this works for IPv6
func getClientIP(r *http.Request) (string, error) {
if config.UseXForwardedFor {
forwarded := strings.Split(r.Header.Get("X-Forwarded-For"), ",")
ip = forwarded[len(forwarded)-1]
if len(forwarded) < config.NumProxies {
return "", fmt.Errorf("number of proxies is %v but x-forwarded-for has only %v entries (%v)",
config.NumProxies, len(forwarded), forwarded)
}
return strings.TrimSpace(forwarded[len(forwarded)-config.NumProxies]), nil
} else {
ip = strings.Split(r.RemoteAddr, ":")[0]
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
panic(err)
}
return ip, nil
}
}

func cidrListContainsIP(cidrs []*net.IPNet, ip string) (bool, error) {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false, fmt.Errorf("unable to parse IP '%v'", ip)
}

for _, cidr := range cidrs {
if cidr.Contains(parsedIP) {
return true, nil
}
}
return false, nil
}

func isPrint(bs []byte) bool {
for _, b := range bs {
if !unicode.IsPrint(rune(b)) {
return false
}
}
return true
}

func proxy(w http.ResponseWriter, r *http.Request) {
// Could be GET or HEAD
if r.Method != "GET" {
return
}

id := newID()
ip, err := getClientIP(r)
if err != nil {
panic(err)
}
origin := r.Header.Get("Origin")

logger := zerolog.New(os.Stderr).With().Timestamp().Str("request_id", id).Str("client_ip", ip).Logger()
logger.Info().Msg(fmt.Sprintf("request from origin '%v'", origin))

header := http.Header{}
header.Add("Origin", r.Header.Get("Origin"))
header.Add("Origin", origin)
header.Add("Sec-WebSocket-Protocol", r.Header.Get("Sec-WebSocket-Protocol"))

upstream, resp, err := websocket.DefaultDialer.Dial(config.UpstreamWebsocketURL, header)
if err != nil {
log.Printf("[client: %v] dial upstream: %v", ip, err)
logger.Error().Err(fmt.Errorf("dial upstream: %v", err)).Send()
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -89,11 +155,19 @@ func proxy(w http.ResponseWriter, r *http.Request) {

downstream, err := upgrader.Upgrade(w, r, header)
if err != nil {
log.Printf("[client: %v] upgrade: %v", ip, err)
logger.Error().Err(fmt.Errorf("upgrade: %v", err)).Send()
return
}
defer downstream.Close()

excludeFromThrottledConnectionsPerIP, err := cidrListContainsIP(maxThrottledConnectionsExcludeCIDRs, ip)
if err != nil {
panic(err)
}
if excludeFromThrottledConnectionsPerIP {
logger.Info().Msg("excluding from throttled connections per IP limit")
}

downstream.SetReadLimit(config.MaxMessageSize)

pongWait := time.Duration(config.PingPeriod*2) * time.Second
Expand All @@ -108,7 +182,7 @@ func proxy(w http.ResponseWriter, r *http.Request) {
go func() {
byteCount := 0
msgCount := 0
var startedCountingAt time.Time
startedCountingAt := time.Now()
measureLock := sync.Mutex{}

go func() {
Expand Down Expand Up @@ -146,34 +220,48 @@ func proxy(w http.ResponseWriter, r *http.Request) {
// XXX: magic
// TODO: make this a config option
if mc > 100 {
errc <- fmt.Errorf("sending messages too fast!")
errc <- fmt.Errorf("sending messages too fast (%v in %v)", mc, period)
break
}

rate := volume / period.Seconds()

if rate > float64(config.RateLimit) {
throttledConnectionsPerIPLock.Lock()
throttledConnectionsPerIP[ip]++
tc := throttledConnectionsPerIP[ip]
throttledConnectionsPerIPLock.Unlock()
var tc int

if !excludeFromThrottledConnectionsPerIP {
throttledConnectionsPerIPLock.Lock()
throttledConnectionsPerIP[ip]++
tc = throttledConnectionsPerIP[ip]
throttledConnectionsPerIPLock.Unlock()
}

// Throttle down to the specified rate limit.
time.Sleep(time.Duration((float64(time.Second)*volume)/float64(config.RateLimit)) - period)

throttledConnectionsPerIPLock.Lock()
if throttledConnectionsPerIP[ip] > 0 {
throttledConnectionsPerIP[ip]--
} else {
// This invariant should never be violated.
// XXX: remove after testing
panic("throttledConnectionsPerIP[ip] <= 0")
if !excludeFromThrottledConnectionsPerIP {
throttledConnectionsPerIPLock.Lock()
if throttledConnectionsPerIP[ip] > 0 {
throttledConnectionsPerIP[ip]--
} else {
// This invariant should never be violated.
// XXX: remove after testing
panic("throttledConnectionsPerIP[ip] <= 0")
}
throttledConnectionsPerIPLock.Unlock()

if tc > config.MaxThrottledConnectionsPerIP {
errc <- fmt.Errorf("too many throttled connections")
break
}
}
throttledConnectionsPerIPLock.Unlock()
}

if tc > config.MaxThrottledConnectionsPerIP {
errc <- fmt.Errorf("too many throttled connections!")
break
if config.LogWebSocketMessages {
if isPrint(msg) {
logger.Info().Str("websocket_message", string(msg)).Send()
} else {
logger.Info().Str("websocket_message_base64", base64.StdEncoding.EncodeToString(msg)).Send()
}
}
}
Expand Down Expand Up @@ -272,30 +360,45 @@ func proxy(w http.ResponseWriter, r *http.Request) {
}
}()

log.Printf("[client: %v] disconnect: %v", ip, <-errc)
logger.Warn().Err(fmt.Errorf("disconnect: %v", <-errc)).Send()

throttledConnectionsPerIPLock.Lock()
if throttledConnectionsPerIP[ip] == 0 {
// Doesn't matter if there are other connections from this IP; zero-value takes care of it.
delete(throttledConnectionsPerIP, ip)
if !excludeFromThrottledConnectionsPerIP {
throttledConnectionsPerIPLock.Lock()
if n, ok := throttledConnectionsPerIP[ip]; ok && n == 0 {
// Doesn't matter if there are other connections from this IP; zero-value takes care of it.
delete(throttledConnectionsPerIP, ip)
}
throttledConnectionsPerIPLock.Unlock()
}
throttledConnectionsPerIPLock.Unlock()

// Signal remaining routines to clean up.
close(done)
}

func main() {
b, err := ioutil.ReadFile(configFile)
const configFile = "config.json"
b, err := os.ReadFile(configFile)
if err != nil {
log.Fatal(err)
panic(err)
}

err = json.Unmarshal(b, &config)
if err != nil {
log.Fatal(err)
panic(err)
}

if config.UseXForwardedFor && config.NumProxies == 0 {
panic("using x-forwarded-for but number of proxies is 0")
}

for _, s := range config.MaxThrottledConnectionsExcludeCIDRs {
_, cidr, err := net.ParseCIDR(s)
if err != nil {
panic(fmt.Errorf("could not parse cidr %v: %v", s, err))
}
maxThrottledConnectionsExcludeCIDRs = append(maxThrottledConnectionsExcludeCIDRs, cidr)
}

http.HandleFunc(config.WebsocketEndpoint, proxy)
log.Fatal(http.ListenAndServe(config.ListenAddress, nil))
http.HandleFunc("GET /", proxy)
panic(http.ListenAndServe(config.ListenAddress, nil))
}

0 comments on commit 2a41b5e

Please sign in to comment.