From f3ff534c54dbfee2b83008ec815b2dd7c4514cd2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Apr 2025 09:44:25 +0800 Subject: [PATCH 01/22] Implement Generics API --- generics.go | 254 +++++++++++++++++++++++++++++++++++++++++ tests/generics_test.go | 173 ++++++++++++++++++++++++++++ 2 files changed, 427 insertions(+) create mode 100644 generics.go create mode 100644 tests/generics_test.go diff --git a/generics.go b/generics.go new file mode 100644 index 000000000..c165cc14c --- /dev/null +++ b/generics.go @@ -0,0 +1,254 @@ +package gorm + +import ( + "context" + "database/sql" + + "gorm.io/gorm/clause" +) + +type Interface[T any] interface { + Raw(sql string, values ...interface{}) ExecInterface[T] + Exec(ctx context.Context, sql string, values ...interface{}) error + CreateInterface[T] +} + +type CreateInterface[T any] interface { + ChainInterface[T] + Table(name string, args ...interface{}) CreateInterface[T] + Create(ctx context.Context, r *T) error + CreateInBatches(ctx context.Context, r *[]T, batchSize int) error +} + +type ChainInterface[T any] interface { + ExecInterface[T] + Scopes(scopes ...func(db *Statement)) ChainInterface[T] + Where(query interface{}, args ...interface{}) ChainInterface[T] + Not(query interface{}, args ...interface{}) ChainInterface[T] + Or(query interface{}, args ...interface{}) ChainInterface[T] + Limit(offset int) ChainInterface[T] + Offset(offset int) ChainInterface[T] + Joins(query string, args ...interface{}) ChainInterface[T] + InnerJoins(query string, args ...interface{}) ChainInterface[T] + Select(query string, args ...interface{}) ChainInterface[T] + Omit(columns ...string) ChainInterface[T] + MapColumns(m map[string]string) ChainInterface[T] + Distinct(args ...interface{}) ChainInterface[T] + Group(name string) ChainInterface[T] + Having(query interface{}, args ...interface{}) ChainInterface[T] + Order(value interface{}) ChainInterface[T] + Preload(query string, args ...interface{}) ChainInterface[T] + + Delete(ctx context.Context) (rowsAffected int, err error) + Update(ctx context.Context, name string, value any) (rowsAffected int, err error) + Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) +} + +type ExecInterface[T any] interface { + Scan(ctx context.Context, r interface{}) error + First(context.Context) (T, error) + Last(ctx context.Context) (T, error) + Find(ctx context.Context) ([]T, error) + FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error + Row(ctx context.Context) *sql.Row + Rows(ctx context.Context) (*sql.Rows, error) +} + +func G[T any](db *DB, opts ...clause.Expression) Interface[T] { + v := &g[T]{ + db: db.Session(&Session{NewDB: true}).Clauses(opts...), + opts: opts, + } + + v.createG = &createG[T]{ + chainG: chainG[T]{ + execG: execG[T]{g: v}, + }, + } + return v +} + +type g[T any] struct { + *createG[T] + db *DB + opts []clause.Expression +} + +func (g *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { + g.db = g.db.Raw(sql, values...) + return &g.execG +} + +func (g *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { + return g.db.WithContext(ctx).Exec(sql, values...).Error +} + +type createG[T any] struct { + chainG[T] +} + +func (g *createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { + g.g.db = g.g.db.Table(name, args...) + return g +} + +func (g *createG[T]) Create(ctx context.Context, r *T) error { + return g.g.db.WithContext(ctx).Create(r).Error +} + +func (g *createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { + return g.g.db.WithContext(ctx).CreateInBatches(r, batchSize).Error +} + +type chainG[T any] struct { + execG[T] +} + +func (g *chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { + for _, fc := range scopes { + fc(g.g.db.Statement) + } + return g +} + +func (g *chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Where(query, args...) + return g +} + +func (g *chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Not(query, args...) + return g +} + +func (g *chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Or(query, args...) + return g +} + +func (g *chainG[T]) Limit(offset int) ChainInterface[T] { + g.g.db = g.g.db.Limit(offset) + return g +} + +func (g *chainG[T]) Offset(offset int) ChainInterface[T] { + g.g.db = g.g.db.Offset(offset) + return g +} + +func (g *chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Joins(query, args...) + return g +} + +func (g *chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.InnerJoins(query, args...) + return g +} + +func (g *chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Select(query, args...) + return g +} + +func (g *chainG[T]) Omit(columns ...string) ChainInterface[T] { + g.g.db = g.g.db.Omit(columns...) + return g +} + +func (g *chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { + g.g.db = g.g.db.MapColumns(m) + return g +} + +func (g *chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Distinct(args...) + return g +} + +func (g *chainG[T]) Group(name string) ChainInterface[T] { + g.g.db = g.g.db.Group(name) + return g +} + +func (g *chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Having(query, args...) + return g +} + +func (g *chainG[T]) Order(value interface{}) ChainInterface[T] { + g.g.db = g.g.db.Order(value) + return g +} + +func (g *chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Preload(query, args...) + return g +} + +func (g *chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { + r := new(T) + res := g.g.db.WithContext(ctx).Delete(r) + return int(res.RowsAffected), res.Error +} + +func (g *chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { + var r T + res := g.g.db.WithContext(ctx).Model(r).Update(name, value) + return int(res.RowsAffected), res.Error +} + +func (g *chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { + res := g.g.db.WithContext(ctx).Updates(t) + return int(res.RowsAffected), res.Error +} + +func (g *chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { + var r T + err = g.g.db.WithContext(ctx).Model(r).Select(column).Count(&result).Error + return +} + +type execG[T any] struct { + g *g[T] +} + +func (g *execG[T]) First(ctx context.Context) (T, error) { + var r T + err := g.g.db.WithContext(ctx).First(&r).Error + return r, err +} + +func (g *execG[T]) Scan(ctx context.Context, result interface{}) error { + var r T + err := g.g.db.WithContext(ctx).Model(r).Find(&result).Error + return err +} + +func (g *execG[T]) Last(ctx context.Context) (T, error) { + var r T + err := g.g.db.WithContext(ctx).Last(&r).Error + return r, err +} + +func (g *execG[T]) Find(ctx context.Context) ([]T, error) { + var r []T + err := g.g.db.WithContext(ctx).Find(&r).Error + return r, err +} + +func (g *execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { + var data []T + return g.g.db.WithContext(ctx).FindInBatches(data, batchSize, func(tx *DB, batch int) error { + return fc(data, batch) + }).Error +} + +func (g *execG[T]) Row(ctx context.Context) *sql.Row { + return g.g.db.WithContext(ctx).Row() +} + +func (g *execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { + return g.g.db.WithContext(ctx).Rows() +} diff --git a/tests/generics_test.go b/tests/generics_test.go new file mode 100644 index 000000000..4d42e9538 --- /dev/null +++ b/tests/generics_test.go @@ -0,0 +1,173 @@ +package tests_test + +import ( + "context" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestGenericsCreate(t *testing.T) { + generic := gorm.G[User](DB) + ctx := context.Background() + + user := User{Name: "TestGenericsCreate"} + err := generic.Create(ctx, &user) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + if user.ID == 0 { + t.Fatalf("no primary key found for %v", user) + } + + if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if u.Name != user.Name || u.ID != user.ID { + t.Errorf("found invalid user, got %v, expect %v", u, user) + } + + result := struct { + ID int + Name string + }{} + if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil { + t.Fatalf("failed to scan user, got error: %v", err) + } else if result.Name != user.Name || uint(result.ID) != user.ID { + t.Errorf("found invalid user, got %v, expect %v", result, user) + } +} + +func TestGenericsCreateInBatches(t *testing.T) { + batch := []User{ + {Name: "GenericsCreateInBatches1"}, + {Name: "GenericsCreateInBatches2"}, + {Name: "GenericsCreateInBatches3"}, + } + ctx := context.Background() + + if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + for _, u := range batch { + if u.ID == 0 { + t.Fatalf("no primary key found for %v", u) + } + } + + count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*") + if err != nil { + t.Fatalf("Count failed: %v", err) + } + if count != 3 { + t.Errorf("expected 3 records, got %d", count) + } + + found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) + if len(found) != len(batch) { + t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) + } +} + +func TestGenericsExecAndUpdate(t *testing.T) { + ctx := context.Background() + + name := "GenericsExec" + if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil { + t.Fatalf("Exec insert failed: %v", err) + } + + u, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) + if err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if u.Name != name || u.ID == 0 { + t.Errorf("found invalid user, got %v", u) + } + + name += "Update" + rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name) + if rows != 1 { + t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) + } + + nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) + if err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if nu.Name != name || u.ID != nu.ID { + t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) + } + + rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18}) + if rows != 1 { + t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) + } + + nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx) + if err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID { + t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) + } +} + +func TestGenericsRow(t *testing.T) { + ctx := context.Background() + + user := User{Name: "GenericsRow"} + if err := gorm.G[User](DB).Create(ctx, &user); err != nil { + t.Fatalf("Create failed: %v", err) + } + + row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx) + var name string + if err := row.Scan(&name); err != nil { + t.Fatalf("Row scan failed: %v", err) + } + if name != user.Name { + t.Errorf("expected %s, got %s", user.Name, name) + } + + user2 := User{Name: "GenericsRow2"} + if err := gorm.G[User](DB).Create(ctx, &user2); err != nil { + t.Fatalf("Create failed: %v", err) + } + rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx) + if err != nil { + t.Fatalf("Rows failed: %v", err) + } + + count := 0 + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + count++ + } + if count != 2 { + t.Errorf("expected 2 rows, got %d", count) + } +} + +func TestGenericsDelete(t *testing.T) { + ctx := context.Background() + + u := User{Name: "GenericsDelete"} + if err := gorm.G[User](DB).Create(ctx, &u); err != nil { + t.Fatalf("Create failed: %v", err) + } + + rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx) + if err != nil { + t.Fatalf("Delete failed: %v", err) + } + if rows != 1 { + t.Errorf("expected 1 row deleted, got %d", rows) + } + + _, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx) + if err != gorm.ErrRecordNotFound { + t.Fatalf("User after delete failed: %v", err) + } +} From 0fbe4f6dc47c76af9a2385b25e000999c5996d59 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Apr 2025 20:16:51 +0800 Subject: [PATCH 02/22] Add more generics tests --- generics.go | 2 +- tests/generics_test.go | 136 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/generics.go b/generics.go index c165cc14c..fc5dfaff1 100644 --- a/generics.go +++ b/generics.go @@ -240,7 +240,7 @@ func (g *execG[T]) Find(ctx context.Context) ([]T, error) { func (g *execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { var data []T - return g.g.db.WithContext(ctx).FindInBatches(data, batchSize, func(tx *DB, batch int) error { + return g.g.db.WithContext(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error { return fc(data, batch) }).Error } diff --git a/tests/generics_test.go b/tests/generics_test.go index 4d42e9538..83686f317 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "fmt" "testing" "gorm.io/gorm" @@ -12,7 +13,7 @@ func TestGenericsCreate(t *testing.T) { generic := gorm.G[User](DB) ctx := context.Background() - user := User{Name: "TestGenericsCreate"} + user := User{Name: "TestGenericsCreate", Age: 18} err := generic.Create(ctx, &user) if err != nil { t.Fatalf("Create failed: %v", err) @@ -27,6 +28,18 @@ func TestGenericsCreate(t *testing.T) { t.Errorf("found invalid user, got %v, expect %v", u, user) } + if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if u.Name != user.Name || u.Age != 0 { + t.Errorf("found invalid user, got %v, expect %v", u, user) + } + + if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if u.Name != "" || u.Age != u.Age { + t.Errorf("found invalid user, got %v, expect %v", u, user) + } + result := struct { ID int Name string @@ -36,6 +49,11 @@ func TestGenericsCreate(t *testing.T) { } else if result.Name != user.Name || uint(result.ID) != user.ID { t.Errorf("found invalid user, got %v, expect %v", result, user) } + + mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).First(ctx) + if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name { + t.Errorf("failed to find map results, got %v", mapResult) + } } func TestGenericsCreateInBatches(t *testing.T) { @@ -68,6 +86,16 @@ func TestGenericsCreateInBatches(t *testing.T) { if len(found) != len(batch) { t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) } + + found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx) + if len(found) != 2 { + t.Errorf("expected %d from Raw Find, got %d", 2, len(found)) + } + + found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx) + if len(found) != 1 { + t.Errorf("expected %d from Raw Find, got %d", 1, len(found)) + } } func TestGenericsExecAndUpdate(t *testing.T) { @@ -78,7 +106,7 @@ func TestGenericsExecAndUpdate(t *testing.T) { t.Fatalf("Exec insert failed: %v", err) } - u, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) + u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx) if err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if u.Name != name || u.ID == 0 { @@ -171,3 +199,107 @@ func TestGenericsDelete(t *testing.T) { t.Fatalf("User after delete failed: %v", err) } } + +func TestGenericsFindInBatches(t *testing.T) { + ctx := context.Background() + + users := []User{ + {Name: "GenericsFindBatchA"}, + {Name: "GenericsFindBatchB"}, + {Name: "GenericsFindBatchC"}, + {Name: "GenericsFindBatchD"}, + {Name: "GenericsFindBatchE"}, + } + if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + total := 0 + err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error { + if len(chunk) > 2 { + t.Errorf("batch size exceed 2: got %d", len(chunk)) + } + + total += len(chunk) + return nil + }) + if err != nil { + t.Fatalf("FindInBatches failed: %v", err) + } + + if total != len(users) { + t.Errorf("expected total %d, got %d", len(users), total) + } +} + +func TestGenericsScopes(t *testing.T) { + ctx := context.Background() + + users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}} + err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)) + if err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + filterName1 := func(stmt *gorm.Statement) { + stmt.Where("name = ?", "GenericsScopes1") + } + + results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx) + if err != nil { + t.Fatalf("Scopes failed: %v", err) + } + if len(results) != 1 || results[0].Name != "GenericsScopes1" { + t.Fatalf("Scopes expected 1, got %d", len(results)) + } + + notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx) + if len(notResult) != 2 { + t.Fatalf("expected 2 results, got %d", len(notResult)) + } else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" { + t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name) + } + + orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx) + if len(orResult) != 2 { + t.Fatalf("expected 2 results, got %d", len(notResult)) + } else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" { + t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name) + } +} + +func TestGenericsJoinsAndPreload(t *testing.T) { + ctx := context.Background() + + u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} + DB.Create(&u) + + // LEFT JOIN + WHERE + result, err := gorm.G[User](DB).Joins("Company").Where("Company.name = ?", u.Company.Name).First(ctx) + if err != nil { + t.Fatalf("Joins failed: %v", err) + } + if result.Name != u.Name || result.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) + } + + // INNER JOIN + Inline WHERE + result2, err := gorm.G[User](DB).InnerJoins("Company", "Company.name = ?", u.Company.Name).First(ctx) + if err != nil { + t.Fatalf("InnerJoins failed: %v", err) + } + if result2.Name != u.Name || result2.Company.Name != u.Company.Name { + t.Errorf("InnerJoins expected , got %+v", result2) + } + + // Preload + result3, err := gorm.G[User](DB).Preload("Company").Where("name = ?", u.Name).First(ctx) + if err != nil { + t.Fatalf("Joins failed: %v", err) + } + if result3.Name != u.Name || result3.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) + } +} + +// Distinct, Group, Having From ba27874dcd589f2078946263b42928cbef76b8db Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Apr 2025 22:15:10 +0800 Subject: [PATCH 03/22] Add more tests and Take method --- generics.go | 7 +++++ tests/generics_test.go | 66 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/generics.go b/generics.go index fc5dfaff1..f40c73be9 100644 --- a/generics.go +++ b/generics.go @@ -49,6 +49,7 @@ type ExecInterface[T any] interface { Scan(ctx context.Context, r interface{}) error First(context.Context) (T, error) Last(ctx context.Context) (T, error) + Take(context.Context) (T, error) Find(ctx context.Context) ([]T, error) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error Row(ctx context.Context) *sql.Row @@ -232,6 +233,12 @@ func (g *execG[T]) Last(ctx context.Context) (T, error) { return r, err } +func (g *execG[T]) Take(ctx context.Context) (T, error) { + var r T + err := g.g.db.WithContext(ctx).Take(&r).Error + return r, err +} + func (g *execG[T]) Find(ctx context.Context) ([]T, error) { var r []T err := g.g.db.WithContext(ctx).Find(&r).Error diff --git a/tests/generics_test.go b/tests/generics_test.go index 83686f317..9e047a559 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -3,6 +3,8 @@ package tests_test import ( "context" "fmt" + "reflect" + "sort" "testing" "gorm.io/gorm" @@ -28,6 +30,12 @@ func TestGenericsCreate(t *testing.T) { t.Errorf("found invalid user, got %v, expect %v", u, user) } + if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if u.Name != user.Name || u.ID != user.ID { + t.Errorf("found invalid user, got %v, expect %v", u, user) + } + if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if u.Name != user.Name || u.Age != 0 { @@ -50,7 +58,7 @@ func TestGenericsCreate(t *testing.T) { t.Errorf("found invalid user, got %v, expect %v", result, user) } - mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).First(ctx) + mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx) if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name { t.Errorf("failed to find map results, got %v", mapResult) } @@ -302,4 +310,58 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } } -// Distinct, Group, Having +func TestGenericsDistinct(t *testing.T) { + ctx := context.Background() + + batch := []User{ + {Name: "GenericsDistinctDup"}, + {Name: "GenericsDistinctDup"}, + {Name: "GenericsDistinctUnique"}, + } + if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx) + if err != nil { + t.Fatalf("Distinct Find failed: %v", err) + } + + if len(results) != 2 { + t.Errorf("expected 2 distinct names, got %d", len(results)) + } + + var names []string + for _, u := range results { + names = append(names, u.Name) + } + sort.Strings(names) + expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"} + if !reflect.DeepEqual(names, expected) { + t.Errorf("expected names %v, got %v", expected, names) + } +} + +func TestGenericsGroupHaving(t *testing.T) { + ctx := context.Background() + + batch := []User{ + {Name: "GenericsGroupHavingMulti"}, + {Name: "GenericsGroupHavingMulti"}, + {Name: "GenericsGroupHavingSingle"}, + } + if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + grouped, err := gorm.G[User](DB).Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(*) > ?", 1).Find(ctx) + if err != nil { + t.Fatalf("Group+Having Find failed: %v", err) + } + + if len(grouped) != 1 { + t.Errorf("expected 1 group with count>1, got %d", len(grouped)) + } else if grouped[0].Name != "GenericsGroupHavingMulti" { + t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name) + } +} From 2d6d7f94859e31a26bb15bb6f29acb9ec1615764 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Apr 2025 18:18:56 +0800 Subject: [PATCH 04/22] =?UTF-8?q?use=20delayed=E2=80=91ops=20pipeline=20fo?= =?UTF-8?q?r=20generics=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generics.go | 231 +++++++++++++++++++++++++---------------- tests/generics_test.go | 15 +-- 2 files changed, 150 insertions(+), 96 deletions(-) diff --git a/generics.go b/generics.go index f40c73be9..5930a6ce1 100644 --- a/generics.go +++ b/generics.go @@ -56,10 +56,18 @@ type ExecInterface[T any] interface { Rows(ctx context.Context) (*sql.Rows, error) } +type op func(*DB) *DB + func G[T any](db *DB, opts ...clause.Expression) Interface[T] { v := &g[T]{ - db: db.Session(&Session{NewDB: true}).Clauses(opts...), - opts: opts, + db: db.Session(&Session{NewDB: true}), + ops: make([]op, 0, 5), + } + + if len(opts) > 0 { + v.ops = append(v.ops, func(db *DB) *DB { + return db.Clauses(opts...) + }) } v.createG = &createG[T]{ @@ -72,142 +80,187 @@ func G[T any](db *DB, opts ...clause.Expression) Interface[T] { type g[T any] struct { *createG[T] - db *DB - opts []clause.Expression + db *DB + ops []op +} + +func (g *g[T]) apply(ctx context.Context) *DB { + db := g.db.Session(&Session{NewDB: true, Context: ctx}).getInstance() + for _, op := range g.ops { + db = op(db) + } + return db } -func (g *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { - g.db = g.db.Raw(sql, values...) - return &g.execG +func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { + return execG[T]{g: &g[T]{ + db: c.db, + ops: append(c.ops, func(db *DB) *DB { + return db.Raw(sql, values...) + }), + }} } -func (g *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { - return g.db.WithContext(ctx).Exec(sql, values...).Error +func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { + return c.apply(ctx).Exec(sql, values...).Error } type createG[T any] struct { chainG[T] } -func (g *createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { - g.g.db = g.g.db.Table(name, args...) - return g +func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Table(name, args...) + })} } -func (g *createG[T]) Create(ctx context.Context, r *T) error { - return g.g.db.WithContext(ctx).Create(r).Error +func (c createG[T]) Create(ctx context.Context, r *T) error { + return c.g.apply(ctx).Create(r).Error } -func (g *createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { - return g.g.db.WithContext(ctx).CreateInBatches(r, batchSize).Error +func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { + return c.g.apply(ctx).CreateInBatches(r, batchSize).Error } type chainG[T any] struct { execG[T] } -func (g *chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { - for _, fc := range scopes { - fc(g.g.db.Statement) +func (c chainG[T]) with(op op) chainG[T] { + return chainG[T]{ + execG: execG[T]{g: &g[T]{ + db: c.g.db, + ops: append(c.g.ops, op), + }}, } - return g } -func (g *chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Where(query, args...) - return g +func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { + return c.with(func(db *DB) *DB { + for _, fc := range scopes { + fc(db.Statement) + } + return db + }) +} + +func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Table(name, args...) + }) +} + +func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Where(query, args...) + }) } -func (g *chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Not(query, args...) - return g +func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Not(query, args...) + }) } -func (g *chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Or(query, args...) - return g +func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Or(query, args...) + }) } -func (g *chainG[T]) Limit(offset int) ChainInterface[T] { - g.g.db = g.g.db.Limit(offset) - return g +func (c chainG[T]) Limit(offset int) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Limit(offset) + }) } -func (g *chainG[T]) Offset(offset int) ChainInterface[T] { - g.g.db = g.g.db.Offset(offset) - return g +func (c chainG[T]) Offset(offset int) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Offset(offset) + }) } -func (g *chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Joins(query, args...) - return g +func (c chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Joins(query, args...) + }) } -func (g *chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.InnerJoins(query, args...) - return g +func (c chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.InnerJoins(query, args...) + }) } -func (g *chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Select(query, args...) - return g +func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Select(query, args...) + }) } -func (g *chainG[T]) Omit(columns ...string) ChainInterface[T] { - g.g.db = g.g.db.Omit(columns...) - return g +func (c chainG[T]) Omit(columns ...string) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Omit(columns...) + }) } -func (g *chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { - g.g.db = g.g.db.MapColumns(m) - return g +func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.MapColumns(m) + }) } -func (g *chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Distinct(args...) - return g +func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Distinct(args...) + }) } -func (g *chainG[T]) Group(name string) ChainInterface[T] { - g.g.db = g.g.db.Group(name) - return g +func (c chainG[T]) Group(name string) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Group(name) + }) } -func (g *chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Having(query, args...) - return g +func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Having(query, args...) + }) } -func (g *chainG[T]) Order(value interface{}) ChainInterface[T] { - g.g.db = g.g.db.Order(value) - return g +func (c chainG[T]) Order(value interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Order(value) + }) } -func (g *chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Preload(query, args...) - return g +func (c chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Preload(query, args...) + }) } -func (g *chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { +func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { r := new(T) - res := g.g.db.WithContext(ctx).Delete(r) + res := c.g.apply(ctx).Delete(r) return int(res.RowsAffected), res.Error } -func (g *chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { +func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { var r T - res := g.g.db.WithContext(ctx).Model(r).Update(name, value) + res := c.g.apply(ctx).Model(r).Update(name, value) return int(res.RowsAffected), res.Error } -func (g *chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { - res := g.g.db.WithContext(ctx).Updates(t) +func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { + res := c.g.apply(ctx).Updates(t) return int(res.RowsAffected), res.Error } -func (g *chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { +func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { var r T - err = g.g.db.WithContext(ctx).Model(r).Select(column).Count(&result).Error + err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error return } @@ -215,47 +268,47 @@ type execG[T any] struct { g *g[T] } -func (g *execG[T]) First(ctx context.Context) (T, error) { +func (g execG[T]) First(ctx context.Context) (T, error) { var r T - err := g.g.db.WithContext(ctx).First(&r).Error + err := g.g.apply(ctx).First(&r).Error return r, err } -func (g *execG[T]) Scan(ctx context.Context, result interface{}) error { +func (g execG[T]) Scan(ctx context.Context, result interface{}) error { var r T - err := g.g.db.WithContext(ctx).Model(r).Find(&result).Error + err := g.g.apply(ctx).Model(r).Find(&result).Error return err } -func (g *execG[T]) Last(ctx context.Context) (T, error) { +func (g execG[T]) Last(ctx context.Context) (T, error) { var r T - err := g.g.db.WithContext(ctx).Last(&r).Error + err := g.g.apply(ctx).Last(&r).Error return r, err } -func (g *execG[T]) Take(ctx context.Context) (T, error) { +func (g execG[T]) Take(ctx context.Context) (T, error) { var r T - err := g.g.db.WithContext(ctx).Take(&r).Error + err := g.g.apply(ctx).Take(&r).Error return r, err } -func (g *execG[T]) Find(ctx context.Context) ([]T, error) { +func (g execG[T]) Find(ctx context.Context) ([]T, error) { var r []T - err := g.g.db.WithContext(ctx).Find(&r).Error + err := g.g.apply(ctx).Find(&r).Error return r, err } -func (g *execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { +func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { var data []T - return g.g.db.WithContext(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error { + return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error { return fc(data, batch) }).Error } -func (g *execG[T]) Row(ctx context.Context) *sql.Row { - return g.g.db.WithContext(ctx).Row() +func (g execG[T]) Row(ctx context.Context) *sql.Row { + return g.g.apply(ctx).Row() } -func (g *execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { - return g.g.db.WithContext(ctx).Rows() +func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { + return g.g.apply(ctx).Rows() } diff --git a/tests/generics_test.go b/tests/generics_test.go index 9e047a559..bceb2f9eb 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -12,11 +12,10 @@ import ( ) func TestGenericsCreate(t *testing.T) { - generic := gorm.G[User](DB) ctx := context.Background() user := User{Name: "TestGenericsCreate", Age: 18} - err := generic.Create(ctx, &user) + err := gorm.G[User](DB).Create(ctx, &user) if err != nil { t.Fatalf("Create failed: %v", err) } @@ -60,7 +59,7 @@ func TestGenericsCreate(t *testing.T) { mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx) if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name { - t.Errorf("failed to find map results, got %v", mapResult) + t.Errorf("failed to find map results, got %v, err %v", mapResult, err) } } @@ -92,6 +91,7 @@ func TestGenericsCreateInBatches(t *testing.T) { found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) if len(found) != len(batch) { + fmt.Println(found) t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) } @@ -278,12 +278,13 @@ func TestGenericsScopes(t *testing.T) { func TestGenericsJoinsAndPreload(t *testing.T) { ctx := context.Background() + db := gorm.G[User](DB) u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} - DB.Create(&u) + db.Create(ctx, &u) // LEFT JOIN + WHERE - result, err := gorm.G[User](DB).Joins("Company").Where("Company.name = ?", u.Company.Name).First(ctx) + result, err := db.Joins("Company").Where("Company.name = ?", u.Company.Name).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } @@ -292,7 +293,7 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // INNER JOIN + Inline WHERE - result2, err := gorm.G[User](DB).InnerJoins("Company", "Company.name = ?", u.Company.Name).First(ctx) + result2, err := db.InnerJoins("Company", "Company.name = ?", u.Company.Name).First(ctx) if err != nil { t.Fatalf("InnerJoins failed: %v", err) } @@ -301,7 +302,7 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // Preload - result3, err := gorm.G[User](DB).Preload("Company").Where("name = ?", u.Name).First(ctx) + result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } From 3de6d0b2f98dddea196f96270b7df92caaea2a03 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Apr 2025 20:28:21 +0800 Subject: [PATCH 05/22] fix generics tests for mysql --- tests/connpool_test.go | 1 + tests/generics_test.go | 2 +- tests/go.mod | 20 ++++++++++---------- tests/joins_test.go | 6 +++--- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/connpool_test.go b/tests/connpool_test.go index 21a2bad03..32492896b 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -119,6 +119,7 @@ func TestConnPoolWrapper(t *testing.T) { }() db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) + db.Logger = DB.Logger if err != nil { t.Fatalf("Should open db success, but got %v", err) } diff --git a/tests/generics_test.go b/tests/generics_test.go index bceb2f9eb..036f1cf93 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -355,7 +355,7 @@ func TestGenericsGroupHaving(t *testing.T) { t.Fatalf("CreateInBatches failed: %v", err) } - grouped, err := gorm.G[User](DB).Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(*) > ?", 1).Find(ctx) + grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx) if err != nil { t.Fatalf("Group+Having Find failed: %v", err) } diff --git a/tests/go.mod b/tests/go.mod index 301434332..c87ccca1f 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,15 +1,15 @@ module gorm.io/gorm/tests -go 1.18 +go 1.18.0 require ( github.com/google/uuid v1.6.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 gorm.io/driver/mysql v1.5.7 - gorm.io/driver/postgres v1.5.10 - gorm.io/driver/sqlite v1.5.6 + gorm.io/driver/postgres v1.5.11 + gorm.io/driver/sqlite v1.5.7 gorm.io/driver/sqlserver v1.5.4 gorm.io/gorm v1.25.12 ) @@ -17,20 +17,20 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-sql-driver/mysql v1.9.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.7.1 // indirect + github.com/jackc/pgx/v5 v5.7.4 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect - github.com/mattn/go-sqlite3 v1.14.24 // indirect - github.com/microsoft/go-mssqldb v1.7.2 // indirect + github.com/mattn/go-sqlite3 v1.14.28 // indirect + github.com/microsoft/go-mssqldb v1.8.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.29.0 // indirect - golang.org/x/text v0.20.0 // indirect + golang.org/x/crypto v0.37.0 // indirect + golang.org/x/text v0.24.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/tests/joins_test.go b/tests/joins_test.go index 497f81467..64a9e407b 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -419,7 +419,7 @@ func TestJoinsPreload_Issue7013(t *testing.T) { var entries []User assert.NotPanics(t, func() { assert.NoError(t, - DB.Debug().Preload("Manager.Team"). + DB.Preload("Manager.Team"). Joins("Manager.Company"). Find(&entries).Error) }) @@ -456,7 +456,7 @@ func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) { var entries []Building assert.NotPanics(t, func() { assert.NoError(t, - DB.Debug().Preload("Owner.Furnitures"). + DB.Preload("Owner.Furnitures"). Joins("Owner.Company"). Find(&entries).Error) }) @@ -468,7 +468,7 @@ func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) { var entries []User assert.NotPanics(t, func() { assert.NoError(t, - DB.Debug().Preload("Manager.Team"). + DB.Preload("Manager.Team"). Joins("Manager.Company"). Where("1 <> 1"). Find(&entries).Error) From 797a557cc82f98bae17ed23fd29516334b10244b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 20 Apr 2025 10:47:00 +0800 Subject: [PATCH 06/22] Support SubQuery for Generics --- generics.go | 5 +++++ statement.go | 11 ++++++----- tests/generics_test.go | 32 ++++++++++++++++++++++++++++++++ tests/go.mod | 2 +- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/generics.go b/generics.go index 5930a6ce1..9dd1af7d3 100644 --- a/generics.go +++ b/generics.go @@ -127,6 +127,11 @@ type chainG[T any] struct { execG[T] } +func (c chainG[T]) getInstance() *DB { + var r T + return c.g.apply(context.Background()).Model(r).getInstance() +} + func (c chainG[T]) with(op op) chainG[T] { return chainG[T]{ execG: execG[T]{g: &g[T]{ diff --git a/statement.go b/statement.go index 39e05d093..11791c3a1 100644 --- a/statement.go +++ b/statement.go @@ -205,19 +205,20 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { writer.WriteString("(NULL)") } - case *DB: - subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() - if v.Statement.SQL.Len() > 0 { + case interface{ getInstance() *DB }: + cv := v.getInstance() + subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() + if cv.Statement.SQL.Len() > 0 { var ( vars = subdb.Statement.Vars - sql = v.Statement.SQL.String() + sql = cv.Statement.SQL.String() ) subdb.Statement.Vars = make([]interface{}, 0, len(vars)) for _, vv := range vars { subdb.Statement.Vars = append(subdb.Statement.Vars, vv) bindvar := strings.Builder{} - v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + cv.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) sql = strings.Replace(sql, bindvar.String(), "?", 1) } diff --git a/tests/generics_test.go b/tests/generics_test.go index 036f1cf93..1587d0909 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -366,3 +366,35 @@ func TestGenericsGroupHaving(t *testing.T) { t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name) } } + +func TestGenericsSubQuery(t *testing.T) { + ctx := context.Background() + users := []User{ + {Name: "GenericsSubquery_1", Age: 10}, + {Name: "GenericsSubquery_2", Age: 20}, + {Name: "GenericsSubquery_3", Age: 30}, + {Name: "GenericsSubquery_4", Age: 40}, + } + + if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx) + if err != nil { + t.Fatalf("got error: %v", err) + } + + if len(results) != 4 { + t.Errorf("Four users should be found, instead found %d", len(results)) + } + + results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx) + if err != nil { + t.Fatalf("got error: %v", err) + } + + if len(results) != 3 { + t.Errorf("Three users should be found, instead found %d", len(results)) + } +} diff --git a/tests/go.mod b/tests/go.mod index c87ccca1f..2d647b08a 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm/tests -go 1.18.0 +go 1.23.0 require ( github.com/google/uuid v1.6.0 From 4fcd909f2e04d9dd03b46d1fd04ffc8131d41c10 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 20 Apr 2025 20:08:07 +0800 Subject: [PATCH 07/22] Add clause.JoinTable helper method --- clause/joins.go | 8 ++++++++ tests/generics_test.go | 5 +++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/clause/joins.go b/clause/joins.go index 879892be4..b0f0359cf 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -1,5 +1,7 @@ package clause +import "gorm.io/gorm/utils" + type JoinType string const ( @@ -18,6 +20,12 @@ type Join struct { Expression Expression } +func JoinTable(names ...string) Table { + return Table{ + Name: utils.JoinNestedRelationNames(names), + } +} + func (join Join) Build(builder Builder) { if join.Expression != nil { join.Expression.Build(builder) diff --git a/tests/generics_test.go b/tests/generics_test.go index 1587d0909..c467f9ff1 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -8,6 +8,7 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -284,7 +285,7 @@ func TestGenericsJoinsAndPreload(t *testing.T) { db.Create(ctx, &u) // LEFT JOIN + WHERE - result, err := db.Joins("Company").Where("Company.name = ?", u.Company.Name).First(ctx) + result, err := db.Joins("Company").Where("?.name = ?", clause.JoinTable("Company"), u.Company.Name).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } @@ -293,7 +294,7 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // INNER JOIN + Inline WHERE - result2, err := db.InnerJoins("Company", "Company.name = ?", u.Company.Name).First(ctx) + result2, err := db.InnerJoins("Company", "?.name = ?", clause.JoinTable("Company"), u.Company.Name).First(ctx) if err != nil { t.Fatalf("InnerJoins failed: %v", err) } From 7095605cd0b7feafc1ac0ebf8287b979b4b1dd0f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 21 Apr 2025 17:33:34 +0800 Subject: [PATCH 08/22] Fix golangci-lint error --- statement.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/statement.go b/statement.go index 11791c3a1..88d76dc3f 100644 --- a/statement.go +++ b/statement.go @@ -218,7 +218,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { for _, vv := range vars { subdb.Statement.Vars = append(subdb.Statement.Vars, vv) bindvar := strings.Builder{} - cv.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + cv.BindVarTo(&bindvar, subdb.Statement, vv) sql = strings.Replace(sql, bindvar.String(), "?", 1) } From 05925b2fc08f9be5a3ca6c384454c7a43febe920 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 May 2025 17:16:10 +0800 Subject: [PATCH 09/22] Complete the design and implementation of generic version Join --- chainable_api.go | 7 ++++--- clause/joins.go | 24 ++++++++++++++++++++++++ generics.go | 16 ++++------------ tests/generics_test.go | 27 +++++++++++++++++++++------ tests/go.mod | 4 ++-- 5 files changed, 55 insertions(+), 23 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 8953413d5..8a6aea343 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { // Unscoped allows queries to include records marked as deleted, // overriding the soft deletion behavior. // Example: -// var users []User -// db.Unscoped().Find(&users) -// // Retrieves all users, including deleted ones. +// +// var users []User +// db.Unscoped().Find(&users) +// // Retrieves all users, including deleted ones. func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() tx.Statement.Unscoped = true diff --git a/clause/joins.go b/clause/joins.go index b0f0359cf..ddb2a5a90 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -11,6 +11,30 @@ const ( RightJoin JoinType = "RIGHT" ) +type JoinTarget struct { + Type JoinType + Association string + Subquery Expression + Table string +} + +func Has(name string) JoinTarget { + return JoinTarget{Type: LeftJoin, Association: name} +} + +func (jt JoinType) Association(name string) JoinTarget { + return JoinTarget{Type: jt, Association: name} +} + +func (jt JoinType) Subquery(subquery Expression) JoinTarget { + return JoinTarget{Type: jt, Subquery: subquery} +} + +func (jt JoinTarget) As(name string) JoinTarget { + jt.Table = name + return jt +} + // Join clause for from type Join struct { Type JoinType diff --git a/generics.go b/generics.go index 9dd1af7d3..95d98100a 100644 --- a/generics.go +++ b/generics.go @@ -28,8 +28,7 @@ type ChainInterface[T any] interface { Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] - Joins(query string, args ...interface{}) ChainInterface[T] - InnerJoins(query string, args ...interface{}) ChainInterface[T] + Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T] Omit(columns ...string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T] @@ -186,16 +185,9 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] { }) } -func (c chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.Joins(query, args...) - }) -} - -func (c chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.InnerJoins(query, args...) - }) +func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] { + // TODO + return nil } func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { diff --git a/tests/generics_test.go b/tests/generics_test.go index c467f9ff1..2d69d22e7 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -285,7 +285,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { db.Create(ctx, &u) // LEFT JOIN + WHERE - result, err := db.Joins("Company").Where("?.name = ?", clause.JoinTable("Company"), u.Company.Name).First(ctx) + result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { + return db.Where("?.name = ?", joinTable, u.Company.Name) + }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } @@ -293,13 +295,26 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } - // INNER JOIN + Inline WHERE - result2, err := db.InnerJoins("Company", "?.name = ?", clause.JoinTable("Company"), u.Company.Name).First(ctx) + // JOIN + result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { + return nil + }).First(ctx) + if err != nil { + t.Fatalf("Joins failed: %v", err) + } + if result.Name != u.Name || result.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) + } + + // Left JOIN + result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { + return nil + }).First(ctx) if err != nil { - t.Fatalf("InnerJoins failed: %v", err) + t.Fatalf("Joins failed: %v", err) } - if result2.Name != u.Name || result2.Company.Name != u.Company.Name { - t.Errorf("InnerJoins expected , got %+v", result2) + if result.Name != u.Name || result.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) } // Preload diff --git a/tests/go.mod b/tests/go.mod index 2d647b08a..7f4d84f7e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -29,8 +29,8 @@ require ( github.com/microsoft/go-mssqldb v1.8.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.37.0 // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/text v0.25.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) From d073805c86836f1bb9f0c3927ea9c5434ace3cea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 20 May 2025 19:44:57 +0800 Subject: [PATCH 10/22] improve generics version Joins support --- callbacks/query.go | 17 +++++- clause/joins.go | 6 +- generics.go | 126 +++++++++++++++++++++++++++++++++++++++-- scan.go | 7 +++ statement.go | 43 ++++++++------ tests/generics_test.go | 53 +++++++++++++---- 6 files changed, 213 insertions(+), 39 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index bbf238a9f..56a5944a0 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -146,9 +146,13 @@ func BuildQuerySQL(db *gorm.DB) { if isRelations { genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { - tableAliasName := relation.Name - if parentTableName != clause.CurrentTable { - tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + tableAliasName := join.Alias + + if tableAliasName == "" { + tableAliasName = relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + } } columnStmt := gorm.Statement{ @@ -167,6 +171,13 @@ func BuildQuerySQL(db *gorm.DB) { } } + if join.Expression != nil { + return clause.Join{ + Type: join.JoinType, + Expression: join.Expression, + } + } + exprs := make([]clause.Expression, len(relation.References)) for idx, ref := range relation.References { if ref.OwnPrimaryKey { diff --git a/clause/joins.go b/clause/joins.go index ddb2a5a90..a6f13e55c 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -19,15 +19,15 @@ type JoinTarget struct { } func Has(name string) JoinTarget { - return JoinTarget{Type: LeftJoin, Association: name} + return JoinTarget{Type: InnerJoin, Association: name} } func (jt JoinType) Association(name string) JoinTarget { return JoinTarget{Type: jt, Association: name} } -func (jt JoinType) Subquery(subquery Expression) JoinTarget { - return JoinTarget{Type: jt, Subquery: subquery} +func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget { + return JoinTarget{Type: jt, Association: name, Subquery: subquery} } func (jt JoinTarget) As(name string) JoinTarget { diff --git a/generics.go b/generics.go index 95d98100a..43c7223a9 100644 --- a/generics.go +++ b/generics.go @@ -3,8 +3,11 @@ package gorm import ( "context" "database/sql" + "fmt" + "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" ) type Interface[T any] interface { @@ -28,7 +31,7 @@ type ChainInterface[T any] interface { Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] - Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] + Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T] Omit(columns ...string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T] @@ -38,6 +41,8 @@ type ChainInterface[T any] interface { Order(value interface{}) ChainInterface[T] Preload(query string, args ...interface{}) ChainInterface[T] + Build(builder clause.Builder) + Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) @@ -55,6 +60,14 @@ type ExecInterface[T any] interface { Rows(ctx context.Context) (*sql.Rows, error) } +type QueryInterface interface { + Select(...string) QueryInterface + Omit(...string) QueryInterface + Where(query interface{}, args ...interface{}) QueryInterface + Not(query interface{}, args ...interface{}) QueryInterface + Or(query interface{}, args ...interface{}) QueryInterface +} + type op func(*DB) *DB func G[T any](db *DB, opts ...clause.Expression) Interface[T] { @@ -185,9 +198,77 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] { }) } -func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] { - // TODO - return nil +type query struct { + db *DB +} + +func (q query) Where(query interface{}, args ...interface{}) QueryInterface { + q.db.Where(query, args...) + return q +} + +func (q query) Or(query interface{}, args ...interface{}) QueryInterface { + q.db.Where(query, args...) + return q +} + +func (q query) Not(query interface{}, args ...interface{}) QueryInterface { + q.db.Where(query, args...) + return q +} + +func (q query) Select(columns ...string) QueryInterface { + q.db.Select(columns) + return q +} + +func (q query) Omit(columns ...string) QueryInterface { + q.db.Omit(columns...) + return q +} + +func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] { + return c.with(func(db *DB) *DB { + if jt.Table == "" { + jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name + } + + q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} + if args != nil { + args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}) + } + + j := join{ + Name: jt.Association, + Alias: jt.Table, + Selects: q.db.Statement.Selects, + Omits: q.db.Statement.Omits, + JoinType: jt.Type, + } + + if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + j.On = &where + } + + if jt.Subquery != nil { + joinType := j.JoinType + if joinType == "" { + joinType = clause.LeftJoin + } + + expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}} + + if j.On != nil { + expr.SQL += " ON ?" + expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs}) + } + + j.Expression = expr + } + + db.Statement.Joins = append(db.Statement.Joins, j) + return db + }) } func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { @@ -261,6 +342,43 @@ func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err return } +func (c chainG[T]) Build(builder clause.Builder) { + subdb := c.getInstance() + subdb.Logger = logger.Discard + subdb.DryRun = true + + if stmt, ok := builder.(*Statement); ok { + if subdb.Statement.SQL.Len() > 0 { + var ( + vars = subdb.Statement.Vars + sql = subdb.Statement.SQL.String() + ) + + subdb.Statement.Vars = make([]interface{}, 0, len(vars)) + for _, vv := range vars { + subdb.Statement.Vars = append(subdb.Statement.Vars, vv) + bindvar := strings.Builder{} + subdb.BindVarTo(&bindvar, subdb.Statement, vv) + sql = strings.Replace(sql, bindvar.String(), "?", 1) + } + + subdb.Statement.SQL.Reset() + subdb.Statement.Vars = stmt.Vars + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } else { + clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } + } else { + subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) + subdb.callbacks.Query().Execute(subdb) + } + + builder.WriteString(subdb.Statement.SQL.String()) + stmt.Vars = subdb.Statement.Vars + } +} + type execG[T any] struct { g *g[T] } diff --git a/scan.go b/scan.go index 6dc55f623..624f822fa 100644 --- a/scan.go +++ b/scan.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" "reflect" + "strings" "time" "gorm.io/gorm/schema" @@ -244,6 +245,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) { matchedFieldCount[column] = 1 } } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation + for _, join := range db.Statement.Joins { + if join.Alias == names[0] { + names = append(strings.Split(join.Name, "."), names[len(names)-1]) + } + } + if rel, ok := sch.Relationships.Relations[names[0]]; ok { subNameCount := len(names) // nested relation fields diff --git a/statement.go b/statement.go index 88d76dc3f..63f78006e 100644 --- a/statement.go +++ b/statement.go @@ -50,12 +50,14 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where - Selects []string - Omits []string - JoinType clause.JoinType + Name string + Alias string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string + Expression clause.Expression + JoinType clause.JoinType } // StatementModifier statement modifier interface @@ -322,6 +324,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] arg, _ = valuer.Value() } + curTable := stmt.Table + if curTable == "" { + curTable = clause.CurrentTable + } + switch v := arg.(type) { case clause.Expression: conds = append(conds, v) @@ -352,7 +359,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] sort.Strings(keys) for _, key := range keys { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + column := clause.Column{Name: key, Table: curTable} + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: keys := make([]string, 0, len(v)) @@ -363,12 +371,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) + column := clause.Column{Name: key, Table: curTable} switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } else if _, ok := v[key].(Valuer); ok { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } else { // optimize reflect value length valueLen := reflectValue.Len() @@ -377,10 +386,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] values[i] = reflectValue.Index(i).Interface() } - conds = append(conds, clause.IN{Column: key, Values: values}) + conds = append(conds, clause.IN{Column: column, Values: values}) } default: - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } } default: @@ -407,9 +416,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if selected || (!restricted && field.Readable) { if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) } } } @@ -421,9 +430,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if selected || (!restricted && field.Readable) { if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) } } } @@ -448,14 +457,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } if len(values) > 0 { - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values}) return []clause.Expression{clause.And(conds...)} } return nil } } - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args}) } } } diff --git a/tests/generics_test.go b/tests/generics_test.go index 2d69d22e7..1e1bf7116 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -44,7 +44,7 @@ func TestGenericsCreate(t *testing.T) { if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil { t.Fatalf("failed to find user, got error: %v", err) - } else if u.Name != "" || u.Age != u.Age { + } else if u.Name != "" || u.Age != user.Age { t.Errorf("found invalid user, got %v, expect %v", u, user) } @@ -92,7 +92,6 @@ func TestGenericsCreateInBatches(t *testing.T) { found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) if len(found) != len(batch) { - fmt.Println(found) t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) } @@ -282,10 +281,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { db := gorm.G[User](DB) u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} - db.Create(ctx, &u) + u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}} + u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}} + db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) - // LEFT JOIN + WHERE - result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { + // Inner JOIN + WHERE + result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { return db.Where("?.name = ?", joinTable, u.Company.Name) }).First(ctx) if err != nil { @@ -295,9 +296,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } - // JOIN - result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { - return nil + // Inner JOIN + WHERE with map + result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + return db.Where(map[string]any{"name": u.Company.Name}) }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) @@ -306,10 +307,8 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } - // Left JOIN - result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { - return nil - }).First(ctx) + // Left JOIN w/o WHERE + result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } @@ -317,6 +316,36 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } + // Left JOIN + Alias WHERE + result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + if joinTable.Name != "t" { + t.Fatalf("Join table should be t, but got %v", joinTable.Name) + } + return db.Where("?.name = ?", joinTable, u.Company.Name) + }).Where(map[string]any{"name": u.Name}).First(ctx) + if err != nil { + t.Fatalf("Joins failed: %v", err) + } + if result.Name != u.Name || result.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) + } + + // Raw Subquery JOIN + WHERE + result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), + func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + if joinTable.Name != "t" { + t.Fatalf("Join table should be t, but got %v", joinTable.Name) + } + return db.Where("?.name = ?", joinTable, u.Company.Name) + }, + ).Where(map[string]any{"name": u2.Name}).First(ctx) + if err != nil { + t.Fatalf("Raw subquery join failed: %v", err) + } + if result.Name != u2.Name || result.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) + } + // Preload result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) if err != nil { From ba94e4eb2f4abc99c9d24d2af6f5ae7ce888d22a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 20 May 2025 20:02:04 +0800 Subject: [PATCH 11/22] allow configuring select/omit columns for joins via subqueries --- generics.go | 10 ++++++++++ tests/generics_test.go | 18 +++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/generics.go b/generics.go index 43c7223a9..d1d1a6e54 100644 --- a/generics.go +++ b/generics.go @@ -256,6 +256,16 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, join joinType = clause.LeftJoin } + if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok { + stmt := db.getInstance().Statement + if len(j.Selects) == 0 { + j.Selects = stmt.Selects + } + if len(j.Omits) == 0 { + j.Omits = stmt.Omits + } + } + expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}} if j.On != nil { diff --git a/tests/generics_test.go b/tests/generics_test.go index 1e1bf7116..313b6baeb 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -342,7 +342,23 @@ func TestGenericsJoinsAndPreload(t *testing.T) { if err != nil { t.Fatalf("Raw subquery join failed: %v", err) } - if result.Name != u2.Name || result.Company.Name != u.Company.Name { + if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) + } + + // Raw Subquery JOIN + WHERE + Select + result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"), + func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + if joinTable.Name != "t" { + t.Fatalf("Join table should be t, but got %v", joinTable.Name) + } + return db.Where("?.name = ?", joinTable, u.Company.Name) + }, + ).Where(map[string]any{"name": u2.Name}).First(ctx) + if err != nil { + t.Fatalf("Raw subquery join failed: %v", err) + } + if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } From 46946735264c3ae59aab517b097745d373f664ef Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 20 May 2025 22:46:24 +0800 Subject: [PATCH 12/22] finish generic version Preload --- callbacks/preload.go | 2 +- generics.go | 102 +++++++++++++++++++++++++++++++++-------- tests/generics_test.go | 48 +++++++++++++------ 3 files changed, 120 insertions(+), 32 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index fd8214bb2..4a6f2b794 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -152,7 +152,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati return gorm.ErrInvalidData } } else { - tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) + tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks, Initialized: true}) tx.Statement.ReflectValue = db.Statement.ReflectValue tx.Statement.Unscoped = db.Statement.Unscoped if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { diff --git a/generics.go b/generics.go index d1d1a6e54..4953c7588 100644 --- a/generics.go +++ b/generics.go @@ -31,7 +31,8 @@ type ChainInterface[T any] interface { Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] - Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] + Joins(query clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T] Omit(columns ...string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T] @@ -39,7 +40,6 @@ type ChainInterface[T any] interface { Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] - Preload(query string, args ...interface{}) ChainInterface[T] Build(builder clause.Builder) @@ -60,12 +60,24 @@ type ExecInterface[T any] interface { Rows(ctx context.Context) (*sql.Rows, error) } -type QueryInterface interface { - Select(...string) QueryInterface - Omit(...string) QueryInterface - Where(query interface{}, args ...interface{}) QueryInterface - Not(query interface{}, args ...interface{}) QueryInterface - Or(query interface{}, args ...interface{}) QueryInterface +type JoinBuilder interface { + Select(...string) JoinBuilder + Omit(...string) JoinBuilder + Where(query interface{}, args ...interface{}) JoinBuilder + Not(query interface{}, args ...interface{}) JoinBuilder + Or(query interface{}, args ...interface{}) JoinBuilder +} + +type PreloadBuilder interface { + Select(...string) PreloadBuilder + Omit(...string) PreloadBuilder + Where(query interface{}, args ...interface{}) PreloadBuilder + Not(query interface{}, args ...interface{}) PreloadBuilder + Or(query interface{}, args ...interface{}) PreloadBuilder + Limit(offset int) PreloadBuilder + Offset(offset int) PreloadBuilder + Order(value interface{}) PreloadBuilder + Scopes(scopes ...func(db *Statement)) PreloadBuilder } type op func(*DB) *DB @@ -198,42 +210,90 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] { }) } -type query struct { +type joinBuilder struct { db *DB } -func (q query) Where(query interface{}, args ...interface{}) QueryInterface { +func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q query) Or(query interface{}, args ...interface{}) QueryInterface { +func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q query) Not(query interface{}, args ...interface{}) QueryInterface { +func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q query) Select(columns ...string) QueryInterface { +func (q joinBuilder) Select(columns ...string) JoinBuilder { q.db.Select(columns) return q } -func (q query) Omit(columns ...string) QueryInterface { +func (q joinBuilder) Omit(columns ...string) JoinBuilder { q.db.Omit(columns...) return q } -func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] { +type preloadBuilder struct { + db *DB +} + +func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { + q.db.Where(query, args...) + return q +} + +func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { + q.db.Where(query, args...) + return q +} + +func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { + q.db.Where(query, args...) + return q +} + +func (q preloadBuilder) Select(columns ...string) PreloadBuilder { + q.db.Select(columns) + return q +} + +func (q preloadBuilder) Omit(columns ...string) PreloadBuilder { + q.db.Omit(columns...) + return q +} + +func (q preloadBuilder) Limit(limit int) PreloadBuilder { + q.db.Limit(limit) + return q +} +func (q preloadBuilder) Offset(offset int) PreloadBuilder { + q.db.Offset(offset) + return q +} +func (q preloadBuilder) Order(value interface{}) PreloadBuilder { + q.db.Order(value) + return q +} +func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder { + for _, fc := range scopes { + fc(q.db.Statement) + } + return q +} + +func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { return c.with(func(db *DB) *DB { if jt.Table == "" { jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name } - q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} + q := joinBuilder{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} if args != nil { args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}) } @@ -323,9 +383,15 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] { }) } -func (c chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { +func (c chainG[T]) Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] { return c.with(func(db *DB) *DB { - return db.Preload(query, args...) + return db.Preload(association, func(db *DB) *DB { + q := preloadBuilder{db: db} + if args != nil { + args(q) + } + return q.db + }) }) } diff --git a/tests/generics_test.go b/tests/generics_test.go index 313b6baeb..2efaacdc8 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -286,8 +286,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) // Inner JOIN + WHERE - result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { - return db.Where("?.name = ?", joinTable, u.Company.Name) + result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { + db.Where("?.name = ?", joinTable, u.Company.Name) + return nil }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) @@ -297,8 +298,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // Inner JOIN + WHERE with map - result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { - return db.Where(map[string]any{"name": u.Company.Name}) + result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { + db.Where(map[string]any{"name": u.Company.Name}) + return nil }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) @@ -317,11 +319,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // Left JOIN + Alias WHERE - result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { if joinTable.Name != "t" { t.Fatalf("Join table should be t, but got %v", joinTable.Name) } - return db.Where("?.name = ?", joinTable, u.Company.Name) + db.Where("?.name = ?", joinTable, u.Company.Name) + return nil }).Where(map[string]any{"name": u.Name}).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) @@ -332,11 +335,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { // Raw Subquery JOIN + WHERE result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), - func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { if joinTable.Name != "t" { t.Fatalf("Join table should be t, but got %v", joinTable.Name) } - return db.Where("?.name = ?", joinTable, u.Company.Name) + db.Where("?.name = ?", joinTable, u.Company.Name) + return nil }, ).Where(map[string]any{"name": u2.Name}).First(ctx) if err != nil { @@ -348,11 +352,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { // Raw Subquery JOIN + WHERE + Select result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"), - func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { if joinTable.Name != "t" { t.Fatalf("Join table should be t, but got %v", joinTable.Name) } - return db.Where("?.name = ?", joinTable, u.Company.Name) + db.Where("?.name = ?", joinTable, u.Company.Name) + return nil }, ).Where(map[string]any{"name": u2.Name}).First(ctx) if err != nil { @@ -363,12 +368,29 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // Preload - result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) + result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx) if err != nil { - t.Fatalf("Joins failed: %v", err) + t.Fatalf("Preload failed: %v", err) } if result3.Name != u.Name || result3.Company.Name != u.Company.Name { - t.Fatalf("Joins expected %s, got %+v", u.Name, result) + t.Fatalf("Preload expected %s, got %+v", u.Name, result) + } + + results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error { + db.Where("name = ?", u.Company.Name) + return nil + }).Find(ctx) + if err != nil { + t.Fatalf("Preload failed: %v", err) + } + for _, result := range results { + if result.Name == u.Name { + if result.Company.Name != u.Company.Name { + t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name) + } + } else if result.Company.Name != "" { + t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name) + } } } From e330694e262a94523e4c330961a39c0ba221bc62 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 20 May 2025 22:52:55 +0800 Subject: [PATCH 13/22] handle error of generics Joins/Preload --- generics.go | 26 +++++++++++++++----------- tests/generics_test.go | 15 +++++++++++++++ 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/generics.go b/generics.go index 4953c7588..230f07f51 100644 --- a/generics.go +++ b/generics.go @@ -31,8 +31,8 @@ type ChainInterface[T any] interface { Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] - Joins(query clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] - Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] + Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T] Omit(columns ...string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T] @@ -287,15 +287,17 @@ func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder { return q } -func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { +func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { return c.with(func(db *DB) *DB { if jt.Table == "" { jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name } - q := joinBuilder{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} - if args != nil { - args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}) + q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)} + if on != nil { + if err := on(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { + db.AddError(err) + } } j := join{ @@ -383,12 +385,14 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] { }) } -func (c chainG[T]) Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] { +func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] { return c.with(func(db *DB) *DB { - return db.Preload(association, func(db *DB) *DB { - q := preloadBuilder{db: db} - if args != nil { - args(q) + return db.Preload(association, func(tx *DB) *DB { + q := preloadBuilder{db: tx} + if query != nil { + if err := query(q); err != nil { + db.AddError(err) + } } return q.db }) diff --git a/tests/generics_test.go b/tests/generics_test.go index 2efaacdc8..2e0dbc28e 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "errors" "fmt" "reflect" "sort" @@ -367,6 +368,13 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } + _, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { + return errors.New("join error") + }).First(ctx) + if err == nil { + t.Fatalf("Joins should got error, but got nil") + } + // Preload result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx) if err != nil { @@ -392,6 +400,13 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name) } } + + _, err = db.Preload("Company", func(db gorm.PreloadBuilder) error { + return errors.New("preload error") + }).Find(ctx) + if err == nil { + t.Fatalf("Preload should failed, but got nil") + } } func TestGenericsDistinct(t *testing.T) { From 9b1ce2b7635a4819bf99184a9e6823a20f2bc21c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 20 May 2025 23:02:40 +0800 Subject: [PATCH 14/22] fix tests --- callbacks/preload.go | 2 +- generics.go | 2 +- tests/go.mod | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 4a6f2b794..fd8214bb2 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -152,7 +152,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati return gorm.ErrInvalidData } } else { - tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) tx.Statement.ReflectValue = db.Statement.ReflectValue tx.Statement.Unscoped = db.Statement.Unscoped if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { diff --git a/generics.go b/generics.go index 230f07f51..7c7257f6e 100644 --- a/generics.go +++ b/generics.go @@ -388,7 +388,7 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] { func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Preload(association, func(tx *DB) *DB { - q := preloadBuilder{db: tx} + q := preloadBuilder{db: tx.getInstance()} if query != nil { if err := query(q); err != nil { db.AddError(err) diff --git a/tests/go.mod b/tests/go.mod index 7f4d84f7e..1c644f314 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -11,7 +11,7 @@ require ( gorm.io/driver/postgres v1.5.11 gorm.io/driver/sqlite v1.5.7 gorm.io/driver/sqlserver v1.5.4 - gorm.io/gorm v1.25.12 + gorm.io/gorm v1.26.1 ) require ( @@ -22,11 +22,11 @@ require ( github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.7.4 // indirect + github.com/jackc/pgx/v5 v5.7.5 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-sqlite3 v1.14.28 // indirect - github.com/microsoft/go-mssqldb v1.8.0 // indirect + github.com/microsoft/go-mssqldb v1.8.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect golang.org/x/crypto v0.38.0 // indirect From 6307f69f18d6c9fe372a4ff03ab0b7d3352fbba0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 21 May 2025 18:04:20 +0800 Subject: [PATCH 15/22] Add LimitPerRecord for generic version Preload --- callbacks/preload.go | 8 +++- generics.go | 103 ++++++++++++++++++++++++++++++++--------- statement.go | 1 + tests/generics_test.go | 94 ++++++++++++++++++++++++++++++++++--- 4 files changed, 178 insertions(+), 28 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index fd8214bb2..607c22bcd 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) if len(values) != 0 { + tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values}) + for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { tx = fc(tx) @@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { + if len(inlineConds) > 0 { + tx = tx.Where(inlineConds[0], inlineConds[1:]...) + } + + if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil { return err } } diff --git a/generics.go b/generics.go index 7c7257f6e..1fab7078f 100644 --- a/generics.go +++ b/generics.go @@ -77,7 +77,7 @@ type PreloadBuilder interface { Limit(offset int) PreloadBuilder Offset(offset int) PreloadBuilder Order(value interface{}) PreloadBuilder - Scopes(scopes ...func(db *Statement)) PreloadBuilder + LimitPerRecord(num int) PreloadBuilder } type op func(*DB) *DB @@ -214,76 +214,78 @@ type joinBuilder struct { db *DB } -func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { +func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { +func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { +func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q joinBuilder) Select(columns ...string) JoinBuilder { +func (q *joinBuilder) Select(columns ...string) JoinBuilder { q.db.Select(columns) return q } -func (q joinBuilder) Omit(columns ...string) JoinBuilder { +func (q *joinBuilder) Omit(columns ...string) JoinBuilder { q.db.Omit(columns...) return q } type preloadBuilder struct { - db *DB + limitPerRecord int + db *DB } -func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { +func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { q.db.Where(query, args...) return q } -func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { +func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { q.db.Where(query, args...) return q } -func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { +func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { q.db.Where(query, args...) return q } -func (q preloadBuilder) Select(columns ...string) PreloadBuilder { +func (q *preloadBuilder) Select(columns ...string) PreloadBuilder { q.db.Select(columns) return q } -func (q preloadBuilder) Omit(columns ...string) PreloadBuilder { +func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder { q.db.Omit(columns...) return q } -func (q preloadBuilder) Limit(limit int) PreloadBuilder { +func (q *preloadBuilder) Limit(limit int) PreloadBuilder { q.db.Limit(limit) return q } -func (q preloadBuilder) Offset(offset int) PreloadBuilder { + +func (q *preloadBuilder) Offset(offset int) PreloadBuilder { q.db.Offset(offset) return q } -func (q preloadBuilder) Order(value interface{}) PreloadBuilder { + +func (q *preloadBuilder) Order(value interface{}) PreloadBuilder { q.db.Order(value) return q } -func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder { - for _, fc := range scopes { - fc(q.db.Statement) - } + +func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder { + q.limitPerRecord = num return q } @@ -295,7 +297,7 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)} if on != nil { - if err := on(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { + if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { db.AddError(err) } } @@ -390,10 +392,69 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err return db.Preload(association, func(tx *DB) *DB { q := preloadBuilder{db: tx.getInstance()} if query != nil { - if err := query(q); err != nil { + if err := query(&q); err != nil { db.AddError(err) } } + + relation, ok := db.Statement.Schema.Relationships.Relations[association] + if !ok { + db.AddError(fmt.Errorf("relation %s not found", association)) + } + + if q.limitPerRecord > 0 { + if relation.JoinTable != nil { + err := fmt.Errorf("many2many relation %s don't support LimitPerRecord", association) + tx.AddError(err) + return tx + } + + refColumns := []clause.Column{} + for _, rel := range relation.References { + if rel.OwnPrimaryKey { + refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName}) + } + } + + if len(refColumns) != 0 { + selects := q.db.Statement.Selects + selectExpr := clause.CommaExpression{} + if len(selects) == 0 { + selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}} + } else { + for _, column := range selects { + selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) + } + } + + partitionBy := clause.CommaExpression{} + for _, column := range refColumns { + partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}}) + } + + rnnColumn := clause.Column{Name: "gorm_preload_rnn"} + sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)" + vars := []interface{}{partitionBy} + if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok { + vars = append(vars, orderBy) + } else { + vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{ + Columns: []clause.OrderByColumn{ + {Column: clause.PrimaryColumn, Desc: false}, + }, + }}) + } + vars = append(vars, rnnColumn) + + selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars}) + + q.db.Clauses(clause.Select{ + Expression: selectExpr, + }) + + return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord) + } + } return q.db }) }) diff --git a/statement.go b/statement.go index 63f78006e..19cdbbafe 100644 --- a/statement.go +++ b/statement.go @@ -209,6 +209,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } case interface{ getInstance() *DB }: cv := v.getInstance() + subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() if cv.Statement.SQL.Len() > 0 { var ( diff --git a/tests/generics_test.go b/tests/generics_test.go index 2e0dbc28e..2f0f722be 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -277,7 +277,7 @@ func TestGenericsScopes(t *testing.T) { } } -func TestGenericsJoinsAndPreload(t *testing.T) { +func TestGenericsJoins(t *testing.T) { ctx := context.Background() db := gorm.G[User](DB) @@ -374,20 +374,32 @@ func TestGenericsJoinsAndPreload(t *testing.T) { if err == nil { t.Fatalf("Joins should got error, but got nil") } +} + +func TestGenericsPreloads(t *testing.T) { + ctx := context.Background() + db := gorm.G[User](DB) + + u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7}) + u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5}) + u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3}) + names := []string{u.Name, u2.Name, u3.Name} - // Preload - result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx) + db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) + + result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx) if err != nil { t.Fatalf("Preload failed: %v", err) } - if result3.Name != u.Name || result3.Company.Name != u.Company.Name { + + if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) { t.Fatalf("Preload expected %s, got %+v", u.Name, result) } results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error { db.Where("name = ?", u.Company.Name) return nil - }).Find(ctx) + }).Where("name in ?", names).Find(ctx) if err != nil { t.Fatalf("Preload failed: %v", err) } @@ -403,10 +415,80 @@ func TestGenericsJoinsAndPreload(t *testing.T) { _, err = db.Preload("Company", func(db gorm.PreloadBuilder) error { return errors.New("preload error") - }).Find(ctx) + }).Where("name in ?", names).Find(ctx) if err == nil { t.Fatalf("Preload should failed, but got nil") } + + results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { + db.LimitPerRecord(5) + return nil + }).Where("name in ?", names).Find(ctx) + + for _, result := range results { + if result.Name == u.Name { + if len(result.Pets) != len(u.Pets) { + t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) + } + } else if len(result.Pets) != 5 { + t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) + } + } + + if DB.Dialector.Name() == "sqlserver" { + // sqlserver doesn't support order by in subquery + return + } + results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { + db.Order("name desc").LimitPerRecord(5) + return nil + }).Where("name in ?", names).Find(ctx) + + for _, result := range results { + if result.Name == u.Name { + if len(result.Pets) != len(u.Pets) { + t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) + } + } else if len(result.Pets) != 5 { + t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) + } + for i := 1; i < len(result.Pets); i++ { + if result.Pets[i-1].Name < result.Pets[i].Name { + t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) + } + } + } + + results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { + db.Order("name").LimitPerRecord(5) + return nil + }).Preload("Friends", func(db gorm.PreloadBuilder) error { + db.Order("name") + return nil + }).Where("name in ?", names).Find(ctx) + + for _, result := range results { + if result.Name == u.Name { + if len(result.Pets) != len(u.Pets) { + t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) + } + if len(result.Friends) != len(u.Friends) { + t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) + } + } else if len(result.Pets) != 5 || len(result.Friends) == 0 { + t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) + } + for i := 1; i < len(result.Pets); i++ { + if result.Pets[i-1].Name > result.Pets[i].Name { + t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) + } + } + for i := 1; i < len(result.Pets); i++ { + if result.Pets[i-1].Name > result.Pets[i].Name { + t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) + } + } + } } func TestGenericsDistinct(t *testing.T) { From 304baabb12e2cac43605b3c7245a1ccef452b28e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 21 May 2025 22:07:20 +0800 Subject: [PATCH 16/22] fix tests for mysql 5.7 --- generics.go | 23 +++++++++-------------- tests/generics_test.go | 8 ++++++++ 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/generics.go b/generics.go index 1fab7078f..0b4d48b88 100644 --- a/generics.go +++ b/generics.go @@ -404,8 +404,7 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err if q.limitPerRecord > 0 { if relation.JoinTable != nil { - err := fmt.Errorf("many2many relation %s don't support LimitPerRecord", association) - tx.AddError(err) + tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association)) return tx } @@ -417,14 +416,13 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err } if len(refColumns) != 0 { - selects := q.db.Statement.Selects selectExpr := clause.CommaExpression{} - if len(selects) == 0 { + for _, column := range q.db.Statement.Selects { + selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) + } + + if len(selectExpr.Exprs) == 0 { selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}} - } else { - for _, column := range selects { - selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) - } } partitionBy := clause.CommaExpression{} @@ -439,22 +437,19 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err vars = append(vars, orderBy) } else { vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{ - Columns: []clause.OrderByColumn{ - {Column: clause.PrimaryColumn, Desc: false}, - }, + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, }}) } vars = append(vars, rnnColumn) selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars}) - q.db.Clauses(clause.Select{ - Expression: selectExpr, - }) + q.db.Clauses(clause.Select{Expression: selectExpr}) return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord) } } + return q.db }) }) diff --git a/tests/generics_test.go b/tests/generics_test.go index 2f0f722be..32881ce53 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -6,8 +6,10 @@ import ( "fmt" "reflect" "sort" + "strings" "testing" + "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" @@ -420,6 +422,12 @@ func TestGenericsPreloads(t *testing.T) { t.Fatalf("Preload should failed, but got nil") } + if DB.Dialector.Name() == "mysql" { + // mysql 5.7 doesn't support row_number() + if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { + return + } + } results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { db.LimitPerRecord(5) return nil From 774d957089bb15c18223c926d7d18e049b11374e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 May 2025 19:45:34 +0800 Subject: [PATCH 17/22] test for nested generic version Join/Preload --- callbacks/query.go | 42 ++++++------ generics.go | 21 +++++- scan.go | 4 +- tests/generics_test.go | 143 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 185 insertions(+), 25 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 56a5944a0..c8632cc56 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -110,7 +110,7 @@ func BuildQuerySQL(db *gorm.DB) { } } - specifiedRelationsName := make(map[string]interface{}) + specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable} for _, join := range db.Statement.Joins { if db.Statement.Schema != nil { var isRelations bool // is relations or raw sql @@ -124,12 +124,12 @@ func BuildQuerySQL(db *gorm.DB) { nestedJoinNames := strings.Split(join.Name, ".") if len(nestedJoinNames) > 1 { isNestedJoin := true - gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) currentRelations := db.Statement.Schema.Relationships.Relations for _, relname := range nestedJoinNames { // incomplete match, only treated as raw sql if relation, ok = currentRelations[relname]; ok { - gussNestedRelations = append(gussNestedRelations, relation) + guessNestedRelations = append(guessNestedRelations, relation) currentRelations = relation.FieldSchema.Relationships.Relations } else { isNestedJoin = false @@ -139,22 +139,13 @@ func BuildQuerySQL(db *gorm.DB) { if isNestedJoin { isRelations = true - relations = gussNestedRelations + relations = guessNestedRelations } } } if isRelations { - genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { - tableAliasName := join.Alias - - if tableAliasName == "" { - tableAliasName = relation.Name - if parentTableName != clause.CurrentTable { - tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) - } - } - + genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join { columnStmt := gorm.Statement{ Table: tableAliasName, DB: db, Schema: relation.FieldSchema, Selects: join.Selects, Omits: join.Omits, @@ -237,19 +228,24 @@ func BuildQuerySQL(db *gorm.DB) { } parentTableName := clause.CurrentTable - for _, rel := range relations { + for idx, rel := range relations { // joins table alias like "Manager, Company, Manager__Company" - nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) - if _, ok := specifiedRelationsName[nestedAlias]; !ok { - fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) - specifiedRelationsName[nestedAlias] = nil + curAliasName := rel.Name + if parentTableName != clause.CurrentTable { + curAliasName = utils.NestedRelationName(parentTableName, curAliasName) } - if parentTableName != clause.CurrentTable { - parentTableName = utils.NestedRelationName(parentTableName, rel.Name) - } else { - parentTableName = rel.Name + if _, ok := specifiedRelationsName[curAliasName]; !ok { + aliasName := curAliasName + if idx == len(relations)-1 && join.Alias != "" { + aliasName = join.Alias + } + + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel)) + specifiedRelationsName[curAliasName] = aliasName } + + parentTableName = curAliasName } } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ diff --git a/generics.go b/generics.go index 0b4d48b88..f2863dac7 100644 --- a/generics.go +++ b/generics.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "sort" "strings" "gorm.io/gorm/clause" @@ -341,6 +342,9 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable } db.Statement.Joins = append(db.Statement.Joins, j) + sort.Slice(db.Statement.Joins, func(i, j int) bool { + return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name + }) return db }) } @@ -399,7 +403,22 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { - db.AddError(fmt.Errorf("relation %s not found", association)) + if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { + relationships := db.Statement.Schema.Relationships + for _, field := range preloadFields { + var ok bool + relation, ok = relationships.Relations[field] + if ok { + relationships = relation.FieldSchema.Relationships + } else { + db.AddError(fmt.Errorf("relation %s not found", association)) + return nil + } + } + } else { + db.AddError(fmt.Errorf("relation %s not found", association)) + return nil + } } if q.limitPerRecord > 0 { diff --git a/scan.go b/scan.go index 624f822fa..9a99d0244 100644 --- a/scan.go +++ b/scan.go @@ -245,9 +245,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) { matchedFieldCount[column] = 1 } } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation + aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1]) for _, join := range db.Statement.Joins { - if join.Alias == names[0] { + if join.Alias == aliasName { names = append(strings.Split(join.Name, "."), names[len(names)-1]) + break } } diff --git a/tests/generics_test.go b/tests/generics_test.go index 32881ce53..876c7409e 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "sort" + "strconv" "strings" "testing" @@ -378,6 +379,82 @@ func TestGenericsJoins(t *testing.T) { } } +func TestGenericsNestedJoins(t *testing.T) { + users := []User{ + { + Name: "generics-nested-joins-1", + Manager: &User{ + Name: "generics-nested-joins-manager-1", + Company: Company{ + Name: "generics-nested-joins-manager-company-1", + }, + NamedPet: &Pet{ + Name: "generics-nested-joins-manager-namepet-1", + Toy: Toy{ + Name: "generics-nested-joins-manager-namepet-toy-1", + }, + }, + }, + NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}}, + }, + { + Name: "generics-nested-joins-2", + Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}), + NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}}, + }, + } + + ctx := context.Background() + db := gorm.G[User](DB) + db.CreateInBatches(ctx, &users, 100) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil). + Joins(clause.LeftJoin.Association("Manager.Company"), nil). + Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil). + Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil). + Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil). + Where(map[string]any{"id": userIDs}).Find(ctx) + + if err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + // user + CheckUser(t, user, users2[idx]) + if users2[idx].Manager == nil { + t.Fatalf("Failed to load Manager") + } + // manager + CheckUser(t, *user.Manager, *users2[idx].Manager) + // user pet + if users2[idx].NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) + // manager pet + if users2[idx].Manager.NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) + } +} + func TestGenericsPreloads(t *testing.T) { ctx := context.Background() db := gorm.G[User](DB) @@ -499,6 +576,35 @@ func TestGenericsPreloads(t *testing.T) { } } +func TestGenericsNestedPreloads(t *testing.T) { + user := *GetUser("generics_nested_preload", Config{Pets: 2}) + user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})} + + ctx := context.Background() + db := gorm.G[User](DB) + + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} + } + + if err := db.Create(ctx, &user); err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { + db.LimitPerRecord(3) + return nil + }).Where(user.ID).Take(ctx) + if err != nil { + t.Errorf("failed to nested preload user") + } + CheckUser(t, user2, user) + + if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 { + t.Errorf("failed to nested preload with limit per record") + } +} + func TestGenericsDistinct(t *testing.T) { ctx := context.Background() @@ -586,3 +692,40 @@ func TestGenericsSubQuery(t *testing.T) { t.Errorf("Three users should be found, instead found %d", len(results)) } } + +func TestGenericsUpsert(t *testing.T) { + ctx := context.Background() + lang := Language{Code: "upsert", Name: "Upsert"} + + if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + lang2 := Language{Code: "upsert", Name: "Upsert"} + if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx) + if err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } + + lang3 := Language{Code: "upsert", Name: "Upsert"} + if err := gorm.G[Language](DB, clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), + }).Create(ctx, &lang3); err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if langs[0].Name != "upsert-new" { + t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) + } +} From ddaee81548c8ffe79d9e67e56c56788324d6788b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 May 2025 20:19:23 +0800 Subject: [PATCH 18/22] Add WithResult support for generics API --- callbacks/create.go | 10 ++++++++++ callbacks/delete.go | 10 ++++++++++ callbacks/query.go | 4 ++++ callbacks/raw.go | 5 +++++ callbacks/update.go | 9 +++++++++ generics.go | 19 ++++++++++++++++++- statement.go | 2 ++ tests/generics_test.go | 15 +++++++++++++++ 8 files changed, 73 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index 8b7846b63..d8701f511 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) { db.AddError(rows.Close()) }() gorm.Scan(rows, db, mode) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } } return @@ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } + if db.RowsAffected == 0 { return } diff --git a/callbacks/delete.go b/callbacks/delete.go index 84f446a3f..07ed6feef 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if !ok { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } } return @@ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { gorm.Scan(rows, db, mode) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } db.AddError(rows.Close()) } } diff --git a/callbacks/query.go b/callbacks/query.go index c8632cc56..548bf7092 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -25,6 +25,10 @@ func Query(db *gorm.DB) { db.AddError(rows.Close()) }() gorm.Scan(rows, db, 0) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } } } } diff --git a/callbacks/raw.go b/callbacks/raw.go index 013e638cb..3bb647c43 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } } } diff --git a/callbacks/update.go b/callbacks/update.go index 7cde7f619..8e2782e16 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) { gorm.Scan(rows, db, mode) db.Statement.Dest = dest db.AddError(rows.Close()) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } } } else { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) { if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() } + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } } } } diff --git a/generics.go b/generics.go index f2863dac7..54ccfca01 100644 --- a/generics.go +++ b/generics.go @@ -11,6 +11,23 @@ import ( "gorm.io/gorm/logger" ) +type result struct { + Result sql.Result + RowsAffected int64 +} + +func (info *result) ModifyStatement(stmt *Statement) { + stmt.Result = info +} + +// Build implements clause.Expression interface +func (result) Build(clause.Builder) { +} + +func WithResult() *result { + return &result{} +} + type Interface[T any] interface { Raw(sql string, values ...interface{}) ExecInterface[T] Exec(ctx context.Context, sql string, values ...interface{}) error @@ -85,7 +102,7 @@ type op func(*DB) *DB func G[T any](db *DB, opts ...clause.Expression) Interface[T] { v := &g[T]{ - db: db.Session(&Session{NewDB: true}), + db: db, ops: make([]op, 0, 5), } diff --git a/statement.go b/statement.go index 19cdbbafe..c6183724e 100644 --- a/statement.go +++ b/statement.go @@ -47,6 +47,7 @@ type Statement struct { attrs []interface{} assigns []interface{} scopes []func(*DB) *DB + Result *result } type join struct { @@ -532,6 +533,7 @@ func (stmt *Statement) clone() *Statement { Context: stmt.Context, RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, SkipHooks: stmt.SkipHooks, + Result: stmt.Result, } if stmt.SQL.Len() > 0 { diff --git a/tests/generics_test.go b/tests/generics_test.go index 876c7409e..5ab76ae70 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -729,3 +729,18 @@ func TestGenericsUpsert(t *testing.T) { t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) } } + +func TestGenericsWithResult(t *testing.T) { + ctx := context.Background() + users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}} + + result := gorm.WithResult() + err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2) + if err != nil { + t.Errorf("failed to create users WithResult") + } + + if result.RowsAffected != 2 { + t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2) + } +} From 8ced5498dfe209e3c88839885368fce602c42f7e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 May 2025 22:31:40 +0800 Subject: [PATCH 19/22] test reuse generics db conditions --- tests/generics_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/generics_test.go b/tests/generics_test.go index 5ab76ae70..f89678b92 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -8,6 +8,7 @@ import ( "sort" "strconv" "strings" + "sync" "testing" "gorm.io/driver/mysql" @@ -744,3 +745,42 @@ func TestGenericsWithResult(t *testing.T) { t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2) } } + +func TestGenericsReuse(t *testing.T) { + ctx := context.Background() + users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}} + + err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2) + if err != nil { + t.Errorf("failed to create users") + } + + reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%") + + sg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + sg.Add(1) + + go func() { + if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil { + t.Errorf("failed to find user, got error: %v", err) + } else if u1.Name != users[0].Name || u1.ID != users[0].ID { + t.Errorf("found invalid user, got %v, expect %v", u1, users[0]) + } + + if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil { + t.Errorf("failed to find user, got error: %v", err) + } else if u2.Name != users[1].Name || u2.ID != users[1].ID { + t.Errorf("found invalid user, got %v, expect %v", u2, users[1]) + } + + if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil { + t.Errorf("failed to find user, got error: %v", err) + } else if len(users) != 2 { + t.Errorf("should find 2 users, but got %d", len(users)) + } + sg.Done() + }() + } + sg.Wait() +} From 0305e0d63e0d0312ea79eb182d39d41cd8a4fa41 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 May 2025 22:57:43 +0800 Subject: [PATCH 20/22] fix data race --- generics.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generics.go b/generics.go index 54ccfca01..52492c8f3 100644 --- a/generics.go +++ b/generics.go @@ -174,11 +174,11 @@ func (c chainG[T]) getInstance() *DB { return c.g.apply(context.Background()).Model(r).getInstance() } -func (c chainG[T]) with(op op) chainG[T] { +func (c chainG[T]) with(v op) chainG[T] { return chainG[T]{ execG: execG[T]{g: &g[T]{ db: c.g.db, - ops: append(c.g.ops, op), + ops: append(append([]op(nil), c.g.ops...), v), }}, } } From 4ee59e1d87a031a752f42f61ce7a1be1c92f10ca Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 22 May 2025 23:03:21 +0800 Subject: [PATCH 21/22] remove ExampleLRU test --- tests/lru_test.go | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/tests/lru_test.go b/tests/lru_test.go index 3eaef5de5..1c8f1afd9 100644 --- a/tests/lru_test.go +++ b/tests/lru_test.go @@ -520,38 +520,6 @@ func TestLRURemoveOldest(t *testing.T) { } } -func ExampleLRU() { - // make cache with 10ms TTL and 5 max keys - cache := lru.NewLRU[string, string](5, nil, time.Millisecond*10) - - // set value under key1. - cache.Add("key1", "val1") - - // get value under key1 - r, ok := cache.Get("key1") - - // check for OK value - if ok { - fmt.Printf("value before expiration is found: %v, value: %q\n", ok, r) - } - - // wait for cache to expire - time.Sleep(time.Millisecond * 100) - - // get value under key1 after key expiration - r, ok = cache.Get("key1") - fmt.Printf("value after expiration is found: %v, value: %q\n", ok, r) - - // set value under key2, would evict old entry because it is already expired. - cache.Add("key2", "val2") - - fmt.Printf("Cache len: %d\n", cache.Len()) - // Output: - // value before expiration is found: true, value: "val1" - // value after expiration is found: false, value: "" - // Cache len: 1 -} - func getRand(tb testing.TB) int64 { out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) if err != nil { From 4db3fde9c568f1486b066ada8d829e3b01a8a159 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 23 May 2025 18:06:34 +0800 Subject: [PATCH 22/22] Add default transaction timeout support --- finisher_api.go | 12 +++++-- generics.go | 6 +++- gorm.go | 6 ++-- tests/generics_test.go | 70 +++++++++++++++++++++++++++++++++++++-- tests/transaction_test.go | 16 +++++++-- 5 files changed, 101 insertions(+), 9 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 6802945cc..57809d17a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" @@ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { opt = opts[0] } + ctx := tx.Statement.Context + if _, ok := ctx.Deadline(); !ok { + if db.Config.DefaultTransactionTimeout > 0 { + ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + } + } + switch beginner := tx.Statement.ConnPool.(type) { case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) default: err = ErrInvalidTransaction } diff --git a/generics.go b/generics.go index 52492c8f3..ad2d063f4 100644 --- a/generics.go +++ b/generics.go @@ -127,7 +127,11 @@ type g[T any] struct { } func (g *g[T]) apply(ctx context.Context) *DB { - db := g.db.Session(&Session{NewDB: true, Context: ctx}).getInstance() + db := g.db + if !db.DryRun { + db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance() + } + for _, op := range g.ops { db = op(db) } diff --git a/gorm.go b/gorm.go index 63a28b37f..27e4caa0a 100644 --- a/gorm.go +++ b/gorm.go @@ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt" type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // You can disable it by setting `SkipDefaultTransaction` to true - SkipDefaultTransaction bool + SkipDefaultTransaction bool + DefaultTransactionTimeout time.Duration + // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer // FullSaveAssociations full save associations @@ -519,7 +521,7 @@ func (db *DB) Use(plugin Plugin) error { // .First(&User{}) // }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { - tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) + tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance()) stmt := tx.Statement return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) diff --git a/tests/generics_test.go b/tests/generics_test.go index f89678b92..39decb3f6 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "regexp" "sort" "strconv" "strings" @@ -593,15 +594,37 @@ func TestGenericsNestedPreloads(t *testing.T) { } user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { - db.LimitPerRecord(3) return nil }).Where(user.ID).Take(ctx) if err != nil { t.Errorf("failed to nested preload user") } CheckUser(t, user2, user) + if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 { + t.Fatalf("failed to nested preload") + } + + if DB.Dialector.Name() == "mysql" { + // mysql 5.7 doesn't support row_number() + if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { + return + } + } + if DB.Dialector.Name() == "sqlserver" { + // sqlserver doesn't support order by in subquery + return + } - if len(user2.Friends) != 1 || len(user2.Friends[0].Pets) != 3 { + user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { + db.LimitPerRecord(3) + return nil + }).Where(user.ID).Take(ctx) + if err != nil { + t.Errorf("failed to nested preload user") + } + CheckUser(t, user3, user) + + if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 { t.Errorf("failed to nested preload with limit per record") } } @@ -784,3 +807,46 @@ func TestGenericsReuse(t *testing.T) { } sg.Wait() } + +func TestGenericsWithTransaction(t *testing.T) { + ctx := context.Background() + tx := DB.Begin() + if tx.Error != nil { + t.Fatalf("failed to begin transaction: %v", tx.Error) + } + + users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}} + err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2) + + count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") + if err != nil { + t.Fatalf("Count failed: %v", err) + } + if count != 2 { + t.Errorf("expected 2 records, got %d", count) + } + + if err := tx.Rollback().Error; err != nil { + t.Fatalf("failed to rollback transaction: %v", err) + } + + count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") + if err != nil { + t.Fatalf("Count failed: %v", err) + } + if count2 != 0 { + t.Errorf("expected 0 records after rollback, got %d", count2) + } +} + +func TestGenericsToSQL(t *testing.T) { + ctx := context.Background() + sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + gorm.G[User](tx).Limit(10).Find(ctx) + return tx + }) + + if !regexp.MustCompile("SELECT \\* FROM `users`.* 10").MatchString(sql) { + t.Errorf("ToSQL: got wrong sql with Generics API %v", sql) + } +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 9f0f067c8..80d3a7fcb 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) { return tx2.Scan(&User{}).Error }) }) - if err != nil { t.Error(err) } @@ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) { return tx3.Where("user_id", user.ID).Delete(&Account{}).Error }) }) - if err != nil { t.Error(err) } } + +func TestTransactionWithDefaultTimeout(t *testing.T) { + db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + tx := db.Begin() + time.Sleep(3 * time.Second) + if err = tx.Find(&User{}).Error; err == nil { + t.Errorf("should return error when transaction timeout, got error %v", err) + } +}