Skip to content

Commit

Permalink
ttrpc: use os.Getuid/os.Getgid directly
Browse files Browse the repository at this point in the history
Because of issues with glibc, using the `os/user` package can cause when
calling `user.Current()`. Neither the Go maintainers or glibc developers
could be bothered to fix it, so we have to work around it by calling the
uid and gid functions directly. This is probably better because we don't
actually use much of the data provided in the `user.User` struct.

This required some refactoring to have better control over when the uid
and gid are resolved. Rather than checking the current user on every
connection, we now resolve it once at initialization. To test that this
provided an improvement in performance, a benchmark was added.
Unfortunately, this exposed a regression in the performance of unix
sockets in Go when `(*UnixConn).File` is called. The underlying culprit
of this performance regression is still at large.

The following open issues describe the underlying problem in more
detail:

golang/go#13470
https://sourceware.org/bugzilla/show_bug.cgi?id=19341

In better news, I now have an entire herd of shaved yaks.

Signed-off-by: Stephen J Day <stephen.day@docker.com>
  • Loading branch information
stevvooe committed Dec 1, 2017
1 parent af6e749 commit 256c17b
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 69 deletions.
10 changes: 10 additions & 0 deletions handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,13 @@ type Handshaker interface {
// client-side.
Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
}

type handshakerFunc func(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)

func (fn handshakerFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
return fn(ctx, conn)
}

func noopHandshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
return conn, nil, nil
}
22 changes: 10 additions & 12 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ttrpc

import (
"context"
"io"
"math/rand"
"net"
"sync"
Expand Down Expand Up @@ -55,10 +56,15 @@ func (s *Server) Serve(l net.Listener) error {
defer s.closeListener(l)

var (
ctx = context.Background()
backoff time.Duration
ctx = context.Background()
backoff time.Duration
handshaker = s.config.handshaker
)

if handshaker == nil {
handshaker = handshakerFunc(noopHandshake)
}

for {
conn, err := l.Accept()
if err != nil {
Expand Down Expand Up @@ -92,7 +98,7 @@ func (s *Server) Serve(l net.Listener) error {

backoff = 0

approved, handshake, err := s.handshake(ctx, conn)
approved, handshake, err := handshaker.Handshake(ctx, conn)
if err != nil {
log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake")
conn.Close()
Expand Down Expand Up @@ -150,14 +156,6 @@ func (s *Server) Close() error {
return err
}

func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
if s.config.handshaker == nil {
return conn, nil, nil
}

return s.config.handshaker.Handshake(ctx, conn)
}

func (s *Server) addListener(l net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down Expand Up @@ -433,7 +431,7 @@ func (c *serverConn) run(sctx context.Context) {
// branch. Basically, it means that we are no longer receiving
// requests due to a terminal error.
recvErr = nil // connection is now "closing"
if err != nil {
if err != nil && err != io.EOF {
log.L.WithError(err).Error("error receiving message")
}
case <-shutdown:
Expand Down
90 changes: 61 additions & 29 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,34 +106,6 @@ func TestServer(t *testing.T) {
}
}

func BenchmarkRoundTrip(b *testing.B) {
var (
ctx = context.Background()
server = mustServer(b)(NewServer())
testImpl = &testingServer{}
addr, listener = newTestListener(b)
client, cleanup = newTestClient(b, addr)
tclient = newTestingClient(client)
)

defer listener.Close()
defer cleanup()

registerTestingService(server, testImpl)

go server.Serve(listener)
defer server.Shutdown(ctx)

var tp testPayload
b.ResetTimer()

for i := 0; i < b.N; i++ {
if _, err := tclient.Test(ctx, &tp); err != nil {
b.Fatal(err)
}
}
}

func TestServerNotFound(t *testing.T) {
var (
ctx = context.Background()
Expand Down Expand Up @@ -363,7 +335,7 @@ func TestClientEOF(t *testing.T) {
func TestUnixSocketHandshake(t *testing.T) {
var (
ctx = context.Background()
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser)))
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser())))
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
Expand All @@ -383,6 +355,66 @@ func TestUnixSocketHandshake(t *testing.T) {
}
}

func BenchmarkRoundTrip(b *testing.B) {
var (
ctx = context.Background()
server = mustServer(b)(NewServer())
testImpl = &testingServer{}
addr, listener = newTestListener(b)
client, cleanup = newTestClient(b, addr)
tclient = newTestingClient(client)
)

defer listener.Close()
defer cleanup()

registerTestingService(server, testImpl)

go server.Serve(listener)
defer server.Shutdown(ctx)

var tp testPayload
b.ResetTimer()

for i := 0; i < b.N; i++ {
if _, err := tclient.Test(ctx, &tp); err != nil {
b.Fatal(err)
}
}
}

func BenchmarkRoundTripUnixSocketCreds(b *testing.B) {
// TODO(stevvooe): Right now, there is a 5x performance decrease when using
// unix socket credentials. See (UnixCredentialsFunc).Handshake for
// details.

var (
ctx = context.Background()
server = mustServer(b)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser())))
testImpl = &testingServer{}
addr, listener = newTestListener(b)
client, cleanup = newTestClient(b, addr)
tclient = newTestingClient(client)
)

defer listener.Close()
defer cleanup()

registerTestingService(server, testImpl)

go server.Serve(listener)
defer server.Shutdown(ctx)

var tp testPayload
b.ResetTimer()

for i := 0; i < b.N; i++ {
if _, err := tclient.Test(ctx, &tp); err != nil {
b.Fatal(err)
}
}
}

func checkServerShutdown(t *testing.T, server *Server) {
t.Helper()
server.mu.Lock()
Expand Down
48 changes: 20 additions & 28 deletions unixcreds.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,13 @@ package ttrpc
import (
"context"
"net"
"os/user"
"strconv"
"os"
"syscall"

"github.com/pkg/errors"
"golang.org/x/sys/unix"
)

var (
UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser)
UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot)
)

type UnixCredentialsFunc func(*unix.Ucred) error

func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
Expand All @@ -26,6 +20,9 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket")
}

