diff --git a/unix/sockcmsg_unix.go b/unix/sockcmsg_unix.go index 453a942c5..3865943f6 100644 --- a/unix/sockcmsg_unix.go +++ b/unix/sockcmsg_unix.go @@ -52,6 +52,20 @@ func ParseSocketControlMessage(b []byte) ([]SocketControlMessage, error) { return msgs, nil } +// ParseOneSocketControlMessage parses a single socket control message from b, returning the message header, +// message data (a slice of b), and the remainder of b after that single message. +// When there are no remaining messages, len(remainder) == 0. +func ParseOneSocketControlMessage(b []byte) (hdr Cmsghdr, data []byte, remainder []byte, err error) { + h, dbuf, err := socketControlMessageHeaderAndData(b) + if err != nil { + return Cmsghdr{}, nil, nil, err + } + if i := cmsgAlignOf(int(h.Len)); i < len(b) { + remainder = b[i:] + } + return *h, dbuf, remainder, nil +} + func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) { h := (*Cmsghdr)(unsafe.Pointer(&b[0])) if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) { diff --git a/unix/syscall_unix_test.go b/unix/syscall_unix_test.go index baff92eea..0517689f0 100644 --- a/unix/syscall_unix_test.go +++ b/unix/syscall_unix_test.go @@ -322,7 +322,7 @@ func passFDChild() { } } -// TestUnixRightsRoundtrip tests that UnixRights, ParseSocketControlMessage, +// TestUnixRightsRoundtrip tests that UnixRights, ParseSocketControlMessage, ParseOneSocketControlMessage, // and ParseUnixRights are able to successfully round-trip lists of file descriptors. func TestUnixRightsRoundtrip(t *testing.T) { testCases := [...][][]int{ @@ -350,6 +350,23 @@ func TestUnixRightsRoundtrip(t *testing.T) { if len(scms) != len(testCase) { t.Fatalf("expected %v SocketControlMessage; got scms = %#v", len(testCase), scms) } + + var c int + for len(b) > 0 { + hdr, data, remainder, err := unix.ParseOneSocketControlMessage(b) + if err != nil { + t.Fatalf("ParseOneSocketControlMessage: %v", err) + } + if scms[c].Header != hdr || !bytes.Equal(scms[c].Data, data) { + t.Fatal("expected SocketControlMessage header and data to match") + } + b = remainder + c++ + } + if c != len(scms) { + t.Fatalf("expected %d SocketControlMessages; got %d", len(scms), c) + } + for i, scm := range scms { gotFds, err := unix.ParseUnixRights(&scm) if err != nil {