diff --git a/internal/db/dbutil/dbutil.go b/internal/db/dbutil/dbutil.go index 1a72b1f9a899..63e417155dfa 100644 --- a/internal/db/dbutil/dbutil.go +++ b/internal/db/dbutil/dbutil.go @@ -9,6 +9,7 @@ import ( "net/url" "os" "strconv" + "strings" "time" // Register driver @@ -110,8 +111,28 @@ func NewDB(dsn, app string) (*sql.DB, error) { return db, nil } +// injectVersionUpdate fixes the dirty state (set by golang-migrate) after a +// successful migration. If the frontend starts a migration that will turn out +// to be successful but does not stay alive for the duration of the query due to +// a startup timeout, there will be no chance to set the new version or unset +// the dirty flag. This function ensures that each successful migration sets the +// version and dirty flag itself, without requiring the frontend to be alive +// once the migration is committed. +// +// See https://github.com/golang-migrate/migrate/issues/325. +func injectVersionUpdate(f bindata.AssetFunc) bindata.AssetFunc { + return func(name string) ([]byte, error) { + oldContents, err := f(name) + if err != nil { + return nil, err + } + newContents := strings.Replace(string(oldContents), "COMMIT;", fmt.Sprintf("UPDATE schema_migrations SET dirty=false;\nCOMMIT;"), 1) + return []byte(newContents), nil + } +} + func NewMigrationSourceLoader(dataSource string) *bindata.AssetSource { - return bindata.Resource(migrations.AssetNames(), migrations.Asset) + return bindata.Resource(migrations.AssetNames(), injectVersionUpdate(migrations.Asset)) } func NewMigrate(db *sql.DB, dataSource string) (*migrate.Migrate, error) { diff --git a/internal/db/dbutil/dbutil_test.go b/internal/db/dbutil/dbutil_test.go index 1b8d626d95e3..c6e22443f42e 100644 --- a/internal/db/dbutil/dbutil_test.go +++ b/internal/db/dbutil/dbutil_test.go @@ -67,3 +67,15 @@ func TestPostgresDSN(t *testing.T) { }) } } + +func TestInjectVersionUpdate(t *testing.T) { + gotContents, err := injectVersionUpdate(func(name string) ([]byte, error) { return []byte("BEGIN;\n-- some statements...\nCOMMIT;"), nil })("migrations/100_dummy.up.sql") + if err != nil { + t.Fatal(err) + } + got := string(gotContents) + want := "BEGIN;\n-- some statements...\nUPDATE schema_migrations SET dirty=false;\nCOMMIT;" + if got != want { + t.Errorf("incorrect contents: got != want\ngot: %v\nwant: %v", got, want) + } +} diff --git a/migrations/migrations_test.go b/migrations/migrations_test.go index 006c6b16764f..152512f409aa 100644 --- a/migrations/migrations_test.go +++ b/migrations/migrations_test.go @@ -1,6 +1,7 @@ package migrations_test import ( + "io/ioutil" "path/filepath" "reflect" "sort" @@ -39,6 +40,26 @@ func TestIDConstraints(t *testing.T) { } } +// Makes sure that every migration contains exactly one `COMMIT;` so that +// `InjectVersionUpdate` in internal/db/dbutil/dbutil.go is guaranteed to succeed. +func TestTransactions(t *testing.T) { + ups, err := filepath.Glob("*.up.sql") + if err != nil { + t.Fatal(err) + } + + for _, name := range ups { + contents, err := ioutil.ReadFile(name) + if err != nil { + t.Fatalf("failed to read migration file %q: %v", name, err) + } + commitCount := strings.Count(string(contents), "COMMIT;") + if commitCount != 1 { + t.Fatalf("expected migration %q to contain exactly one COMMIT; but it contains %d", name, commitCount) + } + } +} + func TestNeedsGenerate(t *testing.T) { want, err := filepath.Glob("*.sql") if err != nil {