From b16e276904ae7614c6ac8077fa9446f21ba26a5b Mon Sep 17 00:00:00 2001 From: Ain Ghazal Date: Tue, 13 Feb 2024 13:24:05 +0100 Subject: [PATCH] move session negotiation states to model to avoid circular import --- internal/controlchannel/controlchannel.go | 4 +- internal/datachannel/service.go | 2 +- internal/model/session.go | 59 ++++++++++++++++++++ internal/model/trace.go | 16 +++--- internal/packetmuxer/service.go | 10 ++-- internal/reliabletransport/receiver.go | 4 +- internal/reliabletransport/sender.go | 4 +- internal/session/manager.go | 68 ++--------------------- internal/tlssession/tlssession.go | 6 +- pkg/tracex/trace.go | 39 ++++++------- 10 files changed, 105 insertions(+), 107 deletions(-) create mode 100644 internal/model/session.go diff --git a/internal/controlchannel/controlchannel.go b/internal/controlchannel/controlchannel.go index 95648efe..b13a6087 100644 --- a/internal/controlchannel/controlchannel.go +++ b/internal/controlchannel/controlchannel.go @@ -92,10 +92,10 @@ func (ws *workersState) moveUpWorker() { // even if after the first key generation we receive two SOFT_RESET requests // back to back. - if ws.sessionManager.NegotiationState() < session.S_GENERATED_KEYS { + if ws.sessionManager.NegotiationState() < model.S_GENERATED_KEYS { continue } - ws.sessionManager.SetNegotiationState(session.S_INITIAL) + ws.sessionManager.SetNegotiationState(model.S_INITIAL) // TODO(ainghazal): revisit this step. // when we implement key rotation. OpenVPN has // the concept of a "lame duck", i.e., the diff --git a/internal/datachannel/service.go b/internal/datachannel/service.go index 32d81aea..23354f87 100644 --- a/internal/datachannel/service.go +++ b/internal/datachannel/service.go @@ -192,7 +192,7 @@ func (ws *workersState) keyWorker(firstKeyReady chan<- any) { ws.logger.Warnf("error on key derivation: %v", err) continue } - ws.sessionManager.SetNegotiationState(session.S_GENERATED_KEYS) + ws.sessionManager.SetNegotiationState(model.S_GENERATED_KEYS) once.Do(func() { close(firstKeyReady) }) diff --git a/internal/model/session.go b/internal/model/session.go new file mode 100644 index 00000000..5e181737 --- /dev/null +++ b/internal/model/session.go @@ -0,0 +1,59 @@ +package model + +// NegotiationState is the state of the session negotiation. +type NegotiationState int + +const ( + // S_ERROR means there was some form of protocol error. + S_ERROR = NegotiationState(iota) - 1 + + // S_UNDER is the undefined state. + S_UNDEF + + // S_INITIAL means we're ready to begin the three-way handshake. + S_INITIAL + + // S_PRE_START means we're waiting for acknowledgment from the remote. + S_PRE_START + + // S_START means we've done the three-way handshake. + S_START + + // S_SENT_KEY means we have sent the local part of the key_source2 random material. + S_SENT_KEY + + // S_GOT_KEY means we have got the remote part of key_source2. + S_GOT_KEY + + // S_ACTIVE means the control channel was established. + S_ACTIVE + + // S_GENERATED_KEYS means the data channel keys have been generated. + S_GENERATED_KEYS +) + +// String maps a [SessionNegotiationState] to a string. +func (sns NegotiationState) String() string { + switch sns { + case S_UNDEF: + return "S_UNDEF" + case S_INITIAL: + return "S_INITIAL" + case S_PRE_START: + return "S_PRE_START" + case S_START: + return "S_START" + case S_SENT_KEY: + return "S_SENT_KEY" + case S_GOT_KEY: + return "S_GOT_KEY" + case S_ACTIVE: + return "S_ACTIVE" + case S_GENERATED_KEYS: + return "S_GENERATED_KEYS" + case S_ERROR: + return "S_ERROR" + default: + return "S_INVALID" + } +} diff --git a/internal/model/trace.go b/internal/model/trace.go index fc564ea7..ff14cfaf 100644 --- a/internal/model/trace.go +++ b/internal/model/trace.go @@ -12,16 +12,16 @@ type HandshakeTracer interface { TimeNow() time.Time // OnStateChange is called for each transition in the state machine. - OnStateChange(state int) + OnStateChange(state NegotiationState) // OnIncomingPacket is called when a packet is received. - OnIncomingPacket(packet *Packet, stage int) + OnIncomingPacket(packet *Packet, stage NegotiationState) // OnOutgoingPacket is called when a packet is about to be sent. - OnOutgoingPacket(packet *Packet, stage int, retries int) + OnOutgoingPacket(packet *Packet, stage NegotiationState, retries int) // OnDroppedPacket is called whenever a packet is dropped (in/out) - OnDroppedPacket(direction Direction, stage int, packet *Packet) + OnDroppedPacket(direction Direction, stage NegotiationState, packet *Packet) } // Direction is one of two directions on a packet. @@ -57,16 +57,16 @@ type dummyTracer struct{} func (dt *dummyTracer) TimeNow() time.Time { return time.Now() } // OnStateChange is called for each transition in the state machine. -func (dt *dummyTracer) OnStateChange(int) {} +func (dt *dummyTracer) OnStateChange(NegotiationState) {} // OnIncomingPacket is called when a packet is received. -func (dt *dummyTracer) OnIncomingPacket(*Packet, int) {} +func (dt *dummyTracer) OnIncomingPacket(*Packet, NegotiationState) {} // OnOutgoingPacket is called when a packet is about to be sent. -func (dt *dummyTracer) OnOutgoingPacket(*Packet, int, int) {} +func (dt *dummyTracer) OnOutgoingPacket(*Packet, NegotiationState, int) {} // OnDroppedPacket is called whenever a packet is dropped (in/out) -func (dt *dummyTracer) OnDroppedPacket(Direction, int, *Packet) {} +func (dt *dummyTracer) OnDroppedPacket(Direction, NegotiationState, *Packet) {} // Assert that dummyTracer implements [model.HandshakeTracer]. var _ HandshakeTracer = &dummyTracer{} diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index 7f3e3222..050fe035 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -199,7 +199,7 @@ func (ws *workersState) startHardReset() error { ws.hardResetCount++ // reset the state to become initial again. - ws.sessionManager.SetNegotiationState(session.S_PRE_START) + ws.sessionManager.SetNegotiationState(model.S_PRE_START) // emit a CONTROL_HARD_RESET_CLIENT_V2 pkt packet := ws.sessionManager.NewHardResetPacket() @@ -223,7 +223,7 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { } // handle the case where we're performing a HARD_RESET - if ws.sessionManager.NegotiationState() == session.S_PRE_START && + if ws.sessionManager.NegotiationState() == model.S_PRE_START && packet.Opcode == model.P_CONTROL_HARD_RESET_SERVER_V2 { packet.Log(ws.logger, model.DirectionIncoming) ws.hardResetTicker.Stop() @@ -238,7 +238,7 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { return workers.ErrShutdown } } else { - if ws.sessionManager.NegotiationState() < session.S_GENERATED_KEYS { + if ws.sessionManager.NegotiationState() < model.S_GENERATED_KEYS { // A well-behaved server should not send us data packets // before we have a working session. Under normal operations, the // connection in the client side should pick a different port, @@ -269,7 +269,7 @@ func (ws *workersState) finishThreeWayHandshake(packet *model.Packet) error { ws.sessionManager.SetRemoteSessionID(packet.LocalSessionID) // advance the state - ws.sessionManager.SetNegotiationState(session.S_START) + ws.sessionManager.SetNegotiationState(model.S_START) // pass the packet up so that we can ack it properly select { @@ -302,7 +302,7 @@ func (ws *workersState) serializeAndEmit(packet *model.Packet) error { ws.tracer.OnOutgoingPacket( packet, - int(ws.sessionManager.NegotiationState()), + ws.sessionManager.NegotiationState(), ws.hardResetCount, ) diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index c3249d4c..91dc37e8 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -31,7 +31,7 @@ func (ws *workersState) moveUpWorker() { // or POSSIBLY BLOCK waiting for notifications select { case packet := <-ws.muxerToReliable: - ws.tracer.OnIncomingPacket(packet, int(ws.sessionManager.NegotiationState())) + ws.tracer.OnIncomingPacket(packet, ws.sessionManager.NegotiationState()) if packet.Opcode != model.P_CONTROL_HARD_RESET_SERVER_V2 { // the hard reset has already been logged by the layer below @@ -63,7 +63,7 @@ func (ws *workersState) moveUpWorker() { // TODO: add reason ws.tracer.OnDroppedPacket( model.DirectionIncoming, - int(ws.sessionManager.NegotiationState()), + ws.sessionManager.NegotiationState(), packet) ws.logger.Debugf("Dropping packet: %v", packet.ID) continue diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index c3dae808..0fb14fde 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -86,8 +86,8 @@ func (ws *workersState) blockOnTryingToSend(sender *reliableSender, ticker *time p.packet.Log(ws.logger, model.DirectionOutgoing) ws.tracer.OnOutgoingPacket( p.packet, - int(ws.sessionManager.NegotiationState()), - int(p.retries), + ws.sessionManager.NegotiationState(), + p.retries, ) select { diff --git a/internal/session/manager.go b/internal/session/manager.go index 0bffcaf2..4a4ec974 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -13,64 +13,6 @@ import ( "github.com/ooni/minivpn/internal/runtimex" ) -// SessionNegotiationState is the state of the session negotiation. -type SessionNegotiationState int - -const ( - // S_ERROR means there was some form of protocol error. - S_ERROR = SessionNegotiationState(iota) - 1 - - // S_UNDER is the undefined state. - S_UNDEF - - // S_INITIAL means we're ready to begin the three-way handshake. - S_INITIAL - - // S_PRE_START means we're waiting for acknowledgment from the remote. - S_PRE_START - - // S_START means we've done the three-way handshake. - S_START - - // S_SENT_KEY means we have sent the local part of the key_source2 random material. - S_SENT_KEY - - // S_GOT_KEY means we have got the remote part of key_source2. - S_GOT_KEY - - // S_ACTIVE means the control channel was established. - S_ACTIVE - - // S_GENERATED_KEYS means the data channel keys have been generated. - S_GENERATED_KEYS -) - -// String maps a [SessionNegotiationState] to a string. -func (sns SessionNegotiationState) String() string { - switch sns { - case S_UNDEF: - return "S_UNDEF" - case S_INITIAL: - return "S_INITIAL" - case S_PRE_START: - return "S_PRE_START" - case S_START: - return "S_START" - case S_SENT_KEY: - return "S_SENT_KEY" - case S_GOT_KEY: - return "S_GOT_KEY" - case S_ACTIVE: - return "S_ACTIVE" - case S_GENERATED_KEYS: - return "S_GENERATED_KEYS" - case S_ERROR: - return "S_ERROR" - default: - return "S_INVALID" - } -} - // Manager manages the session. The zero value is invalid. Please, construct // using [NewManager]. This struct is concurrency safe. type Manager struct { @@ -81,7 +23,7 @@ type Manager struct { localSessionID model.SessionID logger model.Logger mu sync.Mutex - negState SessionNegotiationState + negState model.NegotiationState remoteSessionID optional.Value[model.SessionID] tunnelInfo model.TunnelInfo tracer model.HandshakeTracer @@ -263,20 +205,20 @@ func (m *Manager) localControlPacketIDLocked() (model.PacketID, error) { } // NegotiationState returns the state of the negotiation. -func (m *Manager) NegotiationState() SessionNegotiationState { +func (m *Manager) NegotiationState() model.NegotiationState { defer m.mu.Unlock() m.mu.Lock() return m.negState } // SetNegotiationState sets the state of the negotiation. -func (m *Manager) SetNegotiationState(sns SessionNegotiationState) { +func (m *Manager) SetNegotiationState(sns model.NegotiationState) { defer m.mu.Unlock() m.mu.Lock() m.logger.Infof("[@] %s -> %s", m.negState, sns) - m.tracer.OnStateChange(int(sns)) + m.tracer.OnStateChange(sns) m.negState = sns - if sns == S_GENERATED_KEYS { + if sns == model.S_GENERATED_KEYS { m.Ready <- true } } diff --git a/internal/tlssession/tlssession.go b/internal/tlssession/tlssession.go index def3c809..1fb88d79 100644 --- a/internal/tlssession/tlssession.go +++ b/internal/tlssession/tlssession.go @@ -156,7 +156,7 @@ func (ws *workersState) doTLSAuth(conn net.Conn, config *tls.Config, errorch cha errorch <- err return } - ws.sessionManager.SetNegotiationState(session.S_SENT_KEY) + ws.sessionManager.SetNegotiationState(model.S_SENT_KEY) // read the server's keySource and options remoteKey, serverOptions, err := ws.recvAuthReplyMessage(tlsConn) @@ -174,7 +174,7 @@ func (ws *workersState) doTLSAuth(conn net.Conn, config *tls.Config, errorch cha // add the remote key to the active key activeKey.AddRemoteKey(remoteKey) - ws.sessionManager.SetNegotiationState(session.S_GOT_KEY) + ws.sessionManager.SetNegotiationState(model.S_GOT_KEY) // send the push request if err := ws.sendPushRequestMessage(tlsConn); err != nil { @@ -193,7 +193,7 @@ func (ws *workersState) doTLSAuth(conn net.Conn, config *tls.Config, errorch cha ws.sessionManager.UpdateTunnelInfo(tinfo) // progress to the ACTIVE state - ws.sessionManager.SetNegotiationState(session.S_ACTIVE) + ws.sessionManager.SetNegotiationState(model.S_ACTIVE) // notify the datachannel that we've got a key pair ready to use ws.keyUp <- activeKey diff --git a/pkg/tracex/trace.go b/pkg/tracex/trace.go index 22fd2339..09979fa1 100644 --- a/pkg/tracex/trace.go +++ b/pkg/tracex/trace.go @@ -9,7 +9,6 @@ import ( "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/optional" - "github.com/ooni/minivpn/internal/session" ) const ( @@ -41,8 +40,8 @@ func (e HandshakeEventType) String() string { } } -// event is a handshake event collected by this [model.HandshakeTracer]. -type event struct { +// Event is a handshake event collected by this [model.HandshakeTracer]. +type Event struct { // EventType is the type for this event. EventType string `json:"operation"` @@ -59,8 +58,10 @@ type event struct { LoggedPacket optional.Value[LoggedPacket] `json:"packet"` } -func newEvent(etype HandshakeEventType, st session.SessionNegotiationState, t time.Time, t0 time.Time) *event { - return &event{ +type NegotiationState = model.NegotiationState + +func newEvent(etype HandshakeEventType, st NegotiationState, t time.Time, t0 time.Time) *Event { + return &Event{ EventType: etype.String(), Stage: st.String()[2:], AtTime: t.Sub(t0).Seconds(), @@ -72,7 +73,7 @@ func newEvent(etype HandshakeEventType, st session.SessionNegotiationState, t ti // Tracer implements [model.HandshakeTracer]. type Tracer struct { // events is the array of handshake events. - events []*event + events []*Event // mu guards access to the events. mu sync.Mutex @@ -94,55 +95,51 @@ func (t *Tracer) TimeNow() time.Time { } // OnStateChange is called for each transition in the state machine. -func (t *Tracer) OnStateChange(state int) { +func (t *Tracer) OnStateChange(state NegotiationState) { t.mu.Lock() defer t.mu.Unlock() - stg := session.SessionNegotiationState(state) - e := newEvent(handshakeEventStateChange, stg, t.TimeNow(), t.zeroTime) + e := newEvent(handshakeEventStateChange, state, t.TimeNow(), t.zeroTime) t.events = append(t.events, e) } // OnIncomingPacket is called when a packet is received. -func (t *Tracer) OnIncomingPacket(packet *model.Packet, stage int) { +func (t *Tracer) OnIncomingPacket(packet *model.Packet, stage NegotiationState) { t.mu.Lock() defer t.mu.Unlock() - stg := session.SessionNegotiationState(stage) - e := newEvent(handshakeEventPacketIn, stg, t.TimeNow(), t.zeroTime) + e := newEvent(handshakeEventPacketIn, stage, t.TimeNow(), t.zeroTime) e.LoggedPacket = logPacket(packet, optional.None[int](), model.DirectionIncoming) maybeAddTagsFromPacket(e, packet) t.events = append(t.events, e) } // OnOutgoingPacket is called when a packet is about to be sent. -func (t *Tracer) OnOutgoingPacket(packet *model.Packet, stage int, retries int) { +func (t *Tracer) OnOutgoingPacket(packet *model.Packet, stage NegotiationState, retries int) { t.mu.Lock() defer t.mu.Unlock() - stg := session.SessionNegotiationState(stage) - e := newEvent(handshakeEventPacketOut, stg, t.TimeNow(), t.zeroTime) + e := newEvent(handshakeEventPacketOut, stage, t.TimeNow(), t.zeroTime) e.LoggedPacket = logPacket(packet, optional.Some(retries), model.DirectionOutgoing) maybeAddTagsFromPacket(e, packet) t.events = append(t.events, e) } // OnDroppedPacket is called whenever a packet is dropped (in/out) -func (t *Tracer) OnDroppedPacket(direction model.Direction, stage int, packet *model.Packet) { +func (t *Tracer) OnDroppedPacket(direction model.Direction, stage NegotiationState, packet *model.Packet) { t.mu.Lock() defer t.mu.Unlock() - stg := session.SessionNegotiationState(stage) - e := newEvent(handshakeEventPacketDropped, stg, t.TimeNow(), t.zeroTime) + e := newEvent(handshakeEventPacketDropped, stage, t.TimeNow(), t.zeroTime) e.LoggedPacket = logPacket(packet, optional.None[int](), direction) t.events = append(t.events, e) } // Trace returns a structured log containing a copy of the array of [model.HandshakeEvent]. -func (t *Tracer) Trace() []*event { +func (t *Tracer) Trace() []*Event { t.mu.Lock() defer t.mu.Unlock() - return append([]*event{}, t.events...) + return append([]*Event{}, t.events...) } func logPacket(p *model.Packet, retries optional.Value[int], direction model.Direction) optional.Value[LoggedPacket] { @@ -178,7 +175,7 @@ type LoggedPacket struct { // maybeAddTagsFromPacket attempts to derive meaningful tags from // the packet payload, and adds it to the tag array in the passed event. -func maybeAddTagsFromPacket(e *event, packet *model.Packet) { +func maybeAddTagsFromPacket(e *Event, packet *model.Packet) { if len(packet.Payload) <= 0 { return }