@@ -2682,8 +2682,7 @@ func checkDMLWithBatchInsertMaxLimits(input *RuleHandlerInput) error {
2682
2682
2683
2683
func checkWhereExistFunc (input * RuleHandlerInput ) error {
2684
2684
tables := []* ast.TableName {}
2685
- switch stmt := input .Node .(type ) {
2686
- case * ast.SelectStmt :
2685
+ hasExistFunc := func (stmt * ast.SelectStmt ) bool {
2687
2686
selectExtractor := util.SelectStmtExtractor {}
2688
2687
stmt .Accept (& selectExtractor )
2689
2688
for _ , selectStmt := range selectExtractor .SelectStmts {
@@ -2705,9 +2704,18 @@ func checkWhereExistFunc(input *RuleHandlerInput) error {
2705
2704
}
2706
2705
2707
2706
if checkExistFunc (input .Ctx , input .Rule , input .Res , tables , selectStmt .Where ) {
2708
- break
2707
+ return true
2709
2708
}
2710
2709
}
2710
+
2711
+ return false
2712
+ }
2713
+
2714
+ switch stmt := input .Node .(type ) {
2715
+ case * ast.SelectStmt :
2716
+ if hasExistFunc (stmt ) {
2717
+ break
2718
+ }
2711
2719
case * ast.UpdateStmt :
2712
2720
if stmt .Where != nil {
2713
2721
tableSources := util .GetTableSources (stmt .TableRefs .TableRefs )
@@ -2724,30 +2732,9 @@ func checkWhereExistFunc(input *RuleHandlerInput) error {
2724
2732
checkExistFunc (input .Ctx , input .Rule , input .Res , util .GetTables (stmt .TableRefs .TableRefs ), stmt .Where )
2725
2733
}
2726
2734
case * ast.UnionStmt :
2727
- outerBreaker:
2728
- for _ , selectStmtList := range stmt .SelectList .Selects {
2729
- selectExtractor := util.SelectStmtExtractor {}
2730
- selectStmtList .Accept (& selectExtractor )
2731
- for _ , selectStmt := range selectExtractor .SelectStmts {
2732
- if selectStmt .From == nil || selectStmt .Where == nil {
2733
- continue
2734
- }
2735
-
2736
- tableSources := util .GetTableSources (selectStmt .From .TableRefs )
2737
- if len (tableSources ) < 1 {
2738
- continue
2739
- }
2740
-
2741
- for _ , tableSource := range tableSources {
2742
- switch source := tableSource .Source .(type ) {
2743
- case * ast.TableName :
2744
- tables = append (tables , source )
2745
- }
2746
- }
2747
-
2748
- if checkExistFunc (input .Ctx , input .Rule , input .Res , tables , selectStmt .Where ) {
2749
- break outerBreaker
2750
- }
2735
+ for _ , selectStmt := range stmt .SelectList .Selects {
2736
+ if hasExistFunc (selectStmt ) {
2737
+ break
2751
2738
}
2752
2739
}
2753
2740
default :
@@ -2779,25 +2766,32 @@ func checkExistFunc(ctx *session.Context, rule driverV2.Rule, res *driverV2.Audi
2779
2766
}
2780
2767
2781
2768
func checkWhereColumnImplicitConversion (input * RuleHandlerInput ) error {
2782
- switch stmt := input .Node .(type ) {
2783
- case * ast.SelectStmt :
2769
+ hasWhereColumnImplicitConversionFunc := func (stmt * ast.SelectStmt ) bool {
2784
2770
selectExtractor := util.SelectStmtExtractor {}
2785
2771
stmt .Accept (& selectExtractor )
2786
2772
for _ , selectStmt := range selectExtractor .SelectStmts {
2787
- if selectStmt .Where == nil || selectStmt .From == nil {
2773
+ if selectStmt .From == nil || selectStmt .Where == nil {
2788
2774
continue
2789
2775
}
2790
2776
2791
2777
tableSources := util .GetTableSources (selectStmt .From .TableRefs )
2792
- // not select from table statement
2793
2778
if len (tableSources ) < 1 {
2794
2779
continue
2795
2780
}
2796
2781
2797
2782
if checkWhereColumnImplicitConversionFunc (input .Ctx , input .Rule , input .Res , tableSources , selectStmt .Where ) {
2798
- break
2783
+ return true
2799
2784
}
2800
2785
}
2786
+
2787
+ return false
2788
+ }
2789
+
2790
+ switch stmt := input .Node .(type ) {
2791
+ case * ast.SelectStmt :
2792
+ if hasWhereColumnImplicitConversionFunc (stmt ) {
2793
+ break
2794
+ }
2801
2795
case * ast.UpdateStmt :
2802
2796
if stmt .Where != nil {
2803
2797
tableSources := util .GetTableSources (stmt .TableRefs .TableRefs )
@@ -2809,23 +2803,9 @@ func checkWhereColumnImplicitConversion(input *RuleHandlerInput) error {
2809
2803
checkWhereColumnImplicitConversionFunc (input .Ctx , input .Rule , input .Res , tableSources , stmt .Where )
2810
2804
}
2811
2805
case * ast.UnionStmt :
2812
- outerBreaker:
2813
- for _ , selectStmtList := range stmt .SelectList .Selects {
2814
- selectExtractor := util.SelectStmtExtractor {}
2815
- selectStmtList .Accept (& selectExtractor )
2816
- for _ , selectStmt := range selectExtractor .SelectStmts {
2817
- if selectStmt .From == nil || selectStmt .Where == nil {
2818
- continue
2819
- }
2820
-
2821
- tableSources := util .GetTableSources (selectStmt .From .TableRefs )
2822
- if len (tableSources ) < 1 {
2823
- continue
2824
- }
2825
-
2826
- if checkWhereColumnImplicitConversionFunc (input .Ctx , input .Rule , input .Res , tableSources , selectStmt .Where ) {
2827
- break outerBreaker
2828
- }
2806
+ for _ , selectStmt := range stmt .SelectList .Selects {
2807
+ if hasWhereColumnImplicitConversionFunc (selectStmt ) {
2808
+ break
2829
2809
}
2830
2810
}
2831
2811
default :
0 commit comments