Skip to content

Apply foreign key constraints to UPDATE JOIN #3037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions enginetest/queries/update_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,8 @@ var UpdateScriptTests = []ScriptTest{
},
Assertions: []ScriptTestAssertion{
{
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
Skip: true,
Query: "UPDATE orders o JOIN customers c ON o.customer_id = c.id SET o.customer_id = 123 where o.customer_id != 1;",
ExpectedErr: sql.ErrCheckConstraintViolated,
ExpectedErr: sql.ErrForeignKeyChildViolation,
},
{
Query: "SELECT * FROM orders;",
Expand All @@ -510,16 +508,12 @@ var UpdateScriptTests = []ScriptTest{
},
Assertions: []ScriptTestAssertion{
{
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
Skip: true,
Query: `UPDATE child1 c1
JOIN child2 c2 ON c1.id = 10 AND c2.id = 20
SET c1.p1_id = 999, c2.p2_id = 3;`,
ExpectedErr: sql.ErrForeignKeyChildViolation,
},
{
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
Skip: true,
Query: `UPDATE child1 c1
JOIN child2 c2 ON c1.id = 10 AND c2.id = 20
SET c1.p1_id = 3, c2.p2_id = 999;`,
Expand Down
78 changes: 58 additions & 20 deletions sql/analyzer/apply_foreign_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,27 +128,41 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
if err != nil {
return nil, transform.SameTree, err
}
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
// If foreign keys aren't supported then we return
if !ok {
return n, transform.SameTree, nil
}

fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
if err != nil {
return nil, transform.SameTree, err
}
if fkEditor == nil {
return n, transform.SameTree, nil
switch updateDest.(type) {
case *plan.UpdatableJoinTable:
updateTargets := updateDest.(*plan.UpdatableJoinTable).UpdateTargets
fkHandlerMap := make(map[string]sql.Node, len(updateTargets))
for tableName, updateTarget := range updateTargets {
fkHandlerMap[tableName] = updateTarget
updateDest, err := plan.GetUpdatable(updateTarget)
if err != nil {
return nil, transform.SameTree, err
}
fkHandler, err :=
getForeignKeyHandlerFromUpdateDestination(updateDest, ctx, a, cache, fkChain, updateTarget)
if err != nil {
return nil, transform.SameTree, err
}
if fkHandler == nil {
fkHandlerMap[tableName] = updateTarget
} else {
fkHandlerMap[tableName] = fkHandler
}
}
uj := plan.NewUpdateJoin(fkHandlerMap, n.Child.(*plan.UpdateJoin).Child)
nn, err := n.WithChildren(uj)
return nn, transform.NewTree, err
default:
fkHandler, err := getForeignKeyHandlerFromUpdateDestination(updateDest, ctx, a, cache, fkChain, n.Child)
if err != nil {
return nil, transform.SameTree, err
}
if fkHandler == nil {
return n, transform.SameTree, nil
}
nn, err := n.WithChildren(fkHandler)
return nn, transform.NewTree, err
}
nn, err := n.WithChildren(&plan.ForeignKeyHandler{
Table: fkTbl,
Sch: updateDest.Schema(),
OriginalNode: n.Child,
Editor: fkEditor,
AllUpdaters: fkChain.GetUpdaters(),
})
return nn, transform.NewTree, err
case *plan.DeleteFrom:
if plan.IsEmptyTable(n.Child) {
return n, transform.SameTree, nil
Expand Down Expand Up @@ -445,6 +459,30 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
return fkEditor, nil
}

func getForeignKeyHandlerFromUpdateDestination(updateDest sql.UpdatableTable, ctx *sql.Context, a *Analyzer,
cache *foreignKeyCache, fkChain foreignKeyChain, originalNode sql.Node) (*plan.ForeignKeyHandler, error) {
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
if !ok {
return nil, nil
}

fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
if err != nil {
return nil, err
}
if fkEditor == nil {
return nil, nil
}

return &plan.ForeignKeyHandler{
Table: fkTbl,
Sch: updateDest.Schema(),
OriginalNode: originalNode,
Editor: fkEditor,
AllUpdaters: fkChain.GetUpdaters(),
}, nil
}

// resolveSchemaDefaults resolves the default values for the schema of |table|. This is primarily needed for column
// default value expressions, since those don't get resolved during the planbuilder phase and assignExecIndexes
// doesn't traverse through the ForeignKeyEditors and referential actions to find all of them. In addition to resolving
Expand Down
14 changes: 7 additions & 7 deletions sql/analyzer/assign_update_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
return n, transform.SameTree, nil
}

updaters, err := rowUpdatersByTable(ctx, us, jn)
updateTargets, err := getUpdateTargetsByTable(us, jn)
if err != nil {
return nil, transform.SameTree, err
}

uj := plan.NewUpdateJoin(updaters, us)
uj := plan.NewUpdateJoin(updateTargets, us)
ret, err := n.WithChildren(uj)
if err != nil {
return nil, transform.SameTree, err
Expand All @@ -51,12 +51,12 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
return n, transform.SameTree, nil
}

// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
// getUpdateTargetsByTable maps a set of table names to their corresponding update target Node
func getUpdateTargetsByTable(node sql.Node, ij sql.Node) (map[string]sql.Node, error) {
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
resolvedTables := getTablesByName(ij)

rowUpdatersByTable := make(map[string]sql.RowUpdater)
updateTargets := make(map[string]sql.Node)
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
resolvedTable, ok := resolvedTables[tableToBeUpdated]
if !ok {
Expand All @@ -76,10 +76,10 @@ func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[strin
return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
}

rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx)
updateTargets[tableToBeUpdated] = resolvedTable
}

return rowUpdatersByTable, nil
return updateTargets, nil
}

// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
Expand Down
66 changes: 39 additions & 27 deletions sql/plan/update_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ import (
)

type UpdateJoin struct {
Updaters map[string]sql.RowUpdater
UpdateTargets map[string]sql.Node
UnaryNode
}

// NewUpdateJoin returns an *UpdateJoin node.
func NewUpdateJoin(editorMap map[string]sql.RowUpdater, child sql.Node) *UpdateJoin {
// NewUpdateJoin returns a new *UpdateJoin node.
func NewUpdateJoin(updateTargets map[string]sql.Node, child sql.Node) *UpdateJoin {
return &UpdateJoin{
Updaters: editorMap,
UnaryNode: UnaryNode{Child: child},
UpdateTargets: updateTargets,
UnaryNode: UnaryNode{Child: child},
}
}

Expand All @@ -54,14 +54,9 @@ func (u *UpdateJoin) DebugString() string {

// GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable.
func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
// TODO: UpdateJoin can update multiple tables, but this interface only allows for a single table.
// Additionally, updatableJoinTable doesn't implement interfaces that other parts of the code
// expect, so UpdateJoins don't always work correctly. For example, because updatableJoinTable
// doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks.
// We should revamp this function so that we can communicate multiple tables being updated.
return &updatableJoinTable{
updaters: u.Updaters,
joinNode: u.Child.(*UpdateSource).Child,
return &UpdatableJoinTable{
UpdateTargets: u.UpdateTargets,
joinNode: u.Child.(*UpdateSource).Child,
}
}

Expand All @@ -71,7 +66,7 @@ func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) {
return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1)
}

return NewUpdateJoin(u.Updaters, children[0]), nil
return NewUpdateJoin(u.UpdateTargets, children[0]), nil
}

func (u *UpdateJoin) IsReadOnly() bool {
Expand All @@ -83,48 +78,65 @@ func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.Coll
return sql.GetCoercibility(ctx, u.Child)
}

// updatableJoinTable manages the update of multiple tables.
type updatableJoinTable struct {
updaters map[string]sql.RowUpdater
joinNode sql.Node
func (u *UpdateJoin) GetUpdaters(ctx *sql.Context) (map[string]sql.RowUpdater, error) {
return getUpdaters(u.UpdateTargets, ctx)
}

var _ sql.UpdatableTable = (*updatableJoinTable)(nil)
func getUpdaters(updateTargets map[string]sql.Node, ctx *sql.Context) (map[string]sql.RowUpdater, error) {
updaterMap := make(map[string]sql.RowUpdater)
for tableName, updateTarget := range updateTargets {
updatable, err := GetUpdatable(updateTarget)
if err != nil {
return nil, err
}
updaterMap[tableName] = updatable.Updater(ctx)
}
return updaterMap, nil
}

// UpdatableJoinTable manages the update of multiple tables.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for fixing foreign keys, but as we look at making triggers work correctly, it's worth thinking about if this is the right type here... UpdatableJoinTable implements sql.UpdateableTable, but many of it's methods panic and can't be used (e.g. Name() doesn't make sense when this can actually represent multiple tables with different names). That's a good sign that this isn't modeled as well as it should be. This also causes a problem in Doltgres, where the code is assuming that since it's an sql.UpdatableTable implementation, that it can act as a sql.Table and call other methods like Name().

No need to change this now, but I'd dig into cleaning this incongruence when we look at getting triggers to work across all tables.

type UpdatableJoinTable struct {
UpdateTargets map[string]sql.Node
joinNode sql.Node
}

var _ sql.UpdatableTable = (*UpdatableJoinTable)(nil)

// Partitions implements the sql.UpdatableTable interface.
func (u *updatableJoinTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
func (u *UpdatableJoinTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
panic("this method should not be called")
}

// PartitionsRows implements the sql.UpdatableTable interface.
func (u *updatableJoinTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
func (u *UpdatableJoinTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
panic("this method should not be called")
}

// Name implements the sql.UpdatableTable interface.
func (u *updatableJoinTable) Name() string {
func (u *UpdatableJoinTable) Name() string {
panic("this method should not be called")
}

// String implements the sql.UpdatableTable interface.
func (u *updatableJoinTable) String() string {
func (u *UpdatableJoinTable) String() string {
panic("this method should not be called")
}

// Schema implements the sql.UpdatableTable interface.
func (u *updatableJoinTable) Schema() sql.Schema {
func (u *UpdatableJoinTable) Schema() sql.Schema {
return u.joinNode.Schema()
}

// Collation implements the sql.Table interface.
func (u *updatableJoinTable) Collation() sql.CollationID {
func (u *UpdatableJoinTable) Collation() sql.CollationID {
return sql.Collation_Default
}

// Updater implements the sql.UpdatableTable interface.
func (u *updatableJoinTable) Updater(ctx *sql.Context) sql.RowUpdater {
func (u *UpdatableJoinTable) Updater(ctx *sql.Context) sql.RowUpdater {
updaters, _ := getUpdaters(u.UpdateTargets, ctx)
return &updatableJoinUpdater{
updaterMap: u.updaters,
updaterMap: updaters,
schemaMap: RecreateTableSchemaFromJoinSchema(u.joinNode.Schema()),
joinSchema: u.joinNode.Schema(),
}
Expand Down
6 changes: 5 additions & 1 deletion sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,14 @@ func (b *BaseBuilder) buildUpdateJoin(ctx *sql.Context, n *plan.UpdateJoin, row
return nil, err
}

updaters, err := n.GetUpdaters(ctx)
if err != nil {
return nil, err
}
return &updateJoinIter{
updateSourceIter: ji,
joinSchema: n.Child.(*plan.UpdateSource).Child.Schema(),
updaters: n.Updaters,
updaters: updaters,
caches: make(map[string]sql.KeyValueCache),
disposals: make(map[string]sql.DisposeFunc),
joinNode: n.Child.(*plan.UpdateSource).Child,
Expand Down
Loading