Skip to content

Commit 1b423ec

Browse files
committed
Make Lookup return a Result
This makes it easier to extend without adding many different lookup methods.
1 parent 288208c commit 1b423ec

File tree

6 files changed

+85
-51
lines changed

6 files changed

+85
-51
lines changed

decoder_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ func validateDecoding(t *testing.T, tests map[string]any) {
207207
for inputStr, expected := range tests {
208208
inputBytes, err := hex.DecodeString(inputStr)
209209
require.NoError(t, err)
210-
d := decoder{inputBytes}
210+
d := decoder{buffer: inputBytes}
211211

212212
var result any
213213
_, err = d.decode(0, reflect.ValueOf(&result), 0)
@@ -223,7 +223,7 @@ func validateDecoding(t *testing.T, tests map[string]any) {
223223
func TestPointers(t *testing.T) {
224224
bytes, err := os.ReadFile(testFile("maps-with-pointers.raw"))
225225
require.NoError(t, err)
226-
d := decoder{bytes}
226+
d := decoder{buffer: bytes}
227227

228228
expected := map[uint]map[string]string{
229229
0: {"long_key": "long_value1"},

deserializer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func TestDecodingToDeserializer(t *testing.T) {
1313
require.NoError(t, err, "unexpected error while opening database: %v", err)
1414

1515
dser := testDeserializer{}
16-
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &dser)
16+
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&dser)
1717
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
1818

1919
checkDecodingToInterface(t, dser.rv)

example_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func ExampleReader_Lookup_struct() {
2424
} `maxminddb:"country"`
2525
} // Or any appropriate struct
2626

27-
err = db.Lookup(addr, &record)
27+
err = db.Lookup(addr).Decode(&record)
2828
if err != nil {
2929
log.Panic(err)
3030
}
@@ -44,7 +44,7 @@ func ExampleReader_Lookup_interface() {
4444
addr := netip.MustParseAddr("81.2.69.142")
4545

4646
var record any
47-
err = db.Lookup(addr, &record)
47+
err = db.Lookup(addr).Decode(&record)
4848
if err != nil {
4949
log.Panic(err)
5050
}

reader.go

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func FromBytes(buffer []byte) (*Reader, error) {
6161
}
6262

6363
metadataStart += len(metadataStartMarker)
64-
metadataDecoder := decoder{buffer[metadataStart:]}
64+
metadataDecoder := decoder{buffer: buffer[metadataStart:]}
6565

6666
var metadata Metadata
6767

@@ -78,7 +78,7 @@ func FromBytes(buffer []byte) (*Reader, error) {
7878
return nil, newInvalidDatabaseError("the MaxMind DB contains invalid metadata")
7979
}
8080
d := decoder{
81-
buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)],
81+
buffer: buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)],
8282
}
8383

8484
nodeBuffer := buffer[:searchTreeSize]
@@ -131,15 +131,23 @@ func (r *Reader) setIPv4Start() {
131131
// because of type differences, an UnmarshalTypeError is returned. If the
132132
// database is invalid or otherwise cannot be read, an InvalidDatabaseError
133133
// is returned.
134-
func (r *Reader) Lookup(ip netip.Addr, result any) error {
134+
func (r *Reader) Lookup(ip netip.Addr) Result {
135135
if r.buffer == nil {
136-
return errors.New("cannot call Lookup on a closed database")
136+
return Result{err: errors.New("cannot call Lookup on a closed database")}
137137
}
138138
pointer, _, _, err := r.lookupPointer(ip)
139-
if pointer == 0 || err != nil {
140-
return err
139+
if err != nil {
140+
return Result{err: err}
141+
}
142+
if pointer == 0 {
143+
return Result{offset: notFound}
144+
}
145+
offset, err := r.resolveDataPointer(pointer)
146+
return Result{
147+
decoder: r.decoder,
148+
offset: uint(offset),
149+
err: err,
141150
}
142-
return r.retrieveData(pointer, result)
143151
}
144152

145153
// LookupNetwork retrieves the database record for ip and stores it in the
@@ -229,22 +237,8 @@ func (r *Reader) Decode(offset uintptr, result any) error {
229237
if r.buffer == nil {
230238
return errors.New("cannot call Decode on a closed database")
231239
}
232-
return r.decode(offset, result)
233-
}
234-
235-
func (r *Reader) decode(offset uintptr, result any) error {
236-
rv := reflect.ValueOf(result)
237-
if rv.Kind() != reflect.Ptr || rv.IsNil() {
238-
return errors.New("result param must be a pointer")
239-
}
240-
241-
if dser, ok := result.(deserializer); ok {
242-
_, err := r.decoder.decodeToDeserializer(uint(offset), dser, 0, false)
243-
return err
244-
}
245240

246-
_, err := r.decoder.decode(uint(offset), rv, 0)
247-
return err
241+
return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result)
248242
}
249243

250244
func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, netip.Addr, error) {
@@ -297,7 +291,7 @@ func (r *Reader) retrieveData(pointer uint, result any) error {
297291
if err != nil {
298292
return err
299293
}
300-
return r.decode(offset, result)
294+
return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result)
301295
}
302296

303297
func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) {

reader_test.go

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ func TestDecodingToInterface(t *testing.T) {
212212
require.NoError(t, err, "unexpected error while opening database: %v", err)
213213

214214
var recordInterface any
215-
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &recordInterface)
215+
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&recordInterface)
216216
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
217217

