diff --git a/BUILD.bazel b/BUILD.bazel index 2e18d68..dc16402 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -130,6 +130,7 @@ package_filegroup( "//pkg/scalarule/mocks:filegroup", "//pkg/semanticdb:filegroup", "//pkg/starlarkeval:filegroup", + "//pkg/sweep:filegroup", "//pkg/testutil:filegroup", "//pkg/wildcardimport:filegroup", "//rules:filegroup", diff --git a/language/scala/BUILD.bazel b/language/scala/BUILD.bazel index 8f7f208..4b21a6b 100644 --- a/language/scala/BUILD.bazel +++ b/language/scala/BUILD.bazel @@ -49,6 +49,7 @@ go_library( "//pkg/scalaconfig", "//pkg/scalafiles", "//pkg/scalarule", + "//pkg/sweep", "//pkg/wildcardimport", "@bazel_gazelle//config:go_default_library", "@bazel_gazelle//label:go_default_library", diff --git a/language/scala/existing_scala_rule.go b/language/scala/existing_scala_rule.go index f48e8b1..43ebb87 100644 --- a/language/scala/existing_scala_rule.go +++ b/language/scala/existing_scala_rule.go @@ -12,11 +12,10 @@ import ( "github.com/bazelbuild/bazel-gazelle/rule" "github.com/bazelbuild/buildtools/build" - "github.com/stackb/scala-gazelle/pkg/bazel" - "github.com/stackb/scala-gazelle/pkg/collections" "github.com/stackb/scala-gazelle/pkg/protobuf" "github.com/stackb/scala-gazelle/pkg/scalaconfig" "github.com/stackb/scala-gazelle/pkg/scalarule" + "github.com/stackb/scala-gazelle/pkg/sweep" sppb "github.com/stackb/scala-gazelle/build/stack/gazelle/scala/parse" ) @@ -82,11 +81,6 @@ func (s *existingScalaRuleProvider) ResolveRule(cfg *scalarule.Config, pkg scala log.Printf("skipping %s %s: unable to collect srcs: %v", r.Kind(), r.Name(), err) return nil } - // rule has no srcs. This is OK for binary rules, sometimes they only - // have a main_class. - // if !s.isBinary { - // return nil // no need to print a warning - // } } if scalaRule == nil { log.Panicln("scalaRule should not be nil!") @@ -158,112 +152,21 @@ func (s *existingScalaRule) Resolve(rctx *scalarule.ResolveContext, importsRaw i } if sc.ShouldSweepTransitive("deps") { - if !hasTransitiveComment(rctx.Rule) { - if junk, err := s.sweepTransitiveAttr("deps", rctx.Rule, rctx.From); err != nil { + if !sweep.HasTransitiveRuleComment(rctx.Rule) { + if junk, err := sweep.TransitiveAttr("deps", rctx.File, rctx.Rule, rctx.From); err != nil { log.Printf("warning: transitive sweep failed: %v", err) } else { if len(junk) > 0 { log.Println(formatBuildozerRemoveDeps(rctx.From, junk)) } } - rctx.Rule.AddComment(scalaconfig.TransitiveCommentToken) + rctx.Rule.AddComment(sweep.TransitiveCommentToken) } else { log.Println("> transitive sweep skipped (already done):", rctx.From) } } } -// sweepTransitiveDeps iterates through deps marked "UNKNOWN" and removes them -// if the target still builds without it. -func (s *existingScalaRule) sweepTransitiveAttr(attrName string, r *rule.Rule, from label.Label) ([]string, error) { - // get the File to which this Rule belongs - file := s.pkg.GenerateArgs().File - - return s.sweepTransitive(attrName, file, r, from) -} - -// sweepTransitive iterates through deps marked "UNKNOWN" and removes them if -// the target still builds without it. -func (s *existingScalaRule) sweepTransitive(attrName string, file *rule.File, r *rule.Rule, from label.Label) (junk []string, err error) { - expr := r.Attr(attrName) - if expr == nil { - return nil, nil - } - - deps, isList := expr.(*build.ListExpr) - if !isList { - return nil, nil // some other condition we can't deal with - } - - // check that the deps have at least one unknown dep - var hasTransitiveDeps bool - for _, expr := range deps.List { - if str, ok := expr.(*build.StringExpr); ok { - for _, suffix := range str.Comment().Suffix { - if suffix.Token == scalaconfig.TransitiveCommentToken { - hasTransitiveDeps = true - break - } - } - } - } - if !hasTransitiveDeps { - return nil, nil // nothing to do - } - - // target should build first time, otherwise we can't check accurately. - log.Println("🧱 transitive sweep:", from) - - if out, exitCode, _ := bazel.ExecCommand("bazel", "build", from.String()); exitCode != 0 { - log.Fatalln("sweep failed (must build cleanly on first attempt): %s", string(out)) - } - - // iterate the list backwards - for i := len(deps.List) - 1; i >= 0; i-- { - expr := deps.List[i] - - // look for transitive string dep expressions - dep, ok := expr.(*build.StringExpr) - if !ok { - continue - } - var isTransitiveDep bool - for _, suffix := range dep.Comment().Suffix { - if suffix.Token == scalaconfig.TransitiveCommentToken { - isTransitiveDep = true - break - } - } - if !isTransitiveDep { - continue - } - - // reference of original list in case it does not build - original := deps.List - // reset deps with this one spliced out - deps.List = collections.SliceRemoveIndex(deps.List, i) - // save file to reflect change - if err := file.Save(file.Path); err != nil { - return nil, err - } - // see if it still builds - if _, exitCode, _ := bazel.ExecCommand("bazel", "build", from.String()); exitCode == 0 { - log.Println("- 💩 junk:", dep.Value) - junk = append(junk, dep.Value) - } else { - log.Println("- 👑 keep:", dep.Value) - deps.List = original - } - } - - // final save with possible last change - if err := file.Save(file.Path); err != nil { - return nil, err - } - - return -} - func makeRuleComments(pb *sppb.Rule) (comments []build.Comment) { pb.ParseTimeMillis = 0 json, _ := protobuf.StableJSON(pb) // ignoring error, this isn't critical @@ -281,12 +184,3 @@ func makeRuleComments(pb *sppb.Rule) (comments []build.Comment) { func formatBuildozerRemoveDeps(from label.Label, junk []string) string { return fmt.Sprintf("buildozer 'remove deps %s' %s", strings.Join(junk, " "), from.String()) } - -func hasTransitiveComment(r *rule.Rule) bool { - for _, before := range r.Comments() { - if before == scalaconfig.TransitiveCommentToken { - return true - } - } - return false -} diff --git a/language/scala/lifecycle.go b/language/scala/lifecycle.go index 755f5b5..73555ca 100644 --- a/language/scala/lifecycle.go +++ b/language/scala/lifecycle.go @@ -40,6 +40,11 @@ func (sl *scalaLang) onEnd() { log.Fatalf("provider.OnEnd transition error %s: %v", provider.Name(), err) } } + for _, pkg := range sl.packages { + if err := pkg.OnEnd(); err != nil { + log.Fatalf("pkg.OnEnd transition error %s: %v", pkg.args.Rel, err) + } + } sl.dumpResolvedImportMap() sl.reportCoverage(log.Printf) diff --git a/language/scala/scala_package.go b/language/scala/scala_package.go index 84b6e27..875ecd7 100644 --- a/language/scala/scala_package.go +++ b/language/scala/scala_package.go @@ -326,6 +326,22 @@ func (p *scalaPackage) infof(format string, args ...any) string { return fmt.Sprintf("INFO ["+p.args.Rel+"]: "+format, args...) } +// OnEnd is a lifecycle hook that gets called when the resolve phase has +// ended. +func (p *scalaPackage) OnEnd() error { + // if p.cfg.ShouldSweepTransitiveDeps() { + // // strip off the sweep directive if we got this far in the process + // if err := sweep.RemoveSweepDirective(p.args.File); err != nil { + // return err + // } + // // flip the keep_deps to false + // if err := sweep.SetKeepDepsDirective(p.args.File, false); err != nil { + // return err + // } + // } + return nil +} + func ruleContributesToCoverage(name string) bool { switch name { case "scala_files": diff --git a/pkg/scalaconfig/BUILD.bazel b/pkg/scalaconfig/BUILD.bazel index 1a44329..f053dad 100644 --- a/pkg/scalaconfig/BUILD.bazel +++ b/pkg/scalaconfig/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//pkg/collections", "//pkg/resolver", "//pkg/scalarule", + "//pkg/sweep", "@bazel_gazelle//config:go_default_library", "@bazel_gazelle//label:go_default_library", "@bazel_gazelle//resolve:go_default_library", diff --git a/pkg/scalaconfig/config.go b/pkg/scalaconfig/config.go index 5f0d7e0..acc1c49 100644 --- a/pkg/scalaconfig/config.go +++ b/pkg/scalaconfig/config.go @@ -18,14 +18,12 @@ import ( "github.com/stackb/scala-gazelle/pkg/collections" "github.com/stackb/scala-gazelle/pkg/resolver" "github.com/stackb/scala-gazelle/pkg/scalarule" + "github.com/stackb/scala-gazelle/pkg/sweep" ) type debugAnnotation int -const ( - scalaLangName = "scala" - TransitiveCommentToken = "# TRANSITIVE" -) +const scalaLangName = "scala" const ( DebugUnknown debugAnnotation = 0 @@ -59,17 +57,6 @@ const ( // gazelle:scala_fix_wildcard_imports .scala examples.aeron.api.proto._ scalaFixWildcardImportDirective = "scala_fix_wildcard_imports" - // Flag to preserve deps if the label is not known to be needed from the - // imports (legacy migration mode). - // - // gazelle:scala_keep_unknown_deps true - scalaKeepUnknownDepsDirective = "scala_keep_unknown_deps" - - // Turn on the dep sweeper - // - // gazelle:scala_sweep_transitive_deps true - scalaSweepTransitiveDepsDirective = "scala_sweep_transitive_deps" - // Configure a scala rule // // gazelle:scala_rule RULE_NAME ATTRIBUTE VALUE @@ -147,8 +134,8 @@ func DirectiveNames() []string { scalaDebugDirective, scalaLogLevelDirective, scalaDepsCleanerDirective, - scalaKeepUnknownDepsDirective, - scalaSweepTransitiveDepsDirective, + sweep.ScalaKeepUnknownDepsDirective, + sweep.ScalaSweepTransitiveDepsDirective, scalaFixWildcardImportDirective, scalaGenerateBuildFilesDirective, scalaRuleDirective, @@ -305,11 +292,11 @@ func (c *Config) ParseDirectives(directives []rule.Directive) error { if err := c.parseScalaRuleDirective(d); err != nil { return fmt.Errorf(`invalid directive: "gazelle:%s %s": %w`, d.Key, d.Value, err) } - case scalaKeepUnknownDepsDirective: + case sweep.ScalaKeepUnknownDepsDirective: if err := c.parseKeepUnknownDepsDirective(d); err != nil { return err } - case scalaSweepTransitiveDepsDirective: + case sweep.ScalaSweepTransitiveDepsDirective: if err := c.parseSweepTransitiveDepsDirective(d); err != nil { return err } @@ -420,11 +407,11 @@ func (c *Config) parseFixWildcardImport(d rule.Directive) { func (c *Config) parseKeepUnknownDepsDirective(d rule.Directive) error { parts := strings.Fields(d.Value) if len(parts) != 1 { - return fmt.Errorf("invalid gazelle:%s directive: expected [true|false], got %v", scalaKeepUnknownDepsDirective, parts) + return fmt.Errorf("invalid gazelle:%s directive: expected [true|false], got %v", sweep.ScalaKeepUnknownDepsDirective, parts) } keepUnknownDeps, err := strconv.ParseBool(parts[0]) if err != nil { - return fmt.Errorf("invalid gazelle:%s directive: %v", scalaKeepUnknownDepsDirective, err) + return fmt.Errorf("invalid gazelle:%s directive: %v", sweep.ScalaKeepUnknownDepsDirective, err) } c.keepUnknownDeps = keepUnknownDeps return nil @@ -433,11 +420,11 @@ func (c *Config) parseKeepUnknownDepsDirective(d rule.Directive) error { func (c *Config) parseSweepTransitiveDepsDirective(d rule.Directive) error { parts := strings.Fields(d.Value) if len(parts) != 1 { - return fmt.Errorf("invalid gazelle:%s directive: expected [true|false], got %v", scalaSweepTransitiveDepsDirective, parts) + return fmt.Errorf("invalid gazelle:%s directive: expected [true|false], got %v", sweep.ScalaSweepTransitiveDepsDirective, parts) } sweepTransitiveDeps, err := strconv.ParseBool(parts[0]) if err != nil { - return fmt.Errorf("invalid gazelle:%s directive: %v", scalaSweepTransitiveDepsDirective, err) + return fmt.Errorf("invalid gazelle:%s directive: %v", sweep.ScalaSweepTransitiveDepsDirective, err) } c.sweepTransitiveDeps = sweepTransitiveDeps return nil @@ -805,11 +792,11 @@ func (c *Config) mergeDeps(attrValue build.Expr, deps map[label.Label]bool, impo if c.ShouldSweepTransitive(attrName) { // set as TRANSITIVE comment for sweeping if _, ok := expr.(*build.StringExpr); ok { - expr.Comment().Suffix = []build.Comment{{Token: TransitiveCommentToken}} + expr.Comment().Suffix = []build.Comment{sweep.MakeTransitiveComment()} } dst.List = append(dst.List, expr) } else { - if isTransitive(expr) { + if sweep.IsTransitiveDep(expr) { dst.List = append(dst.List, expr) } else { // one more caveat: preserve unmarked deps in legacy mode @@ -956,14 +943,3 @@ func setCommentPrefix(comment build.Comment, prefix string) build.Comment { comment.Token = "# " + prefix + strings.TrimSpace(strings.TrimPrefix(comment.Token, "#")) return comment } - -// isTransitive returns whether e is marked with a "# TRANSITIVE" comment. -func isTransitive(e build.Expr) bool { - for _, c := range e.Comment().Suffix { - text := strings.TrimSpace(c.Token) - if text == TransitiveCommentToken { - return true - } - } - return false -} diff --git a/pkg/sweep/BUILD.bazel b/pkg/sweep/BUILD.bazel new file mode 100644 index 0000000..7879f08 --- /dev/null +++ b/pkg/sweep/BUILD.bazel @@ -0,0 +1,25 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@build_stack_scala_gazelle//rules:package_filegroup.bzl", "package_filegroup") + +go_library( + name = "sweep", + srcs = ["sweep.go"], + importpath = "github.com/stackb/scala-gazelle/pkg/sweep", + visibility = ["//visibility:public"], + deps = [ + "//pkg/bazel", + "//pkg/collections", + "@bazel_gazelle//label:go_default_library", + "@bazel_gazelle//rule:go_default_library", + "@com_github_bazelbuild_buildtools//build:go_default_library", + ], +) + +package_filegroup( + name = "filegroup", + srcs = [ + "BUILD.bazel", + "sweep.go", + ], + visibility = ["//visibility:public"], +) diff --git a/pkg/sweep/sweep.go b/pkg/sweep/sweep.go new file mode 100644 index 0000000..1e2cdf3 --- /dev/null +++ b/pkg/sweep/sweep.go @@ -0,0 +1,160 @@ +package sweep + +import ( + "bytes" + "fmt" + "log" + "os" + "strings" + + "github.com/bazelbuild/bazel-gazelle/label" + "github.com/bazelbuild/bazel-gazelle/rule" + "github.com/bazelbuild/buildtools/build" + "github.com/stackb/scala-gazelle/pkg/bazel" + "github.com/stackb/scala-gazelle/pkg/collections" +) + +const ( + // Turn on the dep sweeper + // + // gazelle:scala_sweep_transitive_deps true + ScalaSweepTransitiveDepsDirective = "scala_sweep_transitive_deps" + + // Flag to preserve deps if the label is not known to be needed from the + // imports (legacy migration mode). + // + // gazelle:scala_keep_unknown_deps true + ScalaKeepUnknownDepsDirective = "scala_keep_unknown_deps" +) + +const TransitiveCommentToken = "# TRANSITIVE" + +// TransitiveAttr iterates through deps marked "TRANSITIVE" and removes them if +// the target still builds without it. +func TransitiveAttr(attrName string, r *rule.Rule, file *rule.File, from label.Label) error { + expr := r.Attr(attrName) + if expr == nil { + return nil + } + + deps, isList := expr.(*build.ListExpr) + if !isList { + return nil // some other condition we can't deal with + } + + // target should build first time, otherwise we can't check accurately. + log.Println("🧱 transitive sweep:", from) + + if out, exitCode, _ := bazel.ExecCommand("bazel", "build", from.String()); exitCode != 0 { + log.Fatalf("sweep failed (must build cleanly on first attempt): %s", string(out)) + } + + for i := len(deps.List) - 1; i >= 0; i-- { + expr := deps.List[i] + switch t := expr.(type) { + case *build.StringExpr: + if len(t.Comments.Suffix) != 1 { + continue + } + if t.Comments.Suffix[0].Token != "# TRANSITIVE" { + continue + } + + dep, err := label.Parse(t.Value) + if err != nil { + return err + } + deps.List = collections.SliceRemoveIndex(deps.List, i) + + if err := file.Save(file.Path); err != nil { + return err + } + + if _, exitCode, _ := bazel.ExecCommand("bazel", "build", from.String()); exitCode == 0 { + log.Println("- 💩 junk:", dep) + } else { + log.Println("- 👑 keep:", dep) + deps.List = collections.SliceInsertAt(deps.List, i, expr) + } + } + + } + + if err := file.Save(file.Path); err != nil { + return err + } + + return nil +} + +func RemoveSweepDirective(file *rule.File) error { + if file == nil { + return nil + } + // if this file has the sweep directive, remove it + for _, d := range file.Directives { + if d.Key == ScalaSweepTransitiveDepsDirective && d.Value == "true" { + old := []byte(fmt.Sprintf("# gazelle:%s true\n", ScalaSweepTransitiveDepsDirective)) + new := []byte{'\n'} + file.Content = bytes.Replace(file.Content, old, new, -1) + // file.Sync() + if err := file.Save(file.Path); err != nil { + return err + } + // log.Panicln("saved it!", file.Path) + } + } + return nil +} + +func SetKeepDepsDirective(file *rule.File, value bool) error { + if file == nil { + return nil + } + // if this file has the sweep directive, remove it + for _, d := range file.Directives { + if d.Key == ScalaKeepUnknownDepsDirective { + // log.Panicln("found it!", file.Path) + old := []byte(fmt.Sprintf("# gazelle:%s %t", ScalaKeepUnknownDepsDirective, !value)) + new := []byte(fmt.Sprintf("# gazelle:%s %t", ScalaKeepUnknownDepsDirective, value)) + log.Println("OLD:", string(file.Content)) + file.Content = bytes.Replace(file.Content, old, new, -1) + // file.Sync() + log.Println("NEW:", string(file.Content)) + + if err := file.Save(file.Path); err != nil { + return err + } + stat, err := os.Stat(file.Path) + if err != nil { + return err + } + if err := os.WriteFile(file.Path, file.Content, stat.Mode()); err != nil { + return err + } + data, err := os.ReadFile(file.Path) + if err != nil { + return err + } + log.Panicln("saved it!", file.Path, "FILE DATA:\n", string(data)) + } + } + return nil +} + +func MakeTransitiveDep(dep label.Label) *build.StringExpr { + expr := &build.StringExpr{Value: dep.String()} + expr.Comment().Suffix = append(expr.Comment().Suffix, build.Comment{Token: TransitiveCommentToken}) + return expr +} + +// IsTransitiveDep returns whether e is marked with a "# TRANSITIVE" comment. +func IsTransitiveDep(e build.Expr) bool { + for _, c := range e.Comment().Suffix { + text := strings.TrimSpace(c.Token) + if text == TransitiveCommentToken { + return true + } + } + return false +}