Skip to content

Commit 238507d

Browse files
committed
Maintain consistency in transactional DBs
1 parent 9e2a0cd commit 238507d

File tree

15 files changed

+97
-21
lines changed

15 files changed

+97
-21
lines changed

database/cassandra/cassandra.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ func (c *Cassandra) Unlock() error {
164164
return nil
165165
}
166166

167-
func (c *Cassandra) Run(migration io.Reader) error {
167+
func (c *Cassandra) Run(migration io.Reader, version int) error {
168168
migr, err := ioutil.ReadAll(migration)
169169
if err != nil {
170170
return err
@@ -245,6 +245,10 @@ func (c *Cassandra) Drop() error {
245245
return nil
246246
}
247247

248+
func (c *Cassandra) Transactional() bool {
249+
return false
250+
}
251+
248252
// ensureVersionTable checks if versions table exists and, if not, creates it.
249253
// Note that this function locks the database, which deviates from the usual
250254
// convention of "caller locks" in the Cassandra type.

database/clickhouse/clickhouse.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func (ch *ClickHouse) init() error {
9696
return ch.ensureVersionTable()
9797
}
9898

99-
func (ch *ClickHouse) Run(r io.Reader) error {
99+
func (ch *ClickHouse) Run(r io.Reader, version int) error {
100100
migration, err := ioutil.ReadAll(r)
101101
if err != nil {
102102
return err
@@ -193,7 +193,7 @@ func (ch *ClickHouse) ensureVersionTable() (err error) {
193193
// if not, create the empty migration table
194194
query = `
195195
CREATE TABLE ` + ch.config.MigrationsTable + ` (
196-
version Int64,
196+
version Int64,
197197
dirty UInt8,
198198
sequence UInt64
199199
) Engine=TinyLog
@@ -231,6 +231,10 @@ func (ch *ClickHouse) Drop() (err error) {
231231
return nil
232232
}
233233

234+
func (ch *ClickHouse) Transactional() bool {
235+
return false
236+
}
237+
234238
func (ch *ClickHouse) Lock() error { return nil }
235239
func (ch *ClickHouse) Unlock() error { return nil }
236240
func (ch *ClickHouse) Close() error { return ch.conn.Close() }

database/cockroachdb/cockroachdb.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ func (c *CockroachDb) Unlock() error {
219219
return nil
220220
}
221221

222-
func (c *CockroachDb) Run(migration io.Reader) error {
222+
func (c *CockroachDb) Run(migration io.Reader, version int) error {
223223
migr, err := ioutil.ReadAll(migration)
224224
if err != nil {
225225
return err
@@ -311,6 +311,10 @@ func (c *CockroachDb) Drop() (err error) {
311311
return nil
312312
}
313313

314+
func (c *CockroachDb) Transactional() bool {
315+
return false
316+
}
317+
314318
// ensureVersionTable checks if versions table exists and, if not, creates it.
315319
// Note that this function locks the database, which deviates from the usual
316320
// convention of "caller locks" in the CockroachDb type.

database/driver.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ type Driver interface {
6262
Unlock() error
6363

6464
// Run applies a migration to the database. migration is garantueed to be not nil.
65-
Run(migration io.Reader) error
65+
Run(migration io.Reader, version int) error
6666

6767
// SetVersion saves version and dirty state.
6868
// Migrate will call this function before and after each call to Run.
@@ -78,6 +78,9 @@ type Driver interface {
7878
// Note that this is a breaking action, a new call to Open() is necessary to
7979
// ensure subsequent calls work as expected.
8080
Drop() error
81+
82+
// Whether or not this driver supports transactions.
83+
Transactional() bool
8184
}
8285

8386
// Open returns a new driver instance.

database/firebird/firebird.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func (f *Firebird) Unlock() error {
118118
return nil
119119
}
120120

121-
func (f *Firebird) Run(migration io.Reader) error {
121+
func (f *Firebird) Run(migration io.Reader, version int) error {
122122
migr, err := ioutil.ReadAll(migration)
123123
if err != nil {
124124
return err
@@ -210,6 +210,10 @@ func (f *Firebird) Drop() (err error) {
210210
return nil
211211
}
212212

213+
func (f *Firebird) Transactional() bool {
214+
return false
215+
}
216+
213217
// ensureVersionTable checks if versions table exists and, if not, creates it.
214218
func (f *Firebird) ensureVersionTable() (err error) {
215219
if err = f.Lock(); err != nil {

database/mongodb/mongodb.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func (m *Mongo) Version() (version int, dirty bool, err error) {
121121
}
122122
}
123123

124-
func (m *Mongo) Run(migration io.Reader) error {
124+
func (m *Mongo) Run(migration io.Reader, version int) error {
125125
migr, err := ioutil.ReadAll(migration)
126126
if err != nil {
127127
return err
@@ -182,6 +182,10 @@ func (m *Mongo) Drop() error {
182182
return m.db.Drop(context.TODO())
183183
}
184184

185+
func (m *Mongo) Transactional() bool {
186+
return false
187+
}
188+
185189
func (m *Mongo) Lock() error {
186190
return nil
187191
}

database/mysql/mysql.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ func (m *Mysql) Unlock() error {
273273
return nil
274274
}
275275

276-
func (m *Mysql) Run(migration io.Reader) error {
276+
func (m *Mysql) Run(migration io.Reader, version int) error {
277277
migr, err := ioutil.ReadAll(migration)
278278
if err != nil {
279279
return err
@@ -387,6 +387,10 @@ func (m *Mysql) Drop() (err error) {
387387
return nil
388388
}
389389

390+
func (m *Mysql) Transactional() bool {
391+
return false
392+
}
393+
390394
// ensureVersionTable checks if versions table exists and, if not, creates it.
391395
// Note that this function locks the database, which deviates from the usual
392396
// convention of "caller locks" in the Mysql type.

database/postgres/postgres.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,29 @@ func (p *Postgres) Unlock() error {
182182
return nil
183183
}
184184

185-
func (p *Postgres) Run(migration io.Reader) error {
185+
func (p *Postgres) Run(migration io.Reader, targetVersion int) error {
186186
migr, err := ioutil.ReadAll(migration)
187187
if err != nil {
188188
return err
189189
}
190190

191+
setDirty := `
192+
TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) + `;
193+
INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES (` + strconv.Itoa(targetVersion) + `, true);`
194+
setClean := `
195+
TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) + `;
196+
INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES (` + strconv.Itoa(targetVersion) + `, false);`
197+
191198
// run migration
192-
query := string(migr[:])
199+
query := `
200+
BEGIN;
201+
` + string(migr[:]) + `
202+
` + setClean + `
203+
COMMIT;`
193204
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
205+
if _, errDirty := p.conn.ExecContext(context.Background(), `BEGIN;`+setDirty+`COMMIT;`); errDirty != nil {
206+
err = multierror.Append(err, errDirty)
207+
}
194208
if pgErr, ok := err.(*pq.Error); ok {
195209
var line uint
196210
var col uint
@@ -339,6 +353,10 @@ func (p *Postgres) Drop() (err error) {
339353
return nil
340354
}
341355

356+
func (p *Postgres) Transactional() bool {
357+
return true
358+
}
359+
342360
// ensureVersionTable checks if versions table exists and, if not, creates it.
343361
// Note that this function locks the database, which deviates from the usual
344362
// convention of "caller locks" in the Postgres type.

database/ql/ql.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ func (m *Ql) Drop() (err error) {
160160

161161
return nil
162162
}
163+
func (m *Ql) Transactional() bool {
164+
return false
165+
}
163166
func (m *Ql) Lock() error {
164167
if m.isLocked {
165168
return database.ErrLocked
@@ -174,7 +177,7 @@ func (m *Ql) Unlock() error {
174177
m.isLocked = false
175178
return nil
176179
}
177-
func (m *Ql) Run(migration io.Reader) error {
180+
func (m *Ql) Run(migration io.Reader, version int) error {
178181
migr, err := ioutil.ReadAll(migration)
179182
if err != nil {
180183
return err

database/redshift/redshift.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func (p *Redshift) Unlock() error {
138138
return nil
139139
}
140140

141-
func (p *Redshift) Run(migration io.Reader) error {
141+
func (p *Redshift) Run(migration io.Reader, version int) error {
142142
migr, err := ioutil.ReadAll(migration)
143143
if err != nil {
144144
return err
@@ -295,6 +295,10 @@ func (p *Redshift) Drop() (err error) {
295295
return nil
296296
}
297297

298+
func (p *Redshift) Transactional() bool {
299+
return false
300+
}
301+
298302
// ensureVersionTable checks if versions table exists and, if not, creates it.
299303
// Note that this function locks the database, which deviates from the usual
300304
// convention of "caller locks" in the Redshift type.

database/spanner/spanner.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func (s *Spanner) Unlock() error {
135135
}
136136

137137
// Run implements database.Driver
138-
func (s *Spanner) Run(migration io.Reader) error {
138+
func (s *Spanner) Run(migration io.Reader, version int) error {
139139
migr, err := ioutil.ReadAll(migration)
140140
if err != nil {
141141
return err
@@ -257,6 +257,10 @@ func (s *Spanner) Drop() error {
257257
return nil
258258
}
259259

260+
func (s *Spanner) Transactional() bool {
261+
return false
262+
}
263+
260264
// ensureVersionTable checks if versions table exists and, if not, creates it.
261265
// Note that this function locks the database, which deviates from the usual
262266
// convention of "caller locks" in the Spanner type.

database/sqlite3/sqlite3.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ func (m *Sqlite) Drop() (err error) {
157157
return nil
158158
}
159159

160+
func (m *Sqlite) Transactional() bool {
161+
return false
162+
}
163+
160164
func (m *Sqlite) Lock() error {
161165
if m.isLocked {
162166
return database.ErrLocked
@@ -173,7 +177,7 @@ func (m *Sqlite) Unlock() error {
173177
return nil
174178
}
175179

176-
func (m *Sqlite) Run(migration io.Reader) error {
180+
func (m *Sqlite) Run(migration io.Reader, version int) error {
177181
migr, err := ioutil.ReadAll(migration)
178182
if err != nil {
179183
return err

database/sqlserver/sqlserver.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ func (ss *SQLServer) Unlock() error {
201201
}
202202

203203
// Run the migrations for the database
204-
func (ss *SQLServer) Run(migration io.Reader) error {
204+
func (ss *SQLServer) Run(migration io.Reader, version int) error {
205205
migr, err := ioutil.ReadAll(migration)
206206
if err != nil {
207207
return err
@@ -312,6 +312,10 @@ func (ss *SQLServer) Drop() error {
312312
return nil
313313
}
314314

315+
func (ss *SQLServer) Transactional() bool {
316+
return false
317+
}
318+
315319
func (ss *SQLServer) ensureVersionTable() (err error) {
316320
if err = ss.Lock(); err != nil {
317321
return err

database/stub/stub.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func (s *Stub) Unlock() error {
6161
return nil
6262
}
6363

64-
func (s *Stub) Run(migration io.Reader) error {
64+
func (s *Stub) Run(migration io.Reader, version int) error {
6565
m, err := ioutil.ReadAll(migration)
6666
if err != nil {
6767
return err
@@ -90,6 +90,10 @@ func (s *Stub) Drop() error {
9090
return nil
9191
}
9292

93+
func (s *Stub) Transactional() bool {
94+
return false
95+
}
96+
9397
func (s *Stub) EqualSequence(seq []string) bool {
9498
return reflect.DeepEqual(seq, s.MigrationSequence)
9599
}

migrate.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,10 @@ func (m *Migrate) Drop() error {
316316
return m.unlock()
317317
}
318318

319+
func (m *Migrate) Transactional() bool {
320+
return false
321+
}
322+
319323
// Run runs any migration provided by you against the database.
320324
// It does not check any currently active version in database.
321325
// Usually you don't need this function at all. Use Migrate,
@@ -737,20 +741,24 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
737741
migr := r
738742

739743
// set version with dirty state
740-
if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
741-
return err
744+
if !m.databaseDrv.Transactional() {
745+
if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
746+
return err
747+
}
742748
}
743749

744750
if migr.Body != nil {
745751
m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
746-
if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
752+
if err := m.databaseDrv.Run(migr.BufferedBody, migr.TargetVersion); err != nil {
747753
return err
748754
}
749755
}
750756

751757
// set clean state
752-
if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil {
753-
return err
758+
if !m.databaseDrv.Transactional() {
759+
if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil {
760+
return err
761+
}
754762
}
755763

756764
endTime := time.Now()

0 commit comments

Comments
 (0)