Skip to content
25 changes: 20 additions & 5 deletions test/cmd/txsim/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,10 @@ account that can act as the master account. The command runs until all sequences
}
}

if os.Getenv(TxsimPoll) != "" && pollTime != user.DefaultPollTime {
pollTime, err = time.ParseDuration(os.Getenv(TxsimPoll))
if err != nil {
return fmt.Errorf("parsing poll time: %w", err)
}
// set pollTime: flag has priority, then env var, then default
pollTime, err = getPollTime(pollTime, os.Getenv(TxsimPoll), user.DefaultPollTime)
if err != nil {
return fmt.Errorf("parsing poll time: %w", err)
}

opts := txsim.DefaultOptions().
Expand Down Expand Up @@ -296,3 +295,19 @@ func parseUpgradeSchedule(schedule string) (map[int64]uint64, error) {
}
return scheduleMap, nil
}

// getPollTime returns the correct pollTime value based on CLI flag, environment variable, and default.
// Priority: flag > env > default.
func getPollTime(flagValue time.Duration, envValue string, defaultValue time.Duration) (time.Duration, error) {
if flagValue != defaultValue {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be simplified using cmd.Flags().Changed() which would return true if the flag has been changed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be simplified using cmd.Flags().Changed() which would return true if the flag has been changed

yes

return flagValue, nil
}
if envValue != "" {
val, err := time.ParseDuration(envValue)
if err != nil {
return 0, err
}
return val, nil
}
return defaultValue, nil
}
38 changes: 38 additions & 0 deletions test/cmd/txsim/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,44 @@ func TestTxsimDefaultKeypath(t *testing.T) {
require.NoError(t, err)
}

func TestGetPollTime(t *testing.T) {
defaultPoll := 10 * time.Second
flagPoll := 30 * time.Second

// 1. Only default (flag == default, env empty)
poll, err := getPollTime(defaultPoll, "", defaultPoll)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if poll != defaultPoll {
t.Errorf("expected default poll time, got %v", poll)
}

// 2. Only env (flag == default, env set)
poll, err = getPollTime(defaultPoll, "20s", defaultPoll)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if poll != 20*time.Second {
t.Errorf("expected env poll time, got %v", poll)
}

// 3. Flag set (flag != default, env set)
poll, err = getPollTime(flagPoll, "20s", defaultPoll)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if poll != flagPoll {
t.Errorf("expected flag poll time, got %v", poll)
}

// 4. Invalid env value
_, err = getPollTime(defaultPoll, "notaduration", defaultPoll)
if err == nil {
t.Error("expected error for invalid env value, got nil")
}
}

func setup(t testing.TB) (keyring.Keyring, string, string) {
if testing.Short() {
t.Skip("skipping tx sim in short mode.")
Expand Down