Skip to content

Commit da62ba2

Browse files
committed
Switch to net/netip
The preformance of this is approximately the same as the net.IP version, except for the methods that return a network. For those, there is a slight improvement.
1 parent 616cde2 commit da62ba2

File tree

7 files changed

+207
-206
lines changed

7 files changed

+207
-206
lines changed

deserializer_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package maxminddb
22

33
import (
44
"math/big"
5-
"net"
5+
"net/netip"
66
"testing"
77

88
"github.com/stretchr/testify/require"
@@ -13,7 +13,7 @@ func TestDecodingToDeserializer(t *testing.T) {
1313
require.NoError(t, err, "unexpected error while opening database: %v", err)
1414

1515
dser := testDeserializer{}
16-
err = reader.Lookup(net.ParseIP("::1.1.1.0"), &dser)
16+
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &dser)
1717
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
1818

1919
checkDecodingToInterface(t, dser.rv)

example_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ package maxminddb_test
33
import (
44
"fmt"
55
"log"
6-
"net"
6+
"net/netip"
77

8-
"github.com/oschwald/maxminddb-golang"
8+
"github.com/oschwald/maxminddb-golang/v2"
99
)
1010

1111
// This example shows how to decode to a struct.
@@ -16,15 +16,15 @@ func ExampleReader_Lookup_struct() {
1616
}
1717
defer db.Close()
1818

19-
ip := net.ParseIP("81.2.69.142")
19+
addr := netip.MustParseAddr("81.2.69.142")
2020

2121
var record struct {
2222
Country struct {
2323
ISOCode string `maxminddb:"iso_code"`
2424
} `maxminddb:"country"`
2525
} // Or any appropriate struct
2626

27-
err = db.Lookup(ip, &record)
27+
err = db.Lookup(addr, &record)
2828
if err != nil {
2929
log.Panic(err)
3030
}
@@ -41,10 +41,10 @@ func ExampleReader_Lookup_interface() {
4141
}
4242
defer db.Close()
4343

44-
ip := net.ParseIP("81.2.69.142")
44+
addr := netip.MustParseAddr("81.2.69.142")
4545

4646
var record any
47-
err = db.Lookup(ip, &record)
47+
err = db.Lookup(addr, &record)
4848
if err != nil {
4949
log.Panic(err)
5050
}
@@ -118,12 +118,12 @@ func ExampleReader_NetworksWithin() {
118118
Domain string `maxminddb:"connection_type"`
119119
}{}
120120

121-
_, network, err := net.ParseCIDR("1.0.0.0/8")
121+
prefix, err := netip.ParsePrefix("1.0.0.0/8")
122122
if err != nil {
123123
log.Panic(err)
124124
}
125125

126-
networks := db.NetworksWithin(network, maxminddb.SkipAliasedNetworks)
126+
networks := db.NetworksWithin(prefix, maxminddb.SkipAliasedNetworks)
127127
for networks.Next() {
128128
subnet, err := networks.Network(&record)
129129
if err != nil {

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module github.com/oschwald/maxminddb-golang
1+
module github.com/oschwald/maxminddb-golang/v2
22

33
go 1.21
44

reader.go

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"bytes"
66
"errors"
77
"fmt"
8-
"net"
8+
"net/netip"
99
"reflect"
1010
)
1111

@@ -110,6 +110,7 @@ func FromBytes(buffer []byte) (*Reader, error) {
110110

111111
func (r *Reader) setIPv4Start() {
112112
if r.Metadata.IPVersion != 6 {
113+
r.ipv4StartBitDepth = 96
113114
return
114115
}
115116

@@ -130,7 +131,7 @@ func (r *Reader) setIPv4Start() {
130131
// because of type differences, an UnmarshalTypeError is returned. If the
131132
// database is invalid or otherwise cannot be read, an InvalidDatabaseError
132133
// is returned.
133-
func (r *Reader) Lookup(ip net.IP, result any) error {
134+
func (r *Reader) Lookup(ip netip.Addr, result any) error {
134135
if r.buffer == nil {
135136
return errors.New("cannot call Lookup on a closed database")
136137
}
@@ -142,7 +143,7 @@ func (r *Reader) Lookup(ip net.IP, result any) error {
142143
}
143144

144145
// LookupNetwork retrieves the database record for ip and stores it in the
145-
// value pointed to by result. The network returned is the network associated
146+
// value pointed to by result. The prefix returned is the network associated
146147
// with the data record in the database. The ok return value indicates whether
147148
// the database contained a record for the ip.
148149
//
@@ -151,28 +152,29 @@ func (r *Reader) Lookup(ip net.IP, result any) error {
151152
// UnmarshalTypeError is returned. If the database is invalid or otherwise
152153
// cannot be read, an InvalidDatabaseError is returned.
153154
func (r *Reader) LookupNetwork(
154-
ip net.IP,
155+
ip netip.Addr,
155156
result any,
156-
) (network *net.IPNet, ok bool, err error) {
157+
) (prefix netip.Prefix, ok bool, err error) {
157158
if r.buffer == nil {
158-
return nil, false, errors.New("cannot call Lookup on a closed database")
159+
return netip.Prefix{}, false, errors.New("cannot call Lookup on a closed database")
159160
}
160161
pointer, prefixLength, ip, err := r.lookupPointer(ip)
162+
// We return this error below as we want to return the prefix it is for
161163

162-
network = r.cidr(ip, prefixLength)
163-
if pointer == 0 || err != nil {
164-
return network, false, err
164+
prefix, errP := r.cidr(ip, prefixLength)
165+
if pointer == 0 || err != nil || errP != nil {
166+
return prefix, false, errors.Join(err, errP)
165167
}
166168

167-
return network, true, r.retrieveData(pointer, result)
169+
return prefix, true, r.retrieveData(pointer, result)
168170
}
169171

170172
// LookupOffset maps an argument net.IP to a corresponding record offset in the
171173
// database. NotFound is returned if no such record is found, and a record may
172174
// otherwise be extracted by passing the returned offset to Decode. LookupOffset
173175
// is an advanced API, which exists to provide clients with a means to cache
174176
// previously-decoded records.
175-
func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) {
177+
func (r *Reader) LookupOffset(ip netip.Addr) (uintptr, error) {
176178
if r.buffer == nil {
177179
return 0, errors.New("cannot call LookupOffset on a closed database")
178180
}
@@ -183,22 +185,28 @@ func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) {
183185
return r.resolveDataPointer(pointer)
184186
}
185187

186-
func (r *Reader) cidr(ip net.IP, prefixLength int) *net.IPNet {
187-
// This is necessary as the node that the IPv4 start is at may
188-
// be at a bit depth that is less that 96, i.e., ipv4Start points
189-
// to a leaf node. For instance, if a record was inserted at ::/8,
190-
// the ipv4Start would point directly at the leaf node for the
191-
// record and would have a bit depth of 8. This would not happen
192-
// with databases currently distributed by MaxMind as all of them
193-
// have an IPv4 subtree that is greater than a single node.
194-
if r.Metadata.IPVersion == 6 &&
195-
len(ip) == net.IPv4len &&
196-
r.ipv4StartBitDepth != 96 {
197-
return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(r.ipv4StartBitDepth, 128)}
188+
var zeroIP = netip.MustParseAddr("::")
189+
190+
func (r *Reader) cidr(ip netip.Addr, prefixLength int) (netip.Prefix, error) {
191+
if ip.Is4() {
192+
// This is necessary as the node that the IPv4 start is at may
193+
// be at a bit depth that is less that 96, i.e., ipv4Start points
194+
// to a leaf node. For instance, if a record was inserted at ::/8,
195+
// the ipv4Start would point directly at the leaf node for the
196+
// record and would have a bit depth of 8. This would not happen
197+
// with databases currently distributed by MaxMind as all of them
198+
// have an IPv4 subtree that is greater than a single node.
199+
if r.Metadata.IPVersion == 6 && r.ipv4StartBitDepth != 96 {
200+
return netip.PrefixFrom(zeroIP, r.ipv4StartBitDepth), nil
201+
}
202+
prefixLength -= 96
198203
}
199204

200-
mask := net.CIDRMask(prefixLength, len(ip)*8)
201-
return &net.IPNet{IP: ip.Mask(mask), Mask: mask}
205+
prefix, err := ip.Prefix(prefixLength)
206+
if err != nil {
207+
return netip.Prefix{}, fmt.Errorf("creating prefix from %s/%d: %w", ip, prefixLength, err)
208+
}
209+
return prefix, nil
202210
}
203211

204212
// Decode the record at |offset| into |result|. The result value pointed to
@@ -239,29 +247,15 @@ func (r *Reader) decode(offset uintptr, result any) error {
239247
return err
240248
}
241249

242-
func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) {
243-
if ip == nil {
244-
return 0, 0, nil, errors.New("IP passed to Lookup cannot be nil")
245-
}
246-
247-
ipV4Address := ip.To4()
248-
if ipV4Address != nil {
249-
ip = ipV4Address
250-
}
251-
if len(ip) == 16 && r.Metadata.IPVersion == 4 {
250+
func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, netip.Addr, error) {
251+
if r.Metadata.IPVersion == 4 && ip.Is6() {
252252
return 0, 0, ip, fmt.Errorf(
253253
"error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database",
254254
ip.String(),
255255
)
256256
}
257257

258-
bitCount := uint(len(ip) * 8)
259-
260-
var node uint
261-
if bitCount == 32 {
262-
node = r.ipv4Start
263-
}
264-
node, prefixLength := r.traverseTree(ip, node, bitCount)
258+
node, prefixLength := r.traverseTree(ip, 0, 128)
265259

266260
nodeCount := r.Metadata.NodeCount
267261
if node == nodeCount {
@@ -274,12 +268,18 @@ func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) {
274268
return 0, prefixLength, ip, newInvalidDatabaseError("invalid node in search tree")
275269
}
276270

277-
func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) {
271+
func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int) {
272+
i := 0
273+
if ip.Is4() {
274+
i = r.ipv4StartBitDepth
275+
node = r.ipv4Start
276+
}
278277
nodeCount := r.Metadata.NodeCount
279278

280-
i := uint(0)
281-
for ; i < bitCount && node < nodeCount; i++ {
282-
bit := uint(1) & (uint(ip[i>>3]) >> (7 - (i % 8)))
279+
bytes := ip.As16()
280+
281+
for ; i < stopBit && node < nodeCount; i++ {
282+
bit := uint(1) & (uint(bytes[i>>3]) >> (7 - (i % 8)))
283283

284284
offset := node * r.nodeOffsetMult
285285
if bit == 0 {
@@ -289,7 +289,7 @@ func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) {
289289
}
290290
}
291291

292-
return node, int(i)
292+
return node, i
293293
}
294294

295295
func (r *Reader) retrieveData(pointer uint, result any) error {

0 commit comments

Comments
 (0)