Skip to content

Commit e132b9f

Browse files
glebteterinJames Naylor
authored andcommitted
Support MSSQL batch statements (Resolves #652)
1 parent 2788339 commit e132b9f

File tree

3 files changed

+79
-15
lines changed

3 files changed

+79
-15
lines changed

database/sqlserver/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
| `encrypt` | | `disable` - Data send between client and server is not encrypted. `false` - Data sent between client and server is not encrypted beyond the login packet (Default). `true` - Data sent between client and server is encrypted. |
1818
| `app+name` || The application name (default is go-mssqldb). |
1919
| `useMsi` | | `true` - Use Azure MSI Authentication for connecting to Sql Server. Must be running from an Azure VM/an instance with MSI enabled. `false` - Use password authentication (Default). See [here for Azure MSI Auth details](https://docs.microsoft.com/en-us/azure/app-service/app-service-web-tutorial-connect-msi). NOTE: Since this cannot be tested locally, this is not officially supported.
20+
| `x-batch` | | Enable batch statements (default: false) |
2021

2122
See https://github.com/microsoft/go-mssqldb for full parameter list.
2223

database/sqlserver/sqlserver.go

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/golang-migrate/migrate/v4/database"
1717
"github.com/hashicorp/go-multierror"
1818
mssql "github.com/microsoft/go-mssqldb" // mssql support
19+
"github.com/microsoft/go-mssqldb/batch"
1920
)
2021

2122
func init() {
@@ -30,7 +31,7 @@ var (
3031
ErrNoDatabaseName = fmt.Errorf("no database name")
3132
ErrNoSchema = fmt.Errorf("no schema")
3233
ErrDatabaseDirty = fmt.Errorf("database is dirty")
33-
ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
34+
ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed")
3435
)
3536

3637
var lockErrorMap = map[int]string{
@@ -42,9 +43,10 @@ var lockErrorMap = map[int]string{
4243

4344
// Config for database
4445
type Config struct {
45-
MigrationsTable string
46-
DatabaseName string
47-
SchemaName string
46+
MigrationsTable string
47+
DatabaseName string
48+
SchemaName string
49+
BatchStatementEnabled bool
4850
}
4951

5052
// SQL Server connection
@@ -168,9 +170,18 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) {
168170

169171
migrationsTable := purl.Query().Get("x-migrations-table")
170172

173+
batchStatementEnabled := false
174+
if s := purl.Query().Get("x-batch"); len(s) > 0 {
175+
batchStatementEnabled, err = strconv.ParseBool(s)
176+
if err != nil {
177+
return nil, fmt.Errorf("unable to parse option x-batch: %w", err)
178+
}
179+
}
180+
171181
px, err := WithInstance(db, &Config{
172-
DatabaseName: purl.Path,
173-
MigrationsTable: migrationsTable,
182+
DatabaseName: purl.Path,
183+
MigrationsTable: migrationsTable,
184+
BatchStatementEnabled: batchStatementEnabled,
174185
})
175186

176187
if err != nil {
@@ -247,15 +258,23 @@ func (ss *SQLServer) Run(migration io.Reader) error {
247258

248259
// run migration
249260
query := string(migr[:])
250-
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
251-
if msErr, ok := err.(mssql.Error); ok {
252-
message := fmt.Sprintf("migration failed: %s", msErr.Message)
253-
if msErr.ProcName != "" {
254-
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
261+
scripts := []string{query}
262+
263+
if ss.config.BatchStatementEnabled {
264+
scripts = batch.Split(query, "go")
265+
}
266+
267+
for _, script := range scripts {
268+
if _, err := ss.conn.ExecContext(context.Background(), script); err != nil {
269+
if msErr, ok := err.(mssql.Error); ok {
270+
message := fmt.Sprintf("migration failed: %s", msErr.Message)
271+
if msErr.ProcName != "" {
272+
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
273+
}
274+
return database.Error{OrigErr: err, Err: message, Query: []byte(script), Line: uint(msErr.LineNo)}
255275
}
256-
return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
276+
return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(script)}
257277
}
258-
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
259278
}
260279

261280
return nil

database/sqlserver/sqlserver_test.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ var (
3737
}
3838
)
3939

40-
func msConnectionString(host, port string) string {
41-
return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port)
40+
func msConnectionString(host, port string, options ...string) string {
41+
options = append(options, "database=master")
42+
return fmt.Sprintf("sqlserver://sa:%v@%v:%v?%s", saPassword, host, port, strings.Join(options, "&"))
4243
}
4344

4445
func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string {
@@ -191,6 +192,49 @@ func testMultiStatement(t *testing.T) {
191192
})
192193
}
193194

195+
func testBatchedStatement(t *testing.T) {
196+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
197+
ip, port, err := c.Port(defaultPort)
198+
if err != nil {
199+
t.Fatal(err)
200+
}
201+
202+
addr := msConnectionString(ip, port, "x-batch=true")
203+
ms := &SQLServer{}
204+
d, err := ms.Open(addr)
205+
if err != nil {
206+
t.Fatal(err)
207+
}
208+
defer func() {
209+
if err := d.Close(); err != nil {
210+
t.Error(err)
211+
}
212+
}()
213+
if err := d.Run(strings.NewReader(`CREATE PROCEDURE uspA
214+
AS
215+
BEGIN
216+
SELECT 1;
217+
END;
218+
GO
219+
CREATE PROCEDURE uspB
220+
AS
221+
BEGIN
222+
SELECT 2;
223+
END`)); err != nil {
224+
t.Fatalf("expected err to be nil, got %v", err)
225+
}
226+
227+
// make sure second proc exists
228+
var exists int
229+
if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "Select COUNT(1) from sysobjects where type = 'P' and category = 0 and [NAME] = 'uspB'").Scan(&exists); err != nil {
230+
t.Fatal(err)
231+
}
232+
if exists != 1 {
233+
t.Fatalf("expected proc uspB to exist")
234+
}
235+
})
236+
}
237+
194238
func testErrorParsing(t *testing.T) {
195239
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
196240
SkipIfUnsupportedArch(t, c)

0 commit comments

Comments
 (0)