diff --git a/membership.go b/membership.go index 51f4290..cd6e4b6 100644 --- a/membership.go +++ b/membership.go @@ -284,6 +284,31 @@ func guessMulticastAddress() string { return multicastAddress } +// getListenInterface gets the network interface for the listen IP +func getListenInterface() (*net.Interface, error) { + ifaces, err := net.Interfaces() + if err == nil { + for _, iface := range ifaces { + addrs, err := iface.Addrs() + if err != nil { + logfWarn("Can not get addresses of interface %s", iface.Name) + continue + } + for _, addr := range addrs { + ip, _, err := net.ParseCIDR(addr.String()) + if err != nil { + continue + } + if ip.String() == GetListenIP().String() { + logfInfo("Found interface with listen IP: %s", iface.Name) + return &iface, nil + } + } + } + } + return nil, errors.New("Could not determine the interface of the listen IP address") +} + // Returns a random slice of valid ping/forward request targets; i.e., not // this node, and not dead. func getTargetNodes(count int, exclude ...*Node) []*Node { @@ -346,7 +371,11 @@ func listenUDPMulticast(port int) error { } /* Now listen at selected port */ - c, err := net.ListenMulticastUDP("udp", nil, listenAddress) + iface, err := getListenInterface() + if err != nil { + return err + } + c, err := net.ListenMulticastUDP("udp", iface, listenAddress) if err != nil { return err } @@ -402,14 +431,16 @@ func multicastAnnounce(addr string) error { logError(err) return err } - + laddr := &net.UDPAddr{ + IP: GetListenIP(), + Port: 0, + } for { - c, err := net.DialUDP("udp", nil, address) + c, err := net.DialUDP("udp", laddr, address) if err != nil { logError(err) return err } - // Compose and send the multicast announcement msgBytes := encodeMulticastAnnounceBytes() _, err = c.Write(msgBytes) @@ -418,7 +449,7 @@ func multicastAnnounce(addr string) error { return err } - logfTrace("Sent announcement multicast to %v", fullAddr) + logfTrace("Sent announcement multicast from %v to %v", laddr, fullAddr) if GetMulticastAnnounceIntervalSeconds() > 0 { time.Sleep(time.Second * time.Duration(GetMulticastAnnounceIntervalSeconds()))