diff --git a/base/main_test_bucket_pool_config.go b/base/main_test_bucket_pool_config.go index c07e2cbfa6..bbc1bd8939 100644 --- a/base/main_test_bucket_pool_config.go +++ b/base/main_test_bucket_pool_config.go @@ -15,6 +15,8 @@ import ( "strconv" "testing" "time" + + "github.com/stretchr/testify/require" ) // Bucket names start with a fixed prefix and end with a sequential bucket number and a creation timestamp for uniqueness @@ -213,3 +215,27 @@ func TestClusterPassword() string { } return password } + +// TestRunSGCollectIntegrationTests runs the tests only if a specific environment variable is set. These should always run under jenkins/github actions. +func TestRunSGCollectIntegrationTests(t *testing.T) { + env := "SG_TEST_SGCOLLECT_INTEGRATION" + val, ok := os.LookupEnv(env) + if !ok { + ciEnvVars := []string{ + "CI", // convention by github actions + "JENKINS_URL", // from jenkins + } + for _, ciEnv := range ciEnvVars { + if os.Getenv(ciEnv) != "" { + return + } + } + t.Skip("Skipping sgcollect integration tests - set " + env + "=true to run") + } + + runTests, err := strconv.ParseBool(val) + require.NoError(t, err, "Couldn't parse %s=%s as bool", env, val) + if !runTests { + t.Skip("Skipping sgcollect integration tests - set " + env + "=true to run") + } +} diff --git a/rest/admin_api.go b/rest/admin_api.go index 9eb8bd4926..99a354ffd8 100644 --- a/rest/admin_api.go +++ b/rest/admin_api.go @@ -1637,7 +1637,7 @@ func (h *handler) handleSetLogging() error { func (h *handler) handleSGCollectStatus() error { status := "stopped" - if sgcollectInstance.IsRunning() { + if h.server.SGCollect.IsRunning() { status = "running" } @@ -1647,7 +1647,7 @@ func (h *handler) handleSGCollectStatus() error { } func (h *handler) handleSGCollectCancel() error { - err := sgcollectInstance.Stop() + err := h.server.SGCollect.Stop() if err != nil { return base.HTTPErrorf(http.StatusBadRequest, "Error stopping sgcollect_info: %v", err) } @@ -1664,7 +1664,7 @@ func (h *handler) handleSGCollect() error { return err } - var params sgCollectOptions + var params SGCollectOptions if err = base.JSONUnmarshal(body, ¶ms); err != nil { return base.HTTPErrorf(http.StatusBadRequest, "Unable to parse request body: %v", err) } @@ -1676,11 +1676,24 @@ func (h *handler) handleSGCollect() error { // Populate username and password used by sgcollect_info script for talking to Sync Gateway. params.syncGatewayUsername, params.syncGatewayPassword = h.getBasicAuth() - zipFilename := sgcollectFilename() + addr, err := h.server.getServerAddr(adminServer) + if err != nil { + return base.HTTPErrorf(http.StatusInternalServerError, "Error getting admin server address: %v", err) + } + if h.server.Config.API.HTTPS.TLSCertPath != "" { + addr = "https://" + addr + } else { + addr = "http://" + addr + } + params.adminURL = addr + + zipFilename := SGCollectFilename() logFilePath := h.server.Config.Logging.LogFilePath - if err := sgcollectInstance.Start(logFilePath, h.serialNumber, zipFilename, params); err != nil { + ctx := base.CorrelationIDLogCtx(context.WithoutCancel(h.ctx()), fmt.Sprintf("SGCollect-%03d", h.serialNumber)) + + if err := h.server.SGCollect.Start(ctx, logFilePath, zipFilename, params); err != nil { return base.HTTPErrorf(http.StatusInternalServerError, "Error running sgcollect_info: %v", err) } diff --git a/rest/audit_test.go b/rest/audit_test.go index 4a5bb4efb2..a90eee7d6a 100644 --- a/rest/audit_test.go +++ b/rest/audit_test.go @@ -239,7 +239,7 @@ func TestAuditLoggingFields(t *testing.T) { auditableAction: func(t testing.TB) { headers := map[string]string{ requestInfoHeaderName: `{"extra":"field"}`, - "Authorization": getBasicAuthHeader(requestUser, RestTesterDefaultUserPassword), + "Authorization": GetBasicAuthHeader(t, requestUser, RestTesterDefaultUserPassword), } RequireStatus(t, rt.SendRequestWithHeaders(http.MethodGet, "/db/", "", headers), http.StatusOK) }, @@ -442,7 +442,7 @@ func TestAuditLoggingFields(t *testing.T) { name: "metrics request authenticated", auditableAction: func(t testing.TB) { headers := map[string]string{ - "Authorization": getBasicAuthHeader(base.TestClusterUsername(), base.TestClusterPassword()), + "Authorization": GetBasicAuthHeader(t, base.TestClusterUsername(), base.TestClusterPassword()), } RequireStatus(t, rt.SendMetricsRequestWithHeaders(http.MethodGet, "/_metrics", "", headers), http.StatusOK) }, @@ -473,7 +473,7 @@ func TestAuditLoggingFields(t *testing.T) { t.Skip("Skipping subtest that requires admin auth to be enabled") } headers := map[string]string{ - "Authorization": getBasicAuthHeader("notauser", base.TestClusterPassword()), + "Authorization": GetBasicAuthHeader(t, "notauser", base.TestClusterPassword()), } RequireStatus(t, rt.SendMetricsRequestWithHeaders(http.MethodGet, "/_metrics", "", headers), http.StatusUnauthorized) }, @@ -724,7 +724,7 @@ func TestEffectiveUserID(t *testing.T) { ) reqHeaders := map[string]string{ "user_header": fmt.Sprintf(`{"%s": "%s", "%s":"%s"}`, headerDomain, cnfDomain, headerUser, cnfUser), - "Authorization": getBasicAuthHeader(realUser, RestTesterDefaultUserPassword), + "Authorization": GetBasicAuthHeader(t, realUser, RestTesterDefaultUserPassword), } rt := NewRestTester(t, &RestTesterConfig{ diff --git a/rest/server_context.go b/rest/server_context.go index 10681b6343..f0a87e6bfe 100644 --- a/rest/server_context.go +++ b/rest/server_context.go @@ -86,6 +86,8 @@ type ServerContext struct { DatabaseInitManager *DatabaseInitManager // Manages database initialization (index creation and readiness) independent of database stop/start/reload, when using persistent config ActiveReplicationsCounter invalidDatabaseConfigTracking invalidDatabaseConfigs + // handle sgcollect processes for a given Server + SGCollect *sgCollect } type ActiveReplicationsCounter struct { @@ -163,6 +165,7 @@ func NewServerContext(ctx context.Context, config *StartupConfig, persistentConf BootstrapContext: &bootstrapContext{sgVersion: *base.ProductVersion}, hasStarted: make(chan struct{}), _httpServers: map[serverType]*serverInfo{}, + SGCollect: newSGCollect(ctx), } sc.invalidDatabaseConfigTracking = invalidDatabaseConfigs{ dbNames: map[string]*invalidConfigInfo{}, diff --git a/rest/sgcollect.go b/rest/sgcollect.go index b6d3e69e21..97fdd4a2ba 100644 --- a/rest/sgcollect.go +++ b/rest/sgcollect.go @@ -21,6 +21,7 @@ import ( "path/filepath" "regexp" "runtime" + "slices" "sync/atomic" "time" @@ -35,48 +36,52 @@ var ( ErrSGCollectInfoNotRunning = errors.New("not running") validateTicketPattern = regexp.MustCompile(`\d{1,7}`) - - sgPath, sgCollectPath, sgCollectPathErr = sgCollectPaths() - sgcollectInstance = sgCollect{ - status: base.Ptr(sgStopped), - sgPath: sgPath, - sgCollectPath: sgCollectPath, - pathError: sgCollectPathErr, - } ) const ( sgStopped uint32 = iota sgRunning - defaultSGUploadHost = "https://uploads.couchbase.com" + DefaultSGCollectUploadHost = "https://uploads.couchbase.com" ) +// sgCollectOutputStream handles stderr/stdout from a running sgcollect process. +type sgCollectOutputStream struct { + stdoutPipeWriter io.WriteCloser // Pipe writer for stdout + stderrPipeWriter io.WriteCloser // Pipe writer for stderr + stderrPipeReader io.Reader // Pipe reader for stderr + stdoutPipeReader io.Reader // Pipe reader for stdout + stdoutDoneChan chan struct{} // Channel to signal stdout processing completion + stderrDoneChan chan struct{} // Channel to signal stderr processing completion +} + +// sgCollect manages the state of a running sgcollect_info process. type sgCollect struct { - cancel context.CancelFunc - status *uint32 - sgPath string - sgCollectPath string - pathError error - context context.Context + cancel context.CancelFunc // Function to cancel a running sgcollect_info process, set when status == sgRunning + status *uint32 + sgPath string // Path to the Sync Gateway executable + SGCollectPath []string // Path to the sgcollect_info executable + SGCollectPathErr error // Error if sgcollect_info path could not be determined + Stdout io.Writer // test seam, is nil in production + Stderr io.Writer // test seam, is nil in production } // Start will attempt to start sgcollect_info, if another is not already running. -func (sg *sgCollect) Start(logFilePath string, ctxSerialNumber uint64, zipFilename string, params sgCollectOptions) error { +func (sg *sgCollect) Start(ctx context.Context, logFilePath string, zipFilename string, params SGCollectOptions) error { if atomic.LoadUint32(sg.status) == sgRunning { return ErrSGCollectInfoAlreadyRunning } // Return error if there is any failure while obtaining sgCollectPaths. - if sg.pathError != nil { - return sg.pathError + if sg.SGCollectPathErr != nil { + return sg.SGCollectPathErr } if params.OutputDirectory == "" { // If no output directory specified, default to the configured LogFilePath if logFilePath != "" { params.OutputDirectory = logFilePath - base.DebugfCtx(sg.context, base.KeyAdmin, "sgcollect_info: no output directory specified, using LogFilePath: %v", params.OutputDirectory) + base.DebugfCtx(ctx, base.KeyAdmin, "sgcollect_info: no output directory specified, using LogFilePath: %v", params.OutputDirectory) } else { // If LogFilePath is not set, and DefaultLogFilePath is not set via a service script, error out. return errors.New("no output directory or LogFilePath specified") @@ -90,69 +95,46 @@ func (sg *sgCollect) Start(logFilePath string, ctxSerialNumber uint64, zipFilena zipPath := filepath.Join(params.OutputDirectory, zipFilename) - args := params.Args() - args = append(args, "--sync-gateway-executable", sgPath) - args = append(args, zipPath) + cmdline := slices.Clone(sg.SGCollectPath) + cmdline = append(cmdline, params.Args()...) + cmdline = append(cmdline, "--sync-gateway-executable", sg.sgPath) + cmdline = append(cmdline, zipPath) - ctx := base.CorrelationIDLogCtx(context.Background(), fmt.Sprintf("SGCollect-%03d", ctxSerialNumber)) + ctx, sg.cancel = context.WithCancel(ctx) + cmd := exec.CommandContext(ctx, cmdline[0], cmdline[1:]...) - sg.context, sg.cancel = context.WithCancel(ctx) - cmd := exec.CommandContext(sg.context, sgCollectPath, args...) - - // Send command stderr/stdout to pipes - stderrPipeReader, stderrPipeWriter := io.Pipe() - cmd.Stderr = stderrPipeWriter - stdoutPipeReader, stdoutpipeWriter := io.Pipe() - cmd.Stdout = stdoutpipeWriter + outStream := newSGCollectOutputStream(ctx, sg.Stdout, sg.Stderr) + cmd.Stdout = outStream.stdoutPipeWriter + cmd.Stderr = outStream.stderrPipeWriter if err := cmd.Start(); err != nil { + outStream.Close(ctx) return err } atomic.StoreUint32(sg.status, sgRunning) startTime := time.Now() - base.InfofCtx(sg.context, base.KeyAdmin, "sgcollect_info started with args: %v", base.UD(args)) - - // Stream sgcollect_info stderr to warn logs - go func() { - scanner := bufio.NewScanner(stderrPipeReader) - for scanner.Scan() { - base.InfofCtx(sg.context, base.KeyAll, "sgcollect_info: %v", scanner.Text()) - } - if err := scanner.Err(); err != nil { - base.ErrorfCtx(sg.context, "sgcollect_info: unexpected error: %v", err) - } - }() - - // Stream sgcollect_info stdout to debug logs - go func() { - scanner := bufio.NewScanner(stdoutPipeReader) - for scanner.Scan() { - base.InfofCtx(sg.context, base.KeyAll, "sgcollect_info: %v", scanner.Text()) - } - if err := scanner.Err(); err != nil { - base.ErrorfCtx(sg.context, "sgcollect_info: unexpected error: %v", err) - } - }() + base.InfofCtx(ctx, base.KeyAdmin, "sgcollect_info started with output zip: %v", base.UD(zipPath)) go func() { // Blocks until command finishes err := cmd.Wait() + outStream.Close(ctx) atomic.StoreUint32(sg.status, sgStopped) duration := time.Since(startTime) if err != nil { if err.Error() == "signal: killed" { - base.InfofCtx(sg.context, base.KeyAdmin, "sgcollect_info cancelled after %v", duration) + base.InfofCtx(ctx, base.KeyAdmin, "sgcollect_info cancelled after %v", duration) return } - base.ErrorfCtx(sg.context, "sgcollect_info failed after %v with reason: %v. Check warning level logs for more information.", duration, err) + base.ErrorfCtx(ctx, "sgcollect_info failed after %v with reason: %v. Check warning level logs for more information.", duration, err) return } - base.InfofCtx(sg.context, base.KeyAdmin, "sgcollect_info finished successfully after %v", duration) + base.InfofCtx(ctx, base.KeyAdmin, "sgcollect_info finished successfully after %v", duration) }() return nil @@ -175,7 +157,7 @@ func (sg *sgCollect) IsRunning() bool { return atomic.LoadUint32(sg.status) == sgRunning } -type sgCollectOptions struct { +type SGCollectOptions struct { RedactLevel string `json:"redact_level,omitempty"` RedactSalt string `json:"redact_salt,omitempty"` OutputDirectory string `json:"output_dir,omitempty"` @@ -190,6 +172,7 @@ type sgCollectOptions struct { // We'll set them from the request's basic auth. syncGatewayUsername string syncGatewayPassword string + adminURL string // URL to the Sync Gateway admin API. } // validateOutputDirectory will check that the given path exists, and is a directory. @@ -212,8 +195,82 @@ func validateOutputDirectory(dir string) error { return nil } +// newSGCollectOutputStream creates an instance to monitor stdout and stderr. Stdout is logged at Debug and Stderr at Info. extraStdout and extraStderr are optional writers used for testing only. +func newSGCollectOutputStream(ctx context.Context, extraStdout io.Writer, extraStderr io.Writer) *sgCollectOutputStream { + stderrPipeReader, stderrPipeWriter := io.Pipe() + stdoutPipeReader, stdoutPipeWriter := io.Pipe() + o := &sgCollectOutputStream{ + stdoutPipeWriter: stdoutPipeWriter, + stderrPipeWriter: stderrPipeWriter, + stderrPipeReader: stderrPipeReader, + stdoutPipeReader: stdoutPipeReader, + stdoutDoneChan: make(chan struct{}), + stderrDoneChan: make(chan struct{}), + } + go func() { + defer close(o.stderrDoneChan) + scanner := bufio.NewScanner(stderrPipeReader) + for scanner.Scan() { + text := scanner.Text() + base.InfofCtx(ctx, base.KeyAll, "sgcollect_info: %v", text) + if extraStderr != nil { + _, err := extraStderr.Write([]byte(text + "\n")) + if err != nil { + base.ErrorfCtx(ctx, "sgcollect_info: failed to write to stderr pipe: %v", err) + } + } + } + if err := scanner.Err(); err != nil { + base.ErrorfCtx(ctx, "sgcollect_info: unexpected error: %v", err) + } + }() + + // Stream sgcollect_info stdout to debug logs + go func() { + defer close(o.stdoutDoneChan) + scanner := bufio.NewScanner(stdoutPipeReader) + for scanner.Scan() { + text := scanner.Text() + base.InfofCtx(ctx, base.KeyAll, "sgcollect_info: %v", text) + if extraStdout != nil { + _, err := extraStdout.Write([]byte(text + "\n")) + if err != nil { + base.ErrorfCtx(ctx, "sgcollect_info: failed to write to stdout pipe: %v", err) + } + } + } + if err := scanner.Err(); err != nil { + base.ErrorfCtx(ctx, "sgcollect_info: unexpected error: %v", err) + } + }() + return o +} + +// Close the output streams, required to close goroutines when sgCollectOutputStream is created. +func (o *sgCollectOutputStream) Close(ctx context.Context) { + err := o.stderrPipeWriter.Close() + if err != nil { + base.WarnfCtx(ctx, "sgcollect_info: failed to close stderr pipe writer: %v", err) + } + err = o.stdoutPipeWriter.Close() + if err != nil { + base.WarnfCtx(ctx, "sgcollect_info: failed to close stdout pipe writer: %v", err) + } + // Wait for the goroutines to finish processing the output streams, or exit after 5 seconds. + select { + case <-o.stdoutDoneChan: + case <-time.After(5 * time.Second): + base.WarnfCtx(ctx, "sgcollect_info: timed out waiting for stdout processing to finish") + } + select { + case <-o.stderrDoneChan: + case <-time.After(5 * time.Second): + base.WarnfCtx(ctx, "sgcollect_info: timed out waiting for stderr processing to finish") + } +} + // Validate ensures the options are OK to use in sgcollect_info. -func (c *sgCollectOptions) Validate() error { +func (c *SGCollectOptions) Validate() error { var errs *base.MultiError if c.OutputDirectory != "" { @@ -235,7 +292,7 @@ func (c *sgCollectOptions) Validate() error { } // Default uploading to support bucket if upload_host is not specified. if c.UploadHost == "" { - c.UploadHost = defaultSGUploadHost + c.UploadHost = DefaultSGCollectUploadHost } } else { // These fields suggest the user actually wanted to upload, @@ -259,7 +316,7 @@ func (c *sgCollectOptions) Validate() error { } // Args returns a set of arguments to pass to sgcollect_info. -func (c *sgCollectOptions) Args() []string { +func (c *SGCollectOptions) Args() []string { var args = make([]string, 0) if c.Upload { @@ -297,23 +354,26 @@ func (c *sgCollectOptions) Args() []string { if c.KeepZip { args = append(args, "--keep-zip") } - + if c.adminURL != "" { + args = append(args, "--sync-gateway-url", c.adminURL) + } return args } -// sgCollectPaths attempts to return the absolute paths to Sync Gateway and to sgcollect_info binaries. -func sgCollectPaths() (sgBinary, sgCollectBinary string, err error) { +// sgCollectPaths attempts to return the absolute paths to Sync Gateway and to sgcollect_info binaries. Returns an error if either cannot be found. +// +// The sgcollect_info return value is allowed to be a list of strings for testing, where is it , or an error if not. +func sgCollectPaths(ctx context.Context) (sgBinary string, sgCollect []string, err error) { sgBinary, err = os.Executable() if err != nil { - return "", "", err + return "", nil, err } sgBinary, err = filepath.Abs(sgBinary) if err != nil { - return "", "", err + return "", nil, err } - logCtx := context.TODO() // this is global variable at init, we can't pass it in easily hasBinDir := true sgCollectPath := filepath.Join("tools", "sgcollect_info") @@ -324,6 +384,7 @@ func sgCollectPaths() (sgBinary, sgCollectBinary string, err error) { } for { + var sgCollectBinary string if hasBinDir { sgCollectBinary = filepath.Join(filepath.Dir(filepath.Dir(sgBinary)), sgCollectPath) } else { @@ -331,7 +392,7 @@ func sgCollectPaths() (sgBinary, sgCollectBinary string, err error) { } // Check sgcollect_info exists at the path we guessed. - base.DebugfCtx(logCtx, base.KeyAdmin, "Checking sgcollect_info binary exists at: %v", sgCollectBinary) + base.DebugfCtx(ctx, base.KeyAdmin, "Checking sgcollect_info binary exists at: %v", sgCollectBinary) _, err = os.Stat(sgCollectBinary) if err != nil { @@ -341,15 +402,15 @@ func sgCollectPaths() (sgBinary, sgCollectBinary string, err error) { continue } - return "", "", err + return "", nil, err } - return sgBinary, sgCollectBinary, nil + return sgBinary, []string{sgCollectBinary}, nil } } -// sgcollectFilename returns a Windows-safe filename for sgcollect_info zip files. -func sgcollectFilename() string { +// SGCollectFilename returns a Windows-safe filename for sgcollect_info zip files. +func SGCollectFilename() string { // get timestamp timestamp := time.Now().UTC().Format("2006-01-02t150405") @@ -371,3 +432,12 @@ func sgcollectFilename() string { return filename } + +// newSGCollect creates a new sgCollect instance. +func newSGCollect(ctx context.Context) *sgCollect { + sgCollectInstance := sgCollect{ + status: base.Ptr(sgStopped), + } + sgCollectInstance.sgPath, sgCollectInstance.SGCollectPath, sgCollectInstance.SGCollectPathErr = sgCollectPaths(ctx) + return &sgCollectInstance +} diff --git a/rest/sgcollect_test.go b/rest/sgcollect_test.go deleted file mode 100644 index 6b729c131d..0000000000 --- a/rest/sgcollect_test.go +++ /dev/null @@ -1,224 +0,0 @@ -/* -Copyright 2018-Present Couchbase, Inc. - -Use of this software is governed by the Business Source License included in -the file licenses/BSL-Couchbase.txt. As of the Change Date specified in that -file, in accordance with the Business Source License, use of this software will -be governed by the Apache License, Version 2.0, included in the file -licenses/APL2.txt. -*/ - -package rest - -import ( - "fmt" - "os" - "regexp" - "strings" - "testing" - - "github.com/couchbase/sync_gateway/base" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSgcollectFilename(t *testing.T) { - filename := sgcollectFilename() - - // Check it doesn't have forbidden chars - assert.False(t, strings.ContainsAny(filename, "\\/:*?\"<>|")) - - pattern := `^sgcollectinfo\-\d{4}\-\d{2}\-\d{2}t\d{6}\-sga?@(\d{1,3}\.){4}zip$` - matched, err := regexp.Match(pattern, []byte(filename)) - assert.NoError(t, err, "unexpected regexp error") - assert.True(t, matched, fmt.Sprintf("Filename: %s did not match pattern: %s", filename, pattern)) -} - -func TestSgcollectOptionsValidateValid(t *testing.T) { - tests := []struct { - name string - options *sgCollectOptions - }{ - { - name: "defaults", - options: &sgCollectOptions{}, - }, - { - name: "upload with customer name", - options: &sgCollectOptions{Upload: true, Customer: "alice"}, - }, - { - name: "custom upload with customer name", - options: &sgCollectOptions{Upload: true, Customer: "alice", UploadHost: "example.org/custom-s3-bucket-url"}, - }, - { - name: "directory that exists", - options: &sgCollectOptions{OutputDirectory: "."}, - }, - { - name: "valid redact level", - options: &sgCollectOptions{RedactLevel: "partial"}, - }, - { - name: "valid keep_zip option", - options: &sgCollectOptions{KeepZip: true}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - assert.Nil(t, test.options.Validate()) - }) - } -} - -func TestSgcollectOptionsValidateInvalid(t *testing.T) { - binaryPath, err := os.Executable() - assert.NoError(t, err, "unexpected error getting executable path") - - tests := []struct { - name string - options *sgCollectOptions - errContains string - }{ - { - name: "directory doesn't exist", - options: &sgCollectOptions{OutputDirectory: "/path/to/output/dir"}, - errContains: "no such file or directory", - }, - { - name: "path not a directory", - options: &sgCollectOptions{OutputDirectory: binaryPath}, - errContains: "not a directory", - }, - { - name: "invalid redact level", - options: &sgCollectOptions{RedactLevel: "asdf"}, - errContains: "'redact_level' must be", - }, - { - name: "no customer", - options: &sgCollectOptions{Upload: true}, - errContains: "'customer' must be set", - }, - { - name: "no customer with ticket", - options: &sgCollectOptions{Upload: true, Ticket: "12345"}, - errContains: "'customer' must be set", - }, - { - name: "customer no upload", - options: &sgCollectOptions{Upload: false, Customer: "alice"}, - errContains: "'upload' must be set to true", - }, - { - name: "ticket no upload", - options: &sgCollectOptions{Upload: false, Ticket: "12345"}, - errContains: "'upload' must be set to true", - }, - { - name: "customer upload host no upload", - options: &sgCollectOptions{Upload: false, Customer: "alice", UploadHost: "example.org/custom-s3-bucket-url"}, - errContains: "'upload' must be set to true", - }, - { - name: "non-digit ticket number", - options: &sgCollectOptions{Upload: true, Customer: "alice", Ticket: "abc"}, - errContains: "'ticket' must be", - }, - } - - for _, test := range tests { - t.Run(test.name, func(ts *testing.T) { - errs := test.options.Validate() - require.NotNil(t, errs) - multiError, ok := errs.(*base.MultiError) - require.True(t, ok) - - // make sure we get at least one error for the given invalid options. - require.True(t, multiError.Len() > 0) - - // check each error matches the expected string. - for _, err := range multiError.Errors { - assert.Contains(ts, err.Error(), test.errContains) - } - }) - } - -} - -func TestSgcollectOptionsArgs(t *testing.T) { - binPath, err := os.Executable() - assert.NoError(t, err, "unexpected error getting executable path") - - tests := []struct { - options *sgCollectOptions - expectedArgs []string - }{ - { - options: &sgCollectOptions{}, - expectedArgs: []string{}, - }, - { - options: &sgCollectOptions{Upload: true}, - expectedArgs: []string{"--upload-host", defaultSGUploadHost}, - }, - { - options: &sgCollectOptions{Upload: true, Ticket: "123456", KeepZip: true}, - expectedArgs: []string{"--upload-host", defaultSGUploadHost, "--ticket", "123456", "--keep-zip"}, - }, - { - options: &sgCollectOptions{Upload: true, RedactLevel: "partial"}, - expectedArgs: []string{"--upload-host", defaultSGUploadHost, "--log-redaction-level", "partial"}, - }, - { - options: &sgCollectOptions{Upload: true, RedactLevel: "partial", RedactSalt: "asdf"}, - expectedArgs: []string{"--upload-host", defaultSGUploadHost, "--log-redaction-level", "partial", "--log-redaction-salt", "asdf"}, - }, - { - // Check that the default upload host is set - options: &sgCollectOptions{Upload: true, Customer: "alice"}, - expectedArgs: []string{"--upload-host", defaultSGUploadHost, "--customer", "alice"}, - }, - { - options: &sgCollectOptions{Upload: true, Customer: "alice", UploadHost: "example.org/custom-s3-bucket-url"}, - expectedArgs: []string{"--upload-host", "example.org/custom-s3-bucket-url", "--customer", "alice"}, - }, - { - options: &sgCollectOptions{Upload: true, Customer: "alice", UploadHost: "https://example.org/custom-s3-bucket-url", UploadProxy: "http://proxy.example.org:8080"}, - expectedArgs: []string{"--upload-host", "https://example.org/custom-s3-bucket-url", "--upload-proxy", "http://proxy.example.org:8080", "--customer", "alice"}, - }, - { - // Upload false, so don't pass upload host through. same for keep zip - options: &sgCollectOptions{Upload: false, Customer: "alice", UploadHost: "example.org/custom-s3-bucket-url", KeepZip: false}, - expectedArgs: []string{"--customer", "alice"}, - }, - { - // Directory exists - options: &sgCollectOptions{OutputDirectory: "."}, - expectedArgs: []string{}, - }, - { - // Directory doesn't exist - options: &sgCollectOptions{OutputDirectory: "/path/to/output/dir"}, - expectedArgs: []string{}, - }, - { - // Path not a directory - options: &sgCollectOptions{OutputDirectory: binPath}, - expectedArgs: []string{}, - }, - } - - for i, test := range tests { - t.Run(fmt.Sprintf("%d", i), func(ts *testing.T) { - // We'll run validate to populate some default fields, - // but ignore errors raised by it for this test. - _ = test.options.Validate() - - args := test.options.Args() - assert.Equal(ts, test.expectedArgs, args) - }) - } -} diff --git a/rest/sgcollecttest/main_test.go b/rest/sgcollecttest/main_test.go new file mode 100644 index 0000000000..6b4de1b835 --- /dev/null +++ b/rest/sgcollecttest/main_test.go @@ -0,0 +1,25 @@ +/* +Copyright 2020-Present Couchbase, Inc. + +Use of this software is governed by the Business Source License included in +the file licenses/BSL-Couchbase.txt. As of the Change Date specified in that +file, in accordance with the Business Source License, use of this software will +be governed by the Apache License, Version 2.0, included in the file +licenses/APL2.txt. +*/ + +package sgcollecttest + +import ( + "context" + "testing" + + "github.com/couchbase/sync_gateway/base" + "github.com/couchbase/sync_gateway/rest" +) + +func TestMain(m *testing.M) { + ctx := context.Background() // start of test process + tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 8192, NumCollectionsPerBucket: 3} + rest.TestBucketPoolRestWithIndexes(ctx, m, tbpOptions) +} diff --git a/rest/sgcollecttest/sgcollect_test.go b/rest/sgcollecttest/sgcollect_test.go new file mode 100644 index 0000000000..e43d06c71d --- /dev/null +++ b/rest/sgcollecttest/sgcollect_test.go @@ -0,0 +1,292 @@ +/* +Copyright 2018-Present Couchbase, Inc. + +Use of this software is governed by the Business Source License included in +the file licenses/BSL-Couchbase.txt. As of the Change Date specified in that +file, in accordance with the Business Source License, use of this software will +be governed by the Apache License, Version 2.0, included in the file +licenses/APL2.txt. +*/ + +package sgcollecttest + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + "regexp" + "runtime" + "strings" + "testing" + "time" + + "github.com/couchbase/sync_gateway/base" + "github.com/couchbase/sync_gateway/rest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSgcollectFilename(t *testing.T) { + filename := rest.SGCollectFilename() + + // Check it doesn't have forbidden chars + assert.False(t, strings.ContainsAny(filename, "\\/:*?\"<>|")) + + pattern := `^sgcollectinfo\-\d{4}\-\d{2}\-\d{2}t\d{6}\-sga?@(\d{1,3}\.){4}zip$` + matched, err := regexp.Match(pattern, []byte(filename)) + assert.NoError(t, err, "unexpected regexp error") + assert.True(t, matched, fmt.Sprintf("Filename: %s did not match pattern: %s", filename, pattern)) +} + +func TestSgcollectOptionsValidateValid(t *testing.T) { + tests := []struct { + name string + options *rest.SGCollectOptions + }{ + { + name: "defaults", + options: &rest.SGCollectOptions{}, + }, + { + name: "upload with customer name", + options: &rest.SGCollectOptions{Upload: true, Customer: "alice"}, + }, + { + name: "custom upload with customer name", + options: &rest.SGCollectOptions{Upload: true, Customer: "alice", UploadHost: "example.org/custom-s3-bucket-url"}, + }, + { + name: "directory that exists", + options: &rest.SGCollectOptions{OutputDirectory: "."}, + }, + { + name: "valid redact level", + options: &rest.SGCollectOptions{RedactLevel: "partial"}, + }, + { + name: "valid keep_zip option", + options: &rest.SGCollectOptions{KeepZip: true}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Nil(t, test.options.Validate()) + }) + } +} + +func TestSgcollectOptionsValidateInvalid(t *testing.T) { + binaryPath, err := os.Executable() + assert.NoError(t, err, "unexpected error getting executable path") + + tests := []struct { + name string + options *rest.SGCollectOptions + errContains string + }{ + { + name: "directory doesn't exist", + options: &rest.SGCollectOptions{OutputDirectory: "/path/to/output/dir"}, + errContains: "no such file or directory", + }, + { + name: "path not a directory", + options: &rest.SGCollectOptions{OutputDirectory: binaryPath}, + errContains: "not a directory", + }, + { + name: "invalid redact level", + options: &rest.SGCollectOptions{RedactLevel: "asdf"}, + errContains: "'redact_level' must be", + }, + { + name: "no customer", + options: &rest.SGCollectOptions{Upload: true}, + errContains: "'customer' must be set", + }, + { + name: "no customer with ticket", + options: &rest.SGCollectOptions{Upload: true, Ticket: "12345"}, + errContains: "'customer' must be set", + }, + { + name: "customer no upload", + options: &rest.SGCollectOptions{Upload: false, Customer: "alice"}, + errContains: "'upload' must be set to true", + }, + { + name: "ticket no upload", + options: &rest.SGCollectOptions{Upload: false, Ticket: "12345"}, + errContains: "'upload' must be set to true", + }, + { + name: "customer upload host no upload", + options: &rest.SGCollectOptions{Upload: false, Customer: "alice", UploadHost: "example.org/custom-s3-bucket-url"}, + errContains: "'upload' must be set to true", + }, + { + name: "non-digit ticket number", + options: &rest.SGCollectOptions{Upload: true, Customer: "alice", Ticket: "abc"}, + errContains: "'ticket' must be", + }, + } + + for _, test := range tests { + t.Run(test.name, func(ts *testing.T) { + errs := test.options.Validate() + require.NotNil(t, errs) + multiError, ok := errs.(*base.MultiError) + require.True(t, ok) + + // make sure we get at least one error for the given invalid options. + require.True(t, multiError.Len() > 0) + + // check each error matches the expected string. + for _, err := range multiError.Errors { + assert.Contains(ts, err.Error(), test.errContains) + } + }) + } + +} + +func TestSgcollectOptionsArgs(t *testing.T) { + binPath, err := os.Executable() + assert.NoError(t, err, "unexpected error getting executable path") + + tests := []struct { + options *rest.SGCollectOptions + expectedArgs []string + }{ + { + options: &rest.SGCollectOptions{}, + expectedArgs: []string{}, + }, + { + options: &rest.SGCollectOptions{Upload: true}, + expectedArgs: []string{"--upload-host", rest.DefaultSGCollectUploadHost}, + }, + { + options: &rest.SGCollectOptions{Upload: true, Ticket: "123456", KeepZip: true}, + expectedArgs: []string{"--upload-host", rest.DefaultSGCollectUploadHost, "--ticket", "123456", "--keep-zip"}, + }, + { + options: &rest.SGCollectOptions{Upload: true, RedactLevel: "partial"}, + expectedArgs: []string{"--upload-host", rest.DefaultSGCollectUploadHost, "--log-redaction-level", "partial"}, + }, + { + options: &rest.SGCollectOptions{Upload: true, RedactLevel: "partial", RedactSalt: "asdf"}, + expectedArgs: []string{"--upload-host", rest.DefaultSGCollectUploadHost, "--log-redaction-level", "partial", "--log-redaction-salt", "asdf"}, + }, + { + // Check that the default upload host is set + options: &rest.SGCollectOptions{Upload: true, Customer: "alice"}, + expectedArgs: []string{"--upload-host", rest.DefaultSGCollectUploadHost, "--customer", "alice"}, + }, + { + options: &rest.SGCollectOptions{Upload: true, Customer: "alice", UploadHost: "example.org/custom-s3-bucket-url"}, + expectedArgs: []string{"--upload-host", "example.org/custom-s3-bucket-url", "--customer", "alice"}, + }, + { + options: &rest.SGCollectOptions{Upload: true, Customer: "alice", UploadHost: "https://example.org/custom-s3-bucket-url", UploadProxy: "http://proxy.example.org:8080"}, + expectedArgs: []string{"--upload-host", "https://example.org/custom-s3-bucket-url", "--upload-proxy", "http://proxy.example.org:8080", "--customer", "alice"}, + }, + { + // Upload false, so don't pass upload host through. same for keep zip + options: &rest.SGCollectOptions{Upload: false, Customer: "alice", UploadHost: "example.org/custom-s3-bucket-url", KeepZip: false}, + expectedArgs: []string{"--customer", "alice"}, + }, + { + // Directory exists + options: &rest.SGCollectOptions{OutputDirectory: "."}, + expectedArgs: []string{}, + }, + { + // Directory doesn't exist + options: &rest.SGCollectOptions{OutputDirectory: "/path/to/output/dir"}, + expectedArgs: []string{}, + }, + { + // Path not a directory + options: &rest.SGCollectOptions{OutputDirectory: binPath}, + expectedArgs: []string{}, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(ts *testing.T) { + // We'll run validate to populate some default fields, + // but ignore errors raised by it for this test. + _ = test.options.Validate() + + args := test.options.Args() + assert.Equal(ts, test.expectedArgs, args) + }) + } +} + +func TestSGCollectIntegration(t *testing.T) { + base.TestRunSGCollectIntegrationTests(t) + if runtime.GOOS == "windows" { + t.Skip("Skipping sgcollect_info integration test on Windows - currently flaky when running wmic product get name, version, which can take 7+ minutes") + } + base.LongRunningTest(t) // this test is very long, and somewhat fragile, since it involves relying on the sgcollect_info tool to run successfully, which requires system python + cwd, err := os.Getwd() + require.NoError(t, err) + config := rest.BootstrapStartupConfigForTest(t) + sc, closeFn := rest.StartServerWithConfig(t, &config) + defer closeFn() + + outputs := map[string]*strings.Builder{ + "stdout": &strings.Builder{}, + "stderr": &strings.Builder{}, + } + + sc.SGCollect.Stdout = outputs["stdout"] + sc.SGCollect.Stderr = outputs["stderr"] + python := "python3" + if runtime.GOOS == "windows" { + python = "python" + } + sc.SGCollect.SGCollectPath = []string{python, filepath.Join(cwd, "../../tools/sgcollect_info")} + sc.SGCollect.SGCollectPathErr = nil + validAuth := map[string]string{ + "Authorization": rest.GetBasicAuthHeader(t, base.TestClusterUsername(), base.TestClusterPassword()), + } + options := rest.SGCollectOptions{ + OutputDirectory: t.TempDir(), + } + resp := rest.BootstrapAdminRequestWithHeaders(t, sc, http.MethodPost, "/_sgcollect_info", string(base.MustJSONMarshal(t, options)), validAuth) + resp.RequireStatus(http.StatusOK) + + var statusResponse struct { + Status string + } + + defer func() { + if statusResponse.Status == "stopped" { + return + } + resp := rest.BootstrapAdminRequestWithHeaders(t, sc, http.MethodDelete, "/_sgcollect_info", "", validAuth) + resp.AssertStatus(http.StatusOK) + }() + + require.EventuallyWithT(t, func(c *assert.CollectT) { + resp := rest.BootstrapAdminRequestWithHeaders(t, sc, http.MethodGet, "/_sgcollect_info", "", validAuth) + resp.AssertStatus(http.StatusOK) + resp.Unmarshal(&statusResponse) + assert.Equal(c, "stopped", statusResponse.Status) + }, 7*time.Minute, 2*time.Second, "sgcollect_info did not stop running in time") + + for name, stream := range outputs { + output := stream.String() + assert.NotContains(t, output, "Exception", "found in %s", name) + assert.NotContains(t, output, "WARNING", "found in %s", name) + assert.NotContains(t, output, "Error", "found in %s", name) + assert.NotContains(t, output, "Errno", "found in %s", name) + assert.NotContains(t, output, "Fail", "found in %s", name) + } +} diff --git a/rest/utilities_testing_user.go b/rest/utilities_testing_user.go index f52143dab2..e4c35a2e6b 100644 --- a/rest/utilities_testing_user.go +++ b/rest/utilities_testing_user.go @@ -76,6 +76,6 @@ func DeleteUser(t *testing.T, httpClient *http.Client, serverURL, username strin require.NoError(t, resp.Body.Close(), "Error closing response body") } -func getBasicAuthHeader(username, password string) string { +func GetBasicAuthHeader(_ testing.TB, username, password string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) } diff --git a/tools-tests/sgcollect_info_test.py b/tools-tests/sgcollect_info_test.py index 69458848e8..d12995d45a 100644 --- a/tools-tests/sgcollect_info_test.py +++ b/tools-tests/sgcollect_info_test.py @@ -7,8 +7,12 @@ # the file licenses/APL2.txt. import io +import os import pathlib -import unittest +import sys +import unittest.mock +import urllib.error +from typing import Optional import pytest import sgcollect @@ -30,8 +34,8 @@ def test_make_collect_logs_tasks(config, tmpdir): "sgcollect.urlopen_with_basic_auth", return_value=io.BytesIO( config.format( - tmpdir=str(tmpdir).replace("\\", "\\\\"), - log_file=str(log_file).replace("\\", "\\\\"), + tmpdir=normalize_path_for_json(tmpdir), + log_file=normalize_path_for_json(log_file), ).encode("utf-8") ), ): @@ -54,7 +58,7 @@ def test_make_collect_logs_heap_profile(tmpdir): "sgcollect.urlopen_with_basic_auth", return_value=io.BytesIO( '{{"logfilepath": "{logpath}"}}'.format( - logpath=str(tmpdir).replace("\\", "\\\\") + logpath=normalize_path_for_json(tmpdir), ).encode("utf-8") ), ): @@ -89,8 +93,8 @@ def test_make_collect_logs_tasks_duplicate_files(should_redact, tmp_path): "sgcollect.urlopen_with_basic_auth", return_value=io.BytesIO( config.format( - tmpdir1=str(tmpdir1).replace("\\", "\\\\"), - tmpdir2=str(tmpdir2).replace("\\", "\\\\"), + tmpdir1=normalize_path_for_json(tmpdir1), + tmpdir2=normalize_path_for_json(tmpdir2), ).encode("utf-8") ), ): @@ -146,3 +150,207 @@ def test_get_sgcollect_options_task(tmp_path, cmdline): ).read_text() assert "sync_gateway_password" not in output assert f"args: {args}" in output + + +def test_get_paths_from_expvars_no_url() -> None: + assert (None, None) == sgcollect.get_paths_from_expvars( + sg_url="", sg_username="", sg_password="" + ) + + +@pytest.mark.parametrize( + "expvar_output,expected_sg_path,expected_config_path", + [ + (b"", None, None), + (b"{}", None, None), + (b'{"cmdline": []}', None, None), + (b'{"cmdline": ["filename"]}', "filename", None), + ( + b'{"cmdline": ["fake_sync_gateway", "real_file.txt"]}', + "fake_sync_gateway", + None, + ), + ( + b'{"cmdline": ["fake_sync_gateway", "-json", "fake_sync_gateway_config.json"]}', + "fake_sync_gateway", + # platform difference is https://github.com/python/cpython/issues/82852 + "{cwd}{pathsep}fake_sync_gateway_config.json" + if sys.platform != "win32" + else "fake_sync_gateway_config.json", + ), + ( + b'{"cmdline": ["fake_sync_gateway", "-json", "{tmpdir}/real_file.json"]}', + "fake_sync_gateway", + "{tmpdir}{pathsep}real_file.json", + ), + ], +) +def test_get_paths_from_expvars( + expvar_output: bytes, + expected_sg_path: Optional[str], + expected_config_path: Optional[str], + tmpdir: pathlib.Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + real_file = tmpdir / "real_file.json" + real_file.write_text("This is a real file.", encoding="utf-8") + subdir = tmpdir / "subdir" + subdir.mkdir() + monkeypatch.chdir(subdir) + cwd = pathlib.Path.cwd() + expvar_output = expvar_output.replace( + b"{cwd}", normalize_path_for_json(cwd).encode("utf-8") + ) + expvar_output = expvar_output.replace( + b"{tmpdir}", normalize_path_for_json(tmpdir).encode("utf-8") + ) + + # interpolate cwd for pathlib.Path.resolve + if expected_config_path is not None: + expected_config_path = expected_config_path.format( + cwd=str(cwd), + tmpdir=str(tmpdir), + pathsep=os.sep, + ) + with unittest.mock.patch( + "sgcollect.urlopen_with_basic_auth", return_value=io.BytesIO(expvar_output) + ): + sg_path, config_path = sgcollect.get_paths_from_expvars( + sg_url="fakeurl", sg_username="", sg_password="" + ) + assert sg_path == expected_sg_path + assert config_path == expected_config_path + + +def test_discover_sg_binary_path() -> None: + parser = sgcollect.create_option_parser() + options, _ = parser.parse_args([]) + with unittest.mock.patch("os.path.exists", return_value=False): + assert ( + sgcollect.discover_sg_binary_path( + options, + sg_url="", + ) + == "" + ) + options, _ = parser.parse_args(["--sync-gateway-executable", "fake_sg"]) + with pytest.raises( + expected_exception=Exception, + match="executable passed in does not exist", + ): + sgcollect.discover_sg_binary_path(options, sg_url="") + + options, _ = parser.parse_args([]) + with unittest.mock.patch("os.path.exists", return_value=True): + assert ( + sgcollect.discover_sg_binary_path(options, sg_url="") + == "/opt/couchbase-sync-gateway/bin/sync_gateway" + ) + options, _ = parser.parse_args([]) + with unittest.mock.patch("os.path.exists", side_effect=[False, True]): + assert ( + sgcollect.discover_sg_binary_path(options, sg_url="") + == R"C:\Program Files (x86)\Couchbase\sync_gateway.exe" + ) # Windows (Pre-2.0) + + with unittest.mock.patch("os.path.exists", side_effect=[False, False, True]): + assert ( + sgcollect.discover_sg_binary_path(options, sg_url="") + == R"C:\Program Files\Couchbase\Sync Gateway\sync_gateway.exe" # Windows (Post-2.0) + ) + + +@pytest.mark.parametrize( + "cmdline_args,expected_calls", + [ + ( + [], + [ + unittest.mock.call( + url="http://127.0.0.1:4985", + username=None, + password=None, + ), + ], + ), + ( + ["--sync-gateway-username=myuser", "--sync-gateway-password=mypassword"], + [ + unittest.mock.call( + url="http://127.0.0.1:4985", + username="myuser", + password="mypassword", + ), + ], + ), + ( + ["--sync-gateway-url=example.com"], + [ + unittest.mock.call( + url="http://example.com", + username=None, + password=None, + ), + unittest.mock.call( + url="https://example.com", + username=None, + password=None, + ), + unittest.mock.call( + url="http://127.0.0.1:4985", + username=None, + password=None, + ), + ], + ), + ( + ["--sync-gateway-url=https://example.com:4985"], + [ + unittest.mock.call( + url="https://example.com:4985", + username=None, + password=None, + ), + unittest.mock.call( + url="http://127.0.0.1:4985", + username=None, + password=None, + ), + ], + ), + ( + ["--sync-gateway-url=http://example.com:4985"], + [ + unittest.mock.call( + url="http://example.com:4985", + username=None, + password=None, + ), + unittest.mock.call( + url="http://127.0.0.1:4985", + username=None, + password=None, + ), + ], + ), + ], +) +def test_get_sg_url( + cmdline_args: list[str], expected_calls: list[unittest.mock._Call] +) -> None: + parser = sgcollect.create_option_parser() + options, _ = parser.parse_args(cmdline_args) + with unittest.mock.patch( + "sgcollect.urlopen_with_basic_auth", + side_effect=urllib.error.URLError("mock error connecting"), + ) as mock_urlopen: + # this URL isn't correct but it is the fallback URL for this function + assert sgcollect.get_sg_url(options) == "https://127.0.0.1:4985" + assert mock_urlopen.mock_calls == expected_calls + + +def normalize_path_for_json(p: pathlib.Path) -> str: + """ + Convert a pathlib path to something that is OK for JSON, making all windows paths use forward slashes. + """ + return str(p).replace("\\", "\\\\") diff --git a/tools-tests/upload_test.py b/tools-tests/upload_test.py index 747f20b2bd..ce73499296 100644 --- a/tools-tests/upload_test.py +++ b/tools-tests/upload_test.py @@ -67,18 +67,20 @@ def open(self, *args, **kwargs): @pytest.mark.usefixtures("main_norun") @pytest.mark.parametrize("args", [[], ["--log-redaction-level", "none"]]) def test_main_output_exists(args): - with unittest.mock.patch("sys.argv", ["sg_collect", *args, ZIP_NAME]): - sgcollect.main() + with pytest.raises(SystemExit, check=lambda e: e.code == 0): + with unittest.mock.patch("sys.argv", ["sg_collect", *args, ZIP_NAME]): + sgcollect.main() assert pathlib.Path(ZIP_NAME).exists() assert not pathlib.Path(REDACTED_ZIP_NAME).exists() @pytest.mark.usefixtures("main_norun_redacted_zip") def test_main_output_exists_with_redacted(): - with unittest.mock.patch( - "sys.argv", ["sg_collect", "--log-redaction-level", "partial", ZIP_NAME] - ): - sgcollect.main() + with pytest.raises(SystemExit, check=lambda e: e.code == 0): + with unittest.mock.patch( + "sys.argv", ["sg_collect", "--log-redaction-level", "partial", ZIP_NAME] + ): + sgcollect.main() assert pathlib.Path(ZIP_NAME).exists() assert pathlib.Path(REDACTED_ZIP_NAME).exists() diff --git a/tools/sgcollect.py b/tools/sgcollect.py index 434ab05494..07bd5cae0a 100755 --- a/tools/sgcollect.py +++ b/tools/sgcollect.py @@ -13,6 +13,7 @@ # -*- python -*- import base64 import glob +import http import json import optparse import os @@ -20,14 +21,12 @@ import platform import re import ssl -import subprocess import sys import urllib.error import urllib.parse import urllib.request import uuid -from sys import platform as _platform -from typing import List, Optional +from typing import List, NoReturn, Optional import password_remover from tasks import ( @@ -307,7 +306,12 @@ def extract_element_from_logging_config(element, config): return -def urlopen_with_basic_auth(url, username, password): +def urlopen_with_basic_auth( + url: str, username: Optional[str], password: Optional[str] +) -> http.client.HTTPResponse: + """ + Open a URL with basic authentication if username and password are provided. Can raise urllib.error.URLError if there is an error. + """ if username and len(username) > 0: # Add basic auth header request = urllib.request.Request(url) @@ -513,13 +517,22 @@ def get_db_list(sg_url, sg_username, sg_password): # Server config # Each DB config def make_config_tasks( - sg_config_path: str, + sg_config_path: Optional[str], sg_url: str, sg_username: Optional[str], sg_password: Optional[str], should_redact: bool, ) -> List[PythonTask]: - collect_config_tasks = [] + """ + Return a list of tasks suitable for collecting configuration information. + + 1. sync gateway configuration file. + 2. /_config + 3. /_config?include_runtime=true + 4. Each //_config + 5. /_cluster_info + """ + collect_config_tasks: list[PythonTask] = [] # Here are the "usual suspects" to probe for finding the static config sg_config_files = [ @@ -608,31 +621,48 @@ def make_config_tasks( return collect_config_tasks -def get_config_path_from_cmdline(cmdline_args): +def get_config_path_from_cmdline(cmdline_args: list[str]) -> Optional[str]: + """ + Parse command line arguments to find the configuration file path and return an absolute path. May return None if the path doesn't exist. + + Example input:: + + sync_gateway -json config.json + + """ for cmdline_arg in cmdline_args: # if it has .json in the path, assume it's a config file. # ignore any config files that are URL's for now, since # they won't be handled correctly. if ".json" in cmdline_arg and "http" not in cmdline_arg: - return cmdline_arg + return str(pathlib.Path(cmdline_arg).resolve()) return None -def get_paths_from_expvars(sg_url, sg_username, sg_password): +def get_paths_from_expvars( + sg_url: str, sg_username: Optional[str], sg_password: Optional[str] +) -> tuple[Optional[str], Optional[str]]: + """ + Get the Sync Gateway binary and configuration file path from /_expvar endpoint. + """ data = None sg_binary_path = None sg_config_path = None # get content and parse into json - if sg_url: - try: - response = urlopen_with_basic_auth( - expvar_url(sg_url), sg_username, sg_password - ) - # response = urllib.request.urlopen(expvar_url(sg_url)) - data = json.load(response) - except urllib.error.URLError as e: - print("WARNING: Unable to connect to Sync Gateway: {0}".format(e)) + if not sg_url: + return None, None + + try: + response = urlopen_with_basic_auth(expvar_url(sg_url), sg_username, sg_password) + except urllib.error.URLError as e: + print("WARNING: Unable to connect to Sync Gateway: {0}".format(e)) + return None, None + try: + data = json.load(response) + except json.JSONDecodeError as e: + print(f"WARNING: Unable to deserialize expvar output: {e}") + return None, None if data is not None and "cmdline" in data: cmdline_args = data["cmdline"] @@ -640,37 +670,11 @@ def get_paths_from_expvars(sg_url, sg_username, sg_password): return (sg_binary_path, sg_config_path) sg_binary_path = cmdline_args[0] if len(cmdline_args) > 1: - try: - sg_config_path = get_absolute_path( - get_config_path_from_cmdline(cmdline_args[1:]) - ) - except Exception as e: - print( - "Exception trying to get absolute sync gateway path from expvars: {0}".format( - e - ) - ) - sg_config_path = get_config_path_from_cmdline(cmdline_args[1:]) + sg_config_path = get_config_path_from_cmdline(cmdline_args[1:]) return (sg_binary_path, sg_config_path) -def get_absolute_path(relative_path): - sync_gateway_cwd = "" - try: - if _platform.startswith("linux"): - sync_gateway_pid = subprocess.check_output( - ["pgrep", "sync_gateway"] - ).split()[0] - sync_gateway_cwd = subprocess.check_output( - ["readlink", "-e", "/proc/{0}/cwd".format(sync_gateway_pid)] - ).strip("\n") - except subprocess.CalledProcessError: - pass - - return os.path.join(sync_gateway_cwd, relative_path) - - def make_download_expvars_task(sg_url, sg_username, sg_password): task = make_curl_task( name="download_sg_expvars", @@ -774,7 +778,31 @@ def make_sg_tasks( return sg_tasks -def discover_sg_binary_path(options, sg_url, sg_username, sg_password): +def discover_sg_binary_path( + options: optparse.Values, + sg_url: Optional[str], +) -> str: + """ + Return the path to the sync gateway binary, returns None if the path is not found. + + 1. --sync-gateway-executable option, will through an exception if the path does not exist. + 2. If the expvars endpoint is available, it will try to get the path from there. + 3. Discover the path from a set of common locations. + """ + if options.sync_gateway_executable: + if not os.path.exists(options.sync_gateway_executable): + raise Exception( + "Path to sync gateway executable passed in does not exist: {0}".format( + options.sync_gateway_executable + ) + ) + return options.sync_gateway_executable + if sg_url: + sg_binary_path, _ = get_paths_from_expvars( + sg_url, options.sync_gateway_username, options.sync_gateway_password + ) + if sg_binary_path: + return sg_binary_path sg_bin_dirs = [ "/opt/couchbase-sync-gateway/bin/sync_gateway", # Linux + OSX R"C:\Program Files (x86)\Couchbase\sync_gateway.exe", # Windows (Pre-2.0) @@ -784,26 +812,10 @@ def discover_sg_binary_path(options, sg_url, sg_username, sg_password): for sg_binary_path_candidate in sg_bin_dirs: if os.path.exists(sg_binary_path_candidate): return sg_binary_path_candidate + return "" - sg_binary_path, _ = get_paths_from_expvars(sg_url, sg_username, sg_password) - - if ( - options.sync_gateway_executable is not None - and len(options.sync_gateway_executable) > 0 - ): - if not os.path.exists(options.sync_gateway_executable): - raise Exception( - "Path to sync gateway executable passed in does not exist: {0}".format( - options.sync_gateway_executable - ) - ) - return sg_binary_path - - # fallback to whatever was specified in options - return options.sync_gateway_executable - -def main(): +def main() -> NoReturn: # ask all tools to use C locale (MB-12050) os.environ["LANG"] = "C" os.environ["LC_ALL"] = "C" @@ -826,43 +838,7 @@ def main(): if options.watch_stdin: setup_stdin_watcher() - sg_url = options.sync_gateway_url - sg_username = options.sync_gateway_username - sg_password = options.sync_gateway_password - - if not sg_url or "://" not in sg_url: - if not sg_url: - root_url = "127.0.0.1:4985" - else: - root_url = sg_url - sg_url_http = "http://" + root_url - print("Trying Sync Gateway URL: {0}".format(sg_url_http)) - - # Set sg_url to sg_url_http at this point - # If we're unable to determine which URL to use this is our best - # attempt. Avoids having this is 'None' later - sg_url = sg_url_http - - try: - response = urlopen_with_basic_auth(sg_url_http, sg_username, sg_password) - json.load(response) - except Exception as e: - print("Failed to communicate with: {} {}".format(sg_url_http, e)) - sg_url_https = "https://" + root_url - print("Trying Sync Gateway URL: {0}".format(sg_url_https)) - try: - response = urlopen_with_basic_auth( - sg_url_https, sg_username, sg_password - ) - json.load(response) - except Exception as e: - print( - "Failed to communicate with Sync Gateway using url {}. " - "Check that Sync Gateway is running and reachable. " - "Will attempt to continue anyway.".format(e) - ) - else: - sg_url = sg_url_https + sg_url = get_sg_url(options) # Build path to zip directory, make sure it exists zip_filename = args[0] @@ -933,13 +909,13 @@ def main(): log("Python version: %s" % sys.version) # Find path to sg binary - sg_binary_path = discover_sg_binary_path(options, sg_url, sg_username, sg_password) + sg_binary_path = discover_sg_binary_path(options, sg_url) # Run SG specific tasks for task in make_sg_tasks( sg_url, - sg_username, - sg_password, + options.sync_gateway_username, + options.sync_gateway_password, options.sync_gateway_config, options.sync_gateway_executable, should_redact, @@ -977,7 +953,7 @@ def main(): print("Zipfile built: {0}".format(zip_filename)) if not upload_url: - return + sys.exit(0) # Upload the zip to the URL to S3 if required try: if should_redact: @@ -1023,3 +999,51 @@ def should_redact_from_options(options: optparse.Values) -> bool: Returns True if the redaction level is set to 'partial' or 'full'. """ return options.redact_level != "none" + + +def get_sg_url(options: optparse.Values) -> str: + """ + Gets the Sync Gateway URL for the admin port. Returns the first valid URL it can connect to, or https://127.0.0.1:4985 if it can not find one. + + 1. --sync-gateway-url option, if provided. + a. If the URL contains "://", it is used as is. + b. If the URL does not contain "://", it will try to connect to http:// and https:// versions of the URL. + 2. http://127.0.0.1:4985 + 3. https://127.0.0.1:4985 + """ + possible_urls: List[str] = [] + if options.sync_gateway_url: + if "://" in options.sync_gateway_url: + possible_urls.append(options.sync_gateway_url) + else: + possible_urls.extend( + [ + "http://" + options.sync_gateway_url, + "https://" + options.sync_gateway_url, + ] + ) + possible_urls.extend(["http://127.0.0.1:4985"]) + for url in possible_urls: + if can_connect_to_sg_url( + url, options.sync_gateway_username, options.sync_gateway_password + ): + return url + # Default URL if none of the above worked. The more correct way would be to return None if this doesn't work and do + # not try to create any tasks that require connecting to admin API, but the subsequent code expects a URL to be + # returned. + return "https://127.0.0.1:4985" + + +def can_connect_to_sg_url(sg_url: str, sg_username: str, sg_password: str) -> bool: + """ + Return true if can connect to the Sync Gateway URL with the provided username and password. + """ + try: + response = urlopen_with_basic_auth( + url=sg_url, username=sg_username, password=sg_password + ) + json.load(response) + except Exception as e: + print(f"Failed to communicate with: {sg_url} {e}") + return False + return True