218218
checkDecodingToInterface(t, recordInterface)
@@ -299,7 +299,7 @@ func TestDecoder(t *testing.T) {
299299
{
300300
// Directly lookup and decode.
301301
var result TestType
302-
require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result))
302+
require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result))
303303
verify(result)
304304
}
305305
{
@@ -330,7 +330,7 @@ func TestStructInterface(t *testing.T) {
330330
reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb"))
331331
require.NoError(t, err)
332332

333-
require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result))
333+
require.NoError(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result))
334334

335335
assert.True(t, result.method())
336336
}
@@ -341,7 +341,7 @@ func TestNonEmptyNilInterface(t *testing.T) {
341341
reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb"))
342342
require.NoError(t, err)
343343

344-
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result)
344+
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result)
345345
assert.Equal(
346346
t,
347347
"maxminddb: cannot unmarshal map into type maxminddb.TestInterface",
@@ -364,7 +364,7 @@ func TestEmbeddedStructAsInterface(t *testing.T) {
364364
db, err := Open(testFile("GeoIP2-ISP-Test.mmdb"))
365365
require.NoError(t, err)
366366

367-
require.NoError(t, db.Lookup(netip.MustParseAddr("1.128.0.0"), &result))
367+
require.NoError(t, db.Lookup(netip.MustParseAddr("1.128.0.0")).Decode(&result))
368368
}
369369

370370
type BoolInterface interface {
@@ -390,7 +390,7 @@ func TestValueTypeInterface(t *testing.T) {
390390

391391
// although it would be nice to support cases like this, I am not sure it
392392
// is possible to do so in a general way.
393-
assert.Error(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result))
393+
assert.Error(t, reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result))
394394
}
395395

396396
type NestedMapX struct {
@@ -432,7 +432,7 @@ func TestComplexStructWithNestingAndPointer(t *testing.T) {
432432

433433
var result TestPointerType
434434

435-
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result)
435+
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result)
436436
require.NoError(t, err)
437437

