Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify getting a hostinfo or starting a handshake with one #954

Merged
merged 2 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,18 +473,5 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")

//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp)
if !newHostinfo.HandshakeReady {
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
}

//If this is a static host, we don't need to wait for the HostQueryReply
//We can trigger the handshake right now
if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
select {
case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
default:
}
}
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
}
6 changes: 3 additions & 3 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.pki.cs.Store(cs)
Expand Down Expand Up @@ -138,7 +138,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.pki.cs.Store(cs)
Expand Down Expand Up @@ -258,7 +258,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
disconnectInvalid: true,
pki: &PKI{},
Expand Down
2 changes: 0 additions & 2 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ type ConnectionState struct {
initiator bool
messageCounter atomic.Uint64
window *Bits
queueLock sync.Mutex
writeLock sync.Mutex
ready bool
}

func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
var dhFunc noise.DHFunc

switch certState.Certificate.Details.Curve {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
Expand Down
12 changes: 1 addition & 11 deletions control_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,5 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
}

func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp)
ixHandshakeStage0(c.f, vpnIp, hostinfo)

// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
select {
case c.f.handshakeManager.trigger <- hostinfo.vpnIp:
default:
}
}
c.f.handshakeManager.StartHandshake(vpnIp, nil)
}
70 changes: 30 additions & 40 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,12 @@ import (

// This function constructs a handshake packet, but does not actually send it
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
// This queries the lighthouse if we don't know a remote for the host
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
// more quickly, effect is a quicker handshake.
if hostinfo.remote == nil {
f.lightHouse.QueryServer(vpnIp, f)
}

err := f.handshakeManager.AddIndexHostInfo(hostinfo)
func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
err := f.handshakeManager.allocateIndex(hostinfo)
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return
return false
}

certState := f.pki.GetCertState()
Expand All @@ -46,19 +39,19 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hsBytes, err = hs.Marshal()

if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
return false
}

h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
ci.messageCounter.Add(1)

msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
return false
}

// We are sending handshake packet 1, so we don't expect to receive
Expand All @@ -68,6 +61,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true
hostinfo.handshakeStart = time.Now()
return true
}

func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
Expand Down Expand Up @@ -428,31 +422,27 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
f.handshakeManager.DeleteHostInfo(hostinfo)

// Create a new hostinfo/handshake for the intended vpn ip
//TODO: this adds it to the timer wheel in a way that aggressively retries
newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
newHostInfo.Lock()

// Block the current used address
newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)

// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)

f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")

// Swap the packet store to benefit the original intended recipient
hostinfo.ConnectionState.queueLock.Lock()
newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{}
hostinfo.ConnectionState.queueLock.Unlock()

// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo)
newHostInfo.Unlock()
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) {
//TODO: this doesnt know if its being added or is being used for caching a packet
// Block the current used address
newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)

// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)

f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")

// Swap the packet store to benefit the original intended recipient
newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{}

// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo)
})

return true
}
Expand Down
Loading