Skip to content

Commit 17eca8e

Browse files
authored
Merge pull request #24 from liquidata-inc/zachmu/insert-into-select
Support for "INSERT INTO table1 SELECT * FROM table2" statements.
2 parents 148b4e2 + 60106c4 commit 17eca8e

File tree

3 files changed

+162
-14
lines changed

3 files changed

+162
-14
lines changed

engine_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@ var queries = []struct {
3333
"SELECT i FROM mytable;",
3434
[]sql.Row{{int64(1)}, {int64(2)}, {int64(3)}},
3535
},
36+
{
37+
"SELECT s,i FROM mytable;",
38+
[]sql.Row{
39+
{"first row", int64(1)},
40+
{"second row", int64(2)},
41+
{"third row", int64(3)}},
42+
},
43+
{
44+
"SELECT s,i FROM (select i,s from mytable) mt;",
45+
[]sql.Row{
46+
{"first row", int64(1)},
47+
{"second row", int64(2)},
48+
{"third row", int64(3)}},
49+
},
3650
{
3751
"SELECT i + 1 FROM mytable;",
3852
[]sql.Row{{int64(2)}, {int64(3)}, {int64(4)}},
@@ -2113,6 +2127,97 @@ func TestInsertInto(t *testing.T) {
21132127
"SELECT * FROM typestable WHERE id = 999;",
21142128
[]sql.Row{{int64(999), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}},
21152129
},
2130+
{
2131+
"INSERT INTO mytable SELECT * from mytable",
2132+
[]sql.Row{{int64(3)}},
2133+
"SELECT * FROM mytable order by i",
2134+
[]sql.Row{
2135+
{int64(1), "first row"},
2136+
{int64(1), "first row"},
2137+
{int64(2), "second row"},
2138+
{int64(2), "second row"},
2139+
{int64(3), "third row"},
2140+
{int64(3), "third row"},
2141+
},
2142+
},
2143+
{
2144+
"INSERT INTO mytable(i,s) SELECT * from mytable",
2145+
[]sql.Row{{int64(3)}},
2146+
"SELECT * FROM mytable order by i",
2147+
[]sql.Row{
2148+
{int64(1), "first row"},
2149+
{int64(1), "first row"},
2150+
{int64(2), "second row"},
2151+
{int64(2), "second row"},
2152+
{int64(3), "third row"},
2153+
{int64(3), "third row"},
2154+
},
2155+
},
2156+
{
2157+
"INSERT INTO mytable (i,s) SELECT i+10, 'new' from mytable",
2158+
[]sql.Row{{int64(3)}},
2159+
"SELECT * FROM mytable order by i",
2160+
[]sql.Row{
2161+
{int64(1), "first row"},
2162+
{int64(2), "second row"},
2163+
{int64(3), "third row"},
2164+
{int64(11), "new"},
2165+
{int64(12), "new"},
2166+
{int64(13), "new"},
2167+
},
2168+
},
2169+
{
2170+
"INSERT INTO mytable SELECT i2, s2 from othertable",
2171+
[]sql.Row{{int64(3)}},
2172+
"SELECT * FROM mytable order by i,s",
2173+
[]sql.Row{
2174+
{int64(1), "first row"},
2175+
{int64(1), "third"},
2176+
{int64(2), "second"},
2177+
{int64(2), "second row"},
2178+
{int64(3), "first"},
2179+
{int64(3), "third row"},
2180+
},
2181+
},
2182+
{
2183+
"INSERT INTO mytable (s,i) SELECT * from othertable",
2184+
[]sql.Row{{int64(3)}},
2185+
"SELECT * FROM mytable order by i,s",
2186+
[]sql.Row{
2187+
{int64(1), "first row"},
2188+
{int64(1), "third"},
2189+
{int64(2), "second"},
2190+
{int64(2), "second row"},
2191+
{int64(3), "first"},
2192+
{int64(3), "third row"},
2193+
},
2194+
},
2195+
{
2196+
"INSERT INTO mytable (s,i) SELECT concat(m.s, o.s2), m.i from othertable o join mytable m on m.i=o.i2",
2197+
[]sql.Row{{int64(3)}},
2198+
"SELECT * FROM mytable order by i,s",
2199+
[]sql.Row{
2200+
{int64(1), "first row"},
2201+
{int64(1), "first rowthird"},
2202+
{int64(2), "second row"},
2203+
{int64(2), "second rowsecond"},
2204+
{int64(3), "third row"},
2205+
{int64(3), "third rowfirst"},
2206+
},
2207+
},
2208+
{
2209+
"INSERT INTO mytable (i,s) SELECT (i + 10.0) / 10.0 + 10, concat(s, ' new') from mytable",
2210+
[]sql.Row{{int64(3)}},
2211+
"SELECT * FROM mytable order by i, s",
2212+
[]sql.Row{
2213+
{int64(1), "first row"},
2214+
{int64(2), "second row"},
2215+
{int64(3), "third row"},
2216+
{int64(11), "first row new"},
2217+
{int64(11), "second row new"},
2218+
{int64(11), "third row new"},
2219+
},
2220+
},
21162221
}
21172222

