Skip to content

Commit

Permalink
Merge pull request #108 from hashicorp/f-use-sockaddr
Browse files Browse the repository at this point in the history
F use sockaddr
  • Loading branch information
sean- authored Feb 8, 2017
2 parents 591e85c + e300cfb commit 902b55b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 241 deletions.
57 changes: 24 additions & 33 deletions memberlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/hashicorp/go-multierror"
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/miekg/dns"
)

Expand Down Expand Up @@ -326,7 +327,7 @@ func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) {
// as if we received an alive notification our own network channel for
// ourself.
func (m *Memberlist) setAlive() error {
var advertiseAddr []byte
var advertiseAddr net.IP
var advertisePort int
if m.config.AdvertiseAddr != "" {
// If AdvertiseAddr is not empty, then advertise
Expand All @@ -345,42 +346,21 @@ func (m *Memberlist) setAlive() error {
advertisePort = m.config.AdvertisePort
} else {
if m.config.BindAddr == "0.0.0.0" {
// Otherwise, if we're not bound to a specific IP,
//let's list the interfaces on this machine and use
// the first private IP we find.
addresses, err := net.InterfaceAddrs()
// Otherwise, if we're not bound to a specific IP, let's use a suitable
// private IP address.
var err error
m.config.AdvertiseAddr, err = sockaddr.GetPrivateIP()
if err != nil {
return fmt.Errorf("Failed to get interface addresses! Err: %v", err)
return fmt.Errorf("Failed to get interface addresses: %v", err)
}

// Find private IPv4 address
for _, rawAddr := range addresses {
var ip net.IP
switch addr := rawAddr.(type) {
case *net.IPAddr:
ip = addr.IP
case *net.IPNet:
ip = addr.IP
default:
continue
}

if ip.To4() == nil {
continue
}
if !IsPrivateIP(ip.String()) {
continue
}

advertiseAddr = ip
break
if m.config.AdvertiseAddr == "" {
return fmt.Errorf("No private IP address found, and explicit IP not provided")
}

// Failed to find private IP, error
advertiseAddr = net.ParseIP(m.config.AdvertiseAddr)
if advertiseAddr == nil {
return fmt.Errorf("No private IP address found, and explicit IP not provided")
return fmt.Errorf("Failed to parse advertise address: %q", m.config.AdvertiseAddr)
}

} else {
// Use the IP that we're bound to.
addr := m.tcpListener.Addr().(*net.TCPAddr)
Expand All @@ -392,8 +372,19 @@ func (m *Memberlist) setAlive() error {
}

// Check if this is a public address without encryption
addrStr := net.IP(advertiseAddr).String()
if !IsPrivateIP(addrStr) && !isLoopbackIP(addrStr) && !m.config.EncryptionEnabled() {
ipAddr, err := sockaddr.NewIPAddr(advertiseAddr.String())
if err != nil {
return fmt.Errorf("Failed to parse interface addresses: %v", err)
}

ifAddrs := []sockaddr.IfAddr{
sockaddr.IfAddr{
SockAddr: ipAddr,
},
}

_, publicIfs, err := sockaddr.IfByRFC("6890", ifAddrs)
if len(publicIfs) > 0 && !m.config.EncryptionEnabled() {
m.logger.Printf("[WARN] memberlist: Binding to public address without encryption!")
}

Expand Down
112 changes: 0 additions & 112 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"math"
"math/rand"
"net"
"strings"
"time"

Expand All @@ -22,19 +21,6 @@ import (
// while the 65th will triple it.
const pushPullScaleThreshold = 32

/*
* Contains an entry for each private block:
* 10.0.0.0/8
* 100.64.0.0/10
* 127.0.0.0/8
* 169.254.0.0/16
* 172.16.0.0/12
* 192.168.0.0/16
*/
var privateBlocks []*net.IPNet

var loopbackBlock *net.IPNet

const (
// Constant litWidth 2-8
lzwLitWidth = 8
Expand All @@ -43,51 +29,6 @@ const (
func init() {
// Seed the random number generator
rand.Seed(time.Now().UnixNano())

// Add each private block
privateBlocks = make([]*net.IPNet, 6)

_, block, err := net.ParseCIDR("10.0.0.0/8")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[0] = block

_, block, err = net.ParseCIDR("100.64.0.0/10")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[1] = block

_, block, err = net.ParseCIDR("127.0.0.0/8")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[2] = block

_, block, err = net.ParseCIDR("169.254.0.0/16")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[3] = block

_, block, err = net.ParseCIDR("172.16.0.0/12")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[4] = block

_, block, err = net.ParseCIDR("192.168.0.0/16")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[5] = block

_, block, err = net.ParseCIDR("127.0.0.0/8")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
loopbackBlock = block
}

// Decode reverses the encode operation on a byte slice input
Expand All @@ -108,42 +49,6 @@ func encode(msgType messageType, in interface{}) (*bytes.Buffer, error) {
return buf, err
}

// GetPrivateIP returns the first private IP address found in a list of
// addresses.
func GetPrivateIP(addresses []net.Addr) (net.IP, error) {
var candidates []net.IP

// Find private IPv4 address
for _, rawAddr := range addresses {
var ip net.IP
switch addr := rawAddr.(type) {
case *net.IPAddr:
ip = addr.IP
case *net.IPNet:
ip = addr.IP
default:
continue
}

if ip.To4() == nil {
continue
}
if !IsPrivateIP(ip.String()) {
continue
}
candidates = append(candidates, ip)
}
numIps := len(candidates)
switch numIps {
case 0:
return nil, fmt.Errorf("No private IP address found")
case 1:
return candidates[0], nil
default:
return nil, fmt.Errorf("Multiple private IPs found. Please configure one.")
}
}

// Returns a random offset between 0 and n
func randomOffset(n int) int {
if n == 0 {
Expand Down Expand Up @@ -305,23 +210,6 @@ func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) {
return
}

// Returns if the given IP is in a private block
func IsPrivateIP(ip_str string) bool {
ip := net.ParseIP(ip_str)
for _, priv := range privateBlocks {
if priv.Contains(ip) {
return true
}
}
return false
}

// Returns if the given IP is in a loopback block
func isLoopbackIP(ip_str string) bool {
ip := net.ParseIP(ip_str)
return loopbackBlock.Contains(ip)
}

// Given a string of the form "host", "host:port",
// "ipv6::addr" or "[ipv6::address]:port",
// return true if the string includes a port.
Expand Down
96 changes: 0 additions & 96 deletions util_test.go
Original file line number Diff line number Diff line change
@@ -1,108 +1,12 @@
package memberlist

import (
"errors"
"fmt"
"net"
"reflect"
"testing"
"time"
)

func TestGetPrivateIP(t *testing.T) {
ip, _, err := net.ParseCIDR("10.1.2.3/32")
if err != nil {
t.Fatalf("failed to parse private cidr: %v", err)
}

pubIP, _, err := net.ParseCIDR("8.8.8.8/32")
if err != nil {
t.Fatalf("failed to parse public cidr: %v", err)
}

tests := []struct {
addrs []net.Addr
expected net.IP
err error
}{
{
addrs: []net.Addr{
&net.IPAddr{
IP: ip,
},
&net.IPAddr{
IP: pubIP,
},
},
expected: ip,
},
{
addrs: []net.Addr{
&net.IPAddr{
IP: pubIP,
},
},
err: errors.New("No private IP address found"),
},
{
addrs: []net.Addr{
&net.IPAddr{
IP: ip,
},
&net.IPAddr{
IP: ip,
},
&net.IPAddr{
IP: pubIP,
},
},
err: errors.New("Multiple private IPs found. Please configure one."),
},
}

for _, test := range tests {
ip, err := GetPrivateIP(test.addrs)
switch {
case test.err != nil && err != nil:
if err.Error() != test.err.Error() {
t.Fatalf("unexpected error: %v != %v", test.err, err)
}
case (test.err == nil && err != nil) || (test.err != nil && err == nil):
t.Fatalf("unexpected error: %v != %v", test.err, err)
default:
if !test.expected.Equal(ip) {
t.Fatalf("unexpected ip: %v != %v", ip, test.expected)
}
}
}
}

func TestIsPrivateIP(t *testing.T) {
privateIPs := []string{
"10.1.2.3",
"100.115.110.19",
"127.0.0.1",
"169.254.1.254",
"172.16.45.100",
"192.168.1.1",
}
publicIPs := []string{
"8.8.8.8",
"208.67.222.222",
}

for _, privateIP := range privateIPs {
if !IsPrivateIP(privateIP) {
t.Fatalf("bad")
}
}
for _, publicIP := range publicIPs {
if IsPrivateIP(publicIP) {
t.Fatalf("bad")
}
}
}

func Test_hasPort(t *testing.T) {
cases := []struct {
s string
Expand Down

0 comments on commit 902b55b

Please sign in to comment.