diff --git a/internal/session/command.go b/internal/session/command.go index 75390b37..89d424fa 100644 --- a/internal/session/command.go +++ b/internal/session/command.go @@ -1,9 +1,11 @@ package session import ( + "bytes" "context" "encoding/base64" "fmt" + "github.com/sirupsen/logrus" "unicode/utf8" "github.com/bradenaw/juniper/xslices" @@ -27,6 +29,14 @@ func (s *Session) startCommandReader(ctx context.Context, del string) <-chan com logging.GoAnnotated(ctx, func(ctx context.Context) { defer close(cmdCh) + tlsHeaders := [][]byte{ + {0x16, 0x03, 0x01}, // 1.0 + {0x16, 0x03, 0x02}, // 1.1 + {0x16, 0x03, 0x03}, // 1.2 + {0x16, 0x03, 0x04}, // 1.3 + {0x16, 0x00, 0x00}, // 0.0 + } + for { line, literals, err := s.liner.Read(func() error { return response.Continuation().Send(s) }) if err != nil { @@ -37,6 +47,14 @@ func (s *Session) startCommandReader(ctx context.Context, del string) <-chan com return fmt.Sprintf("%v: '%s'", k, literals[k]) })...) + // check if we are receiving raw TLS requests, if so skip. + for _, tlsHeader := range tlsHeaders { + if bytes.HasPrefix(line, tlsHeader) { + logrus.Errorf("TLS Handshake detected while not running with TLS/SSL") + return + } + } + // If the input is not valid UTF-8, we drop the connection. if !utf8.Valid(line) { reporter.MessageWithContext(ctx, diff --git a/reporter/mock_reporter/mock_reporter.go b/reporter/mock_reporter/mock_reporter.go new file mode 100644 index 00000000..7e174383 --- /dev/null +++ b/reporter/mock_reporter/mock_reporter.go @@ -0,0 +1,91 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../reporter.go + +// Package mock_reporter is a generated GoMock package. +package mock_reporter + +import ( + reflect "reflect" + + reporter "github.com/ProtonMail/gluon/reporter" + gomock "github.com/golang/mock/gomock" +) + +// MockReporter is a mock of Reporter interface. +type MockReporter struct { + ctrl *gomock.Controller + recorder *MockReporterMockRecorder +} + +// MockReporterMockRecorder is the mock recorder for MockReporter. +type MockReporterMockRecorder struct { + mock *MockReporter +} + +// NewMockReporter creates a new mock instance. +func NewMockReporter(ctrl *gomock.Controller) *MockReporter { + mock := &MockReporter{ctrl: ctrl} + mock.recorder = &MockReporterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReporter) EXPECT() *MockReporterMockRecorder { + return m.recorder +} + +// ReportException mocks base method. +func (m *MockReporter) ReportException(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReportException", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReportException indicates an expected call of ReportException. +func (mr *MockReporterMockRecorder) ReportException(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportException", reflect.TypeOf((*MockReporter)(nil).ReportException), arg0) +} + +// ReportExceptionWithContext mocks base method. +func (m *MockReporter) ReportExceptionWithContext(arg0 any, arg1 reporter.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReportExceptionWithContext", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReportExceptionWithContext indicates an expected call of ReportExceptionWithContext. +func (mr *MockReporterMockRecorder) ReportExceptionWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportExceptionWithContext", reflect.TypeOf((*MockReporter)(nil).ReportExceptionWithContext), arg0, arg1) +} + +// ReportMessage mocks base method. +func (m *MockReporter) ReportMessage(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReportMessage", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReportMessage indicates an expected call of ReportMessage. +func (mr *MockReporterMockRecorder) ReportMessage(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportMessage", reflect.TypeOf((*MockReporter)(nil).ReportMessage), arg0) +} + +// ReportMessageWithContext mocks base method. +func (m *MockReporter) ReportMessageWithContext(arg0 string, arg1 reporter.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReportMessageWithContext", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReportMessageWithContext indicates an expected call of ReportMessageWithContext. +func (mr *MockReporterMockRecorder) ReportMessageWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportMessageWithContext", reflect.TypeOf((*MockReporter)(nil).ReportMessageWithContext), arg0, arg1) +} diff --git a/tests/bad_test.go b/tests/bad_test.go deleted file mode 100644 index 7192e07b..00000000 --- a/tests/bad_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package tests - -import ( - "crypto/tls" - "reflect" - "testing" - - "github.com/bradenaw/juniper/xslices" - "github.com/stretchr/testify/require" -) - -// nolint:gosec -func TestNonUTF8(t *testing.T) { - runOneToOneTest(t, defaultServerOptions(t), func(_ *testConnection, s *testSession) { - // Create a new connection. - c := s.newConnection() - - // Things work fine when the command is valid UTF-8. - c.C("tag capability").OK("tag") - - // Performing a TLS handshake should fail; the server will drop the connection. - require.Error(t, tls.Client(c.conn, &tls.Config{InsecureSkipVerify: true}).Handshake()) - - // We should have reported the bad UTF-8 command. - require.True(t, xslices.Any(s.reporter.getReports(), func(report report) bool { - return reflect.DeepEqual(report.val, "Received invalid UTF-8 command") - })) - }) -} diff --git a/tests/non_utf8_test.go b/tests/non_utf8_test.go new file mode 100644 index 00000000..39eb18fa --- /dev/null +++ b/tests/non_utf8_test.go @@ -0,0 +1,44 @@ +package tests + +import ( + "github.com/ProtonMail/gluon/reporter/mock_reporter" + "github.com/emersion/go-imap/client" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "testing" + "unicode/utf8" +) + +func TestSSLConnectionOverStartTLS(t *testing.T) { + ctrl := gomock.NewController(t) + reporter := mock_reporter.NewMockReporter(ctrl) + + defer ctrl.Finish() + + // Ensure the nothing is reported when connecting via TLS connection if we are not running with TLS + runOneToOneTestClientWithAuth(t, defaultServerOptions(t, withReporter(reporter)), func(_ *client.Client, session *testSession) { + _, err := client.DialTLS(session.listener.Addr().String(), nil) + require.Error(t, err) + }) +} + +func TestNonUtf8CommandTriggersReporter(t *testing.T) { + ctrl := gomock.NewController(t) + reporter := mock_reporter.NewMockReporter(ctrl) + + defer ctrl.Finish() + + reporter.EXPECT().ReportMessageWithContext("Received invalid UTF-8 command", gomock.Any()).Return(nil).Times(1) + + // Ensure the nothing is reported when connecting via TLS connection if we are not running with TLS + runOneToOneTestWithAuth(t, defaultServerOptions(t, withReporter(reporter)), func(c *testConnection, session *testSession) { + // Encode "ééé" as ISO-8859-1. + b := enc("ééé", "ISO-8859-1") + + // Assert that b is no longer valid UTF-8. + require.False(t, utf8.Valid(b)) + + // This will fail and produce a report + c.Cf(`TAG SEARCH CHARSET ISO-8859-1 BODY ` + string(b)) + }) +} diff --git a/tests/server_test.go b/tests/server_test.go index 674cd78d..e64cfca8 100644 --- a/tests/server_test.go +++ b/tests/server_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/hex" "fmt" + "github.com/ProtonMail/gluon/reporter" "net" "path/filepath" "testing" @@ -74,6 +75,7 @@ type serverOptions struct { connectorBuilder connectorBuilder disableParallelism bool imapLimits limits.IMAP + reporter reporter.Reporter } func (s *serverOptions) defaultUsername() string { @@ -164,6 +166,14 @@ func (m imapLimits) apply(options *serverOptions) { options.imapLimits = m.limits } +type reporterOption struct { + reporter reporter.Reporter +} + +func (r reporterOption) apply(options *serverOptions) { + options.reporter = r.reporter +} + func withIdleBulkTime(idleBulkTime time.Duration) serverOption { return &idleBulkTimeOption{idleBulkTime: idleBulkTime} } @@ -196,6 +206,10 @@ func withIMAPLimits(limits limits.IMAP) serverOption { return &imapLimits{limits: limits} } +func withReporter(reporter reporter.Reporter) serverOption { + return &reporterOption{reporter: reporter} +} + func defaultServerOptions(tb testing.TB, modifiers ...serverOption) *serverOptions { options := &serverOptions{ credentials: []credentials{{ @@ -264,6 +278,10 @@ func runServer(tb testing.TB, options *serverOptions, tests func(session *testSe gluonOptions = append(gluonOptions, gluon.WithDisableParallelism()) } + if options.reporter != nil { + gluonOptions = append(gluonOptions, gluon.WithReporter(options.reporter)) + } + // Create a new gluon server. server, err := gluon.New(gluonOptions...) require.NoError(tb, err)