Skip to content

Maintain consistency in transactional DBs #326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (c *Cassandra) Unlock() error {
return nil
}

func (c *Cassandra) Run(migration io.Reader) error {
func (c *Cassandra) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down Expand Up @@ -245,6 +245,10 @@ func (c *Cassandra) Drop() error {
return nil
}

func (c *Cassandra) Transactional() bool {
return false
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Cassandra type.
Expand Down
8 changes: 6 additions & 2 deletions database/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (ch *ClickHouse) init() error {
return ch.ensureVersionTable()
}

func (ch *ClickHouse) Run(r io.Reader) error {
func (ch *ClickHouse) Run(r io.Reader, version int) error {
migration, err := ioutil.ReadAll(r)
if err != nil {
return err
Expand Down Expand Up @@ -193,7 +193,7 @@ func (ch *ClickHouse) ensureVersionTable() (err error) {
// if not, create the empty migration table
query = `
CREATE TABLE ` + ch.config.MigrationsTable + ` (
version Int64,
version Int64,
dirty UInt8,
sequence UInt64
) Engine=TinyLog
Expand Down Expand Up @@ -231,6 +231,10 @@ func (ch *ClickHouse) Drop() (err error) {
return nil
}

func (ch *ClickHouse) Transactional() bool {
return false
}

func (ch *ClickHouse) Lock() error { return nil }
func (ch *ClickHouse) Unlock() error { return nil }
func (ch *ClickHouse) Close() error { return ch.conn.Close() }
6 changes: 5 additions & 1 deletion database/cockroachdb/cockroachdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func (c *CockroachDb) Unlock() error {
return nil
}

func (c *CockroachDb) Run(migration io.Reader) error {
func (c *CockroachDb) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down Expand Up @@ -311,6 +311,10 @@ func (c *CockroachDb) Drop() (err error) {
return nil
}

func (c *CockroachDb) Transactional() bool {
return false
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the CockroachDb type.
Expand Down
5 changes: 4 additions & 1 deletion database/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type Driver interface {
Unlock() error

// Run applies a migration to the database. migration is garantueed to be not nil.
Run(migration io.Reader) error
Run(migration io.Reader, version int) error

// SetVersion saves version and dirty state.
// Migrate will call this function before and after each call to Run.
Expand All @@ -78,6 +78,9 @@ type Driver interface {
// Note that this is a breaking action, a new call to Open() is necessary to
// ensure subsequent calls work as expected.
Drop() error

// Whether or not this driver supports transactions.
Transactional() bool
}

// Open returns a new driver instance.
Expand Down
6 changes: 5 additions & 1 deletion database/firebird/firebird.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (f *Firebird) Unlock() error {
return nil
}

func (f *Firebird) Run(migration io.Reader) error {
func (f *Firebird) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down Expand Up @@ -210,6 +210,10 @@ func (f *Firebird) Drop() (err error) {
return nil
}

func (f *Firebird) Transactional() bool {
return false
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
func (f *Firebird) ensureVersionTable() (err error) {
if err = f.Lock(); err != nil {
Expand Down
6 changes: 5 additions & 1 deletion database/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (m *Mongo) Version() (version int, dirty bool, err error) {
}
}

func (m *Mongo) Run(migration io.Reader) error {
func (m *Mongo) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down Expand Up @@ -182,6 +182,10 @@ func (m *Mongo) Drop() error {
return m.db.Drop(context.TODO())
}

func (m *Mongo) Transactional() bool {
return false
}

func (m *Mongo) Lock() error {
return nil
}
Expand Down
6 changes: 5 additions & 1 deletion database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func (m *Mysql) Unlock() error {
return nil
}

func (m *Mysql) Run(migration io.Reader) error {
func (m *Mysql) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down Expand Up @@ -387,6 +387,10 @@ func (m *Mysql) Drop() (err error) {
return nil
}

func (m *Mysql) Transactional() bool {
return false
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Mysql type.
Expand Down
22 changes: 20 additions & 2 deletions database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,29 @@ func (p *Postgres) Unlock() error {
return nil
}

func (p *Postgres) Run(migration io.Reader) error {
func (p *Postgres) Run(migration io.Reader, targetVersion int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
}

setDirty := `
TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) + `;
INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES (` + strconv.Itoa(targetVersion) + `, true);`
setClean := `
TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) + `;
INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES (` + strconv.Itoa(targetVersion) + `, false);`

// run migration
query := string(migr[:])
query := `
BEGIN;
` + string(migr[:]) + `
` + setClean + `
COMMIT;`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
if _, errDirty := p.conn.ExecContext(context.Background(), `BEGIN;`+setDirty+`COMMIT;`); errDirty != nil {
err = multierror.Append(err, errDirty)
}
if pgErr, ok := err.(*pq.Error); ok {
var line uint
var col uint
Expand Down Expand Up @@ -339,6 +353,10 @@ func (p *Postgres) Drop() (err error) {
return nil
}

func (p *Postgres) Transactional() bool {
return true
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Postgres type.
Expand Down
5 changes: 4 additions & 1 deletion database/ql/ql.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ func (m *Ql) Drop() (err error) {

return nil
}
func (m *Ql) Transactional() bool {
return false
}
func (m *Ql) Lock() error {
if m.isLocked {
return database.ErrLocked
Expand All @@ -174,7 +177,7 @@ func (m *Ql) Unlock() error {
m.isLocked = false
return nil
}
func (m *Ql) Run(migration io.Reader) error {
func (m *Ql) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down
6 changes: 5 additions & 1 deletion database/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (p *Redshift) Unlock() error {
return nil
}

func (p *Redshift) Run(migration io.Reader) error {
func (p *Redshift) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down Expand Up @@ -295,6 +295,10 @@ func (p *Redshift) Drop() (err error) {
return nil
}

func (p *Redshift) Transactional() bool {
return false
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Redshift type.
Expand Down
6 changes: 5 additions & 1 deletion database/spanner/spanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (s *Spanner) Unlock() error {
}

// Run implements database.Driver
func (s *Spanner) Run(migration io.Reader) error {
func (s *Spanner) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down Expand Up @@ -257,6 +257,10 @@ func (s *Spanner) Drop() error {
return nil
}

func (s *Spanner) Transactional() bool {
return false
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Spanner type.
Expand Down
6 changes: 5 additions & 1 deletion database/sqlite3/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ func (m *Sqlite) Drop() (err error) {
return nil
}

func (m *Sqlite) Transactional() bool {
return false
}

func (m *Sqlite) Lock() error {
if m.isLocked {
return database.ErrLocked
Expand All @@ -173,7 +177,7 @@ func (m *Sqlite) Unlock() error {
return nil
}

func (m *Sqlite) Run(migration io.Reader) error {
func (m *Sqlite) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down
6 changes: 5 additions & 1 deletion database/sqlserver/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (ss *SQLServer) Unlock() error {
}

// Run the migrations for the database
func (ss *SQLServer) Run(migration io.Reader) error {
func (ss *SQLServer) Run(migration io.Reader, version int) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down Expand Up @@ -312,6 +312,10 @@ func (ss *SQLServer) Drop() error {
return nil
}

func (ss *SQLServer) Transactional() bool {
return false
}

func (ss *SQLServer) ensureVersionTable() (err error) {
if err = ss.Lock(); err != nil {
return err
Expand Down
6 changes: 5 additions & 1 deletion database/stub/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (s *Stub) Unlock() error {
return nil
}

func (s *Stub) Run(migration io.Reader) error {
func (s *Stub) Run(migration io.Reader, version int) error {
m, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand Down Expand Up @@ -90,6 +90,10 @@ func (s *Stub) Drop() error {
return nil
}

func (s *Stub) Transactional() bool {
return false
}

func (s *Stub) EqualSequence(seq []string) bool {
return reflect.DeepEqual(seq, s.MigrationSequence)
}
18 changes: 13 additions & 5 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ func (m *Migrate) Drop() error {
return m.unlock()
}

func (m *Migrate) Transactional() bool {
return false
}

// Run runs any migration provided by you against the database.
// It does not check any currently active version in database.
// Usually you don't need this function at all. Use Migrate,
Expand Down Expand Up @@ -737,20 +741,24 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
migr := r

// set version with dirty state
if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
return err
if !m.databaseDrv.Transactional() {
if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
return err
}
}

if migr.Body != nil {
m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
if err := m.databaseDrv.Run(migr.BufferedBody, migr.TargetVersion); err != nil {
return err
}
}

// set clean state
if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil {
return err
if !m.databaseDrv.Transactional() {
if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil {
return err
}
}

endTime := time.Now()
Expand Down