From f80e105a2407f1251abfc456703db6f96f960e37 Mon Sep 17 00:00:00 2001 From: Gargi Panatula Date: Wed, 25 Jun 2025 00:28:20 +0000 Subject: [PATCH] updated verification to aws sdk go v2 --- cmd/aws-iam-authenticator/root.go | 13 +-- cmd/aws-iam-authenticator/server.go | 12 +- cmd/aws-iam-authenticator/verify.go | 37 ++++-- go.mod | 2 +- go.sum | 5 +- hack/dev/describe-regions-policy.json | 10 ++ hack/e2e/aws.sh | 7 ++ hack/e2e/run.sh | 11 ++ hack/lib/dev-env.sh | 38 +++++++ hack/stop-dev-env.sh | 11 ++ pkg/arn/arn.go | 20 +++- pkg/ec2provider/ec2provider.go | 1 + pkg/ec2provider/ec2provider_mock.go | 67 +++++++++++ pkg/ec2provider/ec2provider_test.go | 54 +-------- pkg/server/server.go | 5 +- pkg/token/token.go | 101 ++++++++--------- pkg/token/token_test.go | 130 ++++++++++------------ tests/integration/testutils/testserver.go | 13 +-- 18 files changed, 312 insertions(+), 225 deletions(-) create mode 100644 hack/dev/describe-regions-policy.json create mode 100644 pkg/ec2provider/ec2provider_mock.go diff --git a/cmd/aws-iam-authenticator/root.go b/cmd/aws-iam-authenticator/root.go index d61586ded..448370975 100644 --- a/cmd/aws-iam-authenticator/root.go +++ b/cmd/aws-iam-authenticator/root.go @@ -20,12 +20,13 @@ import ( "errors" "fmt" "os" + "slices" "strings" + "sigs.k8s.io/aws-iam-authenticator/pkg/arn" "sigs.k8s.io/aws-iam-authenticator/pkg/config" "sigs.k8s.io/aws-iam-authenticator/pkg/mapper" - "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -157,14 +158,8 @@ func getConfig() (config.Config, error) { return cfg, errors.New("cluster ID cannot be empty") } - partitionKeys := []string{} - partitionMap := map[string]endpoints.Partition{} - for _, p := range endpoints.DefaultPartitions() { - partitionMap[p.ID()] = p - partitionKeys = append(partitionKeys, p.ID()) - } - if _, ok := partitionMap[cfg.PartitionID]; !ok { - return cfg, errors.New("Invalid partition") + if !slices.Contains(arn.PartitionKeys, cfg.PartitionID) { + return cfg, errors.New("Invalid partition when getting config") } // DynamicFile BackendMode and DynamicFilePath are mutually inclusive. diff --git a/cmd/aws-iam-authenticator/server.go b/cmd/aws-iam-authenticator/server.go index 15a98c1bb..16ea3540a 100644 --- a/cmd/aws-iam-authenticator/server.go +++ b/cmd/aws-iam-authenticator/server.go @@ -25,11 +25,11 @@ import ( "k8s.io/sample-controller/pkg/signals" "sigs.k8s.io/aws-iam-authenticator/pkg" + "sigs.k8s.io/aws-iam-authenticator/pkg/arn" "sigs.k8s.io/aws-iam-authenticator/pkg/mapper" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" "sigs.k8s.io/aws-iam-authenticator/pkg/server" - "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -67,14 +67,8 @@ var serverCmd = &cobra.Command{ } func init() { - partitionKeys := []string{} - for _, p := range endpoints.DefaultPartitions() { - partitionKeys = append(partitionKeys, p.ID()) - } - - serverCmd.Flags().String("partition", - endpoints.AwsPartitionID, - fmt.Sprintf("The AWS partition. Must be one of: %v", partitionKeys)) + serverCmd.Flags().String("partition", "aws", + fmt.Sprintf("The AWS partition. Must be one of: %v", arn.PartitionKeys)) viper.BindPFlag("server.partition", serverCmd.Flags().Lookup("partition")) serverCmd.Flags().String("generate-kubeconfig", diff --git a/cmd/aws-iam-authenticator/verify.go b/cmd/aws-iam-authenticator/verify.go index 9bb37f4f4..f0cbcf7cb 100644 --- a/cmd/aws-iam-authenticator/verify.go +++ b/cmd/aws-iam-authenticator/verify.go @@ -19,15 +19,17 @@ limitations under the License. package main import ( + "context" "encoding/json" "fmt" "os" "sigs.k8s.io/aws-iam-authenticator/pkg/token" - "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -54,14 +56,17 @@ var verifyCmd = &cobra.Command{ os.Exit(1) } - sess := session.Must(session.NewSession()) - ec2metadata := ec2metadata.New(sess) - instanceRegion, err := ec2metadata.Region() + ctx := context.Background() + instanceRegion := getInstanceRegion(ctx) + + cfg, err := config.LoadDefaultConfig(ctx) if err != nil { - fmt.Printf("[Warn] Region not found in instance metadata, err: %v", err) + fmt.Fprintf(os.Stderr, "unable to create sdk client configuration: %v\n", err) + os.Exit(1) } + ec2Client := ec2.NewFromConfig(cfg) - id, err := token.NewVerifier(clusterID, partition, instanceRegion).Verify(tok) + id, err := token.NewVerifier(ctx, clusterID, partition, instanceRegion, ec2Client).Verify(tok) if err != nil { fmt.Fprintf(os.Stderr, "could not verify token: %v\n", err) os.Exit(1) @@ -79,6 +84,24 @@ var verifyCmd = &cobra.Command{ }, } +// Uses EC2 metadata to get the region. Returns "" if no region found. +func getInstanceRegion(ctx context.Context) string { + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "[Warn] Unable to create config for metadata client, err: %v", err) + panic(err) + } + + imdsClient := imds.NewFromConfig(cfg) + getRegionOutput, err := imdsClient.GetRegion(ctx, &imds.GetRegionInput{}) + if err != nil { + fmt.Fprintf(os.Stderr, "[Warn] Region not found in instance metadata, err: %v\n", err) + return "" + } + + return getRegionOutput.Region +} + func init() { rootCmd.AddCommand(verifyCmd) verifyCmd.Flags().StringP("token", "t", "", "Token to verify") diff --git a/go.mod b/go.mod index 23d2dc398..43cafa68f 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.29.17 github.com/aws/aws-sdk-go-v2/credentials v1.17.70 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.32 + github.com/aws/aws-sdk-go-v2/service/account v1.24.2 github.com/aws/aws-sdk-go-v2/service/ec2 v1.225.2 github.com/aws/aws-sdk-go-v2/service/sts v1.34.0 github.com/aws/smithy-go v1.22.4 @@ -55,7 +56,6 @@ require ( github.com/google/gnostic-models v0.6.9 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/mailru/easyjson v0.9.0 // indirect diff --git a/go.sum b/go.sum index 12b6d13b1..284d6666d 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36 h1:i2vNHQiXUvKhs3quBR github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36/go.mod h1:UdyGa7Q91id/sdyHPwth+043HhmP6yP9MBHgbZM0xo8= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/account v1.24.2 h1:1ItkqDExKIDsS8NoIBq7OxQOJnQNOVjC25CYa9RzOos= +github.com/aws/aws-sdk-go-v2/service/account v1.24.2/go.mod h1:NShtay87juyMTb3c6bHN6Bai5dUFmTX7NzURY4/Jyb0= github.com/aws/aws-sdk-go-v2/service/ec2 v1.225.2 h1:IfMb3Ar8xEaWjgH/zeVHYD8izwJdQgRP5mKCTDt4GNk= github.com/aws/aws-sdk-go-v2/service/ec2 v1.225.2/go.mod h1:35jGWx7ECvCwTsApqicFYzZ7JFEnBc6oHUuOQ3xIS54= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.4 h1:CXV68E2dNqhuynZJPB80bhPQwAKqBWVer887figW6Jc= @@ -86,8 +88,6 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -226,7 +226,6 @@ gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSP gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hack/dev/describe-regions-policy.json b/hack/dev/describe-regions-policy.json new file mode 100644 index 000000000..6b70a0280 --- /dev/null +++ b/hack/dev/describe-regions-policy.json @@ -0,0 +1,10 @@ +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "ec2:DescribeRegions", + "Resource": "*" + } + ] +} diff --git a/hack/e2e/aws.sh b/hack/e2e/aws.sh index 2c2dc7da1..18ffaf8a1 100644 --- a/hack/e2e/aws.sh +++ b/hack/e2e/aws.sh @@ -28,6 +28,13 @@ function create_role() { --assume-role-policy-document "$POLICY" \ --output text \ --query 'Role.Arn') + + ## attach describe-regions policy to the role + aws iam put-role-policy \ + --region "${REGION}" \ + --role-name "$ROLE_NAME" \ + --policy-name "DescribeRegionsPolicy" \ + --policy-document "file://${BASE_DIR}/../dev/describe-regions-policy.json" else set -e loudecho "${ROLE_NAME} role already exists" >&2 diff --git a/hack/e2e/run.sh b/hack/e2e/run.sh index fe64eecb5..3a77a0fd5 100755 --- a/hack/e2e/run.sh +++ b/hack/e2e/run.sh @@ -258,7 +258,18 @@ if [[ "${CLEAN}" == true ]]; then "${CLUSTER_NAME}" \ "${KOPS_STATE_FILE}" + aws iam list-role-policies --role-name ${ADMIN_ROLE_NAME} --query "PolicyNames[]" --output text | + while read policy_name; do + echo "Deleting inline policy: $policy_name" + aws iam delete-role-policy --role-name ${ADMIN_ROLE_NAME} --policy-name "$policy_name" + done aws iam delete-role --role-name "${ADMIN_ROLE_NAME}" --region ${REGION} + + aws iam list-role-policies --role-name ${USER_ROLE_NAME} --query "PolicyNames[]" --output text | + while read policy_name; do + echo "Deleting inline policy: $policy_name" + aws iam delete-role-policy --role-name ${USER_ROLE_NAME} --policy-name "$policy_name" + done aws iam delete-role --role-name "${USER_ROLE_NAME}" --region ${REGION} else loudecho "Not cleaning" diff --git a/hack/lib/dev-env.sh b/hack/lib/dev-env.sh index 0394732a3..bd2bf6826 100644 --- a/hack/lib/dev-env.sh +++ b/hack/lib/dev-env.sh @@ -71,7 +71,11 @@ authenticator_backend_mode_dest_file="${authenticator_dynamicfile_dest_path}/bac authenticator_config_dest_dir="/etc/authenticator" authenticator_export_dest_dir="/var/authenticator/export" authenticator_state_dest_dir="/var/authenticator/state" +policies_template="${REPO_ROOT}/hack/dev/policies.template" +policies_json="${OUTPUT}/dev/authenticator/policies.json" apiserver_config_dest_dir="/etc/kubernetes/authenticator" +describe_regions_policy_json="${REPO_ROOT}/hack/dev/describe-regions-policy.json" + # Kubeconfig used when authenticator loads its mapping configuration from the API server authenticator_kubeconfig="${authenticator_config_dest_dir}/authenticator-kubeconfig.yaml" # Kubeconfig passed to the apiserver so it can kind its authentication webhook @@ -86,6 +90,10 @@ kubectl_kubeconfig="${client_dir}/kubeconfig.yaml" # Admin kubeconfig generated by kind kind_kubeconfig="${client_dir}/kind-kubeconfig.yaml" +AWS_ACCOUNT=$(aws sts get-caller-identity --query "Account" --output text) +DESCRIBEREGIONS_ROLE_NAME="authenticator-describeregions-role" +DESCRIBEREGIONS_POLICY_NAME="DescribeRegionsPolicy" + function install_kind() { if ! [[ -f "${KIND_BIN}" ]]; then if [[ "$OSTYPE" == "darwin"* ]]; then @@ -190,6 +198,33 @@ function start_authenticator_with_dynamicfile() { chmod -R 777 "${authenticator_dynamicfile_host_path}" chmod 777 "${authenticator_access_entry_host_file}" + # Create a role that can call ec2:DescribeRegions to run the tests + if ! RoleOutput=$(aws iam get-role --role-name "${DESCRIBEREGIONS_ROLE_NAME}" 2>&1); then + sed -e "s|{{AWS_ACCOUNT}}|${AWS_ACCOUNT}|g" \ + "${policies_template}" > "${policies_json}" + sleep 2 + aws iam create-role --role-name ${DESCRIBEREGIONS_ROLE_NAME} --assume-role-policy-document file://${policies_json} 1>/dev/null + echo "Waiting for IAM propagation of ${DESCRIBEREGIONS_ROLE_NAME}..." + sleep 10 + + aws iam put-role-policy \ + --role-name $DESCRIBEREGIONS_ROLE_NAME \ + --policy-name $DESCRIBEREGIONS_POLICY_NAME \ + --policy-document file://$describe_regions_policy_json + sleep 2 + fi + + # Assume the role and get its credentials + DESCRIBEREGIONS_ROLE_ARN="arn:aws:iam::$(aws sts get-caller-identity --query Account --output text):role/${DESCRIBEREGIONS_ROLE_NAME}" + ASSUME_OUTPUT=$(aws sts assume-role \ + --role-arn "$DESCRIBEREGIONS_ROLE_ARN" \ + --role-session-name "DescribeRegionsSession") + export AWS_ACCESS_KEY_ID=$(echo $ASSUME_OUTPUT | jq -r .Credentials.AccessKeyId) + export AWS_SECRET_ACCESS_KEY=$(echo $ASSUME_OUTPUT | jq -r .Credentials.SecretAccessKey) + export AWS_SESSION_TOKEN=$(echo $ASSUME_OUTPUT | jq -r .Credentials.SessionToken) + + echo "Successfully assumed role: $DESCRIBEREGIONS_ROLE_NAME" + docker run \ --detach \ --ip "${AUTHENTICATOR_IP}" \ @@ -202,6 +237,9 @@ function start_authenticator_with_dynamicfile() { --publish ${authenticator_healthz_port}:${authenticator_healthz_port} \ --publish ${AUTHENTICATOR_PORT}:${AUTHENTICATOR_PORT} \ --env AWS_REGION="us-west-2" \ + --env AWS_ACCESS_KEY_ID \ + --env AWS_SECRET_ACCESS_KEY \ + --env AWS_SESSION_TOKEN \ "${AUTHENTICATOR_IMAGE}" \ server \ --config "${authenticator_config_dest_dir}/authenticator_dynamicfile_mode.yaml" diff --git a/hack/stop-dev-env.sh b/hack/stop-dev-env.sh index e180a4272..a50f82e8b 100755 --- a/hack/stop-dev-env.sh +++ b/hack/stop-dev-env.sh @@ -34,6 +34,7 @@ set -o nounset # between them is over localhost and fixed port. REPO_ROOT="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )/.." &> /dev/null && pwd )" +DESCRIBEREGIONS_ROLE_NAME="authenticator-describeregions-role" source "${REPO_ROOT}/hack/lib/dev-env.sh" @@ -47,3 +48,13 @@ sleep 5 # Tear down network delete_network + +# Delete role used to run tests +# List inline policies +aws iam list-role-policies --role-name ${DESCRIBEREGIONS_ROLE_NAME} --query "PolicyNames[]" --output text | +while read policy_name; do + echo "Deleting inline policy: $policy_name" + aws iam delete-role-policy --role-name ${DESCRIBEREGIONS_ROLE_NAME} --policy-name "$policy_name" +done + +aws iam delete-role --role-name ${DESCRIBEREGIONS_ROLE_NAME} diff --git a/pkg/arn/arn.go b/pkg/arn/arn.go index 22900c96d..a7dbbe1f7 100644 --- a/pkg/arn/arn.go +++ b/pkg/arn/arn.go @@ -2,10 +2,10 @@ package arn import ( "fmt" + "slices" "strings" awsarn "github.com/aws/aws-sdk-go-v2/aws/arn" - "github.com/aws/aws-sdk-go/aws/endpoints" ) type PrincipalType int @@ -20,6 +20,16 @@ const ( ASSUMED_ROLE ) +var PartitionKeys = []string{ + "aws", + "aws-cn", + "aws-us-gov", + "aws-iso", + "aws-iso-b", + "aws-iso-e", + "aws-iso-f", +} + // Canonicalize validates IAM resources are appropriate for the authenticator // and converts STS assumed roles into the IAM role resource. // @@ -101,10 +111,8 @@ func StripPath(arn string) (string, error) { } func checkPartition(partition string) error { - for _, p := range endpoints.DefaultPartitions() { - if partition == p.ID() { - return nil - } + if !slices.Contains(PartitionKeys, partition) { + return fmt.Errorf("partition %s is not recognized", partition) } - return fmt.Errorf("partition %s is not recognized", partition) + return nil } diff --git a/pkg/ec2provider/ec2provider.go b/pkg/ec2provider/ec2provider.go index 4d697d06e..abbbd66d9 100644 --- a/pkg/ec2provider/ec2provider.go +++ b/pkg/ec2provider/ec2provider.go @@ -46,6 +46,7 @@ const ( // EC2API defines the interface for EC2 client operations type EC2API interface { DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) + DescribeRegions(ctx context.Context, params *ec2.DescribeRegionsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeRegionsOutput, error) } // Get a node name from instance ID diff --git a/pkg/ec2provider/ec2provider_mock.go b/pkg/ec2provider/ec2provider_mock.go new file mode 100644 index 000000000..38b8e4433 --- /dev/null +++ b/pkg/ec2provider/ec2provider_mock.go @@ -0,0 +1,67 @@ +package ec2provider + +import ( + "context" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" +) + +type MockEc2Client struct { + Reservations []*ec2types.Reservation + Regions []ec2types.Region +} + +const ( + DescribeDelay = 100 +) + +func newMockedEC2ProviderImpl() *ec2ProviderImpl { + dnsCache := ec2PrivateDNSCache{ + cache: make(map[string]string), + lock: sync.RWMutex{}, + } + ec2Requests := ec2Requests{ + set: make(map[string]bool), + lock: sync.RWMutex{}, + } + return &ec2ProviderImpl{ + ec2: &MockEc2Client{}, + privateDNSCache: dnsCache, + ec2Requests: ec2Requests, + instanceIdsChannel: make(chan string, maxChannelSize), + } + +} + +func (c *MockEc2Client) DescribeInstances(ctx context.Context, in *ec2.DescribeInstancesInput, opts ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + // simulate the time it takes for aws to return + time.Sleep(DescribeDelay * time.Millisecond) + var reservations []ec2types.Reservation + for _, res := range c.Reservations { + var reservation ec2types.Reservation + for _, inst := range res.Instances { + for _, id := range in.InstanceIds { + if id == aws.ToString(inst.InstanceId) { + reservation.Instances = append(reservation.Instances, inst) + } + } + } + if len(reservation.Instances) > 0 { + reservations = append(reservations, reservation) + } + } + return &ec2.DescribeInstancesOutput{ + Reservations: reservations, + }, nil +} + +func (c *MockEc2Client) DescribeRegions(ctx context.Context, params *ec2.DescribeRegionsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeRegionsOutput, error) { + if c.Regions == nil { + return &ec2.DescribeRegionsOutput{}, nil + } + return &ec2.DescribeRegionsOutput{Regions: c.Regions}, nil +} diff --git a/pkg/ec2provider/ec2provider_test.go b/pkg/ec2provider/ec2provider_test.go index a27f78c94..deaf817b8 100644 --- a/pkg/ec2provider/ec2provider_test.go +++ b/pkg/ec2provider/ec2provider_test.go @@ -10,66 +10,16 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/prometheus/client_golang/prometheus" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) -const ( - DescribeDelay = 100 -) - -type mockEc2Client struct { - EC2API - Reservations []*ec2types.Reservation -} - -func (c *mockEc2Client) DescribeInstances(ctx context.Context, in *ec2.DescribeInstancesInput, opts ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { - // simulate the time it takes for aws to return - time.Sleep(DescribeDelay * time.Millisecond) - var reservations []ec2types.Reservation - for _, res := range c.Reservations { - var reservation ec2types.Reservation - for _, inst := range res.Instances { - for _, id := range in.InstanceIds { - if id == aws.ToString(inst.InstanceId) { - reservation.Instances = append(reservation.Instances, inst) - } - } - } - if len(reservation.Instances) > 0 { - reservations = append(reservations, reservation) - } - } - return &ec2.DescribeInstancesOutput{ - Reservations: reservations, - }, nil -} - -func newMockedEC2ProviderImpl() *ec2ProviderImpl { - dnsCache := ec2PrivateDNSCache{ - cache: make(map[string]string), - lock: sync.RWMutex{}, - } - ec2Requests := ec2Requests{ - set: make(map[string]bool), - lock: sync.RWMutex{}, - } - return &ec2ProviderImpl{ - ec2: &mockEc2Client{}, - privateDNSCache: dnsCache, - ec2Requests: ec2Requests, - instanceIdsChannel: make(chan string, maxChannelSize), - } - -} - func TestGetPrivateDNSName(t *testing.T) { metrics.InitMetrics(prometheus.NewRegistry()) ec2Provider := newMockedEC2ProviderImpl() - ec2Provider.ec2 = &mockEc2Client{Reservations: prepareSingleInstanceOutput()} + ec2Provider.ec2 = &MockEc2Client{Reservations: prepareSingleInstanceOutput()} go ec2Provider.StartEc2DescribeBatchProcessing(context.TODO()) dns_name, err := ec2Provider.GetPrivateDNSName(context.TODO(), "ec2-1") if err != nil { @@ -102,7 +52,7 @@ func TestGetPrivateDNSNameWithBatching(t *testing.T) { metrics.InitMetrics(prometheus.NewRegistry()) ec2Provider := newMockedEC2ProviderImpl() reservations := prepare100InstanceOutput() - ec2Provider.ec2 = &mockEc2Client{Reservations: reservations} + ec2Provider.ec2 = &MockEc2Client{Reservations: reservations} go ec2Provider.StartEc2DescribeBatchProcessing(context.TODO()) var wg sync.WaitGroup for i := 1; i < 101; i++ { diff --git a/pkg/server/server.go b/pkg/server/server.go index f5d8a29cc..0aa7beeee 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -30,6 +30,7 @@ import ( awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/ec2" "sigs.k8s.io/aws-iam-authenticator/pkg/config" "sigs.k8s.io/aws-iam-authenticator/pkg/ec2provider" "sigs.k8s.io/aws-iam-authenticator/pkg/errutil" @@ -214,10 +215,12 @@ func (c *Server) getHandler(ctx context.Context, backendMapper BackendMapper, ec cfg.Region = "us-east-1" } else { instanceRegion = instanceRegionOutput.Region + cfg.Region = instanceRegion } + ec2Client := ec2.NewFromConfig(cfg) h := &handler{ - verifier: token.NewVerifier(c.ClusterID, c.PartitionID, instanceRegion), + verifier: token.NewVerifier(ctx, c.ClusterID, c.PartitionID, instanceRegion, ec2Client), ec2Provider: ec2provider.New(ctx, c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst), clusterID: c.ClusterID, backendMapper: backendMapper, diff --git a/pkg/token/token.go b/pkg/token/token.go index 821084c31..e83294dfd 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -35,8 +35,8 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/aws-sdk-go/aws/endpoints" smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/prometheus/client_golang/prometheus" @@ -44,8 +44,10 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/pkg/apis/clientauthentication" clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" + "k8s.io/utils/strings/slices" "sigs.k8s.io/aws-iam-authenticator/pkg" "sigs.k8s.io/aws-iam-authenticator/pkg/arn" + "sigs.k8s.io/aws-iam-authenticator/pkg/ec2provider" "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) @@ -420,77 +422,66 @@ type tokenVerifier struct { validSTShostnames map[string]bool } -func getDefaultHostNameForRegion(partition *endpoints.Partition, region, service string) (string, error) { - rep, err := partition.EndpointFor(service, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption) - if err != nil { - return "", fmt.Errorf("Error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err) - } - parsedURL, err := url.Parse(rep.URL) - if err != nil { - return "", fmt.Errorf("Error parsing STS URL %s. err: %v", rep.URL, err) - } - return parsedURL.Hostname(), nil -} - -func stsHostsForPartition(partitionID, region string) map[string]bool { +func stsHostsForPartition(ctx context.Context, partitionID, region string, ec2Client ec2provider.EC2API) map[string]bool { validSTShostnames := map[string]bool{} - var partition *endpoints.Partition - for _, p := range endpoints.DefaultPartitions() { - if partitionID == p.ID() { - partition = &p - break - } - } - if partition == nil { - logrus.Errorf("Partition %s not valid", partitionID) + var tlds []string + serviceNames := []string{"sts", "sts-fips"} + switch partitionID { + case "aws": + tlds = []string{"amazonaws.com", "api.aws"} + validSTShostnames["sts.amazonaws.com"] = true + case "aws-us-gov": + tlds = []string{"amazonaws.com", "api.aws"} + case "aws-cn": + serviceNames = []string{"sts"} + tlds = []string{"amazonaws.com.cn"} + case "aws-iso": + tlds = []string{"c2s.ic.gov"} + case "aws-iso-b": + tlds = []string{"sc2s.sgov.gov"} + case "aws-iso-e": + tlds = []string{"cloud.adc-e.uk"} + case "aws-iso-f": + tlds = []string{"csp.hci.ic.gov"} + default: + logrus.Errorf("unrecognized partition %s", partitionID) return validSTShostnames } - stsSvc, ok := partition.Services()[stsServiceID] - if !ok { - logrus.Errorf("STS service not found in partition %s", partitionID) - // Add the host of the current instances region if the service doesn't already exists in the partition - // so we don't fail if the service is not present in the go sdk but matches the instances region. - stsHostName, err := getDefaultHostNameForRegion(partition, region, stsServiceID) - if err != nil { - logrus.WithError(err).Error("Error getting default hostname") - } else { - validSTShostnames[stsHostName] = true - } + // Get a list of regions available to this account + var regions []string + regionsOutput, err := ec2Client.DescribeRegions(ctx, &ec2.DescribeRegionsInput{ + AllRegions: aws.Bool(true), + }) + if err != nil { + logrus.Errorf("failed to get regions: %v", err) return validSTShostnames } - stsSvcEndPoints := stsSvc.Endpoints() - for epName, ep := range stsSvcEndPoints { - rep, err := ep.ResolveEndpoint(endpoints.STSRegionalEndpointOption) - if err != nil { - logrus.WithError(err).Errorf("Error resolving endpoint for %s in partition %s", epName, partitionID) - continue - } - parsedURL, err := url.Parse(rep.URL) - if err != nil { - logrus.WithError(err).Errorf("Error parsing STS URL %s", rep.URL) - continue - } - validSTShostnames[parsedURL.Hostname()] = true + for _, regionInfo := range regionsOutput.Regions { + regions = append(regions, *regionInfo.RegionName) } // Add the host of the current instances region if not already exists so we don't fail if the region is not // present in the go sdk but matches the instances region. - if _, ok := stsSvcEndPoints[region]; !ok { - stsHostName, err := getDefaultHostNameForRegion(partition, region, stsServiceID) - if err != nil { - logrus.WithError(err).Error("Error getting default hostname") - return validSTShostnames + if !slices.Contains(regions, region) { + regions = append(regions, region) + } + + for _, regionInfo := range regionsOutput.Regions { + for _, serviceName := range serviceNames { + for _, tld := range tlds { + hostname := fmt.Sprintf("%s.%s.%s", serviceName, *regionInfo.RegionName, tld) + validSTShostnames[hostname] = true + } } - validSTShostnames[stsHostName] = true } return validSTShostnames } // NewVerifier creates a Verifier that is bound to the clusterID and uses the default http client. -func NewVerifier(clusterID, partitionID, region string) Verifier { +func NewVerifier(ctx context.Context, clusterID, partitionID, region string, ec2Client ec2provider.EC2API) Verifier { // Initialize metrics if they haven't already been initialized to avoid a // nil pointer panic when setting metric values. if !metrics.Initialized() { @@ -505,7 +496,7 @@ func NewVerifier(clusterID, partitionID, region string) Verifier { Timeout: 10 * time.Second, }, clusterID: clusterID, - validSTShostnames: stsHostsForPartition(partitionID, region), + validSTShostnames: stsHostsForPartition(ctx, partitionID, region, ec2Client), } } diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index bf5deb7e8..eb7118f88 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -18,14 +18,15 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/pkg/apis/clientauthentication" clientauthv1 "k8s.io/client-go/pkg/apis/clientauthentication/v1" clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" + "sigs.k8s.io/aws-iam-authenticator/pkg/ec2provider" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) @@ -37,7 +38,19 @@ func TestMain(m *testing.M) { func validationErrorTest(t *testing.T, partition string, token string, expectedErr string) { t.Helper() - _, err := NewVerifier("", partition, "").(tokenVerifier).Verify(token) + regions := []ec2types.Region{ + {RegionName: aws.String("sa-east-1")}, + {RegionName: aws.String("us-east-2")}, + {RegionName: aws.String("us-west-2")}, + {RegionName: aws.String("ap-northeast-2")}, + {RegionName: aws.String("ca-central-1")}, + {RegionName: aws.String("eu-west-1")}, + {RegionName: aws.String("cn-north-1")}, + } + + ec2Client := &ec2provider.MockEc2Client{Regions: regions} + + _, err := NewVerifier(context.TODO(), "", partition, "", ec2Client).(tokenVerifier).Verify(token) errorContains(t, err, expectedErr) } @@ -89,6 +102,16 @@ func newVerifier(partition string, statusCode int, body string, err error) Verif if body != "" { rc = io.NopCloser(bytes.NewReader([]byte(body))) } + regions := []ec2types.Region{ + {RegionName: aws.String("sa-east-1")}, + {RegionName: aws.String("us-east-2")}, + {RegionName: aws.String("us-west-2")}, + {RegionName: aws.String("ap-northeast-2")}, + {RegionName: aws.String("ca-central-1")}, + {RegionName: aws.String("eu-west-1")}, + } + ec2Client := &ec2provider.MockEc2Client{Regions: regions} + return tokenVerifier{ client: &http.Client{ Transport: &roundTripper{ @@ -99,7 +122,7 @@ func newVerifier(partition string, statusCode int, body string, err error) Verif }, }, }, - validSTShostnames: stsHostsForPartition(partition, ""), + validSTShostnames: stsHostsForPartition(context.TODO(), partition, "", ec2Client), } } @@ -171,8 +194,32 @@ func TestSTSEndpoints(t *testing.T) { {"aws-not-a-partition", "sts.amazonaws.com", false, ""}, } + regions := []ec2types.Region{ + {RegionName: aws.String("cn-northwest-1")}, + {RegionName: aws.String("cn-north-1")}, + {RegionName: aws.String("us-iso-east-1")}, + {RegionName: aws.String("us-east-1")}, + {RegionName: aws.String("us-east-2")}, + {RegionName: aws.String("us-west-1")}, + {RegionName: aws.String("us-west-2")}, + {RegionName: aws.String("ap-south-1")}, + {RegionName: aws.String("ap-northeast-1")}, + {RegionName: aws.String("ap-northeast-2")}, + {RegionName: aws.String("ap-southeast-1")}, + {RegionName: aws.String("ap-southeast-2")}, + {RegionName: aws.String("ca-central-1")}, + {RegionName: aws.String("eu-central-1")}, + {RegionName: aws.String("eu-west-1")}, + {RegionName: aws.String("eu-west-2")}, + {RegionName: aws.String("eu-west-3")}, + {RegionName: aws.String("eu-north-1")}, + {RegionName: aws.String("us-gov-east-1")}, + {RegionName: aws.String("default-region")}, + } + ec2Client := &ec2provider.MockEc2Client{Regions: regions} + for _, c := range cases { - verifier := NewVerifier("", c.partition, c.region).(tokenVerifier) + verifier := NewVerifier(context.TODO(), "", c.partition, c.region, ec2Client).(tokenVerifier) if err := verifier.verifyHost(c.domain); err != nil && c.valid { t.Errorf("%s is not valid endpoint for partition %s", c.domain, c.partition) } @@ -237,7 +284,9 @@ func TestVerifyNoRedirectsFollowed(t *testing.T) { })) defer ts.Close() - tokVerifier := NewVerifier("", "aws", "").(tokenVerifier) + ec2Client := &ec2provider.MockEc2Client{Regions: []ec2types.Region{}} + + tokVerifier := NewVerifier(context.TODO(), "", "aws", "", ec2Client).(tokenVerifier) resp, err := tokVerifier.client.Get(ts.URL) if err != nil { @@ -254,6 +303,8 @@ func TestVerifyNoRedirectsFollowed(t *testing.T) { } func TestVerifyBodyReadError(t *testing.T) { + ec2Client := &ec2provider.MockEc2Client{Regions: []ec2types.Region{}} + verifier := tokenVerifier{ client: &http.Client{ Transport: &roundTripper{ @@ -264,7 +315,7 @@ func TestVerifyBodyReadError(t *testing.T) { }, }, }, - validSTShostnames: stsHostsForPartition("aws", ""), + validSTShostnames: stsHostsForPartition(context.TODO(), "aws", "", ec2Client), } _, err := verifier.Verify(validToken) errorContains(t, err, "error reading HTTP result") @@ -521,73 +572,6 @@ func response(account, userID, arn string) getCallerIdentityWrapper { return wrapper } -func Test_getDefaultHostNameForRegion(t *testing.T) { - type args struct { - partition endpoints.Partition - region string - service string - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "service doesn't exist should return default host name", - args: args{ - partition: endpoints.AwsIsoEPartition(), - region: "eu-isoe-west-1", - service: "test", - }, - want: "test.eu-isoe-west-1.cloud.adc-e.uk", - wantErr: false, - }, - { - name: "service and region doesn't exist should return default host name", - args: args{ - partition: endpoints.AwsIsoEPartition(), - region: "eu-isoe-test-1", - service: "test", - }, - want: "test.eu-isoe-test-1.cloud.adc-e.uk", - wantErr: false, - }, - { - name: "region doesn't exist should return default host name", - args: args{ - partition: endpoints.AwsIsoPartition(), - region: "us-iso-test-1", - service: "sts", - }, - want: "sts.us-iso-test-1.c2s.ic.gov", - wantErr: false, - }, - { - name: "invalid region should return error", - args: args{ - partition: endpoints.AwsIsoPartition(), - region: "test_123", - service: "sts", - }, - want: "", - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := getDefaultHostNameForRegion(&tt.args.partition, tt.args.region, tt.args.service) - if (err != nil) != tt.wantErr { - t.Errorf("getDefaultHostNameForRegion() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("getDefaultHostNameForRegion() = %v, want %v", got, tt.want) - } - }) - } -} - func TestGetWithSTS(t *testing.T) { clusterID := "test-cluster" diff --git a/tests/integration/testutils/testserver.go b/tests/integration/testutils/testserver.go index 025487416..cdfa6e439 100644 --- a/tests/integration/testutils/testserver.go +++ b/tests/integration/testutils/testserver.go @@ -7,10 +7,10 @@ import ( "net/http" "os" "path/filepath" + "slices" "testing" "time" - "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/prometheus/client_golang/prometheus" utilerrors "k8s.io/apimachinery/pkg/util/errors" "k8s.io/apimachinery/pkg/util/wait" @@ -20,6 +20,7 @@ import ( "k8s.io/kubernetes/pkg/controlplane" "k8s.io/kubernetes/test/integration/framework" + "sigs.k8s.io/aws-iam-authenticator/pkg/arn" "sigs.k8s.io/aws-iam-authenticator/pkg/config" "sigs.k8s.io/aws-iam-authenticator/pkg/mapper" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" @@ -149,14 +150,8 @@ func testConfig(t *testing.T, setup AuthenticatorTestFrameworkSetup) (config.Con return cfg, errors.New("cluster ID cannot be empty") } - partitionKeys := []string{} - partitionMap := map[string]endpoints.Partition{} - for _, p := range endpoints.DefaultPartitions() { - partitionMap[p.ID()] = p - partitionKeys = append(partitionKeys, p.ID()) - } - if _, ok := partitionMap[cfg.PartitionID]; !ok { - return cfg, errors.New("Invalid partition") + if !slices.Contains(arn.PartitionKeys, cfg.PartitionID) { + return cfg, errors.New("Invalid partition in test config") } if errs := mapper.ValidateBackendMode(cfg.BackendMode); len(errs) > 0 {