diff --git a/internal/verifier/check.go b/internal/verifier/check.go index 3a4bcc4a..ef50bb03 100644 --- a/internal/verifier/check.go +++ b/internal/verifier/check.go @@ -285,6 +285,19 @@ func (verifier *Verifier) CreateInitialTasks() error { verifier.logger.Error().Msgf("%s", err) return err } + // Check for database names and append their collections + if NamespacesContainsDBName(verifier.srcNamespaces) { + namespaces, dbs := ParseDBNamespaces(verifier.srcNamespaces) + + dbNamespaces, err := ListUserCollectionsForDBs(context.Background(), verifier.logger, verifier.srcClient, true /* include views */, dbs) + if err != nil { + verifier.logger.Error().Msgf("Failed to parse database namespaces: %s", err) + return err + } + + verifier.srcNamespaces = append(namespaces, dbNamespaces...) + verifier.SetNamespaceMap() + } } isPrimary, err := verifier.CheckIsPrimary() if err != nil { diff --git a/internal/verifier/list_namespaces.go b/internal/verifier/list_namespaces.go index 2ab0603a..89d92c97 100644 --- a/internal/verifier/list_namespaces.go +++ b/internal/verifier/list_namespaces.go @@ -37,8 +37,14 @@ func ListAllUserCollections(ctx context.Context, logger *logger.Logger, client * } logger.Debug().Msgf("All user databases: %+v", dbNames) + return ListUserCollectionsForDBs(ctx, logger, client, includeViews, dbNames) +} + +func ListUserCollectionsForDBs(ctx context.Context, logger *logger.Logger, client *mongo.Client, includeViews bool, + databases []string) ([]string, error) { + collectionNamespaces := []string{} - for _, dbName := range dbNames { + for _, dbName := range databases { db := client.Database(dbName) filter := bson.D{{"name", bson.D{{"$nin", bson.A{ExcludedSystemCollRegex}}}}} if !includeViews { diff --git a/internal/verifier/util.go b/internal/verifier/util.go index 6d200f4e..bf4ff53c 100644 --- a/internal/verifier/util.go +++ b/internal/verifier/util.go @@ -154,3 +154,27 @@ func GetLastOpTimeAndSyncShardClusterTime( t, i := rawOperationTime.Timestamp() return &primitive.Timestamp{T: t, I: i}, nil } + +func NamespacesContainsDBName(namespaces []string) bool { + for _, namespace := range namespaces { + if strings.Index(namespace, ".") < 0 { + return true + } + } + return false +} + +func ParseDBNamespaces(namespaces []string) ([]string, []string) { + databases := []string{} + parsedNamespaces := []string{} + // strip database names from namespaces - db collections will be appended to the namespaces + for _, name := range namespaces { + db, coll := SplitNamespace(name) + if coll == "" { + databases = append(databases, db) + } else { + parsedNamespaces = append(parsedNamespaces, name) + } + } + return parsedNamespaces, databases +}