@@ -17,20 +17,21 @@ var ErrInsertIntoDuplicateColumn = errors.NewKind("duplicate column name %v")
17
17
var ErrInsertIntoNonexistentColumn = errors .NewKind ("invalid column name %v" )
18
18
var ErrInsertIntoNonNullableDefaultNullColumn = errors .NewKind ("column name '%v' is non-nullable but attempted to set default value of null" )
19
19
var ErrInsertIntoNonNullableProvidedNull = errors .NewKind ("column name '%v' is non-nullable but attempted to set a value of null" )
20
+ var ErrInsertIntoIncompatibleTypes = errors .NewKind ("cannot convert type %s to %s" )
20
21
21
22
// InsertInto is a node describing the insertion into some table.
22
23
type InsertInto struct {
23
24
BinaryNode
24
- Columns []string
25
- IsReplace bool
25
+ ColumnNames []string
26
+ IsReplace bool
26
27
}
27
28
28
29
// NewInsertInto creates an InsertInto node.
29
30
func NewInsertInto (dst , src sql.Node , isReplace bool , cols []string ) * InsertInto {
30
31
return & InsertInto {
31
- BinaryNode : BinaryNode {Left : dst , Right : src },
32
- Columns : cols ,
33
- IsReplace : isReplace ,
32
+ BinaryNode : BinaryNode {Left : dst , Right : src },
33
+ ColumnNames : cols ,
34
+ IsReplace : isReplace ,
34
35
}
35
36
}
36
37
@@ -83,13 +84,12 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
83
84
}
84
85
85
86
dstSchema := p .Left .Schema ()
86
- projExprs := make ([]sql.Expression , len (dstSchema ))
87
87
88
88
// If no columns are given, we assume the full schema in order
89
- if len (p .Columns ) == 0 {
90
- p .Columns = make ([]string , len (dstSchema ))
89
+ if len (p .ColumnNames ) == 0 {
90
+ p .ColumnNames = make ([]string , len (dstSchema ))
91
91
for i , f := range dstSchema {
92
- p .Columns [i ] = f .Name
92
+ p .ColumnNames [i ] = f .Name
93
93
}
94
94
} else {
95
95
err = p .validateColumns (dstSchema )
@@ -103,9 +103,10 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
103
103
return 0 , err
104
104
}
105
105
106
+ projExprs := make ([]sql.Expression , len (dstSchema ))
106
107
for i , f := range dstSchema {
107
108
found := false
108
- for j , col := range p .Columns {
109
+ for j , col := range p .ColumnNames {
109
110
if f .Name == col {
110
111
projExprs [i ] = expression .NewGetField (j , f .Type , f .Name , f .Nullable )
111
112
found = true
@@ -121,9 +122,12 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
121
122
}
122
123
}
123
124
124
- proj := NewProject (projExprs , p .Right )
125
+ rowSource , err := p .rowSource (projExprs )
126
+ if err != nil {
127
+ return 0 , err
128
+ }
125
129
126
- iter , err := proj .RowIter (ctx )
130
+ iter , err := rowSource .RowIter (ctx )
127
131
if err != nil {
128
132
return 0 , err
129
133
}
@@ -145,11 +149,11 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
145
149
return i , err
146
150
}
147
151
148
- // Convert integer values in row to specified type in schema
152
+ // Convert values to the destination schema type
149
153
for colIdx , oldValue := range row {
150
154
dstColType := projExprs [colIdx ].Type ()
151
155
152
- if sql . IsInteger ( dstColType ) && oldValue != nil {
156
+ if oldValue != nil {
153
157
newValue , err := dstColType .Convert (oldValue )
154
158
if err != nil {
155
159
return i , err
@@ -185,6 +189,20 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
185
189
return i , nil
186
190
}
187
191
192
+ func (p * InsertInto ) rowSource (projExprs []sql.Expression ) (sql.Node , error ) {
193
+ switch n := p .Right .(type ) {
194
+ case * Values :
195
+ return NewProject (projExprs , n ), nil
196
+ case * ResolvedTable , * Project , * InnerJoin :
197
+ if err := assertCompatibleSchemas (projExprs , n .Schema ()); err != nil {
198
+ return nil , err
199
+ }
200
+ return NewProject (projExprs , n ), nil
201
+ default :
202
+ return nil , ErrInsertIntoUnsupportedValues .New (n )
203
+ }
204
+ }
205
+
188
206
// RowIter implements the Node interface.
189
207
func (p * InsertInto ) RowIter (ctx * sql.Context ) (sql.RowIter , error ) {
190
208
n , err := p .Execute (ctx )
@@ -201,15 +219,15 @@ func (p *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) {
201
219
return nil , sql .ErrInvalidChildrenNumber .New (p , len (children ), 2 )
202
220
}
203
221
204
- return NewInsertInto (children [0 ], children [1 ], p .IsReplace , p .Columns ), nil
222
+ return NewInsertInto (children [0 ], children [1 ], p .IsReplace , p .ColumnNames ), nil
205
223
}
206
224
207
225
func (p InsertInto ) String () string {
208
226
pr := sql .NewTreePrinter ()
209
227
if p .IsReplace {
210
- _ = pr .WriteNode ("Replace(%s)" , strings .Join (p .Columns , ", " ))
228
+ _ = pr .WriteNode ("Replace(%s)" , strings .Join (p .ColumnNames , ", " ))
211
229
} else {
212
- _ = pr .WriteNode ("Insert(%s)" , strings .Join (p .Columns , ", " ))
230
+ _ = pr .WriteNode ("Insert(%s)" , strings .Join (p .ColumnNames , ", " ))
213
231
}
214
232
_ = pr .WriteChildren (p .Left .String (), p .Right .String ())
215
233
return pr .String ()
@@ -219,16 +237,12 @@ func (p *InsertInto) validateValueCount(ctx *sql.Context) error {
219
237
switch node := p .Right .(type ) {
220
238
case * Values :
221
239
for _ , exprTuple := range node .ExpressionTuples {
222
- if len (exprTuple ) != len (p .Columns ) {
240
+ if len (exprTuple ) != len (p .ColumnNames ) {
223
241
return ErrInsertIntoMismatchValueCount .New ()
224
242
}
225
243
}
226
- case * ResolvedTable :
227
- return p .assertSchemasMatch (node .Schema ())
228
- case * Project :
229
- return p .assertSchemasMatch (node .Schema ())
230
- case * InnerJoin :
231
- return p .assertSchemasMatch (node .Schema ())
244
+ case * ResolvedTable , * Project , * InnerJoin :
245
+ return p .assertColumnCountsMatch (node .Schema ())
232
246
default :
233
247
return ErrInsertIntoUnsupportedValues .New (node )
234
248
}
@@ -241,7 +255,7 @@ func (p *InsertInto) validateColumns(dstSchema sql.Schema) error {
241
255
dstColNames [dstCol .Name ] = struct {}{}
242
256
}
243
257
columnNames := make (map [string ]struct {})
244
- for _ , columnName := range p .Columns {
258
+ for _ , columnName := range p .ColumnNames {
245
259
if _ , exists := dstColNames [columnName ]; ! exists {
246
260
return ErrInsertIntoNonexistentColumn .New (columnName )
247
261
}
@@ -263,9 +277,27 @@ func (p *InsertInto) validateNullability(dstSchema sql.Schema, row sql.Row) erro
263
277
return nil
264
278
}
265
279
266
- func (p * InsertInto ) assertSchemasMatch (schema sql.Schema ) error {
267
- if len (p .Columns ) != len (schema ) {
280
+ func (p * InsertInto ) assertColumnCountsMatch (schema sql.Schema ) error {
281
+ if len (p .ColumnNames ) != len (schema ) {
268
282
return ErrInsertIntoMismatchValueCount .New ()
269
283
}
270
284
return nil
271
285
}
286
+
287
+ func assertCompatibleSchemas (projExprs []sql.Expression , schema sql.Schema ) error {
288
+ for _ , expr := range projExprs {
289
+ switch e := expr .(type ) {
290
+ case * expression.Literal :
291
+ continue
292
+ case * expression.GetField :
293
+ otherCol := schema [e .Index ()]
294
+ _ , err := otherCol .Type .Convert (expr .Type ().Zero ())
295
+ if err != nil {
296
+ return ErrInsertIntoIncompatibleTypes .New (otherCol .Type .String (), expr .Type ().String ())
297
+ }
298
+ default :
299
+ return ErrInsertIntoUnsupportedValues .New (expr )
300
+ }
301
+ }
302
+ return nil
303
+ }
0 commit comments