Skip to content
This repository has been archived by the owner on Mar 29, 2024. It is now read-only.

Commit

Permalink
Limit concurrent connection attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
dhaavi committed Oct 5, 2023
1 parent c7d4ec8 commit a7622da
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 45 deletions.
98 changes: 59 additions & 39 deletions crew/op_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ type ConnectOp struct {
doneWriting chan struct{}

// Metrics
incomingTraffic *uint64
outgoingTraffic *uint64
incomingTraffic atomic.Uint64
outgoingTraffic atomic.Uint64
started time.Time

// Connection
Expand Down Expand Up @@ -157,8 +157,6 @@ func NewConnectOp(tunnel *Tunnel) (*ConnectOp, *terminal.Error) {
}

// Setup metrics.
op.incomingTraffic = new(uint64)
op.outgoingTraffic = new(uint64)
op.started = time.Now()

module.StartWorker("connect op conn reader", op.connReader)
Expand Down Expand Up @@ -203,56 +201,75 @@ func startConnectOp(t terminal.Terminal, opID uint32, data *container.Container)
op.ctx, op.cancelCtx = context.WithCancel(t.Ctx())
op.dfq = terminal.NewDuplexFlowQueue(op.Ctx(), request.QueueSize, op.submitUpstream)

// Setup metrics.
op.incomingTraffic = new(uint64)
op.outgoingTraffic = new(uint64)

// Start worker to complete setting up the connection.
module.StartWorker("connect op setup", op.setup)
module.StartWorker("connect op setup", op.handleSetup)

return op, nil
}

