5
5
"bytes"
6
6
"errors"
7
7
"fmt"
8
- "net"
8
+ "net/netip "
9
9
"reflect"
10
10
)
11
11
@@ -110,6 +110,7 @@ func FromBytes(buffer []byte) (*Reader, error) {
110
110
111
111
func (r * Reader ) setIPv4Start () {
112
112
if r .Metadata .IPVersion != 6 {
113
+ r .ipv4StartBitDepth = 96
113
114
return
114
115
}
115
116
@@ -130,7 +131,7 @@ func (r *Reader) setIPv4Start() {
130
131
// because of type differences, an UnmarshalTypeError is returned. If the
131
132
// database is invalid or otherwise cannot be read, an InvalidDatabaseError
132
133
// is returned.
133
- func (r * Reader ) Lookup (ip net. IP , result any ) error {
134
+ func (r * Reader ) Lookup (ip netip. Addr , result any ) error {
134
135
if r .buffer == nil {
135
136
return errors .New ("cannot call Lookup on a closed database" )
136
137
}
@@ -142,7 +143,7 @@ func (r *Reader) Lookup(ip net.IP, result any) error {
142
143
}
143
144
144
145
// LookupNetwork retrieves the database record for ip and stores it in the
145
- // value pointed to by result. The network returned is the network associated
146
+ // value pointed to by result. The prefix returned is the network associated
146
147
// with the data record in the database. The ok return value indicates whether
147
148
// the database contained a record for the ip.
148
149
//
@@ -151,28 +152,29 @@ func (r *Reader) Lookup(ip net.IP, result any) error {
151
152
// UnmarshalTypeError is returned. If the database is invalid or otherwise
152
153
// cannot be read, an InvalidDatabaseError is returned.
153
154
func (r * Reader ) LookupNetwork (
154
- ip net. IP ,
155
+ ip netip. Addr ,
155
156
result any ,
156
- ) (network * net. IPNet , ok bool , err error ) {
157
+ ) (prefix netip. Prefix , ok bool , err error ) {
157
158
if r .buffer == nil {
158
- return nil , false , errors .New ("cannot call Lookup on a closed database" )
159
+ return netip. Prefix {} , false , errors .New ("cannot call Lookup on a closed database" )
159
160
}
160
161
pointer , prefixLength , ip , err := r .lookupPointer (ip )
162
+ // We return this error below as we want to return the prefix it is for
161
163
162
- network = r .cidr (ip , prefixLength )
163
- if pointer == 0 || err != nil {
164
- return network , false , err
164
+ prefix , errP : = r .cidr (ip , prefixLength )
165
+ if pointer == 0 || err != nil || errP != nil {
166
+ return prefix , false , errors . Join ( err , errP )
165
167
}
166
168
167
- return network , true , r .retrieveData (pointer , result )
169
+ return prefix , true , r .retrieveData (pointer , result )
168
170
}
169
171
170
172
// LookupOffset maps an argument net.IP to a corresponding record offset in the
171
173
// database. NotFound is returned if no such record is found, and a record may
172
174
// otherwise be extracted by passing the returned offset to Decode. LookupOffset
173
175
// is an advanced API, which exists to provide clients with a means to cache
174
176
// previously-decoded records.
175
- func (r * Reader ) LookupOffset (ip net. IP ) (uintptr , error ) {
177
+ func (r * Reader ) LookupOffset (ip netip. Addr ) (uintptr , error ) {
176
178
if r .buffer == nil {
177
179
return 0 , errors .New ("cannot call LookupOffset on a closed database" )
178
180
}
@@ -183,22 +185,28 @@ func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) {
183
185
return r .resolveDataPointer (pointer )
184
186
}
185
187
186
- func (r * Reader ) cidr (ip net.IP , prefixLength int ) * net.IPNet {
187
- // This is necessary as the node that the IPv4 start is at may
188
- // be at a bit depth that is less that 96, i.e., ipv4Start points
189
- // to a leaf node. For instance, if a record was inserted at ::/8,
190
- // the ipv4Start would point directly at the leaf node for the
191
- // record and would have a bit depth of 8. This would not happen
192
- // with databases currently distributed by MaxMind as all of them
193
- // have an IPv4 subtree that is greater than a single node.
194
- if r .Metadata .IPVersion == 6 &&
195
- len (ip ) == net .IPv4len &&
196
- r .ipv4StartBitDepth != 96 {
197
- return & net.IPNet {IP : net .ParseIP ("::" ), Mask : net .CIDRMask (r .ipv4StartBitDepth , 128 )}
188
+ var zeroIP = netip .MustParseAddr ("::" )
189
+
190
+ func (r * Reader ) cidr (ip netip.Addr , prefixLength int ) (netip.Prefix , error ) {
191
+ if ip .Is4 () {
192
+ // This is necessary as the node that the IPv4 start is at may
193
+ // be at a bit depth that is less that 96, i.e., ipv4Start points
194
+ // to a leaf node. For instance, if a record was inserted at ::/8,
195
+ // the ipv4Start would point directly at the leaf node for the
196
+ // record and would have a bit depth of 8. This would not happen
197
+ // with databases currently distributed by MaxMind as all of them
198
+ // have an IPv4 subtree that is greater than a single node.
199
+ if r .Metadata .IPVersion == 6 && r .ipv4StartBitDepth != 96 {
200
+ return netip .PrefixFrom (zeroIP , r .ipv4StartBitDepth ), nil
201
+ }
202
+ prefixLength -= 96
198
203
}
199
204
200
- mask := net .CIDRMask (prefixLength , len (ip )* 8 )
201
- return & net.IPNet {IP : ip .Mask (mask ), Mask : mask }
205
+ prefix , err := ip .Prefix (prefixLength )
206
+ if err != nil {
207
+ return netip.Prefix {}, fmt .Errorf ("creating prefix from %s/%d: %w" , ip , prefixLength , err )
208
+ }
209
+ return prefix , nil
202
210
}
203
211
204
212
// Decode the record at |offset| into |result|. The result value pointed to
@@ -239,29 +247,15 @@ func (r *Reader) decode(offset uintptr, result any) error {
239
247
return err
240
248
}
241
249
242
- func (r * Reader ) lookupPointer (ip net.IP ) (uint , int , net.IP , error ) {
243
- if ip == nil {
244
- return 0 , 0 , nil , errors .New ("IP passed to Lookup cannot be nil" )
245
- }
246
-
247
- ipV4Address := ip .To4 ()
248
- if ipV4Address != nil {
249
- ip = ipV4Address
250
- }
251
- if len (ip ) == 16 && r .Metadata .IPVersion == 4 {
250
+ func (r * Reader ) lookupPointer (ip netip.Addr ) (uint , int , netip.Addr , error ) {
251
+ if r .Metadata .IPVersion == 4 && ip .Is6 () {
252
252
return 0 , 0 , ip , fmt .Errorf (
253
253
"error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database" ,
254
254
ip .String (),
255
255
)
256
256
}
257
257
258
- bitCount := uint (len (ip ) * 8 )
259
-
260
- var node uint
261
- if bitCount == 32 {
262
- node = r .ipv4Start
263
- }
264
- node , prefixLength := r .traverseTree (ip , node , bitCount )
258
+ node , prefixLength := r .traverseTree (ip , 0 , 128 )
265
259
266
260
nodeCount := r .Metadata .NodeCount
267
261
if node == nodeCount {
@@ -274,12 +268,18 @@ func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) {
274
268
return 0 , prefixLength , ip , newInvalidDatabaseError ("invalid node in search tree" )
275
269
}
276
270
277
- func (r * Reader ) traverseTree (ip net.IP , node , bitCount uint ) (uint , int ) {
271
+ func (r * Reader ) traverseTree (ip netip.Addr , node uint , stopBit int ) (uint , int ) {
272
+ i := 0
273
+ if ip .Is4 () {
274
+ i = r .ipv4StartBitDepth
275
+ node = r .ipv4Start
276
+ }
278
277
nodeCount := r .Metadata .NodeCount
279
278
280
- i := uint (0 )
281
- for ; i < bitCount && node < nodeCount ; i ++ {
282
- bit := uint (1 ) & (uint (ip [i >> 3 ]) >> (7 - (i % 8 )))
279
+ bytes := ip .As16 ()
280
+
281
+ for ; i < stopBit && node < nodeCount ; i ++ {
282
+ bit := uint (1 ) & (uint (bytes [i >> 3 ]) >> (7 - (i % 8 )))
283
283
284
284
offset := node * r .nodeOffsetMult
285
285
if bit == 0 {
@@ -289,7 +289,7 @@ func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) {
289
289
}
290
290
}
291
291
292
- return node , int ( i )
292
+ return node , i
293
293
}
294
294
295
295
func (r * Reader ) retrieveData (pointer uint , result any ) error {
0 commit comments