Skip to content

Commit 784700a

Browse files
authored
Merge pull request #2461 from actiontech/issue-2457-main
MYSQL规则 避免对条件字段使用函数操作 和 不建议在WHERE条件中使用与过滤字段不一致的数据类型 审核sql panic报错
2 parents 187e7ff + 62493fb commit 784700a

File tree

2 files changed

+78
-31
lines changed

2 files changed

+78
-31
lines changed

sqle/driver/mysql/audit_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2795,6 +2795,18 @@ select v1 from exist_db.exist_tb_1 where v2 = "3"
27952795
`,
27962796
newTestResult(),
27972797
)
2798+
2799+
runSingleRuleInspectCase(rule, t, "select: without from condition", DefaultMysqlInspect(), `select 1`, newTestResult())
2800+
2801+
runSingleRuleInspectCase(rule, t, "select: without where condition", DefaultMysqlInspect(), `select * from exist_db.exist_tb_1`, newTestResult())
2802+
2803+
runSingleRuleInspectCase(rule, t, "select: next select with function", DefaultMysqlInspect(), `select * from (select * from exist_db.exist_tb_1 where nvl(v2,"0") = "3") as t1`, newTestResult().addResult(rulepkg.DMLCheckWhereExistFunc))
2804+
2805+
runSingleRuleInspectCase(rule, t, "select union select 1", DefaultMysqlInspect(), `select 1 union select 1`, newTestResult())
2806+
2807+
runSingleRuleInspectCase(rule, t, "select: union select", DefaultMysqlInspect(), `select * from exist_db.exist_tb_1 where nvl(v2,"0") = "3" union select * from exist_db.exist_tb_1`, newTestResult().addResult(rulepkg.DMLCheckWhereExistFunc))
2808+
2809+
runSingleRuleInspectCase(rule, t, "union next select", DefaultMysqlInspect(), `select * from exist_db.exist_tb_1 union all select * from (select * from exist_db.exist_tb_1 where nvl(v2,"0") = "3") as t1`, newTestResult().addResult(rulepkg.DMLCheckWhereExistFunc))
27982810
}
27992811

28002812
func Test_DDLCheckCreateTimeColumn(t *testing.T) {
@@ -3027,6 +3039,21 @@ select v1 from exist_db.exist_tb_1 where id = 3;
30273039
`,
30283040
newTestResult(),
30293041
)
3042+
3043+
runSingleRuleInspectCase(rule, t, "select: not exist from condition", DefaultMysqlInspect(), `select 1;`, newTestResult())
3044+
3045+
runSingleRuleInspectCase(rule, t, "select: not exist where condition", DefaultMysqlInspect(), `select v1 from exist_db.exist_tb_1;`, newTestResult())
3046+
3047+
runSingleRuleInspectCase(rule, t, "select: nest select", DefaultMysqlInspect(), `select s.* from (select v1 from exist_db.exist_tb_1 where id = "3") s`,
3048+
newTestResult().addResult(rulepkg.DMLCheckWhereExistImplicitConversion))
3049+
3050+
runSingleRuleInspectCase(rule, t, "select: nest select", DefaultMysqlInspect(), `select s.* from (select v1 from exist_db.exist_tb_1 where id = 3) s`, newTestResult())
3051+
3052+
runSingleRuleInspectCase(rule, t, "UNION: union all select", DefaultMysqlInspect(), `select 1 union all select 1`, newTestResult())
3053+
3054+
runSingleRuleInspectCase(rule, t, "UNION: union nest select", DefaultMysqlInspect(), `select v1 from exist_db.exist_tb_1 union select s.v1 from (select v1 from exist_db.exist_tb_1 where v1 = "3") s`, newTestResult())
3055+
3056+
runSingleRuleInspectCase(rule, t, "UNION: union nest select", DefaultMysqlInspect(), `select v1 from exist_db.exist_tb_1 union select s.v1 from (select v1 from exist_db.exist_tb_1 where v1 = 3) s`, newTestResult().addResult(rulepkg.DMLCheckWhereExistImplicitConversion))
30303057
}
30313058

