diff --git a/connection_manager.go b/connection_manager.go index 81563a4eb..900db07cb 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -473,18 +473,5 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") - //TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out - newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp) - if !newHostinfo.HandshakeReady { - ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo) - } - - //If this is a static host, we don't need to wait for the HostQueryReply - //We can trigger the handshake right now - if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok { - select { - case n.intf.handshakeManager.trigger <- hostinfo.vpnIp: - default: - } - } + n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) } diff --git a/connection_manager_test.go b/connection_manager_test.go index e220819f3..e802904e1 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -58,7 +58,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { firewall: &Firewall{}, lightHouse: lh, pki: &PKI{}, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } ifce.pki.cs.Store(cs) @@ -138,7 +138,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { firewall: &Firewall{}, lightHouse: lh, pki: &PKI{}, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } ifce.pki.cs.Store(cs) @@ -258,7 +258,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, disconnectInvalid: true, pki: &PKI{}, diff --git a/connection_state.go b/connection_state.go index 52607496e..f8c31f65a 100644 --- a/connection_state.go +++ b/connection_state.go @@ -23,14 +23,12 @@ type ConnectionState struct { initiator bool messageCounter atomic.Uint64 window *Bits - queueLock sync.Mutex writeLock sync.Mutex ready bool } func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { var dhFunc noise.DHFunc - switch certState.Certificate.Details.Curve { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 diff --git a/control_tester.go b/control_tester.go index 680cd5a77..b786ba383 100644 --- a/control_tester.go +++ b/control_tester.go @@ -165,15 +165,5 @@ func (c *Control) GetCert() *cert.NebulaCertificate { } func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { - hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp) - ixHandshakeStage0(c.f, vpnIp, hostinfo) - - // If this is a static host, we don't need to wait for the HostQueryReply - // We can trigger the handshake right now - if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok { - select { - case c.f.handshakeManager.trigger <- hostinfo.vpnIp: - default: - } - } + c.f.handshakeManager.StartHandshake(vpnIp, nil) } diff --git a/handshake_ix.go b/handshake_ix.go index 94f408f89..7e60c7907 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -13,19 +13,12 @@ import ( // This function constructs a handshake packet, but does not actually send it // Sending is done by the handshake manager -func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { - // This queries the lighthouse if we don't know a remote for the host - // We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send - // more quickly, effect is a quicker handshake. - if hostinfo.remote == nil { - f.lightHouse.QueryServer(vpnIp, f) - } - - err := f.handshakeManager.AddIndexHostInfo(hostinfo) +func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { + err := f.handshakeManager.allocateIndex(hostinfo) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") - return + return false } certState := f.pki.GetCertState() @@ -46,9 +39,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hsBytes, err = hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") - return + return false } h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) @@ -56,9 +49,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") - return + return false } // We are sending handshake packet 1, so we don't expect to receive @@ -68,6 +61,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hostinfo.HandshakePacket[0] = msg hostinfo.HandshakeReady = true hostinfo.handshakeStart = time.Now() + return true } func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { @@ -428,31 +422,27 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip - //TODO: this adds it to the timer wheel in a way that aggressively retries - newHostInfo := f.getOrHandshake(hostinfo.vpnIp) - newHostInfo.Lock() - - // Block the current used address - newHostInfo.remotes = hostinfo.remotes - newHostInfo.remotes.BlockRemote(addr) - - // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - - f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). - WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). - Info("Blocked addresses for handshakes") - - // Swap the packet store to benefit the original intended recipient - hostinfo.ConnectionState.queueLock.Lock() - newHostInfo.packetStore = hostinfo.packetStore - hostinfo.packetStore = []*cachedPacket{} - hostinfo.ConnectionState.queueLock.Unlock() - - // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnIp = vpnIp - f.sendCloseTunnel(hostinfo) - newHostInfo.Unlock() + f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) { + //TODO: this doesnt know if its being added or is being used for caching a packet + // Block the current used address + newHostInfo.remotes = hostinfo.remotes + newHostInfo.remotes.BlockRemote(addr) + + // Get the correct remote list for the host we did handshake with + hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) + + f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). + WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). + Info("Blocked addresses for handshakes") + + // Swap the packet store to benefit the original intended recipient + newHostInfo.packetStore = hostinfo.packetStore + hostinfo.packetStore = []*cachedPacket{} + + // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down + hostinfo.vpnIp = vpnIp + f.sendCloseTunnel(hostinfo) + }) return true } diff --git a/handshake_manager.go b/handshake_manager.go index e15b79433..e2c2cf548 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -57,13 +57,14 @@ type HandshakeManager struct { messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter + f *Interface l *logrus.Logger // can be used to trigger outbound handshake for the given vpnIp trigger chan iputil.VpnIp } -func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { +func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ vpnIps: map[iputil.VpnIp]*HostInfo{}, indexes: map[uint32]*HostInfo{}, @@ -80,7 +81,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ } } -func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { +func (c *HandshakeManager) Run(ctx context.Context) { clockSource := time.NewTicker(c.config.tryInterval) defer clockSource.Stop() @@ -89,25 +90,25 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { case <-ctx.Done(): return case vpnIP := <-c.trigger: - c.handleOutbound(vpnIP, f, true) + c.handleOutbound(vpnIP, true) case now := <-clockSource.C: - c.NextOutboundHandshakeTimerTick(now, f) + c.NextOutboundHandshakeTimerTick(now) } } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { +func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { c.OutboundHandshakeTimer.Advance(now) for { vpnIp, has := c.OutboundHandshakeTimer.Purge() if !has { break } - c.handleOutbound(vpnIp, f, false) + c.handleOutbound(vpnIp, false) } } -func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) { +func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { hostinfo := c.QueryVpnIp(vpnIp) if hostinfo == nil { return @@ -122,14 +123,6 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light return } - // Check if we have a handshake packet to transmit yet - if !hostinfo.HandshakeReady { - // There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly - // Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) - return - } - // If we are out of time, clean up if hostinfo.HandshakeCounter >= c.config.retries { hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)). @@ -143,6 +136,17 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light return } + // Increment the counter to increase our delay, linear backoff + hostinfo.HandshakeCounter++ + + // Check if we have a handshake packet to transmit yet + if !hostinfo.HandshakeReady { + if !ixHandshakeStage0(c.f, hostinfo) { + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + return + } + } + // Get a remotes object if we don't already have one. // This is mainly to protect us as this should never be the case // NB ^ This comment doesn't jive. It's how the thing gets initialized. @@ -170,7 +174,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter - c.lightHouse.QueryServer(vpnIp, f) + c.lightHouse.QueryServer(vpnIp, c.f) } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply @@ -214,7 +218,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light relayHostInfo := c.mainHostMap.QueryVpnIp(*relay) if relayHostInfo == nil || relayHostInfo.remote == nil { hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - f.Handshake(*relay) + c.f.Handshake(*relay) continue } // Check the relay HostInfo to see if we already established a relay through it @@ -222,7 +226,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light switch existingRelay.State { case Established: hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay") - f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + c.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Requested: hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") // Re-send the CreateRelay request, in case the previous one was lost. @@ -239,7 +243,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light Error("Failed to marshal Control message to create relay") } else { // This must send over the hostinfo, not over hm.Hosts[ip] - f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) c.l.WithFields(logrus.Fields{ "relayFrom": c.lightHouse.myVpnIp, "relayTo": vpnIp, @@ -274,7 +278,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light WithError(err). Error("Failed to marshal Control message to create relay") } else { - f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) c.l.WithFields(logrus.Fields{ "relayFrom": c.lightHouse.myVpnIp, "relayTo": vpnIp, @@ -287,23 +291,40 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light } } - // Increment the counter to increase our delay, linear backoff - hostinfo.HandshakeCounter++ - // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) } } -// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it. -func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo { - // A write lock is used to avoid having to recheck the map and trading a read lock for a write lock - c.Lock() - defer c.Unlock() +// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present +// The 2nd argument will be true if the hostinfo is ready to transmit traffic +func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) (*HostInfo, bool) { + // Check the main hostmap and maintain a read lock if our host is not there + hm.mainHostMap.RLock() + if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok { + hm.mainHostMap.RUnlock() + // Do not attempt promotion if you are a lighthouse + if !hm.lightHouse.amLighthouse { + h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f) + } + return h, true + } + + defer hm.mainHostMap.RUnlock() + return hm.StartHandshake(vpnIp, cacheCb), false +} - if hostinfo, ok := c.vpnIps[vpnIp]; ok { - // We are already tracking this vpn ip +// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip +func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) *HostInfo { + hm.Lock() + defer hm.Unlock() + + if hostinfo, ok := hm.vpnIps[vpnIp]; ok { + // We are already trying to handshake with this vpn ip + if cacheCb != nil { + cacheCb(hostinfo) + } return hostinfo } @@ -317,10 +338,30 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo { }, } - c.vpnIps[vpnIp] = hostinfo - c.metricInitiated.Inc(1) - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) + hm.vpnIps[vpnIp] = hostinfo + hm.metricInitiated.Inc(1) + hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) + + if cacheCb != nil { + cacheCb(hostinfo) + } + + // If this is a static host, we don't need to wait for the HostQueryReply + // We can trigger the handshake right now + _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] + if !doTrigger { + // Add any calculated remotes, and trigger early handshake if one found + doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) + } + + if doTrigger { + select { + case hm.trigger <- vpnIp: + default: + } + } + hm.lightHouse.QueryServer(vpnIp, hm.f) return hostinfo } @@ -342,10 +383,10 @@ var ( // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { - c.Lock() - defer c.Unlock() c.mainHostMap.Lock() defer c.mainHostMap.Unlock() + c.Lock() + defer c.Unlock() // Check if we already have a tunnel with this vpn ip existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] @@ -396,47 +437,47 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket // Complete is a simpler version of CheckAndComplete when we already know we // won't have a localIndexId collision because we already have an entry in the // pendingHostMap. An existing hostinfo is returned if there was one. -func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { - c.Lock() - defer c.Unlock() - c.mainHostMap.Lock() - defer c.mainHostMap.Unlock() +func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { + hm.mainHostMap.Lock() + defer hm.mainHostMap.Unlock() + hm.Lock() + defer hm.Unlock() - existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). Info("New host shadows existing host remoteIndex") } // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. - c.unlockedDeleteHostInfo(hostinfo) - c.mainHostMap.unlockedAddHostInfo(hostinfo, f) + hm.unlockedDeleteHostInfo(hostinfo) + hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) } -// AddIndexHostInfo generates a unique localIndexId for this HostInfo +// allocateIndex generates a unique localIndexId for this HostInfo // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId -func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { - c.Lock() - defer c.Unlock() - c.mainHostMap.RLock() - defer c.mainHostMap.RUnlock() +func (hm *HandshakeManager) allocateIndex(h *HostInfo) error { + hm.mainHostMap.RLock() + defer hm.mainHostMap.RUnlock() + hm.Lock() + defer hm.Unlock() for i := 0; i < 32; i++ { - index, err := generateIndex(c.l) + index, err := generateIndex(hm.l) if err != nil { return err } - _, inPending := c.indexes[index] - _, inMain := c.mainHostMap.Indexes[index] + _, inPending := hm.indexes[index] + _, inMain := hm.mainHostMap.Indexes[index] if !inMain && !inPending { h.localIndexId = index - c.indexes[index] = h + hm.indexes[index] = h return nil } } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index c6df37d51..d318a9def 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -14,22 +14,20 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} - mw := &mockEncWriter{} mainHM := NewHostMap(l, vpncidr, preferredRanges) lh := newTestLighthouse() - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) now := time.Now() - blah.NextOutboundHandshakeTimerTick(now, mw) + blah.NextOutboundHandshakeTimerTick(now) - i := blah.AddVpnIp(ip) - i2 := blah.AddVpnIp(ip) + i := blah.StartHandshake(ip, nil) + i2 := blah.StartHandshake(ip, nil) assert.Same(t, i, i2) i.remotes = NewRemoteList(nil) @@ -44,14 +42,14 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { // Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right for i := 1; i <= DefaultHandshakeRetries+1; i++ { now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval) - blah.NextOutboundHandshakeTimerTick(now, mw) + blah.NextOutboundHandshakeTimerTick(now) } // Confirm they are still in the pending index list assert.Contains(t, blah.vpnIps, ip) // Tick 1 more time, a minute will certainly flush it out - blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw) + blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute)) // Confirm they have been removed assert.NotContains(t, blah.vpnIps, ip) diff --git a/hostmap.go b/hostmap.go index 829c7c0bb..f2618c7dc 100644 --- a/hostmap.go +++ b/hostmap.go @@ -456,12 +456,6 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host return nil, nil, errors.New("unable to find host with relay") } -// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every -// `PromoteEvery` calls to this function for a given host. -func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) *HostInfo { - return hm.queryVpnIp(vpnIp, ifce) -} - func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { @@ -579,7 +573,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } } -func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { +func (i *HostInfo) unlockedCachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { //TODO: return the error so we can log with more context if len(i.packetStore) < 100 { tempPacket := make([]byte, len(packet)) @@ -608,7 +602,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) { //TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send //TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical - i.ConnectionState.queueLock.Lock() i.HandshakeComplete = true //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. // Clamping it to 2 gets us out of the woods for now @@ -630,7 +623,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) { i.remotes.ResetBlockedRemotes() i.packetStore = make([]*cachedPacket, 0) i.ConnectionState.ready = true - i.ConnectionState.queueLock.Unlock() } func (i *HostInfo) GetCert() *cert.NebulaCertificate { diff --git a/inside.go b/inside.go index 6a0e078ab..2219d2bd6 100644 --- a/inside.go +++ b/inside.go @@ -44,7 +44,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo := f.getOrHandshake(fwPacket.RemoteIP) + hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(h *HostInfo) { + h.unlockedCachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + if hostinfo == nil { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { @@ -54,23 +57,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } return } - ci := hostinfo.ConnectionState - - if !ci.ready { - // Because we might be sending stored packets, lock here to stop new things going to - // the packet queue. - ci.queueLock.Lock() - if !ci.ready { - hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) - ci.queueLock.Unlock() - return - } - ci.queueLock.Unlock() + + if !ready { + return } dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q) + f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q) } else { f.rejectInside(packet, out, q) @@ -109,62 +103,20 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * } func (f *Interface) Handshake(vpnIp iputil.VpnIp) { - f.getOrHandshake(vpnIp) + f.getOrHandshake(vpnIp, nil) } -// getOrHandshake returns nil if the vpnIp is not routable -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { +// getOrHandshake returns nil if the vpnIp is not routable. +// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel +func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(info *HostInfo)) (*HostInfo, bool) { if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { vpnIp = f.inside.RouteFor(vpnIp) if vpnIp == 0 { - return nil + return nil, false } } - hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f) - if hostinfo == nil { - hostinfo = f.handshakeManager.AddVpnIp(vpnIp) - } - ci := hostinfo.ConnectionState - - if ci != nil && ci.eKey != nil && ci.ready { - return hostinfo - } - - // Handshake is not ready, we need to grab the lock now before we start the handshake process - //TODO: move this to handshake manager - hostinfo.Lock() - defer hostinfo.Unlock() - - // Double check, now that we have the lock - ci = hostinfo.ConnectionState - if ci != nil && ci.eKey != nil && ci.ready { - return hostinfo - } - - // If we have already created the handshake packet, we don't want to call the function at all. - if !hostinfo.HandshakeReady { - ixHandshakeStage0(f, vpnIp, hostinfo) - // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. - //xx_handshakeStage0(f, ip, hostinfo) - - // If this is a static host, we don't need to wait for the HostQueryReply - // We can trigger the handshake right now - _, doTrigger := f.lightHouse.GetStaticHostList()[vpnIp] - if !doTrigger { - // Add any calculated remotes, and trigger early handshake if one found - doTrigger = f.lightHouse.addCalculatedRemotes(vpnIp) - } - - if doTrigger { - select { - case f.handshakeManager.trigger <- vpnIp: - default: - } - } - } - - return hostinfo + return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -191,7 +143,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { - hostInfo := f.getOrHandshake(vpnIp) + hostInfo, ready := f.getOrHandshake(vpnIp, func(h *HostInfo) { + h.unlockedCachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) + }) + if hostInfo == nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnIp", vpnIp). @@ -200,16 +155,8 @@ func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSu return } - if !hostInfo.ConnectionState.ready { - // Because we might be sending stored packets, lock here to stop new things going to - // the packet queue. - hostInfo.ConnectionState.queueLock.Lock() - if !hostInfo.ConnectionState.ready { - hostInfo.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) - hostInfo.ConnectionState.queueLock.Unlock() - return - } - hostInfo.ConnectionState.queueLock.Unlock() + if !ready { + return } f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out) @@ -229,7 +176,7 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } -// sendVia sends a payload through a Relay tunnel. No authentication or encryption is done +// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done // to the payload for the ultimate target host, making this a useful method for sending // handshake messages to peers through relay tunnels. // via is the HostInfo through which the message is relayed. diff --git a/main.go b/main.go index 4e8448b84..883e562d6 100644 --- a/main.go +++ b/main.go @@ -235,7 +235,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg messageMetrics: messageMetrics, } - handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig) + handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger serveDns := false @@ -302,7 +302,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg ifce.RegisterConfigChangeCallbacks(c) ifce.reloadSendRecvError(c) - go handshakeManager.Run(ctx, ifce) + handshakeManager.f = ifce + go handshakeManager.Run(ctx) } // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept diff --git a/relay_manager.go b/relay_manager.go index 8f6365293..224135362 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -244,7 +244,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! - f.getOrHandshake(target) + f.Handshake(target) return } if peer.remote == nil { diff --git a/ssh.go b/ssh.go index c68e0820b..30f9aea3d 100644 --- a/ssh.go +++ b/ssh.go @@ -607,11 +607,10 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } } - hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) if addr != nil { hostInfo.SetRemote(addr) } - ifce.getOrHandshake(vpnIp) return w.WriteLine("Created") }