Skip to content

Commit 220353c

Browse files
authored
Merge pull request #3037 from dolthub/updatejoin
Apply foreign key constraints to `UPDATE JOIN`
2 parents 0270726 + bd4ba29 commit 220353c

File tree

5 files changed

+95
-51
lines changed

5 files changed

+95
-51
lines changed

enginetest/queries/update_queries.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -482,10 +482,8 @@ var UpdateScriptTests = []ScriptTest{
482482
},
483483
Assertions: []ScriptTestAssertion{
484484
{
485-
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
486-
Skip: true,
487485
Query: "UPDATE orders o JOIN customers c ON o.customer_id = c.id SET o.customer_id = 123 where o.customer_id != 1;",
488-
ExpectedErr: sql.ErrCheckConstraintViolated,
486+
ExpectedErr: sql.ErrForeignKeyChildViolation,
489487
},
490488
{
491489
Query: "SELECT * FROM orders;",
@@ -510,16 +508,12 @@ var UpdateScriptTests = []ScriptTest{
510508
},
511509
Assertions: []ScriptTestAssertion{
512510
{
513-
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
514-
Skip: true,
515511
Query: `UPDATE child1 c1
516512
JOIN child2 c2 ON c1.id = 10 AND c2.id = 20
517513
SET c1.p1_id = 999, c2.p2_id = 3;`,
518514
ExpectedErr: sql.ErrForeignKeyChildViolation,
519515
},
520516
{
521-
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
522-
Skip: true,
523517
Query: `UPDATE child1 c1
524518
JOIN child2 c2 ON c1.id = 10 AND c2.id = 20
525519
SET c1.p1_id = 3, c2.p2_id = 999;`,

sql/analyzer/apply_foreign_keys.go

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -122,32 +122,35 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
122122
if plan.IsEmptyTable(n.Child) {
123123
return n, transform.SameTree, nil
124124
}
125-
// TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement
126-
// sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements.
127-
updateDest, err := plan.GetUpdatable(n.Child)
128-
if err != nil {
129-
return nil, transform.SameTree, err
130-
}
131-
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
132-
// If foreign keys aren't supported then we return
133-
if !ok {
134-
return n, transform.SameTree, nil
125+
if n.IsJoin {
126+
uj := n.Child.(*plan.UpdateJoin)
127+
updateTargets := uj.UpdateTargets
128+
fkHandlerMap := make(map[string]sql.Node, len(updateTargets))
129+
for tableName, updateTarget := range updateTargets {
130+
fkHandlerMap[tableName] = updateTarget
131+
fkHandler, err :=
132+
getForeignKeyHandlerFromUpdateTarget(ctx, a, updateTarget, cache, fkChain)
133+
if err != nil {
134+
return nil, transform.SameTree, err
135+
}
136+
if fkHandler == nil {
137+
fkHandlerMap[tableName] = updateTarget
138+
} else {
139+
fkHandlerMap[tableName] = fkHandler
140+
}
141+
}
142+
uj = plan.NewUpdateJoin(fkHandlerMap, uj.Child)
143+
nn, err := n.WithChildren(uj)
144+
return nn, transform.NewTree, err
135145
}
136-
137-
fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
146+
fkHandler, err := getForeignKeyHandlerFromUpdateTarget(ctx, a, n.Child, cache, fkChain)
138147
if err != nil {
139148
return nil, transform.SameTree, err
140149
}
141-
if fkEditor == nil {
150+
if fkHandler == nil {
142151
return n, transform.SameTree, nil
143152
}
144-
nn, err := n.WithChildren(&plan.ForeignKeyHandler{
145-
Table: fkTbl,
146-
Sch: updateDest.Schema(),
147-
OriginalNode: n.Child,
148-
Editor: fkEditor,
149-
AllUpdaters: fkChain.GetUpdaters(),
150-
})
153+
nn, err := n.WithChildren(fkHandler)
151154
return nn, transform.NewTree, err
152155
case *plan.DeleteFrom:
153156
if plan.IsEmptyTable(n.Child) {
@@ -445,6 +448,36 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
445448
return fkEditor, nil
446449
}
447450

451+
// getForeignKeyHandlerFromUpdateTarget creates a ForeignKeyHandler from a given update target Node. It is used for
452+
// applying foreign key constraints to Update nodes
453+
func getForeignKeyHandlerFromUpdateTarget(ctx *sql.Context, a *Analyzer, updateTarget sql.Node,
454+
cache *foreignKeyCache, fkChain foreignKeyChain) (*plan.ForeignKeyHandler, error) {
455+
updateDest, err := plan.GetUpdatable(updateTarget)
456+
if err != nil {
457+
return nil, err
458+
}
459+
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
460+
if !ok {
461+
return nil, nil
462+
}
463+
464+
fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
465+
if err != nil {
466+
return nil, err
467+
}
468+
if fkEditor == nil {
469+
return nil, nil
470+
}
471+
472+
return &plan.ForeignKeyHandler{
473+
Table: fkTbl,
474+
Sch: updateDest.Schema(),
475+
OriginalNode: updateTarget,
476+
Editor: fkEditor,
477+
AllUpdaters: fkChain.GetUpdaters(),
478+
}, nil
479+
}
480+
448481
// resolveSchemaDefaults resolves the default values for the schema of |table|. This is primarily needed for column
449482
// default value expressions, since those don't get resolved during the planbuilder phase and assignExecIndexes
450483
// doesn't traverse through the ForeignKeyEditors and referential actions to find all of them. In addition to resolving

sql/analyzer/assign_update_join.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
3434
return n, transform.SameTree, nil
3535
}
3636

37-
updaters, err := rowUpdatersByTable(ctx, us, jn)
37+
n.IsJoin = true
38+
updateTargets, err := getUpdateTargetsByTable(us, jn)
3839
if err != nil {
3940
return nil, transform.SameTree, err
4041
}
4142

42-
uj := plan.NewUpdateJoin(updaters, us)
43+
uj := plan.NewUpdateJoin(updateTargets, us)
4344
ret, err := n.WithChildren(uj)
4445
if err != nil {
4546
return nil, transform.SameTree, err
@@ -51,12 +52,12 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
5152
return n, transform.SameTree, nil
5253
}
5354

54-
// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
55-
func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
55+
// getUpdateTargetsByTable maps a set of table names and aliases to their corresponding update target Node
56+
func getUpdateTargetsByTable(node sql.Node, ij sql.Node) (map[string]sql.Node, error) {
5657
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
5758
resolvedTables := getTablesByName(ij)
5859

59-
rowUpdatersByTable := make(map[string]sql.RowUpdater)
60+
updateTargets := make(map[string]sql.Node)
6061
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
6162
resolvedTable, ok := resolvedTables[tableToBeUpdated]
6263
if !ok {
@@ -76,10 +77,10 @@ func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[strin
7677
return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
7778
}
7879

79-
rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx)
80+
updateTargets[tableToBeUpdated] = resolvedTable
8081
}
8182

82-
return rowUpdatersByTable, nil
83+
return updateTargets, nil
8384
}
8485

8586
// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.

sql/plan/update_join.go

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ import (
2121
)
2222

2323
type UpdateJoin struct {
24-
Updaters map[string]sql.RowUpdater
24+
UpdateTargets map[string]sql.Node
2525
UnaryNode
2626
}
2727

28-
// NewUpdateJoin returns an *UpdateJoin node.
29-
func NewUpdateJoin(editorMap map[string]sql.RowUpdater, child sql.Node) *UpdateJoin {
28+
// NewUpdateJoin returns a new *UpdateJoin node.
29+
func NewUpdateJoin(updateTargets map[string]sql.Node, child sql.Node) *UpdateJoin {
3030
return &UpdateJoin{
31-
Updaters: editorMap,
32-
UnaryNode: UnaryNode{Child: child},
31+
UpdateTargets: updateTargets,
32+
UnaryNode: UnaryNode{Child: child},
3333
}
3434
}
3535

@@ -54,14 +54,9 @@ func (u *UpdateJoin) DebugString() string {
5454

5555
// GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable.
5656
func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
57-
// TODO: UpdateJoin can update multiple tables, but this interface only allows for a single table.
58-
// Additionally, updatableJoinTable doesn't implement interfaces that other parts of the code
59-
// expect, so UpdateJoins don't always work correctly. For example, because updatableJoinTable
60-
// doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks.
61-
// We should revamp this function so that we can communicate multiple tables being updated.
6257
return &updatableJoinTable{
63-
updaters: u.Updaters,
64-
joinNode: u.Child.(*UpdateSource).Child,
58+
updateTargets: u.UpdateTargets,
59+
joinNode: u.Child.(*UpdateSource).Child,
6560
}
6661
}
6762

@@ -71,7 +66,7 @@ func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) {
7166
return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1)
7267
}
7368

74-
return NewUpdateJoin(u.Updaters, children[0]), nil
69+
return NewUpdateJoin(u.UpdateTargets, children[0]), nil
7570
}
7671

7772
func (u *UpdateJoin) IsReadOnly() bool {
@@ -83,10 +78,26 @@ func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.Coll
8378
return sql.GetCoercibility(ctx, u.Child)
8479
}
8580

81+
func (u *UpdateJoin) GetUpdaters(ctx *sql.Context) (map[string]sql.RowUpdater, error) {
82+
return getUpdaters(u.UpdateTargets, ctx)
83+
}
84+
85+
func getUpdaters(updateTargets map[string]sql.Node, ctx *sql.Context) (map[string]sql.RowUpdater, error) {
86+
updaterMap := make(map[string]sql.RowUpdater)
87+
for tableName, updateTarget := range updateTargets {
88+
updatable, err := GetUpdatable(updateTarget)
89+
if err != nil {
90+
return nil, err
91+
}
92+
updaterMap[tableName] = updatable.Updater(ctx)
93+
}
94+
return updaterMap, nil
95+
}
96+
8697
// updatableJoinTable manages the update of multiple tables.
8798
type updatableJoinTable struct {
88-
updaters map[string]sql.RowUpdater
89-
joinNode sql.Node
99+
updateTargets map[string]sql.Node
100+
joinNode sql.Node
90101
}
91102

92103
var _ sql.UpdatableTable = (*updatableJoinTable)(nil)
@@ -123,8 +134,9 @@ func (u *updatableJoinTable) Collation() sql.CollationID {
123134

124135
// Updater implements the sql.UpdatableTable interface.
125136
func (u *updatableJoinTable) Updater(ctx *sql.Context) sql.RowUpdater {
137+
updaters, _ := getUpdaters(u.updateTargets, ctx)
126138
return &updatableJoinUpdater{
127-
updaterMap: u.updaters,
139+
updaterMap: updaters,
128140
schemaMap: RecreateTableSchemaFromJoinSchema(u.joinNode.Schema()),
129141
joinSchema: u.joinNode.Schema(),
130142
}

sql/rowexec/dml.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,14 @@ func (b *BaseBuilder) buildUpdateJoin(ctx *sql.Context, n *plan.UpdateJoin, row
416416
return nil, err
417417
}
418418

419+
updaters, err := n.GetUpdaters(ctx)
420+
if err != nil {
421+
return nil, err
422+
}
419423
return &updateJoinIter{
420424
updateSourceIter: ji,
421425
joinSchema: n.Child.(*plan.UpdateSource).Child.Schema(),
422-
updaters: n.Updaters,
426+
updaters: updaters,
423427
caches: make(map[string]sql.KeyValueCache),
424428
disposals: make(map[string]sql.DisposeFunc),
425429
joinNode: n.Child.(*plan.UpdateSource).Child,

0 commit comments

Comments
 (0)