30323059
func TestCheckMultiSelectWhereExistImplicitConversion(t *testing.T) {

sqle/driver/mysql/rule/rule.go

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2682,21 +2682,39 @@ func checkDMLWithBatchInsertMaxLimits(input *RuleHandlerInput) error {
26822682

26832683
func checkWhereExistFunc(input *RuleHandlerInput) error {
26842684
tables := []*ast.TableName{}
2685-
switch stmt := input.Node.(type) {
2686-
case *ast.SelectStmt:
2687-
if stmt.Where != nil {
2688-
tableSources := util.GetTableSources(stmt.From.TableRefs)
2685+
hasExistFunc := func(stmt *ast.SelectStmt) bool {
2686+
selectExtractor := util.SelectStmtExtractor{}
2687+
stmt.Accept(&selectExtractor)
2688+
for _, selectStmt := range selectExtractor.SelectStmts {
2689+
if selectStmt.Where == nil || selectStmt.From == nil {
2690+
continue
2691+
}
2692+
2693+
tableSources := util.GetTableSources(selectStmt.From.TableRefs)
26892694
// not select from table statement
26902695
if len(tableSources) < 1 {
2691-
break
2696+
continue
26922697
}
2698+
26932699
for _, tableSource := range tableSources {
26942700
switch source := tableSource.Source.(type) {
26952701
case *ast.TableName:
26962702
tables = append(tables, source)
26972703
}
26982704
}
2699-
checkExistFunc(input.Ctx, input.Rule, input.Res, tables, stmt.Where)
2705+
2706+
if checkExistFunc(input.Ctx, input.Rule, input.Res, tables, selectStmt.Where) {
2707+
return true
2708+
}
2709+
}
2710+
2711+
return false
2712+
}
2713+
2714+
switch stmt := input.Node.(type) {
2715+
case *ast.SelectStmt:
2716+
if hasExistFunc(stmt) {
2717+
break
27002718
}
27012719
case *ast.UpdateStmt:
27022720
if stmt.Where != nil {
@@ -2714,18 +2732,8 @@ func checkWhereExistFunc(input *RuleHandlerInput) error {
27142732
checkExistFunc(input.Ctx, input.Rule, input.Res, util.GetTables(stmt.TableRefs.TableRefs), stmt.Where)
27152733
}
27162734
case *ast.UnionStmt:
2717-
for _, ss := range stmt.SelectList.Selects {
2718-
tableSources := util.GetTableSources(ss.From.TableRefs)
2719-
if len(tableSources) < 1 {
2720-
continue
2721-
}
2722-
for _, tableSource := range tableSources {
2723-
switch source := tableSource.Source.(type) {
2724-
case *ast.TableName:
2725-
tables = append(tables, source)
2726-
}
2727-
}
2728-
if checkExistFunc(input.Ctx, input.Rule, input.Res, tables, ss.Where) {
2735+
for _, selectStmt := range stmt.SelectList.Selects {
2736+
if hasExistFunc(selectStmt) {
27292737
break
27302738
}
27312739
}
@@ -2758,15 +2766,31 @@ func checkExistFunc(ctx *session.Context, rule driverV2.Rule, res *driverV2.Audi
27582766
}
27592767

27602768
func checkWhereColumnImplicitConversion(input *RuleHandlerInput) error {
2761-
switch stmt := input.Node.(type) {
2762-
case *ast.SelectStmt:
2763-
if stmt.Where != nil {
2764-
tableSources := util.GetTableSources(stmt.From.TableRefs)
2765-
// not select from table statement
2769+
hasWhereColumnImplicitConversionFunc := func(stmt *ast.SelectStmt) bool {
2770+
selectExtractor := util.SelectStmtExtractor{}
2771+
stmt.Accept(&selectExtractor)
2772+
for _, selectStmt := range selectExtractor.SelectStmts {
2773+
if selectStmt.From == nil || selectStmt.Where == nil {
2774+
continue
2775+
}
2776+
2777+
tableSources := util.GetTableSources(selectStmt.From.TableRefs)
27662778
if len(tableSources) < 1 {
2767-
break
2779+
continue
27682780
}
2769-
checkWhereColumnImplicitConversionFunc(input.Ctx, input.Rule, input.Res, tableSources, stmt.Where)
2781+
2782+
if checkWhereColumnImplicitConversionFunc(input.Ctx, input.Rule, input.Res, tableSources, selectStmt.Where) {
2783+
return true
2784+
}
2785+
}
2786+
2787+
return false
2788+
}
2789+
2790+
switch stmt := input.Node.(type) {
2791+
case *ast.SelectStmt:
2792+
if hasWhereColumnImplicitConversionFunc(stmt) {
2793+
break
27702794
}
27712795
case *ast.UpdateStmt:
27722796
if stmt.Where != nil {
@@ -2779,12 +2803,8 @@ func checkWhereColumnImplicitConversion(input *RuleHandlerInput) error {
27792803
checkWhereColumnImplicitConversionFunc(input.Ctx, input.Rule, input.Res, tableSources, stmt.Where)
27802804
}
27812805
case *ast.UnionStmt:
2782-
for _, ss := range stmt.SelectList.Selects {
2783-
tableSources := util.GetTableSources(ss.From.TableRefs)
2784-
if len(tableSources) < 1 {
2785-
continue
2786-
}
2787-
if checkWhereColumnImplicitConversionFunc(input.Ctx, input.Rule, input.Res, tableSources, ss.Where) {
2806+
for _, selectStmt := range stmt.SelectList.Selects {
2807+
if hasWhereColumnImplicitConversionFunc(selectStmt) {
27882808
break
27892809
}
27902810
}

0 commit comments

Comments
 (0)