Skip to content

Commit 72707a6

Browse files
committed
optimize ScanColumnsToStruct
1 parent 10b47d4 commit 72707a6

File tree

2 files changed

+82
-32
lines changed

2 files changed

+82
-32
lines changed

dml_select.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,7 @@ func (b *SelectBuilder) SelectStructWithTable(s any, table string) *SelectBuilde
217217
defer ttlock.Unlock()
218218

219219
columntables = typetables.Load().(map[typetable][]string)
220-
columns, ok = columntables[key]
221-
if !ok {
220+
if columns, ok = columntables[key]; !ok {
222221
columns = b.getColumnsFromStruct(s, table)
223222

224223
_columntables := make(map[typetable][]string, len(columntables)+1)

dml_select_row_scanner_struct.go

Lines changed: 81 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
package sqlx
1616

1717
import (
18+
"database/sql"
1819
"database/sql/driver"
1920
"fmt"
21+
"maps"
2022
"reflect"
2123
"strings"
24+
"sync"
25+
"sync/atomic"
2226
"time"
2327
)
2428

@@ -33,41 +37,86 @@ func ScanColumnsToStruct(scan func(...any) error, columns []string, s any) (err
3337
panic("sqlx.ScanColumnsToStruct: no selected columns")
3438
}
3539

36-
fields := getFields(s)
37-
vs := make([]any, len(columns))
38-
for i, c := range columns {
39-
if _, ok := fields[c]; ok {
40-
vs[i] = fields[c].Addr().Interface()
41-
} else {
42-
vs[i] = new(any)
40+
value := reflect.ValueOf(s)
41+
extract := getscannerextractfunc(value.Type())
42+
scanners := make([]any, len(columns))
43+
extract(value, scanners, columns)
44+
return scan(scanners...)
45+
}
46+
47+
func getscannerextractfunc(vtype reflect.Type) scannerExtractFunc {
48+
extract, ok := _scannedstructmaps.Load().(map[reflect.Type]scannerExtractFunc)[vtype]
49+
if !ok {
50+
_scannedstructlock.Lock()
51+
defer _scannedstructlock.Unlock()
52+
53+
types := _scannedstructmaps.Load().(map[reflect.Type]scannerExtractFunc)
54+
if extract, ok = types[vtype]; !ok {
55+
extract = getScannerFieldsFromStruct(vtype)
56+
57+
newtypes := make(map[reflect.Type]scannerExtractFunc, len(types)+1)
58+
maps.Copy(newtypes, types)
59+
newtypes[vtype] = extract
60+
61+
_scannedstructmaps.Store(newtypes)
4362
}
4463
}
45-
return scan(vs...)
64+
return extract
4665
}
4766

48-
func getFields(s any) map[string]reflect.Value {
49-
v := reflect.ValueOf(s)
50-
if v.Kind() != reflect.Ptr {
51-
panic("not a pointer to struct")
52-
} else if v = v.Elem(); v.Kind() != reflect.Struct {
53-
panic("not a pointer to struct")
67+
type scannerExtractFunc func(value reflect.Value, scanners []any, columns []string)
68+
69+
var (
70+
_scannedstructlock sync.Mutex
71+
_scannedstructmaps atomic.Value // map[reflect.Type]scannerExtractFunc
72+
)
73+
74+
func init() {
75+
_scannedstructmaps.Store(map[reflect.Type]scannerExtractFunc(nil))
76+
}
77+
78+
func getScannerFieldsFromStruct(vtype reflect.Type) scannerExtractFunc {
79+
if vtype.Kind() != reflect.Ptr {
80+
panic("sqlx.ScanColumnsToStruct: not a pointer to struct")
81+
} else if vtype = vtype.Elem(); vtype.Kind() != reflect.Struct {
82+
panic("sqlx.ScanColumnsToStruct: not a pointer to struct")
5483
}
5584

56-
vs := make(map[string]reflect.Value, v.NumField())
57-
getFieldsFromStruct("", v, vs)
58-
return vs
85+
fields := make(map[string]scannedfield, 16)
86+
_getScannerFieldsFromStruct(fields, vtype, "", nil)
87+
88+
return func(value reflect.Value, scanners []any, columns []string) {
89+
value = value.Elem()
90+
for i, column := range columns {
91+
if field, ok := fields[column]; ok && column != "deleted_at" {
92+
scanners[i] = field.Scanner(value)
93+
} else {
94+
scanners[i] = nullScanner{}
95+
}
96+
}
97+
}
98+
}
99+
100+
type scannedfield struct {
101+
Indexes []int
59102
}
60103

61-
func getFieldsFromStruct(prefix string, v reflect.Value, vs map[string]reflect.Value) {
62-
vt := v.Type()
63-
_len := v.NumField()
104+
func (f scannedfield) Scanner(value reflect.Value) sql.Scanner {
105+
for _, index := range f.Indexes {
106+
value = value.Field(index)
107+
}
108+
return nullScanner{Value: value.Addr().Interface()}
109+
}
110+
111+
func _getScannerFieldsFromStruct(fields map[string]scannedfield, vtype reflect.Type, prefix string, indexes []int) {
112+
_len := vtype.NumField()
64113

65114
LOOP:
66115
for i := 0; i < _len; i++ {
67-
vft := vt.Field(i)
116+
ftype := vtype.Field(i)
68117

69118
var targs string
70-
tname := vft.Tag.Get("sql")
119+
tname := ftype.Tag.Get("sql")
71120
if index := strings.IndexByte(tname, ','); index > -1 {
72121
targs = tname[index+1:]
73122
tname = strings.TrimSpace(tname[:index])
@@ -77,29 +126,31 @@ LOOP:
77126
continue
78127
}
79128

80-
name := vft.Name
129+
name := ftype.Name
81130
if tname != "" {
82131
name = tname
83132
}
84133

85-
vf := v.Field(i)
86-
if vft.Type.Kind() == reflect.Struct {
134+
_indexes := make([]int, 0, len(indexes)+1)
135+
_indexes = append(_indexes, indexes...)
136+
_indexes = append(_indexes, i)
137+
138+
if ftype.Type.Kind() == reflect.Struct {
87139
if tagContainAttr(targs, "notpropagate") {
88140
continue
89141
}
90142

91-
switch vf.Interface().(type) {
143+
fvalue := reflect.New(ftype.Type).Elem()
144+
switch fvalue.Interface().(type) {
92145
case time.Time:
93146
case driver.Valuer:
94147
default:
95-
getFieldsFromStruct(formatFieldName(prefix, tname), vf, vs)
148+
_getScannerFieldsFromStruct(fields, ftype.Type, formatFieldName(prefix, tname), _indexes)
96149
continue LOOP
97150
}
98151
}
99152

100-
if vf.CanSet() {
101-
vs[formatFieldName(prefix, name)] = v.Field(i)
102-
}
153+
fields[formatFieldName(prefix, name)] = scannedfield{Indexes: _indexes}
103154
}
104155
}
105156

0 commit comments

Comments
 (0)