Skip to content

Commit

Permalink
move session negotiation states to model to avoid circular import
Browse files Browse the repository at this point in the history
  • Loading branch information
ainghazal committed Feb 13, 2024
1 parent 5c1f86a commit b16e276
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 107 deletions.
4 changes: 2 additions & 2 deletions internal/controlchannel/controlchannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ func (ws *workersState) moveUpWorker() {
// even if after the first key generation we receive two SOFT_RESET requests
// back to back.

if ws.sessionManager.NegotiationState() < session.S_GENERATED_KEYS {
if ws.sessionManager.NegotiationState() < model.S_GENERATED_KEYS {
continue
}
ws.sessionManager.SetNegotiationState(session.S_INITIAL)
ws.sessionManager.SetNegotiationState(model.S_INITIAL)
// TODO(ainghazal): revisit this step.
// when we implement key rotation. OpenVPN has
// the concept of a "lame duck", i.e., the
Expand Down
2 changes: 1 addition & 1 deletion internal/datachannel/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func (ws *workersState) keyWorker(firstKeyReady chan<- any) {
ws.logger.Warnf("error on key derivation: %v", err)
continue
}
ws.sessionManager.SetNegotiationState(session.S_GENERATED_KEYS)
ws.sessionManager.SetNegotiationState(model.S_GENERATED_KEYS)
once.Do(func() {
close(firstKeyReady)
})
Expand Down
59 changes: 59 additions & 0 deletions internal/model/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package model

// NegotiationState is the state of the session negotiation.
type NegotiationState int

const (
// S_ERROR means there was some form of protocol error.
S_ERROR = NegotiationState(iota) - 1

Check warning on line 8 in internal/model/session.go

View workflow job for this annotation

GitHub Actions / lint

don't use ALL_CAPS in Go names; use CamelCase

// S_UNDER is the undefined state.

Check warning on line 10 in internal/model/session.go

View workflow job for this annotation

GitHub Actions / lint

comment on exported const S_UNDEF should be of the form "S_UNDEF ..."
S_UNDEF

Check warning on line 11 in internal/model/session.go

View workflow job for this annotation

GitHub Actions / lint

don't use ALL_CAPS in Go names; use CamelCase

// S_INITIAL means we're ready to begin the three-way handshake.
S_INITIAL

Check warning on line 14 in internal/model/session.go

View workflow job for this annotation

GitHub Actions / lint

don't use ALL_CAPS in Go names; use CamelCase

// S_PRE_START means we're waiting for acknowledgment from the remote.
S_PRE_START

Check warning on line 17 in internal/model/session.go

View workflow job for this annotation

GitHub Actions / lint

don't use ALL_CAPS in Go names; use CamelCase

// S_START means we've done the three-way handshake.
S_START

// S_SENT_KEY means we have sent the local part of the key_source2 random material.
S_SENT_KEY

// S_GOT_KEY means we have got the remote part of key_source2.
S_GOT_KEY

// S_ACTIVE means the control channel was established.
S_ACTIVE

// S_GENERATED_KEYS means the data channel keys have been generated.
S_GENERATED_KEYS
)

// String maps a [SessionNegotiationState] to a string.
func (sns NegotiationState) String() string {
switch sns {
case S_UNDEF:
return "S_UNDEF"
case S_INITIAL:
return "S_INITIAL"
case S_PRE_START:
return "S_PRE_START"
case S_START:
return "S_START"
case S_SENT_KEY:
return "S_SENT_KEY"
case S_GOT_KEY:
return "S_GOT_KEY"
case S_ACTIVE:
return "S_ACTIVE"
case S_GENERATED_KEYS:
return "S_GENERATED_KEYS"
case S_ERROR:
return "S_ERROR"
default:
return "S_INVALID"
}
}
16 changes: 8 additions & 8 deletions internal/model/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ type HandshakeTracer interface {
TimeNow() time.Time

// OnStateChange is called for each transition in the state machine.
OnStateChange(state int)
OnStateChange(state NegotiationState)

// OnIncomingPacket is called when a packet is received.
OnIncomingPacket(packet *Packet, stage int)
OnIncomingPacket(packet *Packet, stage NegotiationState)

// OnOutgoingPacket is called when a packet is about to be sent.
OnOutgoingPacket(packet *Packet, stage int, retries int)
OnOutgoingPacket(packet *Packet, stage NegotiationState, retries int)

// OnDroppedPacket is called whenever a packet is dropped (in/out)
OnDroppedPacket(direction Direction, stage int, packet *Packet)
OnDroppedPacket(direction Direction, stage NegotiationState, packet *Packet)
}

// Direction is one of two directions on a packet.
Expand Down Expand Up @@ -57,16 +57,16 @@ type dummyTracer struct{}
func (dt *dummyTracer) TimeNow() time.Time { return time.Now() }

// OnStateChange is called for each transition in the state machine.
func (dt *dummyTracer) OnStateChange(int) {}
func (dt *dummyTracer) OnStateChange(NegotiationState) {}

// OnIncomingPacket is called when a packet is received.
func (dt *dummyTracer) OnIncomingPacket(*Packet, int) {}
func (dt *dummyTracer) OnIncomingPacket(*Packet, NegotiationState) {}

// OnOutgoingPacket is called when a packet is about to be sent.
func (dt *dummyTracer) OnOutgoingPacket(*Packet, int, int) {}
func (dt *dummyTracer) OnOutgoingPacket(*Packet, NegotiationState, int) {}

// OnDroppedPacket is called whenever a packet is dropped (in/out)
func (dt *dummyTracer) OnDroppedPacket(Direction, int, *Packet) {}
func (dt *dummyTracer) OnDroppedPacket(Direction, NegotiationState, *Packet) {}

// Assert that dummyTracer implements [model.HandshakeTracer].
var _ HandshakeTracer = &dummyTracer{}
10 changes: 5 additions & 5 deletions internal/packetmuxer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (ws *workersState) startHardReset() error {
ws.hardResetCount++

// reset the state to become initial again.
ws.sessionManager.SetNegotiationState(session.S_PRE_START)
ws.sessionManager.SetNegotiationState(model.S_PRE_START)

// emit a CONTROL_HARD_RESET_CLIENT_V2 pkt
packet := ws.sessionManager.NewHardResetPacket()
Expand All @@ -223,7 +223,7 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error {
}

// handle the case where we're performing a HARD_RESET
if ws.sessionManager.NegotiationState() == session.S_PRE_START &&
if ws.sessionManager.NegotiationState() == model.S_PRE_START &&
packet.Opcode == model.P_CONTROL_HARD_RESET_SERVER_V2 {
packet.Log(ws.logger, model.DirectionIncoming)
ws.hardResetTicker.Stop()
Expand All @@ -238,7 +238,7 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error {
return workers.ErrShutdown
}
} else {
if ws.sessionManager.NegotiationState() < session.S_GENERATED_KEYS {
if ws.sessionManager.NegotiationState() < model.S_GENERATED_KEYS {
// A well-behaved server should not send us data packets
// before we have a working session. Under normal operations, the
// connection in the client side should pick a different port,
Expand Down Expand Up @@ -269,7 +269,7 @@ func (ws *workersState) finishThreeWayHandshake(packet *model.Packet) error {
ws.sessionManager.SetRemoteSessionID(packet.LocalSessionID)

// advance the state
ws.sessionManager.SetNegotiationState(session.S_START)
ws.sessionManager.SetNegotiationState(model.S_START)

// pass the packet up so that we can ack it properly
select {
Expand Down Expand Up @@ -302,7 +302,7 @@ func (ws *workersState) serializeAndEmit(packet *model.Packet) error {

ws.tracer.OnOutgoingPacket(
packet,
int(ws.sessionManager.NegotiationState()),
ws.sessionManager.NegotiationState(),
ws.hardResetCount,
)

Expand Down
4 changes: 2 additions & 2 deletions internal/reliabletransport/receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (ws *workersState) moveUpWorker() {
// or POSSIBLY BLOCK waiting for notifications
select {
case packet := <-ws.muxerToReliable:
ws.tracer.OnIncomingPacket(packet, int(ws.sessionManager.NegotiationState()))
ws.tracer.OnIncomingPacket(packet, ws.sessionManager.NegotiationState())

if packet.Opcode != model.P_CONTROL_HARD_RESET_SERVER_V2 {
// the hard reset has already been logged by the layer below
Expand Down Expand Up @@ -63,7 +63,7 @@ func (ws *workersState) moveUpWorker() {
// TODO: add reason
ws.tracer.OnDroppedPacket(
model.DirectionIncoming,
int(ws.sessionManager.NegotiationState()),
ws.sessionManager.NegotiationState(),
packet)
ws.logger.Debugf("Dropping packet: %v", packet.ID)
continue
Expand Down
4 changes: 2 additions & 2 deletions internal/reliabletransport/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ func (ws *workersState) blockOnTryingToSend(sender *reliableSender, ticker *time
p.packet.Log(ws.logger, model.DirectionOutgoing)
ws.tracer.OnOutgoingPacket(
p.packet,
int(ws.sessionManager.NegotiationState()),
int(p.retries),
ws.sessionManager.NegotiationState(),
p.retries,
)

select {
Expand Down
68 changes: 5 additions & 63 deletions internal/session/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,6 @@ import (
"github.com/ooni/minivpn/internal/runtimex"
)

// SessionNegotiationState is the state of the session negotiation.
type SessionNegotiationState int

const (
// S_ERROR means there was some form of protocol error.
S_ERROR = SessionNegotiationState(iota) - 1

// S_UNDER is the undefined state.
S_UNDEF

// S_INITIAL means we're ready to begin the three-way handshake.
S_INITIAL

// S_PRE_START means we're waiting for acknowledgment from the remote.
S_PRE_START

// S_START means we've done the three-way handshake.
S_START

// S_SENT_KEY means we have sent the local part of the key_source2 random material.
S_SENT_KEY

// S_GOT_KEY means we have got the remote part of key_source2.
S_GOT_KEY

// S_ACTIVE means the control channel was established.
S_ACTIVE

// S_GENERATED_KEYS means the data channel keys have been generated.
S_GENERATED_KEYS
)

// String maps a [SessionNegotiationState] to a string.
func (sns SessionNegotiationState) String() string {
switch sns {
case S_UNDEF:
return "S_UNDEF"
case S_INITIAL:
return "S_INITIAL"
case S_PRE_START:
return "S_PRE_START"
case S_START:
return "S_START"
case S_SENT_KEY:
return "S_SENT_KEY"
case S_GOT_KEY:
return "S_GOT_KEY"
case S_ACTIVE:
return "S_ACTIVE"
case S_GENERATED_KEYS:
return "S_GENERATED_KEYS"
case S_ERROR:
return "S_ERROR"
default:
return "S_INVALID"
}
}

// Manager manages the session. The zero value is invalid. Please, construct
// using [NewManager]. This struct is concurrency safe.
type Manager struct {
Expand All @@ -81,7 +23,7 @@ type Manager struct {
localSessionID model.SessionID
logger model.Logger
mu sync.Mutex
negState SessionNegotiationState
negState model.NegotiationState
remoteSessionID optional.Value[model.SessionID]
tunnelInfo model.TunnelInfo
tracer model.HandshakeTracer
Expand Down Expand Up @@ -263,20 +205,20 @@ func (m *Manager) localControlPacketIDLocked() (model.PacketID, error) {
}

// NegotiationState returns the state of the negotiation.
func (m *Manager) NegotiationState() SessionNegotiationState {
func (m *Manager) NegotiationState() model.NegotiationState {
defer m.mu.Unlock()
m.mu.Lock()
return m.negState
}

// SetNegotiationState sets the state of the negotiation.
func (m *Manager) SetNegotiationState(sns SessionNegotiationState) {
func (m *Manager) SetNegotiationState(sns model.NegotiationState) {
defer m.mu.Unlock()
m.mu.Lock()
m.logger.Infof("[@] %s -> %s", m.negState, sns)
m.tracer.OnStateChange(int(sns))
m.tracer.OnStateChange(sns)
m.negState = sns
if sns == S_GENERATED_KEYS {
if sns == model.S_GENERATED_KEYS {
m.Ready <- true
}
}
Expand Down
6 changes: 3 additions & 3 deletions internal/tlssession/tlssession.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (ws *workersState) doTLSAuth(conn net.Conn, config *tls.Config, errorch cha
errorch <- err
return
}
ws.sessionManager.SetNegotiationState(session.S_SENT_KEY)
ws.sessionManager.SetNegotiationState(model.S_SENT_KEY)

// read the server's keySource and options
remoteKey, serverOptions, err := ws.recvAuthReplyMessage(tlsConn)
Expand All @@ -174,7 +174,7 @@ func (ws *workersState) doTLSAuth(conn net.Conn, config *tls.Config, errorch cha

// add the remote key to the active key
activeKey.AddRemoteKey(remoteKey)
ws.sessionManager.SetNegotiationState(session.S_GOT_KEY)
ws.sessionManager.SetNegotiationState(model.S_GOT_KEY)

// send the push request
if err := ws.sendPushRequestMessage(tlsConn); err != nil {
Expand All @@ -193,7 +193,7 @@ func (ws *workersState) doTLSAuth(conn net.Conn, config *tls.Config, errorch cha
ws.sessionManager.UpdateTunnelInfo(tinfo)

// progress to the ACTIVE state
ws.sessionManager.SetNegotiationState(session.S_ACTIVE)
ws.sessionManager.SetNegotiationState(model.S_ACTIVE)

// notify the datachannel that we've got a key pair ready to use
ws.keyUp <- activeKey
Expand Down
Loading

0 comments on commit b16e276

Please sign in to comment.