Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F use sockaddr #108

Merged
merged 2 commits into from
Feb 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 24 additions & 33 deletions memberlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/hashicorp/go-multierror"
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/miekg/dns"
)

Expand Down Expand Up @@ -326,7 +327,7 @@ func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) {
// as if we received an alive notification our own network channel for
// ourself.
func (m *Memberlist) setAlive() error {
var advertiseAddr []byte
var advertiseAddr net.IP
var advertisePort int
if m.config.AdvertiseAddr != "" {
// If AdvertiseAddr is not empty, then advertise
Expand All @@ -345,42 +346,21 @@ func (m *Memberlist) setAlive() error {
advertisePort = m.config.AdvertisePort
} else {
if m.config.BindAddr == "0.0.0.0" {
// Otherwise, if we're not bound to a specific IP,
//let's list the interfaces on this machine and use
// the first private IP we find.
addresses, err := net.InterfaceAddrs()
// Otherwise, if we're not bound to a specific IP, let's use a suitable
// private IP address.
var err error
m.config.AdvertiseAddr, err = sockaddr.GetPrivateIP()
if err != nil {
return fmt.Errorf("Failed to get interface addresses! Err: %v", err)
return fmt.Errorf("Failed to get interface addresses: %v", err)
}

// Find private IPv4 address
for _, rawAddr := range addresses {
var ip net.IP
switch addr := rawAddr.(type) {
case *net.IPAddr:
ip = addr.IP
case *net.IPNet:
ip = addr.IP
default:
continue
}

if ip.To4() == nil {
continue
}
if !IsPrivateIP(ip.String()) {
continue
}

advertiseAddr = ip
break
if m.config.AdvertiseAddr == "" {
return fmt.Errorf("No private IP address found, and explicit IP not provided")
}

// Failed to find private IP, error
advertiseAddr = net.ParseIP(m.config.AdvertiseAddr)
if advertiseAddr == nil {
return fmt.Errorf("No private IP address found, and explicit IP not provided")
return fmt.Errorf("Failed to parse advertise address: %q", m.config.AdvertiseAddr)
}

} else {
// Use the IP that we're bound to.
addr := m.tcpListener.Addr().(*net.TCPAddr)
Expand All @@ -392,8 +372,19 @@ func (m *Memberlist) setAlive() error {
}

// Check if this is a public address without encryption
addrStr := net.IP(advertiseAddr).String()
if !IsPrivateIP(addrStr) && !isLoopbackIP(addrStr) && !m.config.EncryptionEnabled() {
ipAddr, err := sockaddr.NewIPAddr(advertiseAddr.String())
if err != nil {
return fmt.Errorf("Failed to parse interface addresses: %v", err)
}

ifAddrs := []sockaddr.IfAddr{
sockaddr.IfAddr{
SockAddr: ipAddr,
},
}

_, publicIfs, err := sockaddr.IfByRFC("6890", ifAddrs)
if len(publicIfs) > 0 && !m.config.EncryptionEnabled() {
m.logger.Printf("[WARN] memberlist: Binding to public address without encryption!")
}

Expand Down
112 changes: 0 additions & 112 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"math"
"math/rand"
"net"
"strings"
"time"

Expand All @@ -22,19 +21,6 @@ import (
// while the 65th will triple it.
const pushPullScaleThreshold = 32

/*
* Contains an entry for each private block:
* 10.0.0.0/8
* 100.64.0.0/10
* 127.0.0.0/8
* 169.254.0.0/16
* 172.16.0.0/12
* 192.168.0.0/16
*/
var privateBlocks []*net.IPNet

var loopbackBlock *net.IPNet

const (
// Constant litWidth 2-8
lzwLitWidth = 8
Expand All @@ -43,51 +29,6 @@ const (
func init() {
// Seed the random number generator
rand.Seed(time.Now().UnixNano())

// Add each private block
privateBlocks = make([]*net.IPNet, 6)

_, block, err := net.ParseCIDR("10.0.0.0/8")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[0] = block

_, block, err = net.ParseCIDR("100.64.0.0/10")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[1] = block

_, block, err = net.ParseCIDR("127.0.0.0/8")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[2] = block

_, block, err = net.ParseCIDR("169.254.0.0/16")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[3] = block

_, block, err = net.ParseCIDR("172.16.0.0/12")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[4] = block

_, block, err = net.ParseCIDR("192.168.0.0/16")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[5] = block

_, block, err = net.ParseCIDR("127.0.0.0/8")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
loopbackBlock = block
}

// Decode reverses the encode operation on a byte slice input
Expand All @@ -108,42 +49,6 @@ func encode(msgType messageType, in interface{}) (*bytes.Buffer, error) {
return buf, err
}

// GetPrivateIP returns the first private IP address found in a list of
// addresses.
func GetPrivateIP(addresses []net.Addr) (net.IP, error) {
var candidates []net.IP

// Find private IPv4 address
for _, rawAddr := range addresses {
var ip net.IP
switch addr := rawAddr.(type) {
case *net.IPAddr:
ip = addr.IP
case *net.IPNet:
ip = addr.IP
default:
continue
}

if ip.To4() == nil {
continue
}
if !IsPrivateIP(ip.String()) {
continue
}
candidates = append(candidates, ip)
}
numIps := len(candidates)
switch numIps {
case 0:
return nil, fmt.Errorf("No private IP address found")
case 1:
return candidates[0], nil
default:
return nil, fmt.Errorf("Multiple private IPs found. Please configure one.")
}
}

// Returns a random offset between 0 and n
func randomOffset(n int) int {
if n == 0 {
Expand Down Expand Up @@ -305,23 +210,6 @@ func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) {
return
}

// Returns if the given IP is in a private block
func IsPrivateIP(ip_str string) bool {
ip := net.ParseIP(ip_str)
for _, priv := range privateBlocks {
if priv.Contains(ip) {
return true
}
}
return false
}

// Returns if the given IP is in a loopback block
func isLoopbackIP(ip_str string) bool {
ip := net.ParseIP(ip_str)
return loopbackBlock.Contains(ip)
}

// Given a string of the form "host", "host:port",
// "ipv6::addr" or "[ipv6::address]:port",
// return true if the string includes a port.
Expand Down
96 changes: 0 additions & 96 deletions util_test.go
Original file line number Diff line number Diff line change
@@ -1,108 +1,12 @@
package memberlist

import (
"errors"
"fmt"
"net"
"reflect"
"testing"
"time"
)

func TestGetPrivateIP(t *testing.T) {
ip, _, err := net.ParseCIDR("10.1.2.3/32")
if err != nil {
t.Fatalf("failed to parse private cidr: %v", err)
}

pubIP, _, err := net.ParseCIDR("8.8.8.8/32")
if err != nil {
t.Fatalf("failed to parse public cidr: %v", err)
}

tests := []struct {
addrs []net.Addr
expected net.IP
err error
}{
{
addrs: []net.Addr{
&net.IPAddr{
IP: ip,
},
&net.IPAddr{
IP: pubIP,
},
},
expected: ip,
},
{
addrs: []net.Addr{
&net.IPAddr{
IP: pubIP,
},
},
err: errors.New("No private IP address found"),
},
{
addrs: []net.Addr{
&net.IPAddr{
IP: ip,
},
&net.IPAddr{
IP: ip,
},
&net.IPAddr{
IP: pubIP,
},
},
err: errors.New("Multiple private IPs found. Please configure one."),
},
}

for _, test := range tests {
ip, err := GetPrivateIP(test.addrs)
switch {
case test.err != nil && err != nil:
if err.Error() != test.err.Error() {
t.Fatalf("unexpected error: %v != %v", test.err, err)
}
case (test.err == nil && err != nil) || (test.err != nil && err == nil):
t.Fatalf("unexpected error: %v != %v", test.err, err)
default:
if !test.expected.Equal(ip) {
t.Fatalf("unexpected ip: %v != %v", ip, test.expected)
}
}
}
}

func TestIsPrivateIP(t *testing.T) {
privateIPs := []string{
"10.1.2.3",
"100.115.110.19",
"127.0.0.1",
"169.254.1.254",
"172.16.45.100",
"192.168.1.1",
}
publicIPs := []string{
"8.8.8.8",
"208.67.222.222",
}

for _, privateIP := range privateIPs {
if !IsPrivateIP(privateIP) {
t.Fatalf("bad")
}
}
for _, publicIP := range publicIPs {
if IsPrivateIP(publicIP) {
t.Fatalf("bad")
}
}
}

func Test_hasPort(t *testing.T) {
cases := []struct {
s string
Expand Down