diff --git a/codec.go b/codec.go index fe672ac..6234f6e 100644 --- a/codec.go +++ b/codec.go @@ -96,54 +96,58 @@ func validateBytes(b []byte) (err error) { return nil } -func bytesToString(b []byte) (ret string, err error) { - s := "" +func readComponent(b []byte) (int, Component, error) { + var offset int + code, n, err := ReadVarintCode(b) + if err != nil { + return 0, Component{}, err + } + offset += n - for len(b) > 0 { - code, n, err := ReadVarintCode(b) - if err != nil { - return "", err - } + p := ProtocolWithCode(code) + if p.Code == 0 { + return 0, Component{}, fmt.Errorf("no protocol with code %d", code) + } - b = b[n:] - p := ProtocolWithCode(code) - if p.Code == 0 { - return "", fmt.Errorf("no protocol with code %d", code) - } - s += "/" + p.Name + if p.Size == 0 { + return offset, Component{ + bytes: b[:offset], + offset: offset, + protocol: p, + }, nil + } - if p.Size == 0 { - continue - } + n, size, err := sizeForAddr(p, b[offset:]) + if err != nil { + return 0, Component{}, err + } - n, size, err := sizeForAddr(p, b) - if err != nil { - return "", err - } + offset += n - b = b[n:] + if len(b[offset:]) < size || size < 0 { + return 0, Component{}, fmt.Errorf("invalid value for size") + } - if len(b) < size || size < 0 { - return "", fmt.Errorf("invalid value for size") - } + return offset + size, Component{ + bytes: b[:offset+size], + protocol: p, + offset: offset, + }, nil +} - if p.Transcoder == nil { - return "", fmt.Errorf("no transcoder for %s protocol", p.Name) - } - a, err := p.Transcoder.BytesToString(b[:size]) +func bytesToString(b []byte) (ret string, err error) { + var buf strings.Builder + + for len(b) > 0 { + n, c, err := readComponent(b) if err != nil { return "", err } - if p.Path && len(a) > 0 && a[0] == '/' { - a = a[1:] - } - if len(a) > 0 { - s += "/" + a - } - b = b[size:] + b = b[n:] + c.writeTo(&buf) } - return s, nil + return buf.String(), nil } func sizeForAddr(p Protocol, b []byte) (skip, size int, err error) { diff --git a/component.go b/component.go new file mode 100644 index 0000000..e31fe5d --- /dev/null +++ b/component.go @@ -0,0 +1,132 @@ +package multiaddr + +import ( + "bytes" + "encoding/binary" + "fmt" + "strings" +) + +// Component is a single multiaddr Component. +type Component struct { + bytes []byte + protocol Protocol + offset int +} + +func (c *Component) Bytes() []byte { + return c.bytes +} + +func (c *Component) Equal(o Multiaddr) bool { + return bytes.Equal(c.bytes, o.Bytes()) +} + +func (c *Component) Protocols() []Protocol { + return []Protocol{c.protocol} +} + +func (c *Component) Decapsulate(o Multiaddr) Multiaddr { + if c.Equal(o) { + return nil + } + return c +} + +func (c *Component) Encapsulate(o Multiaddr) Multiaddr { + m := multiaddr{bytes: c.bytes} + return m.Encapsulate(o) +} + +func (c *Component) ValueForProtocol(code int) (string, error) { + if c.protocol.Code != code { + return "", ErrProtocolNotFound + } + return c.Value(), nil +} + +func (c *Component) Protocol() Protocol { + return c.protocol +} + +func (c *Component) RawValue() []byte { + return c.bytes[c.offset:] +} + +func (c *Component) Value() string { + if c.protocol.Transcoder == nil { + return "" + } + value, err := c.protocol.Transcoder.BytesToString(c.bytes[c.offset:]) + if err != nil { + // This Component must have been checked. + panic(err) + } + return value +} + +func (c *Component) String() string { + var b strings.Builder + c.writeTo(&b) + return b.String() +} + +// writeTo is an efficient, private function for string-formatting a multiaddr. +// Trust me, we tend to allocate a lot when doing this. +func (c *Component) writeTo(b *strings.Builder) { + b.WriteByte('/') + b.WriteString(c.protocol.Name) + value := c.Value() + if len(value) == 0 { + return + } + if !(c.protocol.Path && value[0] == '/') { + b.WriteByte('/') + } + b.WriteString(value) +} + +// NewComponent constructs a new multiaddr component +func NewComponent(protocol, value string) (*Component, error) { + p := ProtocolWithName(protocol) + if p.Code == 0 { + return nil, fmt.Errorf("unsupported protocol: %s", protocol) + } + if p.Transcoder != nil { + bts, err := p.Transcoder.StringToBytes(value) + if err != nil { + return nil, err + } + return newComponent(p, bts), nil + } else if value != "" { + return nil, fmt.Errorf("protocol %s doesn't take a value", p.Name) + } + return newComponent(p, nil), nil + // TODO: handle path /? +} + +func newComponent(protocol Protocol, bvalue []byte) *Component { + size := len(bvalue) + size += len(protocol.VCode) + if protocol.Size < 0 { + size += VarintSize(len(bvalue)) + } + maddr := make([]byte, size) + var offset int + offset += copy(maddr[offset:], protocol.VCode) + if protocol.Size < 0 { + offset += binary.PutUvarint(maddr[offset:], uint64(len(bvalue))) + } + copy(maddr[offset:], bvalue) + + // For debugging + if len(maddr) != offset+len(bvalue) { + panic("incorrect length") + } + + return &Component{ + bytes: maddr, + protocol: protocol, + offset: offset, + } +} diff --git a/interface.go b/interface.go index 1f46184..34bffd9 100644 --- a/interface.go +++ b/interface.go @@ -43,5 +43,8 @@ type Multiaddr interface { Decapsulate(Multiaddr) Multiaddr // ValueForProtocol returns the value (if any) following the specified protocol + // + // Note: protocols can appear multiple times in a single multiaddr. + // Consider using `ForEach` to walk over the addr manually. ValueForProtocol(code int) (string, error) } diff --git a/multiaddr.go b/multiaddr.go index 9b5c251..2c07dd3 100644 --- a/multiaddr.go +++ b/multiaddr.go @@ -127,16 +127,15 @@ func (m multiaddr) Decapsulate(o Multiaddr) Multiaddr { var ErrProtocolNotFound = fmt.Errorf("protocol not found in multiaddr") -func (m multiaddr) ValueForProtocol(code int) (string, error) { - for _, sub := range Split(m) { - p := sub.Protocols()[0] - if p.Code == code { - if p.Size == 0 { - return "", nil - } - return strings.SplitN(sub.String(), "/", 3)[2], nil +func (m multiaddr) ValueForProtocol(code int) (value string, err error) { + err = ErrProtocolNotFound + ForEach(m, func(c Component) bool { + if c.Protocol().Code == code { + value = c.Value() + err = nil + return false } - } - - return "", ErrProtocolNotFound + return true + }) + return } diff --git a/multiaddr_test.go b/multiaddr_test.go index 42c963d..261cba3 100644 --- a/multiaddr_test.go +++ b/multiaddr_test.go @@ -376,7 +376,7 @@ func TestGetValue(t *testing.T) { a = newMultiaddr(t, "/ip4/0.0.0.0/unix/a/b/c/d") // ending in a path one. assertValueForProto(t, a, P_IP4, "0.0.0.0") - assertValueForProto(t, a, P_UNIX, "a/b/c/d") + assertValueForProto(t, a, P_UNIX, "/a/b/c/d") } func TestFuzzBytes(t *testing.T) { diff --git a/util.go b/util.go index 49eff9d..f08788b 100644 --- a/util.go +++ b/util.go @@ -4,15 +4,11 @@ import "fmt" // Split returns the sub-address portions of a multiaddr. func Split(m Multiaddr) []Multiaddr { - split, err := bytesSplit(m.Bytes()) - if err != nil { - panic(fmt.Errorf("invalid multiaddr %s", m.String())) - } - - addrs := make([]Multiaddr, len(split)) - for i, addr := range split { - addrs[i] = multiaddr{bytes: addr} - } + var addrs []Multiaddr + ForEach(m, func(c Component) bool { + addrs = append(addrs, &c) + return true + }) return addrs } @@ -59,3 +55,100 @@ func StringCast(s string) Multiaddr { } return m } + +// SplitFirst returns the first component and the rest of the multiaddr. +func SplitFirst(m Multiaddr) (*Component, Multiaddr) { + b := m.Bytes() + if len(b) == 0 { + return nil, nil + } + n, c, err := readComponent(b) + if err != nil { + panic(err) + } + if len(b) == n { + return &c, nil + } + return &c, multiaddr{b[n:]} +} + +// SplitLast returns the rest of the multiaddr and the last component. +func SplitLast(m Multiaddr) (Multiaddr, *Component) { + b := m.Bytes() + if len(b) == 0 { + return nil, nil + } + + var ( + c Component + err error + offset int + ) + for { + var n int + n, c, err = readComponent(b[offset:]) + if err != nil { + panic(err) + } + if len(b) == n+offset { + // Reached end + if offset == 0 { + // Only one component + return nil, &c + } + return multiaddr{b[:offset]}, &c + } + offset += n + } +} + +// SplitFunc splits the multiaddr when the callback first returns true. The +// component on which the callback first returns will be included in the +// *second* multiaddr. +func SplitFunc(m Multiaddr, cb func(Component) bool) (Multiaddr, Multiaddr) { + b := m.Bytes() + if len(b) == 0 { + return nil, nil + } + var ( + c Component + err error + offset int + ) + for offset < len(b) { + var n int + n, c, err = readComponent(b[offset:]) + if err != nil { + panic(err) + } + if cb(c) { + break + } + offset += n + } + switch offset { + case 0: + return nil, m + case len(b): + return m, nil + default: + return multiaddr{b[:offset]}, multiaddr{b[offset:]} + } +} + +// ForEach walks over the multiaddr, component by component. +// +// This function iterates over components *by value* to avoid allocating. +func ForEach(m Multiaddr, cb func(c Component) bool) { + b := m.Bytes() + for len(b) > 0 { + n, c, err := readComponent(b) + if err != nil { + panic(err) + } + if !cb(c) { + return + } + b = b[n:] + } +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..3210ca1 --- /dev/null +++ b/util_test.go @@ -0,0 +1,101 @@ +package multiaddr + +import ( + "strings" + "testing" +) + +func TestSplitFirstLast(t *testing.T) { + ipStr := "/ip4/0.0.0.0" + tcpStr := "/tcp/123" + quicStr := "/quic" + ipfsStr := "/ipfs/QmPSQnBKM9g7BaUcZCvswUJVscQ1ipjmwxN5PXCjkp9EQ7" + + for _, x := range [][]string{ + []string{ipStr, tcpStr, quicStr, ipfsStr}, + []string{ipStr, tcpStr, ipfsStr}, + []string{ipStr, tcpStr}, + []string{ipStr}, + []string{}, + } { + addr := StringCast(strings.Join(x, "")) + head, tail := SplitFirst(addr) + rest, last := SplitLast(addr) + if len(x) == 0 { + if head != nil { + t.Error("expected head to be nil") + } + if tail != nil { + t.Error("expected tail to be nil") + } + if rest != nil { + t.Error("expected rest to be nil") + } + if last != nil { + t.Error("expected last to be nil") + } + continue + } + if !head.Equal(StringCast(x[0])) { + t.Errorf("expected %s to be %s", head, x[0]) + } + if !last.Equal(StringCast(x[len(x)-1])) { + t.Errorf("expected %s to be %s", head, x[len(x)-1]) + } + if len(x) == 1 { + if tail != nil { + t.Error("expected tail to be nil") + } + if rest != nil { + t.Error("expected rest to be nil") + } + continue + } + tailExp := strings.Join(x[1:], "") + if !tail.Equal(StringCast(tailExp)) { + t.Errorf("expected %s to be %s", tail, tailExp) + } + restExp := strings.Join(x[:len(x)-1], "") + if !rest.Equal(StringCast(restExp)) { + t.Errorf("expected %s to be %s", rest, restExp) + } + } +} + +func TestSplitFunc(t *testing.T) { + ipStr := "/ip4/0.0.0.0" + tcpStr := "/tcp/123" + quicStr := "/quic" + ipfsStr := "/ipfs/QmPSQnBKM9g7BaUcZCvswUJVscQ1ipjmwxN5PXCjkp9EQ7" + + for _, x := range [][]string{ + []string{ipStr, tcpStr, quicStr, ipfsStr}, + []string{ipStr, tcpStr, ipfsStr}, + []string{ipStr, tcpStr}, + []string{ipStr}, + } { + addr := StringCast(strings.Join(x, "")) + for i, cs := range x { + target := StringCast(cs) + a, b := SplitFunc(addr, func(c Component) bool { + return c.Equal(target) + }) + if i == 0 { + if a != nil { + t.Error("expected nil addr") + } + } else { + if !a.Equal(StringCast(strings.Join(x[:i], ""))) { + t.Error("split failed") + } + if !b.Equal(StringCast(strings.Join(x[i:], ""))) { + t.Error("split failed") + } + } + } + a, b := SplitFunc(addr, func(_ Component) bool { return false }) + if !a.Equal(addr) || b != nil { + t.Error("should not have split") + } + } +}