diff --git a/driver.go b/driver.go index 68d78e8..7637b50 100644 --- a/driver.go +++ b/driver.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "math/rand" "net" "os" "strings" @@ -29,8 +30,8 @@ type Driver struct { cachedImage *hcloud.Image Type string cachedType *hcloud.ServerType - Location string - cachedLocation *hcloud.Location + Locations []string + cachedLocations []*hcloud.Location KeyID int cachedKey *hcloud.SSHKey IsExistingKey bool @@ -63,7 +64,7 @@ const ( flagImage = "hetzner-image" flagImageID = "hetzner-image-id" flagType = "hetzner-server-type" - flagLocation = "hetzner-server-location" + flagLocations = "hetzner-server-locations" flagExKeyID = "hetzner-existing-key-id" flagExKeyPath = "hetzner-existing-key-path" flagUserData = "hetzner-user-data" @@ -133,11 +134,11 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag { Usage: "Server type to create", Value: defaultType, }, - mcnflag.StringFlag{ - EnvVar: "HETZNER_LOCATION", - Name: flagLocation, + mcnflag.StringSliceFlag{ + EnvVar: "HETZNER_LOCATIONS", + Name: flagLocations, Usage: "Location to create machine at", - Value: "", + Value: []string{}, }, mcnflag.IntFlag{ EnvVar: "HETZNER_EXISTING_KEY_ID", @@ -249,7 +250,7 @@ func (d *Driver) setConfigFromFlagsImpl(opts drivers.DriverOptions) error { d.AccessToken = opts.String(flagAPIToken) d.Image = opts.String(flagImage) d.ImageID = opts.Int(flagImageID) - d.Location = opts.String(flagLocation) + d.Locations = opts.StringSlice(flagLocations) d.Type = opts.String(flagType) d.KeyID = opts.Int(flagExKeyID) d.IsExistingKey = d.KeyID != 0 @@ -357,6 +358,7 @@ func (d *Driver) PreCreateCheck() error { key.Fingerprint != ssh.FingerprintSHA256(pubk) { return errors.Errorf("remote key %d does not match local key %s", d.KeyID, d.originalKey) } + fmt.Println("hello") } if _, err := d.getType(); err != nil { @@ -880,17 +882,42 @@ func (d *Driver) copySSHKeyPair(src string) error { return nil } -func (d *Driver) getLocation() (*hcloud.Location, error) { - if d.cachedLocation != nil { - return d.cachedLocation, nil +func (d *Driver) getLocations() ([]*hcloud.Location, error) { + if len(d.cachedLocations) > 0 { + return d.cachedLocations, nil + } + + locations := []*hcloud.Location{} + + for _, locationName := range d.Locations { + location, _, err := d.getClient().Location.GetByName(context.Background(), locationName) + + if err != nil { + return []*hcloud.Location{location}, errors.Wrap(err, "could not get location by name") + } + + locations = append(locations, location) } - location, _, err := d.getClient().Location.GetByName(context.Background(), d.Location) + d.cachedLocations = locations + return locations, nil +} + +func (d *Driver) getRandomLocation() (*hcloud.Location, error) { + locations, err := d.getLocations() if err != nil { - return location, errors.Wrap(err, "could not get location by name") + return nil, err } - d.cachedLocation = location - return location, nil + + s := rand.NewSource(time.Now().Unix()) + r := rand.New(s) + location := r.Intn(len(locations)) + + return locations[location], nil +} + +func (d *Driver) getLocation() (*hcloud.Location, error) { + return d.getRandomLocation() } func (d *Driver) getType() (*hcloud.ServerType, error) {