// TODO(stevvooe): Calling (*UnixConn).File causes a 5x performance
// decrease vs just accessing the fd directly. Need to do some more
// troubleshooting to isolate this to Go runtime or kernel.
fp, err := uc.File()
if err != nil {
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file")
Expand All @@ -44,37 +41,32 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net
return uc, ucred, nil
}

func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc {
func UnixSocketRequireUidGid(uid, gid int) UnixCredentialsFunc {
return func(ucred *unix.Ucred) error {
return requireUidGid(ucred, uid, gid)
}
}

func requireRoot(ucred *unix.Ucred) error {
return requireUidGid(ucred, 0, 0)
func UnixSocketRequireRoot() UnixCredentialsFunc {
return UnixSocketRequireUidGid(0, 0)
}

func requireSameUser(ucred *unix.Ucred) error {
u, err := user.Current()
if err != nil {
return errors.Wrapf(err, "could not resolve current user")
}

uid, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid)
}

gid, err := strconv.ParseUint(u.Gid, 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid)
}
// UnixSocketRequireSameUser resolves the current unix user and returns a
// UnixCredentialsFunc that will validate incoming unix connections against the
// current credentials.
//
// This is useful when using abstract sockets that are accessible by all users.
func UnixSocketRequireSameUser() UnixCredentialsFunc {
uid, gid := os.Getuid(), os.Getgid()
return UnixSocketRequireUidGid(uid, gid)
}

return requireUidGid(ucred, uint32(uid), uint32(gid))
func requireRoot(ucred *unix.Ucred) error {
return requireUidGid(ucred, 0, 0)
}

func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error {
if (uid != ucred.Uid) || (gid != ucred.Gid) {
func requireUidGid(ucred *unix.Ucred, uid, gid int) error {
if (uid != -1 && uint32(uid) != ucred.Uid) || (gid != -1 && uint32(gid) != ucred.Gid) {
return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials")
}
return nil
Expand Down

0 comments on commit 256c17b

Please sign in to comment.