func (op *ConnectOp) setup(_ context.Context) error {
func (op *ConnectOp) handleSetup(_ context.Context) error {
// Get terminal session for rate limiting.
var session *terminal.Session
if sessionTerm, ok := op.t.(terminal.SessionTerminal); ok {
session = sessionTerm.GetSession()
} else {
log.Errorf("spn/crew: %T is not a session terminal", op.t)
log.Errorf("spn/crew: %T is not a session terminal, aborting op %s#%d", op.t, op.t.FmtID(), op.ID())
op.Stop(op, terminal.ErrInternalError.With("no session available"))
return nil
}

// Limit concurrency of connecting.
cancelErr := session.LimitConcurrency(op.Ctx(), func() {
op.setup(session)
})

// If context was canceled, stop operation.
if cancelErr != nil {
op.Stop(op, terminal.ErrCanceled.With(cancelErr.Error()))
}

// Do not return a worker error.
return nil
}

func (op *ConnectOp) setup(session *terminal.Session) {
// Rate limit before connecting.
if tErr := session.RateLimit(); tErr != nil {
// Fake connection error when rate limited.
if tErr.Is(terminal.ErrRateLimited) {
log.Debugf("spn/crew: op %s#%d is rate limited: %s", op.t.FmtID(), op.ID(), session.RateLimitInfo())
}
op.Stop(op, tErr)
return
}

// Check if connection target is in global scope.
ipScope := netutils.GetIPScope(op.request.IP)
if ipScope != netutils.Global {
session.ReportSuspiciousActivity(terminal.SusFactorQuiteUnusual)
op.Stop(op, terminal.ErrPermissionDenied.With("denied request to connect to non-global IP %s", op.request.IP))
return nil
return
}

// Check exit policy.
if tErr := checkExitPolicy(op.request); tErr != nil {
session.ReportSuspiciousActivity(terminal.SusFactorQuiteUnusual)
op.Stop(op, tErr)
return nil
return
}

// Rate limit before connecting.
if tErr := session.RateLimit(); tErr != nil {
// Fake connection error when rate limited.
if tErr.Is(terminal.ErrRateLimited) {
log.Debugf("spn/crew: op %s#%d is rate limited: %s", op.t.FmtID(), op.ID(), session.RateLimitInfo())
}
op.Stop(op, tErr)
return nil
// Check one last time before connecting if operation was not canceled.
if op.Ctx().Err() != nil {
op.Stop(op, terminal.ErrCanceled.With(op.Ctx().Err().Error()))
return
}

// Connect to destination.
dialNet := op.request.DialNetwork()
if dialNet == "" {
session.ReportSuspiciousActivity(terminal.SusFactorCommon)
op.Stop(op, terminal.ErrIncorrectUsage.With("protocol %s is not supported", op.request.Protocol))
return nil
return
}
dialer := &net.Dialer{
Timeout: 10 * time.Second,
Expand All @@ -274,7 +291,7 @@ func (op *ConnectOp) setup(_ context.Context) error {
}

op.Stop(op, terminal.ErrConnectionError.With("failed to connect to %s: %w", op.request, err))
return nil
return
}
op.conn = conn

Expand All @@ -284,7 +301,6 @@ func (op *ConnectOp) setup(_ context.Context) error {
module.StartWorker("connect op flow handler", op.dfq.FlowHandler)

log.Infof("spn/crew: connected op %s#%d to %s", op.t.FmtID(), op.ID(), op.request)
return nil
}

func (op *ConnectOp) submitUpstream(msg *terminal.Msg, timeout time.Duration) {
Expand Down Expand Up @@ -313,7 +329,7 @@ func (op *ConnectOp) connReader(_ context.Context) error {
defer func() {
atomic.AddInt64(activeConnectOps, -1)
connectOpDurationHistogram.UpdateDuration(op.started)
connectOpIncomingDataHistogram.Update(float64(atomic.LoadUint64(op.incomingTraffic)))
connectOpIncomingDataHistogram.Update(float64(op.incomingTraffic.Load()))
}()

rateLimiter := terminal.NewRateLimiter(rateLimitMaxMbit)
Expand All @@ -337,7 +353,7 @@ func (op *ConnectOp) connReader(_ context.Context) error {

// Submit metrics.
connectOpIncomingBytes.Add(n)
inBytes := atomic.AddUint64(op.incomingTraffic, uint64(n))
inBytes := op.incomingTraffic.Add(uint64(n))

// Rate limit if over threshold.
if inBytes > rateLimitThreshold {
Expand Down Expand Up @@ -376,7 +392,7 @@ func (op *ConnectOp) Deliver(msg *terminal.Msg) *terminal.Error {
func (op *ConnectOp) connWriter(_ context.Context) error {
// Metrics submitting.
defer func() {
connectOpOutgoingDataHistogram.Update(float64(atomic.LoadUint64(op.outgoingTraffic)))
connectOpOutgoingDataHistogram.Update(float64(op.outgoingTraffic.Load()))
}()

defer func() {
Expand Down Expand Up @@ -420,7 +436,7 @@ writing:

// Submit metrics.
connectOpOutgoingBytes.Add(len(data))
out := atomic.AddUint64(op.outgoingTraffic, uint64(len(data)))
out := op.outgoingTraffic.Add(uint64(len(data)))

// Rate limit if over threshold.
if out > rateLimitThreshold {
Expand Down Expand Up @@ -479,17 +495,21 @@ func (op *ConnectOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Erro
reportConnectError(err)
}

// If the op was ended locally, send all data before closing.
// If the op was ended remotely, don't bother sending remaining data.
if !err.IsExternal() {
// Flushing could mean sending a full buffer of 50000 packets.
op.dfq.Flush(5 * time.Minute)
}
// If the connection has sent or received any data so far, finish the data
// flows as it makes sense.
if op.incomingTraffic.Load() > 0 || op.outgoingTraffic.Load() > 0 {
// If the op was ended locally, send all data before closing.
// If the op was ended remotely, don't bother sending remaining data.
if !err.IsExternal() {
// Flushing could mean sending a full buffer of 50000 packets.
op.dfq.Flush(5 * time.Minute)
}

// If the op was ended remotely, write all remaining received data.
// If the op was ended locally, don't bother writing remaining data.
if err.IsExternal() {
<-op.doneWriting
// If the op was ended remotely, write all remaining received data.
// If the op was ended locally, don't bother writing remaining data.
if err.IsExternal() {
<-op.doneWriting
}
}

// Cancel workers.
Expand All @@ -499,7 +519,7 @@ func (op *ConnectOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Erro
// error and no data was received.
if op.entry && // On clients only.
err.Is(terminal.ErrConnectionError) &&
atomic.LoadUint64(op.outgoingTraffic) == 0 {
op.outgoingTraffic.Load() == 0 {
// Only if no data was received (ie. sent to local application).
op.tunnel.avoidDestinationHub()
}
Expand Down
51 changes: 45 additions & 6 deletions terminal/session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package terminal

import (
"context"
"fmt"
"sync"
"sync/atomic"
Expand All @@ -15,6 +16,9 @@ const (

rateLimitMinSuspicion = 25
rateLimitMaxSuspicionPerSecond = 2 // TODO: Reduce to 1 after test phase.

// Make this big enough to trigger suspicion limit in first blast.
concurrencyPoolSize = 30
)

// Session holds terminal metadata for operations.
Expand All @@ -34,6 +38,8 @@ type Session struct {
// Every suspicious operations is counted as at least 1.
// Rate limited operations because of suspicion are also counted as 1.
suspicionScore atomic.Int64

concurrencyPool chan struct{}
}

// SessionTerminal is an interface for terminals that support authorization.
Expand All @@ -56,14 +62,20 @@ func (t *SessionAddOn) GetSession() *Session {

// Create session if it does not exist.
if t.session == nil {
t.session = &Session{
started: time.Now().Unix() - 1, // Ensure a 1 second difference to current time.
}
t.session = NewSession()
}

return t.session
}

// NewSession returns a new session.
func NewSession() *Session {
return &Session{
started: time.Now().Unix() - 1, // Ensure a 1 second difference to current time.
concurrencyPool: make(chan struct{}, concurrencyPoolSize),
}
}

// RateLimitInfo returns some basic information about the status of the rate limiter.
func (s *Session) RateLimitInfo() string {
secondsActive := time.Now().Unix() - s.started
Expand All @@ -82,7 +94,7 @@ func (s *Session) RateLimit() *Error {

// Check the suspicion limit.
score := s.suspicionScore.Load()
if score >= rateLimitMinSuspicion {
if score > rateLimitMinSuspicion {
scorePerSecond := score / secondsActive
if scorePerSecond >= rateLimitMaxSuspicionPerSecond {
// Add current try to suspicion score.
Expand All @@ -94,7 +106,7 @@ func (s *Session) RateLimit() *Error {

// Check the rate limit.
count := s.opCount.Add(1)
if count >= rateLimitMinOps {
if count > rateLimitMinOps {
opsPerSecond := count / secondsActive
if opsPerSecond >= rateLimitMaxOpsPerSecond {
return ErrRateLimited
Expand All @@ -114,6 +126,33 @@ const (

// ReportSuspiciousActivity reports suspicious activity of the terminal.
func (s *Session) ReportSuspiciousActivity(factor int64) {
log.Debugf("session: suspicion raised by %d", factor)
s.suspicionScore.Add(factor)
}

// LimitConcurrency limits concurrent executions.
// If over the limit, waiting goroutines are selected randomly.
// It returns the context error if it was canceled.
func (s *Session) LimitConcurrency(ctx context.Context, f func()) error {
// Wait for place in pool.
select {
case <-ctx.Done():
return ctx.Err()
case s.concurrencyPool <- struct{}{}:
// We added our entry to the pool, continue with execution.
}

// Drain own spot if pool after execution.
defer func() {
select {
case <-s.concurrencyPool:
// Own entry drained.
default:
// This should never happen, but let's play safe and not deadlock when pool is empty.
log.Warningf("spn/session: failed to drain own entry from concurrency pool")
}
}()

// Execute and return.
f()
return nil
}

0 comments on commit a7622da

Please sign in to comment.