From beb721f806458b718794f196720e82412dbdf49b Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Fri, 28 Jun 2024 07:37:22 -0700 Subject: [PATCH 1/3] Use subtests --- traverse_test.go | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/traverse_test.go b/traverse_test.go index 62e3547..4756f1b 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -234,28 +234,37 @@ var tests = []networkTest{ func TestNetworksWithin(t *testing.T) { for _, v := range tests { for _, recordSize := range []uint{24, 28, 32} { - fileName := testFile(fmt.Sprintf("MaxMind-DB-test-%s-%d.mmdb", v.Database, recordSize)) - reader, err := Open(fileName) - require.NoError(t, err, "unexpected error while opening database: %v", err) - - _, network, err := net.ParseCIDR(v.Network) - require.NoError(t, err) - n := reader.NetworksWithin(network, v.Options...) - var innerIPs []string + name := fmt.Sprintf( + "%s-%d: %s, options: %v", + v.Database, + recordSize, + v.Network, + len(v.Options) != 0, + ) + t.Run(name, func(t *testing.T) { + fileName := testFile(fmt.Sprintf("MaxMind-DB-test-%s-%d.mmdb", v.Database, recordSize)) + reader, err := Open(fileName) + require.NoError(t, err, "unexpected error while opening database: %v", err) - for n.Next() { - record := struct { - IP string `maxminddb:"ip"` - }{} - network, err := n.Network(&record) + _, network, err := net.ParseCIDR(v.Network) require.NoError(t, err) - innerIPs = append(innerIPs, network.String()) - } + n := reader.NetworksWithin(network, v.Options...) + var innerIPs []string - assert.Equal(t, v.Expected, innerIPs) - require.NoError(t, n.Err()) + for n.Next() { + record := struct { + IP string `maxminddb:"ip"` + }{} + network, err := n.Network(&record) + require.NoError(t, err) + innerIPs = append(innerIPs, network.String()) + } - require.NoError(t, reader.Close()) + assert.Equal(t, v.Expected, innerIPs) + require.NoError(t, n.Err()) + + require.NoError(t, reader.Close()) + }) } } } From 81d3c39d0f045f99e040513391051ac1007a92b6 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Fri, 28 Jun 2024 07:47:10 -0700 Subject: [PATCH 2/3] Set network number to first IP in network when looking up a network that is more specific than the network in the database using NetworksWithin. Previously, the network number would be set to the network number of the provided *net.IPNet. This changes it to the canonical form. --- traverse.go | 4 ++++ traverse_test.go | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/traverse.go b/traverse.go index 657e2c4..faf2e70 100644 --- a/traverse.go +++ b/traverse.go @@ -95,6 +95,10 @@ func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) * } pointer, bit := r.traverseTree(ip, 0, uint(prefixLength)) + + if bit < prefixLength { + ip = ip.Mask(net.CIDRMask(bit, len(ip)*8)) + } networks.nodes = []netNode{ { ip: ip, diff --git a/traverse_test.go b/traverse_test.go index 4756f1b..0248243 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -85,6 +85,27 @@ var tests = []networkTest{ "1.1.1.1/32", }, }, + { + Network: "1.1.1.2/32", + Database: "ipv4", + Expected: []string{ + "1.1.1.2/31", + }, + }, + { + Network: "1.1.1.3/32", + Database: "ipv4", + Expected: []string{ + "1.1.1.2/31", + }, + }, + { + Network: "1.1.1.19/32", + Database: "ipv4", + Expected: []string{ + "1.1.1.16/28", + }, + }, { Network: "255.255.255.0/24", Database: "ipv4", From b2df6c39f0cd1f30e8141b9d279237423128896e Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Fri, 28 Jun 2024 08:25:43 -0700 Subject: [PATCH 3/3] Handle input values that are not in canonical form --- traverse.go | 9 ++++++--- traverse_test.go | 27 ++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/traverse.go b/traverse.go index faf2e70..90073e2 100644 --- a/traverse.go +++ b/traverse.go @@ -96,9 +96,12 @@ func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) * pointer, bit := r.traverseTree(ip, 0, uint(prefixLength)) - if bit < prefixLength { - ip = ip.Mask(net.CIDRMask(bit, len(ip)*8)) - } + // We could skip this when bit >= prefixLength if we assume that the network + // passed in is in canonical form. However, given that this may not be the + // case, it is safest to always take the mask. If this is hot code at some + // point, we could eliminate the allocation of the net.IPMask by zeroing + // out the bits in ip directly. + ip = ip.Mask(net.CIDRMask(bit, len(ip)*8)) networks.nodes = []netNode{ { ip: ip, diff --git a/traverse_test.go b/traverse_test.go index 0248243..00edfce 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -3,6 +3,8 @@ package maxminddb import ( "fmt" "net" + "strconv" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -71,6 +73,8 @@ var tests = []networkTest{ }, }, { + // This is intentionally in non-canonical form to test + // that we handle it correctly. Network: "1.1.1.1/30", Database: "ipv4", Expected: []string{ @@ -78,6 +82,13 @@ var tests = []networkTest{ "1.1.1.2/31", }, }, + { + Network: "1.1.1.2/31", + Database: "ipv4", + Expected: []string{ + "1.1.1.2/31", + }, + }, { Network: "1.1.1.1/32", Database: "ipv4", @@ -267,7 +278,21 @@ func TestNetworksWithin(t *testing.T) { reader, err := Open(fileName) require.NoError(t, err, "unexpected error while opening database: %v", err) - _, network, err := net.ParseCIDR(v.Network) + // We are purposely not using net.ParseCIDR so that we can pass in + // values that aren't in canonical form. + parts := strings.Split(v.Network, "/") + ip := net.ParseIP(parts[0]) + if v := ip.To4(); v != nil { + ip = v + } + prefixLength, err := strconv.Atoi(parts[1]) + require.NoError(t, err) + mask := net.CIDRMask(prefixLength, len(ip)*8) + network := &net.IPNet{ + IP: ip, + Mask: mask, + } + require.NoError(t, err) n := reader.NetworksWithin(network, v.Options...) var innerIPs []string