diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f86ab5c9..4995f6a8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,15 +42,3 @@ jobs: go-version: '1.20' - name: Ensure coverage threshold run: make test-coverage-threshold - - integration: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: setup go - uses: actions/setup-go@v2 - with: - go-version: '1.20' - - name: run integration tests - run: go test -v ./tests/integration - diff --git a/internal/controlchannel/controlchannel.go b/internal/controlchannel/controlchannel.go index 22c1ab44..b13a6087 100644 --- a/internal/controlchannel/controlchannel.go +++ b/internal/controlchannel/controlchannel.go @@ -1,3 +1,5 @@ +// Package controlchannel implements the control channel logic. The control channel sits +// above the reliable transport and below the TLS layer. package controlchannel import ( @@ -36,12 +38,12 @@ type Service struct { // // [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md func (svc *Service) StartWorkers( - logger model.Logger, + config *model.Config, workersManager *workers.Manager, sessionManager *session.Manager, ) { ws := &workersState{ - logger: logger, + logger: config.Logger(), notifyTLS: *svc.NotifyTLS, controlToReliable: *svc.ControlToReliable, reliableToControl: svc.ReliableToControl, @@ -90,10 +92,10 @@ func (ws *workersState) moveUpWorker() { // even if after the first key generation we receive two SOFT_RESET requests // back to back. - if ws.sessionManager.NegotiationState() < session.S_GENERATED_KEYS { + if ws.sessionManager.NegotiationState() < model.S_GENERATED_KEYS { continue } - ws.sessionManager.SetNegotiationState(session.S_INITIAL) + ws.sessionManager.SetNegotiationState(model.S_INITIAL) // TODO(ainghazal): revisit this step. // when we implement key rotation. OpenVPN has // the concept of a "lame duck", i.e., the diff --git a/internal/datachannel/controller.go b/internal/datachannel/controller.go index 40b5ec8a..addcb2c5 100644 --- a/internal/datachannel/controller.go +++ b/internal/datachannel/controller.go @@ -25,7 +25,7 @@ type dataChannelHandler interface { // DataChannel represents the data "channel", that will encrypt and decrypt the tunnel payloads. // data implements the dataHandler interface. type DataChannel struct { - options *model.Options + options *model.OpenVPNOptions sessionManager *session.Manager state *dataChannelState decodeFn func(model.Logger, []byte, *session.Manager, *dataChannelState) (*encryptedData, error) @@ -39,7 +39,7 @@ var _ dataChannelHandler = &DataChannel{} // Ensure that we implement dataChanne // NewDataChannelFromOptions returns a new data object, initialized with the // options given. it also returns any error raised. func NewDataChannelFromOptions(log model.Logger, - opt *model.Options, + opt *model.OpenVPNOptions, sessionManager *session.Manager) (*DataChannel, error) { runtimex.Assert(opt != nil, "openvpn datachannel: opts cannot be nil") runtimex.Assert(opt != nil, "openvpn datachannel: opts cannot be nil") diff --git a/internal/datachannel/read.go b/internal/datachannel/read.go index 4d25be2d..75992c20 100644 --- a/internal/datachannel/read.go +++ b/internal/datachannel/read.go @@ -97,7 +97,7 @@ func decodeEncryptedPayloadNonAEAD(log model.Logger, buf []byte, session *sessio // modes are supported at the moment, so no real decompression is done. It // returns a byte array, and an error if the operation could not be completed // successfully. -func maybeDecompress(b []byte, st *dataChannelState, opt *model.Options) ([]byte, error) { +func maybeDecompress(b []byte, st *dataChannelState, opt *model.OpenVPNOptions) ([]byte, error) { if st == nil || st.dataCipher == nil { return []byte{}, fmt.Errorf("%w:%s", errBadInput, "bad state") } diff --git a/internal/datachannel/service.go b/internal/datachannel/service.go index 00837a9b..23354f87 100644 --- a/internal/datachannel/service.go +++ b/internal/datachannel/service.go @@ -50,14 +50,13 @@ type Service struct { // 3. keyWorker BLOCKS on keyUp to read a dataChannelKey and // initializes the internal state with the resulting key; func (s *Service) StartWorkers( - logger model.Logger, + config *model.Config, workersManager *workers.Manager, sessionManager *session.Manager, - options *model.Options, ) { - dc, err := NewDataChannelFromOptions(logger, options, sessionManager) + dc, err := NewDataChannelFromOptions(config.Logger(), config.OpenVPNOptions(), sessionManager) if err != nil { - logger.Warnf("cannot initialize channel %v", err) + config.Logger().Warnf("cannot initialize channel %v", err) return } ws := &workersState{ @@ -65,7 +64,7 @@ func (s *Service) StartWorkers( dataOrControlToMuxer: *s.DataOrControlToMuxer, dataToTUN: s.DataToTUN, keyReady: s.KeyReady, - logger: logger, + logger: config.Logger(), muxerToData: s.MuxerToData, sessionManager: sessionManager, tunToData: s.TUNToData, @@ -193,7 +192,7 @@ func (ws *workersState) keyWorker(firstKeyReady chan<- any) { ws.logger.Warnf("error on key derivation: %v", err) continue } - ws.sessionManager.SetNegotiationState(session.S_GENERATED_KEYS) + ws.sessionManager.SetNegotiationState(model.S_GENERATED_KEYS) once.Do(func() { close(firstKeyReady) }) diff --git a/internal/model/config.go b/internal/model/config.go new file mode 100644 index 00000000..d050e388 --- /dev/null +++ b/internal/model/config.go @@ -0,0 +1,96 @@ +package model + +import ( + "net" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/runtimex" +) + +// Config contains options to initialize the OpenVPN tunnel. +type Config struct { + // openVPNOptions contains options related to openvpn. + openvpnOptions *OpenVPNOptions + + // logger will be used to log events. + logger Logger + + // if a tracer is provided, it will be used to trace the openvpn handshake. + tracer HandshakeTracer +} + +// NewConfig returns a Config ready to intialize a vpn tunnel. +func NewConfig(options ...Option) *Config { + cfg := &Config{ + openvpnOptions: &OpenVPNOptions{}, + logger: log.Log, + tracer: &dummyTracer{}, + } + for _, opt := range options { + opt(cfg) + } + return cfg +} + +// Option is an option you can pass to initialize minivpn. +type Option func(config *Config) + +// WithConfigFile configures OpenVPNOptions parsed from the given file. +func WithConfigFile(configPath string) Option { + return func(config *Config) { + openvpnOpts, err := ReadConfigFile(configPath) + runtimex.PanicOnError(err, "cannot parse config file") + runtimex.PanicIfFalse(openvpnOpts.HasAuthInfo(), "missing auth info") + config.openvpnOptions = openvpnOpts + } +} + +// WithLogger configures the passed [Logger]. +func WithLogger(logger Logger) Option { + return func(config *Config) { + config.logger = logger + } +} + +// WithHandshakeTracer configures the passed [HandshakeTracer]. +func WithHandshakeTracer(tracer HandshakeTracer) Option { + return func(config *Config) { + config.tracer = tracer + } +} + +// Logger returns the configured logger. +func (c *Config) Logger() Logger { + return c.logger +} + +// Tracer returns the handshake tracer. +func (c *Config) Tracer() HandshakeTracer { + return c.tracer +} + +// OpenVPNOptions returns the configured openvpn options. +func (c *Config) OpenVPNOptions() *OpenVPNOptions { + return c.openvpnOptions +} + +// Remote returns the OpenVPN remote. +func (c *Config) Remote() *Remote { + return &Remote{ + IPAddr: c.openvpnOptions.Remote, + Endpoint: net.JoinHostPort(c.openvpnOptions.Remote, c.openvpnOptions.Port), + Protocol: c.openvpnOptions.Proto.String(), + } +} + +// Remote has info about the OpenVPN remote. +type Remote struct { + // IPAddr is the IP Address for the remote. + IPAddr string + + // Endpoint is in the form ip:port. + Endpoint string + + // Protocol is either "tcp" or "udp" + Protocol string +} diff --git a/internal/model/packet.go b/internal/model/packet.go index f1256a4b..51409db1 100644 --- a/internal/model/packet.go +++ b/internal/model/packet.go @@ -334,13 +334,8 @@ func (p *Packet) IsData() bool { return p.Opcode.IsData() } -const ( - DirectionIncoming = iota - DirectionOutgoing -) - // Log writes an entry in the passed logger with a representation of this packet. -func (p *Packet) Log(logger Logger, direction int) { +func (p *Packet) Log(logger Logger, direction Direction) { var dir string switch direction { case DirectionIncoming: diff --git a/internal/model/session.go b/internal/model/session.go new file mode 100644 index 00000000..5e181737 --- /dev/null +++ b/internal/model/session.go @@ -0,0 +1,59 @@ +package model + +// NegotiationState is the state of the session negotiation. +type NegotiationState int + +const ( + // S_ERROR means there was some form of protocol error. + S_ERROR = NegotiationState(iota) - 1 + + // S_UNDER is the undefined state. + S_UNDEF + + // S_INITIAL means we're ready to begin the three-way handshake. + S_INITIAL + + // S_PRE_START means we're waiting for acknowledgment from the remote. + S_PRE_START + + // S_START means we've done the three-way handshake. + S_START + + // S_SENT_KEY means we have sent the local part of the key_source2 random material. + S_SENT_KEY + + // S_GOT_KEY means we have got the remote part of key_source2. + S_GOT_KEY + + // S_ACTIVE means the control channel was established. + S_ACTIVE + + // S_GENERATED_KEYS means the data channel keys have been generated. + S_GENERATED_KEYS +) + +// String maps a [SessionNegotiationState] to a string. +func (sns NegotiationState) String() string { + switch sns { + case S_UNDEF: + return "S_UNDEF" + case S_INITIAL: + return "S_INITIAL" + case S_PRE_START: + return "S_PRE_START" + case S_START: + return "S_START" + case S_SENT_KEY: + return "S_SENT_KEY" + case S_GOT_KEY: + return "S_GOT_KEY" + case S_ACTIVE: + return "S_ACTIVE" + case S_GENERATED_KEYS: + return "S_GENERATED_KEYS" + case S_ERROR: + return "S_ERROR" + default: + return "S_INVALID" + } +} diff --git a/internal/model/trace.go b/internal/model/trace.go new file mode 100644 index 00000000..ff14cfaf --- /dev/null +++ b/internal/model/trace.go @@ -0,0 +1,72 @@ +package model + +import ( + "fmt" + "time" +) + +// HandshakeTracer allows to collect traces for a given OpenVPN handshake. A HandshakeTracer can be optionally +// added to the top-level TUN constructor, and it will be propagated to any layer that needs to register an event. +type HandshakeTracer interface { + // TimeNow allows to inject time for deterministic tests. + TimeNow() time.Time + + // OnStateChange is called for each transition in the state machine. + OnStateChange(state NegotiationState) + + // OnIncomingPacket is called when a packet is received. + OnIncomingPacket(packet *Packet, stage NegotiationState) + + // OnOutgoingPacket is called when a packet is about to be sent. + OnOutgoingPacket(packet *Packet, stage NegotiationState, retries int) + + // OnDroppedPacket is called whenever a packet is dropped (in/out) + OnDroppedPacket(direction Direction, stage NegotiationState, packet *Packet) +} + +// Direction is one of two directions on a packet. +type Direction int + +const ( + // DirectionIncoming marks received packets. + DirectionIncoming = Direction(iota) + + // DirectionOutgoing marks packets to be sent. + DirectionOutgoing +) + +var _ fmt.Stringer = Direction(0) + +// String implements fmt.Stringer +func (d Direction) String() string { + switch d { + case DirectionIncoming: + return "read" + case DirectionOutgoing: + return "write" + default: + return "undefined" + } +} + +// dummyTracer is a no-op implementation of [model.HandshakeTracer] that does nothing +// but can be safely passed as a default implementation. +type dummyTracer struct{} + +// TimeNow allows to manipulate time for deterministic tests. +func (dt *dummyTracer) TimeNow() time.Time { return time.Now() } + +// OnStateChange is called for each transition in the state machine. +func (dt *dummyTracer) OnStateChange(NegotiationState) {} + +// OnIncomingPacket is called when a packet is received. +func (dt *dummyTracer) OnIncomingPacket(*Packet, NegotiationState) {} + +// OnOutgoingPacket is called when a packet is about to be sent. +func (dt *dummyTracer) OnOutgoingPacket(*Packet, NegotiationState, int) {} + +// OnDroppedPacket is called whenever a packet is dropped (in/out) +func (dt *dummyTracer) OnDroppedPacket(Direction, NegotiationState, *Packet) {} + +// Assert that dummyTracer implements [model.HandshakeTracer]. +var _ HandshakeTracer = &dummyTracer{} diff --git a/internal/model/options.go b/internal/model/vpnoptions.go similarity index 89% rename from internal/model/options.go rename to internal/model/vpnoptions.go index bd46c6e5..c10e8ddc 100644 --- a/internal/model/options.go +++ b/internal/model/vpnoptions.go @@ -5,7 +5,7 @@ package model // // Mostly, this file conforms to the format in the reference implementation. // However, there are some additions that are specific. To avoid feature creep -// and fat dependencies, the main `vpn` module only supports mainline +// and fat dependencies, the internal implementation only supports mainline // capabilities. It is still useful to carry all options in a single type, // so it's up to the user of this library to do something useful with // such options. The `extra` package provides some of these extra features, like @@ -92,10 +92,10 @@ var SupportedAuth = []string{ "SHA512", } -// Options make all the relevant configuration options accessible to the +// OpenVPNOptions make all the relevant openvpn configuration options accessible to the // different modules that need it. -type Options struct { - // These options have the same name of OpenVPN options: +type OpenVPNOptions struct { + // These options have the same name of OpenVPN options referenced in the official documentation: Remote string Port string Proto Proto @@ -111,7 +111,9 @@ type Options struct { Auth string TLSMaxVer string - // Below are options that do not conform to the OpenVPN configuration format: + // Below are options that do not conform strictly to the OpenVPN configuration format, but still can + // be understood by us in a configuration file: + Compress Compression ProxyOBFS4 string } @@ -119,7 +121,7 @@ type Options struct { // ReadConfigFile expects a string with a path to a valid config file, // and returns a pointer to a Options struct after parsing the file, and an // error if the operation could not be completed. -func ReadConfigFile(filePath string) (*Options, error) { +func ReadConfigFile(filePath string) (*OpenVPNOptions, error) { lines, err := getLinesFromFile(filePath) dir, _ := filepath.Split(filePath) if err != nil { @@ -130,7 +132,7 @@ func ReadConfigFile(filePath string) (*Options, error) { // ShouldLoadCertsFromPath returns true when the options object is configured to load // certificates from paths; false when we have inline certificates. -func (o *Options) ShouldLoadCertsFromPath() bool { +func (o *OpenVPNOptions) ShouldLoadCertsFromPath() bool { return o.CertPath != "" && o.KeyPath != "" && o.CAPath != "" } @@ -138,7 +140,7 @@ func (o *Options) ShouldLoadCertsFromPath() bool { // - we have paths for cert, key and ca; or // - we have inline byte arrays for cert, key and ca; or // - we have username + password info. -func (o *Options) HasAuthInfo() bool { +func (o *OpenVPNOptions) HasAuthInfo() bool { if o.CertPath != "" && o.KeyPath != "" && o.CAPath != "" { return true } @@ -156,7 +158,7 @@ const clientOptions = "V4,dev-type tun,link-mtu 1549,tun-mtu 1500,proto %sv4,cip // ServerOptionsString produces a comma-separated representation of the options, in the same // order and format that the OpenVPN server expects from us. -func (o *Options) ServerOptionsString() string { +func (o *OpenVPNOptions) ServerOptionsString() string { if o.Cipher == "" { return "" } @@ -258,7 +260,7 @@ func PushedOptionsAsMap(pushedOptions []byte) map[string][]string { return optMap } -func parseProto(p []string, o *Options) error { +func parseProto(p []string, o *OpenVPNOptions) error { if len(p) != 1 { return fmt.Errorf("%w: %s", ErrBadConfig, "proto needs one arg") } @@ -277,7 +279,7 @@ func parseProto(p []string, o *Options) error { // TODO(ainghazal): all these little functions can be better tested if we return the options object too -func parseRemote(p []string, o *Options) error { +func parseRemote(p []string, o *OpenVPNOptions) error { if len(p) != 2 { return fmt.Errorf("%w: %s", ErrBadConfig, "remote needs two args") } @@ -285,7 +287,7 @@ func parseRemote(p []string, o *Options) error { return nil } -func parseCipher(p []string, o *Options) error { +func parseCipher(p []string, o *OpenVPNOptions) error { if len(p) != 1 { return fmt.Errorf("%w: %s", ErrBadConfig, "cipher expects one arg") } @@ -297,7 +299,7 @@ func parseCipher(p []string, o *Options) error { return nil } -func parseAuth(p []string, o *Options) error { +func parseAuth(p []string, o *OpenVPNOptions) error { if len(p) != 1 { return fmt.Errorf("%w: %s", ErrBadConfig, "invalid auth entry") } @@ -309,7 +311,7 @@ func parseAuth(p []string, o *Options) error { return nil } -func parseCA(p []string, o *Options, basedir string) error { +func parseCA(p []string, o *OpenVPNOptions, basedir string) error { e := fmt.Errorf("%w: %s", ErrBadConfig, "ca expects a valid file") if len(p) != 1 { return e @@ -325,7 +327,7 @@ func parseCA(p []string, o *Options, basedir string) error { return nil } -func parseCert(p []string, o *Options, basedir string) error { +func parseCert(p []string, o *OpenVPNOptions, basedir string) error { e := fmt.Errorf("%w: %s", ErrBadConfig, "cert expects a valid file") if len(p) != 1 { return e @@ -341,7 +343,7 @@ func parseCert(p []string, o *Options, basedir string) error { return nil } -func parseKey(p []string, o *Options, basedir string) error { +func parseKey(p []string, o *OpenVPNOptions, basedir string) error { e := fmt.Errorf("%w: %s", ErrBadConfig, "key expects a valid file") if len(p) != 1 { return e @@ -360,7 +362,7 @@ func parseKey(p []string, o *Options, basedir string) error { // parseAuthUser reads credentials from a given file, according to the openvpn // format (user and pass on a line each). To avoid path traversal / LFI, the // credentials file is expected to be in a subdirectory of the base dir. -func parseAuthUser(p []string, o *Options, basedir string) error { +func parseAuthUser(p []string, o *OpenVPNOptions, basedir string) error { e := fmt.Errorf("%w: %s", ErrBadConfig, "auth-user-pass expects a valid file") if len(p) != 1 { return e @@ -380,7 +382,7 @@ func parseAuthUser(p []string, o *Options, basedir string) error { return nil } -func parseCompress(p []string, o *Options) error { +func parseCompress(p []string, o *OpenVPNOptions) error { if len(p) > 1 { return fmt.Errorf("%w: %s", ErrBadConfig, "compress: only empty/stub options supported") } @@ -395,7 +397,7 @@ func parseCompress(p []string, o *Options) error { return fmt.Errorf("%w: %s", ErrBadConfig, "compress: only empty/stub options supported") } -func parseCompLZO(p []string, o *Options) error { +func parseCompLZO(p []string, o *OpenVPNOptions) error { if p[0] != "no" { return fmt.Errorf("%w: %s", ErrBadConfig, "comp-lzo: compression not supported") } @@ -405,7 +407,7 @@ func parseCompLZO(p []string, o *Options) error { // parseTLSVerMax sets the maximum TLS version. This is currently ignored // because we're using uTLS to parrot the Client Hello. -func parseTLSVerMax(p []string, o *Options) error { +func parseTLSVerMax(p []string, o *OpenVPNOptions) error { if len(p) == 0 { o.TLSMaxVer = "1.3" return nil @@ -416,7 +418,7 @@ func parseTLSVerMax(p []string, o *Options) error { return nil } -func parseProxyOBFS4(p []string, o *Options) error { +func parseProxyOBFS4(p []string, o *OpenVPNOptions) error { if len(p) != 1 { return fmt.Errorf("%w: %s", ErrBadConfig, "proto-obfs4: need a properly configured proxy") } @@ -443,15 +445,15 @@ var pMapDir = map[string]interface{}{ "auth-user-pass": parseAuthUser, } -func parseOption(o *Options, dir, key string, p []string, lineno int) error { +func parseOption(o *OpenVPNOptions, dir, key string, p []string, lineno int) error { switch key { case "proto", "remote", "cipher", "auth", "compress", "comp-lzo", "tls-version-max", "proxy-obfs4": - fn := pMap[key].(func([]string, *Options) error) + fn := pMap[key].(func([]string, *OpenVPNOptions) error) if e := fn(p, o); e != nil { return e } case "ca", "cert", "key", "auth-user-pass": - fn := pMapDir[key].(func([]string, *Options, string) error) + fn := pMapDir[key].(func([]string, *OpenVPNOptions, string) error) if e := fn(p, o, dir); e != nil { return e } @@ -464,8 +466,8 @@ func parseOption(o *Options, dir, key string, p []string, lineno int) error { // getOptionsFromLines tries to parse all the lines coming from a config file // and raises validation errors if the values do not conform to the expected // format. The config file supports inline file inclusion for , and . -func getOptionsFromLines(lines []string, dir string) (*Options, error) { - opt := &Options{} +func getOptionsFromLines(lines []string, dir string) (*OpenVPNOptions, error) { + opt := &OpenVPNOptions{} // tag and inlineBuf are used to parse inline files. // these follow the format used by the reference openvpn implementation. @@ -561,7 +563,7 @@ func parseTag(tag string) string { } // parseInlineTag -func parseInlineTag(o *Options, tag string, buf *bytes.Buffer) error { +func parseInlineTag(o *OpenVPNOptions, tag string, buf *bytes.Buffer) error { b := buf.Bytes() if len(b) == 0 { return fmt.Errorf("%w: empty inline tag: %d", ErrBadConfig, len(b)) diff --git a/internal/networkio/service.go b/internal/networkio/service.go index f10d5bf4..19ae5c4b 100644 --- a/internal/networkio/service.go +++ b/internal/networkio/service.go @@ -26,13 +26,13 @@ type Service struct { // // [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md func (svc *Service) StartWorkers( - logger model.Logger, + config *model.Config, manager *workers.Manager, conn FramingConn, ) { ws := &workersState{ conn: conn, - logger: logger, + logger: config.Logger(), manager: manager, muxerToNetwork: svc.MuxerToNetwork, networkToMuxer: *svc.NetworkToMuxer, diff --git a/internal/optional/optional.go b/internal/optional/optional.go index f330b812..8c479c54 100644 --- a/internal/optional/optional.go +++ b/internal/optional/optional.go @@ -1,6 +1,9 @@ +// Package optional implements optional values. package optional import ( + "bytes" + "encoding/json" "reflect" "github.com/ooni/minivpn/internal/runtimex" @@ -56,3 +59,43 @@ func (v Value[T]) UnwrapOr(fallback T) T { } return v.Unwrap() } + +var _ json.Unmarshaler = &Value[int]{} + +// UnmarshalJSON implements json.Unmarshaler. Note that a `null` JSON +// value always leads to an empty Value. +func (v *Value[T]) UnmarshalJSON(data []byte) error { + // A `null` underlying value should always be equivalent to + // invoking the None constructor of for T. While this is not + // what the [json] package recommends doing for this case, + // it is consistent with initializing an optional. + if bytes.Equal(data, []byte(`null`)) { + v.indirect = nil + return nil + } + + // Otherwise, let's try to unmarshal into a real value + var value T + if err := json.Unmarshal(data, &value); err != nil { + return err + } + + // Enforce the same semantics of the Some constructor: treat + // pointer types specially to avoid the case where we have + // a Value that is wrapping a nil pointer but for which the + // IsNone check actually returns false. (Maybe this check is + // redundant but it seems better to enforce it anyway.) + maybeSetFromValue(v, value) + return nil +} + +var _ json.Marshaler = Value[int]{} + +// MarshalJSON implements json.Marshaler. An empty value serializes +// to `null` and otherwise we serialize the underluing value. +func (v Value[T]) MarshalJSON() ([]byte, error) { + if v.indirect == nil { + return json.Marshal(nil) + } + return json.Marshal(*v.indirect) +} diff --git a/internal/packetmuxer/service.go b/internal/packetmuxer/service.go index c67d3c53..050fe035 100644 --- a/internal/packetmuxer/service.go +++ b/internal/packetmuxer/service.go @@ -50,12 +50,12 @@ type Service struct { // // [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md func (s *Service) StartWorkers( - logger model.Logger, + config *model.Config, workersManager *workers.Manager, sessionManager *session.Manager, ) { ws := &workersState{ - logger: logger, + logger: config.Logger(), hardReset: s.HardReset, // initialize to a sufficiently long time from now hardResetTicker: time.NewTicker(longWakeup), @@ -66,6 +66,7 @@ func (s *Service) StartWorkers( muxerToNetwork: *s.MuxerToNetwork, networkToMuxer: s.NetworkToMuxer, sessionManager: sessionManager, + tracer: config.Tracer(), workersManager: workersManager, } workersManager.StartWorker(ws.moveUpWorker) @@ -107,6 +108,9 @@ type workersState struct { // sessionManager manages the OpenVPN session. sessionManager *session.Manager + // tracer is a [model.HandshakeTracer]. + tracer model.HandshakeTracer + // workersManager controls the workers lifecycle. workersManager *workers.Manager } @@ -128,7 +132,8 @@ func (ws *workersState) moveUpWorker() { case rawPacket := <-ws.networkToMuxer: if err := ws.handleRawPacket(rawPacket); err != nil { // error already printed - return + // TODO(ainghazal): trace malformed input + continue } case <-ws.hardResetTicker.C: @@ -190,7 +195,11 @@ func (ws *workersState) moveDownWorker() { // startHardReset is invoked when we need to perform a HARD RESET. func (ws *workersState) startHardReset() error { - ws.hardResetCount += 1 + // increment the hard reset counter for retries + ws.hardResetCount++ + + // reset the state to become initial again. + ws.sessionManager.SetNegotiationState(model.S_PRE_START) // emit a CONTROL_HARD_RESET_CLIENT_V2 pkt packet := ws.sessionManager.NewHardResetPacket() @@ -201,11 +210,6 @@ func (ws *workersState) startHardReset() error { // resend if not received the server's reply in 2 seconds. ws.hardResetTicker.Reset(time.Second * 2) - // reset the state to become initial again. - ws.sessionManager.SetNegotiationState(session.S_PRE_START) - - // TODO: any other change to apply in this case? - return nil } @@ -219,7 +223,7 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { } // handle the case where we're performing a HARD_RESET - if ws.sessionManager.NegotiationState() == session.S_PRE_START && + if ws.sessionManager.NegotiationState() == model.S_PRE_START && packet.Opcode == model.P_CONTROL_HARD_RESET_SERVER_V2 { packet.Log(ws.logger, model.DirectionIncoming) ws.hardResetTicker.Stop() @@ -234,13 +238,20 @@ func (ws *workersState) handleRawPacket(rawPacket []byte) error { return workers.ErrShutdown } } else { - if ws.sessionManager.NegotiationState() < session.S_GENERATED_KEYS { + if ws.sessionManager.NegotiationState() < model.S_GENERATED_KEYS { // A well-behaved server should not send us data packets // before we have a working session. Under normal operations, the // connection in the client side should pick a different port, // so that data sent from previous sessions will not be delivered. - // However, it does not harm to be defensive here. - return errors.New("not ready to handle data") + // However, it does not harm to be defensive here. One such case + // is that we get injected packets intended to mess with the handshake. + // In this case, the caller will drop and log/trace the event. + if packet.IsData() { + ws.logger.Warnf("packetmuxer: moveUpWorker: cannot handle data yet") + return errors.New("not ready to handle data") + } + ws.logger.Warnf("malformed input") + return errors.New("malformed input") } select { case ws.muxerToData <- packet: @@ -258,7 +269,7 @@ func (ws *workersState) finishThreeWayHandshake(packet *model.Packet) error { ws.sessionManager.SetRemoteSessionID(packet.LocalSessionID) // advance the state - ws.sessionManager.SetNegotiationState(session.S_START) + ws.sessionManager.SetNegotiationState(model.S_START) // pass the packet up so that we can ack it properly select { @@ -289,6 +300,12 @@ func (ws *workersState) serializeAndEmit(packet *model.Packet) error { return err } + ws.tracer.OnOutgoingPacket( + packet, + ws.sessionManager.NegotiationState(), + ws.hardResetCount, + ) + // emit the packet. Possibly BLOCK writing to the networkio layer. select { case ws.muxerToNetwork <- rawPacket: diff --git a/internal/reliabletransport/common_test.go b/internal/reliabletransport/common_test.go index 63581f48..bd3f3f9b 100644 --- a/internal/reliabletransport/common_test.go +++ b/internal/reliabletransport/common_test.go @@ -6,6 +6,7 @@ import ( "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/runtimex" "github.com/ooni/minivpn/internal/session" + "github.com/ooni/minivpn/internal/vpntest" "github.com/ooni/minivpn/internal/workers" ) @@ -16,7 +17,7 @@ import ( // initManagers initializes a workers manager and a session manager. func initManagers() (*workers.Manager, *session.Manager) { w := workers.NewManager(log.Log) - s, err := session.NewManager(log.Log) + s, err := session.NewManager(model.NewConfig(model.WithLogger(log.Log))) runtimex.PanicOnError(err, "cannot create session manager") return w, s } @@ -45,3 +46,10 @@ func ackSetFromRange(start, total int) *ackSet { } return newACKSet(acks...) } + +func initializeSessionIDForWriter(writer *vpntest.PacketWriter, session *session.Manager) { + peerSessionID := newRandomSessionID() + writer.RemoteSessionID = model.SessionID(session.LocalSessionID()) + writer.LocalSessionID = peerSessionID + session.SetRemoteSessionID(peerSessionID) +} diff --git a/internal/reliabletransport/model.go b/internal/reliabletransport/model.go index b121eee0..b04d23e2 100644 --- a/internal/reliabletransport/model.go +++ b/internal/reliabletransport/model.go @@ -4,18 +4,6 @@ import ( "github.com/ooni/minivpn/internal/model" ) -// sequentialPacket is a packet that can return a [model.PacketID]. -type sequentialPacket interface { - ID() model.PacketID - ExtractACKs() []model.PacketID - Packet() *model.Packet -} - -// retransmissionPacket is a packet that can be scheduled for retransmission. -type retransmissionPacket interface { - ScheduleForRetransmission() -} - type outgoingPacketWriter interface { // TryInsertOutgoingPacket attempts to insert a packet into the // inflight queue. If return value is false, insertion was not successful (e.g., too many diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index 1cb3991a..1e6893d0 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -19,7 +19,7 @@ type inFlightPacket struct { packet *model.Packet // retries is a monotonically increasing counter for retransmission. - retries uint8 + retries int } func newInFlightPacket(p *model.Packet) *inFlightPacket { diff --git a/internal/reliabletransport/packets_test.go b/internal/reliabletransport/packets_test.go index 4b071dc4..298387cc 100644 --- a/internal/reliabletransport/packets_test.go +++ b/internal/reliabletransport/packets_test.go @@ -10,7 +10,7 @@ import ( func Test_inFlightPacket_backoff(t *testing.T) { type fields struct { - retries uint8 + retries int } tests := []struct { name string diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 9c2a927b..91dc37e8 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -7,6 +7,7 @@ import ( "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/optional" + "github.com/ooni/minivpn/internal/session" ) // moveUpWorker moves packets up the stack (receiver). @@ -30,9 +31,10 @@ func (ws *workersState) moveUpWorker() { // or POSSIBLY BLOCK waiting for notifications select { case packet := <-ws.muxerToReliable: + ws.tracer.OnIncomingPacket(packet, ws.sessionManager.NegotiationState()) + if packet.Opcode != model.P_CONTROL_HARD_RESET_SERVER_V2 { // the hard reset has already been logged by the layer below - // TODO: move logging here? packet.Log(ws.logger, model.DirectionIncoming) } @@ -40,17 +42,12 @@ func (ws *workersState) moveUpWorker() { // I'm not sure that's a valid behavior for a server. // We should be able to deterministically test how this affects the state machine. - // drop a packet that is not for our session - if !bytes.Equal(packet.RemoteSessionID[:], ws.sessionManager.LocalSessionID()) { - ws.logger.Warnf( - "%s: packet with invalid RemoteSessionID: expected %x; got %x", - workerName, - ws.sessionManager.LocalSessionID(), - packet.RemoteSessionID, - ) + // sanity check incoming packet + if ok := incomingSanityChecks(ws.logger, workerName, packet, ws.sessionManager); !ok { continue } + // notify seen packet to the sender using the lateral channel. seen := receiver.newIncomingPacketSeen(packet) ws.incomingSeen <- seen @@ -63,6 +60,11 @@ func (ws *workersState) moveUpWorker() { if inserted := receiver.MaybeInsertIncoming(packet); !inserted { // this packet was not inserted in the queue: we drop it + // TODO: add reason + ws.tracer.OnDroppedPacket( + model.DirectionIncoming, + ws.sessionManager.NegotiationState(), + packet) ws.logger.Debugf("Dropping packet: %v", packet.ID) continue } @@ -83,6 +85,36 @@ func (ws *workersState) moveUpWorker() { } } +func incomingSanityChecks(logger model.Logger, workerName string, packet *model.Packet, session *session.Manager) bool { + // drop a packet from a remote session we don't know about. + if !bytes.Equal(packet.LocalSessionID[:], session.RemoteSessionID()) { + logger.Warnf( + "%s: packet with invalid LocalSessionID: got %x; expected %x", + workerName, + packet.LocalSessionID, + session.RemoteSessionID(), + ) + return false + } + + if len(packet.ACKs) == 0 { + return true + } + + // only if we get incoming ACKs we can also check that the remote session id matches our own + // (packets with no ack array do not include remoteSessionID) + if !bytes.Equal(packet.RemoteSessionID[:], session.LocalSessionID()) { + logger.Warnf( + "%s: packet with invalid RemoteSessionID: got %x; expected %x", + workerName, + packet.RemoteSessionID, + session.LocalSessionID(), + ) + return false + } + return true +} + // // incomingPacketHandler implementation. // diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go index 59767974..f0be8548 100644 --- a/internal/reliabletransport/reliable_ack_test.go +++ b/internal/reliabletransport/reliable_ack_test.go @@ -128,16 +128,12 @@ func TestReliable_ACK(t *testing.T) { t0 := time.Now() // let the workers pump up the jam! - s.StartWorkers(log.Log, workers, session) + s.StartWorkers(model.NewConfig(model.WithLogger(log.Log)), workers, session) writer := vpntest.NewPacketWriter(dataIn) // initialize a mock session ID for our peer - peerSessionID := newRandomSessionID() - - writer.RemoteSessionID = model.SessionID(session.LocalSessionID()) - writer.LocalSessionID = peerSessionID - session.SetRemoteSessionID(peerSessionID) + initializeSessionIDForWriter(writer, session) go writer.WriteSequence(tt.args.inputSequence) diff --git a/internal/reliabletransport/reliable_loss_test.go b/internal/reliabletransport/reliable_loss_test.go index 7b429bf7..330c55a3 100644 --- a/internal/reliabletransport/reliable_loss_test.go +++ b/internal/reliabletransport/reliable_loss_test.go @@ -250,7 +250,7 @@ func TestReliable_WithLoss(t *testing.T) { t0 := time.Now() // let the workers pump up the jam! - s.StartWorkers(log.Log, workers, session) + s.StartWorkers(model.NewConfig(model.WithLogger(log.Log)), workers, session) writer := vpntest.NewPacketWriter(dataIn) go writer.WriteSequenceWithFixedPayload(tt.args.inputSequence, tt.args.inputPayload, 3) diff --git a/internal/reliabletransport/reliable_reorder_test.go b/internal/reliabletransport/reliable_reorder_test.go index bd28a096..3f49cbd7 100644 --- a/internal/reliabletransport/reliable_reorder_test.go +++ b/internal/reliabletransport/reliable_reorder_test.go @@ -123,12 +123,10 @@ func TestReliable_Reordering_UP(t *testing.T) { t0 := time.Now() // let the workers pump up the jam! - s.StartWorkers(log.Log, workers, session) + s.StartWorkers(model.NewConfig(model.WithLogger(log.Log)), workers, session) writer := vpntest.NewPacketWriter(dataIn) - - writer.RemoteSessionID = model.SessionID(session.LocalSessionID()) - writer.LocalSessionID = newRandomSessionID() + initializeSessionIDForWriter(writer, session) go writer.WriteSequence(tt.args.inputSequence) diff --git a/internal/reliabletransport/sender.go b/internal/reliabletransport/sender.go index 989e81f9..0fb14fde 100644 --- a/internal/reliabletransport/sender.go +++ b/internal/reliabletransport/sender.go @@ -82,8 +82,13 @@ func (ws *workersState) blockOnTryingToSend(sender *reliableSender, ticker *time // append any pending ACKs p.packet.ACKs = sender.NextPacketIDsToACK() - // log the packet + // log and trace the packet p.packet.Log(ws.logger, model.DirectionOutgoing) + ws.tracer.OnOutgoingPacket( + p.packet, + ws.sessionManager.NegotiationState(), + p.retries, + ) select { case ws.dataOrControlToMuxer <- p.packet: diff --git a/internal/reliabletransport/service.go b/internal/reliabletransport/service.go index b2a1bd75..6ef12708 100644 --- a/internal/reliabletransport/service.go +++ b/internal/reliabletransport/service.go @@ -31,20 +31,21 @@ type Service struct { // // [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md func (s *Service) StartWorkers( - logger model.Logger, + config *model.Config, workersManager *workers.Manager, sessionManager *session.Manager, ) { + // incomingSeen is a buffered channel to avoid losing packets if we're busy + // processing in the sender goroutine. ws := &workersState{ - logger: logger, - // incomingSeen is a buffered channel to avoid losing packets if we're busy - // processing in the sender goroutine. - incomingSeen: make(chan incomingPacketSeen, 100), - dataOrControlToMuxer: *s.DataOrControlToMuxer, controlToReliable: s.ControlToReliable, + dataOrControlToMuxer: *s.DataOrControlToMuxer, + incomingSeen: make(chan incomingPacketSeen, 100), + logger: config.Logger(), muxerToReliable: s.MuxerToReliable, reliableToControl: *s.ReliableToControl, sessionManager: sessionManager, + tracer: config.Tracer(), workersManager: workersManager, } workersManager.StartWorker(ws.moveUpWorker) @@ -53,17 +54,17 @@ func (s *Service) StartWorkers( // workersState contains the reliable workers state type workersState struct { - // logger is the logger to use - logger model.Logger - - // incomingSeen ins the shared channel to connect sender and receiver goroutines. - incomingSeen chan incomingPacketSeen + // controlToReliable is the channel from which we read packets going down the stack. + controlToReliable <-chan *model.Packet // dataOrControlToMuxer is the channel where we write packets going down the stack. dataOrControlToMuxer chan<- *model.Packet - // controlToReliable is the channel from which we read packets going down the stack. - controlToReliable <-chan *model.Packet + // incomingSeen is the shared channel to connect sender and receiver goroutines. + incomingSeen chan incomingPacketSeen + + // logger is the logger to use + logger model.Logger // muxerToReliable is the channel from which we read packets going up the stack. muxerToReliable <-chan *model.Packet @@ -74,6 +75,9 @@ type workersState struct { // sessionManager manages the OpenVPN session. sessionManager *session.Manager + // tracer is a handshake tracer. + tracer model.HandshakeTracer + // workersManager controls the workers lifecycle. workersManager *workers.Manager } diff --git a/internal/reliabletransport/service_test.go b/internal/reliabletransport/service_test.go index 3555fd7f..aeb4bc49 100644 --- a/internal/reliabletransport/service_test.go +++ b/internal/reliabletransport/service_test.go @@ -18,7 +18,7 @@ func TestService_StartWorkers(t *testing.T) { ReliableToControl *chan *model.Packet } type args struct { - logger model.Logger + config *model.Config workersManager *workers.Manager sessionManager *session.Manager } @@ -42,10 +42,10 @@ func TestService_StartWorkers(t *testing.T) { }(), }, args: args{ - logger: log.Log, + config: model.NewConfig(model.WithLogger(log.Log)), workersManager: workers.NewManager(log.Log), sessionManager: func() *session.Manager { - m, _ := session.NewManager(log.Log) + m, _ := session.NewManager(model.NewConfig(model.WithLogger(log.Log))) return m }(), }, @@ -59,7 +59,7 @@ func TestService_StartWorkers(t *testing.T) { MuxerToReliable: tt.fields.MuxerToReliable, ReliableToControl: tt.fields.ReliableToControl, } - s.StartWorkers(tt.args.logger, tt.args.workersManager, tt.args.sessionManager) + s.StartWorkers(tt.args.config, tt.args.workersManager, tt.args.sessionManager) tt.args.workersManager.StartShutdown() tt.args.workersManager.WaitWorkersShutdown() }) diff --git a/internal/session/manager.go b/internal/session/manager.go index 1e74dae2..4a4ec974 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -13,64 +13,6 @@ import ( "github.com/ooni/minivpn/internal/runtimex" ) -// SessionNegotiationState is the state of the session negotiation. -type SessionNegotiationState int - -const ( - // S_ERROR means there was some form of protocol error. - S_ERROR = SessionNegotiationState(iota) - 1 - - // S_UNDER is the undefined state. - S_UNDEF - - // S_INITIAL means we're ready to begin the three-way handshake. - S_INITIAL - - // S_PRE_START means we're waiting for acknowledgment from the remote. - S_PRE_START - - // S_START means we've done the three-way handshake. - S_START - - // S_SENT_KEY means we have sent the local part of the key_source2 random material. - S_SENT_KEY - - // S_GOT_KEY means we have got the remote part of key_source2. - S_GOT_KEY - - // S_ACTIVE means the control channel was established. - S_ACTIVE - - // S_GENERATED_KEYS means the data channel keys have been generated. - S_GENERATED_KEYS -) - -// String maps a [SessionNegotiationState] to a string. -func (sns SessionNegotiationState) String() string { - switch sns { - case S_UNDEF: - return "S_UNDEF" - case S_INITIAL: - return "S_INITIAL" - case S_PRE_START: - return "S_PRE_START" - case S_START: - return "S_START" - case S_SENT_KEY: - return "S_SENT_KEY" - case S_GOT_KEY: - return "S_GOT_KEY" - case S_ACTIVE: - return "S_ACTIVE" - case S_GENERATED_KEYS: - return "S_GENERATED_KEYS" - case S_ERROR: - return "S_ERROR" - default: - return "S_INVALID" - } -} - // Manager manages the session. The zero value is invalid. Please, construct // using [NewManager]. This struct is concurrency safe. type Manager struct { @@ -81,9 +23,10 @@ type Manager struct { localSessionID model.SessionID logger model.Logger mu sync.Mutex - negState SessionNegotiationState + negState model.NegotiationState remoteSessionID optional.Value[model.SessionID] tunnelInfo model.TunnelInfo + tracer model.HandshakeTracer // Ready is a channel where we signal that we can start accepting data, because we've // successfully generated key material for the data channel. @@ -91,7 +34,7 @@ type Manager struct { } // NewManager returns a [Manager] ready to be used. -func NewManager(logger model.Logger) (*Manager, error) { +func NewManager(config *model.Config) (*Manager, error) { key0 := &DataChannelKey{} sessionManager := &Manager{ keyID: 0, @@ -99,11 +42,12 @@ func NewManager(logger model.Logger) (*Manager, error) { // localControlPacketID should be initialized to 1 because we handle hard-reset as special cases localControlPacketID: 1, localSessionID: [8]byte{}, - logger: logger, + logger: config.Logger(), mu: sync.Mutex{}, negState: 0, remoteSessionID: optional.None[model.SessionID](), tunnelInfo: model.TunnelInfo{}, + tracer: config.Tracer(), // empirically, it seems that the reference OpenVPN server misbehaves if we initialize // the data packet ID counter to zero. @@ -261,19 +205,20 @@ func (m *Manager) localControlPacketIDLocked() (model.PacketID, error) { } // NegotiationState returns the state of the negotiation. -func (m *Manager) NegotiationState() SessionNegotiationState { +func (m *Manager) NegotiationState() model.NegotiationState { defer m.mu.Unlock() m.mu.Lock() return m.negState } // SetNegotiationState sets the state of the negotiation. -func (m *Manager) SetNegotiationState(sns SessionNegotiationState) { +func (m *Manager) SetNegotiationState(sns model.NegotiationState) { defer m.mu.Unlock() m.mu.Lock() m.logger.Infof("[@] %s -> %s", m.negState, sns) + m.tracer.OnStateChange(sns) m.negState = sns - if sns == S_GENERATED_KEYS { + if sns == model.S_GENERATED_KEYS { m.Ready <- true } } diff --git a/internal/tlssession/controlmsg.go b/internal/tlssession/controlmsg.go index d59ad945..8fd84546 100644 --- a/internal/tlssession/controlmsg.go +++ b/internal/tlssession/controlmsg.go @@ -23,7 +23,7 @@ import ( // encodeClientControlMessage returns a byte array with the payload for a control channel packet. // This is the packet that the client sends to the server with the key // material, local options and credentials (if username+password authentication is used). -func encodeClientControlMessageAsBytes(k *session.KeySource, o *model.Options) ([]byte, error) { +func encodeClientControlMessageAsBytes(k *session.KeySource, o *model.OpenVPNOptions) ([]byte, error) { opt, err := bytesx.EncodeOptionStringToBytes(o.ServerOptionsString()) if err != nil { return nil, err diff --git a/internal/tlssession/tlshandshake.go b/internal/tlssession/tlshandshake.go index 3f3bbfb0..dfdc9e26 100644 --- a/internal/tlssession/tlshandshake.go +++ b/internal/tlssession/tlshandshake.go @@ -118,7 +118,7 @@ type certConfig struct { // newCertConfigFromOptions is a constructor that returns a certConfig object initialized // from the paths specified in the passed Options object, and an error if it // could not be properly built. -func newCertConfigFromOptions(o *model.Options) (*certConfig, error) { +func newCertConfigFromOptions(o *model.OpenVPNOptions) (*certConfig, error) { var cfg *certConfig var err error if o.ShouldLoadCertsFromPath() { diff --git a/internal/tlssession/tlssession.go b/internal/tlssession/tlssession.go index 95016bfa..1fb88d79 100644 --- a/internal/tlssession/tlssession.go +++ b/internal/tlssession/tlssession.go @@ -42,16 +42,15 @@ type Service struct { // // [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md func (svc *Service) StartWorkers( - logger model.Logger, + config *model.Config, workersManager *workers.Manager, sessionManager *session.Manager, - options *model.Options, ) { ws := &workersState{ - logger: logger, - notifyTLS: svc.NotifyTLS, - options: options, keyUp: *svc.KeyUp, + logger: config.Logger(), + notifyTLS: svc.NotifyTLS, + options: config.OpenVPNOptions(), tlsRecordDown: *svc.TLSRecordDown, tlsRecordUp: svc.TLSRecordUp, sessionManager: sessionManager, @@ -64,7 +63,7 @@ func (svc *Service) StartWorkers( type workersState struct { logger model.Logger notifyTLS <-chan *model.Notification - options *model.Options + options *model.OpenVPNOptions tlsRecordDown chan<- []byte tlsRecordUp <-chan []byte keyUp chan<- *session.DataChannelKey @@ -157,7 +156,7 @@ func (ws *workersState) doTLSAuth(conn net.Conn, config *tls.Config, errorch cha errorch <- err return } - ws.sessionManager.SetNegotiationState(session.S_SENT_KEY) + ws.sessionManager.SetNegotiationState(model.S_SENT_KEY) // read the server's keySource and options remoteKey, serverOptions, err := ws.recvAuthReplyMessage(tlsConn) @@ -175,7 +174,7 @@ func (ws *workersState) doTLSAuth(conn net.Conn, config *tls.Config, errorch cha // add the remote key to the active key activeKey.AddRemoteKey(remoteKey) - ws.sessionManager.SetNegotiationState(session.S_GOT_KEY) + ws.sessionManager.SetNegotiationState(model.S_GOT_KEY) // send the push request if err := ws.sendPushRequestMessage(tlsConn); err != nil { @@ -194,7 +193,7 @@ func (ws *workersState) doTLSAuth(conn net.Conn, config *tls.Config, errorch cha ws.sessionManager.UpdateTunnelInfo(tinfo) // progress to the ACTIVE state - ws.sessionManager.SetNegotiationState(session.S_ACTIVE) + ws.sessionManager.SetNegotiationState(model.S_ACTIVE) // notify the datachannel that we've got a key pair ready to use ws.keyUp <- activeKey diff --git a/internal/tun/setup.go b/internal/tun/setup.go index 11e9d0bd..073d1258 100644 --- a/internal/tun/setup.go +++ b/internal/tun/setup.go @@ -25,10 +25,11 @@ func connectChannel[T any](signal chan T, slot **chan T) { // file for more information about the workers. // // [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md -func startWorkers(logger model.Logger, sessionManager *session.Manager, - tunDevice *TUN, conn networkio.FramingConn, options *model.Options) *workers.Manager { +func startWorkers(config *model.Config, conn networkio.FramingConn, + sessionManager *session.Manager, tunDevice *TUN) *workers.Manager { + // create a workers manager - workersManager := workers.NewManager(logger) + workersManager := workers.NewManager(config.Logger()) // create the networkio service. nio := &networkio.Service{ @@ -109,12 +110,12 @@ func startWorkers(logger model.Logger, sessionManager *session.Manager, connectChannel(tlsx.NotifyTLS, &muxer.NotifyTLS) // start all the workers - nio.StartWorkers(logger, workersManager, conn) - muxer.StartWorkers(logger, workersManager, sessionManager) - rel.StartWorkers(logger, workersManager, sessionManager) - ctrl.StartWorkers(logger, workersManager, sessionManager) - datach.StartWorkers(logger, workersManager, sessionManager, options) - tlsx.StartWorkers(logger, workersManager, sessionManager, options) + nio.StartWorkers(config, workersManager, conn) + muxer.StartWorkers(config, workersManager, sessionManager) + rel.StartWorkers(config, workersManager, sessionManager) + ctrl.StartWorkers(config, workersManager, sessionManager) + datach.StartWorkers(config, workersManager, sessionManager) + tlsx.StartWorkers(config, workersManager, sessionManager) // tell the packetmuxer that it should handshake ASAP muxer.HardReset <- true diff --git a/internal/tun/tun.go b/internal/tun/tun.go index 5d4ff7b2..5e6c9dc6 100644 --- a/internal/tun/tun.go +++ b/internal/tun/tun.go @@ -9,7 +9,6 @@ import ( "sync" "time" - "github.com/apex/log" "github.com/ooni/minivpn/internal/model" "github.com/ooni/minivpn/internal/networkio" "github.com/ooni/minivpn/internal/session" @@ -22,18 +21,18 @@ var ( // StartTUN initializes and starts the TUN device over the vpn. // If the passed context expires before the TUN device is ready, -func StartTUN(ctx context.Context, conn networkio.FramingConn, options *model.Options, logger model.Logger) (*TUN, error) { +func StartTUN(ctx context.Context, conn networkio.FramingConn, config *model.Config) (*TUN, error) { // create a session - sessionManager, err := session.NewManager(logger) + sessionManager, err := session.NewManager(config) if err != nil { return nil, err } // create the TUN that will OWN the connection - tunnel := newTUN(logger, conn, sessionManager) + tunnel := newTUN(config.Logger(), conn, sessionManager) // start all the workers - workers := startWorkers(log.Log, sessionManager, tunnel, conn, options) + workers := startWorkers(config, conn, sessionManager, tunnel) tunnel.whenDone(func() { workers.StartShutdown() workers.WaitWorkersShutdown() @@ -50,7 +49,7 @@ func StartTUN(ctx context.Context, conn networkio.FramingConn, options *model.Op return tunnel, nil case <-tlsTimeout.C: defer func() { - log.Log.Info("tls timeout") + config.Logger().Info("tls timeout") tunnel.Close() }() return nil, errors.New("tls timeout") @@ -128,6 +127,8 @@ func (t *TUN) whenDone(fn func()) { t.whenDoneFn = fn } +// Close is an idempotent method that closes the underlying connection (owned by us) and +// potentially executes any registed callback. func (t *TUN) Close() error { t.closeOnce.Do(func() { close(t.hangup) @@ -139,6 +140,7 @@ func (t *TUN) Close() error { return nil } +// Read implements net.Conn func (t *TUN) Read(data []byte) (int, error) { for { count, _ := t.readBuffer.Read(data) @@ -160,6 +162,7 @@ func (t *TUN) Read(data []byte) (int, error) { } } +// Write implements net.Conn func (t *TUN) Write(data []byte) (int, error) { if isClosedChan(t.writeDeadline.wait()) { return 0, os.ErrDeadlineExceeded @@ -174,27 +177,32 @@ func (t *TUN) Write(data []byte) (int, error) { } } +// LocalAddr implements net.Conn func (t *TUN) LocalAddr() net.Addr { ip := t.session.TunnelInfo().IP return &tunBioAddr{ip, t.network} } +// RemoteAddr implements net.Conn func (t *TUN) RemoteAddr() net.Addr { gw := t.session.TunnelInfo().GW return &tunBioAddr{gw, t.network} } +// SetDeadline implements net.Conn func (t *TUN) SetDeadline(tm time.Time) error { t.readDeadline.set(tm) t.writeDeadline.set(tm) return nil } +// SetReadDeadline implements net.Conn func (t *TUN) SetReadDeadline(tm time.Time) error { t.readDeadline.set(tm) return nil } +// SetWriteDeadline implements net.Conn func (t *TUN) SetWriteDeadline(tm time.Time) error { t.writeDeadline.set(tm) return nil @@ -219,6 +227,7 @@ func (t *tunBioAddr) String() string { return t.addr } +// NetMask returns the configured net mask for the TUN interface. func (t *TUN) NetMask() net.IPMask { return net.IPMask(net.ParseIP(t.session.TunnelInfo().NetMask)) } diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index 1e53a03d..f2e74ccb 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -46,7 +46,6 @@ func (pw *PacketWriter) WriteSequence(seq []string) { // possibly expand a input sequence in range notation for the packet ids [1..10] func maybeExpand(input string) []string { - fmt.Println("maybe expand") items := []string{} pattern := `^\[(\d+)\.\.(\d+)\] (.+)` regexpPattern := regexp.MustCompile(pattern) diff --git a/pkg/README.md b/pkg/README.md new file mode 100644 index 00000000..6fbe1f04 --- /dev/null +++ b/pkg/README.md @@ -0,0 +1 @@ +This folder contains public go packages. diff --git a/pkg/tracex/trace.go b/pkg/tracex/trace.go new file mode 100644 index 00000000..09979fa1 --- /dev/null +++ b/pkg/tracex/trace.go @@ -0,0 +1,191 @@ +// Package tracex implements a handshake tracer that can be passed to the TUN constructor to +// observe handshake events. +package tracex + +import ( + "fmt" + "sync" + "time" + + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/optional" +) + +const ( + handshakeEventStateChange = iota + handshakeEventPacketIn + handshakeEventPacketOut + handshakeEventPacketDropped +) + +// HandshakeEventType indicates which event we logged. +type HandshakeEventType int + +// Ensure that it implements the Stringer interface. +var _ fmt.Stringer = HandshakeEventType(0) + +// String implements fmt.Stringer +func (e HandshakeEventType) String() string { + switch e { + case handshakeEventStateChange: + return "state" + case handshakeEventPacketIn: + return "packet_in" + case handshakeEventPacketOut: + return "packet_out" + case handshakeEventPacketDropped: + return "packet_dropped" + default: + return "unknown" + } +} + +// Event is a handshake event collected by this [model.HandshakeTracer]. +type Event struct { + // EventType is the type for this event. + EventType string `json:"operation"` + + // Stage is the stage of the handshake negotiation we're in. + Stage string `json:"stage"` + + // AtTime is the time for this event, relative to the start time. + AtTime float64 `json:"t"` + + // Tags is an array of tags that can be useful to interpret this event, like the contents of the packet. + Tags []string `json:"tags"` + + // LoggedPacket is an optional packet metadata. + LoggedPacket optional.Value[LoggedPacket] `json:"packet"` +} + +type NegotiationState = model.NegotiationState + +func newEvent(etype HandshakeEventType, st NegotiationState, t time.Time, t0 time.Time) *Event { + return &Event{ + EventType: etype.String(), + Stage: st.String()[2:], + AtTime: t.Sub(t0).Seconds(), + Tags: make([]string, 0), + LoggedPacket: optional.None[LoggedPacket](), + } +} + +// Tracer implements [model.HandshakeTracer]. +type Tracer struct { + // events is the array of handshake events. + events []*Event + + // mu guards access to the events. + mu sync.Mutex + + // zeroTime is the time when we started a packet trace. + zeroTime time.Time +} + +// NewTracer returns a Tracer with the passed start time. +func NewTracer(start time.Time) *Tracer { + return &Tracer{ + zeroTime: start, + } +} + +// TimeNow allows to manipulate time for deterministic tests. +func (t *Tracer) TimeNow() time.Time { + return time.Now() +} + +// OnStateChange is called for each transition in the state machine. +func (t *Tracer) OnStateChange(state NegotiationState) { + t.mu.Lock() + defer t.mu.Unlock() + + e := newEvent(handshakeEventStateChange, state, t.TimeNow(), t.zeroTime) + t.events = append(t.events, e) +} + +// OnIncomingPacket is called when a packet is received. +func (t *Tracer) OnIncomingPacket(packet *model.Packet, stage NegotiationState) { + t.mu.Lock() + defer t.mu.Unlock() + + e := newEvent(handshakeEventPacketIn, stage, t.TimeNow(), t.zeroTime) + e.LoggedPacket = logPacket(packet, optional.None[int](), model.DirectionIncoming) + maybeAddTagsFromPacket(e, packet) + t.events = append(t.events, e) +} + +// OnOutgoingPacket is called when a packet is about to be sent. +func (t *Tracer) OnOutgoingPacket(packet *model.Packet, stage NegotiationState, retries int) { + t.mu.Lock() + defer t.mu.Unlock() + + e := newEvent(handshakeEventPacketOut, stage, t.TimeNow(), t.zeroTime) + e.LoggedPacket = logPacket(packet, optional.Some(retries), model.DirectionOutgoing) + maybeAddTagsFromPacket(e, packet) + t.events = append(t.events, e) +} + +// OnDroppedPacket is called whenever a packet is dropped (in/out) +func (t *Tracer) OnDroppedPacket(direction model.Direction, stage NegotiationState, packet *model.Packet) { + t.mu.Lock() + defer t.mu.Unlock() + + e := newEvent(handshakeEventPacketDropped, stage, t.TimeNow(), t.zeroTime) + e.LoggedPacket = logPacket(packet, optional.None[int](), direction) + t.events = append(t.events, e) +} + +// Trace returns a structured log containing a copy of the array of [model.HandshakeEvent]. +func (t *Tracer) Trace() []*Event { + t.mu.Lock() + defer t.mu.Unlock() + return append([]*Event{}, t.events...) +} + +func logPacket(p *model.Packet, retries optional.Value[int], direction model.Direction) optional.Value[LoggedPacket] { + logged := LoggedPacket{ + Opcode: p.Opcode.String(), + ID: p.ID, + ACKs: optional.None[[]model.PacketID](), + Direction: direction.String(), + PayloadSize: len(p.Payload), + Retries: retries, + } + if len(p.ACKs) != 0 { + logged.ACKs = optional.Some(p.ACKs) + } + return optional.Some(logged) +} + +// LoggedPacket tracks metadata about a packet useful to build traces. +type LoggedPacket struct { + Direction string `json:"operation"` + + // the only fields of the packet we want to log. + Opcode string `json:"opcode"` + ID model.PacketID `json:"id"` + ACKs optional.Value[[]model.PacketID] `json:"acks"` + + // PayloadSize is the size of the payload in bytes + PayloadSize int `json:"payload_size"` + + // Retries keeps track of packet retransmission (only for outgoing packets). + Retries optional.Value[int] `json:"send_attempts"` +} + +// maybeAddTagsFromPacket attempts to derive meaningful tags from +// the packet payload, and adds it to the tag array in the passed event. +func maybeAddTagsFromPacket(e *Event, packet *model.Packet) { + if len(packet.Payload) <= 0 { + return + } + p := packet.Payload + if p[0] == 0x16 && p[5] == 0x01 { + e.Tags = append(e.Tags, "client_hello") + return + } + if p[0] == 0x16 && p[5] == 0x02 { + e.Tags = append(e.Tags, "server_hello") + return + } +}