438438
assert.Equal(t, []uint{uint(1), uint(2), uint(3)}, *result.Array)
@@ -464,7 +464,7 @@ func TestNestedMapDecode(t *testing.T) {
464464

465465
var r map[string]map[string]any
466466

467-
require.NoError(t, db.Lookup(netip.MustParseAddr("89.160.20.128"), &r))
467+
require.NoError(t, db.Lookup(netip.MustParseAddr("89.160.20.128")).Decode(&r))
468468

469469
assert.Equal(
470470
t,
@@ -564,7 +564,7 @@ func TestDecodingUint16IntoInt(t *testing.T) {
564564
var result struct {
565565
Uint16 int `maxminddb:"uint16"`
566566
}
567-
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &result)
567+
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(&result)
568568
require.NoError(t, err)
569569

570570
assert.Equal(t, 100, result.Uint16)
@@ -575,7 +575,7 @@ func TestIpv6inIpv4(t *testing.T) {
575575
require.NoError(t, err, "unexpected error while opening database: %v", err)
576576

577577
var result TestType
578-
err = reader.Lookup(netip.MustParseAddr("2001::"), &result)
578+
err = reader.Lookup(netip.MustParseAddr("2001::")).Decode(&result)
579579

580580
var emptyResult TestType
581581
assert.Equal(t, emptyResult, result)
@@ -592,7 +592,7 @@ func TestBrokenDoubleDatabase(t *testing.T) {
592592
require.NoError(t, err, "unexpected error while opening database: %v", err)
593593

594594
var result any
595-
err = reader.Lookup(netip.MustParseAddr("2001:220::"), &result)
595+
err = reader.Lookup(netip.MustParseAddr("2001:220::")).Decode(&result)
596596

597597
expected := newInvalidDatabaseError(
598598
"the MaxMind DB file's data section contains bad data (float 64 size of 2)",
@@ -625,7 +625,7 @@ func TestDecodingToNonPointer(t *testing.T) {
625625
require.NoError(t, err)
626626

627627
var recordInterface any
628-
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), recordInterface)
628+
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0")).Decode(recordInterface)
629629
assert.Equal(t, "result param must be a pointer", err.Error())
630630
require.NoError(t, reader.Close(), "error on close")
631631
}
@@ -635,7 +635,7 @@ func TestDecodingToNonPointer(t *testing.T) {
635635
// require.NoError(t, err)
636636

637637
// var recordInterface any
638-
// err = reader.Lookup(nil, recordInterface)
638+
// err = reader.Lookup(nil).Decode( recordInterface)
639639
// assert.Equal(t, "IP passed to Lookup cannot be nil", err.Error())
640640
// require.NoError(t, reader.Close(), "error on close")
641641
// }
@@ -647,7 +647,7 @@ func TestUsingClosedDatabase(t *testing.T) {
647647

648648
var recordInterface any
649649
addr := netip.MustParseAddr("::")
650-
err = reader.Lookup(addr, recordInterface)
650+
err = reader.Lookup(addr).Decode(recordInterface)
651651
assert.Equal(t, "cannot call Lookup on a closed database", err.Error())
652652

653653
_, err = reader.LookupOffset(addr)
@@ -688,7 +688,7 @@ func checkIpv4(t *testing.T, reader *Reader) {
688688
ip := netip.MustParseAddr(address)
689689

690690
var result map[string]string
691-
err := reader.Lookup(ip, &result)
691+
err := reader.Lookup(ip).Decode(&result)
692692
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
693693
assert.Equal(t, map[string]string{"ip": address}, result)
694694
}
@@ -708,7 +708,7 @@ func checkIpv4(t *testing.T, reader *Reader) {
708708
ip := netip.MustParseAddr(keyAddress)
709709

710710
var result map[string]string
711-
err := reader.Lookup(ip, &result)
711+
err := reader.Lookup(ip).Decode(&result)
712712
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
713713
assert.Equal(t, data, result)
714714
}
@@ -717,7 +717,7 @@ func checkIpv4(t *testing.T, reader *Reader) {
717717
ip := netip.MustParseAddr(address)
718718

719719
var result map[string]string
720-
err := reader.Lookup(ip, &result)
720+
err := reader.Lookup(ip).Decode(&result)
721721
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
722722
assert.Nil(t, result)
723723
}
@@ -731,7 +731,7 @@ func checkIpv6(t *testing.T, reader *Reader) {
731731

732732
for _, address := range subnets {
733733
var result map[string]string
734-
err := reader.Lookup(netip.MustParseAddr(address), &result)
734+
err := reader.Lookup(netip.MustParseAddr(address)).Decode(&result)
735735
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
736736
assert.Equal(t, map[string]string{"ip": address}, result)
737737
}
@@ -750,14 +750,14 @@ func checkIpv6(t *testing.T, reader *Reader) {
750750
for keyAddress, valueAddress := range pairs {
751751
data := map[string]string{"ip": valueAddress}
752752
var result map[string]string
753-
err := reader.Lookup(netip.MustParseAddr(keyAddress), &result)
753+
err := reader.Lookup(netip.MustParseAddr(keyAddress)).Decode(&result)
754754
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
755755
assert.Equal(t, data, result)
756756
}
757757

758758
for _, address := range []string{"1.1.1.33", "255.254.253.123", "89fa::"} {
759759
var result map[string]string
760-
err := reader.Lookup(netip.MustParseAddr(address), &result)
760+
err := reader.Lookup(netip.MustParseAddr(address)).Decode(&result)
761761
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
762762
assert.Nil(t, result)
763763
}
@@ -787,7 +787,7 @@ func BenchmarkInterfaceLookup(b *testing.B) {
787787
s := make(net.IP, 4)
788788
for i := 0; i < b.N; i++ {
789789
ip := randomIPv4Address(r, s)
790-
err = db.Lookup(ip, &result)
790+
err = db.Lookup(ip).Decode(&result)
791791
if err != nil {
792792
b.Error(err)
793793
}
@@ -875,7 +875,7 @@ func BenchmarkCityLookup(b *testing.B) {
875875
s := make(net.IP, 4)
876876
for i := 0; i < b.N; i++ {
877877
ip := randomIPv4Address(r, s)
878-
err = db.Lookup(ip, &result)
878+
err = db.Lookup(ip).Decode(&result)
879879
if err != nil {
880880
b.Error(err)
881881
}
@@ -919,7 +919,7 @@ func BenchmarkCountryCode(b *testing.B) {
919919
s := make(net.IP, 4)
920920
for i := 0; i < b.N; i++ {
921921
ip := randomIPv4Address(r, s)
922-
err = db.Lookup(ip, &result)
922+
err = db.Lookup(ip).Decode(&result)
923923
if err != nil {
924924
b.Error(err)
925925
}

result.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package maxminddb
2+
3+
import (
4+
"errors"
5+
"math"
6+
"reflect"
7+
)
8+
9+
const notFound uint = math.MaxUint
10+
11+
type Result struct {
12+
decoder decoder
13+
offset uint
14+
err error
15+
}
16+
17+
func (r Result) Decode(v any) error {
18+
if r.err != nil {
19+
return r.err
20+
}
21+
if r.offset == notFound {
22+
return nil
23+
}
24+
rv := reflect.ValueOf(v)
25+
if rv.Kind() != reflect.Ptr || rv.IsNil() {
26+
return errors.New("result param must be a pointer")
27+
}
28+
29+
if dser, ok := v.(deserializer); ok {
30+
_, err := r.decoder.decodeToDeserializer(r.offset, dser, 0, false)
31+
return err
32+
}
33+
34+
_, err := r.decoder.decode(r.offset, rv, 0)
35+
return err
36+
}
37+
38+
func (d Result) Found() bool {
39+
return d.err == nil && d.offset != notFound
40+
}

0 commit comments

Comments
 (0)