Skip to content

Commit 5b2de05

Browse files
authored
Merge pull request #2533 from actiontech/issue-2356-1
xml解析不再进行format格式化;返回原始sql字符串
2 parents e5204f7 + 8235bce commit 5b2de05

File tree

2 files changed

+13
-35
lines changed

2 files changed

+13
-35
lines changed

sqle/api/controller/v1/sql_audit_record.go

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@ import (
1616

1717
javaParser "github.com/actiontech/java-sql-extractor/parser"
1818
xmlParser "github.com/actiontech/mybatis-mapper-2-sql"
19-
"github.com/actiontech/mybatis-mapper-2-sql/ast"
2019
"github.com/actiontech/sqle/sqle/api/controller"
2120
"github.com/actiontech/sqle/sqle/common"
2221
"github.com/actiontech/sqle/sqle/dms"
2322
"github.com/actiontech/sqle/sqle/driver"
24-
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
2523
"github.com/actiontech/sqle/sqle/errors"
2624
"github.com/actiontech/sqle/sqle/log"
2725
"github.com/actiontech/sqle/sqle/model"
@@ -106,7 +104,7 @@ func CreateSQLAuditRecord(c echo.Context) error {
106104
SQLsFromFormData: req.Sqls,
107105
}
108106
} else {
109-
sqls, err = getSQLFromFile(c, req.DbType)
107+
sqls, err = getSQLFromFile(c)
110108
if err != nil {
111109
return controller.JSONBaseErrorReq(c, err)
112110
}
@@ -324,7 +322,7 @@ func buildOfflineTaskForAudit(userId uint64, dbType string, sqls getSQLFromFileR
324322
}
325323

326324
// todo 此处跳过了不支持的编码格式文件
327-
func getSqlsFromZip(c echo.Context, dbType string) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, exist bool, err error) {
325+
func getSqlsFromZip(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, exist bool, err error) {
328326
file, err := c.FormFile(InputZipFileName)
329327
if err == http.ErrMissingFile {
330328
return nil, nil, false, nil
@@ -390,7 +388,7 @@ func getSqlsFromZip(c echo.Context, dbType string) (sqlsFromSQLFile []SQLsFromSQ
390388
// parse xml content
391389
// xml文件需要把所有文件内容同时解析,否则会无法解析跨namespace引用的SQL
392390
{
393-
sqlsFromXmls, err := parseXMLsWithFilePath(xmlContents, dbType)
391+
sqlsFromXmls, err := parseXMLsWithFilePath(xmlContents)
394392
if err != nil {
395393
return nil, nil, false, err
396394
}
@@ -399,14 +397,8 @@ func getSqlsFromZip(c echo.Context, dbType string) (sqlsFromSQLFile []SQLsFromSQ
399397

400398
return sqlsFromSQLFile, sqlsFromXML, true, nil
401399
}
402-
func parseXMLsWithFilePath(xmlContents []xmlParser.XmlFile, dbType string) ([]SQLFromXML, error) {
403-
var allStmtsFromXml []ast.StmtInfo
404-
var err error
405-
if dbType == driverV2.DriverTypePostgreSQL || dbType == driverV2.DriverTypeTBase {
406-
allStmtsFromXml, err = xmlParser.ParseXMLs(xmlContents, xmlParser.SkipErrorQuery, xmlParser.RestoreOriginSql)
407-
} else {
408-
allStmtsFromXml, err = xmlParser.ParseXMLs(xmlContents, xmlParser.SkipErrorQuery)
409-
}
400+
func parseXMLsWithFilePath(xmlContents []xmlParser.XmlFile) ([]SQLFromXML, error) {
401+
allStmtsFromXml, err := xmlParser.ParseXMLs(xmlContents, xmlParser.SkipErrorQuery, xmlParser.RestoreOriginSql)
410402
if err != nil {
411403
return nil, fmt.Errorf("parse sqls from xml failed: %v", err)
412404
}
@@ -423,7 +415,7 @@ func parseXMLsWithFilePath(xmlContents []xmlParser.XmlFile, dbType string) ([]SQ
423415
}
424416

425417
// todo 此处跳过了不支持的编码格式文件
426-
func getSqlsFromGit(c echo.Context, dbType string) (sqlsFromSQLFiles, sqlsFromJavaFiles []SQLsFromSQLFile, sqlsFromXMLs []SQLFromXML, exist bool, err error) {
418+
func getSqlsFromGit(c echo.Context) (sqlsFromSQLFiles, sqlsFromJavaFiles []SQLsFromSQLFile, sqlsFromXMLs []SQLFromXML, exist bool, err error) {
427419
// make a temp dir and clean up befor return
428420
dir, err := os.MkdirTemp("./", "git-repo-")
429421
if err != nil {
@@ -528,7 +520,7 @@ func getSqlsFromGit(c echo.Context, dbType string) (sqlsFromSQLFiles, sqlsFromJa
528520

529521
// parse xml content
530522
// xml文件需要把所有文件内容同时解析,否则会无法解析跨namespace引用的SQL
531-
sqlsFromXMLs, err = parseXMLsWithFilePath(xmlContents, dbType)
523+
sqlsFromXMLs, err = parseXMLsWithFilePath(xmlContents)
532524
if err != nil {
533525
return nil, nil, nil, false, err
534526
}

sqle/api/controller/v1/task.go

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@ import (
1818

1919
dmsV1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1"
2020
mybatis_parser "github.com/actiontech/mybatis-mapper-2-sql"
21-
"github.com/actiontech/mybatis-mapper-2-sql/ast"
2221
"github.com/actiontech/sqle/sqle/api/controller"
2322
"github.com/actiontech/sqle/sqle/common"
2423
"github.com/actiontech/sqle/sqle/config"
2524
"github.com/actiontech/sqle/sqle/dms"
26-
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
2725
"github.com/actiontech/sqle/sqle/errors"
2826
"github.com/actiontech/sqle/sqle/log"
2927
"github.com/actiontech/sqle/sqle/model"
@@ -109,7 +107,7 @@ const (
109107
ZIPFileExtension = ".zip"
110108
)
111109

112-
func getSQLFromFile(c echo.Context, dbType string) (getSQLFromFileResp, error) {
110+
func getSQLFromFile(c echo.Context) (getSQLFromFileResp, error) {
113111
// Read it from sql file.
114112
fileName, sqlsFromSQLFile, exist, err := controller.ReadFile(c, InputSQLFileName)
115113
if err != nil {
@@ -130,13 +128,7 @@ func getSQLFromFile(c echo.Context, dbType string) (getSQLFromFileResp, error) {
130128
return getSQLFromFileResp{}, err
131129
}
132130
if exist {
133-
var sqls []ast.StmtInfo
134-
var err error
135-
if dbType == driverV2.DriverTypePostgreSQL || dbType == driverV2.DriverTypeTBase {
136-
sqls, err = mybatis_parser.ParseXMLs([]mybatis_parser.XmlFile{{Content: data}}, mybatis_parser.SkipErrorQuery, mybatis_parser.RestoreOriginSql)
137-
} else {
138-
sqls, err = mybatis_parser.ParseXMLs([]mybatis_parser.XmlFile{{Content: data}}, mybatis_parser.SkipErrorQuery)
139-
}
131+
sqls, err := mybatis_parser.ParseXMLs([]mybatis_parser.XmlFile{{Content: data}}, mybatis_parser.SkipErrorQuery, mybatis_parser.RestoreOriginSql)
140132
if err != nil {
141133
return getSQLFromFileResp{}, errors.New(errors.ParseMyBatisXMLFileError, err)
142134
}
@@ -155,7 +147,7 @@ func getSQLFromFile(c echo.Context, dbType string) (getSQLFromFileResp, error) {
155147
}
156148

157149
// If mybatis xml file is not exist, read it from zip file.
158-
sqlsFromSQLFiles, sqlsFromXML, exist, err := getSqlsFromZip(c, dbType)
150+
sqlsFromSQLFiles, sqlsFromXML, exist, err := getSqlsFromZip(c)
159151
if err != nil {
160152
return getSQLFromFileResp{}, err
161153
}
@@ -168,7 +160,7 @@ func getSQLFromFile(c echo.Context, dbType string) (getSQLFromFileResp, error) {
168160
}
169161

170162
// If zip file is not exist, read it from git repository
171-
sqlsFromSQLFiles, sqlsFromJavaFiles, sqlsFromXMLs, exist, err := getSqlsFromGit(c, dbType)
163+
sqlsFromSQLFiles, sqlsFromJavaFiles, sqlsFromXMLs, exist, err := getSqlsFromGit(c)
172164
if err != nil {
173165
return getSQLFromFileResp{}, err
174166
}
@@ -314,20 +306,14 @@ func CreateAndAuditTask(c echo.Context) error {
314306
if err != nil {
315307
return controller.JSONBaseErrorReq(c, err)
316308
}
317-
instance, exist, err := dms.GetInstanceInProjectByName(c.Request().Context(), projectUid, req.InstanceName)
318-
if !exist {
319-
return controller.JSONBaseErrorReq(c, ErrInstanceNotExist)
320-
} else if err != nil {
321-
return controller.JSONBaseErrorReq(c, errors.New(errors.DataConflict, err))
322-
}
323309

324310
if req.Sql != "" {
325311
sqls = getSQLFromFileResp{
326312
SourceType: model.TaskSQLSourceFromFormData,
327313
SQLsFromFormData: req.Sql,
328314
}
329315
} else {
330-
sqls, err = getSQLFromFile(c, instance.DbType)
316+
sqls, err = getSQLFromFile(c)
331317
if err != nil {
332318
return controller.JSONBaseErrorReq(c, err)
333319
}
@@ -1001,7 +987,7 @@ func AuditTaskGroupV1(c echo.Context) error {
1001987
SQLsFromFormData: req.Sql,
1002988
}
1003989
} else {
1004-
sqls, err = getSQLFromFile(c, dbType)
990+
sqls, err = getSQLFromFile(c)
1005991
if err != nil {
1006992
return controller.JSONBaseErrorReq(c, err)
1007993
}

0 commit comments

Comments
 (0)