diff --git a/go/border/BUILD.bazel b/go/border/BUILD.bazel index 939cad35ef..14170d32fc 100644 --- a/go/border/BUILD.bazel +++ b/go/border/BUILD.bazel @@ -70,10 +70,12 @@ go_test( "//go/lib/overlay/conn/mock_conn:go_default_library", "//go/lib/prom:go_default_library", "//go/lib/ringbuf:go_default_library", + "//go/lib/serrors:go_default_library", "//go/lib/topology:go_default_library", "//go/lib/xtest:go_default_library", "@com_github_golang_mock//gomock:go_default_library", "@com_github_smartystreets_goconvey//convey:go_default_library", + "@com_github_stretchr_testify//assert:go_default_library", "@com_github_stretchr_testify//require:go_default_library", ], ) diff --git a/go/border/io.go b/go/border/io.go index f9fdbd1f1e..b52df34a37 100644 --- a/go/border/io.go +++ b/go/border/io.go @@ -18,8 +18,8 @@ package main import ( + "errors" "net" - "os" "syscall" "time" @@ -30,9 +30,11 @@ import ( "github.com/scionproto/scion/go/border/rpkt" "github.com/scionproto/scion/go/lib/assert" "github.com/scionproto/scion/go/lib/common" + "github.com/scionproto/scion/go/lib/fatal" "github.com/scionproto/scion/go/lib/log" "github.com/scionproto/scion/go/lib/overlay/conn" "github.com/scionproto/scion/go/lib/ringbuf" + "github.com/scionproto/scion/go/lib/serrors" ) const ( @@ -176,7 +178,7 @@ func (r *Router) posixInputRead(msgs []ipv4.Message, metas []conn.ReadMeta, // Loop until a read succeeds, or a non-trivial error occurs for { n, err := c.ReadBatch(msgs, metas) - if err != nil && isConnRefused(err) { + if err != nil && isSyscallErrno(err, syscall.ECONNREFUSED) { // As we are using a connected UDP socket for interface sockets, // any ECONNREFUSED errors that happen while sending to the // neighbouring BR show up as read errors on the socket. As these @@ -193,7 +195,6 @@ func (r *Router) posixInputRead(msgs []ipv4.Message, metas []conn.ReadMeta, func (r *Router) posixOutput(s *rctx.Sock, _, stopped chan struct{}) { defer log.LogPanicAndExit() defer close(stopped) - var ringClosed bool src := s.Conn.LocalAddr() dst := s.Conn.RemoteAddr() log.Info("posixOutput starting", "addr", src) @@ -218,7 +219,6 @@ func (r *Router) posixOutput(s *rctx.Sock, _, stopped chan struct{}) { var t float64 // Needs to be declared before goto var ok bool if epkts, ok = r.posixPrepOutput(epkts, msgs, s.Ring, dst != nil); !ok { - ringClosed = true break } toWrite := min(len(epkts), outputBatchCnt) @@ -243,7 +243,8 @@ func (r *Router) posixOutput(s *rctx.Sock, _, stopped chan struct{}) { continue } // Shutdown writer if the error is non-recoverable. - break + fatal.Fatal(serrors.WrapStr("shutdown on irrecoverable error to avoid broken state", + err, "ifid", s.Ifid)) } } t = time.Since(start).Seconds() @@ -268,18 +269,6 @@ func (r *Router) posixOutput(s *rctx.Sock, _, stopped chan struct{}) { // Release any remaining unsent pkts. releasePkts(epkts) epkts = epkts[:0] - // If the ring is not already closed, drain it until it is closed. This - // prevents writers from blocking in case of an unrecoverable error. - if !ringClosed { - for { - var ok bool - if epkts, ok = r.posixPrepOutput(epkts, msgs, s.Ring, dst != nil); !ok { - break - } - releasePkts(epkts) - epkts = epkts[:0] - } - } } // posixPrepOutput fetches new packets if epkts is empty, and sets the msgs @@ -328,31 +317,23 @@ func shiftUnwrittenPkts(epkts ringbuf.EntryList, pktsWritten int) ringbuf.EntryL // isRecoverableErr checks whether an non-temporary error is recoverable. func isRecoverableErr(err error) bool { - return isConnRefused(err) || isNetUnreachable(err) || isHostUnreachable(err) -} - -func isConnRefused(err error) bool { - return isSyscallErrno(err, syscall.ECONNREFUSED) -} - -func isNetUnreachable(err error) bool { - return isSyscallErrno(err, syscall.ENETUNREACH) -} - -func isHostUnreachable(err error) bool { - return isSyscallErrno(err, syscall.EHOSTUNREACH) + switch { + case isSyscallErrno(err, syscall.ECONNREFUSED), + isSyscallErrno(err, syscall.ENETUNREACH), + isSyscallErrno(err, syscall.EHOSTUNREACH), + isSyscallErrno(err, syscall.EPERM): + return true + default: + return false + } } func isSyscallErrno(err error, errno syscall.Errno) bool { - netErr, ok := err.(*net.OpError) - if !ok { - return false - } - osErr, ok := netErr.Err.(*os.SyscallError) - if !ok { - return false + var target syscall.Errno + if errors.As(err, &target) { + return target == errno } - return osErr.Err == errno + return false } func min(a, b int) int { diff --git a/go/border/io_test.go b/go/border/io_test.go index e504187b91..f39ed344e1 100644 --- a/go/border/io_test.go +++ b/go/border/io_test.go @@ -15,13 +15,13 @@ package main import ( - "errors" "net" "os" "syscall" "testing" "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/scionproto/scion/go/border/rctx" @@ -31,9 +31,60 @@ import ( "github.com/scionproto/scion/go/lib/overlay/conn/mock_conn" "github.com/scionproto/scion/go/lib/prom" "github.com/scionproto/scion/go/lib/ringbuf" + "github.com/scionproto/scion/go/lib/serrors" "github.com/scionproto/scion/go/lib/topology" ) +func TestIsSyscallErrno(t *testing.T) { + tests := map[string]struct { + Error error + Errno syscall.Errno + Assertion assert.BoolAssertionFunc + }{ + "ECONNREFUSED": { + Error: &net.OpError{Err: &os.SyscallError{Err: syscall.ECONNREFUSED}}, + Errno: syscall.ECONNREFUSED, + Assertion: assert.True, + }, + "ENETUNREACH": { + Error: &net.OpError{Err: &os.SyscallError{Err: syscall.ENETUNREACH}}, + Errno: syscall.ENETUNREACH, + Assertion: assert.True, + }, + "EHOSTUNREACH": { + Error: &net.OpError{Err: &os.SyscallError{Err: syscall.EHOSTUNREACH}}, + Errno: syscall.EHOSTUNREACH, + Assertion: assert.True, + }, + "EPERM": { + Error: &net.OpError{Err: &os.SyscallError{Err: syscall.EPERM}}, + Errno: syscall.EPERM, + Assertion: assert.True, + }, + "Wrapped(EPERM)": { + Error: serrors.WrapStr("wrapped", syscall.EPERM), + Errno: syscall.EPERM, + Assertion: assert.True, + }, + "mismatch": { + Error: &net.OpError{Err: &os.SyscallError{Err: syscall.EHOSTUNREACH}}, + Errno: syscall.EPERM, + Assertion: assert.False, + }, + "other": { + Error: serrors.New("other"), + Errno: syscall.EPERM, + Assertion: assert.False, + }, + } + for n, tc := range tests { + name, test := n, tc + t.Run(name, func(t *testing.T) { + test.Assertion(t, isSyscallErrno(test.Error, test.Errno)) + }) + } +} + func TestPosixOutputNoLeakNoErrors(t *testing.T) { mctrl := gomock.NewController(t) defer mctrl.Finish() @@ -90,28 +141,6 @@ func TestPosixOutputNoLeakRecoverableErrors(t *testing.T) { sock.Stop() } -func TestPosixOutputNoLeakUnrecoverableErrors(t *testing.T) { - mctrl := gomock.NewController(t) - defer mctrl.Finish() - r := initTestRouter(1) - pkts, checkAllReturned := newTestPktList(t, 2*outputBatchCnt) - defer checkAllReturned(len(pkts)) - // Wait for both batches to be written. - done := make(chan struct{}, 1) - mconn := newTestConn(mctrl) - mconn.EXPECT().WriteBatch(gomock.Any()).DoAndReturn( - func(_ conn.Messages) (int, error) { - done <- struct{}{} - return 0, errors.New("unrecoverable") - }, - ) - sock := newTestSock(r, len(pkts), mconn) - sock.Start() - sock.Ring.Write(pkts, true) - <-done - sock.Stop() -} - func testSuccessfulWrite(done chan<- struct{}) func(conn.Messages) (int, error) { return func(msgs conn.Messages) (int, error) { for i, msg := range msgs {