Skip to content

Commit

Permalink
X-Forwarded-For (#4380)
Browse files Browse the repository at this point in the history
  • Loading branch information
jefferai committed Apr 17, 2018
1 parent f7e886f commit 80b1770
Show file tree
Hide file tree
Showing 8 changed files with 493 additions and 42 deletions.
34 changes: 27 additions & 7 deletions command/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/helper/gated-writer"
Expand Down Expand Up @@ -92,6 +93,11 @@ type ServerCommand struct {
flagTestVerifyOnly bool
}

type ServerListener struct {
net.Listener
config map[string]interface{}
}

func (c *ServerCommand) Synopsis() string {
return "Start a Vault server"
}
Expand Down Expand Up @@ -670,16 +676,19 @@ CLUSTER_SYNTHESIS_COMPLETE:
clusterAddrs := []*net.TCPAddr{}

// Initialize the listeners
lns := make([]ServerListener, 0, len(config.Listeners))
c.reloadFuncsLock.Lock()
lns := make([]net.Listener, 0, len(config.Listeners))
for i, lnConfig := range config.Listeners {
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logGate, c.UI)
if err != nil {
c.UI.Error(fmt.Sprintf("Error initializing listener of type %s: %s", lnConfig.Type, err))
return 1
}

lns = append(lns, ln)
lns = append(lns, ServerListener{
Listener: ln,
config: lnConfig.Config,
})

if reloadFunc != nil {
relSlice := (*c.reloadFuncs)["listener|"+lnConfig.Type]
Expand Down Expand Up @@ -738,7 +747,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
// Make sure we close all listeners from this point on
listenerCloseFunc := func() {
for _, ln := range lns {
ln.Close()
ln.Listener.Close()
}
}

Expand Down Expand Up @@ -776,12 +785,10 @@ CLUSTER_SYNTHESIS_COMPLETE:
return 0
}

handler := vaulthttp.Handler(core)

// This needs to happen before we first unseal, so before we trigger dev
// mode if it's set
core.SetClusterListenerAddrs(clusterAddrs)
core.SetClusterHandler(handler)
core.SetClusterHandler(vaulthttp.Handler(core))

err = core.UnsealWithStoredKeys(context.Background())
if err != nil {
Expand Down Expand Up @@ -914,10 +921,23 @@ CLUSTER_SYNTHESIS_COMPLETE:

// Initialize the HTTP servers
for _, ln := range lns {
handler := vaulthttp.Handler(core)

// We perform validation on the config earlier, we can just cast here
if _, ok := ln.config["x_forwarded_for_authorized_addrs"]; ok {
hopSkips := ln.config["x_forwarded_for_hop_skips"].(int)
authzdAddrs := ln.config["x_forwarded_for_authorized_addrs"].([]*sockaddr.SockAddrMarshaler)
rejectNotPresent := ln.config["x_forwarded_for_reject_not_present"].(bool)
rejectNonAuthz := ln.config["x_forwarded_for_reject_not_authorized"].(bool)
if len(authzdAddrs) > 0 {
handler = vaulthttp.WrapForwardedForHandler(handler, authzdAddrs, rejectNotPresent, rejectNonAuthz, hopSkips)
}
}

server := &http.Server{
Handler: handler,
}
go server.Serve(ln)
go server.Serve(ln.Listener)
}

if newCoreError != nil {
Expand Down
4 changes: 4 additions & 0 deletions command/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,10 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
"address",
"cluster_address",
"endpoint",
"x_forwarded_for_authorized_addrs",
"x_forwarded_for_hop_skips",
"x_forwarded_for_reject_not_authorized",
"x_forwarded_for_reject_not_present",
"infrastructure",
"node_id",
"proxy_protocol_behavior",
Expand Down
55 changes: 55 additions & 0 deletions command/server/listener_tcp.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package server

import (
"fmt"
"io"
"net"
"strconv"
"strings"
"time"

"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/reload"
"github.com/mitchellh/cli"
)
Expand Down Expand Up @@ -39,6 +43,57 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
}

props := map[string]string{"addr": addr}

ffAllowedRaw, ffAllowedOK := config["x_forwarded_for_authorized_addrs"]
if ffAllowedOK {
ffAllowed, err := parseutil.ParseAddrs(ffAllowedRaw)
if err != nil {
return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_authorized_addrs\": {{err}}", err)
}
props["x_forwarded_for_authorized_addrs"] = fmt.Sprintf("%v", ffAllowed)
config["x_forwarded_for_authorized_addrs"] = ffAllowed
}

if ffHopsRaw, ok := config["x_forwarded_for_hop_skips"]; ok {
ffHops64, err := parseutil.ParseInt(ffHopsRaw)
if err != nil {
return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_hop_skips\": {{err}}", err)
}
if ffHops64 < 0 {
return nil, nil, nil, fmt.Errorf("\"x_forwarded_for_hop_skips\" cannot be negative")
}
ffHops := int(ffHops64)
props["x_forwarded_for_hop_skips"] = strconv.Itoa(ffHops)
config["x_forwarded_for_hop_skips"] = ffHops
} else if ffAllowedOK {
props["x_forwarded_for_hop_skips"] = "0"
config["x_forwarded_for_hop_skips"] = int(0)
}