21182223
for _, insertion := range insertions {
@@ -2168,6 +2273,22 @@ func TestInsertIntoErrors(t *testing.T) {
21682273
"null given to non-nullable",
21692274
"INSERT INTO mytable (i, s) VALUES (null, 'y');",
21702275
},
2276+
{
2277+
"incompatible types",
2278+
"INSERT INTO mytable (i, s) select * from othertable",
2279+
},
2280+
{
2281+
"column count mismatch in select",
2282+
"INSERT INTO mytable (i) select * from othertable",
2283+
},
2284+
{
2285+
"column count mismatch in select",
2286+
"INSERT INTO mytable select s from othertable",
2287+
},
2288+
{
2289+
"column count mismatch in join select",
2290+
"INSERT INTO mytable (s,i) SELECT * from othertable o join mytable m on m.i=o.i2",
2291+
},
21712292
}
21722293

21732294
for _, expectedFailure := range expectedFailures {

sql/analyzer/prune_columns.go

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
1616
return n, nil
1717
}
1818

19+
// Skip pruning columns for insert statements. For inserts involving a select (INSERT INTO table1 SELECT a,b FROM
20+
// table2), all columns from the select are used for the insert, and error checking for schema compatibility
21+
// happens at execution time. Otherwise the logic below will convert a Project to a ResolvedTable for the selected
22+
// table, which can alter the column order of the select.
23+
if _, ok := n.(*plan.InsertInto); ok {
24+
return n, nil
25+
}
26+
1927
if describe, ok := n.(*plan.DescribeQuery); ok {
2028
pruned, err := pruneColumns(ctx, a, describe.Child)
2129
if err != nil {
@@ -25,16 +33,7 @@ func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
2533
return plan.NewDescribeQuery(describe.Format, pruned), nil
2634
}
2735

28-
columns := make(usedColumns)
29-
30-
// All the columns required for the output of the query must be mark as
31-
// used, otherwise the schema would change.
32-
for _, col := range n.Schema() {
33-
if _, ok := columns[col.Source]; !ok {
34-
columns[col.Source] = make(map[string]struct{})
35-
}
36-
columns[col.Source][col.Name] = struct{}{}
37-
}
36+
columns := findRequiredColumns(n)
3837

3938
findUsedColumns(columns, n)
4039

@@ -51,6 +50,21 @@ func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
5150
return fixRemainingFieldsIndexes(n)
5251
}
5352

53+
func findRequiredColumns(n sql.Node) usedColumns {
54+
columns := make(usedColumns)
55+
56+
// All the columns required for the output of the query must be mark as
57+
// used, otherwise the schema would change.
58+
for _, col := range n.Schema() {
59+
if _, ok := columns[col.Source]; !ok {
60+
columns[col.Source] = make(map[string]struct{})
61+
}
62+
columns[col.Source][col.Name] = struct{}{}
63+
}
64+
65+
return columns
66+
}
67+
5468
func pruneSubqueryColumns(
5569
ctx *sql.Context,
5670
a *Analyzer,

sql/plan/insert.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
9292
p.Columns[i] = f.Name
9393
}
9494
} else {
95-
err = p.validateColumns(ctx, dstSchema)
95+
err = p.validateColumns(dstSchema)
9696
if err != nil {
9797
return 0, err
9898
}
@@ -139,7 +139,7 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
139139
return i, err
140140
}
141141

142-
err = p.validateNullability(ctx, dstSchema, row)
142+
err = p.validateNullability(dstSchema, row)
143143
if err != nil {
144144
_ = iter.Close()
145145
return i, err
@@ -223,13 +223,19 @@ func (p *InsertInto) validateValueCount(ctx *sql.Context) error {
223223
return ErrInsertIntoMismatchValueCount.New()
224224
}
225225
}
226+
case *ResolvedTable:
227+
return p.assertSchemasMatch(node.Schema())
228+
case *Project:
229+
return p.assertSchemasMatch(node.Schema())
230+
case *InnerJoin:
231+
return p.assertSchemasMatch(node.Schema())
226232
default:
227233
return ErrInsertIntoUnsupportedValues.New(node)
228234
}
229235
return nil
230236
}
231237

232-
func (p *InsertInto) validateColumns(ctx *sql.Context, dstSchema sql.Schema) error {
238+
func (p *InsertInto) validateColumns(dstSchema sql.Schema) error {
233239
dstColNames := make(map[string]struct{})
234240
for _, dstCol := range dstSchema {
235241
dstColNames[dstCol.Name] = struct{}{}
@@ -248,11 +254,18 @@ func (p *InsertInto) validateColumns(ctx *sql.Context, dstSchema sql.Schema) err
248254
return nil
249255
}
250256

251-
func (p *InsertInto) validateNullability(ctx *sql.Context, dstSchema sql.Schema, row sql.Row) error {
257+
func (p *InsertInto) validateNullability(dstSchema sql.Schema, row sql.Row) error {
252258
for i, col := range dstSchema {
253259
if !col.Nullable && row[i] == nil {
254260
return ErrInsertIntoNonNullableProvidedNull.New(col.Name)
255261
}
256262
}
257263
return nil
258264
}
265+
266+
func (p *InsertInto) assertSchemasMatch(schema sql.Schema) error {
267+
if len(p.Columns) != len(schema) {
268+
return ErrInsertIntoMismatchValueCount.New()
269+
}
270+
return nil
271+
}

0 commit comments

Comments
 (0)