Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SmokeScreen Context Fields Public #234

Merged
merged 3 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 46 additions & 45 deletions pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ const (
type ipType int

type ACLDecision struct {
reason, role, project, outboundHost string
Reason, Role, Project, OutboundHost string
ResolvedAddr *net.TCPAddr
allow bool
enforceWouldDeny bool
Expand All @@ -79,9 +79,9 @@ type SmokescreenContext struct {
cfg *Config
start time.Time
Decision *ACLDecision
proxyType string
logger *logrus.Entry
requestedHost string
ProxyType string
Logger *logrus.Entry
RequestedHost string

// Time spent resolving the requested hostname
lookupTime time.Duration
Expand Down Expand Up @@ -257,11 +257,11 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
}
d := sctx.Decision

// If an address hasn't been resolved, does not match the original outboundHost,
// If an address hasn't been resolved, does not match the original OutboundHost,
// or is not tcp we must re-resolve it before establishing the connection.
if d.ResolvedAddr == nil || d.outboundHost != addr || network != "tcp" {
if d.ResolvedAddr == nil || d.OutboundHost != addr || network != "tcp" {
var err error
d.ResolvedAddr, d.reason, err = safeResolve(sctx.cfg, network, addr)
d.ResolvedAddr, d.Reason, err = safeResolve(sctx.cfg, network, addr)
if err != nil {
if _, ok := err.(denyError); ok {
sctx.cfg.Log.WithFields(
Expand Down Expand Up @@ -289,25 +289,25 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err = sctx.cfg.ProxyDialTimeout(ctx, network, d.ResolvedAddr.String(), sctx.cfg.ConnectTimeout)
}
connTime := time.Since(start)
sctx.logger = sctx.logger.WithFields(dialContextLoggerFields(pctx, sctx, conn, connTime))
sctx.Logger = sctx.Logger.WithFields(dialContextLoggerFields(pctx, sctx, conn, connTime))

if sctx.cfg.TimeConnect {
sctx.cfg.MetricsClient.TimingWithTags("cn.atpt.connect.time", connTime, map[string]string{"domain": sctx.requestedHost}, 1)
sctx.cfg.MetricsClient.TimingWithTags("cn.atpt.connect.time", connTime, map[string]string{"domain": sctx.RequestedHost}, 1)
}

if err != nil {
sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "false"}, 1)
sctx.cfg.ConnTracker.RecordAttempt(sctx.requestedHost, false)
sctx.cfg.ConnTracker.RecordAttempt(sctx.RequestedHost, false)
metrics.ReportConnError(sctx.cfg.MetricsClient, err)
return nil, err
}
sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "true"}, 1)
sctx.cfg.ConnTracker.RecordAttempt(sctx.requestedHost, true)
sctx.cfg.ConnTracker.RecordAttempt(sctx.RequestedHost, true)

// Only wrap CONNECT conns with an InstrumentedConn. Connections used for traditional HTTP proxy
// requests are pooled and reused by net.Transport.
if sctx.proxyType == connectProxy {
ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.logger, d.role, d.outboundHost, sctx.proxyType)
if sctx.ProxyType == connectProxy {
ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.Logger, d.Role, d.OutboundHost, sctx.ProxyType)
pctx.ConnErrorHandler = ic.Error
conn = ic
} else {
Expand Down Expand Up @@ -346,11 +346,11 @@ func HTTPErrorHandler(w io.WriteCloser, pctx *goproxy.ProxyCtx, err error) {
resp := rejectResponse(pctx, err)

if err := resp.Write(w); err != nil {
sctx.logger.Errorf("Failed to write HTTP error response: %s", err)
sctx.Logger.Errorf("Failed to write HTTP error response: %s", err)
}

if err := w.Close(); err != nil {
sctx.logger.Errorf("Failed to close proxy client connection: %s", err)
sctx.Logger.Errorf("Failed to close proxy client connection: %s", err)
}
}

Expand Down Expand Up @@ -384,12 +384,12 @@ func rejectResponse(pctx *goproxy.ProxyCtx, err error) *http.Response {
status = "Internal server error"
code = http.StatusInternalServerError
msg = "An unexpected error occurred: " + err.Error()
sctx.logger.WithField("error", err.Error()).Warn("rejectResponse called with unexpected error")
sctx.Logger.WithField("error", err.Error()).Warn("rejectResponse called with unexpected error")
}

// Do not double log deny errors, they are logged in a previous call to logProxy.
if _, ok := err.(denyError); !ok {
sctx.logger.Error(msg)
sctx.Logger.Error(msg)
}

if sctx.cfg.AdditionalErrorMessageOnDeny != "" {
Expand Down Expand Up @@ -438,10 +438,10 @@ func newContext(cfg *Config, proxyType string, req *http.Request) *SmokescreenCo

return &SmokescreenContext{
cfg: cfg,
logger: logger,
proxyType: proxyType,
Logger: logger,
ProxyType: proxyType,
start: start,
requestedHost: req.Host,
RequestedHost: req.Host,
}
}

Expand Down Expand Up @@ -493,7 +493,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
req.Header.Del(traceHeader)
}()

sctx.logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request")
sctx.Logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request")
// Build an address parsable by net.ResolveTCPAddr
destination, err := hostport.NewWithScheme(req.Host, req.URL.Scheme, false)
if err != nil {
Expand All @@ -510,7 +510,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
return req, rejectResponse(pctx, pctx.Error)
}
if !sctx.Decision.allow {
return req, rejectResponse(pctx, denyError{errors.New(sctx.Decision.reason)})
return req, rejectResponse(pctx, denyError{errors.New(sctx.Decision.Reason)})
}

// Call the custom request handler if it exists
Expand Down Expand Up @@ -576,7 +576,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
// We don't want to log if the connection is a MITM as it will be done in HandleConnectFunc
if pctx.ConnectAction != goproxy.ConnectMitm {
// In case of an error, this function is called a second time to filter the
// response we generate so this logger will be called once.
// response we generate so this Logger will be called once.
logProxy(pctx)
}
return resp
Expand All @@ -586,6 +586,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
// The goproxy OnResponse() function above is only called for non-https responses.
if config.AcceptResponseHandler != nil {
proxy.ConnectRespHandler = func(pctx *goproxy.ProxyCtx, resp *http.Response) error {

sctx, ok := pctx.UserData.(*SmokescreenContext)
if !ok {
return fmt.Errorf("goproxy ProxyContext missing required UserData *SmokescreenContext")
Expand All @@ -611,7 +612,7 @@ func logProxy(pctx *goproxy.ProxyCtx) {
}

if sctx.Decision != nil {
fields[LogFieldDecisionReason] = decision.reason
fields[LogFieldDecisionReason] = decision.Reason
fields[LogFieldEnforceWouldDeny] = decision.enforceWouldDeny
fields[LogFieldAllow] = decision.allow
}
Expand All @@ -621,7 +622,7 @@ func logProxy(pctx *goproxy.ProxyCtx) {
fields[LogFieldError] = err.Error()
}

entry := sctx.logger.WithFields(fields)
entry := sctx.Logger.WithFields(fields)
var logMethod func(...interface{})
if _, ok := err.(denyError); !ok && err != nil {
logMethod = entry.Error
Expand All @@ -648,16 +649,16 @@ func extractContextLogFields(pctx *goproxy.ProxyCtx, sctx *SmokescreenContext) l
// Retrieve information from the ACL decision
decision := sctx.Decision
if sctx.Decision != nil {
fields[LogFieldRole] = decision.role
fields[LogFieldProject] = decision.project
fields[LogFieldRole] = decision.Role
fields[LogFieldProject] = decision.Project
}
return fields
}

func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string, error) {
sctx := pctx.UserData.(*SmokescreenContext)

// Check if requesting role is allowed to talk to remote
// Check if requesting Role is allowed to talk to remote
destination, err := hostport.New(pctx.Req.Host, false)
if err != nil {
pctx.Error = denyError{err}
Expand All @@ -673,11 +674,11 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (*goproxy.ConnectActi
return nil, "", pctx.Error
}

// add context fields to all future log messages sent using this smokescreen context's logger
sctx.logger = sctx.logger.WithFields(extractContextLogFields(pctx, sctx))
// add context fields to all future log messages sent using this smokescreen context's Logger
sctx.Logger = sctx.Logger.WithFields(extractContextLogFields(pctx, sctx))

if !sctx.Decision.allow {
return nil, "", denyError{errors.New(sctx.Decision.reason)}
return nil, "", denyError{errors.New(sctx.Decision.Reason)}
}

// Call the custom request handler if it exists
Expand All @@ -696,7 +697,7 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (*goproxy.ConnectActi
deny := denyError{errors.New("ACLDecision specified MITM but Smokescreen doesn't have MITM enabled")}
sctx.Decision.allow = false
sctx.Decision.MitmConfig = nil
sctx.Decision.reason = deny.Error()
sctx.Decision.Reason = deny.Error()
return nil, "", deny
}
mitm := sctx.Decision.MitmConfig
Expand Down Expand Up @@ -912,10 +913,10 @@ func runServer(config *Config, server *http.Server, listener net.Listener, quit
}
}

// Extract the client's ACL role from the HTTP request, using the configured
// RoleFromRequest function. Returns the role, or an error if the role cannot
// Extract the client's ACL Role from the HTTP request, using the configured
// RoleFromRequest function. Returns the Role, or an error if the Role cannot
// be determined (including no RoleFromRequest configured), unless
// AllowMissingRole is configured, in which case an empty role and no error is
// AllowMissingRole is configured, in which case an empty Role and no error is
// returned.
func getRole(config *Config, req *http.Request) (string, error) {
var role string
Expand Down Expand Up @@ -955,7 +956,7 @@ func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destinatio
if _, ok := err.(denyError); !ok {
return decision, lookupTime, err
}
decision.reason = fmt.Sprintf("%s. %s", err.Error(), reason)
decision.Reason = fmt.Sprintf("%s. %s", err.Error(), reason)
decision.allow = false
decision.enforceWouldDeny = true
} else {
Expand All @@ -968,28 +969,28 @@ func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destinatio

func checkACLsForRequest(config *Config, req *http.Request, destination hostport.HostPort) *ACLDecision {
decision := &ACLDecision{
outboundHost: destination.String(),
OutboundHost: destination.String(),
}

if config.EgressACL == nil {
decision.allow = true
decision.reason = "Egress ACL is not configured"
decision.Reason = "Egress ACL is not configured"
return decision
}

role, roleErr := getRole(config, req)
if roleErr != nil {
config.MetricsClient.Incr("acl.role_not_determined", 1)
decision.reason = "Client role cannot be determined"
decision.Reason = "Client role cannot be determined"
return decision
}

decision.role = role
decision.Role = role

// This host validation prevents IPv6 addresses from being used as destinations.
// Added for backwards compatibility.
if strings.ContainsAny(destination.Host, ":") {
decision.reason = "Destination host cannot be determined"
decision.Reason = "Destination host cannot be determined"
return decision
}

Expand Down Expand Up @@ -1024,8 +1025,8 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport
}

ACLDecision, err := config.EgressACL.Decide(role, destination.Host, connectProxyHost)
decision.project = ACLDecision.Project
decision.reason = ACLDecision.Reason
decision.Project = ACLDecision.Project
decision.Reason = ACLDecision.Reason
decision.MitmConfig = ACLDecision.MitmConfig
if err != nil {
config.Log.WithFields(logrus.Fields{
Expand All @@ -1038,7 +1039,7 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport
}

tags := map[string]string{
"role": decision.role,
"role": decision.Role,
"def_rule": fmt.Sprintf("%t", ACLDecision.Default),
"project": ACLDecision.Project,
}
Expand All @@ -1064,7 +1065,7 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport
"destination": destination.Host,
"action": ACLDecision.Result.String(),
}).Warn("Unknown ACL action")
decision.reason = "Internal error"
decision.Reason = "Internal error"
config.MetricsClient.IncrWithTags("acl.unknown_error", tags, 1)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/smokescreen/smokescreen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func TestProxyTimeouts(t *testing.T) {
// for an EOF returned from HTTP client to indicate a connection interruption
// which in our case represents the timeout.
//
// To correctly hook into this, we'd need to pass a logger from Smokescreen to Goproxy
// To correctly hook into this, we'd need to pass a Logger from Smokescreen to Goproxy
// which we have hooks into. This would be able to verify the timeout as errors from
// each end of the connection pair are logged by Goproxy.
t.Run("CONNECT proxy timeouts", func(t *testing.T) {
Expand Down