if ffRejectNotPresentRaw, ok := config["x_forwarded_for_reject_not_present"]; ok {
ffRejectNotPresent, err := parseutil.ParseBool(ffRejectNotPresentRaw)
if err != nil {
return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_reject_not_present\": {{err}}", err)
}
props["x_forwarded_for_reject_not_present"] = strconv.FormatBool(ffRejectNotPresent)
config["x_forwarded_for_reject_not_present"] = ffRejectNotPresent
} else if ffAllowedOK {
props["x_forwarded_for_reject_not_present"] = "true"
config["x_forwarded_for_reject_not_present"] = true
}

if ffRejectNonAuthorizedRaw, ok := config["x_forwarded_for_reject_not_authorized"]; ok {
ffRejectNonAuthorized, err := parseutil.ParseBool(ffRejectNonAuthorizedRaw)
if err != nil {
return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_reject_not_authorized\": {{err}}", err)
}
props["x_forwarded_for_reject_not_authorized"] = strconv.FormatBool(ffRejectNonAuthorized)
config["x_forwarded_for_reject_not_authorized"] = ffRejectNonAuthorized
} else if ffAllowedOK {
props["x_forwarded_for_reject_not_authorized"] = "true"
config["x_forwarded_for_reject_not_authorized"] = true
}

return listenerWrapTLS(ln, props, config, ui)
}

Expand Down
43 changes: 43 additions & 0 deletions helper/parseutil/parseutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package parseutil
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"time"

"github.com/hashicorp/errwrap"
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/helper/strutil"
"github.com/mitchellh/mapstructure"
)
Expand Down Expand Up @@ -118,3 +121,43 @@ func ParseCommaStringSlice(in interface{}) ([]string, error) {
}
return strutil.TrimStrings(result), nil
}

func ParseAddrs(addrs interface{}) ([]*sockaddr.SockAddrMarshaler, error) {
out := make([]*sockaddr.SockAddrMarshaler, 0)
stringAddrs := make([]string, 0)

switch addrs.(type) {
case string:
stringAddrs = strutil.ParseArbitraryStringSlice(addrs.(string), ",")
if len(stringAddrs) == 0 {
return nil, fmt.Errorf("unable to parse addresses from %v", addrs)
}

case []string:
stringAddrs = addrs.([]string)

case []interface{}:
for _, v := range addrs.([]interface{}) {
stringAddr, ok := v.(string)
if !ok {
return nil, fmt.Errorf("error parsing %v as string", v)
}
stringAddrs = append(stringAddrs, stringAddr)
}

default:
return nil, fmt.Errorf("unknown address input type %T", addrs)
}

for _, addr := range stringAddrs {
sa, err := sockaddr.NewSockAddr(addr)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error parsing address %q: {{err}}", addr), err)
}
out = append(out, &sockaddr.SockAddrMarshaler{
SockAddr: sa,
})
}

return out, nil
}
40 changes: 5 additions & 35 deletions helper/proxyutil/proxyutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
proxyproto "github.com/armon/go-proxyproto"
"github.com/hashicorp/errwrap"
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/helper/parseutil"
)

// ProxyProtoConfig contains configuration for the PROXY protocol
Expand All @@ -19,42 +19,12 @@ type ProxyProtoConfig struct {
}

func (p *ProxyProtoConfig) SetAuthorizedAddrs(addrs interface{}) error {
p.AuthorizedAddrs = make([]*sockaddr.SockAddrMarshaler, 0)
stringAddrs := make([]string, 0)

switch addrs.(type) {
case string:
stringAddrs = strutil.ParseArbitraryStringSlice(addrs.(string), ",")
if len(stringAddrs) == 0 {
return fmt.Errorf("unable to parse addresses from %v", addrs)
}

case []string:
stringAddrs = addrs.([]string)

case []interface{}:
for _, v := range addrs.([]interface{}) {
stringAddr, ok := v.(string)
if !ok {
return fmt.Errorf("error parsing %v as string", v)
}
stringAddrs = append(stringAddrs, stringAddr)
}

default:
return fmt.Errorf("unknown address input type %T", addrs)
}

for _, addr := range stringAddrs {
sa, err := sockaddr.NewSockAddr(addr)
if err != nil {
return errwrap.Wrapf("error parsing authorized address: {{err}}", err)
}
p.AuthorizedAddrs = append(p.AuthorizedAddrs, &sockaddr.SockAddrMarshaler{
SockAddr: sa,
})
aa, err := parseutil.ParseAddrs(addrs)
if err != nil {
return err
}

p.AuthorizedAddrs = aa
return nil
}

Expand Down
Loading

0 comments on commit 80b1770

Please sign in to comment.