Skip to content

Commit d1ae6a2

Browse files
authored
Merge pull request #25 from liquidata-inc/zachmu/better-type-checking
Better type checking for insert statements. Introduced a Zero() method for SQL types to assist checking type compatibility before actually inserting any values.
2 parents 17eca8e + daa164a commit d1ae6a2

File tree

3 files changed

+156
-36
lines changed

3 files changed

+156
-36
lines changed

engine_test.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,12 @@ func TestInsertInto(t *testing.T) {
20232023
"SELECT i FROM mytable WHERE s = 'x';",
20242024
[]sql.Row{{int64(999)}},
20252025
},
2026+
{
2027+
"INSERT INTO niltable (f) VALUES (10.0), (12.0);",
2028+
[]sql.Row{{int64(2)}},
2029+
"SELECT f FROM niltable WHERE f in (10.0, 12.0) order by f;",
2030+
[]sql.Row{{10.0}, {12.0}},
2031+
},
20262032
{
20272033
"INSERT INTO mytable SET s = 'x', i = 999;",
20282034
[]sql.Row{{int64(1)}},
@@ -2054,9 +2060,9 @@ func TestInsertInto(t *testing.T) {
20542060
[]sql.Row{{
20552061
int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64),
20562062
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
2057-
float64(math.MaxFloat32), float64(math.MaxFloat64),
2063+
float32(math.MaxFloat32), float64(math.MaxFloat64),
20582064
timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"),
2059-
"random text", true, `{"key":"value"}`, "blobdata",
2065+
"random text", true, ([]byte)(`{"key":"value"}`), ([]byte)("blobdata"),
20602066
}},
20612067
},
20622068
{
@@ -2072,9 +2078,9 @@ func TestInsertInto(t *testing.T) {
20722078
[]sql.Row{{
20732079
int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64),
20742080
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
2075-
float64(math.MaxFloat32), float64(math.MaxFloat64),
2081+
float32(math.MaxFloat32), float64(math.MaxFloat64),
20762082
timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"),
2077-
"random text", true, `{"key":"value"}`, "blobdata",
2083+
"random text", true, ([]byte)(`{"key":"value"}`), ([]byte)("blobdata"),
20782084
}},
20792085
},
20802086
{
@@ -2090,9 +2096,9 @@ func TestInsertInto(t *testing.T) {
20902096
[]sql.Row{{
20912097
int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
20922098
uint8(0), uint16(0), uint32(0), uint64(0),
2093-
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
2099+
float32(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
20942100
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
2095-
"", false, ``, "",
2101+
"", false, ([]byte)(`""`), ([]byte)(""),
20962102
}},
20972103
},
20982104
{
@@ -2108,9 +2114,9 @@ func TestInsertInto(t *testing.T) {
21082114
[]sql.Row{{
21092115
int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
21102116
uint8(0), uint16(0), uint32(0), uint64(0),
2111-
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
2117+
float32(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
21122118
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
2113-
"", false, ``, "",
2119+
"", false, ([]byte)(`""`), ([]byte)(""),
21142120
}},
21152121
},
21162122
{
@@ -3070,7 +3076,8 @@ func testQueryWithContext(ctx *sql.Context, t *testing.T, e *sqle.Engine, q stri
30703076
rows, err := sql.RowIterToRows(iter)
30713077
require.NoError(err)
30723078

3073-
if orderBy {
3079+
// .Equal gives better error messages than .ElementsMatch, so use it when possible
3080+
if orderBy || len(rows) == 1 {
30743081
require.Equal(expected, rows)
30753082
} else {
30763083
require.ElementsMatch(expected, rows)

sql/plan/insert.go

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@ var ErrInsertIntoDuplicateColumn = errors.NewKind("duplicate column name %v")
1717
var ErrInsertIntoNonexistentColumn = errors.NewKind("invalid column name %v")
1818
var ErrInsertIntoNonNullableDefaultNullColumn = errors.NewKind("column name '%v' is non-nullable but attempted to set default value of null")
1919
var ErrInsertIntoNonNullableProvidedNull = errors.NewKind("column name '%v' is non-nullable but attempted to set a value of null")
20+
var ErrInsertIntoIncompatibleTypes = errors.NewKind("cannot convert type %s to %s")
2021

2122
// InsertInto is a node describing the insertion into some table.
2223
type InsertInto struct {
2324
BinaryNode
24-
Columns []string
25-
IsReplace bool
25+
ColumnNames []string
26+
IsReplace bool
2627
}
2728

2829
// NewInsertInto creates an InsertInto node.
2930
func NewInsertInto(dst, src sql.Node, isReplace bool, cols []string) *InsertInto {
3031
return &InsertInto{
31-
BinaryNode: BinaryNode{Left: dst, Right: src},
32-
Columns: cols,
33-
IsReplace: isReplace,
32+
BinaryNode: BinaryNode{Left: dst, Right: src},
33+
ColumnNames: cols,
34+
IsReplace: isReplace,
3435
}
3536
}
3637

@@ -83,13 +84,12 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
8384
}
8485

8586
dstSchema := p.Left.Schema()
86-
projExprs := make([]sql.Expression, len(dstSchema))
8787

8888
// If no columns are given, we assume the full schema in order
89-
if len(p.Columns) == 0 {
90-
p.Columns = make([]string, len(dstSchema))
89+
if len(p.ColumnNames) == 0 {
90+
p.ColumnNames = make([]string, len(dstSchema))
9191
for i, f := range dstSchema {
92-
p.Columns[i] = f.Name
92+
p.ColumnNames[i] = f.Name
9393
}
9494
} else {
9595
err = p.validateColumns(dstSchema)
@@ -103,9 +103,10 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
103103
return 0, err
104104
}
105105

106+
projExprs := make([]sql.Expression, len(dstSchema))
106107
for i, f := range dstSchema {
107108
found := false
108-
for j, col := range p.Columns {
109+
for j, col := range p.ColumnNames {
109110
if f.Name == col {
110111
projExprs[i] = expression.NewGetField(j, f.Type, f.Name, f.Nullable)
111112
found = true
@@ -121,9 +122,12 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
121122
}
122123
}
123124

124-
proj := NewProject(projExprs, p.Right)
125+
rowSource, err := p.rowSource(projExprs)
126+
if err != nil {
127+
return 0, err
128+
}
125129

126-
iter, err := proj.RowIter(ctx)
130+
iter, err := rowSource.RowIter(ctx)
127131
if err != nil {
128132
return 0, err
129133
}
@@ -145,11 +149,11 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
145149
return i, err
146150
}
147151

148-
// Convert integer values in row to specified type in schema
152+
// Convert values to the destination schema type
149153
for colIdx, oldValue := range row {
150154
dstColType := projExprs[colIdx].Type()
151155

152-
if sql.IsInteger(dstColType) && oldValue != nil {
156+
if oldValue != nil {
153157
newValue, err := dstColType.Convert(oldValue)
154158
if err != nil {
155159
return i, err
@@ -185,6 +189,20 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
185189
return i, nil
186190
}
187191

192+
func (p *InsertInto) rowSource(projExprs []sql.Expression) (sql.Node, error) {
193+
switch n := p.Right.(type) {
194+
case *Values:
195+
return NewProject(projExprs, n), nil
196+
case *ResolvedTable, *Project, *InnerJoin:
197+
if err := assertCompatibleSchemas(projExprs, n.Schema()); err != nil {
198+
return nil, err
199+
}
200+
return NewProject(projExprs, n), nil
201+
default:
202+
return nil, ErrInsertIntoUnsupportedValues.New(n)
203+
}
204+
}
205+
188206
// RowIter implements the Node interface.
189207
func (p *InsertInto) RowIter(ctx *sql.Context) (sql.RowIter, error) {
190208
n, err := p.Execute(ctx)
@@ -201,15 +219,15 @@ func (p *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) {
201219
return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2)
202220
}
203221

204-
return NewInsertInto(children[0], children[1], p.IsReplace, p.Columns), nil
222+
return NewInsertInto(children[0], children[1], p.IsReplace, p.ColumnNames), nil
205223
}
206224

207225
func (p InsertInto) String() string {
208226
pr := sql.NewTreePrinter()
209227
if p.IsReplace {
210-
_ = pr.WriteNode("Replace(%s)", strings.Join(p.Columns, ", "))
228+
_ = pr.WriteNode("Replace(%s)", strings.Join(p.ColumnNames, ", "))
211229
} else {
212-
_ = pr.WriteNode("Insert(%s)", strings.Join(p.Columns, ", "))
230+
_ = pr.WriteNode("Insert(%s)", strings.Join(p.ColumnNames, ", "))
213231
}
214232
_ = pr.WriteChildren(p.Left.String(), p.Right.String())
215233
return pr.String()
@@ -219,16 +237,12 @@ func (p *InsertInto) validateValueCount(ctx *sql.Context) error {
219237
switch node := p.Right.(type) {
220238
case *Values:
221239
for _, exprTuple := range node.ExpressionTuples {
222-
if len(exprTuple) != len(p.Columns) {
240+
if len(exprTuple) != len(p.ColumnNames) {
223241
return ErrInsertIntoMismatchValueCount.New()
224242
}
225243
}
226-
case *ResolvedTable:
227-
return p.assertSchemasMatch(node.Schema())
228-
case *Project:
229-
return p.assertSchemasMatch(node.Schema())
230-
case *InnerJoin:
231-
return p.assertSchemasMatch(node.Schema())
244+
case *ResolvedTable, *Project, *InnerJoin:
245+
return p.assertColumnCountsMatch(node.Schema())
232246
default:
233247
return ErrInsertIntoUnsupportedValues.New(node)
234248
}
@@ -241,7 +255,7 @@ func (p *InsertInto) validateColumns(dstSchema sql.Schema) error {
241255
dstColNames[dstCol.Name] = struct{}{}
242256
}
243257
columnNames := make(map[string]struct{})
244-
for _, columnName := range p.Columns {
258+
for _, columnName := range p.ColumnNames {
245259
if _, exists := dstColNames[columnName]; !exists {
246260
return ErrInsertIntoNonexistentColumn.New(columnName)
247261
}
@@ -263,9 +277,27 @@ func (p *InsertInto) validateNullability(dstSchema sql.Schema, row sql.Row) erro
263277
return nil
264278
}
265279

266-
func (p *InsertInto) assertSchemasMatch(schema sql.Schema) error {
267-
if len(p.Columns) != len(schema) {
280+
func (p *InsertInto) assertColumnCountsMatch(schema sql.Schema) error {
281+
if len(p.ColumnNames) != len(schema) {
268282
return ErrInsertIntoMismatchValueCount.New()
269283
}
270284
return nil
271285
}
286+
287+
func assertCompatibleSchemas(projExprs []sql.Expression, schema sql.Schema) error {
288+
for _, expr := range projExprs {
289+
switch e := expr.(type) {
290+
case *expression.Literal:
291+
continue
292+
case *expression.GetField:
293+
otherCol := schema[e.Index()]
294+
_, err := otherCol.Type.Convert(expr.Type().Zero())
295+
if err != nil {
296+
return ErrInsertIntoIncompatibleTypes.New(otherCol.Type.String(), expr.Type().String())
297+
}
298+
default:
299+
return ErrInsertIntoUnsupportedValues.New(expr)
300+
}
301+
}
302+
return nil
303+
}

0 commit comments

Comments
 (0)