diff --git a/THIRD-PARTY-LICENSES b/THIRD-PARTY-LICENSES index f8565727ee..605217e687 100644 --- a/THIRD-PARTY-LICENSES +++ b/THIRD-PARTY-LICENSES @@ -180,7 +180,7 @@ Copyright © 2015 Steve Francia ----- -** aws/aws-sdk-go; version 1.15.7 -- https://github.com/aws/aws-sdk-go/ +** aws/aws-sdk-go-v2; version 1.24.4 -- https://github.com/aws/aws-sdk-go-v2/ ** Etcd; version v3.1.0-alpha.1 -- https://github.com/coreos/etcd/tree/v3.1.0-alpha.1 ** github.com/coreos/go-semver; version 0.2 -- https://github.com/coreos/go-semver ** github.com/coreos/go-systemd/; version 10 -- https://github.com/coreos/go-systemd/ @@ -412,9 +412,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -* For aws/aws-sdk-go see also this required NOTICE: +* For aws/aws-sdk-go-v2 see also this required NOTICE: AWS SDK for Go -Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. Copyright 2014-2015 Stripe, Inc. * For Etcd see also this required NOTICE: CoreOS Project diff --git a/go.mod b/go.mod index f4d550df57..913e549996 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,17 @@ module k8s.io/cloud-provider-aws go 1.24.4 require ( - github.com/aws/aws-sdk-go v1.55.5 - github.com/aws/aws-sdk-go-v2 v1.32.5 - github.com/aws/aws-sdk-go-v2/config v1.28.0 + github.com/aws/aws-sdk-go-v2 v1.36.5 + github.com/aws/aws-sdk-go-v2/config v1.29.14 + github.com/aws/aws-sdk-go-v2/service/autoscaling v1.53.3 github.com/aws/aws-sdk-go-v2/service/ecr v1.36.2 github.com/aws/aws-sdk-go-v2/service/ecrpublic v1.27.2 + github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.29.3 + github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.45.2 + github.com/aws/aws-sdk-go-v2/service/kms v1.41.0 github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 golang.org/x/time v0.6.0 gopkg.in/gcfg.v1 v1.2.3 k8s.io/api v0.31.0 @@ -25,23 +28,28 @@ require ( k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 ) +require ( + github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect + github.com/onsi/ginkgo/v2 v2.23.0 // indirect + github.com/onsi/gomega v1.36.2 // indirect +) + require ( github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect - github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.41 - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.24 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.24 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect - github.com/aws/aws-sdk-go-v2/service/ec2 v1.186.0 - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.5 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.1 - github.com/aws/smithy-go v1.22.1 + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.31 + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.36 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ec2 v1.218.0 + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 + github.com/aws/smithy-go v1.22.4 github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect @@ -101,20 +109,20 @@ require ( go.opentelemetry.io/proto/otlp v1.3.1 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.26.0 // indirect - golang.org/x/crypto v0.24.0 // indirect + golang.org/x/crypto v0.33.0 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect - golang.org/x/mod v0.17.0 // indirect - golang.org/x/net v0.26.0 // indirect + golang.org/x/mod v0.23.0 // indirect + golang.org/x/net v0.35.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.21.0 // indirect - golang.org/x/term v0.21.0 // indirect - golang.org/x/text v0.16.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + golang.org/x/sync v0.11.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/term v0.29.0 // indirect + golang.org/x/text v0.22.0 // indirect + golang.org/x/tools v0.30.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect google.golang.org/grpc v1.65.0 // indirect - google.golang.org/protobuf v1.34.2 // indirect + google.golang.org/protobuf v1.36.1 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect diff --git a/go.sum b/go.sum index 9ebc19794f..aad58bd3c4 100644 --- a/go.sum +++ b/go.sum @@ -4,42 +4,48 @@ github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cq github.com/NYTimes/gziphandler v1.1.1/go.mod h1:n/CVRwUEOgIxrgPvAQhUUr9oeUtvrhMomdKFjzJNB0c= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= -github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= -github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= -github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= -github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/aws/aws-sdk-go-v2 v1.32.5 h1:U8vdWJuY7ruAkzaOdD7guwJjD06YSKmnKCJs7s3IkIo= -github.com/aws/aws-sdk-go-v2 v1.32.5/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= -github.com/aws/aws-sdk-go-v2/config v1.28.0 h1:FosVYWcqEtWNxHn8gB/Vs6jOlNwSoyOCA/g/sxyySOQ= -github.com/aws/aws-sdk-go-v2/config v1.28.0/go.mod h1:pYhbtvg1siOOg8h5an77rXle9tVG8T+BWLWAo7cOukc= -github.com/aws/aws-sdk-go-v2/credentials v1.17.41 h1:7gXo+Axmp+R4Z+AK8YFQO0ZV3L0gizGINCOWxSLY9W8= -github.com/aws/aws-sdk-go-v2/credentials v1.17.41/go.mod h1:u4Eb8d3394YLubphT4jLEwN1rLNq2wFOlT6OuxFwPzU= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 h1:TMH3f/SCAWdNtXXVPPu5D6wrr4G5hI1rAxbcocKfC7Q= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17/go.mod h1:1ZRXLdTpzdJb9fwTMXiLipENRxkGMTn1sfKexGllQCw= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.24 h1:4usbeaes3yJnCFC7kfeyhkdkPtoRYPa/hTmCqMpKpLI= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.24/go.mod h1:5CI1JemjVwde8m2WG3cz23qHKPOxbpkq0HaoreEgLIY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.24 h1:N1zsICrQglfzaBnrfM0Ys00860C+QFwu6u/5+LomP+o= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.24/go.mod h1:dCn9HbJ8+K31i8IQ8EWmWj0EiIk0+vKiHNMxTTYveAg= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.186.0 h1:n2l2WeV+lEABrGwG/4MsE0WFEbd3j7yKsmZzbnEm5CY= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.186.0/go.mod h1:kYXaB4FzyhEJjvrJ84oPnMElLiEAjGxxUunVW2tBSng= +github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA= +github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= +github.com/aws/aws-sdk-go-v2 v1.36.5 h1:0OF9RiEMEdDdZEMqF9MRjevyxAQcf6gY+E7vwBILFj0= +github.com/aws/aws-sdk-go-v2 v1.36.5/go.mod h1:EYrzvCCN9CMUTa5+6lf6MM4tq3Zjp8UhSGR/cBsjai0= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.31 h1:oQWSGexYasNpYp4epLGZxxjsDo8BMBh6iNWkTXQvkwk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.31/go.mod h1:nc332eGUU+djP3vrMI6blS0woaCfHTe3KiSQUVTMRq0= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.36 h1:SsytQyTMHMDPspp+spo7XwXTP44aJZZAC7fBV2C5+5s= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.36/go.mod h1:Q1lnJArKRXkenyog6+Y+zr7WDpk4e6XlR6gs20bbeNo= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36 h1:i2vNHQiXUvKhs3quBR6aqlgJaiaexz/aNvdCktW/kAM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36/go.mod h1:UdyGa7Q91id/sdyHPwth+043HhmP6yP9MBHgbZM0xo8= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.53.3 h1:spHCGHuTPi/QaPd6tADKBTGO/ZTbB0rfGDB0V4jXE9g= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.53.3/go.mod h1:6U/Xm5bBkZGCTxH3NE9+hPKEpCFCothGn/gwytsr1Mk= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.218.0 h1:QPYsTfcPpPhkF+37pxLcl3xbQz2SRxsShQNB6VCkvLo= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.218.0/go.mod h1:ouvGEfHbLaIlWwpDpOVWPWR+YwO0HDv3vm5tYLq8ImY= github.com/aws/aws-sdk-go-v2/service/ecr v1.36.2 h1:VDQaVwGOokbd3VUbHF+wupiffdrbAZPdQnr5XZMJqrs= github.com/aws/aws-sdk-go-v2/service/ecr v1.36.2/go.mod h1:lvUlMghKYmSxSfv0vU7pdU/8jSY+s0zpG8xXhaGKCw0= github.com/aws/aws-sdk-go-v2/service/ecrpublic v1.27.2 h1:Zru9Iy2JPM5+uRnFnoqeOZzi8JIVIHJ0ua6JdeDHcyg= github.com/aws/aws-sdk-go-v2/service/ecrpublic v1.27.2/go.mod h1:PtQC3XjutCYFCn1+i8+wtpDaXvEK+vXF2gyLIKAmh4A= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.5 h1:wtpJ4zcwrSbwhECWQoI/g6WM9zqCcSpHDJIWSbMLOu4= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.5/go.mod h1:qu/W9HXQbbQ4+1+JcZp0ZNPV31ym537ZJN+fiS7Ti8E= -github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 h1:bSYXVyUzoTHoKalBmwaZxs97HU9DWWI3ehHSAMa7xOk= -github.com/aws/aws-sdk-go-v2/service/sso v1.24.2/go.mod h1:skMqY7JElusiOUjMJMOv1jJsP7YUg7DrhgqZZWuzu1U= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 h1:AhmO1fHINP9vFYUE0LHzCWg/LfUWUF+zFPEcY9QXb7o= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2/go.mod h1:o8aQygT2+MVP0NaV6kbdE1YnnIM8RRVQzoeUH45GOdI= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.1 h1:6SZUVRQNvExYlMLbHdlKB48x0fLbc2iVROyaNEwBHbU= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.1/go.mod h1:GqWyYCwLXnlUB1lOAXQyNSPqPLQJvmo8J0DWBzp9mtg= -github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= -github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.29.3 h1:DpyV8LeDf0y7iDaGZ3h1Y+Nh5IaBOR+xj44vVgEEegY= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.29.3/go.mod h1:H232HdqVlSUoqy0cMJYW1TKjcxvGFGFZ20xQG8fOAPw= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.45.2 h1:vX70Z4lNSr7XsioU0uJq5yvxgI50sB66MvD+V/3buS4= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.45.2/go.mod h1:xnCC3vFBfOKpU6PcsCKL2ktgBTZfOwTGxj6V8/X3IS4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/kms v1.41.0 h1:2jKyib9msVrAVn+lngwlSplG13RpUZmzVte2yDao5nc= +github.com/aws/aws-sdk-go-v2/service/kms v1.41.0/go.mod h1:RyhzxkWGcfixlkieewzpO3D4P4fTMxhIDqDZWsh0u/4= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.4 h1:uqXzVZNuNexwc/xrh6Tb56u89WDlJY6HS+KC0S4QSjw= +github.com/aws/smithy-go v1.22.4/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= @@ -107,8 +113,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2Rrd27c3VGxi6a/6HNq8QmHRKM= -github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= @@ -155,10 +161,10 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA= -github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= -github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= -github.com/onsi/gomega v1.33.1/go.mod h1:U4R44UsT+9eLIaYRB2a5qajjtQYn0hauxvRm16AVYg0= +github.com/onsi/ginkgo/v2 v2.23.0 h1:FA1xjp8ieYDzlgS5ABTpdUDB7wtngggONc8a7ku2NqQ= +github.com/onsi/ginkgo/v2 v2.23.0/go.mod h1:zXTP6xIp3U8aVuXN8ENK9IXRaTjFnpVB9mGmaSRvxnM= +github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= +github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -194,8 +200,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75 h1:6fotK7otjonDflCTK0BCfls4SPy3NcCVb5dqqmbRknE= github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75/go.mod h1:KO6IkyS8Y3j8OdNO85qEYBsRPuteD+YciPomcXdrMnk= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= @@ -253,16 +259,16 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= +golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -270,8 +276,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -281,22 +287,22 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= +golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -305,8 +311,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -319,8 +325,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= +google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/pkg/controllers/tagging/tagging_controller.go b/pkg/controllers/tagging/tagging_controller.go index 909c8237eb..5d0f0071fb 100644 --- a/pkg/controllers/tagging/tagging_controller.go +++ b/pkg/controllers/tagging/tagging_controller.go @@ -14,6 +14,7 @@ limitations under the License. package tagging import ( + "context" "crypto/md5" "fmt" "sort" @@ -45,7 +46,7 @@ func init() { // workItem contains the node and an action for that node type workItem struct { node *v1.Node - action func(node *v1.Node) error + action func(ctx context.Context, node *v1.Node) error requeuingCount int enqueueTime time.Time } @@ -176,33 +177,33 @@ func NewTaggingController( // Run will start the controller to tag resources attached to the cluster // and untag resources detached from the cluster. -func (tc *Controller) Run(stopCh <-chan struct{}) { +func (tc *Controller) Run(ctx context.Context) { defer utilruntime.HandleCrash() defer tc.workqueue.ShutDown() // Wait for the caches to be synced before starting workers klog.Info("Waiting for informer caches to sync") - if ok := cache.WaitForCacheSync(stopCh, tc.nodesSynced); !ok { + if ok := cache.WaitForCacheSync(ctx.Done(), tc.nodesSynced); !ok { klog.Errorf("failed to wait for caches to sync") return } klog.Infof("Starting the tagging controller") - go wait.Until(tc.work, tc.nodeMonitorPeriod, stopCh) + go wait.UntilWithContext(ctx, func(ctx context.Context) { tc.work(ctx) }, tc.nodeMonitorPeriod) - <-stopCh + <-ctx.Done() } // work is a long-running function that continuously // call process() for each message on the workqueue -func (tc *Controller) work() { - for tc.process() { +func (tc *Controller) work(ctx context.Context) { + for tc.process(ctx) { } } // process reads each message in the queue and performs either // tag or untag function on the Node object -func (tc *Controller) process() bool { +func (tc *Controller) process(ctx context.Context) bool { obj, shutdown := tc.workqueue.Get() if shutdown { return false @@ -240,7 +241,7 @@ func (tc *Controller) process() bool { return nil } - err = workItem.action(workItem.node) + err = workItem.action(ctx, workItem.node) if err != nil { if workItem.requeuingCount < maxRequeuingCount { @@ -275,11 +276,11 @@ func (tc *Controller) process() bool { // tagNodesResources tag node resources // If we want to tag more resources, modify this function appropriately -func (tc *Controller) tagNodesResources(node *v1.Node) error { +func (tc *Controller) tagNodesResources(ctx context.Context, node *v1.Node) error { for _, resource := range tc.resources { switch resource { case opt.Instance: - err := tc.tagEc2Instance(node) + err := tc.tagEc2Instance(ctx, node) if err != nil { return err } @@ -291,7 +292,7 @@ func (tc *Controller) tagNodesResources(node *v1.Node) error { // tagEc2Instances applies the provided tags to each EC2 instance in // the cluster. -func (tc *Controller) tagEc2Instance(node *v1.Node) error { +func (tc *Controller) tagEc2Instance(ctx context.Context, node *v1.Node) error { if !tc.isTaggingRequired(node) { klog.Infof("Skip tagging node %s since it was already tagged earlier.", node.GetName()) return nil @@ -299,7 +300,7 @@ func (tc *Controller) tagEc2Instance(node *v1.Node) error { instanceID, _ := awsv1.KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() - err := tc.cloud.TagResource(string(instanceID), tc.tags) + err := tc.cloud.TagResource(ctx, string(instanceID), tc.tags) if err != nil { if awsv1.IsAWSErrorInstanceNotFound(err) { @@ -332,11 +333,11 @@ func (tc *Controller) tagEc2Instance(node *v1.Node) error { // untagNodeResources untag node resources // If we want to untag more resources, modify this function appropriately -func (tc *Controller) untagNodeResources(node *v1.Node) error { +func (tc *Controller) untagNodeResources(ctx context.Context, node *v1.Node) error { for _, resource := range tc.resources { switch resource { case opt.Instance: - err := tc.untagEc2Instance(node) + err := tc.untagEc2Instance(ctx, node) if err != nil { return err } @@ -348,10 +349,10 @@ func (tc *Controller) untagNodeResources(node *v1.Node) error { // untagEc2Instances deletes the provided tags to each EC2 instances in // the cluster. -func (tc *Controller) untagEc2Instance(node *v1.Node) error { +func (tc *Controller) untagEc2Instance(ctx context.Context, node *v1.Node) error { instanceID, _ := awsv1.KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() - err := tc.cloud.UntagResource(string(instanceID), tc.tags) + err := tc.cloud.UntagResource(ctx, string(instanceID), tc.tags) if err != nil { klog.Errorf("Error in untagging EC2 instance %s for node %s, error: %v", instanceID, node.GetName(), err) @@ -365,7 +366,7 @@ func (tc *Controller) untagEc2Instance(node *v1.Node) error { // enqueueNode takes in the object and an // action for the object for a workitem and enqueue to the workqueue -func (tc *Controller) enqueueNode(node *v1.Node, action func(node *v1.Node) error) { +func (tc *Controller) enqueueNode(node *v1.Node, action func(ctx context.Context, node *v1.Node) error) { item := &workItem{ node: node, action: action, diff --git a/pkg/controllers/tagging/tagging_controller_test.go b/pkg/controllers/tagging/tagging_controller_test.go index 89142bec12..f0b42221ee 100644 --- a/pkg/controllers/tagging/tagging_controller_test.go +++ b/pkg/controllers/tagging/tagging_controller_test.go @@ -237,7 +237,7 @@ func Test_NodesJoiningAndLeaving(t *testing.T) { } for tc.workqueue.Len() > 0 { - tc.process() + tc.process(context.TODO()) // sleep briefly because of exponential backoff when requeueing failed workitem // resulting in workqueue to be empty if checked immediately diff --git a/pkg/controllers/tagging/tagging_controller_wrapper.go b/pkg/controllers/tagging/tagging_controller_wrapper.go index e44181e168..1e0ee7167d 100644 --- a/pkg/controllers/tagging/tagging_controller_wrapper.go +++ b/pkg/controllers/tagging/tagging_controller_wrapper.go @@ -55,7 +55,7 @@ func (tc *ControllerWrapper) startTaggingController(ctx context.Context, initCon return nil, false, nil } - go taggingcontroller.Run(ctx.Done()) + go taggingcontroller.Run(ctx) return nil, true, nil } diff --git a/pkg/providers/v1/aws.go b/pkg/providers/v1/aws.go index ec126fba42..f8bd3d4fb7 100644 --- a/pkg/providers/v1/aws.go +++ b/pkg/providers/v1/aws.go @@ -18,6 +18,7 @@ package aws import ( "context" + "errors" "fmt" "io" "net" @@ -27,19 +28,17 @@ import ( "strings" "time" - stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/smithy-go" "gopkg.in/gcfg.v1" v1 "k8s.io/api/core/v1" @@ -61,7 +60,6 @@ import ( "k8s.io/cloud-provider-aws/pkg/providers/v1/iface" "k8s.io/cloud-provider-aws/pkg/providers/v1/variant" _ "k8s.io/cloud-provider-aws/pkg/providers/v1/variant/fargate" // ensure the fargate variant gets registered - "k8s.io/cloud-provider-aws/pkg/resourcemanagers" "k8s.io/cloud-provider-aws/pkg/services" ) @@ -287,76 +285,74 @@ const MaxReadThenCreateRetries = 30 // Services is an abstraction over AWS, to allow mocking/other implementations type Services interface { - Compute(region string) (iface.EC2, error) - LoadBalancing(region string) (ELB, error) - LoadBalancingV2(region string) (ELBV2, error) - Metadata() (config.EC2Metadata, error) - KeyManagement(region string) (KMS, error) + Compute(ctx context.Context, region string, assumeRoleProvider *stscreds.AssumeRoleProvider) (iface.EC2, error) + LoadBalancing(ctx context.Context, regionName string, assumeRoleProvider *stscreds.AssumeRoleProvider) (ELB, error) + LoadBalancingV2(ctx context.Context, regionName string, assumeRoleProvider *stscreds.AssumeRoleProvider) (ELBV2, error) + Metadata(ctx context.Context) (config.EC2Metadata, error) + KeyManagement(ctx context.Context, regionName string, assumeRoleProvider *stscreds.AssumeRoleProvider) (KMS, error) } // ELB is a simple pass-through of AWS' ELB client interface, which allows for testing type ELB interface { - CreateLoadBalancer(*elb.CreateLoadBalancerInput) (*elb.CreateLoadBalancerOutput, error) - DeleteLoadBalancer(*elb.DeleteLoadBalancerInput) (*elb.DeleteLoadBalancerOutput, error) - DescribeLoadBalancers(*elb.DescribeLoadBalancersInput) (*elb.DescribeLoadBalancersOutput, error) - AddTags(*elb.AddTagsInput) (*elb.AddTagsOutput, error) - RegisterInstancesWithLoadBalancer(*elb.RegisterInstancesWithLoadBalancerInput) (*elb.RegisterInstancesWithLoadBalancerOutput, error) - DeregisterInstancesFromLoadBalancer(*elb.DeregisterInstancesFromLoadBalancerInput) (*elb.DeregisterInstancesFromLoadBalancerOutput, error) - CreateLoadBalancerPolicy(*elb.CreateLoadBalancerPolicyInput) (*elb.CreateLoadBalancerPolicyOutput, error) - SetLoadBalancerPoliciesForBackendServer(*elb.SetLoadBalancerPoliciesForBackendServerInput) (*elb.SetLoadBalancerPoliciesForBackendServerOutput, error) - SetLoadBalancerPoliciesOfListener(input *elb.SetLoadBalancerPoliciesOfListenerInput) (*elb.SetLoadBalancerPoliciesOfListenerOutput, error) - DescribeLoadBalancerPolicies(input *elb.DescribeLoadBalancerPoliciesInput) (*elb.DescribeLoadBalancerPoliciesOutput, error) + CreateLoadBalancer(ctx context.Context, input *elb.CreateLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.CreateLoadBalancerOutput, error) + DeleteLoadBalancer(ctx context.Context, input *elb.DeleteLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.DeleteLoadBalancerOutput, error) + DescribeLoadBalancers(ctx context.Context, input *elb.DescribeLoadBalancersInput, optFns ...func(*elb.Options)) (*elb.DescribeLoadBalancersOutput, error) + AddTags(ctx context.Context, input *elb.AddTagsInput, optFns ...func(*elb.Options)) (*elb.AddTagsOutput, error) + RegisterInstancesWithLoadBalancer(ctx context.Context, input *elb.RegisterInstancesWithLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.RegisterInstancesWithLoadBalancerOutput, error) + DeregisterInstancesFromLoadBalancer(ctx context.Context, input *elb.DeregisterInstancesFromLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.DeregisterInstancesFromLoadBalancerOutput, error) + CreateLoadBalancerPolicy(ctx context.Context, input *elb.CreateLoadBalancerPolicyInput, optFns ...func(*elb.Options)) (*elb.CreateLoadBalancerPolicyOutput, error) + SetLoadBalancerPoliciesForBackendServer(ctx context.Context, input *elb.SetLoadBalancerPoliciesForBackendServerInput, optFns ...func(*elb.Options)) (*elb.SetLoadBalancerPoliciesForBackendServerOutput, error) + SetLoadBalancerPoliciesOfListener(ctx context.Context, input *elb.SetLoadBalancerPoliciesOfListenerInput, optFns ...func(*elb.Options)) (*elb.SetLoadBalancerPoliciesOfListenerOutput, error) + DescribeLoadBalancerPolicies(ctx context.Context, input *elb.DescribeLoadBalancerPoliciesInput, optFns ...func(*elb.Options)) (*elb.DescribeLoadBalancerPoliciesOutput, error) - DetachLoadBalancerFromSubnets(*elb.DetachLoadBalancerFromSubnetsInput) (*elb.DetachLoadBalancerFromSubnetsOutput, error) - AttachLoadBalancerToSubnets(*elb.AttachLoadBalancerToSubnetsInput) (*elb.AttachLoadBalancerToSubnetsOutput, error) + DetachLoadBalancerFromSubnets(ctx context.Context, input *elb.DetachLoadBalancerFromSubnetsInput, optFns ...func(*elb.Options)) (*elb.DetachLoadBalancerFromSubnetsOutput, error) + AttachLoadBalancerToSubnets(ctx context.Context, input *elb.AttachLoadBalancerToSubnetsInput, optFns ...func(*elb.Options)) (*elb.AttachLoadBalancerToSubnetsOutput, error) - CreateLoadBalancerListeners(*elb.CreateLoadBalancerListenersInput) (*elb.CreateLoadBalancerListenersOutput, error) - DeleteLoadBalancerListeners(*elb.DeleteLoadBalancerListenersInput) (*elb.DeleteLoadBalancerListenersOutput, error) + CreateLoadBalancerListeners(ctx context.Context, input *elb.CreateLoadBalancerListenersInput, optFns ...func(*elb.Options)) (*elb.CreateLoadBalancerListenersOutput, error) + DeleteLoadBalancerListeners(ctx context.Context, input *elb.DeleteLoadBalancerListenersInput, optFns ...func(*elb.Options)) (*elb.DeleteLoadBalancerListenersOutput, error) - ApplySecurityGroupsToLoadBalancer(*elb.ApplySecurityGroupsToLoadBalancerInput) (*elb.ApplySecurityGroupsToLoadBalancerOutput, error) + ApplySecurityGroupsToLoadBalancer(ctx context.Context, input *elb.ApplySecurityGroupsToLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.ApplySecurityGroupsToLoadBalancerOutput, error) - ConfigureHealthCheck(*elb.ConfigureHealthCheckInput) (*elb.ConfigureHealthCheckOutput, error) + ConfigureHealthCheck(ctx context.Context, input *elb.ConfigureHealthCheckInput, optFns ...func(*elb.Options)) (*elb.ConfigureHealthCheckOutput, error) - DescribeLoadBalancerAttributes(*elb.DescribeLoadBalancerAttributesInput) (*elb.DescribeLoadBalancerAttributesOutput, error) - ModifyLoadBalancerAttributes(*elb.ModifyLoadBalancerAttributesInput) (*elb.ModifyLoadBalancerAttributesOutput, error) + DescribeLoadBalancerAttributes(ctx context.Context, input *elb.DescribeLoadBalancerAttributesInput, optFns ...func(*elb.Options)) (*elb.DescribeLoadBalancerAttributesOutput, error) + ModifyLoadBalancerAttributes(ctx context.Context, input *elb.ModifyLoadBalancerAttributesInput, optFns ...func(*elb.Options)) (*elb.ModifyLoadBalancerAttributesOutput, error) } // ELBV2 is a simple pass-through of AWS' ELBV2 client interface, which allows for testing type ELBV2 interface { - AddTags(input *elbv2.AddTagsInput) (*elbv2.AddTagsOutput, error) + AddTags(ctx context.Context, input *elbv2.AddTagsInput, optFns ...func(*elbv2.Options)) (*elbv2.AddTagsOutput, error) - CreateLoadBalancer(*elbv2.CreateLoadBalancerInput) (*elbv2.CreateLoadBalancerOutput, error) - DescribeLoadBalancers(*elbv2.DescribeLoadBalancersInput) (*elbv2.DescribeLoadBalancersOutput, error) - DeleteLoadBalancer(*elbv2.DeleteLoadBalancerInput) (*elbv2.DeleteLoadBalancerOutput, error) + CreateLoadBalancer(ctx context.Context, input *elbv2.CreateLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateLoadBalancerOutput, error) + DescribeLoadBalancers(ctx context.Context, input *elbv2.DescribeLoadBalancersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancersOutput, error) + DeleteLoadBalancer(ctx context.Context, input *elbv2.DeleteLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteLoadBalancerOutput, error) - ModifyLoadBalancerAttributes(*elbv2.ModifyLoadBalancerAttributesInput) (*elbv2.ModifyLoadBalancerAttributesOutput, error) - DescribeLoadBalancerAttributes(*elbv2.DescribeLoadBalancerAttributesInput) (*elbv2.DescribeLoadBalancerAttributesOutput, error) + ModifyLoadBalancerAttributes(ctx context.Context, input *elbv2.ModifyLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyLoadBalancerAttributesOutput, error) + DescribeLoadBalancerAttributes(ctx context.Context, input *elbv2.DescribeLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancerAttributesOutput, error) - CreateTargetGroup(*elbv2.CreateTargetGroupInput) (*elbv2.CreateTargetGroupOutput, error) - DescribeTargetGroups(*elbv2.DescribeTargetGroupsInput) (*elbv2.DescribeTargetGroupsOutput, error) - ModifyTargetGroup(*elbv2.ModifyTargetGroupInput) (*elbv2.ModifyTargetGroupOutput, error) - DeleteTargetGroup(*elbv2.DeleteTargetGroupInput) (*elbv2.DeleteTargetGroupOutput, error) + CreateTargetGroup(ctx context.Context, input *elbv2.CreateTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateTargetGroupOutput, error) + DescribeTargetGroups(ctx context.Context, input *elbv2.DescribeTargetGroupsInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupsOutput, error) + ModifyTargetGroup(ctx context.Context, input *elbv2.ModifyTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupOutput, error) + DeleteTargetGroup(ctx context.Context, input *elbv2.DeleteTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteTargetGroupOutput, error) - DescribeTargetHealth(input *elbv2.DescribeTargetHealthInput) (*elbv2.DescribeTargetHealthOutput, error) + DescribeTargetHealth(ctx context.Context, input *elbv2.DescribeTargetHealthInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) - DescribeTargetGroupAttributes(*elbv2.DescribeTargetGroupAttributesInput) (*elbv2.DescribeTargetGroupAttributesOutput, error) - ModifyTargetGroupAttributes(*elbv2.ModifyTargetGroupAttributesInput) (*elbv2.ModifyTargetGroupAttributesOutput, error) + DescribeTargetGroupAttributes(ctx context.Context, input *elbv2.DescribeTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupAttributesOutput, error) + ModifyTargetGroupAttributes(ctx context.Context, input *elbv2.ModifyTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupAttributesOutput, error) - RegisterTargets(*elbv2.RegisterTargetsInput) (*elbv2.RegisterTargetsOutput, error) - DeregisterTargets(*elbv2.DeregisterTargetsInput) (*elbv2.DeregisterTargetsOutput, error) + RegisterTargets(ctx context.Context, input *elbv2.RegisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) + DeregisterTargets(ctx context.Context, input *elbv2.DeregisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) - CreateListener(*elbv2.CreateListenerInput) (*elbv2.CreateListenerOutput, error) - DescribeListeners(*elbv2.DescribeListenersInput) (*elbv2.DescribeListenersOutput, error) - DeleteListener(*elbv2.DeleteListenerInput) (*elbv2.DeleteListenerOutput, error) - ModifyListener(*elbv2.ModifyListenerInput) (*elbv2.ModifyListenerOutput, error) - - WaitUntilLoadBalancersDeleted(*elbv2.DescribeLoadBalancersInput) error + CreateListener(ctx context.Context, input *elbv2.CreateListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateListenerOutput, error) + DescribeListeners(ctx context.Context, input *elbv2.DescribeListenersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeListenersOutput, error) + DeleteListener(ctx context.Context, input *elbv2.DeleteListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteListenerOutput, error) + ModifyListener(ctx context.Context, input *elbv2.ModifyListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyListenerOutput, error) } // KMS is a simple pass-through of the Key Management Service client interface, // which allows for testing. type KMS interface { - DescribeKey(*kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) + DescribeKey(ctx context.Context, input *kms.DescribeKeyInput, optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) } var _ cloudprovider.Interface = (*Cloud)(nil) @@ -384,7 +380,7 @@ type Cloud struct { instanceCache instanceCache zoneCache zoneCache - instanceTopologyManager resourcemanagers.InstanceTopologyManager + instanceTopologyManager InstanceTopologyManager clientBuilder cloudprovider.ControllerClientBuilder kubeClient clientset.Interface @@ -399,7 +395,15 @@ type Cloud struct { // Interface to make the CloudConfig immutable for awsSDKProvider type awsCloudConfigProvider interface { - GetResolver() endpoints.ResolverFunc + GetEC2EndpointOpts(region string) []func(*ec2.Options) + GetCustomEC2Resolver() ec2.EndpointResolverV2 + GetELBEndpointOpts(region string) []func(*elb.Options) + GetCustomELBResolver() elb.EndpointResolverV2 + GetELBV2EndpointOpts(region string) []func(*elbv2.Options) + GetCustomELBV2Resolver() elbv2.EndpointResolverV2 + GetKMSEndpointOpts(region string) []func(*kms.Options) + GetCustomKMSResolver() kms.EndpointResolverV2 + GetIMDSEndpointOpts() []func(*imds.Options) } // InstanceIDIndexFunc indexes based on a Node's instance ID found in its spec.providerID @@ -432,12 +436,12 @@ func (c *Cloud) SetInformers(informerFactory informers.SharedInformerFactory) { }) } -func newEc2Filter(name string, values ...string) *ec2.Filter { - filter := &ec2.Filter{ +func newEc2Filter(name string, values ...string) ec2types.Filter { + filter := ec2types.Filter{ Name: aws.String(name), } for _, value := range values { - filter.Values = append(filter.Values, aws.String(value)) + filter.Values = append(filter.Values, value) } return filter } @@ -465,70 +469,28 @@ func init() { return nil, fmt.Errorf("unable to validate custom endpoint overrides: %v", err) } - metadata, err := newAWSSDKProvider(nil, cfg).Metadata() + metadata, err := newAWSSDKProvider(nil, cfg).Metadata(ctx) if err != nil { return nil, fmt.Errorf("error creating AWS metadata client: %q", err) } - regionName, err := getRegionFromMetadata(*cfg, metadata) + regionName, err := getRegionFromMetadata(ctx, *cfg, metadata) if err != nil { return nil, err } - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *aws.NewConfig().WithRegion(regionName).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint), - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) - } - - var creds *credentials.Credentials - var credsV2 *stscredsv2.AssumeRoleProvider + var creds *stscreds.AssumeRoleProvider if cfg.Global.RoleARN != "" { - stsClient, err := getSTSClient(sess, cfg.Global.RoleARN, cfg.Global.SourceARN) - if err != nil { - return nil, fmt.Errorf("unable to create sts client, %v", err) - } - creds = credentials.NewChainCredentials( - []credentials.Provider{ - &credentials.EnvProvider{}, - assumeRoleProvider(&stscreds.AssumeRoleProvider{ - Client: stsClient, - RoleARN: cfg.Global.RoleARN, - }), - }) - - stsClientv2, err := services.NewStsV2Client(ctx, regionName, cfg.Global.RoleARN, cfg.Global.SourceARN) + stsClient, err := services.NewStsClient(ctx, regionName, cfg.Global.RoleARN, cfg.Global.SourceARN) if err != nil { return nil, fmt.Errorf("unable to create sts v2 client: %v", err) } - credsV2 = stscredsv2.NewAssumeRoleProvider(stsClientv2, cfg.Global.RoleARN) + creds = stscreds.NewAssumeRoleProvider(stsClient, cfg.Global.RoleARN) } aws := newAWSSDKProvider(creds, cfg) - return newAWSCloud2(*cfg, aws, aws, creds, credsV2) - }) -} - -func getSTSClient(sess *session.Session, roleARN, sourceARN string) (*sts.STS, error) { - klog.Infof("Using AWS assumed role %v", roleARN) - stsClient := sts.New(sess) - sourceAcct, err := GetSourceAccount(roleARN) - if err != nil { - return nil, err - } - reqHeaders := map[string]string{ - headerSourceAccount: sourceAcct, - } - if sourceARN != "" { - reqHeaders[headerSourceArn] = sourceARN - } - stsClient.Handlers.Sign.PushFront(func(s *request.Request) { - s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) + return newAWSCloud2(*cfg, aws, aws, creds) }) - klog.V(4).Infof("configuring STS client with extra headers, %v", reqHeaders) - return stsClient, nil } // readAWSCloudConfig reads an instance of AWSCloudConfig from config reader. @@ -563,48 +525,43 @@ func azToRegion(az string) (string, error) { } func newAWSCloud(cfg config.CloudConfig, awsServices Services) (*Cloud, error) { - return newAWSCloud2(cfg, awsServices, nil, nil, nil) + return newAWSCloud2(cfg, awsServices, nil, nil) } // newAWSCloud creates a new instance of AWSCloud. // AWSProvider and instanceId are primarily for tests -func newAWSCloud2(cfg config.CloudConfig, awsServices Services, provider config.SDKProvider, credentials *credentials.Credentials, credentialsV2 *stscredsv2.AssumeRoleProvider) (*Cloud, error) { +func newAWSCloud2(cfg config.CloudConfig, awsServices Services, provider config.SDKProvider, credentials *stscreds.AssumeRoleProvider) (*Cloud, error) { ctx := context.Background() // We have some state in the Cloud object // Log so that if we are building multiple Cloud objects, it is obvious! klog.Infof("Building AWS cloudprovider") - metadata, err := awsServices.Metadata() + metadata, err := awsServices.Metadata(ctx) if err != nil { return nil, fmt.Errorf("error creating AWS metadata client: %q", err) } - regionName, err := getRegionFromMetadata(cfg, metadata) + regionName, err := getRegionFromMetadata(ctx, cfg, metadata) if err != nil { return nil, err } - ec2, err := awsServices.Compute(regionName) + ec2, err := awsServices.Compute(ctx, regionName, credentials) if err != nil { return nil, fmt.Errorf("error creating AWS EC2 client: %v", err) } - ec2v2, err := services.NewEc2SdkV2(ctx, regionName, credentialsV2) - if err != nil { - return nil, fmt.Errorf("error creating AWS EC2v2 client: %v", err) - } - - elb, err := awsServices.LoadBalancing(regionName) + elb, err := awsServices.LoadBalancing(ctx, regionName, credentials) if err != nil { return nil, fmt.Errorf("error creating AWS ELB client: %v", err) } - elbv2, err := awsServices.LoadBalancingV2(regionName) + elbv2, err := awsServices.LoadBalancingV2(ctx, regionName, credentials) if err != nil { return nil, fmt.Errorf("error creating AWS ELBV2 client: %v", err) } - kms, err := awsServices.KeyManagement(regionName) + kms, err := awsServices.KeyManagement(ctx, regionName, credentials) if err != nil { return nil, fmt.Errorf("error creating AWS key management client: %v", err) } @@ -620,7 +577,7 @@ func newAWSCloud2(cfg config.CloudConfig, awsServices Services, provider config. } awsCloud.instanceCache.cloud = awsCloud awsCloud.zoneCache.cloud = awsCloud - awsCloud.instanceTopologyManager = resourcemanagers.NewInstanceTopologyManager(ec2v2, &cfg) + awsCloud.instanceTopologyManager = NewInstanceTopologyManager(ec2, &cfg) tagged := cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != "" if cfg.Global.VPC != "" && (cfg.Global.SubnetID != "" || cfg.Global.RoleARN != "") && tagged { @@ -634,7 +591,7 @@ func newAWSCloud2(cfg config.CloudConfig, awsServices Services, provider config. } awsCloud.vpcID = cfg.Global.VPC } else { - selfAWSInstance, err := awsCloud.buildSelfAWSInstance() + selfAWSInstance, err := awsCloud.buildSelfAWSInstance(ctx) if err != nil { return nil, err } @@ -648,7 +605,7 @@ func newAWSCloud2(cfg config.CloudConfig, awsServices Services, provider config. } } else { // TODO: Clean up double-API query - info, err := awsCloud.selfAWSInstance.describeInstance() + info, err := awsCloud.selfAWSInstance.describeInstance(ctx) if err != nil { return nil, err } @@ -745,7 +702,7 @@ func (c *Cloud) NodeAddresses(ctx context.Context, name types.NodeName) ([]v1.No // extractIPv4NodeAddresses maps the instance information from EC2 to an array of NodeAddresses. // This function will extract private and public IP addresses and their corresponding DNS names. -func extractIPv4NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) { +func extractIPv4NodeAddresses(instance *ec2types.Instance) ([]v1.NodeAddress, error) { // Not clear if the order matters here, but we might as well indicate a sensible preference order if instance == nil { @@ -764,21 +721,21 @@ func extractIPv4NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) return true } - return aws.Int64Value(instance.NetworkInterfaces[i].Attachment.DeviceIndex) < aws.Int64Value(instance.NetworkInterfaces[j].Attachment.DeviceIndex) + return aws.ToInt32(instance.NetworkInterfaces[i].Attachment.DeviceIndex) < aws.ToInt32(instance.NetworkInterfaces[j].Attachment.DeviceIndex) }) // handle internal network interfaces for _, networkInterface := range instance.NetworkInterfaces { // skip network interfaces that are not currently in use - if aws.StringValue(networkInterface.Status) != ec2.NetworkInterfaceStatusInUse { + if networkInterface.Status != ec2types.NetworkInterfaceStatusInUse { continue } for _, internalIP := range networkInterface.PrivateIpAddresses { - if ipAddress := aws.StringValue(internalIP.PrivateIpAddress); ipAddress != "" { + if ipAddress := aws.ToString(internalIP.PrivateIpAddress); ipAddress != "" { ip := netutils.ParseIPSloppy(ipAddress) if ip == nil { - return nil, fmt.Errorf("EC2 instance had invalid private address: %s (%q)", aws.StringValue(instance.InstanceId), ipAddress) + return nil, fmt.Errorf("EC2 instance had invalid private address: %s (%q)", aws.ToString(instance.InstanceId), ipAddress) } addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()}) } @@ -786,22 +743,22 @@ func extractIPv4NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) } // TODO: Other IP addresses (multiple ips)? - publicIPAddress := aws.StringValue(instance.PublicIpAddress) + publicIPAddress := aws.ToString(instance.PublicIpAddress) if publicIPAddress != "" { ip := netutils.ParseIPSloppy(publicIPAddress) if ip == nil { - return nil, fmt.Errorf("EC2 instance had invalid public address: %s (%s)", aws.StringValue(instance.InstanceId), publicIPAddress) + return nil, fmt.Errorf("EC2 instance had invalid public address: %s (%s)", aws.ToString(instance.InstanceId), publicIPAddress) } addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()}) } - privateDNSName := aws.StringValue(instance.PrivateDnsName) + privateDNSName := aws.ToString(instance.PrivateDnsName) if privateDNSName != "" { addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalDNS, Address: privateDNSName}) addresses = append(addresses, v1.NodeAddress{Type: v1.NodeHostName, Address: privateDNSName}) } - publicDNSName := aws.StringValue(instance.PublicDnsName) + publicDNSName := aws.ToString(instance.PublicDnsName) if publicDNSName != "" { addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalDNS, Address: publicDNSName}) } @@ -811,7 +768,7 @@ func extractIPv4NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) // extractIPv6NodeAddresses maps the instance information from EC2 to an array of NodeAddresses // All IPv6 addresses are considered internal even if they are publicly routable. There are no instance DNS names associated with IPv6. -func extractIPv6NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) { +func extractIPv6NodeAddresses(instance *ec2types.Instance) ([]v1.NodeAddress, error) { // Not clear if the order matters here, but we might as well indicate a sensible preference order if instance == nil { @@ -823,15 +780,15 @@ func extractIPv6NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) // handle internal network interfaces with IPv6 addresses for _, networkInterface := range instance.NetworkInterfaces { // skip network interfaces that are not currently in use - if aws.StringValue(networkInterface.Status) != ec2.NetworkInterfaceStatusInUse || len(networkInterface.Ipv6Addresses) == 0 { + if networkInterface.Status != ec2types.NetworkInterfaceStatusInUse || len(networkInterface.Ipv6Addresses) == 0 { continue } // return only the "first" address for each ENI - internalIPv6 := aws.StringValue(networkInterface.Ipv6Addresses[0].Ipv6Address) + internalIPv6 := aws.ToString(networkInterface.Ipv6Addresses[0].Ipv6Address) ip := net.ParseIP(internalIPv6) if ip == nil { - return nil, fmt.Errorf("EC2 instance had invalid IPv6 address: %s (%q)", aws.StringValue(instance.InstanceId), internalIPv6) + return nil, fmt.Errorf("EC2 instance had invalid IPv6 address: %s (%q)", aws.ToString(instance.InstanceId), internalIPv6) } addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()}) } @@ -849,10 +806,10 @@ func (c *Cloud) NodeAddressesByProviderID(ctx context.Context, providerID string } if v := variant.GetVariant(string(instanceID)); v != nil { - return v.NodeAddresses(string(instanceID), c.vpcID) + return v.NodeAddresses(ctx, string(instanceID), c.vpcID) } - instance, err := describeInstance(c.ec2, instanceID) + instance, err := describeInstance(ctx, c.ec2, instanceID) if err != nil { return nil, err } @@ -888,14 +845,14 @@ func (c *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID strin } if v := variant.GetVariant(string(instanceID)); v != nil { - return v.InstanceExists(string(instanceID), c.vpcID) + return v.InstanceExists(ctx, string(instanceID), c.vpcID) } request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []string{string(instanceID)}, } - instances, err := c.ec2.DescribeInstances(request) + instances, err := c.ec2.DescribeInstances(ctx, request) if err != nil { // if err is InstanceNotFound, return false with no error if IsAWSErrorInstanceNotFound(err) { @@ -911,7 +868,7 @@ func (c *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID strin } state := instances[0].State.Name - if *state == ec2.InstanceStateNameTerminated { + if state == ec2types.InstanceStateNameTerminated { klog.Warningf("the instance %s is terminated", instanceID) return false, nil } @@ -927,14 +884,14 @@ func (c *Cloud) InstanceShutdownByProviderID(ctx context.Context, providerID str } if v := variant.GetVariant(string(instanceID)); v != nil { - return v.InstanceShutdown(string(instanceID), c.vpcID) + return v.InstanceShutdown(ctx, string(instanceID), c.vpcID) } request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []string{string(instanceID)}, } - instances, err := c.ec2.DescribeInstances(request) + instances, err := c.ec2.DescribeInstances(ctx, request) if err != nil { return false, err } @@ -950,8 +907,8 @@ func (c *Cloud) InstanceShutdownByProviderID(ctx context.Context, providerID str instance := instances[0] if instance.State != nil { - state := aws.StringValue(instance.State.Name) - if state == ec2.InstanceStateNameStopped { + state := instance.State.Name + if state == ec2types.InstanceStateNameStopped { return true, nil } } @@ -965,7 +922,7 @@ func (c *Cloud) InstanceID(ctx context.Context, nodeName types.NodeName) (string if c.selfAWSInstance.nodeName == nodeName { return "/" + c.selfAWSInstance.availabilityZone + "/" + c.selfAWSInstance.awsID, nil } - inst, err := c.getInstanceByNodeName(nodeName) + inst, err := c.getInstanceByNodeName(ctx, nodeName) if err != nil { if err == cloudprovider.InstanceNotFound { // The Instances interface requires that we return InstanceNotFound (without wrapping) @@ -973,7 +930,7 @@ func (c *Cloud) InstanceID(ctx context.Context, nodeName types.NodeName) (string } return "", fmt.Errorf("getInstanceByNodeName failed for %q with %q", nodeName, err) } - return "/" + aws.StringValue(inst.Placement.AvailabilityZone) + "/" + aws.StringValue(inst.InstanceId), nil + return "/" + aws.ToString(inst.Placement.AvailabilityZone) + "/" + aws.ToString(inst.InstanceId), nil } // InstanceTypeByProviderID returns the cloudprovider instance type of the node with the specified unique providerID @@ -989,12 +946,12 @@ func (c *Cloud) InstanceTypeByProviderID(ctx context.Context, providerID string) return v.InstanceTypeByProviderID(string(instanceID)) } - instance, err := describeInstance(c.ec2, instanceID) + instance, err := describeInstance(ctx, c.ec2, instanceID) if err != nil { return "", err } - return aws.StringValue(instance.InstanceType), nil + return string(instance.InstanceType), nil } // InstanceType returns the type of the node with the specified nodeName. @@ -1002,11 +959,11 @@ func (c *Cloud) InstanceType(ctx context.Context, nodeName types.NodeName) (stri if c.selfAWSInstance.nodeName == nodeName { return c.selfAWSInstance.instanceType, nil } - inst, err := c.getInstanceByNodeName(nodeName) + inst, err := c.getInstanceByNodeName(ctx, nodeName) if err != nil { return "", fmt.Errorf("getInstanceByNodeName failed for %q with %q", nodeName, err) } - return aws.StringValue(inst.InstanceType), nil + return string(inst.InstanceType), nil } // GetZone implements Zones.GetZone @@ -1027,10 +984,10 @@ func (c *Cloud) GetZoneByProviderID(ctx context.Context, providerID string) (clo } if v := variant.GetVariant(string(instanceID)); v != nil { - return v.GetZone(string(instanceID), c.vpcID, c.region) + return v.GetZone(ctx, string(instanceID), c.vpcID, c.region) } - instance, err := c.getInstanceByID(string(instanceID)) + instance, err := c.getInstanceByID(ctx, string(instanceID)) if err != nil { return cloudprovider.Zone{}, err } @@ -1047,7 +1004,7 @@ func (c *Cloud) GetZoneByProviderID(ctx context.Context, providerID string) (clo // This is particularly useful in external cloud providers where the kubelet // does not initialize node data. func (c *Cloud) GetZoneByNodeName(ctx context.Context, nodeName types.NodeName) (cloudprovider.Zone, error) { - instance, err := c.getInstanceByNodeName(nodeName) + instance, err := c.getInstanceByNodeName(ctx, nodeName) if err != nil { return cloudprovider.Zone{}, err } @@ -1066,11 +1023,10 @@ func IsAWSErrorInstanceNotFound(err error) bool { return false } - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == ec2.UnsuccessfulInstanceCreditSpecificationErrorCodeInvalidInstanceIdNotFound { - return true - } - } else if strings.Contains(err.Error(), ec2.UnsuccessfulInstanceCreditSpecificationErrorCodeInvalidInstanceIdNotFound) { + var ae smithy.APIError + if errors.As(err, &ae) { + return ae.ErrorCode() == string(ec2types.UnsuccessfulInstanceCreditSpecificationErrorCodeInstanceNotFound) + } else if strings.Contains(err.Error(), string(ec2types.UnsuccessfulInstanceCreditSpecificationErrorCodeInstanceNotFound)) { // In places like https://github.com/kubernetes/cloud-provider-aws/blob/1c6194aad0122ab44504de64187e3d1a7415b198/pkg/providers/v1/aws.go#L1007, // the error has been transformed into something else so check the error string to see if it contains the error code we're looking for. return true @@ -1081,11 +1037,11 @@ func IsAWSErrorInstanceNotFound(err error) bool { // Builds the awsInstance for the EC2 instance on which we are running. // This is called when the AWSCloud is initialized, and should not be called otherwise (because the awsInstance for the local instance is a singleton with drive mapping state) -func (c *Cloud) buildSelfAWSInstance() (*awsInstance, error) { +func (c *Cloud) buildSelfAWSInstance(ctx context.Context) (*awsInstance, error) { if c.selfAWSInstance != nil { panic("do not call buildSelfAWSInstance directly") } - instanceID, err := c.metadata.GetMetadata("instance-id") + instanceIDMetadata, err := c.metadata.GetMetadata(ctx, &imds.GetMetadataInput{Path: "instance-id"}) if err != nil { return nil, fmt.Errorf("error fetching instance-id from ec2 metadata service: %q", err) } @@ -1098,42 +1054,51 @@ func (c *Cloud) buildSelfAWSInstance() (*awsInstance, error) { // information from the instance returned by the EC2 API - it is a // single API call to get all the information, and it means we don't // have two code paths. - instance, err := c.getInstanceByID(instanceID) + instanceIDBytes, err := io.ReadAll(instanceIDMetadata.Content) + if err != nil { + return nil, fmt.Errorf("unable to parse instance id: %q", err) + } + defer instanceIDMetadata.Content.Close() + + instance, err := c.getInstanceByID(ctx, string(instanceIDBytes)) if err != nil { - return nil, fmt.Errorf("error finding instance %s: %q", instanceID, err) + return nil, fmt.Errorf("error finding instance %s: %q", string(instanceIDBytes), err) } return newAWSInstance(c.ec2, instance), nil } // Gets the current load balancer state -func (c *Cloud) describeLoadBalancer(name string) (*elb.LoadBalancerDescription, error) { +func (c *Cloud) describeLoadBalancer(ctx context.Context, name string) (*elbtypes.LoadBalancerDescription, error) { request := &elb.DescribeLoadBalancersInput{} - request.LoadBalancerNames = []*string{&name} + request.LoadBalancerNames = []string{name} + + response, err := c.elb.DescribeLoadBalancers(ctx, request) - response, err := c.elb.DescribeLoadBalancers(request) if err != nil { - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == "LoadBalancerNotFound" { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "LoadBalancerNotFound" { return nil, nil } } + return nil, err } - var ret *elb.LoadBalancerDescription + var ret *elbtypes.LoadBalancerDescription for _, loadBalancer := range response.LoadBalancerDescriptions { if ret != nil { klog.Errorf("Found multiple load balancers with name: %s", name) } - ret = loadBalancer + ret = &loadBalancer } return ret, nil } -func (c *Cloud) addLoadBalancerTags(loadBalancerName string, requested map[string]string) error { - var tags []*elb.Tag +func (c *Cloud) addLoadBalancerTags(ctx context.Context, loadBalancerName string, requested map[string]string) error { + var tags []elbtypes.Tag for k, v := range requested { - tag := &elb.Tag{ + tag := elbtypes.Tag{ Key: aws.String(k), Value: aws.String(v), } @@ -1141,10 +1106,10 @@ func (c *Cloud) addLoadBalancerTags(loadBalancerName string, requested map[strin } request := &elb.AddTagsInput{} - request.LoadBalancerNames = []*string{&loadBalancerName} + request.LoadBalancerNames = []string{loadBalancerName} request.Tags = tags - _, err := c.elb.AddTags(request) + _, err := c.elb.AddTags(ctx, request) if err != nil { return fmt.Errorf("error adding tags to load balancer: %v", err) } @@ -1152,25 +1117,24 @@ func (c *Cloud) addLoadBalancerTags(loadBalancerName string, requested map[strin } // Gets the current load balancer state -func (c *Cloud) describeLoadBalancerv2(name string) (*elbv2.LoadBalancer, error) { +func (c *Cloud) describeLoadBalancerv2(ctx context.Context, name string) (*elbv2types.LoadBalancer, error) { request := &elbv2.DescribeLoadBalancersInput{ - Names: []*string{aws.String(name)}, + Names: []string{name}, } - response, err := c.elbv2.DescribeLoadBalancers(request) + response, err := c.elbv2.DescribeLoadBalancers(ctx, request) if err != nil { - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == elbv2.ErrCodeLoadBalancerNotFoundException { - return nil, nil - } + var notFoundErr *elbv2types.LoadBalancerNotFoundException + if errors.As(err, ¬FoundErr) { + return nil, nil } return nil, fmt.Errorf("error describing load balancer: %q", err) } // AWS will not return 2 load balancers with the same name _and_ type. for i := range response.LoadBalancers { - if aws.StringValue(response.LoadBalancers[i].Type) == elbv2.LoadBalancerTypeEnumNetwork { - return response.LoadBalancers[i], nil + if response.LoadBalancers[i].Type == elbv2types.LoadBalancerTypeEnumNetwork { + return &response.LoadBalancers[i], nil } } @@ -1178,11 +1142,17 @@ func (c *Cloud) describeLoadBalancerv2(name string) (*elbv2.LoadBalancer, error) } // Retrieves instance's vpc id from metadata -func (c *Cloud) findVPCID() (string, error) { - macs, err := c.metadata.GetMetadata("network/interfaces/macs/") +func (c *Cloud) findVPCID(ctx context.Context) (string, error) { + macsMetadata, err := c.metadata.GetMetadata(ctx, &imds.GetMetadataInput{Path: "network/interfaces/macs/"}) if err != nil { return "", fmt.Errorf("could not list interfaces of the instance: %q", err) } + macsBytes, err := io.ReadAll(macsMetadata.Content) + if err != nil { + return "", fmt.Errorf("unable to parse macs: %q", err) + } + defer macsMetadata.Content.Close() + macs := string(macsBytes) // loop over interfaces, first vpc id returned wins for _, macPath := range strings.Split(macs, "\n") { @@ -1190,23 +1160,28 @@ func (c *Cloud) findVPCID() (string, error) { continue } url := fmt.Sprintf("network/interfaces/macs/%svpc-id", macPath) - vpcID, err := c.metadata.GetMetadata(url) + vpcIDMetadata, err := c.metadata.GetMetadata(ctx, &imds.GetMetadataInput{Path: url}) if err != nil { continue } - return vpcID, nil + vpcIDBytes, err := io.ReadAll(vpcIDMetadata.Content) + if err != nil { + continue + } + defer vpcIDMetadata.Content.Close() + return string(vpcIDBytes), nil } return "", fmt.Errorf("could not find VPC ID in instance metadata") } // Retrieves the specified security group from the AWS API, or returns nil if not found -func (c *Cloud) findSecurityGroup(securityGroupID string) (*ec2.SecurityGroup, error) { +func (c *Cloud) findSecurityGroup(ctx context.Context, securityGroupID string) (*ec2types.SecurityGroup, error) { describeSecurityGroupsRequest := &ec2.DescribeSecurityGroupsInput{ - GroupIds: []*string{&securityGroupID}, + GroupIds: []string{securityGroupID}, } // We don't apply our tag filters because we are retrieving by ID - groups, err := c.ec2.DescribeSecurityGroups(describeSecurityGroupsRequest) + groups, err := c.ec2.DescribeSecurityGroups(ctx, describeSecurityGroupsRequest) if err != nil { klog.Warningf("Error retrieving security group: %q", err) return nil, err @@ -1220,10 +1195,10 @@ func (c *Cloud) findSecurityGroup(securityGroupID string) (*ec2.SecurityGroup, e return nil, fmt.Errorf("multiple security groups found with same id %q", securityGroupID) } group := groups[0] - return group, nil + return &group, nil } -func isEqualIntPointer(l, r *int64) bool { +func isEqualIntPointer(l, r *int32) bool { if l == nil { return r == nil } @@ -1243,7 +1218,7 @@ func isEqualStringPointer(l, r *string) bool { return *l == *r } -func ipPermissionExists(newPermission, existing *ec2.IpPermission, compareGroupUserIDs bool) bool { +func ipPermissionExists(newPermission, existing *ec2types.IpPermission, compareGroupUserIDs bool) bool { if !isEqualIntPointer(newPermission.FromPort, existing.FromPort) { return false } @@ -1276,7 +1251,7 @@ func ipPermissionExists(newPermission, existing *ec2.IpPermission, compareGroupU for _, leftPair := range newPermission.UserIdGroupPairs { found := false for _, rightPair := range existing.UserIdGroupPairs { - if isEqualUserGroupPair(leftPair, rightPair, compareGroupUserIDs) { + if isEqualUserGroupPair(&leftPair, &rightPair, compareGroupUserIDs) { found = true break } @@ -1289,7 +1264,7 @@ func ipPermissionExists(newPermission, existing *ec2.IpPermission, compareGroupU return true } -func isEqualUserGroupPair(l, r *ec2.UserIdGroupPair, compareGroupUserIDs bool) bool { +func isEqualUserGroupPair(l, r *ec2types.UserIdGroupPair, compareGroupUserIDs bool) bool { klog.V(2).Infof("Comparing %v to %v", *l.GroupId, *r.GroupId) if isEqualStringPointer(l.GroupId, r.GroupId) { if compareGroupUserIDs { @@ -1307,8 +1282,8 @@ func isEqualUserGroupPair(l, r *ec2.UserIdGroupPair, compareGroupUserIDs bool) b // Makes sure the security group ingress is exactly the specified permissions // Returns true if and only if changes were made // The security group must already exist -func (c *Cloud) setSecurityGroupIngress(securityGroupID string, permissions IPPermissionSet) (bool, error) { - group, err := c.findSecurityGroup(securityGroupID) +func (c *Cloud) setSecurityGroupIngress(ctx context.Context, securityGroupID string, permissions IPPermissionSet) (bool, error) { + group, err := c.findSecurityGroup(ctx, securityGroupID) if err != nil { klog.Warningf("Error retrieving security group %q", err) return false, err @@ -1354,7 +1329,7 @@ func (c *Cloud) setSecurityGroupIngress(securityGroupID string, permissions IPPe request := &ec2.AuthorizeSecurityGroupIngressInput{} request.GroupId = &securityGroupID request.IpPermissions = add.List() - _, err = c.ec2.AuthorizeSecurityGroupIngress(request) + _, err = c.ec2.AuthorizeSecurityGroupIngress(ctx, request) if err != nil { return false, fmt.Errorf("error authorizing security group ingress: %q", err) } @@ -1365,7 +1340,7 @@ func (c *Cloud) setSecurityGroupIngress(securityGroupID string, permissions IPPe request := &ec2.RevokeSecurityGroupIngressInput{} request.GroupId = &securityGroupID request.IpPermissions = remove.List() - _, err = c.ec2.RevokeSecurityGroupIngress(request) + _, err = c.ec2.RevokeSecurityGroupIngress(ctx, request) if err != nil { return false, fmt.Errorf("error revoking security group ingress: %q", err) } @@ -1377,13 +1352,13 @@ func (c *Cloud) setSecurityGroupIngress(securityGroupID string, permissions IPPe // Makes sure the security group includes the specified permissions // Returns true if and only if changes were made // The security group must already exist -func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions []*ec2.IpPermission) (bool, error) { +func (c *Cloud) addSecurityGroupIngress(ctx context.Context, securityGroupID string, addPermissions []ec2types.IpPermission) (bool, error) { // We do not want to make changes to the Global defined SG if securityGroupID == c.cfg.Global.ElbSecurityGroup { return false, nil } - group, err := c.findSecurityGroup(securityGroupID) + group, err := c.findSecurityGroup(ctx, securityGroupID) if err != nil { klog.Warningf("Error retrieving security group: %q", err) return false, err @@ -1395,7 +1370,7 @@ func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions [ klog.V(2).Infof("Existing security group ingress: %s %v", securityGroupID, group.IpPermissions) - changes := []*ec2.IpPermission{} + changes := []ec2types.IpPermission{} for _, addPermission := range addPermissions { hasUserID := false for i := range addPermission.UserIdGroupPairs { @@ -1406,7 +1381,7 @@ func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions [ found := false for _, groupPermission := range group.IpPermissions { - if ipPermissionExists(addPermission, groupPermission, hasUserID) { + if ipPermissionExists(&addPermission, &groupPermission, hasUserID) { found = true break } @@ -1426,7 +1401,7 @@ func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions [ request := &ec2.AuthorizeSecurityGroupIngressInput{} request.GroupId = &securityGroupID request.IpPermissions = changes - _, err = c.ec2.AuthorizeSecurityGroupIngress(request) + _, err = c.ec2.AuthorizeSecurityGroupIngress(ctx, request) if err != nil { klog.Warningf("Error authorizing security group ingress %q", err) return false, fmt.Errorf("error authorizing security group ingress: %q", err) @@ -1438,13 +1413,13 @@ func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions [ // Makes sure the security group no longer includes the specified permissions // Returns true if and only if changes were made // If the security group no longer exists, will return (false, nil) -func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermissions []*ec2.IpPermission) (bool, error) { +func (c *Cloud) removeSecurityGroupIngress(ctx context.Context, securityGroupID string, removePermissions []ec2types.IpPermission) (bool, error) { // We do not want to make changes to the Global defined SG if securityGroupID == c.cfg.Global.ElbSecurityGroup { return false, nil } - group, err := c.findSecurityGroup(securityGroupID) + group, err := c.findSecurityGroup(ctx, securityGroupID) if err != nil { klog.Warningf("Error retrieving security group: %q", err) return false, err @@ -1455,7 +1430,7 @@ func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermiss return false, nil } - changes := []*ec2.IpPermission{} + changes := []ec2types.IpPermission{} for _, removePermission := range removePermissions { hasUserID := false for i := range removePermission.UserIdGroupPairs { @@ -1464,16 +1439,16 @@ func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermiss } } - var found *ec2.IpPermission + var found *ec2types.IpPermission for _, groupPermission := range group.IpPermissions { - if ipPermissionExists(removePermission, groupPermission, hasUserID) { - found = removePermission + if ipPermissionExists(&removePermission, &groupPermission, hasUserID) { + found = &removePermission break } } if found != nil { - changes = append(changes, found) + changes = append(changes, *found) } } @@ -1486,7 +1461,7 @@ func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermiss request := &ec2.RevokeSecurityGroupIngressInput{} request.GroupId = &securityGroupID request.IpPermissions = changes - _, err = c.ec2.RevokeSecurityGroupIngress(request) + _, err = c.ec2.RevokeSecurityGroupIngress(ctx, request) if err != nil { klog.Warningf("Error revoking security group ingress: %q", err) return false, err @@ -1499,7 +1474,7 @@ func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermiss // For multi-cluster isolation, name must be globally unique, for example derived from the service UUID. // Additional tags can be specified // Returns the security group id or error -func (c *Cloud) ensureSecurityGroup(name string, description string, additionalTags map[string]string) (string, error) { +func (c *Cloud) ensureSecurityGroup(ctx context.Context, name string, description string, additionalTags map[string]string) (string, error) { groupID := "" attempt := 0 for { @@ -1511,12 +1486,12 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT // If it has a different cluster's tags, that is an error. // This shouldn't happen because name is expected to be globally unique (UUID derived) request := &ec2.DescribeSecurityGroupsInput{} - request.Filters = []*ec2.Filter{ + request.Filters = []ec2types.Filter{ newEc2Filter("group-name", name), newEc2Filter("vpc-id", c.vpcID), } - securityGroups, err := c.ec2.DescribeSecurityGroups(request) + securityGroups, err := c.ec2.DescribeSecurityGroups(ctx, request) if err != nil { return "", err } @@ -1525,14 +1500,14 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT if len(securityGroups) > 1 { klog.Warningf("Found multiple security groups with name: %q", name) } - err := c.tagging.readRepairClusterTags( - c.ec2, aws.StringValue(securityGroups[0].GroupId), + err := c.tagging.readRepairClusterTags(ctx, + c.ec2, aws.ToString(securityGroups[0].GroupId), ResourceLifecycleOwned, nil, securityGroups[0].Tags) if err != nil { return "", err } - return aws.StringValue(securityGroups[0].GroupId), nil + return aws.ToString(securityGroups[0].GroupId), nil } createRequest := &ec2.CreateSecurityGroupInput{} @@ -1540,27 +1515,27 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT createRequest.GroupName = &name createRequest.Description = &description tags := c.tagging.buildTags(ResourceLifecycleOwned, additionalTags) - var awsTags []*ec2.Tag + var awsTags []ec2types.Tag for k, v := range tags { - tag := &ec2.Tag{ + tag := ec2types.Tag{ Key: aws.String(k), Value: aws.String(v), } awsTags = append(awsTags, tag) } - createRequest.TagSpecifications = []*ec2.TagSpecification{ + createRequest.TagSpecifications = []ec2types.TagSpecification{ { - ResourceType: aws.String(ec2.ResourceTypeSecurityGroup), + ResourceType: ec2types.ResourceTypeSecurityGroup, Tags: awsTags, }, } - createResponse, err := c.ec2.CreateSecurityGroup(createRequest) + createResponse, err := c.ec2.CreateSecurityGroup(ctx, createRequest) if err != nil { ignore := false - switch err := err.(type) { - case awserr.Error: - if err.Code() == "InvalidGroup.Duplicate" && attempt < MaxReadThenCreateRetries { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "InvalidGroup.Duplicate" && attempt < MaxReadThenCreateRetries { klog.V(2).Infof("Got InvalidGroup.Duplicate while creating security group (race?); will retry") ignore = true } @@ -1571,7 +1546,7 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT } time.Sleep(1 * time.Second) } else { - groupID = aws.StringValue(createResponse.GroupId) + groupID = aws.ToString(createResponse.GroupId) break } } @@ -1583,10 +1558,10 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT } // Finds the value for a given tag. -func findTag(tags []*ec2.Tag, key string) (string, bool) { +func findTag(tags []ec2types.Tag, key string) (string, bool) { for _, tag := range tags { - if aws.StringValue(tag.Key) == key { - return aws.StringValue(tag.Value), true + if aws.ToString(tag.Key) == key { + return aws.ToString(tag.Value), true } } return "", false @@ -1595,16 +1570,16 @@ func findTag(tags []*ec2.Tag, key string) (string, bool) { // Finds the subnets associated with the cluster, by matching cluster tags if present. // For maximal backwards compatibility, if no subnets are tagged, it will fall-back to the current subnet. // However, in future this will likely be treated as an error. -func (c *Cloud) findSubnets() ([]*ec2.Subnet, error) { +func (c *Cloud) findSubnets(ctx context.Context) ([]ec2types.Subnet, error) { request := &ec2.DescribeSubnetsInput{} - request.Filters = []*ec2.Filter{newEc2Filter("vpc-id", c.vpcID)} + request.Filters = []ec2types.Filter{newEc2Filter("vpc-id", c.vpcID)} - subnets, err := c.ec2.DescribeSubnets(request) + subnets, err := c.ec2.DescribeSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error describing subnets: %q", err) } - var matches []*ec2.Subnet + var matches []ec2types.Subnet for _, subnet := range subnets { if c.tagging.hasClusterTag(subnet.Tags) { matches = append(matches, subnet) @@ -1621,9 +1596,9 @@ func (c *Cloud) findSubnets() ([]*ec2.Subnet, error) { klog.Warningf("No tagged subnets found; will fall-back to the current subnet only. This is likely to be an error in a future version of k8s.") request = &ec2.DescribeSubnetsInput{} - request.Filters = []*ec2.Filter{newEc2Filter("subnet-id", c.selfAWSInstance.subnetID)} + request.Filters = []ec2types.Filter{newEc2Filter("subnet-id", c.selfAWSInstance.subnetID)} - subnets, err = c.ec2.DescribeSubnets(request) + subnets, err = c.ec2.DescribeSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error describing subnets: %q", err) } @@ -1634,25 +1609,25 @@ func (c *Cloud) findSubnets() ([]*ec2.Subnet, error) { // Finds the subnets to use for an ELB we are creating. // Normal (Internet-facing) ELBs must use public subnets, so we skip private subnets. // Internal ELBs can use public or private subnets, but if we have a private subnet we should prefer that. -func (c *Cloud) findELBSubnets(internalELB bool) ([]string, error) { +func (c *Cloud) findELBSubnets(ctx context.Context, internalELB bool) ([]string, error) { vpcIDFilter := newEc2Filter("vpc-id", c.vpcID) - subnets, err := c.findSubnets() + subnets, err := c.findSubnets(ctx) if err != nil { return nil, err } rRequest := &ec2.DescribeRouteTablesInput{} - rRequest.Filters = []*ec2.Filter{vpcIDFilter} - rt, err := c.ec2.DescribeRouteTables(rRequest) + rRequest.Filters = []ec2types.Filter{vpcIDFilter} + rt, err := c.ec2.DescribeRouteTables(ctx, rRequest) if err != nil { return nil, fmt.Errorf("error describe route table: %q", err) } - subnetsByAZ := make(map[string]*ec2.Subnet) + subnetsByAZ := make(map[string]ec2types.Subnet) for _, subnet := range subnets { - az := aws.StringValue(subnet.AvailabilityZone) - id := aws.StringValue(subnet.SubnetId) + az := aws.ToString(subnet.AvailabilityZone) + id := aws.ToString(subnet.SubnetId) if az == "" || id == "" { klog.Warningf("Ignoring subnet with empty az/id: %v", subnet) continue @@ -1667,8 +1642,8 @@ func (c *Cloud) findELBSubnets(internalELB bool) ([]string, error) { continue } - existing := subnetsByAZ[az] - if existing == nil { + existing, exists := subnetsByAZ[az] + if !exists { subnetsByAZ[az] = subnet continue } @@ -1719,7 +1694,7 @@ func (c *Cloud) findELBSubnets(internalELB bool) ([]string, error) { sort.Strings(azNames) - zoneNameToDetails, err := c.zoneCache.getZoneDetailsByNames(azNames) + zoneNameToDetails, err := c.zoneCache.getZoneDetailsByNames(ctx, azNames) if err != nil { return nil, fmt.Errorf("error get availability zone types: %q", err) } @@ -1733,7 +1708,7 @@ func (c *Cloud) findELBSubnets(internalELB bool) ([]string, error) { // does not support NLB/CLB for the moment, only ALB. continue } - subnetIDs = append(subnetIDs, aws.StringValue(subnetsByAZ[zone].SubnetId)) + subnetIDs = append(subnetIDs, aws.ToString(subnetsByAZ[zone].SubnetId)) } return subnetIDs, nil @@ -1762,15 +1737,15 @@ func parseStringSliceAnnotation(annotations map[string]string, annotation string return true } -func (c *Cloud) getLoadBalancerSubnets(service *v1.Service, internalELB bool) ([]string, error) { +func (c *Cloud) getLoadBalancerSubnets(ctx context.Context, service *v1.Service, internalELB bool) ([]string, error) { var rawSubnetNameOrIDs []string if exists := parseStringSliceAnnotation(service.Annotations, ServiceAnnotationLoadBalancerSubnets, &rawSubnetNameOrIDs); exists { - return c.resolveSubnetNameOrIDs(rawSubnetNameOrIDs) + return c.resolveSubnetNameOrIDs(ctx, rawSubnetNameOrIDs) } - return c.findELBSubnets(internalELB) + return c.findELBSubnets(ctx, internalELB) } -func (c *Cloud) resolveSubnetNameOrIDs(subnetNameOrIDs []string) ([]string, error) { +func (c *Cloud) resolveSubnetNameOrIDs(ctx context.Context, subnetNameOrIDs []string) ([]string, error) { var subnetIDs []string var subnetNames []string if len(subnetNameOrIDs) == 0 { @@ -1783,12 +1758,12 @@ func (c *Cloud) resolveSubnetNameOrIDs(subnetNameOrIDs []string) ([]string, erro subnetNames = append(subnetNames, nameOrID) } } - var resolvedSubnets []*ec2.Subnet + var resolvedSubnets []ec2types.Subnet if len(subnetIDs) > 0 { req := &ec2.DescribeSubnetsInput{ - SubnetIds: aws.StringSlice(subnetIDs), + SubnetIds: subnetIDs, } - subnets, err := c.ec2.DescribeSubnets(req) + subnets, err := c.ec2.DescribeSubnets(ctx, req) if err != nil { return []string{}, err } @@ -1796,18 +1771,18 @@ func (c *Cloud) resolveSubnetNameOrIDs(subnetNameOrIDs []string) ([]string, erro } if len(subnetNames) > 0 { req := &ec2.DescribeSubnetsInput{ - Filters: []*ec2.Filter{ + Filters: []ec2types.Filter{ { Name: aws.String("tag:Name"), - Values: aws.StringSlice(subnetNames), + Values: subnetNames, }, { Name: aws.String("vpc-id"), - Values: aws.StringSlice([]string{c.vpcID}), + Values: []string{c.vpcID}, }, }, } - subnets, err := c.ec2.DescribeSubnets(req) + subnets, err := c.ec2.DescribeSubnets(ctx, req) if err != nil { return []string{}, err } @@ -1818,17 +1793,17 @@ func (c *Cloud) resolveSubnetNameOrIDs(subnetNameOrIDs []string) ([]string, erro } var subnets []string for _, subnet := range resolvedSubnets { - subnets = append(subnets, aws.StringValue(subnet.SubnetId)) + subnets = append(subnets, aws.ToString(subnet.SubnetId)) } return subnets, nil } -func isSubnetPublic(rt []*ec2.RouteTable, subnetID string) (bool, error) { - var subnetTable *ec2.RouteTable +func isSubnetPublic(rt []ec2types.RouteTable, subnetID string) (bool, error) { + var subnetTable *ec2types.RouteTable for _, table := range rt { for _, assoc := range table.Associations { - if aws.StringValue(assoc.SubnetId) == subnetID { - subnetTable = table + if aws.ToString(assoc.SubnetId) == subnetID { + subnetTable = &table break } } @@ -1839,10 +1814,10 @@ func isSubnetPublic(rt []*ec2.RouteTable, subnetID string) (bool, error) { // associated with the VPC's main routing table. for _, table := range rt { for _, assoc := range table.Associations { - if aws.BoolValue(assoc.Main) == true { + if aws.ToBool(assoc.Main) == true { klog.V(4).Infof("Assuming implicit use of main routing table %s for %s", - aws.StringValue(table.RouteTableId), subnetID) - subnetTable = table + aws.ToString(table.RouteTableId), subnetID) + subnetTable = &table break } } @@ -1860,7 +1835,7 @@ func isSubnetPublic(rt []*ec2.RouteTable, subnetID string) (bool, error) { // from the default in-subnet route which is called "local" // or other virtual gateway (starting with vgv) // or vpc peering connections (starting with pcx). - if strings.HasPrefix(aws.StringValue(route.GatewayId), "igw") { + if strings.HasPrefix(aws.ToString(route.GatewayId), "igw") { return true, nil } } @@ -1869,8 +1844,8 @@ func isSubnetPublic(rt []*ec2.RouteTable, subnetID string) (bool, error) { } type portSets struct { - names sets.String - numbers sets.Int64 + names sets.Set[string] + numbers sets.Set[int32] } // getPortSets returns a portSets structure representing port names and numbers @@ -1879,8 +1854,8 @@ type portSets struct { func getPortSets(annotation string) (ports *portSets) { if annotation != "" && annotation != "*" { ports = &portSets{ - sets.NewString(), - sets.NewInt64(), + sets.New[string](), + sets.New[int32](), } portStringSlice := strings.Split(annotation, ",") for _, item := range portStringSlice { @@ -1888,7 +1863,7 @@ func getPortSets(annotation string) (ports *portSets) { if err != nil { ports.names.Insert(item) } else { - ports.numbers.Insert(int64(port)) + ports.numbers.Insert(int32(port)) } } } @@ -1913,7 +1888,7 @@ func getSGListFromAnnotation(annotatedSG string) []string { // Extra groups can be specified via annotation, as can extra tags for any // new groups. The annotation "ServiceAnnotationLoadBalancerSecurityGroups" allows for // setting the security groups specified. -func (c *Cloud) buildELBSecurityGroupList(serviceName types.NamespacedName, loadBalancerName string, annotations map[string]string) ([]string, bool, error) { +func (c *Cloud) buildELBSecurityGroupList(ctx context.Context, serviceName types.NamespacedName, loadBalancerName string, annotations map[string]string) ([]string, bool, error) { var err error var securityGroupID string // We do not want to make changes to a Global defined SG @@ -1929,7 +1904,7 @@ func (c *Cloud) buildELBSecurityGroupList(serviceName types.NamespacedName, load // Create a security group for the load balancer sgName := "k8s-elb-" + loadBalancerName sgDescription := fmt.Sprintf("Security group for Kubernetes ELB %s (%v)", loadBalancerName, serviceName) - securityGroupID, err = c.ensureSecurityGroup(sgName, sgDescription, getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags)) + securityGroupID, err = c.ensureSecurityGroup(ctx, sgName, sgDescription, getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags)) if err != nil { klog.Errorf("Error creating load balancer security group: %q", err) return nil, setupSg, err @@ -1978,16 +1953,16 @@ func (c *Cloud) sortELBSecurityGroupList(securityGroupIDs []string, annotations // buildListener creates a new listener from the given port, adding an SSL certificate // if indicated by the appropriate annotations. -func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts *portSets) (*elb.Listener, error) { - loadBalancerPort := int64(port.Port) +func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts *portSets) (elbtypes.Listener, error) { + loadBalancerPort := port.Port portName := strings.ToLower(port.Name) - instancePort := int64(port.NodePort) + instancePort := port.NodePort protocol := strings.ToLower(string(port.Protocol)) instanceProtocol := protocol - listener := &elb.Listener{} + listener := elbtypes.Listener{} listener.InstancePort = &instancePort - listener.LoadBalancerPort = &loadBalancerPort + listener.LoadBalancerPort = loadBalancerPort certID := annotations[ServiceAnnotationLoadBalancerCertificate] if certID != "" && (sslPorts == nil || sslPorts.numbers.Has(loadBalancerPort) || sslPorts.names.Has(portName)) { instanceProtocol = annotations[ServiceAnnotationLoadBalancerBEProtocol] @@ -1997,7 +1972,7 @@ func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts } else { protocol = backendProtocolMapping[instanceProtocol] if protocol == "" { - return nil, fmt.Errorf("Invalid backend protocol %s for %s in %s", instanceProtocol, certID, ServiceAnnotationLoadBalancerBEProtocol) + return elbtypes.Listener{}, fmt.Errorf("Invalid backend protocol %s for %s in %s", instanceProtocol, certID, ServiceAnnotationLoadBalancerBEProtocol) } } listener.SSLCertificateId = &certID @@ -2012,13 +1987,13 @@ func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts return listener, nil } -func (c *Cloud) getSubnetCidrs(subnetIDs []string) ([]string, error) { +func (c *Cloud) getSubnetCidrs(ctx context.Context, subnetIDs []string) ([]string, error) { request := &ec2.DescribeSubnetsInput{} for _, subnetID := range subnetIDs { - request.SubnetIds = append(request.SubnetIds, aws.String(subnetID)) + request.SubnetIds = append(request.SubnetIds, subnetID) } - subnets, err := c.ec2.DescribeSubnets(request) + subnets, err := c.ec2.DescribeSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error querying Subnet for ELB: %q", err) } @@ -2028,7 +2003,7 @@ func (c *Cloud) getSubnetCidrs(subnetIDs []string) ([]string, error) { cidrs := make([]string, 0, len(subnets)) for _, subnet := range subnets { - cidrs = append(cidrs, aws.StringValue(subnet.CidrBlock)) + cidrs = append(cidrs, aws.ToString(subnet.CidrBlock)) } return cidrs, nil } @@ -2041,9 +2016,10 @@ func parseStringAnnotation(annotations map[string]string, annotation string, val return false } -func parseInt64Annotation(annotations map[string]string, annotation string, value *int64) (bool, error) { +func parseInt32Annotation(annotations map[string]string, annotation string, value *int32) (bool, error) { if v, ok := annotations[annotation]; ok { - parsed, err := strconv.ParseInt(v, 10, 0) + parsed64, err := strconv.ParseInt(v, 10, 0) + parsed := int32(parsed64) if err != nil { return true, fmt.Errorf("failed to parse annotation %v=%v", annotation, v) } @@ -2057,7 +2033,7 @@ func (c *Cloud) buildNLBHealthCheckConfiguration(svc *v1.Service) (healthCheckCo hc := healthCheckConfig{ Port: defaultHealthCheckPort, Path: defaultHealthCheckPath, - Protocol: elbv2.ProtocolEnumTcp, + Protocol: elbv2types.ProtocolEnumTcp, Interval: defaultNlbHealthCheckInterval, Timeout: defaultNlbHealthCheckTimeout, HealthyThreshold: defaultNlbHealthCheckThreshold, @@ -2068,20 +2044,22 @@ func (c *Cloud) buildNLBHealthCheckConfiguration(svc *v1.Service) (healthCheckCo hc = healthCheckConfig{ Port: strconv.Itoa(int(port)), Path: path, - Protocol: elbv2.ProtocolEnumHttp, + Protocol: elbv2types.ProtocolEnumHttp, Interval: 10, Timeout: 10, HealthyThreshold: 2, UnhealthyThreshold: 2, } } - if parseStringAnnotation(svc.Annotations, ServiceAnnotationLoadBalancerHealthCheckProtocol, &hc.Protocol) { - hc.Protocol = strings.ToUpper(hc.Protocol) + + var protocolStr string = string(hc.Protocol) + if parseStringAnnotation(svc.Annotations, ServiceAnnotationLoadBalancerHealthCheckProtocol, &protocolStr) { + hc.Protocol = elbv2types.ProtocolEnum(strings.ToUpper(protocolStr)) } switch hc.Protocol { - case elbv2.ProtocolEnumHttp, elbv2.ProtocolEnumHttps: + case elbv2types.ProtocolEnumHttp, elbv2types.ProtocolEnumHttps: parseStringAnnotation(svc.Annotations, ServiceAnnotationLoadBalancerHealthCheckPath, &hc.Path) - case elbv2.ProtocolEnumTcp: + case elbv2types.ProtocolEnumTcp: hc.Path = "" default: return healthCheckConfig{}, fmt.Errorf("Unsupported health check protocol %v", hc.Protocol) @@ -2089,16 +2067,16 @@ func (c *Cloud) buildNLBHealthCheckConfiguration(svc *v1.Service) (healthCheckCo parseStringAnnotation(svc.Annotations, ServiceAnnotationLoadBalancerHealthCheckPort, &hc.Port) - if _, err := parseInt64Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCInterval, &hc.Interval); err != nil { + if _, err := parseInt32Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCInterval, &hc.Interval); err != nil { return healthCheckConfig{}, err } - if _, err := parseInt64Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCTimeout, &hc.Timeout); err != nil { + if _, err := parseInt32Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCTimeout, &hc.Timeout); err != nil { return healthCheckConfig{}, err } - if _, err := parseInt64Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCHealthyThreshold, &hc.HealthyThreshold); err != nil { + if _, err := parseInt32Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCHealthyThreshold, &hc.HealthyThreshold); err != nil { return healthCheckConfig{}, err } - if _, err := parseInt64Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCUnhealthyThreshold, &hc.UnhealthyThreshold); err != nil { + if _, err := parseInt32Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCUnhealthyThreshold, &hc.UnhealthyThreshold); err != nil { return healthCheckConfig{}, err } @@ -2131,7 +2109,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS return nil, err } // Figure out what mappings we want on the load balancer - listeners := []*elb.Listener{} + listeners := []elbtypes.Listener{} v2Mappings := []nlbPortMapping{} sslPorts := getPortSets(annotations[ServiceAnnotationLoadBalancerSSLPorts]) @@ -2147,10 +2125,10 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS if isNLB(annotations) { portMapping := nlbPortMapping{ - FrontendPort: int64(port.Port), - FrontendProtocol: string(port.Protocol), - TrafficPort: int64(port.NodePort), - TrafficProtocol: string(port.Protocol), + FrontendPort: int32(port.Port), + FrontendProtocol: elbv2types.ProtocolEnum(port.Protocol), + TrafficPort: int32(port.NodePort), + TrafficProtocol: elbv2types.ProtocolEnum(port.Protocol), } var err error if portMapping.HealthCheckConfig, err = c.buildNLBHealthCheckConfiguration(apiService); err != nil { @@ -2158,13 +2136,13 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } certificateARN := annotations[ServiceAnnotationLoadBalancerCertificate] - if port.Protocol != v1.ProtocolUDP && certificateARN != "" && (sslPorts == nil || sslPorts.numbers.Has(int64(port.Port)) || sslPorts.names.Has(port.Name)) { - portMapping.FrontendProtocol = elbv2.ProtocolEnumTls + if port.Protocol != v1.ProtocolUDP && certificateARN != "" && (sslPorts == nil || sslPorts.numbers.Has(port.Port) || sslPorts.names.Has(port.Name)) { + portMapping.FrontendProtocol = elbv2types.ProtocolEnumTls portMapping.SSLCertificateARN = certificateARN portMapping.SSLPolicy = annotations[ServiceAnnotationLoadBalancerSSLNegotiationPolicy] if backendProtocol := annotations[ServiceAnnotationLoadBalancerBEProtocol]; backendProtocol == "ssl" { - portMapping.TrafficProtocol = elbv2.ProtocolEnumTls + portMapping.TrafficProtocol = elbv2types.ProtocolEnumTls } } @@ -2182,7 +2160,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS return nil, fmt.Errorf("LoadBalancerIP cannot be specified for AWS ELB") } - instances, err := c.findInstancesForELB(nodes, annotations) + instances, err := c.findInstancesForELB(ctx, nodes, annotations) if err != nil { return nil, err } @@ -2203,7 +2181,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS if isNLB(annotations) { // Find the subnets that the ELB will live in - discoveredSubnetIDs, err := c.getLoadBalancerSubnets(apiService, internalELB) + discoveredSubnetIDs, err := c.getLoadBalancerSubnets(ctx, apiService, internalELB) if err != nil { klog.Errorf("Error listing subnets in VPC: %q", err) return nil, err @@ -2221,7 +2199,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS instanceIDs = append(instanceIDs, string(id)) } - v2LoadBalancer, err := c.ensureLoadBalancerv2( + v2LoadBalancer, err := c.ensureLoadBalancerv2(ctx, serviceName, loadBalancerName, v2Mappings, @@ -2243,7 +2221,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS if len(ensuredSubnetIDs) == 0 { return nil, fmt.Errorf("did not find ensured subnets on LB %s", loadBalancerName) } - subnetCidrs, err = c.getSubnetCidrs(ensuredSubnetIDs) + subnetCidrs, err = c.getSubnetCidrs(ctx, ensuredSubnetIDs) if err != nil { klog.Errorf("Error getting subnet cidrs: %q", err) return nil, err @@ -2257,7 +2235,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS sourceRangeCidrs = append(sourceRangeCidrs, "0.0.0.0/0") } - err = c.updateInstanceSecurityGroupsForNLB(loadBalancerName, instances, subnetCidrs, sourceRangeCidrs, v2Mappings) + err = c.updateInstanceSecurityGroupsForNLB(ctx, loadBalancerName, instances, subnetCidrs, sourceRangeCidrs, v2Mappings) if err != nil { klog.Warningf("Error opening ingress rules for the load balancer to the instances: %q", err) return nil, err @@ -2281,24 +2259,24 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } // Some load balancer attributes are required, so defaults are set. These can be overridden by annotations. - loadBalancerAttributes := &elb.LoadBalancerAttributes{ - AccessLog: &elb.AccessLog{Enabled: aws.Bool(false)}, - ConnectionDraining: &elb.ConnectionDraining{Enabled: aws.Bool(false)}, - ConnectionSettings: &elb.ConnectionSettings{IdleTimeout: aws.Int64(60)}, - CrossZoneLoadBalancing: &elb.CrossZoneLoadBalancing{Enabled: aws.Bool(false)}, + loadBalancerAttributes := &elbtypes.LoadBalancerAttributes{ + AccessLog: &elbtypes.AccessLog{Enabled: false}, + ConnectionDraining: &elbtypes.ConnectionDraining{Enabled: false}, + ConnectionSettings: &elbtypes.ConnectionSettings{IdleTimeout: aws.Int32(60)}, + CrossZoneLoadBalancing: &elbtypes.CrossZoneLoadBalancing{Enabled: false}, } // Determine if an access log emit interval has been specified accessLogEmitIntervalAnnotation := annotations[ServiceAnnotationLoadBalancerAccessLogEmitInterval] if accessLogEmitIntervalAnnotation != "" { - accessLogEmitInterval, err := strconv.ParseInt(accessLogEmitIntervalAnnotation, 10, 64) + accessLogEmitInterval, err := strconv.ParseInt(accessLogEmitIntervalAnnotation, 10, 32) if err != nil { return nil, fmt.Errorf("error parsing service annotation: %s=%s", ServiceAnnotationLoadBalancerAccessLogEmitInterval, accessLogEmitIntervalAnnotation, ) } - loadBalancerAttributes.AccessLog.EmitInterval = &accessLogEmitInterval + loadBalancerAttributes.AccessLog.EmitInterval = aws.Int32(int32(accessLogEmitInterval)) } // Determine if access log enabled/disabled has been specified @@ -2311,7 +2289,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS accessLogEnabledAnnotation, ) } - loadBalancerAttributes.AccessLog.Enabled = &accessLogEnabled + loadBalancerAttributes.AccessLog.Enabled = accessLogEnabled } // Determine if access log s3 bucket name has been specified @@ -2336,33 +2314,33 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS connectionDrainingEnabledAnnotation, ) } - loadBalancerAttributes.ConnectionDraining.Enabled = &connectionDrainingEnabled + loadBalancerAttributes.ConnectionDraining.Enabled = connectionDrainingEnabled } // Determine if connection draining timeout has been specified connectionDrainingTimeoutAnnotation := annotations[ServiceAnnotationLoadBalancerConnectionDrainingTimeout] if connectionDrainingTimeoutAnnotation != "" { - connectionDrainingTimeout, err := strconv.ParseInt(connectionDrainingTimeoutAnnotation, 10, 64) + connectionDrainingTimeout, err := strconv.ParseInt(connectionDrainingTimeoutAnnotation, 10, 32) if err != nil { return nil, fmt.Errorf("error parsing service annotation: %s=%s", ServiceAnnotationLoadBalancerConnectionDrainingTimeout, connectionDrainingTimeoutAnnotation, ) } - loadBalancerAttributes.ConnectionDraining.Timeout = &connectionDrainingTimeout + loadBalancerAttributes.ConnectionDraining.Timeout = aws.Int32(int32(connectionDrainingTimeout)) } // Determine if connection idle timeout has been specified connectionIdleTimeoutAnnotation := annotations[ServiceAnnotationLoadBalancerConnectionIdleTimeout] if connectionIdleTimeoutAnnotation != "" { - connectionIdleTimeout, err := strconv.ParseInt(connectionIdleTimeoutAnnotation, 10, 64) + connectionIdleTimeout, err := strconv.ParseInt(connectionIdleTimeoutAnnotation, 10, 32) if err != nil { return nil, fmt.Errorf("error parsing service annotation: %s=%s", ServiceAnnotationLoadBalancerConnectionIdleTimeout, connectionIdleTimeoutAnnotation, ) } - loadBalancerAttributes.ConnectionSettings.IdleTimeout = &connectionIdleTimeout + loadBalancerAttributes.ConnectionSettings.IdleTimeout = aws.Int32(int32(connectionIdleTimeout)) } // Determine if cross zone load balancing enabled/disabled has been specified @@ -2375,11 +2353,11 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS crossZoneLoadBalancingEnabledAnnotation, ) } - loadBalancerAttributes.CrossZoneLoadBalancing.Enabled = &crossZoneLoadBalancingEnabled + loadBalancerAttributes.CrossZoneLoadBalancing.Enabled = crossZoneLoadBalancingEnabled } // Find the subnets that the ELB will live in - subnetIDs, err := c.getLoadBalancerSubnets(apiService, internalELB) + subnetIDs, err := c.getLoadBalancerSubnets(ctx, apiService, internalELB) if err != nil { klog.Errorf("Error listing subnets in VPC: %q", err) return nil, err @@ -2392,7 +2370,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS loadBalancerName := c.GetLoadBalancerName(ctx, clusterName, apiService) serviceName := types.NamespacedName{Namespace: apiService.Namespace, Name: apiService.Name} - securityGroupIDs, setupSg, err := c.buildELBSecurityGroupList(serviceName, loadBalancerName, annotations) + securityGroupIDs, setupSg, err := c.buildELBSecurityGroupList(ctx, serviceName, loadBalancerName, annotations) if err != nil { return nil, err } @@ -2401,19 +2379,18 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } if setupSg { - ec2SourceRanges := []*ec2.IpRange{} + ec2SourceRanges := []ec2types.IpRange{} for _, sourceRange := range sourceRanges.StringSlice() { - ec2SourceRanges = append(ec2SourceRanges, &ec2.IpRange{CidrIp: aws.String(sourceRange)}) + ec2SourceRanges = append(ec2SourceRanges, ec2types.IpRange{CidrIp: aws.String(sourceRange)}) } permissions := NewIPPermissionSet() for _, port := range apiService.Spec.Ports { - portInt64 := int64(port.Port) protocol := strings.ToLower(string(port.Protocol)) - permission := &ec2.IpPermission{} - permission.FromPort = &portInt64 - permission.ToPort = &portInt64 + permission := ec2types.IpPermission{} + permission.FromPort = aws.Int32(port.Port) + permission.ToPort = aws.Int32(port.Port) permission.IpRanges = ec2SourceRanges permission.IpProtocol = &protocol @@ -2422,23 +2399,23 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS // Allow ICMP fragmentation packets, important for MTU discovery { - permission := &ec2.IpPermission{ + permission := ec2types.IpPermission{ IpProtocol: aws.String("icmp"), - FromPort: aws.Int64(3), - ToPort: aws.Int64(4), + FromPort: aws.Int32(3), + ToPort: aws.Int32(4), IpRanges: ec2SourceRanges, } permissions.Insert(permission) } - _, err = c.setSecurityGroupIngress(securityGroupIDs[0], permissions) + _, err = c.setSecurityGroupIngress(ctx, securityGroupIDs[0], permissions) if err != nil { return nil, err } } // Build the load balancer itself - loadBalancer, err := c.ensureLoadBalancer( + loadBalancer, err := c.ensureLoadBalancer(ctx, serviceName, loadBalancerName, listeners, @@ -2454,13 +2431,13 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } if sslPolicyName, ok := annotations[ServiceAnnotationLoadBalancerSSLNegotiationPolicy]; ok { - err := c.ensureSSLNegotiationPolicy(loadBalancer, sslPolicyName) + err := c.ensureSSLNegotiationPolicy(ctx, loadBalancer, sslPolicyName) if err != nil { return nil, err } for _, port := range c.getLoadBalancerTLSPorts(loadBalancer) { - err := c.setSSLNegotiationPolicy(loadBalancerName, sslPolicyName, port) + err := c.setSSLNegotiationPolicy(ctx, loadBalancerName, sslPolicyName, port) if err != nil { return nil, err } @@ -2481,7 +2458,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS if annotations[ServiceAnnotationLoadBalancerHealthCheckPort] == defaultHealthCheckPort { healthCheckNodePort = tcpHealthCheckPort } - err = c.ensureLoadBalancerHealthCheck(loadBalancer, "HTTP", healthCheckNodePort, path, annotations) + err = c.ensureLoadBalancerHealthCheck(ctx, loadBalancer, "HTTP", healthCheckNodePort, path, annotations) if err != nil { return nil, fmt.Errorf("Failed to ensure health check for localized service %v on node port %v: %q", loadBalancerName, healthCheckNodePort, err) } @@ -2495,25 +2472,25 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS hcProtocol = "TCP" } // there must be no path on TCP health check - err = c.ensureLoadBalancerHealthCheck(loadBalancer, hcProtocol, tcpHealthCheckPort, "", annotations) + err = c.ensureLoadBalancerHealthCheck(ctx, loadBalancer, hcProtocol, tcpHealthCheckPort, "", annotations) if err != nil { return nil, err } } - err = c.updateInstanceSecurityGroupsForLoadBalancer(loadBalancer, instances, annotations) + err = c.updateInstanceSecurityGroupsForLoadBalancer(ctx, loadBalancer, instances, annotations) if err != nil { klog.Warningf("Error opening ingress rules for the load balancer to the instances: %q", err) return nil, err } - err = c.ensureLoadBalancerInstances(aws.StringValue(loadBalancer.LoadBalancerName), loadBalancer.Instances, instances) + err = c.ensureLoadBalancerInstances(ctx, aws.ToString(loadBalancer.LoadBalancerName), loadBalancer.Instances, instances) if err != nil { klog.Warningf("Error registering instances with the load balancer: %q", err) return nil, err } - klog.V(1).Infof("Loadbalancer %s (%v) has DNS name %s", loadBalancerName, serviceName, aws.StringValue(loadBalancer.DNSName)) + klog.V(1).Infof("Loadbalancer %s (%v) has DNS name %s", loadBalancerName, serviceName, aws.ToString(loadBalancer.DNSName)) // TODO: Wait for creation? @@ -2529,7 +2506,7 @@ func (c *Cloud) GetLoadBalancer(ctx context.Context, clusterName string, service loadBalancerName := c.GetLoadBalancerName(ctx, clusterName, service) if isNLB(service.Annotations) { - lb, err := c.describeLoadBalancerv2(loadBalancerName) + lb, err := c.describeLoadBalancerv2(ctx, loadBalancerName) if err != nil { return nil, false, err } @@ -2539,7 +2516,7 @@ func (c *Cloud) GetLoadBalancer(ctx context.Context, clusterName string, service return v2toStatus(lb), true, nil } - lb, err := c.describeLoadBalancer(loadBalancerName) + lb, err := c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { return nil, false, err } @@ -2558,19 +2535,19 @@ func (c *Cloud) GetLoadBalancerName(ctx context.Context, clusterName string, ser return cloudprovider.DefaultLoadBalancerName(service) } -func toStatus(lb *elb.LoadBalancerDescription) *v1.LoadBalancerStatus { +func toStatus(lb *elbtypes.LoadBalancerDescription) *v1.LoadBalancerStatus { status := &v1.LoadBalancerStatus{} - if aws.StringValue(lb.DNSName) != "" { + if aws.ToString(lb.DNSName) != "" { var ingress v1.LoadBalancerIngress - ingress.Hostname = aws.StringValue(lb.DNSName) + ingress.Hostname = aws.ToString(lb.DNSName) status.Ingress = []v1.LoadBalancerIngress{ingress} } return status } -func v2toStatus(lb *elbv2.LoadBalancer) *v1.LoadBalancerStatus { +func v2toStatus(lb *elbv2types.LoadBalancer) *v1.LoadBalancerStatus { status := &v1.LoadBalancerStatus{} if lb == nil { klog.Error("[BUG] v2toStatus got nil input, this is a Kubernetes bug, please report") @@ -2578,10 +2555,10 @@ func v2toStatus(lb *elbv2.LoadBalancer) *v1.LoadBalancerStatus { } // We check for Active or Provisioning, the only successful statuses - if aws.StringValue(lb.DNSName) != "" && (aws.StringValue(lb.State.Code) == elbv2.LoadBalancerStateEnumActive || - aws.StringValue(lb.State.Code) == elbv2.LoadBalancerStateEnumProvisioning) { + if aws.ToString(lb.DNSName) != "" && (lb.State.Code == elbv2types.LoadBalancerStateEnumActive || + lb.State.Code == elbv2types.LoadBalancerStateEnumProvisioning) { var ingress v1.LoadBalancerIngress - ingress.Hostname = aws.StringValue(lb.DNSName) + ingress.Hostname = aws.ToString(lb.DNSName) status.Ingress = []v1.LoadBalancerIngress{ingress} } @@ -2592,13 +2569,13 @@ func v2toStatus(lb *elbv2.LoadBalancer) *v1.LoadBalancerStatus { // We only create instances with one security group, so we don't expect multiple security groups. // However, if there are multiple security groups, we will choose the one tagged with our cluster filter. // Otherwise we will return an error. -func findSecurityGroupForInstance(instance *ec2.Instance, taggedSecurityGroups map[string]*ec2.SecurityGroup) (*ec2.GroupIdentifier, error) { - instanceID := aws.StringValue(instance.InstanceId) +func findSecurityGroupForInstance(instance *ec2types.Instance, taggedSecurityGroups map[string]*ec2types.SecurityGroup) (*ec2types.GroupIdentifier, error) { + instanceID := aws.ToString(instance.InstanceId) - var tagged []*ec2.GroupIdentifier - var untagged []*ec2.GroupIdentifier + var tagged []ec2types.GroupIdentifier + var untagged []ec2types.GroupIdentifier for _, group := range instance.SecurityGroups { - groupID := aws.StringValue(group.GroupId) + groupID := aws.ToString(group.GroupId) if groupID == "" { klog.Warningf("Ignoring security group without id for instance %q: %v", instanceID, group) continue @@ -2621,7 +2598,7 @@ func findSecurityGroupForInstance(instance *ec2.Instance, taggedSecurityGroups m } return nil, fmt.Errorf("Multiple tagged security groups found for instance %s; ensure only the k8s security group is tagged; the tagged groups were %v", instanceID, taggedGroups) } - return tagged[0], nil + return &tagged[0], nil } if len(untagged) > 0 { @@ -2629,7 +2606,7 @@ func findSecurityGroupForInstance(instance *ec2.Instance, taggedSecurityGroups m if len(untagged) != 1 { return nil, fmt.Errorf("Multiple untagged security groups found for instance %s; ensure the k8s security group is tagged", instanceID) } - return untagged[0], nil + return &untagged[0], nil } klog.Warningf("No security group found for instance %q", instanceID) @@ -2637,52 +2614,52 @@ func findSecurityGroupForInstance(instance *ec2.Instance, taggedSecurityGroups m } // Return all the security groups that are tagged as being part of our cluster -func (c *Cloud) getTaggedSecurityGroups() (map[string]*ec2.SecurityGroup, error) { +func (c *Cloud) getTaggedSecurityGroups(ctx context.Context) (map[string]*ec2types.SecurityGroup, error) { request := &ec2.DescribeSecurityGroupsInput{} - groups, err := c.ec2.DescribeSecurityGroups(request) + groups, err := c.ec2.DescribeSecurityGroups(ctx, request) if err != nil { return nil, fmt.Errorf("error querying security groups: %q", err) } - m := make(map[string]*ec2.SecurityGroup) + m := make(map[string]*ec2types.SecurityGroup) for _, group := range groups { if !c.tagging.hasClusterTag(group.Tags) { continue } - id := aws.StringValue(group.GroupId) + id := aws.ToString(group.GroupId) if id == "" { klog.Warningf("Ignoring group without id: %v", group) continue } - m[id] = group + m[id] = &group } return m, nil } // Open security group ingress rules on the instances so that the load balancer can talk to them // Will also remove any security groups ingress rules for the load balancer that are _not_ needed for allInstances -func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancerDescription, instances map[InstanceID]*ec2.Instance, annotations map[string]string) error { +func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(ctx context.Context, lb *elbtypes.LoadBalancerDescription, instances map[InstanceID]*ec2types.Instance, annotations map[string]string) error { if c.cfg.Global.DisableSecurityGroupIngress { return nil } // Determine the load balancer security group id - lbSecurityGroupIDs := aws.StringValueSlice(lb.SecurityGroups) + lbSecurityGroupIDs := lb.SecurityGroups if len(lbSecurityGroupIDs) == 0 { - return fmt.Errorf("could not determine security group for load balancer: %s", aws.StringValue(lb.LoadBalancerName)) + return fmt.Errorf("could not determine security group for load balancer: %s", aws.ToString(lb.LoadBalancerName)) } c.sortELBSecurityGroupList(lbSecurityGroupIDs, annotations) loadBalancerSecurityGroupID := lbSecurityGroupIDs[0] // Get the actual list of groups that allow ingress from the load-balancer - var actualGroups []*ec2.SecurityGroup + var actualGroups []*ec2types.SecurityGroup { describeRequest := &ec2.DescribeSecurityGroupsInput{} - describeRequest.Filters = []*ec2.Filter{ + describeRequest.Filters = []ec2types.Filter{ newEc2Filter("ip-permission.group-id", loadBalancerSecurityGroupID), } - response, err := c.ec2.DescribeSecurityGroups(describeRequest) + response, err := c.ec2.DescribeSecurityGroups(ctx, describeRequest) if err != nil { return fmt.Errorf("error querying security groups for ELB: %q", err) } @@ -2690,11 +2667,11 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer if !c.tagging.hasClusterTag(sg.Tags) { continue } - actualGroups = append(actualGroups, sg) + actualGroups = append(actualGroups, &sg) } } - taggedSecurityGroups, err := c.getTaggedSecurityGroups() + taggedSecurityGroups, err := c.getTaggedSecurityGroups(ctx) if err != nil { return fmt.Errorf("error querying for tagged security groups: %q", err) } @@ -2715,10 +2692,10 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer } if securityGroup == nil { - klog.Warning("Ignoring instance without security group: ", aws.StringValue(instance.InstanceId)) + klog.Warning("Ignoring instance without security group: ", aws.ToString(instance.InstanceId)) continue } - id := aws.StringValue(securityGroup.GroupId) + id := aws.ToString(securityGroup.GroupId) if id == "" { klog.Warningf("found security group without id: %v", securityGroup) continue @@ -2729,7 +2706,7 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer // Compare to actual groups for _, actualGroup := range actualGroups { - actualGroupID := aws.StringValue(actualGroup.GroupId) + actualGroupID := aws.ToString(actualGroup.GroupId) if actualGroupID == "" { klog.Warning("Ignoring group without ID: ", actualGroup) continue @@ -2751,19 +2728,19 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer } else { klog.V(2).Infof("Removing rule for traffic from the load balancer (%s) to instance (%s)", loadBalancerSecurityGroupID, instanceSecurityGroupID) } - sourceGroupID := &ec2.UserIdGroupPair{} + sourceGroupID := ec2types.UserIdGroupPair{} sourceGroupID.GroupId = &loadBalancerSecurityGroupID allProtocols := "-1" - permission := &ec2.IpPermission{} + permission := ec2types.IpPermission{} permission.IpProtocol = &allProtocols - permission.UserIdGroupPairs = []*ec2.UserIdGroupPair{sourceGroupID} + permission.UserIdGroupPairs = []ec2types.UserIdGroupPair{sourceGroupID} - permissions := []*ec2.IpPermission{permission} + permissions := []ec2types.IpPermission{permission} if add { - changed, err := c.addSecurityGroupIngress(instanceSecurityGroupID, permissions) + changed, err := c.addSecurityGroupIngress(ctx, instanceSecurityGroupID, permissions) if err != nil { return err } @@ -2771,7 +2748,7 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID) } } else { - changed, err := c.removeSecurityGroupIngress(instanceSecurityGroupID, permissions) + changed, err := c.removeSecurityGroupIngress(ctx, instanceSecurityGroupID, permissions) if err != nil { return err } @@ -2792,7 +2769,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin loadBalancerName := c.GetLoadBalancerName(ctx, clusterName, service) if isNLB(service.Annotations) { - lb, err := c.describeLoadBalancerv2(loadBalancerName) + lb, err := c.describeLoadBalancerv2(ctx, loadBalancerName) if err != nil { return err } @@ -2813,14 +2790,14 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin // * Clean up SecurityGroupRules { - targetGroups, err := c.elbv2.DescribeTargetGroups( + targetGroups, err := c.elbv2.DescribeTargetGroups(ctx, &elbv2.DescribeTargetGroupsInput{LoadBalancerArn: lb.LoadBalancerArn}, ) if err != nil { return fmt.Errorf("error listing target groups before deleting load balancer: %q", err) } - _, err = c.elbv2.DeleteLoadBalancer( + _, err = c.elbv2.DeleteLoadBalancer(ctx, &elbv2.DeleteLoadBalancerInput{LoadBalancerArn: lb.LoadBalancerArn}, ) if err != nil { @@ -2828,7 +2805,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin } for _, group := range targetGroups.TargetGroups { - _, err := c.elbv2.DeleteTargetGroup( + _, err := c.elbv2.DeleteTargetGroup(ctx, &elbv2.DeleteTargetGroupInput{TargetGroupArn: group.TargetGroupArn}, ) if err != nil { @@ -2837,10 +2814,10 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin } } - return c.updateInstanceSecurityGroupsForNLB(loadBalancerName, nil, nil, nil, nil) + return c.updateInstanceSecurityGroupsForNLB(ctx, loadBalancerName, nil, nil, nil, nil) } - lb, err := c.describeLoadBalancer(loadBalancerName) + lb, err := c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { return err } @@ -2852,7 +2829,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin { // De-authorize the load balancer security group from the instances security group - err = c.updateInstanceSecurityGroupsForLoadBalancer(lb, nil, service.Annotations) + err = c.updateInstanceSecurityGroupsForLoadBalancer(ctx, lb, nil, service.Annotations) if err != nil { klog.Errorf("Error deregistering load balancer from instance security groups: %q", err) return err @@ -2864,7 +2841,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin request := &elb.DeleteLoadBalancerInput{} request.LoadBalancerName = lb.LoadBalancerName - _, err = c.elb.DeleteLoadBalancer(request) + _, err = c.elb.DeleteLoadBalancer(ctx, request) if err != nil { // TODO: Check if error was because load balancer was concurrently deleted klog.Errorf("Error deleting load balancer: %q", err) @@ -2877,13 +2854,13 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin // Note that this is annoying: the load balancer disappears from the API immediately, but it is still // deleting in the background. We get a DependencyViolation until the load balancer has deleted itself - var loadBalancerSGs = aws.StringValueSlice(lb.SecurityGroups) + var loadBalancerSGs = lb.SecurityGroups describeRequest := &ec2.DescribeSecurityGroupsInput{} - describeRequest.Filters = []*ec2.Filter{ + describeRequest.Filters = []ec2types.Filter{ newEc2Filter("group-id", loadBalancerSGs...), } - response, err := c.ec2.DescribeSecurityGroups(describeRequest) + response, err := c.ec2.DescribeSecurityGroups(ctx, describeRequest) if err != nil { return fmt.Errorf("error querying security groups for ELB: %q", err) } @@ -2900,7 +2877,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin } for _, sg := range response { - sgID := aws.StringValue(sg.GroupId) + sgID := aws.ToString(sg.GroupId) if sgID == c.cfg.Global.ElbSecurityGroup { //We don't want to delete a security group that was defined in the Cloud Configuration. @@ -2931,13 +2908,14 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin for securityGroupID := range securityGroupIDs { request := &ec2.DeleteSecurityGroupInput{} request.GroupId = &securityGroupID - _, err := c.ec2.DeleteSecurityGroup(request) + _, err := c.ec2.DeleteSecurityGroup(ctx, request) if err == nil { delete(securityGroupIDs, securityGroupID) } else { ignore := false - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == "DependencyViolation" { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "DependencyViolation" { klog.V(2).Infof("Ignoring DependencyViolation while deleting load-balancer security group (%s), assuming because LB is in process of deleting", securityGroupID) ignore = true } @@ -2976,13 +2954,13 @@ func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, serv if isLBExternal(service.Annotations) { return cloudprovider.ImplementedElsewhere } - instances, err := c.findInstancesForELB(nodes, service.Annotations) + instances, err := c.findInstancesForELB(ctx, nodes, service.Annotations) if err != nil { return err } loadBalancerName := c.GetLoadBalancerName(ctx, clusterName, service) if isNLB(service.Annotations) { - lb, err := c.describeLoadBalancerv2(loadBalancerName) + lb, err := c.describeLoadBalancerv2(ctx, loadBalancerName) if err != nil { return err } @@ -2992,7 +2970,7 @@ func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, serv _, err = c.EnsureLoadBalancer(ctx, clusterName, service, nodes) return err } - lb, err := c.describeLoadBalancer(loadBalancerName) + lb, err := c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { return err } @@ -3002,25 +2980,25 @@ func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, serv } if sslPolicyName, ok := service.Annotations[ServiceAnnotationLoadBalancerSSLNegotiationPolicy]; ok { - err := c.ensureSSLNegotiationPolicy(lb, sslPolicyName) + err := c.ensureSSLNegotiationPolicy(ctx, lb, sslPolicyName) if err != nil { return err } for _, port := range c.getLoadBalancerTLSPorts(lb) { - err := c.setSSLNegotiationPolicy(loadBalancerName, sslPolicyName, port) + err := c.setSSLNegotiationPolicy(ctx, loadBalancerName, sslPolicyName, port) if err != nil { return err } } } - err = c.ensureLoadBalancerInstances(aws.StringValue(lb.LoadBalancerName), lb.Instances, instances) + err = c.ensureLoadBalancerInstances(ctx, aws.ToString(lb.LoadBalancerName), lb.Instances, instances) if err != nil { klog.Warningf("Error registering/deregistering instances with the load balancer: %q", err) return err } - err = c.updateInstanceSecurityGroupsForLoadBalancer(lb, instances, service.Annotations) + err = c.updateInstanceSecurityGroupsForLoadBalancer(ctx, lb, instances, service.Annotations) if err != nil { return err } @@ -3029,8 +3007,8 @@ func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, serv } // Returns the instance with the specified ID -func (c *Cloud) getInstanceByID(instanceID string) (*ec2.Instance, error) { - instances, err := c.getInstancesByIDs([]*string{&instanceID}) +func (c *Cloud) getInstanceByID(ctx context.Context, instanceID string) (*ec2types.Instance, error) { + instances, err := c.getInstancesByIDs(ctx, []string{instanceID}) if err != nil { return nil, err } @@ -3045,8 +3023,8 @@ func (c *Cloud) getInstanceByID(instanceID string) (*ec2.Instance, error) { return instances[instanceID], nil } -func (c *Cloud) getInstancesByIDs(instanceIDs []*string) (map[string]*ec2.Instance, error) { - instancesByID := make(map[string]*ec2.Instance) +func (c *Cloud) getInstancesByIDs(ctx context.Context, instanceIDs []string) (map[string]*ec2types.Instance, error) { + instancesByID := make(map[string]*ec2types.Instance) if len(instanceIDs) == 0 { return instancesByID, nil } @@ -3055,46 +3033,45 @@ func (c *Cloud) getInstancesByIDs(instanceIDs []*string) (map[string]*ec2.Instan InstanceIds: instanceIDs, } - instances, err := c.ec2.DescribeInstances(request) + instances, err := c.ec2.DescribeInstances(ctx, request) if err != nil { return nil, err } for _, instance := range instances { - instanceID := aws.StringValue(instance.InstanceId) + instanceID := aws.ToString(instance.InstanceId) if instanceID == "" { continue } - instancesByID[instanceID] = instance + instancesByID[instanceID] = &instance } return instancesByID, nil } -func (c *Cloud) getInstancesByNodeNames(nodeNames []string, states ...string) ([]*ec2.Instance, error) { - names := aws.StringSlice(nodeNames) - ec2Instances := []*ec2.Instance{} +func (c *Cloud) getInstancesByNodeNames(ctx context.Context, nodeNames []string, states ...string) ([]*ec2types.Instance, error) { + ec2Instances := []*ec2types.Instance{} - for i := 0; i < len(names); i += filterNodeLimit { + for i := 0; i < len(nodeNames); i += filterNodeLimit { end := i + filterNodeLimit - if end > len(names) { - end = len(names) + if end > len(nodeNames) { + end = len(nodeNames) } - nameSlice := names[i:end] + nameSlice := nodeNames[i:end] - nodeNameFilter := &ec2.Filter{ + nodeNameFilter := ec2types.Filter{ Name: aws.String("private-dns-name"), Values: nameSlice, } - filters := []*ec2.Filter{nodeNameFilter} + filters := []ec2types.Filter{nodeNameFilter} if len(states) > 0 { filters = append(filters, newEc2Filter("instance-state-name", states...)) } - instances, err := c.describeInstances(filters) + instances, err := c.describeInstances(ctx, filters) if err != nil { klog.V(2).Infof("Failed to describe instances %v", nodeNames) return nil, err @@ -3110,20 +3087,20 @@ func (c *Cloud) getInstancesByNodeNames(nodeNames []string, states ...string) ([ } // TODO: Move to instanceCache -func (c *Cloud) describeInstances(filters []*ec2.Filter) ([]*ec2.Instance, error) { +func (c *Cloud) describeInstances(ctx context.Context, filters []ec2types.Filter) ([]*ec2types.Instance, error) { request := &ec2.DescribeInstancesInput{ Filters: filters, } - response, err := c.ec2.DescribeInstances(request) + response, err := c.ec2.DescribeInstances(ctx, request) if err != nil { return nil, err } - var matches []*ec2.Instance + var matches []*ec2types.Instance for _, instance := range response { if c.tagging.hasClusterTag(instance.Tags) { - matches = append(matches, instance) + matches = append(matches, &instance) } } return matches, nil @@ -3163,29 +3140,29 @@ func mapNodeNameToPrivateDNSName(nodeName types.NodeName) string { // // Deprecated: use instanceIDToNodeName instead. See // mapNodeNameToPrivateDNSName for details. -func mapInstanceToNodeName(i *ec2.Instance) types.NodeName { - return types.NodeName(aws.StringValue(i.PrivateDnsName)) +func mapInstanceToNodeName(i *ec2types.Instance) types.NodeName { + return types.NodeName(aws.ToString(i.PrivateDnsName)) } var aliveFilter = []string{ - ec2.InstanceStateNamePending, - ec2.InstanceStateNameRunning, - ec2.InstanceStateNameShuttingDown, - ec2.InstanceStateNameStopping, - ec2.InstanceStateNameStopped, + string(ec2types.InstanceStateNamePending), + string(ec2types.InstanceStateNameRunning), + string(ec2types.InstanceStateNameShuttingDown), + string(ec2types.InstanceStateNameStopping), + string(ec2types.InstanceStateNameStopped), } // Returns the instance with the specified node name // Returns nil if it does not exist -func (c *Cloud) findInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, error) { +func (c *Cloud) findInstanceByNodeName(ctx context.Context, nodeName types.NodeName) (*ec2types.Instance, error) { privateDNSName := mapNodeNameToPrivateDNSName(nodeName) - filters := []*ec2.Filter{ + filters := []ec2types.Filter{ newEc2Filter("private-dns-name", privateDNSName), // exclude instances in "terminated" state newEc2Filter("instance-state-name", aliveFilter...), } - instances, err := c.describeInstances(filters) + instances, err := c.describeInstances(ctx, filters) if err != nil { return nil, err } @@ -3201,8 +3178,8 @@ func (c *Cloud) findInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, // Returns the instance with the specified node name // Like findInstanceByNodeName, but returns error if node not found -func (c *Cloud) getInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, error) { - var instance *ec2.Instance +func (c *Cloud) getInstanceByNodeName(ctx context.Context, nodeName types.NodeName) (*ec2types.Instance, error) { + var instance *ec2types.Instance // we leverage node cache to try to retrieve node's instance id first, as // get instance by instance id is way more efficient than by filters in @@ -3210,9 +3187,9 @@ func (c *Cloud) getInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, e awsID, err := c.nodeNameToInstanceID(nodeName) if err != nil { klog.V(3).Infof("Unable to convert node name %q to aws instanceID, fall back to findInstanceByNodeName: %v", nodeName, err) - instance, err = c.findInstanceByNodeName(nodeName) + instance, err = c.findInstanceByNodeName(ctx, nodeName) } else { - instance, err = c.getInstanceByID(string(awsID)) + instance, err = c.getInstanceByID(ctx, string(awsID)) } if err == nil && instance == nil { return nil, cloudprovider.InstanceNotFound @@ -3220,12 +3197,12 @@ func (c *Cloud) getInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, e return instance, err } -func (c *Cloud) getFullInstance(nodeName types.NodeName) (*awsInstance, *ec2.Instance, error) { +func (c *Cloud) getFullInstance(ctx context.Context, nodeName types.NodeName) (*awsInstance, *ec2types.Instance, error) { if nodeName == "" { - instance, err := c.getInstanceByID(c.selfAWSInstance.awsID) + instance, err := c.getInstanceByID(ctx, c.selfAWSInstance.awsID) return c.selfAWSInstance, instance, err } - instance, err := c.getInstanceByNodeName(nodeName) + instance, err := c.getInstanceByNodeName(ctx, nodeName) if err != nil { return nil, nil, err } @@ -3315,7 +3292,7 @@ func checkProtocol(port v1.ServicePort, annotations map[string]string) error { return fmt.Errorf("Protocol %s not supported by LoadBalancer", port.Protocol) } -func getRegionFromMetadata(cfg config.CloudConfig, metadata config.EC2Metadata) (string, error) { +func getRegionFromMetadata(ctx context.Context, cfg config.CloudConfig, metadata config.EC2Metadata) (string, error) { // For backwards compatibility reasons, keeping this check to avoid breaking possible // cases where Zone was set to override the region configuration. Otherwise, fall back // to getting region the standard way. @@ -3326,5 +3303,5 @@ func getRegionFromMetadata(cfg config.CloudConfig, metadata config.EC2Metadata) return azToRegion(zone) } - return cfg.GetRegion(metadata) + return cfg.GetRegion(ctx, metadata) } diff --git a/pkg/providers/v1/aws_assumerole_provider.go b/pkg/providers/v1/aws_assumerole_provider.go deleted file mode 100644 index ad5a63b4c7..0000000000 --- a/pkg/providers/v1/aws_assumerole_provider.go +++ /dev/null @@ -1,62 +0,0 @@ -/* -Copyright 2014 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package aws - -import ( - "sync" - "time" - - "github.com/aws/aws-sdk-go/aws/credentials" -) - -const ( - invalidateCredsAfter = 1 * time.Second -) - -// assumeRoleProviderWithRateLimiting makes sure we call the underlying provider only -// once after `invalidateCredsAfter` period -type assumeRoleProviderWithRateLimiting struct { - provider credentials.Provider - invalidateCredsAfter time.Duration - sync.RWMutex - lastError error - lastValue credentials.Value - lastRetrieveTime time.Time -} - -func assumeRoleProvider(provider credentials.Provider) credentials.Provider { - return &assumeRoleProviderWithRateLimiting{provider: provider, - invalidateCredsAfter: invalidateCredsAfter} -} - -func (l *assumeRoleProviderWithRateLimiting) Retrieve() (credentials.Value, error) { - l.Lock() - defer l.Unlock() - if time.Since(l.lastRetrieveTime) < l.invalidateCredsAfter { - if l.lastError != nil { - return credentials.Value{}, l.lastError - } - return l.lastValue, nil - } - l.lastValue, l.lastError = l.provider.Retrieve() - l.lastRetrieveTime = time.Now() - return l.lastValue, l.lastError -} - -func (l *assumeRoleProviderWithRateLimiting) IsExpired() bool { - return l.provider.IsExpired() -} diff --git a/pkg/providers/v1/aws_assumerole_provider_test.go b/pkg/providers/v1/aws_assumerole_provider_test.go deleted file mode 100644 index db5af7355a..0000000000 --- a/pkg/providers/v1/aws_assumerole_provider_test.go +++ /dev/null @@ -1,132 +0,0 @@ -/* -Copyright 2014 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package aws - -import ( - "fmt" - "reflect" - "sync" - "testing" - "time" - - "github.com/aws/aws-sdk-go/aws/credentials" -) - -func Test_assumeRoleProviderWithRateLimiting_Retrieve(t *testing.T) { - type fields struct { - provider credentials.Provider - invalidateCredsAfter time.Duration - RWMutex sync.RWMutex - lastError error - lastValue credentials.Value - lastRetrieveTime time.Time - } - tests := []struct { - name string - fields *fields - want credentials.Value - wantProviderCalled bool - sleepBeforeCallingProvider time.Duration - wantErr bool - wantErrString string - }{{ - name: "Call assume role provider and verify access ID returned", - fields: &fields{provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID"}}, - want: credentials.Value{AccessKeyID: "fakeID"}, - wantProviderCalled: true, - }, { - name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last value", - fields: &fields{ - provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID"}, - invalidateCredsAfter: 100 * time.Millisecond, - lastValue: credentials.Value{AccessKeyID: "fakeID1"}, - lastRetrieveTime: time.Now(), - }, - want: credentials.Value{AccessKeyID: "fakeID1"}, - wantProviderCalled: false, - sleepBeforeCallingProvider: 10 * time.Millisecond, - }, { - name: "Assume role provider returns an error when trying to assume a role", - fields: &fields{ - provider: &fakeAssumeRoleProvider{err: fmt.Errorf("can't assume fake role")}, - invalidateCredsAfter: 10 * time.Millisecond, - lastRetrieveTime: time.Now(), - }, - wantProviderCalled: true, - wantErr: true, - wantErrString: "can't assume fake role", - sleepBeforeCallingProvider: 15 * time.Millisecond, - }, { - name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last error value", - fields: &fields{ - provider: &fakeAssumeRoleProvider{}, - invalidateCredsAfter: 100 * time.Millisecond, - lastRetrieveTime: time.Now(), - }, - want: credentials.Value{}, - wantProviderCalled: false, - wantErr: true, - wantErrString: "can't assume fake role", - }, { - name: "Delayed call to assume role API, should call the underlying provider", - fields: &fields{ - provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID2"}, - invalidateCredsAfter: 20 * time.Millisecond, - lastRetrieveTime: time.Now(), - }, - want: credentials.Value{AccessKeyID: "fakeID2"}, - wantProviderCalled: true, - sleepBeforeCallingProvider: 25 * time.Millisecond, - }} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - l := &assumeRoleProviderWithRateLimiting{ - provider: tt.fields.provider, - invalidateCredsAfter: tt.fields.invalidateCredsAfter, - lastError: tt.fields.lastError, - lastValue: tt.fields.lastValue, - lastRetrieveTime: tt.fields.lastRetrieveTime, - } - time.Sleep(tt.sleepBeforeCallingProvider) - got, err := l.Retrieve() - if (err != nil) != tt.wantErr && (tt.wantErr && reflect.DeepEqual(err, tt.wantErrString)) { - t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() got = %v, want %v", got, tt.want) - return - } - if tt.wantProviderCalled != tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled { - t.Errorf("provider called %v, want %v", tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled, tt.wantProviderCalled) - } - }) - } -} - -type fakeAssumeRoleProvider struct { - accesskeyID string - err error - providerCalled bool -} - -func (f *fakeAssumeRoleProvider) Retrieve() (credentials.Value, error) { - f.providerCalled = true - return credentials.Value{AccessKeyID: f.accesskeyID}, f.err -} - -func (f *fakeAssumeRoleProvider) IsExpired() bool { return true } diff --git a/pkg/providers/v1/aws_ec2.go b/pkg/providers/v1/aws_ec2.go index 5c9ce3f483..7ba9fb0034 100644 --- a/pkg/providers/v1/aws_ec2.go +++ b/pkg/providers/v1/aws_ec2.go @@ -17,34 +17,57 @@ limitations under the License. package aws import ( + "context" "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) +// EC2API is an interface to satisfy the ec2.Client API. +// More details about this pattern: https://docs.aws.amazon.com/sdk-for-go/v2/developer-guide/unit-testing.html +type EC2API interface { + AuthorizeSecurityGroupIngress(ctx context.Context, params *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) + CreateRoute(ctx context.Context, params *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) + CreateSecurityGroup(ctx context.Context, params *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) + CreateTags(ctx context.Context, params *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) + DeleteRoute(ctx context.Context, params *ec2.DeleteRouteInput, optFns ...func(*ec2.Options)) (*ec2.DeleteRouteOutput, error) + DeleteSecurityGroup(ctx context.Context, params *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) + DeleteTags(ctx context.Context, params *ec2.DeleteTagsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTagsOutput, error) + DescribeAvailabilityZones(ctx context.Context, params *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) + DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFuns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) + DescribeInstanceTopology(ctx context.Context, params *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTopologyOutput, error) + DescribeNetworkInterfaces(ctx context.Context, params *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) + DescribeRouteTables(ctx context.Context, params *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) + DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) + DescribeSubnets(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) + DescribeVpcs(ctx context.Context, params *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) + ModifyInstanceAttribute(ctx context.Context, params *ec2.ModifyInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyInstanceAttributeOutput, error) + RevokeSecurityGroupIngress(ctx context.Context, params *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) +} + // awsSdkEC2 is an implementation of the EC2 interface, backed by aws-sdk-go type awsSdkEC2 struct { - ec2 ec2iface.EC2API + ec2 EC2API } // Implementation of EC2.Instances -func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) { +func (s *awsSdkEC2) DescribeInstances(ctx context.Context, request *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) ([]ec2types.Instance, error) { // Instances are paged - results := []*ec2.Instance{} + results := []ec2types.Instance{} var nextToken *string requestTime := time.Now() if request.MaxResults == nil && len(request.InstanceIds) == 0 { // MaxResults must be set in order for pagination to work // MaxResults cannot be set with InstanceIds - request.MaxResults = aws.Int64(1000) + request.MaxResults = aws.Int32(1000) } for { - response, err := s.ec2.DescribeInstances(request) + response, err := s.ec2.DescribeInstances(ctx, request) if err != nil { recordAWSMetric("describe_instance", 0, err) return nil, fmt.Errorf("error listing AWS instances: %q", err) @@ -55,7 +78,7 @@ func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*e } nextToken = response.NextToken - if aws.StringValue(nextToken) == "" { + if aws.ToString(nextToken) == "" { break } request.NextToken = nextToken @@ -65,23 +88,38 @@ func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*e return results, nil } +func (s *awsSdkEC2) DescribeInstanceTopology(ctx context.Context, input *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) ([]ec2types.InstanceTopology, error) { + var topologies []ec2types.InstanceTopology + + paginator := ec2.NewDescribeInstanceTopologyPaginator(s.ec2, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return nil, err + } + topologies = append(topologies, output.Instances...) + } + + return topologies, nil +} + // DescribeNetworkInterfaces describes network interface provided in the input. -func (s *awsSdkEC2) DescribeNetworkInterfaces(input *ec2.DescribeNetworkInterfacesInput) (*ec2.DescribeNetworkInterfacesOutput, error) { +func (s *awsSdkEC2) DescribeNetworkInterfaces(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) { requestTime := time.Now() - resp, err := s.ec2.DescribeNetworkInterfaces(input) + resp, err := s.ec2.DescribeNetworkInterfaces(ctx, input) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("describe_network_interfaces", timeTaken, err) return resp, err } // Implements EC2.DescribeSecurityGroups -func (s *awsSdkEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { +func (s *awsSdkEC2) DescribeSecurityGroups(ctx context.Context, request *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) ([]ec2types.SecurityGroup, error) { // Security groups are paged - results := []*ec2.SecurityGroup{} + results := []ec2types.SecurityGroup{} var nextToken *string requestTime := time.Now() for { - response, err := s.ec2.DescribeSecurityGroups(request) + response, err := s.ec2.DescribeSecurityGroups(ctx, request) if err != nil { recordAWSMetric("describe_security_groups", 0, err) return nil, fmt.Errorf("error listing AWS security groups: %q", err) @@ -90,7 +128,7 @@ func (s *awsSdkEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsIn results = append(results, response.SecurityGroups...) nextToken = response.NextToken - if aws.StringValue(nextToken) == "" { + if aws.ToString(nextToken) == "" { break } request.NextToken = nextToken @@ -100,62 +138,62 @@ func (s *awsSdkEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsIn return results, nil } -func (s *awsSdkEC2) DescribeSubnets(request *ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) { +func (s *awsSdkEC2) DescribeSubnets(ctx context.Context, request *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) ([]ec2types.Subnet, error) { // Subnets are not paged - response, err := s.ec2.DescribeSubnets(request) + response, err := s.ec2.DescribeSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error listing AWS subnets: %q", err) } return response.Subnets, nil } -func (s *awsSdkEC2) DescribeAvailabilityZones(request *ec2.DescribeAvailabilityZonesInput) ([]*ec2.AvailabilityZone, error) { +func (s *awsSdkEC2) DescribeAvailabilityZones(ctx context.Context, request *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) ([]ec2types.AvailabilityZone, error) { // AZs are not paged - response, err := s.ec2.DescribeAvailabilityZones(request) + response, err := s.ec2.DescribeAvailabilityZones(ctx, request) if err != nil { return nil, fmt.Errorf("error listing AWS availability zones: %q", err) } return response.AvailabilityZones, err } -func (s *awsSdkEC2) CreateSecurityGroup(request *ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) { - return s.ec2.CreateSecurityGroup(request) +func (s *awsSdkEC2) CreateSecurityGroup(ctx context.Context, request *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) { + return s.ec2.CreateSecurityGroup(ctx, request) } -func (s *awsSdkEC2) DeleteSecurityGroup(request *ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) { - return s.ec2.DeleteSecurityGroup(request) +func (s *awsSdkEC2) DeleteSecurityGroup(ctx context.Context, request *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) { + return s.ec2.DeleteSecurityGroup(ctx, request) } -func (s *awsSdkEC2) AuthorizeSecurityGroupIngress(request *ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { - return s.ec2.AuthorizeSecurityGroupIngress(request) +func (s *awsSdkEC2) AuthorizeSecurityGroupIngress(ctx context.Context, request *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { + return s.ec2.AuthorizeSecurityGroupIngress(ctx, request) } -func (s *awsSdkEC2) RevokeSecurityGroupIngress(request *ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) { - return s.ec2.RevokeSecurityGroupIngress(request) +func (s *awsSdkEC2) RevokeSecurityGroupIngress(ctx context.Context, request *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) { + return s.ec2.RevokeSecurityGroupIngress(ctx, request) } -func (s *awsSdkEC2) CreateTags(request *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) { +func (s *awsSdkEC2) CreateTags(ctx context.Context, request *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) { requestTime := time.Now() - resp, err := s.ec2.CreateTags(request) + resp, err := s.ec2.CreateTags(ctx, request) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("create_tags", timeTaken, err) return resp, err } -func (s *awsSdkEC2) DeleteTags(request *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) { +func (s *awsSdkEC2) DeleteTags(ctx context.Context, request *ec2.DeleteTagsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTagsOutput, error) { requestTime := time.Now() - resp, err := s.ec2.DeleteTags(request) + resp, err := s.ec2.DeleteTags(ctx, request) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("delete_tags", timeTaken, err) return resp, err } -func (s *awsSdkEC2) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ([]*ec2.RouteTable, error) { - results := []*ec2.RouteTable{} +func (s *awsSdkEC2) DescribeRouteTables(ctx context.Context, request *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) ([]ec2types.RouteTable, error) { + results := []ec2types.RouteTable{} var nextToken *string requestTime := time.Now() for { - response, err := s.ec2.DescribeRouteTables(request) + response, err := s.ec2.DescribeRouteTables(ctx, request) if err != nil { recordAWSMetric("describe_route_tables", 0, err) return nil, fmt.Errorf("error listing AWS route tables: %q", err) @@ -164,7 +202,7 @@ func (s *awsSdkEC2) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ( results = append(results, response.RouteTables...) nextToken = response.NextToken - if aws.StringValue(nextToken) == "" { + if aws.ToString(nextToken) == "" { break } request.NextToken = nextToken @@ -174,18 +212,18 @@ func (s *awsSdkEC2) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ( return results, nil } -func (s *awsSdkEC2) CreateRoute(request *ec2.CreateRouteInput) (*ec2.CreateRouteOutput, error) { - return s.ec2.CreateRoute(request) +func (s *awsSdkEC2) CreateRoute(ctx context.Context, request *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) { + return s.ec2.CreateRoute(ctx, request) } -func (s *awsSdkEC2) DeleteRoute(request *ec2.DeleteRouteInput) (*ec2.DeleteRouteOutput, error) { - return s.ec2.DeleteRoute(request) +func (s *awsSdkEC2) DeleteRoute(ctx context.Context, request *ec2.DeleteRouteInput, optFns ...func(*ec2.Options)) (*ec2.DeleteRouteOutput, error) { + return s.ec2.DeleteRoute(ctx, request) } -func (s *awsSdkEC2) ModifyInstanceAttribute(request *ec2.ModifyInstanceAttributeInput) (*ec2.ModifyInstanceAttributeOutput, error) { - return s.ec2.ModifyInstanceAttribute(request) +func (s *awsSdkEC2) ModifyInstanceAttribute(ctx context.Context, request *ec2.ModifyInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyInstanceAttributeOutput, error) { + return s.ec2.ModifyInstanceAttribute(ctx, request) } -func (s *awsSdkEC2) DescribeVpcs(request *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { - return s.ec2.DescribeVpcs(request) +func (s *awsSdkEC2) DescribeVpcs(ctx context.Context, request *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) { + return s.ec2.DescribeVpcs(ctx, request) } diff --git a/pkg/providers/v1/aws_fakes.go b/pkg/providers/v1/aws_fakes.go index 18ba2431a9..025afc51fe 100644 --- a/pkg/providers/v1/aws_fakes.go +++ b/pkg/providers/v1/aws_fakes.go @@ -17,19 +17,23 @@ limitations under the License. package aws import ( + "context" "errors" "fmt" + "io" "sort" "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/aws" + stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/kms" "k8s.io/klog/v2" "k8s.io/cloud-provider-aws/pkg/providers/v1/config" @@ -39,8 +43,8 @@ import ( // FakeAWSServices is an fake AWS session used for testing type FakeAWSServices struct { region string - instances []*ec2.Instance - selfInstance *ec2.Instance + instances []*ec2types.Instance + selfInstance *ec2types.Instance networkInterfacesMacs []string networkInterfacesPrivateIPs [][]string networkInterfacesVpcIDs []string @@ -69,23 +73,23 @@ func NewFakeAWSServices(clusterID string) *FakeAWSServices { s.networkInterfacesMacs = []string{"aa:bb:cc:dd:ee:00", "aa:bb:cc:dd:ee:01"} s.networkInterfacesVpcIDs = []string{"vpc-mac0", "vpc-mac1"} - selfInstance := &ec2.Instance{} + selfInstance := &ec2types.Instance{} selfInstance.InstanceId = aws.String("i-self") - selfInstance.Placement = &ec2.Placement{ + selfInstance.Placement = &ec2types.Placement{ AvailabilityZone: aws.String("us-west-2a"), } selfInstance.PrivateDnsName = aws.String("ip-172-20-0-100.ec2.internal") selfInstance.PrivateIpAddress = aws.String("192.168.0.1") selfInstance.PublicIpAddress = aws.String("1.2.3.4") s.selfInstance = selfInstance - s.instances = []*ec2.Instance{selfInstance} + s.instances = []*ec2types.Instance{selfInstance} - selfInstance.NetworkInterfaces = []*ec2.InstanceNetworkInterface{ + selfInstance.NetworkInterfaces = []ec2types.InstanceNetworkInterface{ { - Attachment: &ec2.InstanceNetworkInterfaceAttachment{ - DeviceIndex: aws.Int64(1), + Attachment: &ec2types.InstanceNetworkInterfaceAttachment{ + DeviceIndex: aws.Int32(1), }, - PrivateIpAddresses: []*ec2.InstancePrivateIpAddress{ + PrivateIpAddresses: []ec2types.InstancePrivateIpAddress{ { Primary: aws.Bool(true), PrivateDnsName: aws.String("ip-172-20-1-100.ec2.internal"), @@ -97,13 +101,13 @@ func NewFakeAWSServices(clusterID string) *FakeAWSServices { PrivateIpAddress: aws.String("172.20.1.2"), }, }, - Status: aws.String(ec2.NetworkInterfaceStatusInUse), + Status: ec2types.NetworkInterfaceStatusInUse, }, { - Attachment: &ec2.InstanceNetworkInterfaceAttachment{ - DeviceIndex: aws.Int64(0), + Attachment: &ec2types.InstanceNetworkInterfaceAttachment{ + DeviceIndex: aws.Int32(0), }, - PrivateIpAddresses: []*ec2.InstancePrivateIpAddress{ + PrivateIpAddresses: []ec2types.InstancePrivateIpAddress{ { Primary: aws.Bool(true), PrivateDnsName: aws.String("ip-172-20-0-100.ec2.internal"), @@ -115,15 +119,14 @@ func NewFakeAWSServices(clusterID string) *FakeAWSServices { PrivateIpAddress: aws.String("172.20.0.101"), }, }, - Status: aws.String(ec2.NetworkInterfaceStatusInUse), + Status: ec2types.NetworkInterfaceStatusInUse, }, } - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterLegacy) tag.Value = aws.String(clusterID) - selfInstance.Tags = []*ec2.Tag{&tag} - + selfInstance.Tags = []ec2types.Tag{tag} s.callCounts = make(map[string]int) return s @@ -132,7 +135,7 @@ func NewFakeAWSServices(clusterID string) *FakeAWSServices { // WithAz sets the ec2 placement availability zone func (s *FakeAWSServices) WithAz(az string) *FakeAWSServices { if s.selfInstance.Placement == nil { - s.selfInstance.Placement = &ec2.Placement{} + s.selfInstance.Placement = &ec2types.Placement{} } s.selfInstance.Placement.AvailabilityZone = aws.String(az) return s @@ -154,51 +157,51 @@ func (s *FakeAWSServices) countCall(service string, api string, resourceID strin } // Compute returns a fake EC2 client -func (s *FakeAWSServices) Compute(region string) (iface.EC2, error) { +func (s *FakeAWSServices) Compute(ctx context.Context, region string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (iface.EC2, error) { return s.ec2, nil } // LoadBalancing returns a fake ELB client -func (s *FakeAWSServices) LoadBalancing(region string) (ELB, error) { +func (s *FakeAWSServices) LoadBalancing(ctx context.Context, region string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ELB, error) { return s.elb, nil } // LoadBalancingV2 returns a fake ELBV2 client -func (s *FakeAWSServices) LoadBalancingV2(region string) (ELBV2, error) { +func (s *FakeAWSServices) LoadBalancingV2(ctx context.Context, region string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ELBV2, error) { return s.elbv2, nil } // Metadata returns a fake EC2Metadata client -func (s *FakeAWSServices) Metadata() (config.EC2Metadata, error) { +func (s *FakeAWSServices) Metadata(ctx context.Context) (config.EC2Metadata, error) { return s.metadata, nil } // KeyManagement returns a fake KMS client -func (s *FakeAWSServices) KeyManagement(region string) (KMS, error) { +func (s *FakeAWSServices) KeyManagement(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (KMS, error) { return s.kms, nil } // FakeEC2 is a fake EC2 client used for testing type FakeEC2 interface { iface.EC2 - CreateSubnet(*ec2.Subnet) (*ec2.CreateSubnetOutput, error) + CreateSubnet(*ec2types.Subnet) (*ec2.CreateSubnetOutput, error) RemoveSubnets() - CreateRouteTable(*ec2.RouteTable) (*ec2.CreateRouteTableOutput, error) + CreateRouteTable(*ec2types.RouteTable) (*ec2.CreateRouteTableOutput, error) RemoveRouteTables() } // FakeEC2Impl is an implementation of the FakeEC2 interface used for testing type FakeEC2Impl struct { aws *FakeAWSServices - Subnets []*ec2.Subnet + Subnets []ec2types.Subnet DescribeSubnetsInput *ec2.DescribeSubnetsInput - RouteTables []*ec2.RouteTable + RouteTables []ec2types.RouteTable DescribeRouteTablesInput *ec2.DescribeRouteTablesInput } // DescribeInstances returns fake instance descriptions -func (ec2i *FakeEC2Impl) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) { - matches := []*ec2.Instance{} +func (ec2i *FakeEC2Impl) DescribeInstances(ctx context.Context, request *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) ([]ec2types.Instance, error) { + matches := []ec2types.Instance{} for _, instance := range ec2i.aws.instances { if request.InstanceIds != nil { if instance.InstanceId == nil { @@ -208,7 +211,7 @@ func (ec2i *FakeEC2Impl) DescribeInstances(request *ec2.DescribeInstancesInput) found := false for _, instanceID := range request.InstanceIds { - if *instanceID == *instance.InstanceId { + if instanceID == *instance.InstanceId { found = true break } @@ -229,29 +232,34 @@ func (ec2i *FakeEC2Impl) DescribeInstances(request *ec2.DescribeInstancesInput) continue } } - matches = append(matches, instance) + matches = append(matches, *instance) } return matches, nil } +// DescribeInstanceTopology is not implemented but is required for interface conformance +func (ec2i *FakeEC2Impl) DescribeInstanceTopology(ctx context.Context, request *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) ([]ec2types.InstanceTopology, error) { + panic("Not implemented") +} + // AttachVolume is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) AttachVolume(request *ec2.AttachVolumeInput) (resp *ec2.VolumeAttachment, err error) { +func (ec2i *FakeEC2Impl) AttachVolume(request *ec2.AttachVolumeInput) (resp *ec2types.VolumeAttachment, err error) { panic("Not implemented") } // DetachVolume is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) DetachVolume(request *ec2.DetachVolumeInput) (resp *ec2.VolumeAttachment, err error) { +func (ec2i *FakeEC2Impl) DetachVolume(request *ec2.DetachVolumeInput) (resp *ec2types.VolumeAttachment, err error) { panic("Not implemented") } // DescribeVolumes is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) DescribeVolumes(request *ec2.DescribeVolumesInput) ([]*ec2.Volume, error) { +func (ec2i *FakeEC2Impl) DescribeVolumes(request *ec2.DescribeVolumesInput) ([]*ec2types.Volume, error) { panic("Not implemented") } // CreateVolume is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) CreateVolume(request *ec2.CreateVolumeInput) (resp *ec2.Volume, err error) { +func (ec2i *FakeEC2Impl) CreateVolume(request *ec2.CreateVolumeInput) (resp *ec2types.Volume, err error) { panic("Not implemented") } @@ -262,37 +270,37 @@ func (ec2i *FakeEC2Impl) DeleteVolume(request *ec2.DeleteVolumeInput) (resp *ec2 // DescribeSecurityGroups is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { +func (ec2i *FakeEC2Impl) DescribeSecurityGroups(ctx context.Context, request *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) ([]ec2types.SecurityGroup, error) { panic("Not implemented") } // CreateSecurityGroup is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) CreateSecurityGroup(*ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) { +func (ec2i *FakeEC2Impl) CreateSecurityGroup(ctx context.Context, request *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) { panic("Not implemented") } // DeleteSecurityGroup is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) DeleteSecurityGroup(*ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) { +func (ec2i *FakeEC2Impl) DeleteSecurityGroup(ctx context.Context, request *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) { panic("Not implemented") } // AuthorizeSecurityGroupIngress is not implemented but is required for // interface conformance -func (ec2i *FakeEC2Impl) AuthorizeSecurityGroupIngress(*ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { +func (ec2i *FakeEC2Impl) AuthorizeSecurityGroupIngress(ctx context.Context, request *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { panic("Not implemented") } // RevokeSecurityGroupIngress is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) RevokeSecurityGroupIngress(*ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) { +func (ec2i *FakeEC2Impl) RevokeSecurityGroupIngress(ctx context.Context, request *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) { panic("Not implemented") } // DescribeVolumeModifications is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) DescribeVolumeModifications(*ec2.DescribeVolumesModificationsInput) ([]*ec2.VolumeModification, error) { +func (ec2i *FakeEC2Impl) DescribeVolumeModifications(*ec2.DescribeVolumesModificationsInput) ([]*ec2types.VolumeModification, error) { panic("Not implemented") } @@ -302,8 +310,8 @@ func (ec2i *FakeEC2Impl) ModifyVolume(*ec2.ModifyVolumeInput) (*ec2.ModifyVolume } // CreateSubnet creates fake subnets -func (ec2i *FakeEC2Impl) CreateSubnet(request *ec2.Subnet) (*ec2.CreateSubnetOutput, error) { - ec2i.Subnets = append(ec2i.Subnets, request) +func (ec2i *FakeEC2Impl) CreateSubnet(request *ec2types.Subnet) (*ec2.CreateSubnetOutput, error) { + ec2i.Subnets = append(ec2i.Subnets, *request) response := &ec2.CreateSubnetOutput{ Subnet: request, } @@ -311,7 +319,7 @@ func (ec2i *FakeEC2Impl) CreateSubnet(request *ec2.Subnet) (*ec2.CreateSubnetOut } // DescribeSubnets returns fake subnet descriptions -func (ec2i *FakeEC2Impl) DescribeSubnets(request *ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) { +func (ec2i *FakeEC2Impl) DescribeSubnets(ctx context.Context, request *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) ([]ec2types.Subnet, error) { ec2i.DescribeSubnetsInput = request return ec2i.Subnets, nil } @@ -323,8 +331,8 @@ func (ec2i *FakeEC2Impl) RemoveSubnets() { // DescribeAvailabilityZones returns fake availability zones // For every input returns a hardcoded list of fake availability zones for the moment -func (ec2i *FakeEC2Impl) DescribeAvailabilityZones(request *ec2.DescribeAvailabilityZonesInput) ([]*ec2.AvailabilityZone, error) { - return []*ec2.AvailabilityZone{ +func (ec2i *FakeEC2Impl) DescribeAvailabilityZones(ctx context.Context, request *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) ([]ec2types.AvailabilityZone, error) { + return []ec2types.AvailabilityZone{ { ZoneName: aws.String("us-west-2a"), ZoneType: aws.String("availability-zone"), @@ -354,25 +362,25 @@ func (ec2i *FakeEC2Impl) DescribeAvailabilityZones(request *ec2.DescribeAvailabi } // CreateTags is a mock for CreateTags from EC2 -func (ec2i *FakeEC2Impl) CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) { +func (ec2i *FakeEC2Impl) CreateTags(ctx context.Context, input *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) { for _, id := range input.Resources { - callCount := ec2i.aws.countCall("ec2", "CreateTags", *id) - if *id == "i-error" { + callCount := ec2i.aws.countCall("ec2", "CreateTags", id) + if id == "i-error" { return nil, errors.New("Unable to tag") } - if *id == "i-not-found" { - return nil, awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil) + if id == "i-not-found" { + return nil, errors.New("InvalidInstanceID.NotFound: Instance not found") } // return an Instance not found error for the first `n` calls // instance ID should be of the format `i-not-found-count-$N-$SUFFIX` - if strings.HasPrefix(*id, "i-not-found-count-") { - notFoundCount, err := strconv.Atoi(strings.Split(*id, "-")[4]) + if strings.HasPrefix(id, "i-not-found-count-") { + notFoundCount, err := strconv.Atoi(strings.Split(id, "-")[4]) if err != nil { panic(err) } if callCount < notFoundCount { - return nil, awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil) + return nil, errors.New("InvalidInstanceID.NotFound: Instance not found") } } } @@ -380,28 +388,28 @@ func (ec2i *FakeEC2Impl) CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTags } // DeleteTags is a mock for DeleteTags from EC2 -func (ec2i *FakeEC2Impl) DeleteTags(input *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) { +func (ec2i *FakeEC2Impl) DeleteTags(ctx context.Context, input *ec2.DeleteTagsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTagsOutput, error) { for _, id := range input.Resources { - if *id == "i-error" { + if id == "i-error" { return nil, errors.New("Unable to remove tag") } - if *id == "i-not-found" { - return nil, awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil) + if id == "i-not-found" { + return nil, errors.New("InvalidInstanceID.NotFound: Instance not found") } } return &ec2.DeleteTagsOutput{}, nil } // DescribeRouteTables returns fake route table descriptions -func (ec2i *FakeEC2Impl) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ([]*ec2.RouteTable, error) { +func (ec2i *FakeEC2Impl) DescribeRouteTables(ctx context.Context, request *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) ([]ec2types.RouteTable, error) { ec2i.DescribeRouteTablesInput = request return ec2i.RouteTables, nil } // CreateRouteTable creates fake route tables -func (ec2i *FakeEC2Impl) CreateRouteTable(request *ec2.RouteTable) (*ec2.CreateRouteTableOutput, error) { - ec2i.RouteTables = append(ec2i.RouteTables, request) +func (ec2i *FakeEC2Impl) CreateRouteTable(request *ec2types.RouteTable) (*ec2.CreateRouteTableOutput, error) { + ec2i.RouteTables = append(ec2i.RouteTables, *request) response := &ec2.CreateRouteTableOutput{ RouteTable: request, } @@ -414,24 +422,24 @@ func (ec2i *FakeEC2Impl) RemoveRouteTables() { } // CreateRoute is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) CreateRoute(request *ec2.CreateRouteInput) (*ec2.CreateRouteOutput, error) { +func (ec2i *FakeEC2Impl) CreateRoute(ctx context.Context, request *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) { panic("Not implemented") } // DeleteRoute is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) DeleteRoute(request *ec2.DeleteRouteInput) (*ec2.DeleteRouteOutput, error) { +func (ec2i *FakeEC2Impl) DeleteRoute(ctx context.Context, request *ec2.DeleteRouteInput, optFns ...func(*ec2.Options)) (*ec2.DeleteRouteOutput, error) { panic("Not implemented") } // ModifyInstanceAttribute is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) ModifyInstanceAttribute(request *ec2.ModifyInstanceAttributeInput) (*ec2.ModifyInstanceAttributeOutput, error) { +func (ec2i *FakeEC2Impl) ModifyInstanceAttribute(ctx context.Context, request *ec2.ModifyInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyInstanceAttributeOutput, error) { panic("Not implemented") } // DescribeVpcs returns fake VPC descriptions -func (ec2i *FakeEC2Impl) DescribeVpcs(request *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { - return &ec2.DescribeVpcsOutput{Vpcs: []*ec2.Vpc{{CidrBlock: aws.String("172.20.0.0/16")}}}, nil +func (ec2i *FakeEC2Impl) DescribeVpcs(ctx context.Context, request *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) { + return &ec2.DescribeVpcsOutput{Vpcs: []ec2types.Vpc{{CidrBlock: aws.String("172.20.0.0/16")}}}, nil } // FakeMetadata is a fake EC2 metadata service client used for testing @@ -440,25 +448,26 @@ type FakeMetadata struct { } // GetMetadata returns fake EC2 metadata for testing -func (m *FakeMetadata) GetMetadata(key string) (string, error) { +func (m *FakeMetadata) GetMetadata(ctx context.Context, input *imds.GetMetadataInput, optFns ...func(*imds.Options)) (*imds.GetMetadataOutput, error) { + key := input.Path networkInterfacesPrefix := "network/interfaces/macs/" i := m.aws.selfInstance if key == "placement/availability-zone" { az := "" if i.Placement != nil { - az = aws.StringValue(i.Placement.AvailabilityZone) + az = aws.ToString(i.Placement.AvailabilityZone) } - return az, nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(az))}, nil } else if key == "instance-id" { - return aws.StringValue(i.InstanceId), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.InstanceId))}, nil } else if key == "local-hostname" { - return aws.StringValue(i.PrivateDnsName), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.PrivateDnsName))}, nil } else if key == "public-hostname" { - return aws.StringValue(i.PublicDnsName), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.PublicDnsName))}, nil } else if key == "local-ipv4" { - return aws.StringValue(i.PrivateIpAddress), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.PrivateIpAddress))}, nil } else if key == "public-ipv4" { - return aws.StringValue(i.PublicIpAddress), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.PublicIpAddress))}, nil } else if strings.HasPrefix(key, networkInterfacesPrefix) { if key == networkInterfacesPrefix { // Return the MACs sorted lexically rather than in device-number @@ -467,7 +476,8 @@ func (m *FakeMetadata) GetMetadata(key string) (string, error) { macs := make([]string, len(m.aws.networkInterfacesMacs)) copy(macs, m.aws.networkInterfacesMacs) sort.Strings(macs) - return strings.Join(macs, "/\n") + "/\n", nil + value := strings.Join(macs, "/\n") + "/\n" + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(value))}, nil } keySplit := strings.Split(key, "/") @@ -475,7 +485,7 @@ func (m *FakeMetadata) GetMetadata(key string) (string, error) { if len(keySplit) == 5 && keySplit[4] == "vpc-id" { for i, macElem := range m.aws.networkInterfacesMacs { if macParam == macElem { - return m.aws.networkInterfacesVpcIDs[i], nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(m.aws.networkInterfacesVpcIDs[i]))}, nil } } } @@ -487,27 +497,29 @@ func (m *FakeMetadata) GetMetadata(key string) (string, error) { // Introduce an artificial gap, just to test eg: [eth0, eth2] n++ } - return fmt.Sprintf("%d\n", n), nil + value := fmt.Sprintf("%d\n", n) + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(value))}, nil } } } if len(keySplit) == 5 && keySplit[4] == "local-ipv4s" { for i, macElem := range m.aws.networkInterfacesMacs { if macParam == macElem { - return strings.Join(m.aws.networkInterfacesPrivateIPs[i], "/\n"), nil + value := strings.Join(m.aws.networkInterfacesPrivateIPs[i], "/\n") + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(value))}, nil } } } - return "", nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(""))}, nil } - return "", nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(""))}, nil } -// Region returns AWS region -func (m *FakeMetadata) Region() (string, error) { - return m.aws.region, nil +// GetRegion returns AWS region +func (m *FakeMetadata) GetRegion(ctx context.Context, params *imds.GetRegionInput, optFns ...func(*imds.Options)) (*imds.GetRegionOutput, error) { + return &imds.GetRegionOutput{Region: m.aws.region}, nil } // FakeELB is a fake ELB client used for testing @@ -517,108 +529,108 @@ type FakeELB struct { // CreateLoadBalancer is not implemented but is required for interface // conformance -func (elb *FakeELB) CreateLoadBalancer(*elb.CreateLoadBalancerInput) (*elb.CreateLoadBalancerOutput, error) { +func (elb *FakeELB) CreateLoadBalancer(ctx context.Context, input *elb.CreateLoadBalancerInput, opts ...func(*elb.Options)) (*elb.CreateLoadBalancerOutput, error) { panic("Not implemented") } // DeleteLoadBalancer is not implemented but is required for interface // conformance -func (elb *FakeELB) DeleteLoadBalancer(input *elb.DeleteLoadBalancerInput) (*elb.DeleteLoadBalancerOutput, error) { +func (elb *FakeELB) DeleteLoadBalancer(ctx context.Context, input *elb.DeleteLoadBalancerInput, opts ...func(*elb.Options)) (*elb.DeleteLoadBalancerOutput, error) { panic("Not implemented") } // DescribeLoadBalancers is not implemented but is required for interface // conformance -func (elb *FakeELB) DescribeLoadBalancers(input *elb.DescribeLoadBalancersInput) (*elb.DescribeLoadBalancersOutput, error) { +func (elb *FakeELB) DescribeLoadBalancers(ctx context.Context, input *elb.DescribeLoadBalancersInput, opts ...func(*elb.Options)) (*elb.DescribeLoadBalancersOutput, error) { panic("Not implemented") } // AddTags is not implemented but is required for interface conformance -func (elb *FakeELB) AddTags(input *elb.AddTagsInput) (*elb.AddTagsOutput, error) { +func (elb *FakeELB) AddTags(ctx context.Context, input *elb.AddTagsInput, opts ...func(*elb.Options)) (*elb.AddTagsOutput, error) { panic("Not implemented") } // RegisterInstancesWithLoadBalancer is not implemented but is required for // interface conformance -func (elb *FakeELB) RegisterInstancesWithLoadBalancer(*elb.RegisterInstancesWithLoadBalancerInput) (*elb.RegisterInstancesWithLoadBalancerOutput, error) { +func (elb *FakeELB) RegisterInstancesWithLoadBalancer(ctx context.Context, input *elb.RegisterInstancesWithLoadBalancerInput, opts ...func(*elb.Options)) (*elb.RegisterInstancesWithLoadBalancerOutput, error) { panic("Not implemented") } // DeregisterInstancesFromLoadBalancer is not implemented but is required for // interface conformance -func (elb *FakeELB) DeregisterInstancesFromLoadBalancer(*elb.DeregisterInstancesFromLoadBalancerInput) (*elb.DeregisterInstancesFromLoadBalancerOutput, error) { +func (elb *FakeELB) DeregisterInstancesFromLoadBalancer(ctx context.Context, input *elb.DeregisterInstancesFromLoadBalancerInput, opts ...func(*elb.Options)) (*elb.DeregisterInstancesFromLoadBalancerOutput, error) { panic("Not implemented") } // DetachLoadBalancerFromSubnets is not implemented but is required for // interface conformance -func (elb *FakeELB) DetachLoadBalancerFromSubnets(*elb.DetachLoadBalancerFromSubnetsInput) (*elb.DetachLoadBalancerFromSubnetsOutput, error) { +func (elb *FakeELB) DetachLoadBalancerFromSubnets(ctx context.Context, input *elb.DetachLoadBalancerFromSubnetsInput, opts ...func(*elb.Options)) (*elb.DetachLoadBalancerFromSubnetsOutput, error) { panic("Not implemented") } // AttachLoadBalancerToSubnets is not implemented but is required for interface // conformance -func (elb *FakeELB) AttachLoadBalancerToSubnets(*elb.AttachLoadBalancerToSubnetsInput) (*elb.AttachLoadBalancerToSubnetsOutput, error) { +func (elb *FakeELB) AttachLoadBalancerToSubnets(ctx context.Context, input *elb.AttachLoadBalancerToSubnetsInput, opts ...func(*elb.Options)) (*elb.AttachLoadBalancerToSubnetsOutput, error) { panic("Not implemented") } // CreateLoadBalancerListeners is not implemented but is required for interface // conformance -func (elb *FakeELB) CreateLoadBalancerListeners(*elb.CreateLoadBalancerListenersInput) (*elb.CreateLoadBalancerListenersOutput, error) { +func (elb *FakeELB) CreateLoadBalancerListeners(ctx context.Context, input *elb.CreateLoadBalancerListenersInput, opts ...func(*elb.Options)) (*elb.CreateLoadBalancerListenersOutput, error) { panic("Not implemented") } // DeleteLoadBalancerListeners is not implemented but is required for interface // conformance -func (elb *FakeELB) DeleteLoadBalancerListeners(*elb.DeleteLoadBalancerListenersInput) (*elb.DeleteLoadBalancerListenersOutput, error) { +func (elb *FakeELB) DeleteLoadBalancerListeners(ctx context.Context, input *elb.DeleteLoadBalancerListenersInput, opts ...func(*elb.Options)) (*elb.DeleteLoadBalancerListenersOutput, error) { panic("Not implemented") } // ApplySecurityGroupsToLoadBalancer is not implemented but is required for // interface conformance -func (elb *FakeELB) ApplySecurityGroupsToLoadBalancer(*elb.ApplySecurityGroupsToLoadBalancerInput) (*elb.ApplySecurityGroupsToLoadBalancerOutput, error) { +func (elb *FakeELB) ApplySecurityGroupsToLoadBalancer(ctx context.Context, input *elb.ApplySecurityGroupsToLoadBalancerInput, opts ...func(*elb.Options)) (*elb.ApplySecurityGroupsToLoadBalancerOutput, error) { panic("Not implemented") } // ConfigureHealthCheck is not implemented but is required for interface // conformance -func (elb *FakeELB) ConfigureHealthCheck(*elb.ConfigureHealthCheckInput) (*elb.ConfigureHealthCheckOutput, error) { +func (elb *FakeELB) ConfigureHealthCheck(ctx context.Context, input *elb.ConfigureHealthCheckInput, opts ...func(*elb.Options)) (*elb.ConfigureHealthCheckOutput, error) { panic("Not implemented") } // CreateLoadBalancerPolicy is not implemented but is required for interface // conformance -func (elb *FakeELB) CreateLoadBalancerPolicy(*elb.CreateLoadBalancerPolicyInput) (*elb.CreateLoadBalancerPolicyOutput, error) { +func (elb *FakeELB) CreateLoadBalancerPolicy(ctx context.Context, input *elb.CreateLoadBalancerPolicyInput, opts ...func(*elb.Options)) (*elb.CreateLoadBalancerPolicyOutput, error) { panic("Not implemented") } // SetLoadBalancerPoliciesForBackendServer is not implemented but is required // for interface conformance -func (elb *FakeELB) SetLoadBalancerPoliciesForBackendServer(*elb.SetLoadBalancerPoliciesForBackendServerInput) (*elb.SetLoadBalancerPoliciesForBackendServerOutput, error) { +func (elb *FakeELB) SetLoadBalancerPoliciesForBackendServer(ctx context.Context, input *elb.SetLoadBalancerPoliciesForBackendServerInput, opts ...func(*elb.Options)) (*elb.SetLoadBalancerPoliciesForBackendServerOutput, error) { panic("Not implemented") } // SetLoadBalancerPoliciesOfListener is not implemented but is required for // interface conformance -func (elb *FakeELB) SetLoadBalancerPoliciesOfListener(input *elb.SetLoadBalancerPoliciesOfListenerInput) (*elb.SetLoadBalancerPoliciesOfListenerOutput, error) { +func (elb *FakeELB) SetLoadBalancerPoliciesOfListener(ctx context.Context, input *elb.SetLoadBalancerPoliciesOfListenerInput, opts ...func(*elb.Options)) (*elb.SetLoadBalancerPoliciesOfListenerOutput, error) { panic("Not implemented") } // DescribeLoadBalancerPolicies is not implemented but is required for // interface conformance -func (elb *FakeELB) DescribeLoadBalancerPolicies(input *elb.DescribeLoadBalancerPoliciesInput) (*elb.DescribeLoadBalancerPoliciesOutput, error) { +func (elb *FakeELB) DescribeLoadBalancerPolicies(ctx context.Context, input *elb.DescribeLoadBalancerPoliciesInput, opts ...func(*elb.Options)) (*elb.DescribeLoadBalancerPoliciesOutput, error) { panic("Not implemented") } // DescribeLoadBalancerAttributes is not implemented but is required for // interface conformance -func (elb *FakeELB) DescribeLoadBalancerAttributes(*elb.DescribeLoadBalancerAttributesInput) (*elb.DescribeLoadBalancerAttributesOutput, error) { +func (elb *FakeELB) DescribeLoadBalancerAttributes(ctx context.Context, input *elb.DescribeLoadBalancerAttributesInput, opts ...func(*elb.Options)) (*elb.DescribeLoadBalancerAttributesOutput, error) { panic("Not implemented") } // ModifyLoadBalancerAttributes is not implemented but is required for // interface conformance -func (elb *FakeELB) ModifyLoadBalancerAttributes(*elb.ModifyLoadBalancerAttributesInput) (*elb.ModifyLoadBalancerAttributesOutput, error) { +func (elb *FakeELB) ModifyLoadBalancerAttributes(ctx context.Context, input *elb.ModifyLoadBalancerAttributesInput, opts ...func(*elb.Options)) (*elb.ModifyLoadBalancerAttributesOutput, error) { panic("Not implemented") } @@ -628,117 +640,97 @@ type FakeELBV2 struct { } // AddTags is not implemented but is required for interface conformance -func (elb *FakeELBV2) AddTags(input *elbv2.AddTagsInput) (*elbv2.AddTagsOutput, error) { +func (elb *FakeELBV2) AddTags(ctx context.Context, input *elbv2.AddTagsInput, optFns ...func(*elbv2.Options)) (*elbv2.AddTagsOutput, error) { panic("Not implemented") } -// CreateLoadBalancer is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) CreateLoadBalancer(*elbv2.CreateLoadBalancerInput) (*elbv2.CreateLoadBalancerOutput, error) { +// CreateLoadBalancer is not implemented but is required for interface conformance +func (elb *FakeELBV2) CreateLoadBalancer(ctx context.Context, input *elbv2.CreateLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateLoadBalancerOutput, error) { panic("Not implemented") } -// DescribeLoadBalancers is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DescribeLoadBalancers(*elbv2.DescribeLoadBalancersInput) (*elbv2.DescribeLoadBalancersOutput, error) { +// DescribeLoadBalancers is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeLoadBalancers(ctx context.Context, input *elbv2.DescribeLoadBalancersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancersOutput, error) { panic("Not implemented") } -// DeleteLoadBalancer is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DeleteLoadBalancer(*elbv2.DeleteLoadBalancerInput) (*elbv2.DeleteLoadBalancerOutput, error) { +// DeleteLoadBalancer is not implemented but is required for interface conformance +func (elb *FakeELBV2) DeleteLoadBalancer(ctx context.Context, input *elbv2.DeleteLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteLoadBalancerOutput, error) { panic("Not implemented") } -// ModifyLoadBalancerAttributes is not implemented but is required for -// interface conformance -func (elb *FakeELBV2) ModifyLoadBalancerAttributes(*elbv2.ModifyLoadBalancerAttributesInput) (*elbv2.ModifyLoadBalancerAttributesOutput, error) { +// ModifyLoadBalancerAttributes is not implemented but is required for interface conformance +func (elb *FakeELBV2) ModifyLoadBalancerAttributes(ctx context.Context, input *elbv2.ModifyLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyLoadBalancerAttributesOutput, error) { panic("Not implemented") } -// DescribeLoadBalancerAttributes is not implemented but is required for -// interface conformance -func (elb *FakeELBV2) DescribeLoadBalancerAttributes(*elbv2.DescribeLoadBalancerAttributesInput) (*elbv2.DescribeLoadBalancerAttributesOutput, error) { +// DescribeLoadBalancerAttributes is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeLoadBalancerAttributes(ctx context.Context, input *elbv2.DescribeLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancerAttributesOutput, error) { panic("Not implemented") } -// CreateTargetGroup is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) CreateTargetGroup(*elbv2.CreateTargetGroupInput) (*elbv2.CreateTargetGroupOutput, error) { +// CreateTargetGroup is not implemented but is required for interface conformance +func (elb *FakeELBV2) CreateTargetGroup(ctx context.Context, input *elbv2.CreateTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateTargetGroupOutput, error) { panic("Not implemented") } -// DescribeTargetGroups is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DescribeTargetGroups(*elbv2.DescribeTargetGroupsInput) (*elbv2.DescribeTargetGroupsOutput, error) { +// DescribeTargetGroups is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeTargetGroups(ctx context.Context, input *elbv2.DescribeTargetGroupsInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupsOutput, error) { panic("Not implemented") } -// ModifyTargetGroup is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) ModifyTargetGroup(*elbv2.ModifyTargetGroupInput) (*elbv2.ModifyTargetGroupOutput, error) { +// ModifyTargetGroup is not implemented but is required for interface conformance +func (elb *FakeELBV2) ModifyTargetGroup(ctx context.Context, input *elbv2.ModifyTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupOutput, error) { panic("Not implemented") } -// DeleteTargetGroup is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DeleteTargetGroup(*elbv2.DeleteTargetGroupInput) (*elbv2.DeleteTargetGroupOutput, error) { +// DeleteTargetGroup is not implemented but is required for interface conformance +func (elb *FakeELBV2) DeleteTargetGroup(ctx context.Context, input *elbv2.DeleteTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteTargetGroupOutput, error) { panic("Not implemented") } -// DescribeTargetHealth is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DescribeTargetHealth(input *elbv2.DescribeTargetHealthInput) (*elbv2.DescribeTargetHealthOutput, error) { +// DescribeTargetHealth is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeTargetHealth(ctx context.Context, input *elbv2.DescribeTargetHealthInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) { panic("Not implemented") } -// DescribeTargetGroupAttributes is not implemented but is required for -// interface conformance -func (elb *FakeELBV2) DescribeTargetGroupAttributes(*elbv2.DescribeTargetGroupAttributesInput) (*elbv2.DescribeTargetGroupAttributesOutput, error) { +// DescribeTargetGroupAttributes is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeTargetGroupAttributes(ctx context.Context, input *elbv2.DescribeTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupAttributesOutput, error) { panic("Not implemented") } -// ModifyTargetGroupAttributes is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) ModifyTargetGroupAttributes(*elbv2.ModifyTargetGroupAttributesInput) (*elbv2.ModifyTargetGroupAttributesOutput, error) { +// ModifyTargetGroupAttributes is not implemented but is required for interface conformance +func (elb *FakeELBV2) ModifyTargetGroupAttributes(ctx context.Context, input *elbv2.ModifyTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupAttributesOutput, error) { panic("Not implemented") } // RegisterTargets is not implemented but is required for interface conformance -func (elb *FakeELBV2) RegisterTargets(*elbv2.RegisterTargetsInput) (*elbv2.RegisterTargetsOutput, error) { +func (elb *FakeELBV2) RegisterTargets(ctx context.Context, input *elbv2.RegisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) { panic("Not implemented") } -// DeregisterTargets is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DeregisterTargets(*elbv2.DeregisterTargetsInput) (*elbv2.DeregisterTargetsOutput, error) { +// DeregisterTargets is not implemented but is required for interface conformance +func (elb *FakeELBV2) DeregisterTargets(ctx context.Context, input *elbv2.DeregisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) { panic("Not implemented") } // CreateListener is not implemented but is required for interface conformance -func (elb *FakeELBV2) CreateListener(*elbv2.CreateListenerInput) (*elbv2.CreateListenerOutput, error) { +func (elb *FakeELBV2) CreateListener(ctx context.Context, input *elbv2.CreateListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateListenerOutput, error) { panic("Not implemented") } -// DescribeListeners is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DescribeListeners(*elbv2.DescribeListenersInput) (*elbv2.DescribeListenersOutput, error) { +// DescribeListeners is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeListeners(ctx context.Context, input *elbv2.DescribeListenersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeListenersOutput, error) { panic("Not implemented") } // DeleteListener is not implemented but is required for interface conformance -func (elb *FakeELBV2) DeleteListener(*elbv2.DeleteListenerInput) (*elbv2.DeleteListenerOutput, error) { +func (elb *FakeELBV2) DeleteListener(ctx context.Context, input *elbv2.DeleteListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteListenerOutput, error) { panic("Not implemented") } // ModifyListener is not implemented but is required for interface conformance -func (elb *FakeELBV2) ModifyListener(*elbv2.ModifyListenerInput) (*elbv2.ModifyListenerOutput, error) { - panic("Not implemented") -} - -// WaitUntilLoadBalancersDeleted is not implemented but is required for -// interface conformance -func (elb *FakeELBV2) WaitUntilLoadBalancersDeleted(*elbv2.DescribeLoadBalancersInput) error { +func (elb *FakeELBV2) ModifyListener(ctx context.Context, input *elbv2.ModifyListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyListenerOutput, error) { panic("Not implemented") } @@ -765,26 +757,26 @@ type FakeKMS struct { } // DescribeKey is not implemented but is required for interface conformance -func (kms *FakeKMS) DescribeKey(*kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) { +func (kms *FakeKMS) DescribeKey(ctx context.Context, input *kms.DescribeKeyInput, optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { panic("Not implemented") } -func instanceMatchesFilter(instance *ec2.Instance, filter *ec2.Filter) bool { +func instanceMatchesFilter(instance *ec2types.Instance, filter ec2types.Filter) bool { name := *filter.Name if name == "private-dns-name" { if instance.PrivateDnsName == nil { return false } - return contains(filter.Values, *instance.PrivateDnsName) + return contains(filter.Values, aws.ToString(instance.PrivateDnsName)) } if name == "instance-state-name" { - return contains(filter.Values, *instance.State.Name) + return contains(filter.Values, string(instance.State.Name)) } if name == "tag-key" { for _, instanceTag := range instance.Tags { - if contains(filter.Values, aws.StringValue(instanceTag.Key)) { + if contains(filter.Values, aws.ToString(instanceTag.Key)) { return true } } @@ -794,7 +786,7 @@ func instanceMatchesFilter(instance *ec2.Instance, filter *ec2.Filter) bool { if strings.HasPrefix(name, "tag:") { tagName := name[4:] for _, instanceTag := range instance.Tags { - if aws.StringValue(instanceTag.Key) == tagName && contains(filter.Values, aws.StringValue(instanceTag.Value)) { + if aws.ToString(instanceTag.Key) == tagName && contains(filter.Values, aws.ToString(instanceTag.Value)) { return true } } @@ -804,10 +796,10 @@ func instanceMatchesFilter(instance *ec2.Instance, filter *ec2.Filter) bool { panic("Unknown filter name: " + name) } -func contains(haystack []*string, needle string) bool { +func contains(haystack []string, needle string) bool { for _, s := range haystack { // (deliberately panic if s == nil) - if needle == *s { + if needle == s { return true } } @@ -815,29 +807,29 @@ func contains(haystack []*string, needle string) bool { } // DescribeNetworkInterfaces returns list of ENIs for testing -func (ec2i *FakeEC2Impl) DescribeNetworkInterfaces(input *ec2.DescribeNetworkInterfacesInput) (*ec2.DescribeNetworkInterfacesOutput, error) { +func (ec2i *FakeEC2Impl) DescribeNetworkInterfaces(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) { fargateNodeNamePrefix := "fargate-" - networkInterface := []*ec2.NetworkInterface{ + networkInterface := []ec2types.NetworkInterface{ { PrivateIpAddress: aws.String("1.2.3.4"), AvailabilityZone: aws.String("us-west-2c"), }, } for _, filter := range input.Filters { - if strings.HasPrefix(*filter.Values[0], fargateNodeNamePrefix) { + if strings.HasPrefix(filter.Values[0], fargateNodeNamePrefix) { // verify filter doesn't have fargate prefix - panic(fmt.Sprintf("invalid endpoint specified for DescribeNetworkInterface call %s", *filter.Values[0])) - } else if strings.HasPrefix(*filter.Values[0], "not-found") { + panic(fmt.Sprintf("invalid endpoint specified for DescribeNetworkInterface call %s", filter.Values[0])) + } else if strings.HasPrefix(filter.Values[0], "not-found") { // for negative testing return &ec2.DescribeNetworkInterfacesOutput{}, nil } - if strings.Contains(*filter.Values[0], "return.private.dns.name") { + if strings.Contains(filter.Values[0], "return.private.dns.name") { networkInterface[0].PrivateDnsName = aws.String("ip-1-2-3-4.compute.amazon.com") } - if *filter.Values[0] == "return.private.dns.name.ipv6" { - networkInterface[0].Ipv6Addresses = []*ec2.NetworkInterfaceIpv6Address{ + if filter.Values[0] == "return.private.dns.name.ipv6" { + networkInterface[0].Ipv6Addresses = []ec2types.NetworkInterfaceIpv6Address{ { Ipv6Address: aws.String("2001:db8:3333:4444:5555:6666:7777:8888"), }, diff --git a/pkg/providers/v1/aws_instance.go b/pkg/providers/v1/aws_instance.go index e7e8b152a1..d5dc8d1f09 100644 --- a/pkg/providers/v1/aws_instance.go +++ b/pkg/providers/v1/aws_instance.go @@ -17,10 +17,12 @@ limitations under the License. package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "k8s.io/apimachinery/pkg/types" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "k8s.io/apimachinery/pkg/types" "k8s.io/cloud-provider-aws/pkg/providers/v1/iface" ) @@ -47,25 +49,25 @@ type awsInstance struct { } // newAWSInstance creates a new awsInstance object -func newAWSInstance(ec2Service iface.EC2, instance *ec2.Instance) *awsInstance { +func newAWSInstance(ec2Service iface.EC2, instance *ec2types.Instance) *awsInstance { az := "" if instance.Placement != nil { - az = aws.StringValue(instance.Placement.AvailabilityZone) + az = aws.ToString(instance.Placement.AvailabilityZone) } self := &awsInstance{ ec2: ec2Service, - awsID: aws.StringValue(instance.InstanceId), + awsID: aws.ToString(instance.InstanceId), nodeName: mapInstanceToNodeName(instance), availabilityZone: az, - instanceType: aws.StringValue(instance.InstanceType), - vpcID: aws.StringValue(instance.VpcId), - subnetID: aws.StringValue(instance.SubnetId), + instanceType: string(instance.InstanceType), + vpcID: aws.ToString(instance.VpcId), + subnetID: aws.ToString(instance.SubnetId), } return self } // Gets the full information about this instance from the EC2 API -func (i *awsInstance) describeInstance() (*ec2.Instance, error) { - return describeInstance(i.ec2, InstanceID(i.awsID)) +func (i *awsInstance) describeInstance(ctx context.Context) (*ec2types.Instance, error) { + return describeInstance(ctx, i.ec2, InstanceID(i.awsID)) } diff --git a/pkg/providers/v1/aws_loadbalancer.go b/pkg/providers/v1/aws_loadbalancer.go index c39ea3de37..b82d378c2e 100644 --- a/pkg/providers/v1/aws_loadbalancer.go +++ b/pkg/providers/v1/aws_loadbalancer.go @@ -17,8 +17,10 @@ limitations under the License. package aws import ( + "context" "crypto/sha1" "encoding/hex" + "errors" "fmt" "reflect" "regexp" @@ -26,11 +28,13 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" @@ -58,13 +62,13 @@ const ( var ( // Defaults for ELB Healthcheck - defaultElbHCHealthyThreshold = int64(2) - defaultElbHCUnhealthyThreshold = int64(6) - defaultElbHCTimeout = int64(5) - defaultElbHCInterval = int64(10) - defaultNlbHealthCheckInterval = int64(30) - defaultNlbHealthCheckTimeout = int64(10) - defaultNlbHealthCheckThreshold = int64(3) + defaultElbHCHealthyThreshold = int32(2) + defaultElbHCUnhealthyThreshold = int32(6) + defaultElbHCTimeout = int32(5) + defaultElbHCInterval = int32(10) + defaultNlbHealthCheckInterval = int32(30) + defaultNlbHealthCheckTimeout = int32(10) + defaultNlbHealthCheckThreshold = int32(3) defaultHealthCheckPort = "traffic-port" defaultHealthCheckPath = "/" @@ -90,19 +94,19 @@ func isLBExternal(annotations map[string]string) bool { type healthCheckConfig struct { Port string Path string - Protocol string - Interval int64 - Timeout int64 - HealthyThreshold int64 - UnhealthyThreshold int64 + Protocol elbv2types.ProtocolEnum + Interval int32 + Timeout int32 + HealthyThreshold int32 + UnhealthyThreshold int32 } type nlbPortMapping struct { - FrontendPort int64 - FrontendProtocol string + FrontendPort int32 + FrontendProtocol elbv2types.ProtocolEnum - TrafficPort int64 - TrafficProtocol string + TrafficPort int32 + TrafficProtocol elbv2types.ProtocolEnum SSLCertificateARN string SSLPolicy string @@ -138,8 +142,8 @@ func getKeyValuePropertiesFromAnnotation(annotations map[string]string, annotati } // ensureLoadBalancerv2 ensures a v2 load balancer is created -func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBalancerName string, mappings []nlbPortMapping, instanceIDs, discoveredSubnetIDs []string, internalELB bool, annotations map[string]string) (*elbv2.LoadBalancer, error) { - loadBalancer, err := c.describeLoadBalancerv2(loadBalancerName) +func (c *Cloud) ensureLoadBalancerv2(ctx context.Context, namespacedName types.NamespacedName, loadBalancerName string, mappings []nlbPortMapping, instanceIDs, discoveredSubnetIDs []string, internalELB bool, annotations map[string]string) (*elbv2types.LoadBalancer, error) { + loadBalancer, err := c.describeLoadBalancerv2(ctx, loadBalancerName) if err != nil { return nil, err } @@ -155,11 +159,11 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa if loadBalancer == nil { // Create the LB createRequest := &elbv2.CreateLoadBalancerInput{ - Type: aws.String(elbv2.LoadBalancerTypeEnumNetwork), + Type: elbv2types.LoadBalancerTypeEnumNetwork, Name: aws.String(loadBalancerName), } if internalELB { - createRequest.Scheme = aws.String("internal") + createRequest.Scheme = elbv2types.LoadBalancerSchemeEnumInternal } var allocationIDs []string @@ -175,27 +179,27 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa createRequest.SubnetMappings = createSubnetMappings(discoveredSubnetIDs, allocationIDs) for k, v := range tags { - createRequest.Tags = append(createRequest.Tags, &elbv2.Tag{ + createRequest.Tags = append(createRequest.Tags, elbv2types.Tag{ Key: aws.String(k), Value: aws.String(v), }) } klog.Infof("Creating load balancer for %v with name: %s", namespacedName, loadBalancerName) - createResponse, err := c.elbv2.CreateLoadBalancer(createRequest) + createResponse, err := c.elbv2.CreateLoadBalancer(ctx, createRequest) if err != nil { return nil, fmt.Errorf("error creating load balancer: %q", err) } - loadBalancer = createResponse.LoadBalancers[0] + loadBalancer = &createResponse.LoadBalancers[0] for i := range mappings { // It is easier to keep track of updates by having possibly // duplicate target groups where the backend port is the same - _, err := c.createListenerV2(createResponse.LoadBalancers[0].LoadBalancerArn, mappings[i], namespacedName, instanceIDs, *createResponse.LoadBalancers[0].VpcId, tags) + _, err := c.createListenerV2(ctx, createResponse.LoadBalancers[0].LoadBalancerArn, mappings[i], namespacedName, instanceIDs, *createResponse.LoadBalancers[0].VpcId, tags) if err != nil { return nil, fmt.Errorf("error creating listener: %q", err) } } - if err := c.reconcileLBAttributes(aws.StringValue(loadBalancer.LoadBalancerArn), annotations); err != nil { + if err := c.reconcileLBAttributes(ctx, aws.ToString(loadBalancer.LoadBalancerArn), annotations); err != nil { return nil, err } } else { @@ -203,7 +207,7 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa // sync mappings { - listenerDescriptions, err := c.elbv2.DescribeListeners( + listenerDescriptions, err := c.elbv2.DescribeListeners(ctx, &elbv2.DescribeListenersInput{ LoadBalancerArn: loadBalancer.LoadBalancerArn, }, @@ -213,15 +217,15 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa } // actual maps FrontendPort to an elbv2.Listener - actual := map[int64]map[string]*elbv2.Listener{} + actual := map[int32]map[elbv2types.ProtocolEnum]*elbv2types.Listener{} for _, listener := range listenerDescriptions.Listeners { if actual[*listener.Port] == nil { - actual[*listener.Port] = map[string]*elbv2.Listener{} + actual[*listener.Port] = map[elbv2types.ProtocolEnum]*elbv2types.Listener{} } - actual[*listener.Port][*listener.Protocol] = listener + actual[*listener.Port][listener.Protocol] = &listener } - actualTargetGroups, err := c.elbv2.DescribeTargetGroups( + actualTargetGroups, err := c.elbv2.DescribeTargetGroups(ctx, &elbv2.DescribeTargetGroupsInput{ LoadBalancerArn: loadBalancer.LoadBalancerArn, }, @@ -230,9 +234,9 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa return nil, fmt.Errorf("error listing target groups: %q", err) } - nodePortTargetGroup := map[int64]*elbv2.TargetGroup{} + nodePortTargetGroup := map[int32]*elbv2types.TargetGroup{} for _, targetGroup := range actualTargetGroups.TargetGroups { - nodePortTargetGroup[*targetGroup.Port] = targetGroup + nodePortTargetGroup[*targetGroup.Port] = &targetGroup } // Handle additions/modifications @@ -244,22 +248,22 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa if listener, ok := actual[frontendPort][frontendProtocol]; ok { listenerNeedsModification := false - if aws.StringValue(listener.Protocol) != mapping.FrontendProtocol { + if listener.Protocol != mapping.FrontendProtocol { listenerNeedsModification = true } switch mapping.FrontendProtocol { - case elbv2.ProtocolEnumTls: + case elbv2types.ProtocolEnumTls: { - if aws.StringValue(listener.SslPolicy) != mapping.SSLPolicy { + if aws.ToString(listener.SslPolicy) != mapping.SSLPolicy { listenerNeedsModification = true } - if len(listener.Certificates) == 0 || aws.StringValue(listener.Certificates[0].CertificateArn) != mapping.SSLCertificateARN { + if len(listener.Certificates) == 0 || aws.ToString(listener.Certificates[0].CertificateArn) != mapping.SSLCertificateARN { listenerNeedsModification = true } } - case elbv2.ProtocolEnumTcp: + case elbv2types.ProtocolEnumTcp: { - if aws.StringValue(listener.SslPolicy) != "" { + if aws.ToString(listener.SslPolicy) != "" { listenerNeedsModification = true } if len(listener.Certificates) != 0 { @@ -273,14 +277,14 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa targetGroupRecreated := false targetGroup, ok := nodePortTargetGroup[nodePort] - if targetGroup != nil && (!strings.EqualFold(mapping.HealthCheckConfig.Protocol, aws.StringValue(targetGroup.HealthCheckProtocol)) || - mapping.HealthCheckConfig.Interval != aws.Int64Value(targetGroup.HealthCheckIntervalSeconds)) { + if targetGroup != nil && (!strings.EqualFold(string(mapping.HealthCheckConfig.Protocol), string(targetGroup.HealthCheckProtocol)) || + mapping.HealthCheckConfig.Interval != aws.ToInt32(targetGroup.HealthCheckIntervalSeconds)) { healthCheckModified = true } - if !ok || aws.StringValue(targetGroup.Protocol) != mapping.TrafficProtocol || healthCheckModified { + if !ok || targetGroup.Protocol != mapping.TrafficProtocol || healthCheckModified { // create new target group - targetGroup, err = c.ensureTargetGroup( + targetGroup, err = c.ensureTargetGroup(ctx, nil, namespacedName, mapping, @@ -298,38 +302,38 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa if listenerNeedsModification { modifyListenerInput := &elbv2.ModifyListenerInput{ ListenerArn: listener.ListenerArn, - Port: aws.Int64(frontendPort), - Protocol: aws.String(mapping.FrontendProtocol), - DefaultActions: []*elbv2.Action{{ + Port: aws.Int32(frontendPort), + Protocol: mapping.FrontendProtocol, + DefaultActions: []elbv2types.Action{{ TargetGroupArn: targetGroup.TargetGroupArn, - Type: aws.String("forward"), + Type: elbv2types.ActionTypeEnumForward, }}, } - if mapping.FrontendProtocol == elbv2.ProtocolEnumTls { + if mapping.FrontendProtocol == elbv2types.ProtocolEnumTls { if mapping.SSLPolicy != "" { modifyListenerInput.SslPolicy = aws.String(mapping.SSLPolicy) } - modifyListenerInput.Certificates = []*elbv2.Certificate{ + modifyListenerInput.Certificates = []elbv2types.Certificate{ { CertificateArn: aws.String(mapping.SSLCertificateARN), }, } } - if _, err := c.elbv2.ModifyListener(modifyListenerInput); err != nil { + if _, err := c.elbv2.ModifyListener(ctx, modifyListenerInput); err != nil { return nil, fmt.Errorf("error updating load balancer listener: %q", err) } } // Delete old targetGroup if needed if targetGroupRecreated { - if _, err := c.elbv2.DeleteTargetGroup(&elbv2.DeleteTargetGroupInput{ + if _, err := c.elbv2.DeleteTargetGroup(ctx, &elbv2.DeleteTargetGroupInput{ TargetGroupArn: listener.DefaultActions[0].TargetGroupArn, }); err != nil { return nil, fmt.Errorf("error deleting old target group: %q", err) } } else { // Run ensureTargetGroup to make sure instances in service are up-to-date - _, err = c.ensureTargetGroup( + _, err = c.ensureTargetGroup(ctx, targetGroup, namespacedName, mapping, @@ -346,17 +350,17 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa } // Additions - _, err := c.createListenerV2(loadBalancer.LoadBalancerArn, mapping, namespacedName, instanceIDs, *loadBalancer.VpcId, tags) + _, err := c.createListenerV2(ctx, loadBalancer.LoadBalancerArn, mapping, namespacedName, instanceIDs, *loadBalancer.VpcId, tags) if err != nil { return nil, err } dirty = true } - frontEndPorts := map[int64]map[string]bool{} + frontEndPorts := map[int32]map[elbv2types.ProtocolEnum]bool{} for i := range mappings { if frontEndPorts[mappings[i].FrontendPort] == nil { - frontEndPorts[mappings[i].FrontendPort] = map[string]bool{} + frontEndPorts[mappings[i].FrontendPort] = map[elbv2types.ProtocolEnum]bool{} } frontEndPorts[mappings[i].FrontendPort][mappings[i].FrontendProtocol] = true } @@ -365,7 +369,7 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa for port := range actual { for protocol := range actual[port] { if _, ok := frontEndPorts[port][protocol]; !ok { - err := c.deleteListenerV2(actual[port][protocol]) + err := c.deleteListenerV2(ctx, actual[port][protocol]) if err != nil { return nil, err } @@ -374,29 +378,29 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa } } } - if err := c.reconcileLBAttributes(aws.StringValue(loadBalancer.LoadBalancerArn), annotations); err != nil { + if err := c.reconcileLBAttributes(ctx, aws.ToString(loadBalancer.LoadBalancerArn), annotations); err != nil { return nil, err } // Subnets cannot be modified on NLBs if dirty { - loadBalancers, err := c.elbv2.DescribeLoadBalancers( + loadBalancers, err := c.elbv2.DescribeLoadBalancers(ctx, &elbv2.DescribeLoadBalancersInput{ - LoadBalancerArns: []*string{ - loadBalancer.LoadBalancerArn, + LoadBalancerArns: []string{ + aws.ToString(loadBalancer.LoadBalancerArn), }, }, ) if err != nil { return nil, fmt.Errorf("error retrieving load balancer after update: %q", err) } - loadBalancer = loadBalancers.LoadBalancers[0] + loadBalancer = &loadBalancers.LoadBalancers[0] } } return loadBalancer, nil } -func (c *Cloud) reconcileLBAttributes(loadBalancerArn string, annotations map[string]string) error { +func (c *Cloud) reconcileLBAttributes(ctx context.Context, loadBalancerArn string, annotations map[string]string) error { desiredLoadBalancerAttributes := map[string]string{} desiredLoadBalancerAttributes[lbAttrLoadBalancingCrossZoneEnabled] = "false" @@ -435,25 +439,25 @@ func (c *Cloud) reconcileLBAttributes(loadBalancerArn string, annotations map[st desiredLoadBalancerAttributes[lbAttrAccessLogsS3Prefix] = annotations[ServiceAnnotationLoadBalancerAccessLogS3BucketPrefix] currentLoadBalancerAttributes := map[string]string{} - describeAttributesOutput, err := c.elbv2.DescribeLoadBalancerAttributes(&elbv2.DescribeLoadBalancerAttributesInput{ + describeAttributesOutput, err := c.elbv2.DescribeLoadBalancerAttributes(ctx, &elbv2.DescribeLoadBalancerAttributesInput{ LoadBalancerArn: aws.String(loadBalancerArn), }) if err != nil { return fmt.Errorf("unable to retrieve load balancer attributes during attribute sync: %q", err) } for _, attr := range describeAttributesOutput.Attributes { - currentLoadBalancerAttributes[aws.StringValue(attr.Key)] = aws.StringValue(attr.Value) + currentLoadBalancerAttributes[aws.ToString(attr.Key)] = aws.ToString(attr.Value) } - var changedAttributes []*elbv2.LoadBalancerAttribute + var changedAttributes []elbv2types.LoadBalancerAttribute if desiredLoadBalancerAttributes[lbAttrLoadBalancingCrossZoneEnabled] != currentLoadBalancerAttributes[lbAttrLoadBalancingCrossZoneEnabled] { - changedAttributes = append(changedAttributes, &elbv2.LoadBalancerAttribute{ + changedAttributes = append(changedAttributes, elbv2types.LoadBalancerAttribute{ Key: aws.String(lbAttrLoadBalancingCrossZoneEnabled), Value: aws.String(desiredLoadBalancerAttributes[lbAttrLoadBalancingCrossZoneEnabled]), }) } if desiredLoadBalancerAttributes[lbAttrAccessLogsS3Enabled] != currentLoadBalancerAttributes[lbAttrAccessLogsS3Enabled] { - changedAttributes = append(changedAttributes, &elbv2.LoadBalancerAttribute{ + changedAttributes = append(changedAttributes, elbv2types.LoadBalancerAttribute{ Key: aws.String(lbAttrAccessLogsS3Enabled), Value: aws.String(desiredLoadBalancerAttributes[lbAttrAccessLogsS3Enabled]), }) @@ -462,13 +466,13 @@ func (c *Cloud) reconcileLBAttributes(loadBalancerArn string, annotations map[st // ELBV2 API forbids us to set bucket to an empty bucket, so we keep it unchanged if AccessLogsS3Enabled==false. if desiredLoadBalancerAttributes[lbAttrAccessLogsS3Enabled] == "true" { if desiredLoadBalancerAttributes[lbAttrAccessLogsS3Bucket] != currentLoadBalancerAttributes[lbAttrAccessLogsS3Bucket] { - changedAttributes = append(changedAttributes, &elbv2.LoadBalancerAttribute{ + changedAttributes = append(changedAttributes, elbv2types.LoadBalancerAttribute{ Key: aws.String(lbAttrAccessLogsS3Bucket), Value: aws.String(desiredLoadBalancerAttributes[lbAttrAccessLogsS3Bucket]), }) } if desiredLoadBalancerAttributes[lbAttrAccessLogsS3Prefix] != currentLoadBalancerAttributes[lbAttrAccessLogsS3Prefix] { - changedAttributes = append(changedAttributes, &elbv2.LoadBalancerAttribute{ + changedAttributes = append(changedAttributes, elbv2types.LoadBalancerAttribute{ Key: aws.String(lbAttrAccessLogsS3Prefix), Value: aws.String(desiredLoadBalancerAttributes[lbAttrAccessLogsS3Prefix]), }) @@ -478,7 +482,7 @@ func (c *Cloud) reconcileLBAttributes(loadBalancerArn string, annotations map[st if len(changedAttributes) > 0 { klog.V(2).Infof("updating load-balancer attributes for %q", loadBalancerArn) - _, err = c.elbv2.ModifyLoadBalancerAttributes(&elbv2.ModifyLoadBalancerAttributesInput{ + _, err = c.elbv2.ModifyLoadBalancerAttributes(ctx, &elbv2.ModifyLoadBalancerAttributesInput{ LoadBalancerArn: aws.String(loadBalancerArn), Attributes: changedAttributes, }) @@ -494,17 +498,17 @@ var invalidELBV2NameRegex = regexp.MustCompile("[^[:alnum:]]") // buildTargetGroupName will build unique name for targetGroup of service & port. // the name is in format k8s-{namespace:8}-{name:8}-{uuid:10} (chosen to benefit most common use cases). // Note: nodePort & targetProtocol & targetType are included since they cannot be modified on existing targetGroup. -func (c *Cloud) buildTargetGroupName(serviceName types.NamespacedName, servicePort int64, nodePort int64, targetProtocol string, targetType string, mapping nlbPortMapping) string { +func (c *Cloud) buildTargetGroupName(serviceName types.NamespacedName, servicePort int32, nodePort int32, targetProtocol elbv2types.ProtocolEnum, targetType elbv2types.TargetTypeEnum, mapping nlbPortMapping) string { hasher := sha1.New() _, _ = hasher.Write([]byte(c.tagging.clusterID())) _, _ = hasher.Write([]byte(serviceName.Namespace)) _, _ = hasher.Write([]byte(serviceName.Name)) - _, _ = hasher.Write([]byte(strconv.FormatInt(servicePort, 10))) - _, _ = hasher.Write([]byte(strconv.FormatInt(nodePort, 10))) + _, _ = hasher.Write([]byte(strconv.FormatInt(int64(servicePort), 10))) + _, _ = hasher.Write([]byte(strconv.FormatInt(int64(nodePort), 10))) _, _ = hasher.Write([]byte(targetProtocol)) _, _ = hasher.Write([]byte(targetType)) _, _ = hasher.Write([]byte(mapping.HealthCheckConfig.Protocol)) - _, _ = hasher.Write([]byte(strconv.FormatInt(mapping.HealthCheckConfig.Interval, 10))) + _, _ = hasher.Write([]byte(strconv.FormatInt(int64(mapping.HealthCheckConfig.Interval), 10))) tgUUID := hex.EncodeToString(hasher.Sum(nil)) sanitizedNamespace := invalidELBV2NameRegex.ReplaceAllString(serviceName.Namespace, "") @@ -512,8 +516,8 @@ func (c *Cloud) buildTargetGroupName(serviceName types.NamespacedName, servicePo return fmt.Sprintf("k8s-%.8s-%.8s-%.10s", sanitizedNamespace, sanitizedServiceName, tgUUID) } -func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping, namespacedName types.NamespacedName, instanceIDs []string, vpcID string, tags map[string]string) (listener *elbv2.Listener, err error) { - target, err := c.ensureTargetGroup( +func (c *Cloud) createListenerV2(ctx context.Context, loadBalancerArn *string, mapping nlbPortMapping, namespacedName types.NamespacedName, instanceIDs []string, vpcID string, tags map[string]string) (listener *elbv2types.Listener, err error) { + target, err := c.ensureTargetGroup(ctx, nil, namespacedName, mapping, @@ -525,9 +529,9 @@ func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping return nil, err } - elbTags := []*elbv2.Tag{} + elbTags := []elbv2types.Tag{} for k, v := range tags { - elbTags = append(elbTags, &elbv2.Tag{ + elbTags = append(elbTags, elbv2types.Tag{ Key: aws.String(k), Value: aws.String(v), }) @@ -535,11 +539,11 @@ func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping createListernerInput := &elbv2.CreateListenerInput{ LoadBalancerArn: loadBalancerArn, - Port: aws.Int64(mapping.FrontendPort), - Protocol: aws.String(mapping.FrontendProtocol), - DefaultActions: []*elbv2.Action{{ + Port: aws.Int32(mapping.FrontendPort), + Protocol: mapping.FrontendProtocol, + DefaultActions: []elbv2types.Action{{ TargetGroupArn: target.TargetGroupArn, - Type: aws.String(elbv2.ActionTypeEnumForward), + Type: elbv2types.ActionTypeEnumForward, }}, Tags: elbTags, } @@ -547,7 +551,7 @@ func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping if mapping.SSLPolicy != "" { createListernerInput.SslPolicy = aws.String(mapping.SSLPolicy) } - createListernerInput.Certificates = []*elbv2.Certificate{ + createListernerInput.Certificates = []elbv2types.Certificate{ { CertificateArn: aws.String(mapping.SSLCertificateARN), }, @@ -555,20 +559,20 @@ func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping } klog.Infof("Creating load balancer listener for %v", namespacedName) - createListenerOutput, err := c.elbv2.CreateListener(createListernerInput) + createListenerOutput, err := c.elbv2.CreateListener(ctx, createListernerInput) if err != nil { return nil, fmt.Errorf("error creating load balancer listener: %q", err) } - return createListenerOutput.Listeners[0], nil + return &createListenerOutput.Listeners[0], nil } // cleans up listener and corresponding target group -func (c *Cloud) deleteListenerV2(listener *elbv2.Listener) error { - _, err := c.elbv2.DeleteListener(&elbv2.DeleteListenerInput{ListenerArn: listener.ListenerArn}) +func (c *Cloud) deleteListenerV2(ctx context.Context, listener *elbv2types.Listener) error { + _, err := c.elbv2.DeleteListener(ctx, &elbv2.DeleteListenerInput{ListenerArn: listener.ListenerArn}) if err != nil { return fmt.Errorf("error deleting load balancer listener: %q", err) } - _, err = c.elbv2.DeleteTargetGroup(&elbv2.DeleteTargetGroupInput{TargetGroupArn: listener.DefaultActions[0].TargetGroupArn}) + _, err = c.elbv2.DeleteTargetGroup(ctx, &elbv2.DeleteTargetGroupInput{TargetGroupArn: listener.DefaultActions[0].TargetGroupArn}) if err != nil { return fmt.Errorf("error deleting load balancer target group: %q", err) } @@ -576,41 +580,41 @@ func (c *Cloud) deleteListenerV2(listener *elbv2.Listener) error { } // ensureTargetGroup creates a target group with a set of instances. -func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName types.NamespacedName, mapping nlbPortMapping, instances []string, vpcID string, tags map[string]string) (*elbv2.TargetGroup, error) { +func (c *Cloud) ensureTargetGroup(ctx context.Context, targetGroup *elbv2types.TargetGroup, serviceName types.NamespacedName, mapping nlbPortMapping, instances []string, vpcID string, tags map[string]string) (*elbv2types.TargetGroup, error) { dirty := false expectedTargets := c.computeTargetGroupExpectedTargets(instances, mapping.TrafficPort) if targetGroup == nil { - targetType := "instance" + targetType := elbv2types.TargetTypeEnumInstance name := c.buildTargetGroupName(serviceName, mapping.FrontendPort, mapping.TrafficPort, mapping.TrafficProtocol, targetType, mapping) klog.Infof("Creating load balancer target group for %v with name: %s", serviceName, name) input := &elbv2.CreateTargetGroupInput{ VpcId: aws.String(vpcID), Name: aws.String(name), - Port: aws.Int64(mapping.TrafficPort), - Protocol: aws.String(mapping.TrafficProtocol), - TargetType: aws.String(targetType), - HealthCheckIntervalSeconds: aws.Int64(mapping.HealthCheckConfig.Interval), + Port: aws.Int32(mapping.TrafficPort), + Protocol: mapping.TrafficProtocol, + TargetType: targetType, + HealthCheckIntervalSeconds: aws.Int32(mapping.HealthCheckConfig.Interval), HealthCheckPort: aws.String(mapping.HealthCheckConfig.Port), - HealthCheckProtocol: aws.String(mapping.HealthCheckConfig.Protocol), - HealthyThresholdCount: aws.Int64(mapping.HealthCheckConfig.HealthyThreshold), - UnhealthyThresholdCount: aws.Int64(mapping.HealthCheckConfig.UnhealthyThreshold), + HealthCheckProtocol: mapping.HealthCheckConfig.Protocol, + HealthyThresholdCount: aws.Int32(mapping.HealthCheckConfig.HealthyThreshold), + UnhealthyThresholdCount: aws.Int32(mapping.HealthCheckConfig.UnhealthyThreshold), // HealthCheckTimeoutSeconds: Currently not configurable, 6 seconds for HTTP, 10 for TCP/HTTPS } - if mapping.HealthCheckConfig.Protocol != elbv2.ProtocolEnumTcp { + if mapping.HealthCheckConfig.Protocol != elbv2types.ProtocolEnumTcp { input.HealthCheckPath = aws.String(mapping.HealthCheckConfig.Path) } if len(tags) != 0 { - targetGroupTags := make([]*elbv2.Tag, 0, len(tags)) + targetGroupTags := make([]elbv2types.Tag, 0, len(tags)) for k, v := range tags { - targetGroupTags = append(targetGroupTags, &elbv2.Tag{ + targetGroupTags = append(targetGroupTags, elbv2types.Tag{ Key: aws.String(k), Value: aws.String(v), }) } input.Tags = targetGroupTags } - result, err := c.elbv2.CreateTargetGroup(input) + result, err := c.elbv2.CreateTargetGroup(ctx, input) if err != nil { return nil, fmt.Errorf("error creating load balancer target group: %q", err) } @@ -619,21 +623,21 @@ func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName ty } tg := result.TargetGroups[0] - tgARN := aws.StringValue(tg.TargetGroupArn) - if err := c.ensureTargetGroupTargets(tgARN, expectedTargets, nil); err != nil { + tgARN := aws.ToString(tg.TargetGroupArn) + if err := c.ensureTargetGroupTargets(ctx, tgARN, expectedTargets, nil); err != nil { return nil, err } - return tg, nil + return &tg, nil } // handle instances in service { - tgARN := aws.StringValue(targetGroup.TargetGroupArn) - actualTargets, err := c.obtainTargetGroupActualTargets(tgARN) + tgARN := aws.ToString(targetGroup.TargetGroupArn) + actualTargets, err := c.obtainTargetGroupActualTargets(ctx, tgARN) if err != nil { return nil, err } - if err := c.ensureTargetGroupTargets(tgARN, expectedTargets, actualTargets); err != nil { + if err := c.ensureTargetGroupTargets(ctx, tgARN, expectedTargets, actualTargets); err != nil { return nil, err } } @@ -645,24 +649,24 @@ func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName ty input := &elbv2.ModifyTargetGroupInput{ TargetGroupArn: targetGroup.TargetGroupArn, } - if mapping.HealthCheckConfig.Port != aws.StringValue(targetGroup.HealthCheckPort) { + if mapping.HealthCheckConfig.Port != aws.ToString(targetGroup.HealthCheckPort) { input.HealthCheckPort = aws.String(mapping.HealthCheckConfig.Port) dirtyHealthCheck = true } - if mapping.HealthCheckConfig.HealthyThreshold != aws.Int64Value(targetGroup.HealthyThresholdCount) { + if mapping.HealthCheckConfig.HealthyThreshold != aws.ToInt32(targetGroup.HealthyThresholdCount) { dirtyHealthCheck = true - input.HealthyThresholdCount = aws.Int64(mapping.HealthCheckConfig.HealthyThreshold) - input.UnhealthyThresholdCount = aws.Int64(mapping.HealthCheckConfig.UnhealthyThreshold) + input.HealthyThresholdCount = aws.Int32(mapping.HealthCheckConfig.HealthyThreshold) + input.UnhealthyThresholdCount = aws.Int32(mapping.HealthCheckConfig.UnhealthyThreshold) } - if !strings.EqualFold(mapping.HealthCheckConfig.Protocol, elbv2.ProtocolEnumTcp) { - if mapping.HealthCheckConfig.Path != aws.StringValue(input.HealthCheckPath) { + if !strings.EqualFold(string(mapping.HealthCheckConfig.Protocol), string(elbv2types.ProtocolEnumTcp)) { + if mapping.HealthCheckConfig.Path != aws.ToString(input.HealthCheckPath) { input.HealthCheckPath = aws.String(mapping.HealthCheckConfig.Path) dirtyHealthCheck = true } } if dirtyHealthCheck { - _, err := c.elbv2.ModifyTargetGroup(input) + _, err := c.elbv2.ModifyTargetGroup(ctx, input) if err != nil { return nil, fmt.Errorf("error modifying target group health check: %q", err) } @@ -672,19 +676,19 @@ func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName ty } if dirty { - result, err := c.elbv2.DescribeTargetGroups(&elbv2.DescribeTargetGroupsInput{ - TargetGroupArns: []*string{targetGroup.TargetGroupArn}, + result, err := c.elbv2.DescribeTargetGroups(ctx, &elbv2.DescribeTargetGroupsInput{ + TargetGroupArns: []string{aws.ToString(targetGroup.TargetGroupArn)}, }) if err != nil { return nil, fmt.Errorf("error retrieving target group after creation/update: %q", err) } - targetGroup = result.TargetGroups[0] + targetGroup = &result.TargetGroups[0] } return targetGroup, nil } -func (c *Cloud) ensureTargetGroupTargets(tgARN string, expectedTargets []*elbv2.TargetDescription, actualTargets []*elbv2.TargetDescription) error { +func (c *Cloud) ensureTargetGroupTargets(ctx context.Context, tgARN string, expectedTargets []*elbv2types.TargetDescription, actualTargets []*elbv2types.TargetDescription) error { targetsToRegister, targetsToDeregister := c.diffTargetGroupTargets(expectedTargets, actualTargets) if len(targetsToRegister) > 0 { targetsToRegisterChunks := c.chunkTargetDescriptions(targetsToRegister, defaultRegisterTargetsChunkSize) @@ -693,7 +697,7 @@ func (c *Cloud) ensureTargetGroupTargets(tgARN string, expectedTargets []*elbv2. TargetGroupArn: aws.String(tgARN), Targets: targetsChunk, } - if _, err := c.elbv2.RegisterTargets(req); err != nil { + if _, err := c.elbv2.RegisterTargets(ctx, req); err != nil { return fmt.Errorf("error trying to register targets in target group: %q", err) } } @@ -705,7 +709,7 @@ func (c *Cloud) ensureTargetGroupTargets(tgARN string, expectedTargets []*elbv2. TargetGroupArn: aws.String(tgARN), Targets: targetsChunk, } - if _, err := c.elbv2.DeregisterTargets(req); err != nil { + if _, err := c.elbv2.DeregisterTargets(ctx, req); err != nil { return fmt.Errorf("error trying to deregister targets in target group: %q", err) } } @@ -713,28 +717,28 @@ func (c *Cloud) ensureTargetGroupTargets(tgARN string, expectedTargets []*elbv2. return nil } -func (c *Cloud) computeTargetGroupExpectedTargets(instanceIDs []string, port int64) []*elbv2.TargetDescription { - expectedTargets := make([]*elbv2.TargetDescription, 0, len(instanceIDs)) +func (c *Cloud) computeTargetGroupExpectedTargets(instanceIDs []string, port int32) []*elbv2types.TargetDescription { + expectedTargets := make([]*elbv2types.TargetDescription, 0, len(instanceIDs)) for _, instanceID := range instanceIDs { - expectedTargets = append(expectedTargets, &elbv2.TargetDescription{ + expectedTargets = append(expectedTargets, &elbv2types.TargetDescription{ Id: aws.String(instanceID), - Port: aws.Int64(port), + Port: aws.Int32(port), }) } return expectedTargets } -func (c *Cloud) obtainTargetGroupActualTargets(tgARN string) ([]*elbv2.TargetDescription, error) { +func (c *Cloud) obtainTargetGroupActualTargets(ctx context.Context, tgARN string) ([]*elbv2types.TargetDescription, error) { req := &elbv2.DescribeTargetHealthInput{ TargetGroupArn: aws.String(tgARN), } - resp, err := c.elbv2.DescribeTargetHealth(req) + resp, err := c.elbv2.DescribeTargetHealth(ctx, req) if err != nil { return nil, fmt.Errorf("error describing target group health: %q", err) } - actualTargets := make([]*elbv2.TargetDescription, 0, len(resp.TargetHealthDescriptions)) + actualTargets := make([]*elbv2types.TargetDescription, 0, len(resp.TargetHealthDescriptions)) for _, targetDesc := range resp.TargetHealthDescriptions { - if targetDesc.TargetHealth.Reason != nil && aws.StringValue(targetDesc.TargetHealth.Reason) == elbv2.TargetHealthReasonEnumTargetDeregistrationInProgress { + if targetDesc.TargetHealth.Reason == elbv2types.TargetHealthReasonEnumDeregistrationInProgress { continue } actualTargets = append(actualTargets, targetDesc.Target) @@ -743,16 +747,16 @@ func (c *Cloud) obtainTargetGroupActualTargets(tgARN string) ([]*elbv2.TargetDes } // diffTargetGroupTargets computes the targets to register and targets to deregister based on existingTargets and desired instances. -func (c *Cloud) diffTargetGroupTargets(expectedTargets []*elbv2.TargetDescription, actualTargets []*elbv2.TargetDescription) (targetsToRegister []*elbv2.TargetDescription, targetsToDeregister []*elbv2.TargetDescription) { - expectedTargetsByUID := make(map[string]*elbv2.TargetDescription, len(expectedTargets)) +func (c *Cloud) diffTargetGroupTargets(expectedTargets []*elbv2types.TargetDescription, actualTargets []*elbv2types.TargetDescription) (targetsToRegister []elbv2types.TargetDescription, targetsToDeregister []elbv2types.TargetDescription) { + expectedTargetsByUID := make(map[string]elbv2types.TargetDescription, len(expectedTargets)) for _, target := range expectedTargets { - targetUID := fmt.Sprintf("%v:%v", aws.StringValue(target.Id), aws.Int64Value(target.Port)) - expectedTargetsByUID[targetUID] = target + targetUID := fmt.Sprintf("%v:%v", aws.ToString(target.Id), aws.ToInt32(target.Port)) + expectedTargetsByUID[targetUID] = *target } - actualTargetsByUID := make(map[string]*elbv2.TargetDescription, len(actualTargets)) + actualTargetsByUID := make(map[string]elbv2types.TargetDescription, len(actualTargets)) for _, target := range actualTargets { - targetUID := fmt.Sprintf("%v:%v", aws.StringValue(target.Id), aws.Int64Value(target.Port)) - actualTargetsByUID[targetUID] = target + targetUID := fmt.Sprintf("%v:%v", aws.ToString(target.Id), aws.ToInt32(target.Port)) + actualTargetsByUID[targetUID] = *target } expectedTargetsUIDs := sets.StringKeySet(expectedTargetsByUID) @@ -767,8 +771,8 @@ func (c *Cloud) diffTargetGroupTargets(expectedTargets []*elbv2.TargetDescriptio } // chunkTargetDescriptions will split slice of TargetDescription into chunks -func (c *Cloud) chunkTargetDescriptions(targets []*elbv2.TargetDescription, chunkSize int) [][]*elbv2.TargetDescription { - var chunks [][]*elbv2.TargetDescription +func (c *Cloud) chunkTargetDescriptions(targets []elbv2types.TargetDescription, chunkSize int) [][]elbv2types.TargetDescription { + var chunks [][]elbv2types.TargetDescription for i := 0; i < len(targets); i += chunkSize { end := i + chunkSize if end > len(targets) { @@ -781,12 +785,12 @@ func (c *Cloud) chunkTargetDescriptions(targets []*elbv2.TargetDescription, chun // updateInstanceSecurityGroupsForNLB will adjust securityGroup's settings to allow inbound traffic into instances from clientCIDRs and portMappings. // TIP: if either instances or clientCIDRs or portMappings are nil, then the securityGroup rules for lbName are cleared. -func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[InstanceID]*ec2.Instance, subnetCIDRs []string, clientCIDRs []string, portMappings []nlbPortMapping) error { +func (c *Cloud) updateInstanceSecurityGroupsForNLB(ctx context.Context, lbName string, instances map[InstanceID]*ec2types.Instance, subnetCIDRs []string, clientCIDRs []string, portMappings []nlbPortMapping) error { if c.cfg.Global.DisableSecurityGroupIngress { return nil } - clusterSGs, err := c.getTaggedSecurityGroups() + clusterSGs, err := c.getTaggedSecurityGroups(ctx) if err != nil { return fmt.Errorf("error querying for tagged security groups: %q", err) } @@ -798,17 +802,17 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ return err } if sg == nil { - klog.Warningf("Ignoring instance without security group: %s", aws.StringValue(instance.InstanceId)) + klog.Warningf("Ignoring instance without security group: %s", aws.ToString(instance.InstanceId)) continue } - desiredSGIDs.Insert(aws.StringValue(sg.GroupId)) + desiredSGIDs.Insert(aws.ToString(sg.GroupId)) } // TODO(@M00nF1sh): do we really needs to support SG without cluster tag at current version? // findSecurityGroupForInstance might return SG that are not tagged. { for sgID := range desiredSGIDs.Difference(sets.StringKeySet(clusterSGs)) { - sg, err := c.findSecurityGroup(sgID) + sg, err := c.findSecurityGroup(ctx, sgID) if err != nil { return fmt.Errorf("error finding instance group: %q", err) } @@ -817,20 +821,21 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ } { - clientPorts := sets.Int64{} + clientPorts := sets.Set[int32]{} clientProtocol := "tcp" - healthCheckPorts := sets.Int64{} + healthCheckPorts := sets.Set[int32]{} for _, port := range portMappings { clientPorts.Insert(port.TrafficPort) hcPort := port.TrafficPort if port.HealthCheckConfig.Port != defaultHealthCheckPort { - var err error - if hcPort, err = strconv.ParseInt(port.HealthCheckConfig.Port, 10, 0); err != nil { + hcPort64, err := strconv.ParseInt(port.HealthCheckConfig.Port, 10, 0) + if err != nil { return fmt.Errorf("Invalid health check port %v", port.HealthCheckConfig.Port) } + hcPort = int32(hcPort64) } healthCheckPorts.Insert(hcPort) - if port.TrafficProtocol == string(v1.ProtocolUDP) { + if port.TrafficProtocol == elbv2types.ProtocolEnumUdp { clientProtocol = "udp" } } @@ -842,23 +847,23 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ // If the client rule is 1) all addresses 2) tcp and 3) has same ports as the healthcheck, // then the health rules are a subset of the client rule and are not needed. if len(clientCIDRs) != 1 || clientCIDRs[0] != "0.0.0.0/0" || clientProtocol != "tcp" || !healthCheckPorts.Equal(clientPorts) { - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, subnetCIDRs); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(ctx, sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, subnetCIDRs); err != nil { return err } } - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, clientProtocol, clientPorts, clientCIDRs); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(ctx, sgID, sgPerms, clientRuleAnnotation, clientProtocol, clientPorts, clientCIDRs); err != nil { return err } } else { - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", nil, nil); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(ctx, sgID, sgPerms, healthRuleAnnotation, "tcp", nil, nil); err != nil { return err } - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, clientProtocol, nil, nil); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(ctx, sgID, sgPerms, clientRuleAnnotation, clientProtocol, nil, nil); err != nil { return err } } if !sgPerms.Equal(NewIPPermissionSet(sg.IpPermissions...).Ungroup()) { - if err := c.updateInstanceSecurityGroupForNLBMTU(sgID, sgPerms); err != nil { + if err := c.updateInstanceSecurityGroupForNLBMTU(ctx, sgID, sgPerms); err != nil { return err } } @@ -869,15 +874,15 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ // updateInstanceSecurityGroupForNLBTraffic will manage permissions set(identified by ruleDesc) on securityGroup to match desired set(allow protocol traffic from ports/cidr). // Note: sgPerms will be updated to reflect the current permission set on SG after update. -func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IPPermissionSet, ruleDesc string, protocol string, ports sets.Int64, cidrs []string) error { +func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(ctx context.Context, sgID string, sgPerms IPPermissionSet, ruleDesc string, protocol string, ports sets.Set[int32], cidrs []string) error { desiredPerms := NewIPPermissionSet() for port := range ports { for _, cidr := range cidrs { - desiredPerms.Insert(&ec2.IpPermission{ + desiredPerms.Insert(ec2types.IpPermission{ IpProtocol: aws.String(protocol), - FromPort: aws.Int64(port), - ToPort: aws.Int64(port), - IpRanges: []*ec2.IpRange{ + FromPort: aws.Int32(int32(port)), + ToPort: aws.Int32(int32(port)), + IpRanges: []ec2types.IpRange{ { CidrIp: aws.String(cidr), Description: aws.String(ruleDesc), @@ -892,7 +897,7 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IP permsToRevoke.DeleteIf(IPPermissionNotMatch{IPPermissionMatchDesc{ruleDesc}}) if len(permsToRevoke) > 0 { permsToRevokeList := permsToRevoke.List() - changed, err := c.removeSecurityGroupIngress(sgID, permsToRevokeList) + changed, err := c.removeSecurityGroupIngress(ctx, sgID, permsToRevokeList) if err != nil { klog.Warningf("Error remove traffic permission from security group: %q", err) return err @@ -904,7 +909,7 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IP } if len(permsToGrant) > 0 { permsToGrantList := permsToGrant.List() - changed, err := c.addSecurityGroupIngress(sgID, permsToGrantList) + changed, err := c.addSecurityGroupIngress(ctx, sgID, permsToGrantList) if err != nil { klog.Warningf("Error add traffic permission to security group: %q", err) return err @@ -918,16 +923,16 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IP } // Note: sgPerms will be updated to reflect the current permission set on SG after update. -func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPermissionSet) error { +func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(ctx context.Context, sgID string, sgPerms IPPermissionSet) error { desiredPerms := NewIPPermissionSet() for _, perm := range sgPerms { for _, ipRange := range perm.IpRanges { - if strings.Contains(aws.StringValue(ipRange.Description), NLBClientRuleDescription) { - desiredPerms.Insert(&ec2.IpPermission{ + if strings.Contains(aws.ToString(ipRange.Description), NLBClientRuleDescription) { + desiredPerms.Insert(ec2types.IpPermission{ IpProtocol: aws.String("icmp"), - FromPort: aws.Int64(3), - ToPort: aws.Int64(4), - IpRanges: []*ec2.IpRange{ + FromPort: aws.Int32(3), + ToPort: aws.Int32(4), + IpRanges: []ec2types.IpRange{ { CidrIp: ipRange.CidrIp, Description: aws.String(NLBMtuDiscoveryRuleDescription), @@ -943,7 +948,7 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPerm permsToRevoke.DeleteIf(IPPermissionNotMatch{IPPermissionMatchDesc{NLBMtuDiscoveryRuleDescription}}) if len(permsToRevoke) > 0 { permsToRevokeList := permsToRevoke.List() - changed, err := c.removeSecurityGroupIngress(sgID, permsToRevokeList) + changed, err := c.removeSecurityGroupIngress(ctx, sgID, permsToRevokeList) if err != nil { klog.Warningf("Error remove MTU permission from security group: %q", err) return err @@ -956,7 +961,7 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPerm } if len(permsToGrant) > 0 { permsToGrantList := permsToGrant.List() - changed, err := c.addSecurityGroupIngress(sgID, permsToGrantList) + changed, err := c.addSecurityGroupIngress(ctx, sgID, permsToGrantList) if err != nil { klog.Warningf("Error add MTU permission to security group: %q", err) return err @@ -969,8 +974,8 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPerm return nil } -func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBalancerName string, listeners []*elb.Listener, subnetIDs []string, securityGroupIDs []string, internalELB, proxyProtocol bool, loadBalancerAttributes *elb.LoadBalancerAttributes, annotations map[string]string) (*elb.LoadBalancerDescription, error) { - loadBalancer, err := c.describeLoadBalancer(loadBalancerName) +func (c *Cloud) ensureLoadBalancer(ctx context.Context, namespacedName types.NamespacedName, loadBalancerName string, listeners []elbtypes.Listener, subnetIDs []string, securityGroupIDs []string, internalELB, proxyProtocol bool, loadBalancerAttributes *elbtypes.LoadBalancerAttributes, annotations map[string]string) (*elbtypes.LoadBalancerDescription, error) { + loadBalancer, err := c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { return nil, err } @@ -992,13 +997,13 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if subnetIDs == nil { createRequest.Subnets = nil } else { - createRequest.Subnets = aws.StringSlice(subnetIDs) + createRequest.Subnets = subnetIDs } if securityGroupIDs == nil { createRequest.SecurityGroups = nil } else { - createRequest.SecurityGroups = aws.StringSlice(securityGroupIDs) + createRequest.SecurityGroups = securityGroupIDs } // Get additional tags set by the user @@ -1009,26 +1014,26 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala tags = c.tagging.buildTags(ResourceLifecycleOwned, tags) for k, v := range tags { - createRequest.Tags = append(createRequest.Tags, &elb.Tag{ + createRequest.Tags = append(createRequest.Tags, elbtypes.Tag{ Key: aws.String(k), Value: aws.String(v), }) } klog.Infof("Creating load balancer for %v with name: %s", namespacedName, loadBalancerName) - _, err := c.elb.CreateLoadBalancer(createRequest) + _, err := c.elb.CreateLoadBalancer(ctx, createRequest) if err != nil { return nil, err } if proxyProtocol { - err = c.createProxyProtocolPolicy(loadBalancerName) + err = c.createProxyProtocolPolicy(ctx, loadBalancerName) if err != nil { return nil, err } for _, listener := range listeners { klog.V(2).Infof("Adjusting AWS loadbalancer proxy protocol on node port %d. Setting to true", *listener.InstancePort) - err := c.setBackendPolicies(loadBalancerName, *listener.InstancePort, []*string{aws.String(ProxyProtocolPolicyName)}) + err := c.setBackendPolicies(ctx, loadBalancerName, listener.InstancePort, []*string{aws.String(ProxyProtocolPolicyName)}) if err != nil { return nil, err } @@ -1041,8 +1046,8 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala { // Sync subnets - expected := sets.NewString(subnetIDs...) - actual := stringSetFromPointers(loadBalancer.Subnets) + expected := sets.New[string](subnetIDs...) + actual := sets.New[string](loadBalancer.Subnets...) additions := expected.Difference(actual) removals := actual.Difference(expected) @@ -1050,9 +1055,9 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if removals.Len() != 0 { request := &elb.DetachLoadBalancerFromSubnetsInput{} request.LoadBalancerName = aws.String(loadBalancerName) - request.Subnets = stringSetToPointers(removals) + request.Subnets = stringSetToList(removals) klog.V(2).Info("Detaching load balancer from removed subnets") - _, err := c.elb.DetachLoadBalancerFromSubnets(request) + _, err := c.elb.DetachLoadBalancerFromSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error detaching AWS loadbalancer from subnets: %q", err) } @@ -1062,9 +1067,9 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if additions.Len() != 0 { request := &elb.AttachLoadBalancerToSubnetsInput{} request.LoadBalancerName = aws.String(loadBalancerName) - request.Subnets = stringSetToPointers(additions) + request.Subnets = stringSetToList(additions) klog.V(2).Info("Attaching load balancer to added subnets") - _, err := c.elb.AttachLoadBalancerToSubnets(request) + _, err := c.elb.AttachLoadBalancerToSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error attaching AWS loadbalancer to subnets: %q", err) } @@ -1074,8 +1079,8 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala { // Sync security groups - expected := sets.NewString(securityGroupIDs...) - actual := stringSetFromPointers(loadBalancer.SecurityGroups) + expected := sets.New[string](securityGroupIDs...) + actual := stringSetFromList(loadBalancer.SecurityGroups) if !expected.Equal(actual) { // This call just replaces the security groups, unlike e.g. subnets (!) @@ -1084,10 +1089,10 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if securityGroupIDs == nil { request.SecurityGroups = nil } else { - request.SecurityGroups = aws.StringSlice(securityGroupIDs) + request.SecurityGroups = securityGroupIDs } klog.V(2).Info("Applying updated security groups to load balancer") - _, err := c.elb.ApplySecurityGroupsToLoadBalancer(request) + _, err := c.elb.ApplySecurityGroupsToLoadBalancer(ctx, request) if err != nil { return nil, fmt.Errorf("error applying AWS loadbalancer security groups: %q", err) } @@ -1103,7 +1108,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala request.LoadBalancerName = aws.String(loadBalancerName) request.LoadBalancerPorts = removals klog.V(2).Info("Deleting removed load balancer listeners") - if _, err := c.elb.DeleteLoadBalancerListeners(request); err != nil { + if _, err := c.elb.DeleteLoadBalancerListeners(ctx, request); err != nil { return nil, fmt.Errorf("error deleting AWS loadbalancer listeners: %q", err) } dirty = true @@ -1114,7 +1119,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala request.LoadBalancerName = aws.String(loadBalancerName) request.Listeners = additions klog.V(2).Info("Creating added load balancer listeners") - if _, err := c.elb.CreateLoadBalancerListeners(request); err != nil { + if _, err := c.elb.CreateLoadBalancerListeners(ctx, request); err != nil { return nil, fmt.Errorf("error creating AWS loadbalancer listeners: %q", err) } dirty = true @@ -1132,7 +1137,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala // back if a policy of the same name already exists. However, the aws-sdk does not // seem to return an error to us in these cases. Therefore, this will issue an API // request every time. - err := c.createProxyProtocolPolicy(loadBalancerName) + err := c.createProxyProtocolPolicy(ctx, loadBalancerName) if err != nil { return nil, err } @@ -1140,11 +1145,11 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala proxyPolicies = append(proxyPolicies, aws.String(ProxyProtocolPolicyName)) } - foundBackends := make(map[int64]bool) - proxyProtocolBackends := make(map[int64]bool) + foundBackends := make(map[int32]bool) + proxyProtocolBackends := make(map[int32]bool) for _, backendListener := range loadBalancer.BackendServerDescriptions { - foundBackends[*backendListener.InstancePort] = false - proxyProtocolBackends[*backendListener.InstancePort] = proxyProtocolEnabled(backendListener) + foundBackends[aws.ToInt32(backendListener.InstancePort)] = false + proxyProtocolBackends[aws.ToInt32(backendListener.InstancePort)] = proxyProtocolEnabled(backendListener) } for _, listener := range listeners { @@ -1165,7 +1170,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if setPolicy { klog.V(2).Infof("Adjusting AWS loadbalancer proxy protocol on node port %d. Setting to %t", instancePort, proxyProtocol) - err := c.setBackendPolicies(loadBalancerName, instancePort, proxyPolicies) + err := c.setBackendPolicies(ctx, loadBalancerName, aws.Int32(instancePort), proxyPolicies) if err != nil { return nil, err } @@ -1179,7 +1184,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala for instancePort, found := range foundBackends { if !found { klog.V(2).Infof("Adjusting AWS loadbalancer proxy protocol on node port %d. Setting to false", instancePort) - err := c.setBackendPolicies(loadBalancerName, instancePort, []*string{}) + err := c.setBackendPolicies(ctx, loadBalancerName, aws.Int32(instancePort), []*string{}) if err != nil { return nil, err } @@ -1193,7 +1198,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala klog.V(2).Infof("Creating additional load balancer tags for %s", loadBalancerName) tags := getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags) if len(tags) > 0 { - err := c.addLoadBalancerTags(loadBalancerName, tags) + err := c.addLoadBalancerTags(ctx, loadBalancerName, tags) if err != nil { return nil, fmt.Errorf("unable to create additional load balancer tags: %v", err) } @@ -1207,7 +1212,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala { describeAttributesRequest := &elb.DescribeLoadBalancerAttributesInput{} describeAttributesRequest.LoadBalancerName = aws.String(loadBalancerName) - describeAttributesOutput, err := c.elb.DescribeLoadBalancerAttributes(describeAttributesRequest) + describeAttributesOutput, err := c.elb.DescribeLoadBalancerAttributes(ctx, describeAttributesRequest) if err != nil { klog.Warning("Unable to retrieve load balancer attributes during attribute sync") return nil, err @@ -1222,7 +1227,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala modifyAttributesRequest := &elb.ModifyLoadBalancerAttributesInput{} modifyAttributesRequest.LoadBalancerName = aws.String(loadBalancerName) modifyAttributesRequest.LoadBalancerAttributes = loadBalancerAttributes - _, err = c.elb.ModifyLoadBalancerAttributes(modifyAttributesRequest) + _, err = c.elb.ModifyLoadBalancerAttributes(ctx, modifyAttributesRequest) if err != nil { return nil, fmt.Errorf("Unable to update load balancer attributes during attribute sync: %q", err) } @@ -1231,7 +1236,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala } if dirty { - loadBalancer, err = c.describeLoadBalancer(loadBalancerName) + loadBalancer, err = c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { klog.Warning("Unable to retrieve load balancer after creation/update") return nil, err @@ -1245,10 +1250,10 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala // NOTE: there exists an O(nlgn) implementation for this function. However, as the default limit of // // listeners per elb is 100, this implementation is reduced from O(m*n) => O(n). -func syncElbListeners(loadBalancerName string, listeners []*elb.Listener, listenerDescriptions []*elb.ListenerDescription) ([]*elb.Listener, []*int64) { +func syncElbListeners(loadBalancerName string, listeners []elbtypes.Listener, listenerDescriptions []elbtypes.ListenerDescription) ([]elbtypes.Listener, []int32) { foundSet := make(map[int]bool) - removals := []*int64{} - additions := []*elb.Listener{} + removals := []int32{} + additions := []elbtypes.Listener{} for _, listenerDescription := range listenerDescriptions { actual := listenerDescription.Listener @@ -1259,11 +1264,7 @@ func syncElbListeners(loadBalancerName string, listeners []*elb.Listener, listen found := false for i, expected := range listeners { - if expected == nil { - klog.Warning("Ignoring empty desired listener for loadbalancer: ", loadBalancerName) - continue - } - if elbListenersAreEqual(actual, expected) { + if elbListenersAreEqual(*actual, expected) { // The current listener on the actual // elb is in the set of desired listeners. foundSet[i] = true @@ -1285,17 +1286,17 @@ func syncElbListeners(loadBalancerName string, listeners []*elb.Listener, listen return additions, removals } -func elbListenersAreEqual(actual, expected *elb.Listener) bool { +func elbListenersAreEqual(actual, expected elbtypes.Listener) bool { if !elbProtocolsAreEqual(actual.Protocol, expected.Protocol) { return false } if !elbProtocolsAreEqual(actual.InstanceProtocol, expected.InstanceProtocol) { return false } - if aws.Int64Value(actual.InstancePort) != aws.Int64Value(expected.InstancePort) { + if aws.ToInt32(actual.InstancePort) != aws.ToInt32(expected.InstancePort) { return false } - if aws.Int64Value(actual.LoadBalancerPort) != aws.Int64Value(expected.LoadBalancerPort) { + if actual.LoadBalancerPort != expected.LoadBalancerPort { return false } if !awsArnEquals(actual.SSLCertificateId, expected.SSLCertificateId) { @@ -1304,11 +1305,11 @@ func elbListenersAreEqual(actual, expected *elb.Listener) bool { return true } -func createSubnetMappings(subnetIDs []string, allocationIDs []string) []*elbv2.SubnetMapping { - response := []*elbv2.SubnetMapping{} +func createSubnetMappings(subnetIDs []string, allocationIDs []string) []elbv2types.SubnetMapping { + response := []elbv2types.SubnetMapping{} for index, id := range subnetIDs { - sm := &elbv2.SubnetMapping{SubnetId: aws.String(id)} + sm := elbv2types.SubnetMapping{SubnetId: aws.String(id)} if len(allocationIDs) > 0 { sm.AllocationId = aws.String(allocationIDs[index]) } @@ -1324,7 +1325,7 @@ func elbProtocolsAreEqual(l, r *string) bool { if l == nil || r == nil { return l == r } - return strings.EqualFold(aws.StringValue(l), aws.StringValue(r)) + return strings.EqualFold(aws.ToString(l), aws.ToString(r)) } // awsArnEquals checks if two ARN strings are considered the same @@ -1333,23 +1334,23 @@ func awsArnEquals(l, r *string) bool { if l == nil || r == nil { return l == r } - return strings.EqualFold(aws.StringValue(l), aws.StringValue(r)) + return strings.EqualFold(aws.ToString(l), aws.ToString(r)) } // getExpectedHealthCheck returns an elb.Healthcheck for the provided target // and using either sensible defaults or overrides via Service annotations -func (c *Cloud) getExpectedHealthCheck(target string, annotations map[string]string) (*elb.HealthCheck, error) { - healthcheck := &elb.HealthCheck{Target: &target} - getOrDefault := func(annotation string, defaultValue int64) (*int64, error) { - i64 := defaultValue - var err error +func (c *Cloud) getExpectedHealthCheck(target string, annotations map[string]string) (*elbtypes.HealthCheck, error) { + healthcheck := &elbtypes.HealthCheck{Target: &target} + getOrDefault := func(annotation string, defaultValue int32) (*int32, error) { + i32 := defaultValue if s, ok := annotations[annotation]; ok { - i64, err = strconv.ParseInt(s, 10, 0) + i64, err := strconv.ParseInt(s, 10, 0) if err != nil { return nil, fmt.Errorf("failed parsing health check annotation value: %v", err) } + i32 = int32(i64) } - return &i64, nil + return &i32, nil } var err error healthcheck.HealthyThreshold, err = getOrDefault(ServiceAnnotationLoadBalancerHCHealthyThreshold, defaultElbHCHealthyThreshold) @@ -1368,15 +1369,15 @@ func (c *Cloud) getExpectedHealthCheck(target string, annotations map[string]str if err != nil { return nil, err } - if err = healthcheck.Validate(); err != nil { + if err = ValidateHealthCheck(healthcheck); err != nil { return nil, fmt.Errorf("some of the load balancer health check parameters are invalid: %v", err) } return healthcheck, nil } // Makes sure that the health check for an ELB matches the configured health check node port -func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDescription, protocol string, port int32, path string, annotations map[string]string) error { - name := aws.StringValue(loadBalancer.LoadBalancerName) +func (c *Cloud) ensureLoadBalancerHealthCheck(ctx context.Context, loadBalancer *elbtypes.LoadBalancerDescription, protocol string, port int32, path string, annotations map[string]string) error { + name := aws.ToString(loadBalancer.LoadBalancerName) actual := loadBalancer.HealthCheck // Override healthcheck protocol, port and path based on annotations @@ -1409,11 +1410,11 @@ func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDesc // comparing attributes 1 by 1 to avoid breakage in case a new field is // added to the HC which breaks the equality - if aws.StringValue(expected.Target) == aws.StringValue(actual.Target) && - aws.Int64Value(expected.HealthyThreshold) == aws.Int64Value(actual.HealthyThreshold) && - aws.Int64Value(expected.UnhealthyThreshold) == aws.Int64Value(actual.UnhealthyThreshold) && - aws.Int64Value(expected.Interval) == aws.Int64Value(actual.Interval) && - aws.Int64Value(expected.Timeout) == aws.Int64Value(actual.Timeout) { + if aws.ToString(expected.Target) == aws.ToString(actual.Target) && + aws.ToInt32(expected.HealthyThreshold) == aws.ToInt32(actual.HealthyThreshold) && + aws.ToInt32(expected.UnhealthyThreshold) == aws.ToInt32(actual.UnhealthyThreshold) && + aws.ToInt32(expected.Interval) == aws.ToInt32(actual.Interval) && + aws.ToInt32(expected.Timeout) == aws.ToInt32(actual.Timeout) { return nil } @@ -1421,7 +1422,7 @@ func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDesc request.HealthCheck = expected request.LoadBalancerName = loadBalancer.LoadBalancerName - _, err = c.elb.ConfigureHealthCheck(request) + _, err = c.elb.ConfigureHealthCheck(ctx, request) if err != nil { return fmt.Errorf("error configuring load balancer health check for %q: %q", name, err) } @@ -1430,7 +1431,7 @@ func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDesc } // Makes sure that exactly the specified hosts are registered as instances with the load balancer -func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances []*elb.Instance, instanceIDs map[InstanceID]*ec2.Instance) error { +func (c *Cloud) ensureLoadBalancerInstances(ctx context.Context, loadBalancerName string, lbInstances []elbtypes.Instance, instanceIDs map[InstanceID]*ec2types.Instance) error { expected := sets.NewString() for id := range instanceIDs { expected.Insert(string(id)) @@ -1438,22 +1439,22 @@ func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances actual := sets.NewString() for _, lbInstance := range lbInstances { - actual.Insert(aws.StringValue(lbInstance.InstanceId)) + actual.Insert(aws.ToString(lbInstance.InstanceId)) } additions := expected.Difference(actual) removals := actual.Difference(expected) - addInstances := []*elb.Instance{} + addInstances := []elbtypes.Instance{} for _, instanceID := range additions.List() { - addInstance := &elb.Instance{} + addInstance := elbtypes.Instance{} addInstance.InstanceId = aws.String(instanceID) addInstances = append(addInstances, addInstance) } - removeInstances := []*elb.Instance{} + removeInstances := []elbtypes.Instance{} for _, instanceID := range removals.List() { - removeInstance := &elb.Instance{} + removeInstance := elbtypes.Instance{} removeInstance.InstanceId = aws.String(instanceID) removeInstances = append(removeInstances, removeInstance) } @@ -1462,7 +1463,7 @@ func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances registerRequest := &elb.RegisterInstancesWithLoadBalancerInput{} registerRequest.Instances = addInstances registerRequest.LoadBalancerName = aws.String(loadBalancerName) - _, err := c.elb.RegisterInstancesWithLoadBalancer(registerRequest) + _, err := c.elb.RegisterInstancesWithLoadBalancer(ctx, registerRequest) if err != nil { return err } @@ -1473,7 +1474,7 @@ func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances deregisterRequest := &elb.DeregisterInstancesFromLoadBalancerInput{} deregisterRequest.Instances = removeInstances deregisterRequest.LoadBalancerName = aws.String(loadBalancerName) - _, err := c.elb.DeregisterInstancesFromLoadBalancer(deregisterRequest) + _, err := c.elb.DeregisterInstancesFromLoadBalancer(ctx, deregisterRequest) if err != nil { return err } @@ -1483,33 +1484,30 @@ func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances return nil } -func (c *Cloud) getLoadBalancerTLSPorts(loadBalancer *elb.LoadBalancerDescription) []int64 { +func (c *Cloud) getLoadBalancerTLSPorts(loadBalancer *elbtypes.LoadBalancerDescription) []int64 { ports := []int64{} for _, listenerDescription := range loadBalancer.ListenerDescriptions { - protocol := aws.StringValue(listenerDescription.Listener.Protocol) + protocol := aws.ToString(listenerDescription.Listener.Protocol) if protocol == "SSL" || protocol == "HTTPS" { - ports = append(ports, aws.Int64Value(listenerDescription.Listener.LoadBalancerPort)) + ports = append(ports, int64(listenerDescription.Listener.LoadBalancerPort)) } } return ports } -func (c *Cloud) ensureSSLNegotiationPolicy(loadBalancer *elb.LoadBalancerDescription, policyName string) error { +func (c *Cloud) ensureSSLNegotiationPolicy(ctx context.Context, loadBalancer *elbtypes.LoadBalancerDescription, policyName string) error { klog.V(2).Info("Describing load balancer policies on load balancer") - result, err := c.elb.DescribeLoadBalancerPolicies(&elb.DescribeLoadBalancerPoliciesInput{ + result, err := c.elb.DescribeLoadBalancerPolicies(ctx, &elb.DescribeLoadBalancerPoliciesInput{ LoadBalancerName: loadBalancer.LoadBalancerName, - PolicyNames: []*string{ - aws.String(fmt.Sprintf(SSLNegotiationPolicyNameFormat, policyName)), + PolicyNames: []string{ + fmt.Sprintf(SSLNegotiationPolicyNameFormat, policyName), }, }) if err != nil { - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case elb.ErrCodePolicyNotFoundException: - default: - return fmt.Errorf("error describing security policies on load balancer: %q", err) - } + var notFoundErr *elbtypes.PolicyNotFoundException + if !errors.As(err, ¬FoundErr) { + return fmt.Errorf("error describing security policies on load balancer: %q", err) } } @@ -1520,11 +1518,11 @@ func (c *Cloud) ensureSSLNegotiationPolicy(loadBalancer *elb.LoadBalancerDescrip klog.V(2).Infof("Creating SSL negotiation policy '%s' on load balancer", fmt.Sprintf(SSLNegotiationPolicyNameFormat, policyName)) // there is an upper limit of 98 policies on an ELB, we're pretty safe from // running into it - _, err = c.elb.CreateLoadBalancerPolicy(&elb.CreateLoadBalancerPolicyInput{ + _, err = c.elb.CreateLoadBalancerPolicy(ctx, &elb.CreateLoadBalancerPolicyInput{ LoadBalancerName: loadBalancer.LoadBalancerName, PolicyName: aws.String(fmt.Sprintf(SSLNegotiationPolicyNameFormat, policyName)), PolicyTypeName: aws.String("SSLNegotiationPolicyType"), - PolicyAttributes: []*elb.PolicyAttribute{ + PolicyAttributes: []elbtypes.PolicyAttribute{ { AttributeName: aws.String("Reference-Security-Policy"), AttributeValue: aws.String(policyName), @@ -1537,29 +1535,27 @@ func (c *Cloud) ensureSSLNegotiationPolicy(loadBalancer *elb.LoadBalancerDescrip return nil } -func (c *Cloud) setSSLNegotiationPolicy(loadBalancerName, sslPolicyName string, port int64) error { +func (c *Cloud) setSSLNegotiationPolicy(ctx context.Context, loadBalancerName, sslPolicyName string, port int64) error { policyName := fmt.Sprintf(SSLNegotiationPolicyNameFormat, sslPolicyName) request := &elb.SetLoadBalancerPoliciesOfListenerInput{ LoadBalancerName: aws.String(loadBalancerName), - LoadBalancerPort: aws.Int64(port), - PolicyNames: []*string{ - aws.String(policyName), - }, + LoadBalancerPort: int32(port), + PolicyNames: []string{policyName}, } klog.V(2).Infof("Setting SSL negotiation policy '%s' on load balancer", policyName) - _, err := c.elb.SetLoadBalancerPoliciesOfListener(request) + _, err := c.elb.SetLoadBalancerPoliciesOfListener(ctx, request) if err != nil { return fmt.Errorf("error setting SSL negotiation policy '%s' on load balancer: %q", policyName, err) } return nil } -func (c *Cloud) createProxyProtocolPolicy(loadBalancerName string) error { +func (c *Cloud) createProxyProtocolPolicy(ctx context.Context, loadBalancerName string) error { request := &elb.CreateLoadBalancerPolicyInput{ LoadBalancerName: aws.String(loadBalancerName), PolicyName: aws.String(ProxyProtocolPolicyName), PolicyTypeName: aws.String("ProxyProtocolPolicyType"), - PolicyAttributes: []*elb.PolicyAttribute{ + PolicyAttributes: []elbtypes.PolicyAttribute{ { AttributeName: aws.String("ProxyProtocol"), AttributeValue: aws.String("true"), @@ -1567,7 +1563,7 @@ func (c *Cloud) createProxyProtocolPolicy(loadBalancerName string) error { }, } klog.V(2).Info("Creating proxy protocol policy on load balancer") - _, err := c.elb.CreateLoadBalancerPolicy(request) + _, err := c.elb.CreateLoadBalancerPolicy(ctx, request) if err != nil { return fmt.Errorf("error creating proxy protocol policy on load balancer: %q", err) } @@ -1575,18 +1571,18 @@ func (c *Cloud) createProxyProtocolPolicy(loadBalancerName string) error { return nil } -func (c *Cloud) setBackendPolicies(loadBalancerName string, instancePort int64, policies []*string) error { +func (c *Cloud) setBackendPolicies(ctx context.Context, loadBalancerName string, instancePort *int32, policies []*string) error { request := &elb.SetLoadBalancerPoliciesForBackendServerInput{ - InstancePort: aws.Int64(instancePort), + InstancePort: instancePort, LoadBalancerName: aws.String(loadBalancerName), - PolicyNames: policies, + PolicyNames: aws.ToStringSlice(policies), } if len(policies) > 0 { klog.V(2).Infof("Adding AWS loadbalancer backend policies on node port %d", instancePort) } else { klog.V(2).Infof("Removing AWS loadbalancer backend policies on node port %d", instancePort) } - _, err := c.elb.SetLoadBalancerPoliciesForBackendServer(request) + _, err := c.elb.SetLoadBalancerPoliciesForBackendServer(ctx, request) if err != nil { return fmt.Errorf("error adjusting AWS loadbalancer backend policies: %q", err) } @@ -1594,9 +1590,9 @@ func (c *Cloud) setBackendPolicies(loadBalancerName string, instancePort int64, return nil } -func proxyProtocolEnabled(backend *elb.BackendServerDescription) bool { +func proxyProtocolEnabled(backend elbtypes.BackendServerDescription) bool { for _, policy := range backend.PolicyNames { - if aws.StringValue(policy) == ProxyProtocolPolicyName { + if policy == ProxyProtocolPolicyName { return true } } @@ -1607,7 +1603,7 @@ func proxyProtocolEnabled(backend *elb.BackendServerDescription) bool { // findInstancesForELB gets the EC2 instances corresponding to the Nodes, for setting up an ELB // We ignore Nodes (with a log message) where the instanceid cannot be determined from the provider, // and we ignore instances which are not found -func (c *Cloud) findInstancesForELB(nodes []*v1.Node, annotations map[string]string) (map[InstanceID]*ec2.Instance, error) { +func (c *Cloud) findInstancesForELB(ctx context.Context, nodes []*v1.Node, annotations map[string]string) (map[InstanceID]*ec2types.Instance, error) { targetNodes := filterTargetNodes(nodes, annotations) @@ -1618,7 +1614,7 @@ func (c *Cloud) findInstancesForELB(nodes []*v1.Node, annotations map[string]str MaxAge: defaultEC2InstanceCacheMaxAge, HasInstances: instanceIDs, // Refresh if any of the instance ids are missing } - snapshot, err := c.instanceCache.describeAllInstancesCached(cacheCriteria) + snapshot, err := c.instanceCache.describeAllInstancesCached(ctx, cacheCriteria) if err != nil { return nil, err } @@ -1660,3 +1656,48 @@ func filterTargetNodes(nodes []*v1.Node, annotations map[string]string) []*v1.No return targetNodes } + +// ValidateHealthCheck replaces ELB.HealthCheck.Validate() from AWS SDK Go V1, which has been deprecated in V2 +// V1 implementation: https://github.com/aws/aws-sdk-go/blob/v1.55.7/service/elb/api.go#L5346 +func ValidateHealthCheck(s *elbtypes.HealthCheck) error { + var validationErrors []string + + if s == nil { + validationErrors = append(validationErrors, "HealthCheck is nil") + return fmt.Errorf("HealthCheck validation errors: %s", strings.Join(validationErrors, "; ")) + } + + if s.HealthyThreshold == nil { + validationErrors = append(validationErrors, "HealthyThreshold is required") + } else if *s.HealthyThreshold < 2 { + validationErrors = append(validationErrors, "HealthyThreshold must be at least 2") + } + + if s.Interval == nil { + validationErrors = append(validationErrors, "Interval is required") + } else if *s.Interval < 5 { + validationErrors = append(validationErrors, "Interval must be at least 5") + } + + if s.Target == nil { + validationErrors = append(validationErrors, "Target is required") + } + + if s.Timeout == nil { + validationErrors = append(validationErrors, "Timeout is required") + } else if *s.Timeout < 2 { + validationErrors = append(validationErrors, "Timeout must be at least 2") + } + + if s.UnhealthyThreshold == nil { + validationErrors = append(validationErrors, "UnhealthyThreshold is required") + } else if *s.UnhealthyThreshold < 2 { + validationErrors = append(validationErrors, "UnhealthyThreshold must be at least 2") + } + + if len(validationErrors) > 0 { + return fmt.Errorf("HealthCheck validation errors: %s", strings.Join(validationErrors, "; ")) + } + + return nil +} diff --git a/pkg/providers/v1/aws_loadbalancer_test.go b/pkg/providers/v1/aws_loadbalancer_test.go index 309d9eb209..45cb710fa7 100644 --- a/pkg/providers/v1/aws_loadbalancer_test.go +++ b/pkg/providers/v1/aws_loadbalancer_test.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "context" "fmt" "reflect" "testing" @@ -26,10 +27,10 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/stretchr/testify/assert" "k8s.io/cloud-provider-aws/pkg/providers/v1/config" @@ -215,76 +216,70 @@ func TestSyncElbListeners(t *testing.T) { tests := []struct { name string loadBalancerName string - listeners []*elb.Listener - listenerDescriptions []*elb.ListenerDescription - toCreate []*elb.Listener - toDelete []*int64 + listeners []elbtypes.Listener + listenerDescriptions []elbtypes.ListenerDescription + toCreate []elbtypes.Listener + toDelete []int32 }{ { name: "no edge cases", loadBalancerName: "lb_one", - listeners: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, - {InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, - {InstancePort: aws.Int64(8443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(8443), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + listeners: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, + {InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + {InstancePort: aws.Int32(8443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 8443, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, }, - listenerDescriptions: []*elb.ListenerDescription{ - {Listener: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}}, - {Listener: &elb.Listener{InstancePort: aws.Int64(8443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(8443), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}}, + listenerDescriptions: []elbtypes.ListenerDescription{ + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}}, + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(8443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 8443, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}}, }, - toDelete: []*int64{ - aws.Int64(80), - }, - toCreate: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, - {InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + toDelete: []int32{80}, + toCreate: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, + {InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, }, }, { name: "no listeners to delete", loadBalancerName: "lb_two", - listeners: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, - {InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + listeners: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, + {InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, }, - listenerDescriptions: []*elb.ListenerDescription{ - {Listener: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, + listenerDescriptions: []elbtypes.ListenerDescription{ + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, }, - toCreate: []*elb.Listener{ - {InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + toCreate: []elbtypes.Listener{ + {InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, }, - toDelete: []*int64{}, + toDelete: []int32{}, }, { name: "no listeners to create", loadBalancerName: "lb_three", - listeners: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, - }, - listenerDescriptions: []*elb.ListenerDescription{ - {Listener: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}}, - {Listener: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, + listeners: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, }, - toDelete: []*int64{ - aws.Int64(80), + listenerDescriptions: []elbtypes.ListenerDescription{ + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}}, + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, }, - toCreate: []*elb.Listener{}, + toDelete: []int32{80}, + toCreate: []elbtypes.Listener{}, }, { name: "nil actual listener", loadBalancerName: "lb_four", - listeners: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP")}, + listeners: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP")}, }, - listenerDescriptions: []*elb.ListenerDescription{ - {Listener: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, + listenerDescriptions: []elbtypes.ListenerDescription{ + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, {Listener: nil}, }, - toDelete: []*int64{ - aws.Int64(443), - }, - toCreate: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP")}, + toDelete: []int32{443}, + toCreate: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP")}, }, }, } @@ -301,37 +296,37 @@ func TestSyncElbListeners(t *testing.T) { func TestElbListenersAreEqual(t *testing.T) { tests := []struct { name string - expected, actual *elb.Listener + expected, actual elbtypes.Listener equal bool }{ { name: "should be equal", - expected: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, equal: true, }, { name: "instance port should be different", - expected: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, equal: false, }, { name: "instance protocol should be different", - expected: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, equal: false, }, { name: "load balancer port should be different", - expected: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 443, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, equal: false, }, { name: "protocol should be different", - expected: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("HTTP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("HTTP")}, equal: false, }, } @@ -346,10 +341,10 @@ func TestElbListenersAreEqual(t *testing.T) { func TestBuildTargetGroupName(t *testing.T) { type args struct { serviceName types.NamespacedName - servicePort int64 - nodePort int64 - targetProtocol string - targetType string + servicePort int32 + nodePort int32 + targetProtocol elbv2types.ProtocolEnum + targetType elbv2types.TargetTypeEnum nlbConfig nlbPortMapping } tests := []struct { @@ -365,8 +360,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-7fa2e07508", @@ -378,8 +373,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-719ee635da", @@ -391,8 +386,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "another", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-another-servicea-f66e09847d", @@ -404,8 +399,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-b"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-serviceb-196c19c881", @@ -417,8 +412,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 9090, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-06876706cb", @@ -430,8 +425,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 9090, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-119f844ec0", @@ -443,8 +438,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "UDP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumUdp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-3868761686", @@ -456,8 +451,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "ip", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumIp, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-0fa31f4b0f", @@ -469,8 +464,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "ip", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumIp, nlbConfig: nlbPortMapping{ HealthCheckConfig: healthCheckConfig{ Protocol: "HTTP", @@ -547,11 +542,11 @@ func TestFilterTargetNodes(t *testing.T) { } } -func makeNodeInstancePair(offset int) (*v1.Node, *ec2.Instance) { +func makeNodeInstancePair(offset int) (*v1.Node, *ec2types.Instance) { instanceID := fmt.Sprintf("i-%x", int64(0x03bcc3496da09f78e)+int64(offset)) - instance := &ec2.Instance{ + instance := &ec2types.Instance{ InstanceId: aws.String(instanceID), - Placement: &ec2.Placement{ + Placement: &ec2types.Placement{ AvailabilityZone: aws.String("us-east-1b"), }, PrivateDnsName: aws.String(fmt.Sprintf("ip-192-168-32-%d.ec2.internal", 101+offset)), @@ -559,10 +554,10 @@ func makeNodeInstancePair(offset int) (*v1.Node, *ec2.Instance) { PublicIpAddress: aws.String(fmt.Sprintf("1.2.3.%d", 1+offset)), } - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(fmt.Sprintf("%s%s", TagNameKubernetesClusterPrefix, TestClusterID)) tag.Value = aws.String("owned") - instance.Tags = []*ec2.Tag{&tag} + instance.Tags = []ec2types.Tag{tag} node := &v1.Node{ ObjectMeta: metav1.ObjectMeta{ @@ -592,26 +587,26 @@ func TestCloud_findInstancesForELB(t *testing.T) { return } - want := map[InstanceID]*ec2.Instance{ + want := map[InstanceID]*ec2types.Instance{ "i-self": awsServices.selfInstance, } - got, err := c.findInstancesForELB([]*v1.Node{defaultNode}, nil) + got, err := c.findInstancesForELB(context.TODO(), []*v1.Node{defaultNode}, nil) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(want, got)) // Add a new EC2 instance awsServices.instances = append(awsServices.instances, newInstance) - want = map[InstanceID]*ec2.Instance{ + want = map[InstanceID]*ec2types.Instance{ "i-self": awsServices.selfInstance, - InstanceID(aws.StringValue(newInstance.InstanceId)): newInstance, + InstanceID(aws.ToString(newInstance.InstanceId)): newInstance, } - got, err = c.findInstancesForELB([]*v1.Node{defaultNode, newNode}, nil) + got, err = c.findInstancesForELB(context.TODO(), []*v1.Node{defaultNode, newNode}, nil) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(want, got)) // Verify existing instance cache gets used cacheExpiryOld := c.instanceCache.snapshot.timestamp - got, err = c.findInstancesForELB([]*v1.Node{defaultNode, newNode}, nil) + got, err = c.findInstancesForELB(context.TODO(), []*v1.Node{defaultNode, newNode}, nil) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(want, got)) cacheExpiryNew := c.instanceCache.snapshot.timestamp @@ -620,7 +615,7 @@ func TestCloud_findInstancesForELB(t *testing.T) { // Force cache expiry and verify cache gets updated with new timestamp cacheExpiryOld = c.instanceCache.snapshot.timestamp c.instanceCache.snapshot.timestamp = c.instanceCache.snapshot.timestamp.Add(-(defaultEC2InstanceCacheMaxAge + 1*time.Second)) - got, err = c.findInstancesForELB([]*v1.Node{defaultNode, newNode}, nil) + got, err = c.findInstancesForELB(context.TODO(), []*v1.Node{defaultNode, newNode}, nil) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(want, got)) cacheExpiryNew = c.instanceCache.snapshot.timestamp @@ -629,56 +624,56 @@ func TestCloud_findInstancesForELB(t *testing.T) { func TestCloud_chunkTargetDescriptions(t *testing.T) { type args struct { - targets []*elbv2.TargetDescription + targets []elbv2types.TargetDescription chunkSize int } tests := []struct { name string args args - want [][]*elbv2.TargetDescription + want [][]elbv2types.TargetDescription }{ { name: "can be evenly chunked", args: args{ - targets: []*elbv2.TargetDescription{ + targets: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, chunkSize: 2, }, - want: [][]*elbv2.TargetDescription{ + want: [][]elbv2types.TargetDescription{ { { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, { { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -686,46 +681,46 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { { name: "cannot be evenly chunked", args: args{ - targets: []*elbv2.TargetDescription{ + targets: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, chunkSize: 3, }, - want: [][]*elbv2.TargetDescription{ + want: [][]elbv2types.TargetDescription{ { { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, { { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -733,43 +728,43 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { { name: "chunkSize equal to total count", args: args{ - targets: []*elbv2.TargetDescription{ + targets: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, chunkSize: 4, }, - want: [][]*elbv2.TargetDescription{ + want: [][]elbv2types.TargetDescription{ { { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -777,43 +772,43 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { { name: "chunkSize greater than total count", args: args{ - targets: []*elbv2.TargetDescription{ + targets: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, chunkSize: 10, }, - want: [][]*elbv2.TargetDescription{ + want: [][]elbv2types.TargetDescription{ { { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -829,7 +824,7 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { { name: "chunk empty slice", args: args{ - targets: []*elbv2.TargetDescription{}, + targets: []elbv2types.TargetDescription{}, chunkSize: 2, }, want: nil, @@ -846,38 +841,38 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { func TestCloud_diffTargetGroupTargets(t *testing.T) { type args struct { - expectedTargets []*elbv2.TargetDescription - actualTargets []*elbv2.TargetDescription + expectedTargets []*elbv2types.TargetDescription + actualTargets []*elbv2types.TargetDescription } tests := []struct { name string args args - wantTargetsToRegister []*elbv2.TargetDescription - wantTargetsToDeregister []*elbv2.TargetDescription + wantTargetsToRegister []elbv2types.TargetDescription + wantTargetsToDeregister []elbv2types.TargetDescription }{ { name: "all targets to register", args: args{ - expectedTargets: []*elbv2.TargetDescription{ + expectedTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, actualTargets: nil, }, - wantTargetsToRegister: []*elbv2.TargetDescription{ + wantTargetsToRegister: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, wantTargetsToDeregister: nil, @@ -886,79 +881,79 @@ func TestCloud_diffTargetGroupTargets(t *testing.T) { name: "all targets to deregister", args: args{ expectedTargets: nil, - actualTargets: []*elbv2.TargetDescription{ + actualTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, wantTargetsToRegister: nil, - wantTargetsToDeregister: []*elbv2.TargetDescription{ + wantTargetsToDeregister: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, { name: "some targets to register and deregister", args: args{ - expectedTargets: []*elbv2.TargetDescription{ + expectedTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef5"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, - actualTargets: []*elbv2.TargetDescription{ + actualTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, - wantTargetsToRegister: []*elbv2.TargetDescription{ + wantTargetsToRegister: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdef4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef5"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, - wantTargetsToDeregister: []*elbv2.TargetDescription{ + wantTargetsToDeregister: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -974,32 +969,32 @@ func TestCloud_diffTargetGroupTargets(t *testing.T) { { name: "expected and actual targets equals", args: args{ - expectedTargets: []*elbv2.TargetDescription{ + expectedTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, - actualTargets: []*elbv2.TargetDescription{ + actualTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -1020,12 +1015,12 @@ func TestCloud_diffTargetGroupTargets(t *testing.T) { func TestCloud_computeTargetGroupExpectedTargets(t *testing.T) { type args struct { instanceIDs []string - port int64 + port int32 } tests := []struct { name string args args - want []*elbv2.TargetDescription + want []*elbv2types.TargetDescription }{ { name: "no instance", @@ -1033,7 +1028,7 @@ func TestCloud_computeTargetGroupExpectedTargets(t *testing.T) { instanceIDs: nil, port: 8080, }, - want: []*elbv2.TargetDescription{}, + want: []*elbv2types.TargetDescription{}, }, { name: "one instance", @@ -1041,10 +1036,10 @@ func TestCloud_computeTargetGroupExpectedTargets(t *testing.T) { instanceIDs: []string{"i-abcdef1"}, port: 8080, }, - want: []*elbv2.TargetDescription{ + want: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -1054,18 +1049,18 @@ func TestCloud_computeTargetGroupExpectedTargets(t *testing.T) { instanceIDs: []string{"i-abcdef1", "i-abcdef2", "i-abcdef3"}, port: 8080, }, - want: []*elbv2.TargetDescription{ + want: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, diff --git a/pkg/providers/v1/aws_routes.go b/pkg/providers/v1/aws_routes.go index e3e7c5b7a4..2fea34553e 100644 --- a/pkg/providers/v1/aws_routes.go +++ b/pkg/providers/v1/aws_routes.go @@ -20,22 +20,23 @@ import ( "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "k8s.io/klog/v2" cloudprovider "k8s.io/cloud-provider" ) -func (c *Cloud) findRouteTable(clusterName string) (*ec2.RouteTable, error) { +func (c *Cloud) findRouteTable(ctx context.Context, clusterName string) (*ec2types.RouteTable, error) { // This should be unnecessary (we already filter on TagNameKubernetesCluster, // and something is broken if cluster name doesn't match, but anyway... // TODO: All clouds should be cluster-aware by default - var tables []*ec2.RouteTable + var tables []ec2types.RouteTable if c.cfg.Global.RouteTableID != "" { - request := &ec2.DescribeRouteTablesInput{Filters: []*ec2.Filter{newEc2Filter("route-table-id", c.cfg.Global.RouteTableID)}} - response, err := c.ec2.DescribeRouteTables(request) + request := &ec2.DescribeRouteTablesInput{Filters: []ec2types.Filter{newEc2Filter("route-table-id", c.cfg.Global.RouteTableID)}} + response, err := c.ec2.DescribeRouteTables(ctx, request) if err != nil { return nil, err } @@ -43,7 +44,7 @@ func (c *Cloud) findRouteTable(clusterName string) (*ec2.RouteTable, error) { tables = response } else { request := &ec2.DescribeRouteTablesInput{} - response, err := c.ec2.DescribeRouteTables(request) + response, err := c.ec2.DescribeRouteTables(ctx, request) if err != nil { return nil, err } @@ -62,37 +63,37 @@ func (c *Cloud) findRouteTable(clusterName string) (*ec2.RouteTable, error) { if len(tables) != 1 { return nil, fmt.Errorf("found multiple matching AWS route tables for AWS cluster: %s", clusterName) } - return tables[0], nil + return &tables[0], nil } // ListRoutes implements Routes.ListRoutes // List all routes that match the filter func (c *Cloud) ListRoutes(ctx context.Context, clusterName string) ([]*cloudprovider.Route, error) { - table, err := c.findRouteTable(clusterName) + table, err := c.findRouteTable(ctx, clusterName) if err != nil { return nil, err } var routes []*cloudprovider.Route - var instanceIDs []*string + var instanceIDs []string for _, r := range table.Routes { - instanceID := aws.StringValue(r.InstanceId) + instanceID := aws.ToString(r.InstanceId) if instanceID == "" { continue } - instanceIDs = append(instanceIDs, &instanceID) + instanceIDs = append(instanceIDs, instanceID) } - instances, err := c.getInstancesByIDs(instanceIDs) + instances, err := c.getInstancesByIDs(ctx, instanceIDs) if err != nil { return nil, err } for _, r := range table.Routes { - destinationCIDR := aws.StringValue(r.DestinationCidrBlock) + destinationCIDR := aws.ToString(r.DestinationCidrBlock) if destinationCIDR == "" { continue } @@ -103,14 +104,14 @@ func (c *Cloud) ListRoutes(ctx context.Context, clusterName string) ([]*cloudpro } // Capture blackhole routes - if aws.StringValue(r.State) == ec2.RouteStateBlackhole { + if r.State == ec2types.RouteStateBlackhole { route.Blackhole = true routes = append(routes, route) continue } // Capture instance routes - instanceID := aws.StringValue(r.InstanceId) + instanceID := aws.ToString(r.InstanceId) if instanceID != "" { _, found := instances[instanceID] if found { @@ -130,12 +131,12 @@ func (c *Cloud) ListRoutes(ctx context.Context, clusterName string) ([]*cloudpro } // Sets the instance attribute "source-dest-check" to the specified value -func (c *Cloud) configureInstanceSourceDestCheck(instanceID string, sourceDestCheck bool) error { +func (c *Cloud) configureInstanceSourceDestCheck(ctx context.Context, instanceID string, sourceDestCheck bool) error { request := &ec2.ModifyInstanceAttributeInput{} request.InstanceId = aws.String(instanceID) - request.SourceDestCheck = &ec2.AttributeBooleanValue{Value: aws.Bool(sourceDestCheck)} + request.SourceDestCheck = &ec2types.AttributeBooleanValue{Value: aws.Bool(sourceDestCheck)} - _, err := c.ec2.ModifyInstanceAttribute(request) + _, err := c.ec2.ModifyInstanceAttribute(ctx, request) if err != nil { return fmt.Errorf("error configuring source-dest-check on instance %s: %q", instanceID, err) } @@ -145,46 +146,46 @@ func (c *Cloud) configureInstanceSourceDestCheck(instanceID string, sourceDestCh // CreateRoute implements Routes.CreateRoute // Create the described route func (c *Cloud) CreateRoute(ctx context.Context, clusterName string, nameHint string, route *cloudprovider.Route) error { - instance, err := c.getInstanceByNodeName(route.TargetNode) + instance, err := c.getInstanceByNodeName(ctx, route.TargetNode) if err != nil { return err } // In addition to configuring the route itself, we also need to configure the instance to accept that traffic // On AWS, this requires turning source-dest checks off - err = c.configureInstanceSourceDestCheck(aws.StringValue(instance.InstanceId), false) + err = c.configureInstanceSourceDestCheck(ctx, aws.ToString(instance.InstanceId), false) if err != nil { return err } - table, err := c.findRouteTable(clusterName) + table, err := c.findRouteTable(ctx, clusterName) if err != nil { return err } - var deleteRoute *ec2.Route + var deleteRoute *ec2types.Route for _, r := range table.Routes { - destinationCIDR := aws.StringValue(r.DestinationCidrBlock) + destinationCIDR := aws.ToString(r.DestinationCidrBlock) if destinationCIDR != route.DestinationCIDR { continue } - if aws.StringValue(r.State) == ec2.RouteStateBlackhole { - deleteRoute = r + if r.State == ec2types.RouteStateBlackhole { + deleteRoute = &r } } if deleteRoute != nil { - klog.Infof("deleting blackholed route: %s", aws.StringValue(deleteRoute.DestinationCidrBlock)) + klog.Infof("deleting blackholed route: %s", aws.ToString(deleteRoute.DestinationCidrBlock)) request := &ec2.DeleteRouteInput{} request.DestinationCidrBlock = deleteRoute.DestinationCidrBlock request.RouteTableId = table.RouteTableId - _, err = c.ec2.DeleteRoute(request) + _, err = c.ec2.DeleteRoute(ctx, request) if err != nil { - return fmt.Errorf("error deleting blackholed AWS route (%s): %q", aws.StringValue(deleteRoute.DestinationCidrBlock), err) + return fmt.Errorf("error deleting blackholed AWS route (%s): %q", aws.ToString(deleteRoute.DestinationCidrBlock), err) } } @@ -194,7 +195,7 @@ func (c *Cloud) CreateRoute(ctx context.Context, clusterName string, nameHint st request.InstanceId = instance.InstanceId request.RouteTableId = table.RouteTableId - _, err = c.ec2.CreateRoute(request) + _, err = c.ec2.CreateRoute(ctx, request) if err != nil { return fmt.Errorf("error creating AWS route (%s): %q", route.DestinationCIDR, err) } @@ -205,7 +206,7 @@ func (c *Cloud) CreateRoute(ctx context.Context, clusterName string, nameHint st // DeleteRoute implements Routes.DeleteRoute // Delete the specified route func (c *Cloud) DeleteRoute(ctx context.Context, clusterName string, route *cloudprovider.Route) error { - table, err := c.findRouteTable(clusterName) + table, err := c.findRouteTable(ctx, clusterName) if err != nil { return err } @@ -214,7 +215,7 @@ func (c *Cloud) DeleteRoute(ctx context.Context, clusterName string, route *clou request.DestinationCidrBlock = aws.String(route.DestinationCIDR) request.RouteTableId = table.RouteTableId - _, err = c.ec2.DeleteRoute(request) + _, err = c.ec2.DeleteRoute(ctx, request) if err != nil { return fmt.Errorf("error deleting AWS route (%s): %q", route.DestinationCIDR, err) } diff --git a/pkg/providers/v1/aws_sdk.go b/pkg/providers/v1/aws_sdk.go index bc41b4cbbc..a2a482b058 100644 --- a/pkg/providers/v1/aws_sdk.go +++ b/pkg/providers/v1/aws_sdk.go @@ -17,34 +17,38 @@ limitations under the License. package aws import ( + "context" "fmt" "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" - "github.com/aws/aws-sdk-go/service/kms" - "k8s.io/client-go/pkg/version" - "k8s.io/klog/v2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/aws/retry" + awsConfig "github.com/aws/aws-sdk-go-v2/config" + stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/ec2" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/kms" + smithymiddleware "github.com/aws/smithy-go/middleware" + + "k8s.io/client-go/pkg/version" "k8s.io/cloud-provider-aws/pkg/providers/v1/config" "k8s.io/cloud-provider-aws/pkg/providers/v1/iface" + "k8s.io/klog/v2" ) type awsSDKProvider struct { - creds *credentials.Credentials + creds aws.CredentialsProvider cfg awsCloudConfigProvider mutex sync.Mutex regionDelayers map[string]*CrossRequestRetryDelay } -func newAWSSDKProvider(creds *credentials.Credentials, cfg *config.CloudConfig) *awsSDKProvider { +func newAWSSDKProvider(creds aws.CredentialsProvider, cfg *config.CloudConfig) *awsSDKProvider { return &awsSDKProvider{ creds: creds, cfg: cfg, @@ -52,43 +56,38 @@ func newAWSSDKProvider(creds *credentials.Credentials, cfg *config.CloudConfig) } } -func (p *awsSDKProvider) AddHandlers(regionName string, h *request.Handlers) { - h.Build.PushFrontNamed(request.NamedHandler{ - Name: "k8s/user-agent", - Fn: request.MakeAddToUserAgentHandler("kubernetes", version.Get().String()), - }) - - h.Sign.PushFrontNamed(request.NamedHandler{ - Name: "k8s/logger", - Fn: awsHandlerLogger, - }) +// Adds middleware to AWS SDK Go V2 clients. +func (p *awsSDKProvider) AddMiddleware(ctx context.Context, regionName string, cfg *aws.Config) { + cfg.APIOptions = append(cfg.APIOptions, + middleware.AddUserAgentKeyValue("kubernetes", version.Get().String()), + func(stack *smithymiddleware.Stack) error { + return stack.Finalize.Add(awsHandlerLoggerMiddleware(), smithymiddleware.Before) + }, + ) delayer := p.getCrossRequestRetryDelay(regionName) if delayer != nil { - h.Sign.PushFrontNamed(request.NamedHandler{ - Name: "k8s/delay-presign", - Fn: delayer.BeforeSign, - }) - - h.AfterRetry.PushFrontNamed(request.NamedHandler{ - Name: "k8s/delay-afterretry", - Fn: delayer.AfterRetry, - }) + cfg.APIOptions = append(cfg.APIOptions, + func(stack *smithymiddleware.Stack) error { + stack.Finalize.Add(delayPreSign(delayer), smithymiddleware.Before) + stack.Finalize.Insert(delayAfterRetry(delayer), "Retry", smithymiddleware.Before) + return nil + }, + ) } - p.addAPILoggingHandlers(h) + p.addAPILoggingMiddleware(cfg) } -func (p *awsSDKProvider) addAPILoggingHandlers(h *request.Handlers) { - h.Send.PushBackNamed(request.NamedHandler{ - Name: "k8s/api-request", - Fn: awsSendHandlerLogger, - }) - - h.ValidateResponse.PushFrontNamed(request.NamedHandler{ - Name: "k8s/api-validate-response", - Fn: awsValidateResponseHandlerLogger, - }) +// Adds logging middleware for AWS SDK Go V2 clients +func (p *awsSDKProvider) addAPILoggingMiddleware(cfg *aws.Config) { + cfg.APIOptions = append(cfg.APIOptions, + func(stack *smithymiddleware.Stack) error { + stack.Serialize.Add(awsSendHandlerLoggerMiddleware(), smithymiddleware.After) + stack.Deserialize.Add(awsValidateResponseHandlerLoggerMiddleware(), smithymiddleware.Before) + return nil + }, + ) } // Get a CrossRequestRetryDelay, scoped to the region, not to the request. @@ -112,84 +111,105 @@ func (p *awsSDKProvider) getCrossRequestRetryDelay(regionName string) *CrossRequ return delayer } -func (p *awsSDKProvider) Compute(regionName string) (iface.EC2, error) { - awsConfig := &aws.Config{ - Region: ®ionName, - Credentials: p.creds, +func (p *awsSDKProvider) Compute(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (iface.EC2, error) { + cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion), + awsConfig.WithRegion(regionName), + ) + if assumeRoleProvider != nil { + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.GetResolver()) - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - service := ec2.New(sess) - p.AddHandlers(regionName, &service.Handlers) + p.AddMiddleware(ctx, regionName, &cfg) + var opts []func(*ec2.Options) = p.cfg.GetEC2EndpointOpts(regionName) + opts = append(opts, func(o *ec2.Options) { + o.Retryer = &customRetryer{ + retry.NewStandard(), + } + o.EndpointResolverV2 = p.cfg.GetCustomEC2Resolver() + }) + + ec2Client := ec2.NewFromConfig(cfg, opts...) ec2 := &awsSdkEC2{ - ec2: service, + ec2: ec2Client, } return ec2, nil } -func (p *awsSDKProvider) LoadBalancing(regionName string) (ELB, error) { - awsConfig := &aws.Config{ - Region: ®ionName, - Credentials: p.creds, +func (p *awsSDKProvider) LoadBalancing(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ELB, error) { + cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion), + awsConfig.WithRegion(regionName), + ) + if assumeRoleProvider != nil { + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.GetResolver()) - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - elbClient := elb.New(sess) - p.AddHandlers(regionName, &elbClient.Handlers) + + p.AddMiddleware(ctx, regionName, &cfg) + var opts []func(*elb.Options) = p.cfg.GetELBEndpointOpts(regionName) + opts = append(opts, func(o *elb.Options) { + o.Retryer = &customRetryer{ + retry.NewStandard(), + } + o.EndpointResolverV2 = p.cfg.GetCustomELBResolver() + }) + + elbClient := elb.NewFromConfig(cfg, opts...) return elbClient, nil } -func (p *awsSDKProvider) LoadBalancingV2(regionName string) (ELBV2, error) { - awsConfig := &aws.Config{ - Region: ®ionName, - Credentials: p.creds, +func (p *awsSDKProvider) LoadBalancingV2(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ELBV2, error) { + cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion), + awsConfig.WithRegion(regionName), + ) + if assumeRoleProvider != nil { + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.GetResolver()) - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - elbClient := elbv2.New(sess) - p.AddHandlers(regionName, &elbClient.Handlers) + p.AddMiddleware(ctx, regionName, &cfg) + var opts []func(*elbv2.Options) = p.cfg.GetELBV2EndpointOpts(regionName) + opts = append(opts, func(o *elbv2.Options) { + o.Retryer = &customRetryer{ + retry.NewStandard(), + } + o.EndpointResolverV2 = p.cfg.GetCustomELBV2Resolver() + }) - return elbClient, nil + elbv2Client := elbv2.NewFromConfig(cfg, opts...) + + return elbv2Client, nil } -func (p *awsSDKProvider) Metadata() (config.EC2Metadata, error) { - sess, err := session.NewSession(&aws.Config{ - EndpointResolver: p.cfg.GetResolver(), - }) +func (p *awsSDKProvider) Metadata(ctx context.Context) (config.EC2Metadata, error) { + cfg, err := awsConfig.LoadDefaultConfig(context.TODO(), awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion)) if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - client := ec2metadata.New(sess) - p.addAPILoggingHandlers(&client.Handlers) - identity, err := client.GetInstanceIdentityDocument() + p.addAPILoggingMiddleware(&cfg) + + // Unlike other SDK clients, the IMDS client does not support signing, so any overrides of the signing region and name + // from awsSDKProvider.cfg will not be recognized. + // Standard SDK clients use SigV4: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html + // But IMDS uses a different request pattern: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html + var opts []func(*imds.Options) = p.cfg.GetIMDSEndpointOpts() + imdsClient := imds.NewFromConfig(cfg, opts...) + // opts = append(opts, func(o *imds.Options) { + // o.ClientEnableState = imds.ClientEnabled + // }) + + getInstanceIdentityDocumentOutput, err := imdsClient.GetInstanceIdentityDocument(ctx, &imds.GetInstanceIdentityDocumentInput{}) if err == nil { + identity := getInstanceIdentityDocumentOutput.InstanceIdentityDocument klog.InfoS("instance metadata identity", "region", identity.Region, "availability-zone", identity.AvailabilityZone, @@ -200,26 +220,30 @@ func (p *awsSDKProvider) Metadata() (config.EC2Metadata, error) { "account-id", identity.AccountID, "image-id", identity.ImageID) } - return client, nil + return imdsClient, nil } -func (p *awsSDKProvider) KeyManagement(regionName string) (KMS, error) { - awsConfig := &aws.Config{ - Region: ®ionName, - Credentials: p.creds, +func (p *awsSDKProvider) KeyManagement(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (KMS, error) { + cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion), + awsConfig.WithRegion(regionName), + ) + if assumeRoleProvider != nil { + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.GetResolver()) - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - kmsClient := kms.New(sess) - p.AddHandlers(regionName, &kmsClient.Handlers) + p.AddMiddleware(ctx, regionName, &cfg) + var opts []func(*kms.Options) = p.cfg.GetKMSEndpointOpts(regionName) + opts = append(opts, func(o *kms.Options) { + o.Retryer = &customRetryer{ + retry.NewStandard(), + } + o.EndpointResolverV2 = p.cfg.GetCustomKMSResolver() + }) + + kmsClient := kms.NewFromConfig(cfg, opts...) return kmsClient, nil } diff --git a/pkg/providers/v1/aws_sdk_test.go b/pkg/providers/v1/aws_sdk_test.go new file mode 100644 index 0000000000..5ce40863f0 --- /dev/null +++ b/pkg/providers/v1/aws_sdk_test.go @@ -0,0 +1,636 @@ +package aws + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "regexp" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/ec2" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/kms" + + "github.com/stretchr/testify/assert" + "k8s.io/cloud-provider-aws/pkg/providers/v1/config" +) + +type requestInfo struct { + usedCustomEndpoint bool + credential string +} + +// Given an override, a custom endpoint should be used when making API requests +func TestClientsEndpointOverride(t *testing.T) { + reqInfo := requestInfo{} // stores information about requests, should be reset between API calls + // Dummy server that sets usedCustomEndpoint when called, and collects information about the request + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqInfo.usedCustomEndpoint = true + // Extract credential from auth header + auth := r.Header.Get("Authorization") + credRe := regexp.MustCompile(`Credential=([^,]+)`) + credMatch := credRe.FindStringSubmatch(auth) + if len(credMatch) == 2 { // true when it's able to find exactly one match for the Credential header + reqInfo.credential = credMatch[1] + } + })) + defer testServer.Close() + + // Clients should be able to have their default signing region and name overridden + t.Run("With overridden URL, signing region, and signing name", func(t *testing.T) { + cfgWithServiceOverride := config.CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "custom-service", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "custom-service", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "custom-service", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "custom-service", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // Test EC2 client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating EC2 client, %v", err) + } + _, err = ec2Client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "EC2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "EC2: signing name was not properly overridden") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "EC2: signing region was not properly overridden") + + // Test ELB client + reqInfo = requestInfo{} + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELB client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELB: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "ELB: signing name was not properly overridden") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "ELB: signing region was not properly overridden") + + // Test ELBV2 client + reqInfo = requestInfo{} + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELBV2 client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELBV2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "ELBV2: signing name was not properly overridden") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "ELBV2: signing region was not properly overridden") + + // Test KMS client + reqInfo = requestInfo{} + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating KMS client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, reqInfo.usedCustomEndpoint, "KMS: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "KMS: signing name was not properly overridden") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "KMS: signing region was not properly overridden") + }) + + // When the signing name is overridden but not the signing region, the signing name should be + // whatever is configured in the override, and the signing region should fall back to the request region. + t.Run("With overridden signing name and default region", func(t *testing.T) { + cfgWithServiceOverride := config.CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "custom-service", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "custom-service", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "custom-service", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "custom-service", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // Test EC2 client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating EC2 client, %v", err) + } + _, err = ec2Client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "EC2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "EC2: blank signing region should fall back to request region") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "EC2: signing name was not properly overridden") + + // Test ELB client + reqInfo = requestInfo{} + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELB client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELB: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "ELB: blank signing region should fall back to request region") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "ELB: signing name was not properly overridden") + + // Test ELBV2 client + reqInfo = requestInfo{} + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELBV2 client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELBV2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "ELBV2: blank signing region should fall back to request region") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "ELBV2: signing name was not properly overridden") + + // Test KMS client + reqInfo = requestInfo{} + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating KMS client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, reqInfo.usedCustomEndpoint, "KMS: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "KMS: blank signing region should fall back to request region") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "KMS: signing name was not properly overridden") + }) + + // When the signing region is overridden but not the signing name, the signing region should be + // whatever is configured in the override, and the signing name should fall back to the client's service name. + t.Run("With overriden signing region and default name", func(t *testing.T) { + cfgWithServiceOverride := config.CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // Test EC2 client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating EC2 client, %v", err) + } + _, err = ec2Client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "EC2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ToLower(ec2.ServiceID)), "EC2: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "EC2: signing region was not properly overridden") + + // Test ELB client + reqInfo = requestInfo{} + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELB client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELB: custom endpoint was not used") + // remove whitespace due to multi-word service name + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(elb.ServiceID), " ", "")), "ELB: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "ELB: signing region was not properly overridden") + + // Test ELBV2 client + reqInfo = requestInfo{} + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELBV2 client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELBV2: custom endpoint was not used") + // ELB and ELBV2 use the same default signing name (https://docs.aws.amazon.com/general/latest/gr/elb.html) + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(elb.ServiceID), " ", "")), "ELBV2: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "ELBV2: signing region was not properly overridden") + + // Test KMS client + reqInfo = requestInfo{} + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating KMS client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, reqInfo.usedCustomEndpoint, "KMS: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ToLower(kms.ServiceID)), "KMS: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "KMS: signing region was not properly overridden") + }) + + // When only the URL is overridden, and not the signing region or name, the URL should be whatever is configured in + // the override, the region should fall back to the request region, and the name should fall back to the client's + // service name. + t.Run("Only URL override", func(t *testing.T) { + cfgWithServiceOverride := config.CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + "5": { + Service: imds.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // Test EC2 client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating EC2 client, %v", err) + } + _, err = ec2Client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "EC2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ToLower(ec2.ServiceID)), "EC2: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "EC2: blank signing region should fall back to request region") + + // Test ELB client + reqInfo = requestInfo{} + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELB client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELB: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(elb.ServiceID), " ", "")), "ELB: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "ELB: blank signing region should fall back to request region") + + // Test ELBV2 client + reqInfo = requestInfo{} + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELBV2 client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELBV2: custom endpoint was not used") + // ELB and ELBV2 use the same default signing name (https://docs.aws.amazon.com/general/latest/gr/elb.html) + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(elb.ServiceID), " ", "")), "ELBV2: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "ELBV2: blank signing region should fall back to request region") + + // Test KMS client + reqInfo = requestInfo{} + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating KMS client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, reqInfo.usedCustomEndpoint, "KMS: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ToLower(kms.ServiceID)), "KMS: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "KMS: blank signing region should fall back to request region") + + val := os.Getenv("AWS_EC2_METADATA_DISABLED") + // Test Metadata client. This client only supports overriding the URL, not the signing name and region. + reqInfo = requestInfo{} + // This client can only successfully make requests when AWS_EC2_METADATA_DISABLED = false. + // https://docs.aws.amazon.com/sdkref/latest/guide/feature-imds-credentials.html + os.Setenv("AWS_EC2_METADATA_DISABLED", "false") + // Call Metadata(), which both creates the client and uses it to make a request. + mockProvider.Metadata(context.TODO()) + assert.True(t, reqInfo.usedCustomEndpoint, "IMDS: custom endpoint was not used") + + // Create the client with AWS_EC2_METADATA_DISABLED = true to make sure that requests are not made. + os.Setenv("AWS_EC2_METADATA_DISABLED", "true") + reqInfo = requestInfo{} + mockProvider.Metadata(context.TODO()) + assert.False(t, reqInfo.usedCustomEndpoint, "IMDS: request was completed despite setting AWS_EC2_METADATA_DISABLED=true") + + // reset AWS_EC2_METADATA_DISABLED + os.Setenv("AWS_EC2_METADATA_DISABLED", val) + }) +} + +// Test whether SDK clients refrain from retrying an API request when given a nonRetryableError. +func TestClientsNoRetry(t *testing.T) { + attemptCount := 0 + // Dummy server that counts attempts and returns a nonRetryableError + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + w.Header().Set("Content-Type", "text/xml") + w.WriteHeader(http.StatusBadRequest) + + // Insert the nonRetryableError error message + errorXML := fmt.Sprintf(` + + + + %d + %s + + + 12345678-1234-1234-1234-123456789012 + `, http.StatusBadRequest, nonRetryableError) + + w.Write([]byte(errorXML)) + })) + defer testServer.Close() + + // Override service endpoints with dummy server URL + cfgWithServiceOverride := config.CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // EC2 Client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = ec2Client.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{}) + // Ensure that only 1 attempt was made, signifying no retries + assert.True(t, attemptCount == 1, fmt.Sprintf("expected an attempt count of 1 for EC2 client, got %d", attemptCount)) + + // ELB Client + attemptCount = 0 // reset attempt count for next request + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, attemptCount == 1, fmt.Sprintf("expected an attempt count of 1 for ELB client, got %d", attemptCount)) + + // ELBV2 Client + attemptCount = 0 + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, attemptCount == 1, fmt.Sprintf("expected an attempt count of 1 for ELBV2 client, got %d", attemptCount)) + + // KMS Client + attemptCount = 0 + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, attemptCount == 1, fmt.Sprintf("expected an attempt count of 1 for KMS client, got %d", attemptCount)) +} + +// Test whether SDK clients retry an API request when given a retryable error code. +func TestClientsWithRetry(t *testing.T) { + attemptCount := 0 + // Dummy server that counts attempts and returns a retryable error + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // 500 status codes are retried by SDK (see https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws/retry) + http.Error(w, "RequestTimeout", 500) + })) + + // Override service endpoints with dummy server URL + cfgWithServiceOverride := config.CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // EC2 Client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = ec2Client.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{}) + // Ensure that more than 1 attempt was made, signifying retries + assert.True(t, attemptCount > 1, fmt.Sprintf("expected an attempt count of >1 for EC2 client, got %d", attemptCount)) + + // ELB Client + attemptCount = 0 // Reset the attempt count before the next request + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, attemptCount > 1, fmt.Sprintf("expected an attempt count of >1 for ELB client, got %d", attemptCount)) + + // ELBV2 Client + attemptCount = 0 + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, attemptCount > 1, fmt.Sprintf("expected an attempt count of >1 for ELB client, got %d", attemptCount)) + + // KMS Client + attemptCount = 0 + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, attemptCount > 1, fmt.Sprintf("expected an attempt count of >1 for KMS client, got %d", attemptCount)) +} diff --git a/pkg/providers/v1/aws_test.go b/pkg/providers/v1/aws_test.go index 577f5d72cf..517aec9c96 100644 --- a/pkg/providers/v1/aws_test.go +++ b/pkg/providers/v1/aws_test.go @@ -27,12 +27,16 @@ import ( "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + + "github.com/aws/aws-sdk-go-v2/aws" + + "github.com/aws/smithy-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -56,20 +60,28 @@ type MockedFakeEC2 struct { } func (m *MockedFakeEC2) expectDescribeSecurityGroups(clusterID, groupName string) { - tags := []*ec2.Tag{ + tags := []ec2types.Tag{ {Key: aws.String(TagNameKubernetesClusterLegacy), Value: aws.String(clusterID)}, {Key: aws.String(fmt.Sprintf("%s%s", TagNameKubernetesClusterPrefix, clusterID)), Value: aws.String(ResourceLifecycleOwned)}, } - m.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{Filters: []*ec2.Filter{ + m.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{Filters: []ec2types.Filter{ newEc2Filter("group-name", groupName), newEc2Filter("vpc-id", ""), - }}).Return([]*ec2.SecurityGroup{{Tags: tags}}) + }}).Return([]ec2types.SecurityGroup{{Tags: tags}}) } -func (m *MockedFakeEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { +func (m *MockedFakeEC2) DescribeSecurityGroups(ctx context.Context, request *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) ([]ec2types.SecurityGroup, error) { args := m.Called(request) - return args.Get(0).([]*ec2.SecurityGroup), nil + return args.Get(0).([]ec2types.SecurityGroup), nil +} + +func (m *MockedFakeEC2) DescribeInstanceTopology(ctx context.Context, request *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) ([]ec2types.InstanceTopology, error) { + args := m.Called(ctx, request) + if args.Get(1) != nil { + return nil, args.Get(1).(error) + } + return args.Get(0).([]ec2types.InstanceTopology), nil } type MockedFakeELB struct { @@ -77,23 +89,23 @@ type MockedFakeELB struct { mock.Mock } -func (m *MockedFakeELB) DescribeLoadBalancers(input *elb.DescribeLoadBalancersInput) (*elb.DescribeLoadBalancersOutput, error) { +func (m *MockedFakeELB) DescribeLoadBalancers(ctx context.Context, input *elb.DescribeLoadBalancersInput, optFns ...func(*elb.Options)) (*elb.DescribeLoadBalancersOutput, error) { args := m.Called(input) return args.Get(0).(*elb.DescribeLoadBalancersOutput), nil } func (m *MockedFakeELB) expectDescribeLoadBalancers(loadBalancerName string) { - m.On("DescribeLoadBalancers", &elb.DescribeLoadBalancersInput{LoadBalancerNames: []*string{aws.String(loadBalancerName)}}).Return(&elb.DescribeLoadBalancersOutput{ - LoadBalancerDescriptions: []*elb.LoadBalancerDescription{{}}, + m.On("DescribeLoadBalancers", &elb.DescribeLoadBalancersInput{LoadBalancerNames: []string{loadBalancerName}}).Return(&elb.DescribeLoadBalancersOutput{ + LoadBalancerDescriptions: []elbtypes.LoadBalancerDescription{{}}, }) } -func (m *MockedFakeELB) AddTags(input *elb.AddTagsInput) (*elb.AddTagsOutput, error) { +func (m *MockedFakeELB) AddTags(ctx context.Context, input *elb.AddTagsInput, optFns ...func(*elb.Options)) (*elb.AddTagsOutput, error) { args := m.Called(input) return args.Get(0).(*elb.AddTagsOutput), nil } -func (m *MockedFakeELB) ConfigureHealthCheck(input *elb.ConfigureHealthCheckInput) (*elb.ConfigureHealthCheckOutput, error) { +func (m *MockedFakeELB) ConfigureHealthCheck(ctx context.Context, input *elb.ConfigureHealthCheckInput, optFns ...func(*elb.Options)) (*elb.ConfigureHealthCheckOutput, error) { args := m.Called(input) if args.Get(0) == nil { return nil, args.Error(1) @@ -101,7 +113,7 @@ func (m *MockedFakeELB) ConfigureHealthCheck(input *elb.ConfigureHealthCheckInpu return args.Get(0).(*elb.ConfigureHealthCheckOutput), args.Error(1) } -func (m *MockedFakeELB) expectConfigureHealthCheck(loadBalancerName *string, expectedHC *elb.HealthCheck, returnErr error) { +func (m *MockedFakeELB) expectConfigureHealthCheck(loadBalancerName *string, expectedHC *elbtypes.HealthCheck, returnErr error) { expected := &elb.ConfigureHealthCheckInput{HealthCheck: expectedHC, LoadBalancerName: loadBalancerName} call := m.On("ConfigureHealthCheck", expected) if returnErr != nil { @@ -160,16 +172,15 @@ type ServiceDescriptor struct { signingName string } -func TestOverridesActiveConfig(t *testing.T) { +func TestValidateOverridesActiveConfig(t *testing.T) { tests := []struct { name string reader io.Reader aws Services - expectError bool - active bool - servicesOverridden []ServiceDescriptor + expectError bool + active bool }{ { "No overrides", @@ -178,7 +189,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, false, false, - []ServiceDescriptor{}, }, { "Missing Service Name", @@ -193,7 +203,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, true, false, - []ServiceDescriptor{}, }, { "Missing Service Region", @@ -208,7 +217,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, true, false, - []ServiceDescriptor{}, }, { "Missing URL", @@ -223,7 +231,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, true, false, - []ServiceDescriptor{}, }, { "Missing Signing Region", @@ -238,7 +245,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, true, false, - []ServiceDescriptor{}, }, { "Active Overrides", @@ -254,7 +260,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, false, true, - []ServiceDescriptor{{name: "s3", region: "sregion", signingRegion: "sregion", signingMethod: "v4"}}, }, { "Multiple Overridden Services", @@ -277,8 +282,6 @@ func TestOverridesActiveConfig(t *testing.T) { SigningMethod = v4`), nil, false, true, - []ServiceDescriptor{{name: "s3", region: "sregion1", signingRegion: "sregion1", signingMethod: "v4"}, - {name: "ec2", region: "sregion2", signingRegion: "sregion2", signingMethod: "v4"}}, }, { "Duplicate Services", @@ -301,7 +304,6 @@ func TestOverridesActiveConfig(t *testing.T) { SigningMethod = sign`), nil, true, false, - []ServiceDescriptor{}, }, { "Multiple Overridden Services in Multiple regions", @@ -323,8 +325,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, false, true, - []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: ""}, - {name: "ec2", region: "region2", signingRegion: "sregion", signingMethod: "v4"}}, }, { "Multiple regions, Same Service", @@ -348,8 +348,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, false, true, - []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: "v3"}, - {name: "s3", region: "region2", signingRegion: "sregion1", signingMethod: "v4", signingName: "name"}}, }, } @@ -367,71 +365,6 @@ func TestOverridesActiveConfig(t *testing.T) { if err != nil { t.Errorf("Should succeed for case: %s, got %v", test.name, err) } - - if len(cfg.ServiceOverride) != len(test.servicesOverridden) { - t.Errorf("Expected %d overridden services, received %d for case %s", - len(test.servicesOverridden), len(cfg.ServiceOverride), test.name) - } else { - for _, sd := range test.servicesOverridden { - var found *struct { - Service string - Region string - URL string - SigningRegion string - SigningMethod string - SigningName string - } - for _, v := range cfg.ServiceOverride { - if v.Service == sd.name && v.Region == sd.region { - found = v - break - } - } - if found == nil { - t.Errorf("Missing override for service %s in case %s", - sd.name, test.name) - } else { - if found.SigningRegion != sd.signingRegion { - t.Errorf("Expected signing region '%s', received '%s' for case %s", - sd.signingRegion, found.SigningRegion, test.name) - } - if found.SigningMethod != sd.signingMethod { - t.Errorf("Expected signing method '%s', received '%s' for case %s", - sd.signingMethod, found.SigningRegion, test.name) - } - targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) - if found.URL != targetName { - t.Errorf("Expected Endpoint '%s', received '%s' for case %s", - targetName, found.URL, test.name) - } - if found.SigningName != sd.signingName { - t.Errorf("Expected signing name '%s', received '%s' for case %s", - sd.signingName, found.SigningName, test.name) - } - - fn := cfg.GetResolver() - ep1, e := fn(sd.name, sd.region, nil) - if e != nil { - t.Errorf("Expected a valid endpoint for %s in case %s", - sd.name, test.name) - } else { - targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) - if ep1.URL != targetName { - t.Errorf("Expected endpoint url: %s, received %s in case %s", - targetName, ep1.URL, test.name) - } - if ep1.SigningRegion != sd.signingRegion { - t.Errorf("Expected signing region '%s', received '%s' in case %s", - sd.signingRegion, ep1.SigningRegion, test.name) - } - if ep1.SigningMethod != sd.signingMethod { - t.Errorf("Expected signing method '%s', received '%s' in case %s", - sd.signingMethod, ep1.SigningRegion, test.name) - } - } - } - } - } } } } @@ -481,7 +414,7 @@ func TestNewAWSCloud(t *testing.T) { } } -func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (*Cloud, *FakeAWSServices) { +func mockInstancesResp(selfInstance *ec2types.Instance, instances []*ec2types.Instance) (*Cloud, *FakeAWSServices) { awsServices := newMockedFakeAWSServices(TestClusterID) awsServices.instances = instances awsServices.selfInstance = selfInstance @@ -526,34 +459,34 @@ func testHasNodeAddress(t *testing.T, addrs []v1.NodeAddress, addressType v1.Nod t.Errorf("Did not find expected address: %s:%s in %v", addressType, address, addrs) } -func makeMinimalInstance(instanceID string) ec2.Instance { +func makeMinimalInstance(instanceID string) ec2types.Instance { return makeInstance(instanceID, "", "", "", "", nil, false) } -func makeInstance(instanceID string, privateIP, publicIP, privateDNSName, publicDNSName string, ipv6s []string, setNetInterface bool) ec2.Instance { - var tag ec2.Tag +func makeInstance(instanceID string, privateIP, publicIP, privateDNSName, publicDNSName string, ipv6s []string, setNetInterface bool) ec2types.Instance { + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterLegacy) tag.Value = aws.String(TestClusterID) - tags := []*ec2.Tag{&tag} + tags := []ec2types.Tag{tag} - instance := ec2.Instance{ + instance := ec2types.Instance{ InstanceId: &instanceID, PrivateDnsName: aws.String(privateDNSName), PrivateIpAddress: aws.String(privateIP), PublicDnsName: aws.String(publicDNSName), PublicIpAddress: aws.String(publicIP), - InstanceType: aws.String("c3.large"), + InstanceType: ec2types.InstanceTypeC3Large, Tags: tags, - Placement: &ec2.Placement{AvailabilityZone: aws.String("us-west-2a")}, - State: &ec2.InstanceState{ - Name: aws.String("running"), + Placement: &ec2types.Placement{AvailabilityZone: aws.String("us-west-2a")}, + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, } if setNetInterface == true { - instance.NetworkInterfaces = []*ec2.InstanceNetworkInterface{ + instance.NetworkInterfaces = []ec2types.InstanceNetworkInterface{ { - Status: aws.String(ec2.NetworkInterfaceStatusInUse), - PrivateIpAddresses: []*ec2.InstancePrivateIpAddress{ + Status: ec2types.NetworkInterfaceStatusInUse, + PrivateIpAddresses: []ec2types.InstancePrivateIpAddress{ { PrivateIpAddress: aws.String(privateIP), }, @@ -561,7 +494,7 @@ func makeInstance(instanceID string, privateIP, publicIP, privateDNSName, public }, } if len(ipv6s) > 0 { - instance.NetworkInterfaces[0].Ipv6Addresses = []*ec2.InstanceIpv6Address{ + instance.NetworkInterfaces[0].Ipv6Addresses = []ec2types.InstanceIpv6Address{ { Ipv6Address: aws.String(ipv6s[0]), }, @@ -625,7 +558,7 @@ func TestNodeAddressesByProviderID(t *testing.T) { } { t.Run(tc.Name, func(t *testing.T) { instance := makeInstance(tc.InstanceID, tc.PrivateIP, tc.PublicIP, tc.PrivateDNSName, tc.PublicDNSName, tc.Ipv6s, tc.SetNetInterface) - aws1, _ := mockInstancesResp(&instance, []*ec2.Instance{&instance}) + aws1, _ := mockInstancesResp(&instance, []*ec2types.Instance{&instance}) _, err := aws1.NodeAddressesByProviderID(context.TODO(), "i-xxx") if err == nil { t.Errorf("Should error when no instance found") @@ -731,7 +664,7 @@ func TestNodeAddresses(t *testing.T) { } { t.Run(tc.Name, func(t *testing.T) { instance := makeInstance(tc.InstanceID, tc.PrivateIP, tc.PublicIP, tc.PrivateDNSName, tc.PublicDNSName, tc.Ipv6s, tc.SetNetInterface) - aws1, _ := mockInstancesResp(&instance, []*ec2.Instance{&instance}) + aws1, _ := mockInstancesResp(&instance, []*ec2types.Instance{&instance}) _, err := aws1.NodeAddresses(context.TODO(), "instance-mismatch.ec2.internal") if err == nil { t.Errorf("Should error when no instance found") @@ -795,7 +728,7 @@ func TestFindVPCID(t *testing.T) { t.Errorf("Error building aws cloud: %v", err) return } - vpcID, err := c.findVPCID() + vpcID, err := c.findVPCID(context.TODO()) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -804,7 +737,7 @@ func TestFindVPCID(t *testing.T) { } } -func constructSubnets(subnetsIn map[int]map[string]string) (subnetsOut []*ec2.Subnet) { +func constructSubnets(subnetsIn map[int]map[string]string) (subnetsOut []*ec2types.Subnet) { for i := range subnetsIn { subnetsOut = append( subnetsOut, @@ -817,18 +750,18 @@ func constructSubnets(subnetsIn map[int]map[string]string) (subnetsOut []*ec2.Su return } -func constructSubnet(id string, az string) *ec2.Subnet { - return &ec2.Subnet{ +func constructSubnet(id string, az string) *ec2types.Subnet { + return &ec2types.Subnet{ SubnetId: &id, AvailabilityZone: &az, } } -func constructRouteTables(routeTablesIn map[string]bool) (routeTablesOut []*ec2.RouteTable) { +func constructRouteTables(routeTablesIn map[string]bool) (routeTablesOut []*ec2types.RouteTable) { routeTablesOut = append(routeTablesOut, - &ec2.RouteTable{ - Associations: []*ec2.RouteTableAssociation{{Main: aws.Bool(true)}}, - Routes: []*ec2.Route{{ + &ec2types.RouteTable{ + Associations: []ec2types.RouteTableAssociation{{Main: aws.Bool(true)}}, + Routes: []ec2types.Route{{ DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: aws.String("igw-main"), }}, @@ -846,16 +779,16 @@ func constructRouteTables(routeTablesIn map[string]bool) (routeTablesOut []*ec2. return } -func constructRouteTable(subnetID string, public bool) *ec2.RouteTable { +func constructRouteTable(subnetID string, public bool) *ec2types.RouteTable { var gatewayID string if public { gatewayID = "igw-" + subnetID[len(subnetID)-8:8] } else { gatewayID = "vgw-" + subnetID[len(subnetID)-8:8] } - return &ec2.RouteTable{ - Associations: []*ec2.RouteTableAssociation{{SubnetId: aws.String(subnetID)}}, - Routes: []*ec2.Route{{ + return &ec2types.RouteTable{ + Associations: []ec2types.RouteTableAssociation{{SubnetId: aws.String(subnetID)}}, + Routes: []ec2types.Route{{ DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: aws.String(gatewayID), }}, @@ -869,30 +802,30 @@ func Test_findELBSubnets(t *testing.T) { t.Errorf("Error building aws cloud: %v", err) return } - subnetA0000001 := &ec2.Subnet{ + subnetA0000001 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2a"), SubnetId: aws.String("subnet-a0000001"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(TagNameSubnetPublicELB), Value: aws.String("1"), }, }, } - subnetA0000002 := &ec2.Subnet{ + subnetA0000002 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2a"), SubnetId: aws.String("subnet-a0000002"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(TagNameSubnetPublicELB), Value: aws.String("1"), }, }, } - subnetA0000003 := &ec2.Subnet{ + subnetA0000003 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2a"), SubnetId: aws.String("subnet-a0000003"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -903,10 +836,10 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetB0000001 := &ec2.Subnet{ + subnetB0000001 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2b"), SubnetId: aws.String("subnet-b0000001"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -917,10 +850,10 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetB0000002 := &ec2.Subnet{ + subnetB0000002 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2b"), SubnetId: aws.String("subnet-b0000002"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -931,10 +864,10 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetC0000001 := &ec2.Subnet{ + subnetC0000001 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-c0000001"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -945,10 +878,10 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetOther := &ec2.Subnet{ + subnetOther := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-other"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(TagNameKubernetesClusterPrefix + "clusterid.other"), Value: aws.String("owned"), @@ -959,24 +892,24 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetNoTag := &ec2.Subnet{ + subnetNoTag := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-notag"), } - subnetLocalZone := &ec2.Subnet{ + subnetLocalZone := &ec2types.Subnet{ AvailabilityZone: aws.String("az-local"), SubnetId: aws.String("subnet-in-local-zone"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), }, }, } - subnetWavelengthZone := &ec2.Subnet{ + subnetWavelengthZone := &ec2types.Subnet{ AvailabilityZone: aws.String("az-wavelength"), SubnetId: aws.String("subnet-in-wavelength-zone"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -986,7 +919,7 @@ func Test_findELBSubnets(t *testing.T) { tests := []struct { name string - subnets []*ec2.Subnet + subnets []*ec2types.Subnet routeTables map[string]bool internal bool want []string @@ -996,7 +929,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "single tagged subnet", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, }, routeTables: map[string]bool{ @@ -1007,7 +940,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "no matching public subnet", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000002, }, routeTables: map[string]bool{ @@ -1017,7 +950,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "prefer role over cluster tag", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetA0000003, }, @@ -1029,7 +962,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "prefer cluster tag", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetC0000001, subnetNoTag, }, @@ -1037,7 +970,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "include untagged", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetNoTag, }, @@ -1049,7 +982,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "ignore some other cluster owned subnet", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetB0000001, subnetOther, }, @@ -1061,7 +994,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "prefer matching role", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetB0000001, subnetB0000002, }, @@ -1074,7 +1007,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "choose lexicographic order", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetA0000002, }, @@ -1086,7 +1019,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "everything", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetA0000002, subnetB0000001, @@ -1108,7 +1041,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "exclude subnets from local and wavelenght zones", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetB0000001, subnetC0000001, @@ -1129,7 +1062,7 @@ func Test_findELBSubnets(t *testing.T) { for _, rt := range routeTables { awsServices.ec2.CreateRouteTable(rt) } - got, _ := c.findELBSubnets(tt.internal) + got, _ := c.findELBSubnets(context.TODO(), tt.internal) sort.Strings(tt.want) sort.Strings(got) assert.Equal(t, tt.want, got) @@ -1147,7 +1080,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { tests := []struct { name string service *v1.Service - subnets []*ec2.Subnet + subnets []*ec2types.Subnet internalELB bool want []string wantErr error @@ -1169,7 +1102,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { }, { name: "subnet ids", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ { AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-a000001"), @@ -1190,7 +1123,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { }, { name: "subnet names", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ { AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-a000001"), @@ -1211,7 +1144,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { }, { name: "unable to find all subnets", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ { AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-a000001"), @@ -1233,7 +1166,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { for _, subnet := range tt.subnets { awsServices.ec2.CreateSubnet(subnet) } - got, err := c.getLoadBalancerSubnets(tt.service, tt.internalELB) + got, err := c.getLoadBalancerSubnets(context.TODO(), tt.service, tt.internalELB) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) } else { @@ -1279,7 +1212,7 @@ func TestSubnetIDsinVPC(t *testing.T) { awsServices.ec2.CreateRouteTable(rt) } - result, err := c.findELBSubnets(false) + result, err := c.findELBSubnets(context.TODO(), false) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -1309,7 +1242,7 @@ func TestSubnetIDsinVPC(t *testing.T) { awsServices.ec2.CreateRouteTable(rt) } - result, err = c.findELBSubnets(false) + result, err = c.findELBSubnets(context.TODO(), false) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -1355,7 +1288,7 @@ func TestSubnetIDsinVPC(t *testing.T) { awsServices.ec2.CreateRouteTable(rt) } - result, err = c.findELBSubnets(false) + result, err = c.findELBSubnets(context.TODO(), false) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -1366,7 +1299,7 @@ func TestSubnetIDsinVPC(t *testing.T) { return } - expected := []*string{aws.String("subnet-a0000001"), aws.String("subnet-b0000001"), aws.String("subnet-c0000000")} + expected := []string{"subnet-a0000001", "subnet-b0000001", "subnet-c0000000"} for _, s := range result { if !contains(expected, s) { t.Errorf("Unexpected subnet '%s' found", s) @@ -1402,7 +1335,7 @@ func TestSubnetIDsinVPC(t *testing.T) { for _, rt := range constructedRouteTables { awsServices.ec2.CreateRouteTable(rt) } - result, err = c.findELBSubnets(false) + result, err = c.findELBSubnets(context.TODO(), false) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -1413,7 +1346,7 @@ func TestSubnetIDsinVPC(t *testing.T) { return } - expected = []*string{aws.String("subnet-c0000000"), aws.String("subnet-d0000001"), aws.String("subnet-d0000002")} + expected = []string{"subnet-c0000000", "subnet-d0000001", "subnet-d0000002"} for _, s := range result { if !contains(expected, s) { t.Errorf("Unexpected subnet '%s' found", s) @@ -1423,22 +1356,22 @@ func TestSubnetIDsinVPC(t *testing.T) { } func TestIpPermissionExistsHandlesMultipleGroupIds(t *testing.T) { - oldIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + oldIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("firstGroupId")}, {GroupId: aws.String("secondGroupId")}, {GroupId: aws.String("thirdGroupId")}, }, } - existingIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + existingIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("secondGroupId")}, }, } - newIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + newIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("fourthGroupId")}, }, } @@ -1454,8 +1387,8 @@ func TestIpPermissionExistsHandlesMultipleGroupIds(t *testing.T) { } // The first pair matches, but the second does not - newIPPermission2 := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + newIPPermission2 := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("firstGroupId")}, {GroupId: aws.String("fourthGroupId")}, }, @@ -1468,29 +1401,29 @@ func TestIpPermissionExistsHandlesMultipleGroupIds(t *testing.T) { func TestIpPermissionExistsHandlesRangeSubsets(t *testing.T) { // Two existing scenarios we'll test against - emptyIPPermission := ec2.IpPermission{} + emptyIPPermission := ec2types.IpPermission{} - oldIPPermission := ec2.IpPermission{ - IpRanges: []*ec2.IpRange{ + oldIPPermission := ec2types.IpPermission{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("10.0.0.0/8")}, {CidrIp: aws.String("192.168.1.0/24")}, }, } // Two already existing ranges and a new one - existingIPPermission := ec2.IpPermission{ - IpRanges: []*ec2.IpRange{ + existingIPPermission := ec2types.IpPermission{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("10.0.0.0/8")}, }, } - existingIPPermission2 := ec2.IpPermission{ - IpRanges: []*ec2.IpRange{ + existingIPPermission2 := ec2types.IpPermission{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("192.168.1.0/24")}, }, } - newIPPermission := ec2.IpPermission{ - IpRanges: []*ec2.IpRange{ + newIPPermission := ec2types.IpPermission{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("172.16.0.0/16")}, }, } @@ -1524,22 +1457,22 @@ func TestIpPermissionExistsHandlesRangeSubsets(t *testing.T) { } func TestIpPermissionExistsHandlesMultipleGroupIdsWithUserIds(t *testing.T) { - oldIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + oldIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("firstGroupId"), UserId: aws.String("firstUserId")}, {GroupId: aws.String("secondGroupId"), UserId: aws.String("secondUserId")}, {GroupId: aws.String("thirdGroupId"), UserId: aws.String("thirdUserId")}, }, } - existingIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + existingIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("secondGroupId"), UserId: aws.String("secondUserId")}, }, } - newIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + newIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("secondGroupId"), UserId: aws.String("anotherUserId")}, }, } @@ -1557,35 +1490,35 @@ func TestIpPermissionExistsHandlesMultipleGroupIdsWithUserIds(t *testing.T) { func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) { awsStates := []struct { - id int64 - state string + id int32 + state ec2types.InstanceStateName expected bool }{ - {0, ec2.InstanceStateNamePending, true}, - {16, ec2.InstanceStateNameRunning, true}, - {32, ec2.InstanceStateNameShuttingDown, true}, - {48, ec2.InstanceStateNameTerminated, false}, - {64, ec2.InstanceStateNameStopping, true}, - {80, ec2.InstanceStateNameStopped, true}, + {0, ec2types.InstanceStateNamePending, true}, + {16, ec2types.InstanceStateNameRunning, true}, + {32, ec2types.InstanceStateNameShuttingDown, true}, + {48, ec2types.InstanceStateNameTerminated, false}, + {64, ec2types.InstanceStateNameStopping, true}, + {80, ec2types.InstanceStateNameStopped, true}, } awsServices := newMockedFakeAWSServices(TestClusterID) nodeName := types.NodeName("my-dns.internal") - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterLegacy) tag.Value = aws.String(TestClusterID) - tags := []*ec2.Tag{&tag} + tags := []ec2types.Tag{tag} - var testInstance ec2.Instance + var testInstance ec2types.Instance testInstance.PrivateDnsName = aws.String(string(nodeName)) testInstance.Tags = tags awsDefaultInstances := awsServices.instances for _, awsState := range awsStates { - id := "i-" + awsState.state + id := string("i-" + awsState.state) testInstance.InstanceId = aws.String(id) - testInstance.State = &ec2.InstanceState{Code: aws.Int64(awsState.id), Name: aws.String(awsState.state)} + testInstance.State = &ec2types.InstanceState{Code: aws.Int32(awsState.id), Name: awsState.state} awsServices.instances = append(awsDefaultInstances, &testInstance) @@ -1595,7 +1528,7 @@ func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) { return } - resultInstance, err := c.findInstanceByNodeName(nodeName) + resultInstance, err := c.findInstanceByNodeName(context.TODO(), nodeName) if awsState.expected { if err != nil || resultInstance == nil { @@ -1619,25 +1552,24 @@ func TestGetInstanceByNodeNameBatching(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) c, err := newAWSCloud(config.CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterPrefix + TestClusterID) tag.Value = aws.String("") - tags := []*ec2.Tag{&tag} + tags := []ec2types.Tag{tag} nodeNames := []string{} for i := 0; i < 200; i++ { nodeName := fmt.Sprintf("ip-171-20-42-%d.ec2.internal", i) nodeNames = append(nodeNames, nodeName) - ec2Instance := &ec2.Instance{} + ec2Instance := &ec2types.Instance{} instanceID := fmt.Sprintf("i-abcedf%d", i) ec2Instance.InstanceId = aws.String(instanceID) ec2Instance.PrivateDnsName = aws.String(nodeName) - ec2Instance.State = &ec2.InstanceState{Code: aws.Int64(48), Name: aws.String("running")} + ec2Instance.State = &ec2types.InstanceState{Code: aws.Int32(48), Name: ec2types.InstanceStateNameRunning} ec2Instance.Tags = tags awsServices.instances = append(awsServices.instances, ec2Instance) - } - instances, err := c.getInstancesByNodeNames(nodeNames) + instances, err := c.getInstancesByNodeNames(context.TODO(), nodeNames) assert.Nil(t, err, "Error getting instances by nodeNames %v: %v", nodeNames, err) assert.NotEmpty(t, instances) assert.Equal(t, 200, len(instances), "Expected 200 but got less") @@ -1785,9 +1717,9 @@ func TestBuildListener(t *testing.T) { tests := []struct { name string - lbPort int64 + lbPort int32 portName string - instancePort int64 + instancePort int32 backendProtocolAnnotation string certAnnotation string sslPortAnnotation string @@ -1897,10 +1829,10 @@ func TestBuildListener(t *testing.T) { if test.certID != "" { cert = &test.certID } - expected := &elb.Listener{ + expected := elbtypes.Listener{ InstancePort: &test.instancePort, InstanceProtocol: &test.instanceProtocol, - LoadBalancerPort: &test.lbPort, + LoadBalancerPort: test.lbPort, Protocol: &test.lbProtocol, SSLCertificateId: cert, } @@ -1914,27 +1846,25 @@ func TestBuildListener(t *testing.T) { } func TestProxyProtocolEnabled(t *testing.T) { - policies := sets.NewString(ProxyProtocolPolicyName, "FooBarFoo") - fakeBackend := &elb.BackendServerDescription{ - InstancePort: aws.Int64(80), - PolicyNames: stringSetToPointers(policies), + policies := []string{ProxyProtocolPolicyName, "FooBarFoo"} + fakeBackend := elbtypes.BackendServerDescription{ + InstancePort: aws.Int32(80), + PolicyNames: policies, } result := proxyProtocolEnabled(fakeBackend) assert.True(t, result, "expected to find %s in %s", ProxyProtocolPolicyName, policies) - policies = sets.NewString("FooBarFoo") - fakeBackend = &elb.BackendServerDescription{ - InstancePort: aws.Int64(80), - PolicyNames: []*string{ - aws.String("FooBarFoo"), - }, + policies = []string{"FooBarFoo"} + fakeBackend = elbtypes.BackendServerDescription{ + InstancePort: aws.Int32(80), + PolicyNames: []string{"FooBarFoo"}, } result = proxyProtocolEnabled(fakeBackend) assert.False(t, result, "did not expect to find %s in %s", ProxyProtocolPolicyName, policies) - policies = sets.NewString() - fakeBackend = &elb.BackendServerDescription{ - InstancePort: aws.Int64(80), + policies = []string{} + fakeBackend = elbtypes.BackendServerDescription{ + InstancePort: aws.Int32(80), } result = proxyProtocolEnabled(fakeBackend) assert.False(t, result, "did not expect to find %s in %s", ProxyProtocolPolicyName, policies) @@ -2031,7 +1961,7 @@ func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { t.Run(test.name, func(t *testing.T) { serviceName := types.NamespacedName{Namespace: "default", Name: "myservice"} - sgList, setupSg, err := c.buildELBSecurityGroupList(serviceName, "aid", test.annotations) + sgList, setupSg, err := c.buildELBSecurityGroupList(context.TODO(), serviceName, "aid", test.annotations) assert.NoError(t, err, "buildELBSecurityGroupList failed") extraSGs := sgList[1:] assert.True(t, sets.NewString(test.expectedSGs...).Equal(sets.NewString(extraSGs...)), @@ -2065,7 +1995,7 @@ func TestLBSecurityGroupsAnnotation(t *testing.T) { t.Run(test.name, func(t *testing.T) { serviceName := types.NamespacedName{Namespace: "default", Name: "myservice"} - sgList, setupSg, err := c.buildELBSecurityGroupList(serviceName, "aid", test.annotations) + sgList, setupSg, err := c.buildELBSecurityGroupList(context.TODO(), serviceName, "aid", test.annotations) assert.NoError(t, err, "buildELBSecurityGroupList failed") assert.True(t, sets.NewString(test.expectedSGs...).Equal(sets.NewString(sgList...)), "Security Groups expected=%q , returned=%q", test.expectedSGs, sgList) @@ -2084,8 +2014,8 @@ func TestAddLoadBalancerTags(t *testing.T) { want["tag1"] = "val1" expectedAddTagsRequest := &elb.AddTagsInput{ - LoadBalancerNames: []*string{&loadBalancerName}, - Tags: []*elb.Tag{ + LoadBalancerNames: []string{loadBalancerName}, + Tags: []elbtypes.Tag{ { Key: aws.String("tag1"), Value: aws.String("val1"), @@ -2094,7 +2024,7 @@ func TestAddLoadBalancerTags(t *testing.T) { } awsServices.elb.(*MockedFakeELB).On("AddTags", expectedAddTagsRequest).Return(&elb.AddTagsOutput{}) - err := c.addLoadBalancerTags(loadBalancerName, want) + err := c.addLoadBalancerTags(context.TODO(), loadBalancerName, want) assert.Nil(t, err, "Error adding load balancer tags: %v", err) awsServices.elb.(*MockedFakeELB).AssertExpectations(t) } @@ -2103,60 +2033,60 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { tests := []struct { name string annotations map[string]string - want elb.HealthCheck + want elbtypes.HealthCheck }{ { name: "falls back to HC defaults", annotations: map[string]string{}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, { name: "healthy threshold override", annotations: map[string]string{ServiceAnnotationLoadBalancerHCHealthyThreshold: "7"}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(7), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(7), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, { name: "unhealthy threshold override", annotations: map[string]string{ServiceAnnotationLoadBalancerHCUnhealthyThreshold: "7"}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(7), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(7), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, { name: "timeout override", annotations: map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "7"}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(7), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(7), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, { name: "interval override", annotations: map[string]string{ServiceAnnotationLoadBalancerHCInterval: "7"}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(7), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(7), Target: aws.String("TCP:8080"), }, }, @@ -2165,11 +2095,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { annotations: map[string]string{ ServiceAnnotationLoadBalancerHealthCheckPort: "2122", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:2122"), }, }, @@ -2178,11 +2108,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { annotations: map[string]string{ ServiceAnnotationLoadBalancerHealthCheckProtocol: "HTTP", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("HTTP:8080/"), }, }, @@ -2193,11 +2123,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { ServiceAnnotationLoadBalancerHealthCheckPath: "/healthz", ServiceAnnotationLoadBalancerHealthCheckPort: "31224", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("HTTPS:31224/healthz"), }, }, @@ -2208,11 +2138,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { ServiceAnnotationLoadBalancerHealthCheckPath: "/healthz", ServiceAnnotationLoadBalancerHealthCheckPort: "3124", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("SSL:3124"), }, }, @@ -2222,11 +2152,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { ServiceAnnotationLoadBalancerHealthCheckProtocol: "TCP", ServiceAnnotationLoadBalancerHealthCheckPort: "traffic-port", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, @@ -2234,15 +2164,15 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { lbName := "myLB" // this HC will always differ from the expected HC and thus it is expected an // API call will be made to update it - currentHC := &elb.HealthCheck{} - elbDesc := &elb.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: currentHC} - defaultHealthyThreshold := int64(2) - defaultUnhealthyThreshold := int64(6) - defaultTimeout := int64(5) - defaultInterval := int64(10) + currentHC := &elbtypes.HealthCheck{} + elbDesc := &elbtypes.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: currentHC} + defaultHealthyThreshold := int32(2) + defaultUnhealthyThreshold := int32(6) + defaultTimeout := int32(5) + defaultInterval := int32(10) protocol, path, port := "TCP", "", int32(8080) target := "TCP:8080" - defaultHC := &elb.HealthCheck{ + defaultHC := &elbtypes.HealthCheck{ HealthyThreshold: &defaultHealthyThreshold, UnhealthyThreshold: &defaultUnhealthyThreshold, Timeout: &defaultTimeout, @@ -2257,7 +2187,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { expectedHC := test.want awsServices.elb.(*MockedFakeELB).expectConfigureHealthCheck(&lbName, &expectedHC, nil) - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, test.annotations) + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, test.annotations) require.NoError(t, err) awsServices.elb.(*MockedFakeELB).AssertExpectations(t) @@ -2269,20 +2199,20 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { c, err := newAWSCloud(config.CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC - timeout := int64(3) - expectedHC.Timeout = &timeout + timeout := int32(3) + expectedHC.Timeout = aws.Int32(timeout) annotations := map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "3"} - var currentHC elb.HealthCheck + var currentHC elbtypes.HealthCheck currentHC = expectedHC // NOTE no call expectations are set on the ELB mock // test default HC - elbDesc := &elb.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: defaultHC} - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, map[string]string{}) + elbDesc := &elbtypes.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: defaultHC} + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, map[string]string{}) assert.NoError(t, err) // test HC with override - elbDesc = &elb.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: ¤tHC} - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, annotations) + elbDesc = &elbtypes.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: ¤tHC} + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, annotations) assert.NoError(t, err) }) @@ -2291,13 +2221,13 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { c, err := newAWSCloud(config.CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC - invalidThreshold := int64(1) - expectedHC.HealthyThreshold = &invalidThreshold - require.Error(t, expectedHC.Validate()) // confirm test precondition + invalidThreshold := int32(1) + expectedHC.HealthyThreshold = aws.Int32(invalidThreshold) + require.Error(t, ValidateHealthCheck(&expectedHC)) // confirm test precondition annotations := map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "1"} // NOTE no call expectations are set on the ELB mock - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, annotations) + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, annotations) require.Error(t, err) }) @@ -2309,7 +2239,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { annotations := map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "3.3"} // NOTE no call expectations are set on the ELB mock - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, annotations) + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, annotations) require.Error(t, err) }) @@ -2321,7 +2251,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { returnErr := fmt.Errorf("throttling error") awsServices.elb.(*MockedFakeELB).expectConfigureHealthCheck(&lbName, defaultHC, returnErr) - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, map[string]string{}) + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, map[string]string{}) require.Error(t, err) awsServices.elb.(*MockedFakeELB).AssertExpectations(t) @@ -2329,8 +2259,8 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { } func TestFindSecurityGroupForInstance(t *testing.T) { - groups := map[string]*ec2.SecurityGroup{"sg123": {GroupId: aws.String("sg123")}} - id, err := findSecurityGroupForInstance(&ec2.Instance{SecurityGroups: []*ec2.GroupIdentifier{{GroupId: aws.String("sg123"), GroupName: aws.String("my_group")}}}, groups) + groups := map[string]*ec2types.SecurityGroup{"sg123": {GroupId: aws.String("sg123")}} + id, err := findSecurityGroupForInstance(&ec2types.Instance{SecurityGroups: []ec2types.GroupIdentifier{{GroupId: aws.String("sg123"), GroupName: aws.String("my_group")}}}, groups) if err != nil { t.Error() } @@ -2339,9 +2269,9 @@ func TestFindSecurityGroupForInstance(t *testing.T) { } func TestFindSecurityGroupForInstanceMultipleTagged(t *testing.T) { - groups := map[string]*ec2.SecurityGroup{"sg123": {GroupId: aws.String("sg123")}} - _, err := findSecurityGroupForInstance(&ec2.Instance{ - SecurityGroups: []*ec2.GroupIdentifier{ + groups := map[string]*ec2types.SecurityGroup{"sg123": {GroupId: aws.String("sg123")}} + _, err := findSecurityGroupForInstance(&ec2types.Instance{ + SecurityGroups: []ec2types.GroupIdentifier{ {GroupId: aws.String("sg123"), GroupName: aws.String("my_group")}, {GroupId: aws.String("sg123"), GroupName: aws.String("another_group")}, }, @@ -2498,39 +2428,39 @@ func informerNotSynced() bool { } type MockedFakeELBV2 struct { - LoadBalancers []*elbv2.LoadBalancer - TargetGroups []*elbv2.TargetGroup - Listeners []*elbv2.Listener + LoadBalancers []elbv2types.LoadBalancer + TargetGroups []elbv2types.TargetGroup + Listeners []elbv2types.Listener // keys on all of these maps are ARNs LoadBalancerAttributes map[string]map[string]string - Tags map[string][]elbv2.Tag + Tags map[string][]elbv2types.Tag RegisteredInstances map[string][]string // value is list of instance IDs } -func (m *MockedFakeELBV2) AddTags(request *elbv2.AddTagsInput) (*elbv2.AddTagsOutput, error) { - for _, arn := range request.ResourceArns { - for _, tag := range request.Tags { - m.Tags[aws.StringValue(arn)] = append(m.Tags[aws.StringValue(arn)], *tag) +func (m *MockedFakeELBV2) AddTags(ctx context.Context, input *elbv2.AddTagsInput, optFns ...func(*elbv2.Options)) (*elbv2.AddTagsOutput, error) { + for _, arn := range input.ResourceArns { + for _, tag := range input.Tags { + m.Tags[arn] = append(m.Tags[arn], tag) } } return &elbv2.AddTagsOutput{}, nil } -func (m *MockedFakeELBV2) CreateLoadBalancer(request *elbv2.CreateLoadBalancerInput) (*elbv2.CreateLoadBalancerOutput, error) { +func (m *MockedFakeELBV2) CreateLoadBalancer(ctx context.Context, input *elbv2.CreateLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateLoadBalancerOutput, error) { accountID := 123456789 arn := fmt.Sprintf("arn:aws:elasticloadbalancing:us-west-2:%d:loadbalancer/net/%x/%x", accountID, rand.Uint64(), rand.Uint32()) - newLB := &elbv2.LoadBalancer{ + newLB := elbv2types.LoadBalancer{ LoadBalancerArn: aws.String(arn), - LoadBalancerName: request.Name, - Type: aws.String(elbv2.LoadBalancerTypeEnumNetwork), + LoadBalancerName: input.Name, + Type: elbv2types.LoadBalancerTypeEnumNetwork, VpcId: aws.String("vpc-abc123def456abc78"), - AvailabilityZones: []*elbv2.AvailabilityZone{ + AvailabilityZones: []elbv2types.AvailabilityZone{ { ZoneName: aws.String("us-west-2a"), SubnetId: aws.String("subnet-abc123de"), @@ -2540,35 +2470,35 @@ func (m *MockedFakeELBV2) CreateLoadBalancer(request *elbv2.CreateLoadBalancerIn m.LoadBalancers = append(m.LoadBalancers, newLB) return &elbv2.CreateLoadBalancerOutput{ - LoadBalancers: []*elbv2.LoadBalancer{newLB}, + LoadBalancers: []elbv2types.LoadBalancer{newLB}, }, nil } -func (m *MockedFakeELBV2) DescribeLoadBalancers(request *elbv2.DescribeLoadBalancersInput) (*elbv2.DescribeLoadBalancersOutput, error) { +func (m *MockedFakeELBV2) DescribeLoadBalancers(ctx context.Context, input *elbv2.DescribeLoadBalancersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancersOutput, error) { findMeNames := make(map[string]bool) - for _, name := range request.Names { - findMeNames[aws.StringValue(name)] = true + for _, name := range input.Names { + findMeNames[name] = true } findMeARNs := make(map[string]bool) - for _, arn := range request.LoadBalancerArns { - findMeARNs[aws.StringValue(arn)] = true + for _, arn := range input.LoadBalancerArns { + findMeARNs[arn] = true } - result := []*elbv2.LoadBalancer{} + result := []elbv2types.LoadBalancer{} for _, lb := range m.LoadBalancers { - if _, present := findMeNames[aws.StringValue(lb.LoadBalancerName)]; present { + if _, present := findMeNames[aws.ToString(lb.LoadBalancerName)]; present { result = append(result, lb) - delete(findMeNames, aws.StringValue(lb.LoadBalancerName)) - } else if _, present := findMeARNs[aws.StringValue(lb.LoadBalancerArn)]; present { + delete(findMeNames, aws.ToString(lb.LoadBalancerName)) + } else if _, present := findMeARNs[aws.ToString(lb.LoadBalancerArn)]; present { result = append(result, lb) - delete(findMeARNs, aws.StringValue(lb.LoadBalancerArn)) + delete(findMeARNs, aws.ToString(lb.LoadBalancerArn)) } } if len(findMeNames) > 0 || len(findMeARNs) > 0 { - return nil, awserr.New(elbv2.ErrCodeLoadBalancerNotFoundException, "not found", nil) + return nil, &elbv2types.LoadBalancerNotFoundException{Message: aws.String("not found")} } return &elbv2.DescribeLoadBalancersOutput{ @@ -2576,33 +2506,33 @@ func (m *MockedFakeELBV2) DescribeLoadBalancers(request *elbv2.DescribeLoadBalan }, nil } -func (m *MockedFakeELBV2) DeleteLoadBalancer(*elbv2.DeleteLoadBalancerInput) (*elbv2.DeleteLoadBalancerOutput, error) { +func (m *MockedFakeELBV2) DeleteLoadBalancer(ctx context.Context, input *elbv2.DeleteLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteLoadBalancerOutput, error) { panic("Not implemented") } -func (m *MockedFakeELBV2) ModifyLoadBalancerAttributes(request *elbv2.ModifyLoadBalancerAttributesInput) (*elbv2.ModifyLoadBalancerAttributesOutput, error) { - attrMap, present := m.LoadBalancerAttributes[aws.StringValue(request.LoadBalancerArn)] +func (m *MockedFakeELBV2) ModifyLoadBalancerAttributes(ctx context.Context, input *elbv2.ModifyLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyLoadBalancerAttributesOutput, error) { + attrMap, present := m.LoadBalancerAttributes[aws.ToString(input.LoadBalancerArn)] if !present { attrMap = make(map[string]string) - m.LoadBalancerAttributes[aws.StringValue(request.LoadBalancerArn)] = attrMap + m.LoadBalancerAttributes[aws.ToString(input.LoadBalancerArn)] = attrMap } - for _, attr := range request.Attributes { - attrMap[aws.StringValue(attr.Key)] = aws.StringValue(attr.Value) + for _, attr := range input.Attributes { + attrMap[aws.ToString(attr.Key)] = aws.ToString(attr.Value) } return &elbv2.ModifyLoadBalancerAttributesOutput{ - Attributes: request.Attributes, + Attributes: input.Attributes, }, nil } -func (m *MockedFakeELBV2) DescribeLoadBalancerAttributes(request *elbv2.DescribeLoadBalancerAttributesInput) (*elbv2.DescribeLoadBalancerAttributesOutput, error) { - attrs := []*elbv2.LoadBalancerAttribute{} +func (m *MockedFakeELBV2) DescribeLoadBalancerAttributes(ctx context.Context, input *elbv2.DescribeLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancerAttributesOutput, error) { + attrs := []elbv2types.LoadBalancerAttribute{} - if lbAttrs, present := m.LoadBalancerAttributes[aws.StringValue(request.LoadBalancerArn)]; present { + if lbAttrs, present := m.LoadBalancerAttributes[aws.ToString(input.LoadBalancerArn)]; present { for key, value := range lbAttrs { - attrs = append(attrs, &elbv2.LoadBalancerAttribute{ + attrs = append(attrs, elbv2types.LoadBalancerAttribute{ Key: aws.String(key), Value: aws.String(value), }) @@ -2614,65 +2544,65 @@ func (m *MockedFakeELBV2) DescribeLoadBalancerAttributes(request *elbv2.Describe }, nil } -func (m *MockedFakeELBV2) CreateTargetGroup(request *elbv2.CreateTargetGroupInput) (*elbv2.CreateTargetGroupOutput, error) { +func (m *MockedFakeELBV2) CreateTargetGroup(ctx context.Context, input *elbv2.CreateTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateTargetGroupOutput, error) { accountID := 123456789 arn := fmt.Sprintf("arn:aws:elasticloadbalancing:us-west-2:%d:targetgroup/%x/%x", accountID, rand.Uint64(), rand.Uint32()) - newTG := &elbv2.TargetGroup{ + newTG := elbv2types.TargetGroup{ TargetGroupArn: aws.String(arn), - TargetGroupName: request.Name, - Port: request.Port, - Protocol: request.Protocol, - HealthCheckProtocol: request.HealthCheckProtocol, - HealthCheckPath: request.HealthCheckPath, - HealthCheckPort: request.HealthCheckPort, - HealthCheckTimeoutSeconds: request.HealthCheckTimeoutSeconds, - HealthCheckIntervalSeconds: request.HealthCheckIntervalSeconds, - HealthyThresholdCount: request.HealthyThresholdCount, - UnhealthyThresholdCount: request.UnhealthyThresholdCount, + TargetGroupName: input.Name, + Port: input.Port, + Protocol: input.Protocol, + HealthCheckProtocol: input.HealthCheckProtocol, + HealthCheckPath: input.HealthCheckPath, + HealthCheckPort: input.HealthCheckPort, + HealthCheckTimeoutSeconds: input.HealthCheckTimeoutSeconds, + HealthCheckIntervalSeconds: input.HealthCheckIntervalSeconds, + HealthyThresholdCount: input.HealthyThresholdCount, + UnhealthyThresholdCount: input.UnhealthyThresholdCount, } m.TargetGroups = append(m.TargetGroups, newTG) return &elbv2.CreateTargetGroupOutput{ - TargetGroups: []*elbv2.TargetGroup{newTG}, + TargetGroups: []elbv2types.TargetGroup{newTG}, }, nil } -func (m *MockedFakeELBV2) DescribeTargetGroups(request *elbv2.DescribeTargetGroupsInput) (*elbv2.DescribeTargetGroupsOutput, error) { - var targetGroups []*elbv2.TargetGroup +func (m *MockedFakeELBV2) DescribeTargetGroups(ctx context.Context, input *elbv2.DescribeTargetGroupsInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupsOutput, error) { + var targetGroups []elbv2types.TargetGroup - if request.LoadBalancerArn != nil { - targetGroups = []*elbv2.TargetGroup{} + if input.LoadBalancerArn != nil { + targetGroups = []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { for _, lbArn := range tg.LoadBalancerArns { - if aws.StringValue(lbArn) == aws.StringValue(request.LoadBalancerArn) { + if lbArn == aws.ToString(input.LoadBalancerArn) { targetGroups = append(targetGroups, tg) break } } } - } else if len(request.Names) != 0 { - targetGroups = []*elbv2.TargetGroup{} + } else if len(input.Names) != 0 { + targetGroups = []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { - for _, name := range request.Names { - if aws.StringValue(tg.TargetGroupName) == aws.StringValue(name) { + for _, name := range input.Names { + if aws.ToString(tg.TargetGroupName) == name { targetGroups = append(targetGroups, tg) break } } } - } else if len(request.TargetGroupArns) != 0 { - targetGroups = []*elbv2.TargetGroup{} + } else if len(input.TargetGroupArns) != 0 { + targetGroups = []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { - for _, arn := range request.TargetGroupArns { - if aws.StringValue(tg.TargetGroupArn) == aws.StringValue(arn) { + for _, arn := range input.TargetGroupArns { + if aws.ToString(tg.TargetGroupArn) == arn { targetGroups = append(targetGroups, tg) break } @@ -2687,46 +2617,46 @@ func (m *MockedFakeELBV2) DescribeTargetGroups(request *elbv2.DescribeTargetGrou }, nil } -func (m *MockedFakeELBV2) ModifyTargetGroup(request *elbv2.ModifyTargetGroupInput) (*elbv2.ModifyTargetGroupOutput, error) { - var matchingTargetGroup *elbv2.TargetGroup - dirtyGroups := []*elbv2.TargetGroup{} +func (m *MockedFakeELBV2) ModifyTargetGroup(ctx context.Context, input *elbv2.ModifyTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupOutput, error) { + var matchingTargetGroup *elbv2types.TargetGroup + dirtyGroups := []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { - if aws.StringValue(tg.TargetGroupArn) == aws.StringValue(request.TargetGroupArn) { - matchingTargetGroup = tg + if aws.ToString(tg.TargetGroupArn) == aws.ToString(input.TargetGroupArn) { + matchingTargetGroup = &tg break } } if matchingTargetGroup != nil { - dirtyGroups = append(dirtyGroups, matchingTargetGroup) + dirtyGroups = append(dirtyGroups, *matchingTargetGroup) - if request.HealthCheckEnabled != nil { - matchingTargetGroup.HealthCheckEnabled = request.HealthCheckEnabled + if input.HealthCheckEnabled != nil { + matchingTargetGroup.HealthCheckEnabled = input.HealthCheckEnabled } - if request.HealthCheckIntervalSeconds != nil { - matchingTargetGroup.HealthCheckIntervalSeconds = request.HealthCheckIntervalSeconds + if input.HealthCheckIntervalSeconds != nil { + matchingTargetGroup.HealthCheckIntervalSeconds = input.HealthCheckIntervalSeconds } - if request.HealthCheckPath != nil { - matchingTargetGroup.HealthCheckPath = request.HealthCheckPath + if input.HealthCheckPath != nil { + matchingTargetGroup.HealthCheckPath = input.HealthCheckPath } - if request.HealthCheckPort != nil { - matchingTargetGroup.HealthCheckPort = request.HealthCheckPort + if input.HealthCheckPort != nil { + matchingTargetGroup.HealthCheckPort = input.HealthCheckPort } - if request.HealthCheckProtocol != nil { - matchingTargetGroup.HealthCheckProtocol = request.HealthCheckProtocol + if string(input.HealthCheckProtocol) != "" { + matchingTargetGroup.HealthCheckProtocol = input.HealthCheckProtocol } - if request.HealthCheckTimeoutSeconds != nil { - matchingTargetGroup.HealthCheckTimeoutSeconds = request.HealthCheckTimeoutSeconds + if input.HealthCheckTimeoutSeconds != nil { + matchingTargetGroup.HealthCheckTimeoutSeconds = input.HealthCheckTimeoutSeconds } - if request.HealthyThresholdCount != nil { - matchingTargetGroup.HealthyThresholdCount = request.HealthyThresholdCount + if input.HealthyThresholdCount != nil { + matchingTargetGroup.HealthyThresholdCount = input.HealthyThresholdCount } - if request.Matcher != nil { - matchingTargetGroup.Matcher = request.Matcher + if input.Matcher != nil { + matchingTargetGroup.Matcher = input.Matcher } - if request.UnhealthyThresholdCount != nil { - matchingTargetGroup.UnhealthyThresholdCount = request.UnhealthyThresholdCount + if input.UnhealthyThresholdCount != nil { + matchingTargetGroup.UnhealthyThresholdCount = input.UnhealthyThresholdCount } } @@ -2735,44 +2665,44 @@ func (m *MockedFakeELBV2) ModifyTargetGroup(request *elbv2.ModifyTargetGroupInpu }, nil } -func (m *MockedFakeELBV2) DeleteTargetGroup(request *elbv2.DeleteTargetGroupInput) (*elbv2.DeleteTargetGroupOutput, error) { - newTargetGroups := []*elbv2.TargetGroup{} +func (m *MockedFakeELBV2) DeleteTargetGroup(ctx context.Context, input *elbv2.DeleteTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteTargetGroupOutput, error) { + newTargetGroups := []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { - if aws.StringValue(tg.TargetGroupArn) != aws.StringValue(request.TargetGroupArn) { + if aws.ToString(tg.TargetGroupArn) != aws.ToString(input.TargetGroupArn) { newTargetGroups = append(newTargetGroups, tg) } } m.TargetGroups = newTargetGroups - delete(m.RegisteredInstances, aws.StringValue(request.TargetGroupArn)) + delete(m.RegisteredInstances, aws.ToString(input.TargetGroupArn)) return &elbv2.DeleteTargetGroupOutput{}, nil } -func (m *MockedFakeELBV2) DescribeTargetHealth(request *elbv2.DescribeTargetHealthInput) (*elbv2.DescribeTargetHealthOutput, error) { - healthDescriptions := []*elbv2.TargetHealthDescription{} +func (m *MockedFakeELBV2) DescribeTargetHealth(ctx context.Context, input *elbv2.DescribeTargetHealthInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) { + healthDescriptions := []elbv2types.TargetHealthDescription{} - var matchingTargetGroup *elbv2.TargetGroup + var matchingTargetGroup elbv2types.TargetGroup for _, tg := range m.TargetGroups { - if aws.StringValue(tg.TargetGroupArn) == aws.StringValue(request.TargetGroupArn) { + if aws.ToString(tg.TargetGroupArn) == aws.ToString(input.TargetGroupArn) { matchingTargetGroup = tg break } } - if registeredTargets, present := m.RegisteredInstances[aws.StringValue(request.TargetGroupArn)]; present { + if registeredTargets, present := m.RegisteredInstances[aws.ToString(input.TargetGroupArn)]; present { for _, target := range registeredTargets { - healthDescriptions = append(healthDescriptions, &elbv2.TargetHealthDescription{ + healthDescriptions = append(healthDescriptions, elbv2types.TargetHealthDescription{ HealthCheckPort: matchingTargetGroup.HealthCheckPort, - Target: &elbv2.TargetDescription{ + Target: &elbv2types.TargetDescription{ Id: aws.String(target), Port: matchingTargetGroup.Port, }, - TargetHealth: &elbv2.TargetHealth{ - State: aws.String("healthy"), + TargetHealth: &elbv2types.TargetHealth{ + State: elbv2types.TargetHealthStateEnumHealthy, }, }) } @@ -2783,46 +2713,46 @@ func (m *MockedFakeELBV2) DescribeTargetHealth(request *elbv2.DescribeTargetHeal }, nil } -func (m *MockedFakeELBV2) DescribeTargetGroupAttributes(*elbv2.DescribeTargetGroupAttributesInput) (*elbv2.DescribeTargetGroupAttributesOutput, error) { +func (m *MockedFakeELBV2) DescribeTargetGroupAttributes(ctx context.Context, input *elbv2.DescribeTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupAttributesOutput, error) { panic("Not implemented") } -func (m *MockedFakeELBV2) ModifyTargetGroupAttributes(*elbv2.ModifyTargetGroupAttributesInput) (*elbv2.ModifyTargetGroupAttributesOutput, error) { +func (m *MockedFakeELBV2) ModifyTargetGroupAttributes(ctx context.Context, input *elbv2.ModifyTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupAttributesOutput, error) { panic("Not implemented") } -func (m *MockedFakeELBV2) RegisterTargets(request *elbv2.RegisterTargetsInput) (*elbv2.RegisterTargetsOutput, error) { - arn := aws.StringValue(request.TargetGroupArn) +func (m *MockedFakeELBV2) RegisterTargets(ctx context.Context, input *elbv2.RegisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) { + arn := aws.ToString(input.TargetGroupArn) alreadyExists := make(map[string]bool) for _, targetID := range m.RegisteredInstances[arn] { alreadyExists[targetID] = true } - for _, target := range request.Targets { - if !alreadyExists[aws.StringValue(target.Id)] { - m.RegisteredInstances[arn] = append(m.RegisteredInstances[arn], aws.StringValue(target.Id)) + for _, target := range input.Targets { + if !alreadyExists[aws.ToString(target.Id)] { + m.RegisteredInstances[arn] = append(m.RegisteredInstances[arn], aws.ToString(target.Id)) } } return &elbv2.RegisterTargetsOutput{}, nil } -func (m *MockedFakeELBV2) DeregisterTargets(request *elbv2.DeregisterTargetsInput) (*elbv2.DeregisterTargetsOutput, error) { +func (m *MockedFakeELBV2) DeregisterTargets(ctx context.Context, input *elbv2.DeregisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) { removeMe := make(map[string]bool) - for _, target := range request.Targets { - removeMe[aws.StringValue(target.Id)] = true + for _, target := range input.Targets { + removeMe[aws.ToString(target.Id)] = true } newRegisteredInstancesForArn := []string{} - for _, targetID := range m.RegisteredInstances[aws.StringValue(request.TargetGroupArn)] { + for _, targetID := range m.RegisteredInstances[aws.ToString(input.TargetGroupArn)] { if !removeMe[targetID] { newRegisteredInstancesForArn = append(newRegisteredInstancesForArn, targetID) } } - m.RegisteredInstances[aws.StringValue(request.TargetGroupArn)] = newRegisteredInstancesForArn + m.RegisteredInstances[aws.ToString(input.TargetGroupArn)] = newRegisteredInstancesForArn return &elbv2.DeregisterTargetsOutput{}, nil } -func (m *MockedFakeELBV2) CreateListener(request *elbv2.CreateListenerInput) (*elbv2.CreateListenerOutput, error) { +func (m *MockedFakeELBV2) CreateListener(ctx context.Context, input *elbv2.CreateListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateListenerOutput, error) { accountID := 123456789 arn := fmt.Sprintf("arn:aws:elasticloadbalancing:us-west-2:%d:listener/net/%x/%x/%x", accountID, @@ -2830,40 +2760,40 @@ func (m *MockedFakeELBV2) CreateListener(request *elbv2.CreateListenerInput) (*e rand.Uint32(), rand.Uint32()) - newListener := &elbv2.Listener{ + newListener := elbv2types.Listener{ ListenerArn: aws.String(arn), - Port: request.Port, - Protocol: request.Protocol, - DefaultActions: request.DefaultActions, - LoadBalancerArn: request.LoadBalancerArn, + Port: input.Port, + Protocol: input.Protocol, + DefaultActions: input.DefaultActions, + LoadBalancerArn: input.LoadBalancerArn, } m.Listeners = append(m.Listeners, newListener) for _, tg := range m.TargetGroups { - for _, action := range request.DefaultActions { - if aws.StringValue(action.TargetGroupArn) == aws.StringValue(tg.TargetGroupArn) { - tg.LoadBalancerArns = append(tg.LoadBalancerArns, request.LoadBalancerArn) + for _, action := range input.DefaultActions { + if aws.ToString(action.TargetGroupArn) == aws.ToString(tg.TargetGroupArn) { + tg.LoadBalancerArns = append(tg.LoadBalancerArns, aws.ToString(input.LoadBalancerArn)) break } } } return &elbv2.CreateListenerOutput{ - Listeners: []*elbv2.Listener{newListener}, + Listeners: []elbv2types.Listener{newListener}, }, nil } -func (m *MockedFakeELBV2) DescribeListeners(request *elbv2.DescribeListenersInput) (*elbv2.DescribeListenersOutput, error) { - if len(request.ListenerArns) == 0 && request.LoadBalancerArn == nil { +func (m *MockedFakeELBV2) DescribeListeners(ctx context.Context, input *elbv2.DescribeListenersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeListenersOutput, error) { + if len(input.ListenerArns) == 0 && input.LoadBalancerArn == nil { return &elbv2.DescribeListenersOutput{ Listeners: m.Listeners, }, nil - } else if len(request.ListenerArns) == 0 { - listeners := []*elbv2.Listener{} + } else if len(input.ListenerArns) == 0 { + listeners := []elbv2types.Listener{} for _, lb := range m.Listeners { - if aws.StringValue(lb.LoadBalancerArn) == aws.StringValue(request.LoadBalancerArn) { + if aws.ToString(lb.LoadBalancerArn) == aws.ToString(input.LoadBalancerArn) { listeners = append(listeners, lb) } } @@ -2875,31 +2805,32 @@ func (m *MockedFakeELBV2) DescribeListeners(request *elbv2.DescribeListenersInpu panic("Not implemented") } -func (m *MockedFakeELBV2) DeleteListener(*elbv2.DeleteListenerInput) (*elbv2.DeleteListenerOutput, error) { +func (m *MockedFakeELBV2) DeleteListener(ctx context.Context, input *elbv2.DeleteListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteListenerOutput, error) { panic("Not implemented") } -func (m *MockedFakeELBV2) ModifyListener(request *elbv2.ModifyListenerInput) (*elbv2.ModifyListenerOutput, error) { - modifiedListeners := []*elbv2.Listener{} +func (m *MockedFakeELBV2) ModifyListener(ctx context.Context, input *elbv2.ModifyListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyListenerOutput, error) { - for _, listener := range m.Listeners { - if aws.StringValue(listener.ListenerArn) == aws.StringValue(request.ListenerArn) { - if request.DefaultActions != nil { + modifiedListeners := []elbv2types.Listener{} + for i := range m.Listeners { + listener := &m.Listeners[i] + if aws.ToString(listener.ListenerArn) == aws.ToString(input.ListenerArn) { + if input.DefaultActions != nil { // for each old action, find the corresponding target group, and remove the listener's LB ARN from its list for _, action := range listener.DefaultActions { - var targetGroupForAction *elbv2.TargetGroup + var targetGroupForAction *elbv2types.TargetGroup for _, tg := range m.TargetGroups { - if aws.StringValue(action.TargetGroupArn) == aws.StringValue(tg.TargetGroupArn) { - targetGroupForAction = tg + if aws.ToString(action.TargetGroupArn) == aws.ToString(tg.TargetGroupArn) { + targetGroupForAction = &tg break } } if targetGroupForAction != nil { - newLoadBalancerARNs := []*string{} + newLoadBalancerARNs := []string{} for _, lbArn := range targetGroupForAction.LoadBalancerArns { - if aws.StringValue(lbArn) != aws.StringValue(listener.LoadBalancerArn) { + if lbArn != aws.ToString(listener.LoadBalancerArn) { newLoadBalancerARNs = append(newLoadBalancerARNs, lbArn) } } @@ -2908,33 +2839,34 @@ func (m *MockedFakeELBV2) ModifyListener(request *elbv2.ModifyListenerInput) (*e } } - listener.DefaultActions = request.DefaultActions + listener.DefaultActions = input.DefaultActions // for each new action, add the listener's LB ARN to that action's target groups' lists - for _, action := range request.DefaultActions { - var targetGroupForAction *elbv2.TargetGroup + for _, action := range input.DefaultActions { + var targetGroupForAction *elbv2types.TargetGroup for _, tg := range m.TargetGroups { - if aws.StringValue(action.TargetGroupArn) == aws.StringValue(tg.TargetGroupArn) { - targetGroupForAction = tg + if aws.ToString(action.TargetGroupArn) == aws.ToString(tg.TargetGroupArn) { + targetGroupForAction = &tg break } } if targetGroupForAction != nil { - targetGroupForAction.LoadBalancerArns = append(targetGroupForAction.LoadBalancerArns, listener.LoadBalancerArn) + targetGroupForAction.LoadBalancerArns = append(targetGroupForAction.LoadBalancerArns, aws.ToString(listener.LoadBalancerArn)) } } } - if request.Port != nil { - listener.Port = request.Port + if input.Port != nil { + listener.Port = input.Port } - if request.Protocol != nil { - listener.Protocol = request.Protocol + if string(input.Protocol) != "" { + listener.Protocol = input.Protocol } - modifiedListeners = append(modifiedListeners, listener) + modifiedListeners = append(modifiedListeners, *listener) } + } return &elbv2.ModifyListenerOutput{ @@ -2942,34 +2874,30 @@ func (m *MockedFakeELBV2) ModifyListener(request *elbv2.ModifyListenerInput) (*e }, nil } -func (m *MockedFakeELBV2) WaitUntilLoadBalancersDeleted(*elbv2.DescribeLoadBalancersInput) error { - panic("Not implemented") -} - func (m *MockedFakeEC2) maybeExpectDescribeSecurityGroups(clusterID, groupName string) { - tags := []*ec2.Tag{ + tags := []ec2types.Tag{ {Key: aws.String(TagNameKubernetesClusterLegacy), Value: aws.String(clusterID)}, {Key: aws.String(fmt.Sprintf("%s%s", TagNameKubernetesClusterPrefix, clusterID)), Value: aws.String(ResourceLifecycleOwned)}, } - m.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{Filters: []*ec2.Filter{ + m.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{Filters: []ec2types.Filter{ newEc2Filter("group-name", groupName), newEc2Filter("vpc-id", ""), - }}).Maybe().Return([]*ec2.SecurityGroup{{Tags: tags}}) + }}).Maybe().Return([]ec2types.SecurityGroup{{Tags: tags}}) - m.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{}).Maybe().Return([]*ec2.SecurityGroup{{Tags: tags}}) + m.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{}).Maybe().Return([]ec2types.SecurityGroup{{Tags: tags}}) } func TestNLBNodeRegistration(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - awsServices.elbv2 = &MockedFakeELBV2{Tags: make(map[string][]elbv2.Tag), RegisteredInstances: make(map[string][]string), LoadBalancerAttributes: make(map[string]map[string]string)} + awsServices.elbv2 = &MockedFakeELBV2{Tags: make(map[string][]elbv2types.Tag), RegisteredInstances: make(map[string][]string), LoadBalancerAttributes: make(map[string]map[string]string)} c, _ := newAWSCloud(config.CloudConfig{}, awsServices) - awsServices.ec2.(*MockedFakeEC2).Subnets = []*ec2.Subnet{ + awsServices.ec2.(*MockedFakeEC2).Subnets = []ec2types.Subnet{ { AvailabilityZone: aws.String("us-west-2a"), SubnetId: aws.String("subnet-abc123de"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -2978,9 +2906,9 @@ func TestNLBNodeRegistration(t *testing.T) { }, } - awsServices.ec2.(*MockedFakeEC2).RouteTables = []*ec2.RouteTable{ + awsServices.ec2.(*MockedFakeEC2).RouteTables = []ec2types.RouteTable{ { - Associations: []*ec2.RouteTableAssociation{ + Associations: []ec2types.RouteTableAssociation{ { Main: aws.Bool(true), RouteTableAssociationId: aws.String("rtbassoc-abc123def456abc78"), @@ -2989,11 +2917,11 @@ func TestNLBNodeRegistration(t *testing.T) { }, }, RouteTableId: aws.String("rtb-abc123def456abc78"), - Routes: []*ec2.Route{ + Routes: []ec2types.Route{ { DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: aws.String("igw-abc123def456abc78"), - State: aws.String("active"), + State: ec2types.RouteStateActive, }, }, }, @@ -3055,29 +2983,29 @@ func TestNLBNodeRegistration(t *testing.T) { } fauxService.Annotations[ServiceAnnotationLoadBalancerHealthCheckProtocol] = "http" - tgARN := aws.StringValue(awsServices.elbv2.(*MockedFakeELBV2).Listeners[0].DefaultActions[0].TargetGroupArn) + tgARN := aws.ToString(awsServices.elbv2.(*MockedFakeELBV2).Listeners[0].DefaultActions[0].TargetGroupArn) _, err = c.EnsureLoadBalancer(context.TODO(), TestClusterName, fauxService, nodes) if err != nil { t.Errorf("EnsureLoadBalancer returned an error: %v", err) } assert.Equal(t, 1, len(awsServices.elbv2.(*MockedFakeELBV2).Listeners)) - assert.NotEqual(t, tgARN, aws.StringValue(awsServices.elbv2.(*MockedFakeELBV2).Listeners[0].DefaultActions[0].TargetGroupArn)) + assert.NotEqual(t, tgARN, aws.ToString(awsServices.elbv2.(*MockedFakeELBV2).Listeners[0].DefaultActions[0].TargetGroupArn)) } func makeNamedNode(s *FakeAWSServices, offset int, name string) *v1.Node { instanceID := fmt.Sprintf("i-%x", int64(0x02bce90670bb0c7cd)+int64(offset)) - instance := &ec2.Instance{} + instance := &ec2types.Instance{} instance.InstanceId = aws.String(instanceID) - instance.Placement = &ec2.Placement{ + instance.Placement = &ec2types.Placement{ AvailabilityZone: aws.String("us-west-2c"), } instance.PrivateDnsName = aws.String(fmt.Sprintf("ip-172-20-0-%d.ec2.internal", 101+offset)) instance.PrivateIpAddress = aws.String(fmt.Sprintf("192.168.0.%d", 1+offset)) - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterLegacy) tag.Value = aws.String(TestClusterID) - instance.Tags = []*ec2.Tag{&tag} + instance.Tags = []ec2types.Tag{tag} s.instances = append(s.instances, instance) @@ -3208,7 +3136,7 @@ func TestCloud_buildNLBHealthCheckConfiguration(t *testing.T) { }, want: healthCheckConfig{ Port: "traffic-port", - Protocol: elbv2.ProtocolEnumTcp, + Protocol: elbv2types.ProtocolEnumTcp, Interval: 30, Timeout: 10, HealthyThreshold: 3, @@ -3242,7 +3170,7 @@ func TestCloud_buildNLBHealthCheckConfiguration(t *testing.T) { want: healthCheckConfig{ Port: "32213", Path: "/healthz", - Protocol: elbv2.ProtocolEnumHttp, + Protocol: elbv2types.ProtocolEnumHttp, Interval: 10, Timeout: 10, HealthyThreshold: 2, @@ -3387,7 +3315,7 @@ func TestCloud_buildNLBHealthCheckConfiguration(t *testing.T) { }, want: healthCheckConfig{ Port: "traffic-port", - Protocol: elbv2.ProtocolEnumTcp, + Protocol: elbv2types.ProtocolEnumTcp, Interval: 23, Timeout: 10, HealthyThreshold: 3, @@ -3445,7 +3373,7 @@ func TestCloud_buildNLBHealthCheckConfiguration(t *testing.T) { }, want: healthCheckConfig{ Port: "traffic-port", - Protocol: elbv2.ProtocolEnumTcp, + Protocol: elbv2types.ProtocolEnumTcp, Interval: 30, Timeout: 10, HealthyThreshold: 7, @@ -3600,7 +3528,7 @@ func TestInstanceExistsByProviderIDForInstanceNotFound(t *testing.T) { mockedEC2API := newMockedEC2API() c := &Cloud{ec2: &awsSdkEC2{ec2: mockedEC2API}} - mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil)) + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, errors.New("InvalidInstanceID.NotFound: Instance not found")) instanceExists, err := c.InstanceExistsByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/i-not-found") assert.Nil(t, err) @@ -3657,32 +3585,40 @@ func TestGetRegionFromMetadata(t *testing.T) { // Returns region from zone if set cfg := config.CloudConfig{} cfg.Global.Zone = "us-west-2a" - region, err := getRegionFromMetadata(cfg, awsServices.metadata) + region, err := getRegionFromMetadata(context.TODO(), cfg, awsServices.metadata) assert.NoError(t, err) assert.Equal(t, "us-west-2", region) // Returns error if can map to region cfg = config.CloudConfig{} cfg.Global.Zone = "some-fake-zone" - _, err = getRegionFromMetadata(cfg, awsServices.metadata) + _, err = getRegionFromMetadata(context.TODO(), cfg, awsServices.metadata) assert.Error(t, err) // Returns region from metadata if zone unset cfg = config.CloudConfig{} - region, err = getRegionFromMetadata(cfg, awsServices.metadata) + region, err = getRegionFromMetadata(context.TODO(), cfg, awsServices.metadata) assert.NoError(t, err) assert.Equal(t, "us-west-2", region) } type MockedEC2API struct { - ec2iface.EC2API + EC2API mock.Mock } -func (m *MockedEC2API) DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { +func (m *MockedEC2API) DescribeInstances(ctx context.Context, input *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { args := m.Called(input) return args.Get(0).(*ec2.DescribeInstancesOutput), args.Error(1) } -func (m *MockedEC2API) DescribeAvailabilityZones(input *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) { +func (m *MockedEC2API) DescribeInstanceTopology(ctx context.Context, params *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTopologyOutput, error) { + args := m.Called(ctx, params) + if args.Get(1) != nil { + return nil, args.Get(1).(error) + } + return args.Get(0).(*ec2.DescribeInstanceTopologyOutput), nil +} + +func (m *MockedEC2API) DescribeAvailabilityZones(ctx context.Context, input *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) { args := m.Called(input) return args.Get(0).(*ec2.DescribeAvailabilityZonesOutput), args.Error(1) } @@ -3695,17 +3631,17 @@ func TestDescribeInstances(t *testing.T) { tests := []struct { name string input *ec2.DescribeInstancesInput - expect func(ec2iface.EC2API) + expect func(EC2API) isError bool }{ { "MaxResults set on empty DescribeInstancesInput and NextToken respected.", &ec2.DescribeInstancesInput{}, - func(mockedEc2 ec2iface.EC2API) { + func(mockedEc2 EC2API) { m := mockedEc2.(*MockedEC2API) m.On("DescribeInstances", &ec2.DescribeInstancesInput{ - MaxResults: aws.Int64(1000), + MaxResults: aws.Int32(1000), }, ).Return( &ec2.DescribeInstancesOutput{ @@ -3715,7 +3651,7 @@ func TestDescribeInstances(t *testing.T) { ) m.On("DescribeInstances", &ec2.DescribeInstancesInput{ - MaxResults: aws.Int64(1000), + MaxResults: aws.Int32(1000), NextToken: aws.String("asdf"), }, ).Return( @@ -3728,13 +3664,13 @@ func TestDescribeInstances(t *testing.T) { { "MaxResults only set if empty DescribeInstancesInput", &ec2.DescribeInstancesInput{ - MaxResults: aws.Int64(3), + MaxResults: aws.Int32(3), }, - func(mockedEc2 ec2iface.EC2API) { + func(mockedEc2 EC2API) { m := mockedEc2.(*MockedEC2API) m.On("DescribeInstances", &ec2.DescribeInstancesInput{ - MaxResults: aws.Int64(3), + MaxResults: aws.Int32(3), }, ).Return( &ec2.DescribeInstancesOutput{}, @@ -3746,13 +3682,13 @@ func TestDescribeInstances(t *testing.T) { { "MaxResults not set if instance IDs are provided", &ec2.DescribeInstancesInput{ - InstanceIds: []*string{aws.String("i-1234")}, + InstanceIds: []string{"i-1234"}, }, - func(mockedEc2 ec2iface.EC2API) { + func(mockedEc2 EC2API) { m := mockedEc2.(*MockedEC2API) m.On("DescribeInstances", &ec2.DescribeInstancesInput{ - InstanceIds: []*string{aws.String("i-1234")}, + InstanceIds: []string{"i-1234"}, }, ).Return( &ec2.DescribeInstancesOutput{}, @@ -3770,7 +3706,7 @@ func TestDescribeInstances(t *testing.T) { fakeEC2 := awsSdkEC2{ ec2: mockedEC2API, } - _, err := fakeEC2.DescribeInstances(test.input) + _, err := fakeEC2.DescribeInstances(context.TODO(), test.input) if !test.isError { assert.NoError(t, err) } @@ -3833,3 +3769,32 @@ func TestInstanceIDIndexFunc(t *testing.T) { }) } } + +func TestIsAWSErrorInstanceNotFound(t *testing.T) { + mockedEC2API := newMockedEC2API() + ec2Client := &awsSdkEC2{ + ec2: mockedEC2API, + } + + // API error + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, error(&smithy.GenericAPIError{ + Code: string(ec2types.UnsuccessfulInstanceCreditSpecificationErrorCodeInstanceNotFound), + Message: "test", + })) + _, err := ec2Client.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + assert.True(t, IsAWSErrorInstanceNotFound(err)) + + // Wrapped error + _, err = ec2Client.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + err = fmt.Errorf("error listing AWS instances: %q", err) + assert.True(t, IsAWSErrorInstanceNotFound(err)) + + // Expect false for nil and any other errors + assert.False(t, IsAWSErrorInstanceNotFound(nil)) + + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesInput{}, &smithy.GenericAPIError{ + Code: string(ec2types.UnsuccessfulInstanceCreditSpecificationErrorCodeIncorrectInstanceState), + }) + _, err = ec2Client.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + assert.False(t, IsAWSErrorInstanceNotFound(nil)) +} diff --git a/pkg/providers/v1/aws_utils.go b/pkg/providers/v1/aws_utils.go index 621731ed1c..44cea6378e 100644 --- a/pkg/providers/v1/aws_utils.go +++ b/pkg/providers/v1/aws_utils.go @@ -19,30 +19,29 @@ package aws import ( "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go-v2/aws/arn" "k8s.io/apimachinery/pkg/util/sets" ) -func stringSetToPointers(in sets.String) []*string { +func stringSetToList(in sets.Set[string]) []string { if in == nil { return nil } - out := make([]*string, 0, len(in)) + out := make([]string, 0, len(in)) for k := range in { - out = append(out, aws.String(k)) + out = append(out, k) } return out } -func stringSetFromPointers(in []*string) sets.String { +func stringSetFromList(in []string) sets.Set[string] { if in == nil { return nil } - out := sets.NewString() + out := sets.New[string]() for i := range in { - out.Insert(aws.StringValue(in[i])) + out.Insert(in[i]) } return out } diff --git a/pkg/providers/v1/config/config.go b/pkg/providers/v1/config/config.go index bc8f5e9e57..4aa456f6b8 100644 --- a/pkg/providers/v1/config/config.go +++ b/pkg/providers/v1/config/config.go @@ -1,12 +1,20 @@ package config import ( + "context" "fmt" + "net/url" + "strings" - "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/ec2" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/kms" - "github.com/aws/aws-sdk-go/aws/endpoints" + smithyendpoints "github.com/aws/smithy-go/endpoints" "k8s.io/klog/v2" ) @@ -67,7 +75,7 @@ type CloudConfig struct { // Override to regex validating whether or not instance types require instance topology // to get a definitive response. This will impact whether or not the node controller will // block on getting instance topology information for nodes. - // See pkg/resourcemanagers/topology.go for more details. + // See pkg/providers/v1/topology.go for more details. // // WARNING: Updating the default behavior and corresponding unit tests would be a much safer option. SupportedTopologyInstanceTypePattern string `json:"supportedTopologyInstanceTypePattern,omitempty" yaml:"supportedTopologyInstanceTypePattern,omitempty"` @@ -98,25 +106,25 @@ type CloudConfig struct { // EC2Metadata is an abstraction over the AWS metadata service. type EC2Metadata interface { // Query the EC2 metadata service (used to discover instance-id etc) - GetMetadata(path string) (string, error) - Region() (string, error) + GetMetadata(ctx context.Context, params *imds.GetMetadataInput, optFns ...func(*imds.Options)) (*imds.GetMetadataOutput, error) + GetRegion(ctx context.Context, params *imds.GetRegionInput, optFns ...func(*imds.Options)) (*imds.GetRegionOutput, error) } // GetRegion returns the AWS region from the config, if set, or gets it from the metadata // service if unset and sets in config -func (cfg *CloudConfig) GetRegion(metadata EC2Metadata) (string, error) { +func (cfg *CloudConfig) GetRegion(ctx context.Context, metadata EC2Metadata) (string, error) { if cfg.Global.Region != "" { return cfg.Global.Region, nil } klog.Info("Loading region from metadata service") - region, err := metadata.Region() + region, err := metadata.GetRegion(ctx, &imds.GetRegionInput{}) if err != nil { return "", err } - cfg.Global.Region = region - return region, nil + cfg.Global.Region = region.Region + return region.Region, nil } // ValidateOverrides ensures overrides are correct @@ -158,34 +166,224 @@ func (cfg *CloudConfig) ValidateOverrides() error { return nil } -// GetResolver computes the correct resolver to use -func (cfg *CloudConfig) GetResolver() endpoints.ResolverFunc { - defaultResolver := endpoints.DefaultResolver() - defaultResolverFn := func(service, region string, - optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - return defaultResolver.EndpointFor(service, region, optFns...) +// GetEC2EndpointOpts returns client configuration options that override +// the signing name and region, if appropriate. +func (cfg *CloudConfig) GetEC2EndpointOpts(region string) []func(*ec2.Options) { + opts := []func(*ec2.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == ec2.ServiceID && override.Region == region { + opts = append(opts, + ec2.WithSigV4SigningName(override.SigningName), + ec2.WithSigV4SigningRegion(override.SigningRegion), + ) + } } - if len(cfg.ServiceOverride) == 0 { - return defaultResolverFn - } - - return func(service, region string, - optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - for _, override := range cfg.ServiceOverride { - if override.Service == service && override.Region == region { - return endpoints.ResolvedEndpoint{ - URL: override.URL, - SigningRegion: override.SigningRegion, - SigningMethod: override.SigningMethod, - SigningName: override.SigningName, - }, nil + return opts +} + +// GetCustomEC2Resolver returns an endpoint resolver for EC2 Clients +func (cfg *CloudConfig) GetCustomEC2Resolver() ec2.EndpointResolverV2 { + return &EC2Resolver{ + Resolver: ec2.NewDefaultEndpointResolverV2(), + Cfg: cfg, + } +} + +// EC2Resolver overrides the endpoint for an AWS SDK Go V2 EC2 Client, +// using the provided CloudConfig to determine if an override +// is appropriate. +type EC2Resolver struct { + Resolver ec2.EndpointResolverV2 + Cfg *CloudConfig +} + +// ResolveEndpoint resolves the endpoint, overriding when custom configurations are set. +func (r *EC2Resolver) ResolveEndpoint( + ctx context.Context, params ec2.EndpointParameters, +) ( + endpoint smithyendpoints.Endpoint, err error, +) { + for _, override := range r.Cfg.ServiceOverride { + if override.Service == ec2.ServiceID && override.Region == aws.ToString(params.Region) { + customURL, err := url.Parse(override.URL) + if err != nil { + return smithyendpoints.Endpoint{}, fmt.Errorf("could not parse override URL, %w", err) + } + return smithyendpoints.Endpoint{ + URI: *customURL, + }, nil + } + } + return r.Resolver.ResolveEndpoint(ctx, params) +} + +// GetELBEndpointOpts returns client configuration options that override +// the signing name and region, if appropriate. +func (cfg *CloudConfig) GetELBEndpointOpts(region string) []func(*elb.Options) { + opts := []func(*elb.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == elb.ServiceID && override.Region == region { + opts = append(opts, + elb.WithSigV4SigningName(override.SigningName), + elb.WithSigV4SigningRegion(override.SigningRegion), + ) + } + } + return opts +} + +// GetCustomELBResolver returns an endpoint resolver for ELB Clients +func (cfg *CloudConfig) GetCustomELBResolver() elb.EndpointResolverV2 { + return &ELBResolver{ + Resolver: elb.NewDefaultEndpointResolverV2(), + Cfg: cfg, + } +} + +// ELBResolver overrides the endpoint for an AWS SDK Go V2 ELB Client, +// using the provided CloudConfig to determine if an override +// is appropriate. +type ELBResolver struct { + Resolver elb.EndpointResolverV2 + Cfg *CloudConfig +} + +// ResolveEndpoint resolves the endpoint, overriding when custom configurations are set. +func (r *ELBResolver) ResolveEndpoint( + ctx context.Context, params elb.EndpointParameters, +) ( + endpoint smithyendpoints.Endpoint, err error, +) { + for _, override := range r.Cfg.ServiceOverride { + if override.Service == elb.ServiceID && override.Region == aws.ToString(params.Region) { + customURL, err := url.Parse(override.URL) + if err != nil { + return smithyendpoints.Endpoint{}, fmt.Errorf("could not parse override URL, %w", err) + } + return smithyendpoints.Endpoint{ + URI: *customURL, + }, nil + } + } + return r.Resolver.ResolveEndpoint(ctx, params) +} + +// GetELBV2EndpointOpts returns client configuration options that override +// the signing name and region, if appropriate. +func (cfg *CloudConfig) GetELBV2EndpointOpts(region string) []func(*elbv2.Options) { + opts := []func(*elbv2.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == elbv2.ServiceID && override.Region == region { + opts = append(opts, + elbv2.WithSigV4SigningName(override.SigningName), + elbv2.WithSigV4SigningRegion(override.SigningRegion), + ) + } + } + return opts +} + +// GetCustomELBV2Resolver returns an endpoint resolver for ELB Clients +func (cfg *CloudConfig) GetCustomELBV2Resolver() elbv2.EndpointResolverV2 { + return &ELBV2Resolver{ + Resolver: elbv2.NewDefaultEndpointResolverV2(), + Cfg: cfg, + } +} + +// ELBV2Resolver overrides the endpoint for an AWS SDK Go V2 ELB Client, +// using the provided CloudConfig to determine if an override +// is appropriate. +type ELBV2Resolver struct { + Resolver elbv2.EndpointResolverV2 + Cfg *CloudConfig +} + +// ResolveEndpoint resolves the endpoint, overriding when custom configurations are set. +func (r *ELBV2Resolver) ResolveEndpoint( + ctx context.Context, params elbv2.EndpointParameters, +) ( + endpoint smithyendpoints.Endpoint, err error, +) { + for _, override := range r.Cfg.ServiceOverride { + if override.Service == elbv2.ServiceID && override.Region == aws.ToString(params.Region) { + customURL, err := url.Parse(override.URL) + if err != nil { + return smithyendpoints.Endpoint{}, fmt.Errorf("could not parse override URL, %w", err) + } + return smithyendpoints.Endpoint{ + URI: *customURL, + }, nil + } + } + return r.Resolver.ResolveEndpoint(ctx, params) +} + +// GetKMSEndpointOpts returns client configuration options that override +// the signing name and region, if appropriate. +func (cfg *CloudConfig) GetKMSEndpointOpts(region string) []func(*kms.Options) { + opts := []func(*kms.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == kms.ServiceID && override.Region == region { + opts = append(opts, + kms.WithSigV4SigningName(override.SigningName), + kms.WithSigV4SigningRegion(override.SigningRegion), + ) + } + } + return opts +} + +// GetCustomKMSResolver returns an endpoint resolver for KMS Clients +func (cfg *CloudConfig) GetCustomKMSResolver() kms.EndpointResolverV2 { + return &KMSResolver{ + Resolver: kms.NewDefaultEndpointResolverV2(), + Cfg: cfg, + } +} + +// KMSResolver overrides the endpoint for an AWS SDK Go V2 KMS Client, +// using the provided CloudConfig to determine if an override +// is appropriate. +type KMSResolver struct { + Resolver kms.EndpointResolverV2 + Cfg *CloudConfig +} + +// ResolveEndpoint resolves the endpoint, overriding when custom configurations are set. +func (r *KMSResolver) ResolveEndpoint( + ctx context.Context, params kms.EndpointParameters, +) ( + endpoint smithyendpoints.Endpoint, err error, +) { + for _, override := range r.Cfg.ServiceOverride { + if override.Service == kms.ServiceID && override.Region == aws.ToString(params.Region) { + customURL, err := url.Parse(override.URL) + if err != nil { + return smithyendpoints.Endpoint{}, fmt.Errorf("could not parse override URL, %w", err) } + return smithyendpoints.Endpoint{ + URI: *customURL, + }, nil + } + } + return r.Resolver.ResolveEndpoint(ctx, params) +} + +// GetIMDSEndpointOpts overrides the endpoint URL for IMDS clients +func (cfg *CloudConfig) GetIMDSEndpointOpts() []func(*imds.Options) { + opts := []func(*imds.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == imds.ServiceID { + opts = append(opts, func(o *imds.Options) { + o.Endpoint = override.URL + }) } - return defaultResolver.EndpointFor(service, region, optFns...) } + return opts } // SDKProvider can be used by variants to add their own handlers type SDKProvider interface { - AddHandlers(regionName string, h *request.Handlers) + AddMiddleware(ctx context.Context, regionName string, cfg *aws.Config) } diff --git a/pkg/providers/v1/iface/types.go b/pkg/providers/v1/iface/types.go index 451ffecc36..c064db4998 100644 --- a/pkg/providers/v1/iface/types.go +++ b/pkg/providers/v1/iface/types.go @@ -1,7 +1,10 @@ package iface import ( - "github.com/aws/aws-sdk-go/service/ec2" + "context" + + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) // EC2 is an abstraction over AWS', to allow mocking/other implementations @@ -9,30 +12,31 @@ import ( // TODO: Should we rename this to AWS (EBS & ELB are not technically part of EC2) type EC2 interface { // Query EC2 for instances matching the filter - DescribeInstances(request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) + DescribeInstances(ctx context.Context, request *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) ([]ec2types.Instance, error) + DescribeInstanceTopology(ctx context.Context, request *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) ([]ec2types.InstanceTopology, error) - DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) + DescribeSecurityGroups(ctx context.Context, request *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) ([]ec2types.SecurityGroup, error) - CreateSecurityGroup(*ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) - DeleteSecurityGroup(request *ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) + CreateSecurityGroup(ctx context.Context, request *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) + DeleteSecurityGroup(ctx context.Context, request *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) - AuthorizeSecurityGroupIngress(*ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) - RevokeSecurityGroupIngress(*ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) + AuthorizeSecurityGroupIngress(ctx context.Context, request *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) + RevokeSecurityGroupIngress(ctx context.Context, request *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) - DescribeSubnets(*ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) + DescribeSubnets(ctx context.Context, request *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) ([]ec2types.Subnet, error) - DescribeAvailabilityZones(request *ec2.DescribeAvailabilityZonesInput) ([]*ec2.AvailabilityZone, error) + DescribeAvailabilityZones(ctx context.Context, request *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) ([]ec2types.AvailabilityZone, error) - CreateTags(*ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) - DeleteTags(input *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) + CreateTags(ctx context.Context, request *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) + DeleteTags(ctx context.Context, request *ec2.DeleteTagsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTagsOutput, error) - DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ([]*ec2.RouteTable, error) - CreateRoute(request *ec2.CreateRouteInput) (*ec2.CreateRouteOutput, error) - DeleteRoute(request *ec2.DeleteRouteInput) (*ec2.DeleteRouteOutput, error) + DescribeRouteTables(ctx context.Context, request *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) ([]ec2types.RouteTable, error) + CreateRoute(ctx context.Context, request *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) + DeleteRoute(ctx context.Context, request *ec2.DeleteRouteInput, optFns ...func(*ec2.Options)) (*ec2.DeleteRouteOutput, error) - ModifyInstanceAttribute(request *ec2.ModifyInstanceAttributeInput) (*ec2.ModifyInstanceAttributeOutput, error) + ModifyInstanceAttribute(ctx context.Context, request *ec2.ModifyInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyInstanceAttributeOutput, error) - DescribeVpcs(input *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) + DescribeVpcs(ctx context.Context, input *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) - DescribeNetworkInterfaces(input *ec2.DescribeNetworkInterfacesInput) (*ec2.DescribeNetworkInterfacesOutput, error) + DescribeNetworkInterfaces(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) } diff --git a/pkg/providers/v1/instances.go b/pkg/providers/v1/instances.go index 08ae3aff21..1f98a2f23b 100644 --- a/pkg/providers/v1/instances.go +++ b/pkg/providers/v1/instances.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "context" "fmt" "net/url" "regexp" @@ -24,8 +25,9 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" @@ -125,12 +127,12 @@ func mapToAWSInstanceIDsTolerant(nodes []*v1.Node) []InstanceID { } // Gets the full information about this instance from the EC2 API -func describeInstance(ec2Client iface.EC2, instanceID InstanceID) (*ec2.Instance, error) { +func describeInstance(ctx context.Context, ec2Client iface.EC2, instanceID InstanceID) (*ec2types.Instance, error) { request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []string{string(instanceID)}, } - instances, err := ec2Client.DescribeInstances(request) + instances, err := ec2Client.DescribeInstances(ctx, request) if err != nil { return nil, err } @@ -140,7 +142,7 @@ func describeInstance(ec2Client iface.EC2, instanceID InstanceID) (*ec2.Instance if len(instances) > 1 { return nil, fmt.Errorf("multiple instances found for instance: %s", instanceID) } - return instances[0], nil + return &instances[0], nil } // instanceCache manages the cache of DescribeInstances @@ -154,20 +156,20 @@ type instanceCache struct { // Gets the full information about these instance from the EC2 API. Caller must have acquired c.mutex before // calling describeAllInstancesUncached. -func (c *instanceCache) describeAllInstancesUncached() (*allInstancesSnapshot, error) { +func (c *instanceCache) describeAllInstancesUncached(ctx context.Context) (*allInstancesSnapshot, error) { now := time.Now() klog.V(4).Infof("EC2 DescribeInstances - fetching all instances") - var filters []*ec2.Filter - instances, err := c.cloud.describeInstances(filters) + var filters []ec2types.Filter + instances, err := c.cloud.describeInstances(ctx, filters) if err != nil { return nil, err } - m := make(map[InstanceID]*ec2.Instance) + m := make(map[InstanceID]*ec2types.Instance) for _, i := range instances { - id := InstanceID(aws.StringValue(i.InstanceId)) + id := InstanceID(aws.ToString(i.InstanceId)) m[id] = i } @@ -194,7 +196,7 @@ type cacheCriteria struct { } // describeAllInstancesCached returns all instances, using cached results if applicable -func (c *instanceCache) describeAllInstancesCached(criteria cacheCriteria) (*allInstancesSnapshot, error) { +func (c *instanceCache) describeAllInstancesCached(ctx context.Context, criteria cacheCriteria) (*allInstancesSnapshot, error) { c.mutex.Lock() defer c.mutex.Unlock() if c.snapshot != nil && c.snapshot.MeetsCriteria(criteria) { @@ -202,7 +204,7 @@ func (c *instanceCache) describeAllInstancesCached(criteria cacheCriteria) (*all return c.snapshot, nil } - return c.describeAllInstancesUncached() + return c.describeAllInstancesUncached(ctx) } // olderThan is a simple helper to encapsulate timestamp comparison @@ -238,12 +240,12 @@ func (s *allInstancesSnapshot) MeetsCriteria(criteria cacheCriteria) bool { // along with the timestamp for cache-invalidation purposes type allInstancesSnapshot struct { timestamp time.Time - instances map[InstanceID]*ec2.Instance + instances map[InstanceID]*ec2types.Instance } // FindInstances returns the instances corresponding to the specified ids. If an id is not found, it is ignored. -func (s *allInstancesSnapshot) FindInstances(ids []InstanceID) map[InstanceID]*ec2.Instance { - m := make(map[InstanceID]*ec2.Instance) +func (s *allInstancesSnapshot) FindInstances(ids []InstanceID) map[InstanceID]*ec2types.Instance { + m := make(map[InstanceID]*ec2types.Instance) for _, id := range ids { instance := s.instances[id] if instance != nil { diff --git a/pkg/providers/v1/instances_test.go b/pkg/providers/v1/instances_test.go index ac431c6cf6..9a9866946f 100644 --- a/pkg/providers/v1/instances_test.go +++ b/pkg/providers/v1/instances_test.go @@ -20,8 +20,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/stretchr/testify/assert" v1 "k8s.io/api/core/v1" ) @@ -150,8 +150,8 @@ func TestSnapshotMeetsCriteria(t *testing.T) { t.Errorf("Snapshot did not honor HasInstances with missing instances") } - snapshot.instances = make(map[InstanceID]*ec2.Instance) - snapshot.instances[InstanceID("i-12345678")] = &ec2.Instance{} + snapshot.instances = make(map[InstanceID]*ec2types.Instance) + snapshot.instances[InstanceID("i-12345678")] = &ec2types.Instance{} if !snapshot.MeetsCriteria(cacheCriteria{HasInstances: []InstanceID{InstanceID("i-12345678")}}) { t.Errorf("Snapshot did not honor HasInstances with matching instances") @@ -177,14 +177,14 @@ func TestOlderThan(t *testing.T) { func TestSnapshotFindInstances(t *testing.T) { snapshot := &allInstancesSnapshot{} - snapshot.instances = make(map[InstanceID]*ec2.Instance) + snapshot.instances = make(map[InstanceID]*ec2types.Instance) { id := InstanceID("i-12345678") - snapshot.instances[id] = &ec2.Instance{InstanceId: id.awsString()} + snapshot.instances[id] = &ec2types.Instance{InstanceId: id.awsString()} } { id := InstanceID("i-23456789") - snapshot.instances[id] = &ec2.Instance{InstanceId: id.awsString()} + snapshot.instances[id] = &ec2types.Instance{InstanceId: id.awsString()} } instances := snapshot.FindInstances([]InstanceID{InstanceID("i-12345678"), InstanceID("i-23456789"), InstanceID("i-00000000")}) @@ -198,7 +198,7 @@ func TestSnapshotFindInstances(t *testing.T) { t.Errorf("findInstances did not return %s", id) continue } - if aws.StringValue(i.InstanceId) != string(id) { + if aws.ToString(i.InstanceId) != string(id) { t.Errorf("findInstances did not return expected instanceId for %s", id) } if i != snapshot.instances[id] { diff --git a/pkg/providers/v1/instances_v2.go b/pkg/providers/v1/instances_v2.go index 260da352f8..a045410fc2 100644 --- a/pkg/providers/v1/instances_v2.go +++ b/pkg/providers/v1/instances_v2.go @@ -72,7 +72,7 @@ func (c *Cloud) getAdditionalLabels(ctx context.Context, zoneName string, instan // If zone ID label is already set, skip. if _, ok := existingLabels[LabelZoneID]; !ok { // Add the zone ID to the additional labels - zoneID, err := c.zoneCache.getZoneIDByZoneName(zoneName) + zoneID, err := c.zoneCache.getZoneIDByZoneName(ctx, zoneName) if err != nil { return nil, err } diff --git a/pkg/providers/v1/instances_v2_test.go b/pkg/providers/v1/instances_v2_test.go index 945cb578d3..35e5df5abf 100644 --- a/pkg/providers/v1/instances_v2_test.go +++ b/pkg/providers/v1/instances_v2_test.go @@ -18,18 +18,17 @@ package aws import ( "context" + "errors" "fmt" "testing" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" v1 "k8s.io/api/core/v1" - "k8s.io/cloud-provider-aws/pkg/resourcemanagers" "k8s.io/cloud-provider-aws/pkg/services" ) @@ -61,7 +60,7 @@ func TestGetProviderId(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { instance := makeMinimalInstance(tc.instanceID) - c, _ := mockInstancesResp(&instance, []*ec2.Instance{&instance}) + c, _ := mockInstancesResp(&instance, []*ec2types.Instance{&instance}) result, err := c.getProviderID(context.TODO(), &tc.node) if err != nil { @@ -79,7 +78,7 @@ func TestInstanceExists(t *testing.T) { for _, tc := range []struct { name string instanceExists bool - instanceState string + instanceState ec2types.InstanceStateName expectedExists bool }{ { @@ -91,13 +90,13 @@ func TestInstanceExists(t *testing.T) { { name: "Should return true when instance is found and running", instanceExists: true, - instanceState: ec2.InstanceStateNameRunning, + instanceState: ec2types.InstanceStateNameRunning, expectedExists: true, }, { name: "Should return false when instance is found but terminated", instanceExists: true, - instanceState: ec2.InstanceStateNameTerminated, + instanceState: ec2types.InstanceStateNameTerminated, expectedExists: false, }, } { @@ -124,25 +123,25 @@ func TestInstanceShutdown(t *testing.T) { for _, tc := range []struct { name string instanceExists bool - instanceState string + instanceState ec2types.InstanceStateName expectedShutdown bool }{ { name: "Should return false when instance is found and running", instanceExists: true, - instanceState: ec2.InstanceStateNameRunning, + instanceState: ec2types.InstanceStateNameRunning, expectedShutdown: false, }, { name: "Should return false when instance is found and terminated", instanceExists: true, - instanceState: ec2.InstanceStateNameTerminated, + instanceState: ec2types.InstanceStateNameTerminated, expectedShutdown: false, }, { name: "Should return true when instance is found and stopped", instanceExists: true, - instanceState: ec2.InstanceStateNameStopped, + instanceState: ec2types.InstanceStateNameStopped, expectedShutdown: true, }, } { @@ -168,16 +167,16 @@ func TestInstanceShutdown(t *testing.T) { func TestInstanceMetadata(t *testing.T) { t.Run("Should return populated InstanceMetadata", func(t *testing.T) { instance := makeInstance("i-00000000000000000", "192.168.0.1", "1.2.3.4", "instance-same.ec2.internal", "instance-same.ec2.external", nil, true) - c, _ := mockInstancesResp(&instance, []*ec2.Instance{&instance}) - var mockedTopologyManager resourcemanagers.MockedInstanceTopologyManager + c, _ := mockInstancesResp(&instance, []*ec2types.Instance{&instance}) + var mockedTopologyManager MockedInstanceTopologyManager c.instanceTopologyManager = &mockedTopologyManager - mockedTopologyManager.On("GetNodeTopology", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&types.InstanceTopology{ - AvailabilityZone: awsv2.String("us-west-2b"), + mockedTopologyManager.On("GetNodeTopology", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&ec2types.InstanceTopology{ + AvailabilityZone: aws.String("us-west-2b"), GroupName: new(string), - InstanceId: awsv2.String("i-123456789"), + InstanceId: aws.String("i-123456789"), InstanceType: new(string), NetworkNodes: []string{"nn-123456789", "nn-234567890", "nn-345678901"}, - ZoneId: awsv2.String("az2"), + ZoneId: aws.String("az2"), }, nil) node := &v1.Node{ Spec: v1.NodeSpec{ @@ -212,8 +211,8 @@ func TestInstanceMetadata(t *testing.T) { t.Run("Should skip additional labels if already set", func(t *testing.T) { instance := makeInstance("i-00000000000000000", "192.168.0.1", "1.2.3.4", "instance-same.ec2.internal", "instance-same.ec2.external", nil, true) - c, _ := mockInstancesResp(&instance, []*ec2.Instance{&instance}) - var mockedTopologyManager resourcemanagers.MockedInstanceTopologyManager + c, _ := mockInstancesResp(&instance, []*ec2types.Instance{&instance}) + var mockedTopologyManager MockedInstanceTopologyManager c.instanceTopologyManager = &mockedTopologyManager node := &v1.Node{ Spec: v1.NodeSpec{ @@ -240,8 +239,8 @@ func TestInstanceMetadata(t *testing.T) { t.Run("Should swallow errors if getting node topology fails if instance type not expected to be supported", func(t *testing.T) { instance := makeInstance("i-00000000000000000", "192.168.0.1", "1.2.3.4", "instance-same.ec2.internal", "instance-same.ec2.external", nil, true) - c, _ := mockInstancesResp(&instance, []*ec2.Instance{&instance}) - var mockedTopologyManager resourcemanagers.MockedInstanceTopologyManager + c, _ := mockInstancesResp(&instance, []*ec2types.Instance{&instance}) + var mockedTopologyManager MockedInstanceTopologyManager c.instanceTopologyManager = &mockedTopologyManager mockedTopologyManager.On("GetNodeTopology", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, services.NewMockAPIError("InvalidParameterValue", "Nope.")) @@ -265,8 +264,8 @@ func TestInstanceMetadata(t *testing.T) { t.Run("Should not swallow errors if getting node topology fails if instance type is expected to be supported", func(t *testing.T) { instance := makeInstance("i-00000000000000000", "192.168.0.1", "1.2.3.4", "instance-same.ec2.internal", "instance-same.ec2.external", nil, true) - c, _ := mockInstancesResp(&instance, []*ec2.Instance{&instance}) - var mockedTopologyManager resourcemanagers.MockedInstanceTopologyManager + c, _ := mockInstancesResp(&instance, []*ec2types.Instance{&instance}) + var mockedTopologyManager MockedInstanceTopologyManager c.instanceTopologyManager = &mockedTopologyManager mockedTopologyManager.On("GetNodeTopology", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, services.NewMockAPIError("InvalidParameterValue", "Nope.")) @@ -286,20 +285,20 @@ func TestInstanceMetadata(t *testing.T) { }) } -func getCloudWithMockedDescribeInstances(instanceExists bool, instanceState string) *Cloud { +func getCloudWithMockedDescribeInstances(instanceExists bool, instanceState ec2types.InstanceStateName) *Cloud { mockedEC2API := newMockedEC2API() c := &Cloud{ec2: &awsSdkEC2{ec2: mockedEC2API}} if !instanceExists { - mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil)) + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, errors.New("InvalidInstanceID.NotFound: Instance not found")) } else { mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{ - Reservations: []*ec2.Reservation{ + Reservations: []ec2types.Reservation{ { - Instances: []*ec2.Instance{ + Instances: []ec2types.Instance{ { - State: &ec2.InstanceState{ - Name: aws.String(instanceState), + State: &ec2types.InstanceState{ + Name: instanceState, }, }, }, diff --git a/pkg/providers/v1/log_handler.go b/pkg/providers/v1/log_handler.go index bf0e45664a..ff15339ebe 100644 --- a/pkg/providers/v1/log_handler.go +++ b/pkg/providers/v1/log_handler.go @@ -17,32 +17,70 @@ limitations under the License. package aws import ( - "github.com/aws/aws-sdk-go/aws/request" + "context" + "fmt" + + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/middleware" + "github.com/aws/smithy-go/transport/http" "k8s.io/klog/v2" ) -// Handler for aws-sdk-go that logs all requests -func awsHandlerLogger(req *request.Request) { - service, name := awsServiceAndName(req) - klog.V(4).Infof("AWS request: %s %s", service, name) +// Middleware for AWS SDK Go V2 clients. Logs requests at the Finalize stage. +func awsHandlerLoggerMiddleware() middleware.FinalizeMiddleware { + return middleware.FinalizeMiddlewareFunc( + "k8s/logger", + func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + service, name := awsServiceAndName(ctx) + + klog.V(4).Infof("AWS request: %s %s", service, name) + return next.HandleFinalize(ctx, in) + }, + ) } -func awsSendHandlerLogger(req *request.Request) { - service, name := awsServiceAndName(req) - klog.V(4).Infof("AWS API Send: %s %s %v %v", service, name, req.Operation, req.Params) +// Logs details about the response at the Deserialization stage +func awsValidateResponseHandlerLoggerMiddleware() middleware.DeserializeMiddleware { + return middleware.DeserializeMiddlewareFunc( + "k8s/api-validate-response", + func(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, + ) { + out, metadata, err = next.HandleDeserialize(ctx, in) + response, ok := out.RawResponse.(*http.Response) + if !ok { + return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("unknown transport type %T", out.RawResponse)} + } + service, name := awsServiceAndName(ctx) + klog.V(4).Infof("AWS API ValidateResponse: %s %s %d", service, name, response.StatusCode) + return out, metadata, err + }, + ) } -func awsValidateResponseHandlerLogger(req *request.Request) { - service, name := awsServiceAndName(req) - klog.V(4).Infof("AWS API ValidateResponse: %s %s %v %v %s", service, name, req.Operation, req.Params, req.HTTPResponse.Status) +// Logs details about the request at the Serialize stage +func awsSendHandlerLoggerMiddleware() middleware.SerializeMiddleware { + return middleware.SerializeMiddlewareFunc( + "k8s/api-request", + func(ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler) ( + out middleware.SerializeOutput, metadata middleware.Metadata, err error, + ) { + service, name := awsServiceAndName(ctx) + klog.V(4).Infof("AWS API Send: %s %s %v", service, name, in.Parameters) + return next.HandleSerialize(ctx, in) + }, + ) } -func awsServiceAndName(req *request.Request) (string, string) { - service := req.ClientInfo.ServiceName +// Gets the service and operation name from AWS SDK Go V2 client requests. +func awsServiceAndName(ctx context.Context) (string, string) { + service := middleware.GetServiceID(ctx) name := "?" - if req.Operation != nil { - name = req.Operation.Name + if opName := middleware.GetOperationName(ctx); opName != "" { + name = opName } return service, name } diff --git a/pkg/providers/v1/retry_handler.go b/pkg/providers/v1/retry_handler.go index 8023596dad..18643c2dd9 100644 --- a/pkg/providers/v1/retry_handler.go +++ b/pkg/providers/v1/retry_handler.go @@ -17,16 +17,27 @@ limitations under the License. package aws import ( + "context" + "errors" "math" + "strings" "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/middleware" + "github.com/aws/smithy-go/transport/http" "k8s.io/klog/v2" ) +// nonRetryableError is the code for errors coming from API requests that should not be retried. This +// exists to replicate behavior from AWS SDK Go V1, where requests were marked as non-retryable +// in certain cases. +// In AWS SDK Go V2, an error with this error code is thrown in those same cases, and then +// caught during the IsErrorRetryable check by customRetryer. +var nonRetryableError = "non-retryable error" + const ( decayIntervalSeconds = 20 decayFraction = 0.8 @@ -47,60 +58,6 @@ func NewCrossRequestRetryDelay() *CrossRequestRetryDelay { return c } -// BeforeSign is added to the Sign chain; called before each request -func (c *CrossRequestRetryDelay) BeforeSign(r *request.Request) { - now := time.Now() - delay := c.backoff.ComputeDelayForRequest(now) - if delay > 0 { - klog.Warningf("Inserting delay before AWS request (%s) to avoid RequestLimitExceeded: %s", - describeRequest(r), delay.String()) - - if sleepFn := r.Config.SleepDelay; sleepFn != nil { - // Support SleepDelay for backwards compatibility - sleepFn(delay) - } else if err := aws.SleepWithContext(r.Context(), delay); err != nil { - r.Error = awserr.New(request.CanceledErrorCode, "request context canceled", err) - r.Retryable = aws.Bool(false) - return - } - - // Avoid clock skew problems - r.Time = now - } -} - -// Return the operation name, for use in log messages and metrics -func operationName(r *request.Request) string { - name := "?" - if r.Operation != nil { - name = r.Operation.Name - } - return name -} - -// Return a user-friendly string describing the request, for use in log messages -func describeRequest(r *request.Request) string { - service := r.ClientInfo.ServiceName - return service + "::" + operationName(r) -} - -// AfterRetry is added to the AfterRetry chain; called after any error -func (c *CrossRequestRetryDelay) AfterRetry(r *request.Request) { - if r.Error == nil { - return - } - awsError, ok := r.Error.(awserr.Error) - if !ok { - return - } - if awsError.Code() == "RequestLimitExceeded" { - c.backoff.ReportError() - recordAWSThrottlesMetric(operationName(r)) - klog.Warningf("Got RequestLimitExceeded error on AWS request (%s)", - describeRequest(r)) - } -} - // Backoff manages a backoff that varies based on the recently observed failures type Backoff struct { decayIntervalSeconds int64 @@ -170,6 +127,104 @@ func (b *Backoff) ComputeDelayForRequest(now time.Time) time.Duration { func (b *Backoff) ReportError() { b.mutex.Lock() defer b.mutex.Unlock() - b.countErrorsRequestLimit += 1.0 } + +// Standard retry implementation, except that it doesn't retry NON_RETRYABLE_ERROR errors. +// This works in tandem with (l *delayPrerequest) HandleFinalize, which will throw the error +// in certain cases as part of the middleware. +type customRetryer struct { + aws.Retryer +} + +func (r customRetryer) IsErrorRetryable(err error) bool { + if strings.Contains(err.Error(), nonRetryableError) { + return false + } + return r.Retryer.IsErrorRetryable(err) +} + +// Middleware for AWS SDK Go V2 clients +// Throws nonRetryableError if the request context was canceled, to preserve behavior from AWS +// SDK Go V1, where requests were marked as non-retryable under the same conditions. +// This works in tandem with customRetryer, which will not retry nonRetryableErrors. +func delayPreSign(delayer *CrossRequestRetryDelay) middleware.FinalizeMiddleware { + return middleware.FinalizeMiddlewareFunc( + "k8s/delay-presign", + func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + now := time.Now() + delay := delayer.backoff.ComputeDelayForRequest(now) + + if delay > 0 { + klog.Warningf("Inserting delay before AWS request (%s) to avoid RequestLimitExceeded: %s", + describeRequest(ctx), delay.String()) + + if err := sleepWithContext(ctx, delay); err != nil { + return middleware.FinalizeOutput{}, middleware.Metadata{}, errors.New(nonRetryableError) + } + } + + service, name := awsServiceAndName(ctx) + request, ok := in.Request.(*http.Request) + if ok { + klog.V(4).Infof("AWS API Send: %s %s %s %s", service, name, request.Request.Method, request.Request.URL.Path) + } + return next.HandleFinalize(ctx, in) + }, + ) +} + +func delayAfterRetry(delayer *CrossRequestRetryDelay) middleware.FinalizeMiddleware { + return middleware.FinalizeMiddlewareFunc( + "k8s/delay-afterretry", + func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + finOutput, finMetadata, finErr := next.HandleFinalize(ctx, in) + if finErr == nil { + return finOutput, finMetadata, finErr + } + + var ae smithy.APIError + if errors.As(finErr, &ae) && strings.Contains(ae.Error(), "RequestLimitExceeded") { + delayer.backoff.ReportError() + recordAWSThrottlesMetric(operationName(ctx)) + klog.Warningf("Got RequestLimitExceeded error on AWS request (%s)", + describeRequest(ctx)) + } + return finOutput, finMetadata, finErr + }, + ) +} + +// Return the operation name, for use in log messages and metrics +func operationName(ctx context.Context) string { + name := "?" + if opName := middleware.GetOperationName(ctx); opName != "" { + name = opName + } + return name +} + +// Return a user-friendly string describing the request, for use in log messages. +func describeRequest(ctx context.Context) string { + service := middleware.GetServiceID(ctx) + + return service + "::" + operationName(ctx) +} + +func sleepWithContext(ctx context.Context, dur time.Duration) error { + t := time.NewTimer(dur) + defer t.Stop() + + select { + case <-t.C: + break + case <-ctx.Done(): + return ctx.Err() + } + + return nil +} diff --git a/pkg/providers/v1/retry_handler_test.go b/pkg/providers/v1/retry_handler_test.go index 27b18c6005..079ed1a825 100644 --- a/pkg/providers/v1/retry_handler_test.go +++ b/pkg/providers/v1/retry_handler_test.go @@ -17,8 +17,20 @@ limitations under the License. package aws import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" "time" + + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "k8s.io/cloud-provider-aws/pkg/providers/v1/config" ) // There follows a group of tests for the backoff logic. There's nothing @@ -133,3 +145,205 @@ func TestBackoffRecovers(t *testing.T) { now = now.Add(time.Second) } } + +// Make sure that nonRetryableErrors, which are thrown by AWS SDK Go V2 clients +// when the request context is canceled, are not retried with customRetryer is used. +func TestNonRetryableError(t *testing.T) { + mockedEC2API := newMockedEC2API() + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, errors.New(nonRetryableError)) + + ec2Client := &awsSdkEC2{ + ec2: mockedEC2API, + } + _, err := ec2Client.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + + // Verify that the custom retryer can recognize when a nonRetryableError is thrown + retryer := &customRetryer{ + retry.NewStandard(), + } + if retryer.IsErrorRetryable(err) { + t.Errorf("Expected nonRetryableError error to be non-retryable") + } +} + +// Tests delayPresign to ensure that it delays the request +func TestDelayPresign(t *testing.T) { + // This test forces certain results from ComputeDelayForRequest() and sleepWithContext() + // to trigger a delay from delayPresign(). + // Dummy server to make sure the client request doesn't actually hit the API. + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + cfgWithServiceOverride := config.CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: "EC2", + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + // Create a dummy delayer that sets a delay of 1 second for ComputeDelayForRequest() + delayer := NewCrossRequestRetryDelay() + delayer.backoff.countRequests = 1 + delayer.backoff.countErrorsRequestLimit = 20000 + delayer.backoff.maxDelay = 100000 + regionDelayersMap := make(map[string]*CrossRequestRetryDelay) + regionDelayersMap["us-west-2"] = delayer + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: regionDelayersMap, + } + + ec2Client, err := mockProvider.Compute(context.Background(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + startTime := time.Now() + _, _ = ec2Client.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + endTime := time.Now() + diff := endTime.Sub(startTime).Seconds() + assert.True(t, diff > 1, fmt.Sprintf("expected a delay of at least 1 second, got %f", diff)) +} + +// Tests that delayAfterRetry() recognizes RequestLimitExceeded errors and counts them towards the backoff +func TestDelayAfterRetry(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/xml") + w.WriteHeader(http.StatusBadRequest) + + // Insert the RequestLimitExceeded error message + errorXML := fmt.Sprintf(` + + + + %d + %s + + + 12345678-1234-1234-1234-123456789012 + `, http.StatusBadRequest, "RequestLimitExceeded") + + w.Write([]byte(errorXML)) + })) + defer testServer.Close() + + cfgWithServiceOverride := config.CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: "EC2", + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + delayer := NewCrossRequestRetryDelay() + regionDelayersMap := make(map[string]*CrossRequestRetryDelay) + regionDelayersMap["us-west-2"] = delayer + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: regionDelayersMap, + } + + ec2Client, err := mockProvider.Compute(context.Background(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + preDelayErrorCount := delayer.backoff.countErrorsRequestLimit + _, err = ec2Client.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + postDelayErrorCount := delayer.backoff.countErrorsRequestLimit + + // Verify that a RequestLimitExceeded error was thrown + assert.Error(t, err) + assert.Contains(t, err.Error(), "RequestLimitExceeded") + + // In the event that delayAfterRetry() catches a RequestLimitExceeded error, it will + // update the error count in the delayer. This count is used to verify that this case + // was entered. + diff := (int)(postDelayErrorCount - preDelayErrorCount) + assert.True(t, diff == 1, fmt.Sprintf("expected an update to the backoff count of %d, got %d", 1, diff)) +} + +// Tests that delayAfterRetry() does not update the backoff in case of an error other than RequestLimitExceeded +func TestDelayAfterRetryNoDelay(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/xml") + w.WriteHeader(http.StatusBadRequest) + + // Insert a dummy error message that's not RequestLimitExceeded + errorXML := fmt.Sprintf(` + + + + %d + %s + + + 12345678-1234-1234-1234-123456789012 + `, http.StatusBadRequest, "DummyError") + + w.Write([]byte(errorXML)) + })) + defer testServer.Close() + + cfgWithServiceOverride := config.CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: "EC2", + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + delayer := NewCrossRequestRetryDelay() + regionDelayersMap := make(map[string]*CrossRequestRetryDelay) + regionDelayersMap["us-west-2"] = delayer + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: regionDelayersMap, + } + + ec2Client, err := mockProvider.Compute(context.Background(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + preDelayErrorCount := delayer.backoff.countErrorsRequestLimit + _, err = ec2Client.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + postDelayErrorCount := delayer.backoff.countErrorsRequestLimit + + // Verify that a RequestLimitExceeded error wasn't thrown + assert.Error(t, err) + assert.NotContains(t, err.Error(), "RequestLimitExceeded") + + // In the event that delayAfterRetry() catches a RequestLimitExceeded error, it will + // update the error count in the delayer. This count is used to verify that this case + // was not entered. + diff := (int)(postDelayErrorCount - preDelayErrorCount) + assert.True(t, diff == 0, fmt.Sprintf("expected an update to the backoff count of %d, got %d", 0, diff)) +} diff --git a/pkg/providers/v1/sets_ippermissions.go b/pkg/providers/v1/sets_ippermissions.go index a304deedd5..72e99ec2ea 100644 --- a/pkg/providers/v1/sets_ippermissions.go +++ b/pkg/providers/v1/sets_ippermissions.go @@ -20,21 +20,21 @@ import ( "encoding/json" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) // IPPermissionSet maps IP strings of strings to EC2 IpPermissions -type IPPermissionSet map[string]*ec2.IpPermission +type IPPermissionSet map[string]ec2types.IpPermission // IPPermissionPredicate is an predicate to test whether IPPermission matches some condition. type IPPermissionPredicate interface { // Test checks whether specified IPPermission matches condition. - Test(perm *ec2.IpPermission) bool + Test(perm ec2types.IpPermission) bool } // NewIPPermissionSet creates a new IPPermissionSet -func NewIPPermissionSet(items ...*ec2.IpPermission) IPPermissionSet { +func NewIPPermissionSet(items ...ec2types.IpPermission) IPPermissionSet { s := make(IPPermissionSet) s.Insert(items...) return s @@ -44,44 +44,44 @@ func NewIPPermissionSet(items ...*ec2.IpPermission) IPPermissionSet { // EC2 will combine permissions with the same port but different SourceRanges together, for example // We ungroup them so we can process them func (s IPPermissionSet) Ungroup() IPPermissionSet { - l := []*ec2.IpPermission{} + l := []ec2types.IpPermission{} for _, p := range s.List() { if len(p.IpRanges) <= 1 { l = append(l, p) continue } for _, ipRange := range p.IpRanges { - c := &ec2.IpPermission{} - *c = *p - c.IpRanges = []*ec2.IpRange{ipRange} + c := ec2types.IpPermission{} + c = p + c.IpRanges = []ec2types.IpRange{ipRange} l = append(l, c) } } - l2 := []*ec2.IpPermission{} + l2 := []ec2types.IpPermission{} for _, p := range l { if len(p.UserIdGroupPairs) <= 1 { l2 = append(l2, p) continue } for _, u := range p.UserIdGroupPairs { - c := &ec2.IpPermission{} - *c = *p - c.UserIdGroupPairs = []*ec2.UserIdGroupPair{u} + c := ec2types.IpPermission{} + c = p + c.UserIdGroupPairs = []ec2types.UserIdGroupPair{u} l2 = append(l2, c) } } - l3 := []*ec2.IpPermission{} + l3 := []ec2types.IpPermission{} for _, p := range l2 { if len(p.PrefixListIds) <= 1 { l3 = append(l3, p) continue } for _, v := range p.PrefixListIds { - c := &ec2.IpPermission{} - *c = *p - c.PrefixListIds = []*ec2.PrefixListId{v} + c := ec2types.IpPermission{} + c = p + c.PrefixListIds = []ec2types.PrefixListId{v} l3 = append(l3, c) } } @@ -90,7 +90,7 @@ func (s IPPermissionSet) Ungroup() IPPermissionSet { } // Insert adds items to the set. -func (s IPPermissionSet) Insert(items ...*ec2.IpPermission) { +func (s IPPermissionSet) Insert(items ...ec2types.IpPermission) { for _, p := range items { k := keyForIPPermission(p) s[k] = p @@ -98,7 +98,7 @@ func (s IPPermissionSet) Insert(items ...*ec2.IpPermission) { } // Delete delete permission from the set. -func (s IPPermissionSet) Delete(items ...*ec2.IpPermission) { +func (s IPPermissionSet) Delete(items ...ec2types.IpPermission) { for _, p := range items { k := keyForIPPermission(p) delete(s, k) @@ -115,8 +115,8 @@ func (s IPPermissionSet) DeleteIf(predicate IPPermissionPredicate) { } // List returns the contents as a slice. Order is not defined. -func (s IPPermissionSet) List() []*ec2.IpPermission { - res := make([]*ec2.IpPermission, 0, len(s)) +func (s IPPermissionSet) List() []ec2types.IpPermission { + res := make([]ec2types.IpPermission, 0, len(s)) for _, v := range s { res = append(res, v) } @@ -163,7 +163,7 @@ func (s IPPermissionSet) Len() int { return len(s) } -func keyForIPPermission(p *ec2.IpPermission) string { +func keyForIPPermission(p ec2types.IpPermission) string { v, err := json.Marshal(p) if err != nil { panic(fmt.Sprintf("error building JSON representation of ec2.IpPermission: %v", err)) @@ -179,24 +179,24 @@ type IPPermissionMatchDesc struct { } // Test whether specific IPPermission contains description. -func (p IPPermissionMatchDesc) Test(perm *ec2.IpPermission) bool { +func (p IPPermissionMatchDesc) Test(perm ec2types.IpPermission) bool { for _, v4Range := range perm.IpRanges { - if aws.StringValue(v4Range.Description) == p.Description { + if aws.ToString(v4Range.Description) == p.Description { return true } } for _, v6Range := range perm.Ipv6Ranges { - if aws.StringValue(v6Range.Description) == p.Description { + if aws.ToString(v6Range.Description) == p.Description { return true } } for _, prefixListID := range perm.PrefixListIds { - if aws.StringValue(prefixListID.Description) == p.Description { + if aws.ToString(prefixListID.Description) == p.Description { return true } } for _, group := range perm.UserIdGroupPairs { - if aws.StringValue(group.Description) == p.Description { + if aws.ToString(group.Description) == p.Description { return true } } @@ -211,6 +211,6 @@ type IPPermissionNotMatch struct { } // Test whether specific IPPermission not match the embed predicate. -func (p IPPermissionNotMatch) Test(perm *ec2.IpPermission) bool { +func (p IPPermissionNotMatch) Test(perm ec2types.IpPermission) bool { return !p.Predicate.Test(perm) } diff --git a/pkg/providers/v1/sets_ippermissions_test.go b/pkg/providers/v1/sets_ippermissions_test.go index 0680b29b1e..4e4f3a54fd 100644 --- a/pkg/providers/v1/sets_ippermissions_test.go +++ b/pkg/providers/v1/sets_ippermissions_test.go @@ -3,8 +3,8 @@ package aws import ( "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) func TestUngroup(t *testing.T) { @@ -17,67 +17,67 @@ func TestUngroup(t *testing.T) { { "Single IP range in input set", NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + ToPort: aws.Int32(2), }, ), NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + ToPort: aws.Int32(2), }, ), }, { "Three ip ranges in input set", NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("10.0.0.0/16")}, {CidrIp: aws.String("10.1.0.0/16")}, {CidrIp: aws.String("10.2.0.0/16")}, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, ), NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + ToPort: aws.Int32(2), }, - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.1.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.1.0.0/16")}}, + ToPort: aws.Int32(2), }, - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.2.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.2.0.0/16")}}, + ToPort: aws.Int32(2), }, ), }, { "Three UserIdGroupPairs in input set", NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("10.0.0.0/16")}, }, - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ { GroupId: aws.String("1"), GroupName: aws.String("group-1"), @@ -97,15 +97,15 @@ func TestUngroup(t *testing.T) { VpcId: aws.String("123"), }, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, ), NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + UserIdGroupPairs: []ec2types.UserIdGroupPair{ { GroupId: aws.String("1"), GroupName: aws.String("group-1"), @@ -113,13 +113,13 @@ func TestUngroup(t *testing.T) { VpcId: aws.String("123"), }, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + UserIdGroupPairs: []ec2types.UserIdGroupPair{ { GroupId: aws.String("2"), GroupName: aws.String("group-2"), @@ -127,13 +127,13 @@ func TestUngroup(t *testing.T) { VpcId: aws.String("123"), }, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + UserIdGroupPairs: []ec2types.UserIdGroupPair{ { GroupId: aws.String("3"), GroupName: aws.String("group-3"), @@ -141,7 +141,7 @@ func TestUngroup(t *testing.T) { VpcId: aws.String("123"), }, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, ), }, diff --git a/pkg/providers/v1/tags.go b/pkg/providers/v1/tags.go index 8c5ee23153..be5c18126d 100644 --- a/pkg/providers/v1/tags.go +++ b/pkg/providers/v1/tags.go @@ -17,11 +17,13 @@ limitations under the License. package aws import ( + "context" "fmt" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "k8s.io/klog/v2" "k8s.io/apimachinery/pkg/util/wait" @@ -87,7 +89,7 @@ func (t *awsTagging) init(legacyClusterID string, clusterID string) error { // Extracts a clusterID from the given tags, if one is present // If no clusterID is found, returns "", nil // If multiple (different) clusterIDs are found, returns an error -func (t *awsTagging) initFromTags(tags []*ec2.Tag) error { +func (t *awsTagging) initFromTags(tags []ec2types.Tag) error { legacyClusterID, newClusterID, err := findClusterIDs(tags) if err != nil { return err @@ -102,12 +104,12 @@ func (t *awsTagging) initFromTags(tags []*ec2.Tag) error { // Extracts the legacy & new cluster ids from the given tags, if they are present // If duplicate tags are found, returns an error -func findClusterIDs(tags []*ec2.Tag) (string, string, error) { +func findClusterIDs(tags []ec2types.Tag) (string, string, error) { legacyClusterID := "" newClusterID := "" for _, tag := range tags { - tagKey := aws.StringValue(tag.Key) + tagKey := aws.ToString(tag.Key) if strings.HasPrefix(tagKey, TagNameKubernetesClusterPrefix) { id := strings.TrimPrefix(tagKey, TagNameKubernetesClusterPrefix) if newClusterID != "" { @@ -117,7 +119,7 @@ func findClusterIDs(tags []*ec2.Tag) (string, string, error) { } if tagKey == TagNameKubernetesClusterLegacy { - id := aws.StringValue(tag.Value) + id := aws.ToString(tag.Value) if legacyClusterID != "" { return "", "", fmt.Errorf("Found multiple %s tags (%q and %q)", TagNameKubernetesClusterLegacy, legacyClusterID, id) } @@ -132,17 +134,17 @@ func (t *awsTagging) clusterTagKey() string { return TagNameKubernetesClusterPrefix + t.ClusterID } -func (t *awsTagging) hasClusterTag(tags []*ec2.Tag) bool { +func (t *awsTagging) hasClusterTag(tags []ec2types.Tag) bool { // if the clusterID is not configured -- we consider all instances. if len(t.ClusterID) == 0 { return true } clusterTagKey := t.clusterTagKey() for _, tag := range tags { - tagKey := aws.StringValue(tag.Key) + tagKey := aws.ToString(tag.Key) // For 1.6, we continue to recognize the legacy tags, for the 1.5 -> 1.6 upgrade // Note that we want to continue traversing tag list if we see a legacy tag with value != ClusterID - if (tagKey == TagNameKubernetesClusterLegacy) && (aws.StringValue(tag.Value) == t.ClusterID) { + if (tagKey == TagNameKubernetesClusterLegacy) && (aws.ToString(tag.Value) == t.ClusterID) { return true } if tagKey == clusterTagKey { @@ -152,9 +154,9 @@ func (t *awsTagging) hasClusterTag(tags []*ec2.Tag) bool { return false } -func (t *awsTagging) hasNoClusterPrefixTag(tags []*ec2.Tag) bool { +func (t *awsTagging) hasNoClusterPrefixTag(tags []ec2types.Tag) bool { for _, tag := range tags { - if strings.HasPrefix(aws.StringValue(tag.Key), TagNameKubernetesClusterPrefix) { + if strings.HasPrefix(aws.ToString(tag.Key), TagNameKubernetesClusterPrefix) { return false } } @@ -164,10 +166,10 @@ func (t *awsTagging) hasNoClusterPrefixTag(tags []*ec2.Tag) bool { // Ensure that a resource has the correct tags // If it has no tags, we assume that this was a problem caused by an error in between creation and tagging, // and we add the tags. If it has a different cluster's tags, that is an error. -func (t *awsTagging) readRepairClusterTags(client iface.EC2, resourceID string, lifecycle ResourceLifecycle, additionalTags map[string]string, observedTags []*ec2.Tag) error { +func (t *awsTagging) readRepairClusterTags(ctx context.Context, client iface.EC2, resourceID string, lifecycle ResourceLifecycle, additionalTags map[string]string, observedTags []ec2types.Tag) error { actualTagMap := make(map[string]string) for _, tag := range observedTags { - actualTagMap[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + actualTagMap[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } expectedTags := t.buildTags(lifecycle, additionalTags) @@ -190,7 +192,7 @@ func (t *awsTagging) readRepairClusterTags(client iface.EC2, resourceID string, return nil } - if err := t.createTags(client, resourceID, lifecycle, addTags); err != nil { + if err := t.createTags(ctx, client, resourceID, lifecycle, addTags); err != nil { return fmt.Errorf("error adding missing tags to resource %q: %q", resourceID, err) } @@ -200,16 +202,16 @@ func (t *awsTagging) readRepairClusterTags(client iface.EC2, resourceID string, // createTags calls EC2 CreateTags, but adds retry-on-failure logic // We retry mainly because if we create an object, we cannot tag it until it is "fully created" (eventual consistency) // The error code varies though (depending on what we are tagging), so we simply retry on all errors -func (t *awsTagging) createTags(client iface.EC2, resourceID string, lifecycle ResourceLifecycle, additionalTags map[string]string) error { +func (t *awsTagging) createTags(ctx context.Context, client iface.EC2, resourceID string, lifecycle ResourceLifecycle, additionalTags map[string]string) error { tags := t.buildTags(lifecycle, additionalTags) if tags == nil || len(tags) == 0 { return nil } - var awsTags []*ec2.Tag + var awsTags []ec2types.Tag for k, v := range tags { - tag := &ec2.Tag{ + tag := ec2types.Tag{ Key: aws.String(k), Value: aws.String(v), } @@ -222,12 +224,12 @@ func (t *awsTagging) createTags(client iface.EC2, resourceID string, lifecycle R Steps: createTagSteps, } request := &ec2.CreateTagsInput{} - request.Resources = []*string{&resourceID} + request.Resources = []string{resourceID} request.Tags = awsTags var lastErr error err := wait.ExponentialBackoff(backoff, func() (bool, error) { - _, err := client.CreateTags(request) + _, err := client.CreateTags(ctx, request) if err == nil { return true, nil } @@ -247,7 +249,7 @@ func (t *awsTagging) createTags(client iface.EC2, resourceID string, lifecycle R // Add additional filters, to match on our tags // This lets us run multiple k8s clusters in a single EC2 AZ -func (t *awsTagging) addFilters(filters []*ec2.Filter) []*ec2.Filter { +func (t *awsTagging) addFilters(filters []*ec2types.Filter) []*ec2types.Filter { // if there are no clusterID configured - no filtering by special tag names // should be applied to revert to legacy behaviour. if len(t.ClusterID) == 0 { @@ -260,7 +262,7 @@ func (t *awsTagging) addFilters(filters []*ec2.Filter) []*ec2.Filter { } f := newEc2Filter("tag-key", t.clusterTagKey()) - filters = append(filters, f) + filters = append(filters, &f) return filters } @@ -268,7 +270,7 @@ func (t *awsTagging) addFilters(filters []*ec2.Filter) []*ec2.Filter { // 1.5 -> 1.6 clusters and exists for backwards compatibility // // This lets us run multiple k8s clusters in a single EC2 AZ -func (t *awsTagging) addLegacyFilters(filters []*ec2.Filter) []*ec2.Filter { +func (t *awsTagging) addLegacyFilters(filters []*ec2types.Filter) []*ec2types.Filter { // if there are no clusterID configured - no filtering by special tag names // should be applied to revert to legacy behaviour. if len(t.ClusterID) == 0 { @@ -284,7 +286,7 @@ func (t *awsTagging) addLegacyFilters(filters []*ec2.Filter) []*ec2.Filter { // We can't pass a zero-length Filters to AWS (it's an error) // So if we end up with no filters; we need to return nil - filters = append(filters, f) + filters = append(filters, &f) return filters } @@ -315,13 +317,13 @@ func (t *awsTagging) clusterID() string { // TagResource calls EC2 and tag the resource associated to resourceID // with the supplied tags -func (c *Cloud) TagResource(resourceID string, tags map[string]string) error { +func (c *Cloud) TagResource(ctx context.Context, resourceID string, tags map[string]string) error { request := &ec2.CreateTagsInput{ - Resources: []*string{aws.String(resourceID)}, + Resources: []string{resourceID}, Tags: buildAwsTags(tags), } - output, err := c.ec2.CreateTags(request) + output, err := c.ec2.CreateTags(ctx, request) if err != nil { klog.Errorf("Error occurred trying to tag resources, %v", err) @@ -335,13 +337,13 @@ func (c *Cloud) TagResource(resourceID string, tags map[string]string) error { // UntagResource calls EC2 and tag the resource associated to resourceID // with the supplied tags -func (c *Cloud) UntagResource(resourceID string, tags map[string]string) error { +func (c *Cloud) UntagResource(ctx context.Context, resourceID string, tags map[string]string) error { request := &ec2.DeleteTagsInput{ - Resources: []*string{aws.String(resourceID)}, + Resources: []string{resourceID}, Tags: buildAwsTags(tags), } - output, err := c.ec2.DeleteTags(request) + output, err := c.ec2.DeleteTags(ctx, request) if err != nil { // An instance not found should not fail the untagging workflow as it @@ -359,10 +361,10 @@ func (c *Cloud) UntagResource(resourceID string, tags map[string]string) error { return nil } -func buildAwsTags(tags map[string]string) []*ec2.Tag { - var awsTags []*ec2.Tag +func buildAwsTags(tags map[string]string) []ec2types.Tag { + var awsTags []ec2types.Tag for k, v := range tags { - newTag := &ec2.Tag{ + newTag := ec2types.Tag{ Key: aws.String(k), Value: aws.String(v), } diff --git a/pkg/providers/v1/tags_test.go b/pkg/providers/v1/tags_test.go index d7b8dad58a..79a4ad58fc 100644 --- a/pkg/providers/v1/tags_test.go +++ b/pkg/providers/v1/tags_test.go @@ -18,14 +18,15 @@ package aws import ( "bytes" + "context" "errors" "flag" "os" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/stretchr/testify/assert" "k8s.io/klog/v2" @@ -96,9 +97,9 @@ func TestFindClusterID(t *testing.T) { }, } for _, g := range grid { - var ec2Tags []*ec2.Tag + var ec2Tags []ec2types.Tag for k, v := range g.Tags { - ec2Tags = append(ec2Tags, &ec2.Tag{Key: aws.String(k), Value: aws.String(v)}) + ec2Tags = append(ec2Tags, ec2types.Tag{Key: aws.String(k), Value: aws.String(v)}) } actualLegacy, actualNew, err := findClusterIDs(ec2Tags) if g.ExpectError { @@ -179,9 +180,9 @@ func TestHasClusterTag(t *testing.T) { }, } for _, g := range grid { - var ec2Tags []*ec2.Tag + var ec2Tags []ec2types.Tag for k, v := range g.Tags { - ec2Tags = append(ec2Tags, &ec2.Tag{Key: aws.String(k), Value: aws.String(v)}) + ec2Tags = append(ec2Tags, ec2types.Tag{Key: aws.String(k), Value: aws.String(v)}) } result := c.tagging.hasClusterTag(ec2Tags) if result != g.Expected { @@ -199,7 +200,7 @@ func TestHasNoClusterPrefixTag(t *testing.T) { } tests := []struct { name string - tags []*ec2.Tag + tags []ec2types.Tag want bool }{ { @@ -208,7 +209,7 @@ func TestHasNoClusterPrefixTag(t *testing.T) { }, { name: "no cluster tags", - tags: []*ec2.Tag{ + tags: []ec2types.Tag{ { Key: aws.String("not a cluster tag"), Value: aws.String("true"), @@ -218,7 +219,7 @@ func TestHasNoClusterPrefixTag(t *testing.T) { }, { name: "contains cluster tags", - tags: []*ec2.Tag{ + tags: []ec2types.Tag{ { Key: aws.String("tag1"), Value: aws.String("value1"), @@ -271,7 +272,7 @@ func TestTagResource(t *testing.T) { { name: "tagging failed due to resource not found error", instanceID: "i-not-found", - err: awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil), + err: errors.New("InvalidInstanceID.NotFound: Instance not found"), expectedMessage: "Error occurred trying to tag resources", }, } @@ -284,7 +285,7 @@ func TestTagResource(t *testing.T) { klog.SetOutput(os.Stderr) }() - err := c.TagResource(tt.instanceID, nil) + err := c.TagResource(context.TODO(), tt.instanceID, nil) assert.Equal(t, tt.err, err) assert.Contains(t, logBuf.String(), tt.expectedMessage) }) @@ -336,7 +337,7 @@ func TestUntagResource(t *testing.T) { klog.SetOutput(os.Stderr) }() - err := c.UntagResource(tt.instanceID, nil) + err := c.UntagResource(context.TODO(), tt.instanceID, nil) assert.Equal(t, tt.err, err) assert.Contains(t, logBuf.String(), tt.expectedMessage) }) diff --git a/pkg/resourcemanagers/topology.go b/pkg/providers/v1/topology.go similarity index 96% rename from pkg/resourcemanagers/topology.go rename to pkg/providers/v1/topology.go index 7cb0e5427c..989785d126 100644 --- a/pkg/resourcemanagers/topology.go +++ b/pkg/providers/v1/topology.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package resourcemanagers +package aws import ( "context" @@ -28,7 +28,7 @@ import ( "github.com/aws/smithy-go" "k8s.io/client-go/tools/cache" "k8s.io/cloud-provider-aws/pkg/providers/v1/config" - "k8s.io/cloud-provider-aws/pkg/services" + "k8s.io/cloud-provider-aws/pkg/providers/v1/iface" "k8s.io/klog/v2" ) @@ -66,13 +66,13 @@ type InstanceTopologyManager interface { // instanceTopologyManager manages getting instance topology for nodes. type instanceTopologyManager struct { - ec2 services.Ec2SdkV2 + ec2 iface.EC2 unsupportedKeyStore cache.Store supportedTopologyInstanceTypePattern *regexp.Regexp } // NewInstanceTopologyManager generates a new InstanceTopologyManager. -func NewInstanceTopologyManager(ec2 services.Ec2SdkV2, cfg *config.CloudConfig) InstanceTopologyManager { +func NewInstanceTopologyManager(ec2 iface.EC2, cfg *config.CloudConfig) InstanceTopologyManager { var supportedTopologyInstanceTypePattern *regexp.Regexp if cfg.Global.SupportedTopologyInstanceTypePattern != "" { supportedTopologyInstanceTypePattern = regexp.MustCompile(cfg.Global.SupportedTopologyInstanceTypePattern) diff --git a/pkg/resourcemanagers/topology_mock.go b/pkg/providers/v1/topology_mock.go similarity index 98% rename from pkg/resourcemanagers/topology_mock.go rename to pkg/providers/v1/topology_mock.go index 5c10f38573..8fc06d4c24 100644 --- a/pkg/resourcemanagers/topology_mock.go +++ b/pkg/providers/v1/topology_mock.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package resourcemanagers +package aws import ( "context" diff --git a/pkg/resourcemanagers/topology_test.go b/pkg/providers/v1/topology_test.go similarity index 74% rename from pkg/resourcemanagers/topology_test.go rename to pkg/providers/v1/topology_test.go index 4469d74eba..47e0b0a3fe 100644 --- a/pkg/resourcemanagers/topology_test.go +++ b/pkg/providers/v1/topology_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package resourcemanagers +package aws import ( "context" @@ -69,8 +69,8 @@ func TestDoesInstanceTypeRequireResponse(t *testing.T) { func TestGetNodeTopology(t *testing.T) { t.Run("Should skip nodes that don't have instance type set", func(t *testing.T) { - mockedEc2SdkV2 := services.MockedEc2SdkV2{} - topologyManager := NewInstanceTopologyManager(&mockedEc2SdkV2, &config.CloudConfig{}) + mockedEC2 := MockedFakeEC2{} + topologyManager := NewInstanceTopologyManager(&mockedEC2, &config.CloudConfig{}) // Loop multiple times to check cache use topology, err := topologyManager.GetNodeTopology(context.TODO(), "" /* empty instance type */, "some-region", "some-id") if err != nil { @@ -80,14 +80,14 @@ func TestGetNodeTopology(t *testing.T) { t.Errorf("Should not be returning a topology: %v", topology) } - mockedEc2SdkV2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 0) + mockedEC2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 0) }) t.Run("Should handle unsupported regions and utilize cache", func(t *testing.T) { - mockedEc2SdkV2 := services.MockedEc2SdkV2{} - topologyManager := NewInstanceTopologyManager(&mockedEc2SdkV2, &config.CloudConfig{}) + mockedEC2 := MockedFakeEC2{} + topologyManager := NewInstanceTopologyManager(&mockedEC2, &config.CloudConfig{}) - mockedEc2SdkV2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return(nil, + mockedEC2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return(nil, services.NewMockAPIError("UnsupportedOperation", "Not supported in region")) // Loop multiple times to check cache use @@ -101,14 +101,14 @@ func TestGetNodeTopology(t *testing.T) { } } - mockedEc2SdkV2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) + mockedEC2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) }) t.Run("Should handle unsupported instance types and utilize cache", func(t *testing.T) { - mockedEc2SdkV2 := services.MockedEc2SdkV2{} - topologyManager := NewInstanceTopologyManager(&mockedEc2SdkV2, &config.CloudConfig{}) + mockedEC2 := MockedFakeEC2{} + topologyManager := NewInstanceTopologyManager(&mockedEC2, &config.CloudConfig{}) - mockedEc2SdkV2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return([]types.InstanceTopology{}, nil) + mockedEC2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return([]types.InstanceTopology{}, nil) // Loop multiple times to check cache use for i := 0; i < 2; i++ { @@ -121,14 +121,14 @@ func TestGetNodeTopology(t *testing.T) { } } - mockedEc2SdkV2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) + mockedEC2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) }) t.Run("Should handle unsupported instance IDs and utilize cache", func(t *testing.T) { - mockedEc2SdkV2 := services.MockedEc2SdkV2{} - topologyManager := NewInstanceTopologyManager(&mockedEc2SdkV2, &config.CloudConfig{}) + mockedEC2 := MockedFakeEC2{} + topologyManager := NewInstanceTopologyManager(&mockedEC2, &config.CloudConfig{}) - mockedEc2SdkV2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return([]types.InstanceTopology{}, nil) + mockedEC2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return([]types.InstanceTopology{}, nil) // Loop multiple times to check cache use for i := 0; i < 2; i++ { @@ -142,14 +142,14 @@ func TestGetNodeTopology(t *testing.T) { } } - mockedEc2SdkV2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) + mockedEC2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) }) t.Run("Should handle missing permissions to call DescribeInstanceTopology", func(t *testing.T) { - mockedEc2SdkV2 := services.MockedEc2SdkV2{} - topologyManager := NewInstanceTopologyManager(&mockedEc2SdkV2, &config.CloudConfig{}) + mockedEC2 := MockedFakeEC2{} + topologyManager := NewInstanceTopologyManager(&mockedEC2, &config.CloudConfig{}) - mockedEc2SdkV2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return(nil, + mockedEC2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return(nil, services.NewMockAPIError("UnauthorizedOperation", "Update your perms")) // Loop multiple times to check cache use @@ -163,14 +163,14 @@ func TestGetNodeTopology(t *testing.T) { } } - mockedEc2SdkV2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) + mockedEC2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) }) t.Run("Should return error when exceeding request limits for DescribeInstanceTopology", func(t *testing.T) { - mockedEc2SdkV2 := services.MockedEc2SdkV2{} - topologyManager := NewInstanceTopologyManager(&mockedEc2SdkV2, &config.CloudConfig{}) + mockedEC2 := MockedFakeEC2{} + topologyManager := NewInstanceTopologyManager(&mockedEC2, &config.CloudConfig{}) - mockedEc2SdkV2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return(nil, + mockedEC2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return(nil, services.NewMockAPIError("RequestLimitExceeded", "Slow down!")) // Loop multiple times to check cache use @@ -181,14 +181,14 @@ func TestGetNodeTopology(t *testing.T) { } } - mockedEc2SdkV2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 2) + mockedEC2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 2) }) t.Run("Should return unhandled errors", func(t *testing.T) { - mockedEc2SdkV2 := services.MockedEc2SdkV2{} - topologyManager := NewInstanceTopologyManager(&mockedEc2SdkV2, &config.CloudConfig{}) + mockedEC2 := MockedFakeEC2{} + topologyManager := NewInstanceTopologyManager(&mockedEC2, &config.CloudConfig{}) - mockedEc2SdkV2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return(nil, + mockedEC2.On("DescribeInstanceTopology", mock.Anything, mock.Anything).Return(nil, services.NewMockAPIError("NOPE", "Nice try.")) _, err := topologyManager.GetNodeTopology(context.TODO(), "some-type", "some-region", "some-id") @@ -196,6 +196,6 @@ func TestGetNodeTopology(t *testing.T) { t.Errorf("Should have gotten an error") } - mockedEc2SdkV2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) + mockedEC2.AssertNumberOfCalls(t, "DescribeInstanceTopology", 1) }) } diff --git a/pkg/providers/v1/variant/fargate/fargate.go b/pkg/providers/v1/variant/fargate/fargate.go index f4d7174603..93866bc7e2 100644 --- a/pkg/providers/v1/variant/fargate/fargate.go +++ b/pkg/providers/v1/variant/fargate/fargate.go @@ -1,12 +1,13 @@ package fargate import ( + "context" "fmt" "strings" - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" v1 "k8s.io/api/core/v1" cloudprovider "k8s.io/cloud-provider" @@ -25,11 +26,11 @@ const ( type fargateVariant struct { cloudConfig *config.CloudConfig ec2API iface.EC2 - credentials *credentials.Credentials + credentials aws.CredentialsProvider provider config.SDKProvider } -func (f *fargateVariant) Initialize(cloudConfig *config.CloudConfig, credentials *credentials.Credentials, provider config.SDKProvider, ec2API iface.EC2, region string) error { +func (f *fargateVariant) Initialize(cloudConfig *config.CloudConfig, credentials aws.CredentialsProvider, provider config.SDKProvider, ec2API iface.EC2, region string) error { f.cloudConfig = cloudConfig f.ec2API = ec2API f.credentials = credentials @@ -41,8 +42,8 @@ func (f *fargateVariant) InstanceTypeByProviderID(instanceID string) (string, er return "", nil } -func (f *fargateVariant) GetZone(instanceID, vpcID, region string) (cloudprovider.Zone, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) GetZone(ctx context.Context, instanceID, vpcID, region string) (cloudprovider.Zone, error) { + eni, err := f.DescribeNetworkInterfaces(ctx, f.ec2API, instanceID, vpcID) if eni == nil || err != nil { return cloudprovider.Zone{}, err } @@ -56,8 +57,8 @@ func (f *fargateVariant) IsSupportedNode(nodeName string) bool { return strings.HasPrefix(nodeName, fargateNodeNamePrefix) } -func (f *fargateVariant) NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddress, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) NodeAddresses(ctx context.Context, instanceID, vpcID string) ([]v1.NodeAddress, error) { + eni, err := f.DescribeNetworkInterfaces(ctx, f.ec2API, instanceID, vpcID) if eni == nil || err != nil { return nil, err } @@ -68,7 +69,7 @@ func (f *fargateVariant) NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddre for _, family := range f.cloudConfig.Global.NodeIPFamilies { switch family { case "ipv4": - nodeAddresses := getNodeAddressesForFargateNode(awssdk.StringValue(eni.PrivateDnsName), awssdk.StringValue(eni.PrivateIpAddress)) + nodeAddresses := getNodeAddressesForFargateNode(aws.ToString(eni.PrivateDnsName), aws.ToString(eni.PrivateIpAddress)) addresses = append(addresses, nodeAddresses...) case "ipv6": if eni.Ipv6Addresses == nil || len(eni.Ipv6Addresses) == 0 { @@ -76,29 +77,29 @@ func (f *fargateVariant) NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddre continue } internalIPv6Address := eni.Ipv6Addresses[0].Ipv6Address - nodeAddresses := getNodeAddressesForFargateNode(awssdk.StringValue(eni.PrivateDnsName), awssdk.StringValue(internalIPv6Address)) + nodeAddresses := getNodeAddressesForFargateNode(aws.ToString(eni.PrivateDnsName), aws.ToString(internalIPv6Address)) addresses = append(addresses, nodeAddresses...) } } return addresses, nil } -func (f *fargateVariant) InstanceExists(instanceID, vpcID string) (bool, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) InstanceExists(ctx context.Context, instanceID, vpcID string) (bool, error) { + eni, err := f.DescribeNetworkInterfaces(ctx, f.ec2API, instanceID, vpcID) return eni != nil, err } -func (f *fargateVariant) InstanceShutdown(instanceID, vpcID string) (bool, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) InstanceShutdown(ctx context.Context, instanceID, vpcID string) (bool, error) { + eni, err := f.DescribeNetworkInterfaces(ctx, f.ec2API, instanceID, vpcID) return eni != nil, err } -func newEc2Filter(name string, values ...string) *ec2.Filter { - filter := &ec2.Filter{ - Name: awssdk.String(name), +func newEc2Filter(name string, values ...string) ec2types.Filter { + filter := ec2types.Filter{ + Name: aws.String(name), } for _, value := range values { - filter.Values = append(filter.Values, awssdk.String(value)) + filter.Values = append(filter.Values, value) } return filter } @@ -116,10 +117,10 @@ func nodeNameToIPAddress(nodeName string) string { } // DescribeNetworkInterfaces returns network interface information for the given DNS name. -func (f *fargateVariant) DescribeNetworkInterfaces(ec2API iface.EC2, instanceID, vpcID string) (*ec2.NetworkInterface, error) { +func (f *fargateVariant) DescribeNetworkInterfaces(ctx context.Context, ec2API iface.EC2, instanceID, vpcID string) (*ec2types.NetworkInterface, error) { eniEndpoint := strings.TrimPrefix(instanceID, fargateNodeNamePrefix) - filters := []*ec2.Filter{ + filters := []ec2types.Filter{ newEc2Filter("attachment.status", "attached"), newEc2Filter("vpc-id", vpcID), } @@ -137,7 +138,7 @@ func (f *fargateVariant) DescribeNetworkInterfaces(ec2API iface.EC2, instanceID, Filters: filters, } - eni, err := ec2API.DescribeNetworkInterfaces(request) + eni, err := ec2API.DescribeNetworkInterfaces(ctx, request) if err != nil { return nil, err } @@ -146,9 +147,9 @@ func (f *fargateVariant) DescribeNetworkInterfaces(ec2API iface.EC2, instanceID, } if len(eni.NetworkInterfaces) != 1 { // This should not be possible - ids should be unique - return nil, fmt.Errorf("multiple interfaces found with same id %q", eni.NetworkInterfaces) + return nil, fmt.Errorf("multiple interfaces found with same id %+v", eni.NetworkInterfaces) } - return eni.NetworkInterfaces[0], nil + return &eni.NetworkInterfaces[0], nil } func init() { diff --git a/pkg/providers/v1/variant/variant.go b/pkg/providers/v1/variant/variant.go index 39df86b795..222141ae40 100644 --- a/pkg/providers/v1/variant/variant.go +++ b/pkg/providers/v1/variant/variant.go @@ -1,13 +1,14 @@ package variant import ( + "context" "fmt" "sync" v1 "k8s.io/api/core/v1" cloudprovider "k8s.io/cloud-provider" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" "k8s.io/cloud-provider-aws/pkg/providers/v1/config" "k8s.io/cloud-provider-aws/pkg/providers/v1/iface" @@ -18,13 +19,13 @@ var variants = make(map[string]Variant) // Variant is a slightly different type of node type Variant interface { - Initialize(cloudConfig *config.CloudConfig, credentials *credentials.Credentials, + Initialize(cloudConfig *config.CloudConfig, credentials aws.CredentialsProvider, provider config.SDKProvider, ec2API iface.EC2, region string) error IsSupportedNode(nodeName string) bool - NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddress, error) - GetZone(instanceID, vpcID, region string) (cloudprovider.Zone, error) - InstanceExists(instanceID, vpcID string) (bool, error) - InstanceShutdown(instanceID, vpcID string) (bool, error) + NodeAddresses(ctx context.Context, instanceID, vpcID string) ([]v1.NodeAddress, error) + GetZone(ctx context.Context, instanceID, vpcID, region string) (cloudprovider.Zone, error) + InstanceExists(ctx context.Context, instanceID, vpcID string) (bool, error) + InstanceShutdown(ctx context.Context, instanceID, vpcID string) (bool, error) InstanceTypeByProviderID(id string) (string, error) } diff --git a/pkg/providers/v1/zones.go b/pkg/providers/v1/zones.go index 92d6fd2744..f8141021ee 100644 --- a/pkg/providers/v1/zones.go +++ b/pkg/providers/v1/zones.go @@ -17,11 +17,12 @@ limitations under the License. package aws import ( + "context" "fmt" "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" "k8s.io/klog/v2" ) @@ -37,8 +38,8 @@ type zoneCache struct { zoneNameToDetails map[string]zoneDetails } -func (z *zoneCache) getZoneIDByZoneName(zoneName string) (string, error) { - zoneNameToDetails, err := z.getZoneDetailsByNames([]string{zoneName}) +func (z *zoneCache) getZoneIDByZoneName(ctx context.Context, zoneName string) (string, error) { + zoneNameToDetails, err := z.getZoneDetailsByNames(ctx, []string{zoneName}) if err != nil { return "", err } @@ -53,14 +54,14 @@ func (z *zoneCache) getZoneIDByZoneName(zoneName string) (string, error) { // Get the zone details by zone names and load from the cache if available as // zone information should never change. -func (z *zoneCache) getZoneDetailsByNames(zoneNames []string) (map[string]zoneDetails, error) { +func (z *zoneCache) getZoneDetailsByNames(ctx context.Context, zoneNames []string) (map[string]zoneDetails, error) { if len(zoneNames) == 0 { return map[string]zoneDetails{}, nil } if z.shouldPopulateCache(zoneNames) { // Populate the cache if it hasn't been populated yet - err := z.populate() + err := z.populate(ctx) if err != nil { return nil, err } @@ -103,12 +104,12 @@ func (z *zoneCache) shouldPopulateCache(zoneNames []string) bool { // Populates the zone cache. If cache is already populated, it will overwrite entries, // which is useful when accounts get access to new zones. -func (z *zoneCache) populate() error { +func (z *zoneCache) populate(ctx context.Context) error { z.mutex.Lock() defer z.mutex.Unlock() azRequest := &ec2.DescribeAvailabilityZonesInput{} - zones, err := z.cloud.ec2.DescribeAvailabilityZones(azRequest) + zones, err := z.cloud.ec2.DescribeAvailabilityZones(ctx, azRequest) if err != nil { return fmt.Errorf("error describe availability zones: %q", err) } @@ -119,11 +120,11 @@ func (z *zoneCache) populate() error { } for _, zone := range zones { - name := aws.StringValue(zone.ZoneName) + name := aws.ToString(zone.ZoneName) z.zoneNameToDetails[name] = zoneDetails{ name: name, - id: aws.StringValue(zone.ZoneId), - zoneType: aws.StringValue(zone.ZoneType), + id: aws.ToString(zone.ZoneId), + zoneType: aws.ToString(zone.ZoneType), } } diff --git a/pkg/providers/v1/zones_test.go b/pkg/providers/v1/zones_test.go index d45ae332b0..111e58149f 100644 --- a/pkg/providers/v1/zones_test.go +++ b/pkg/providers/v1/zones_test.go @@ -17,11 +17,15 @@ limitations under the License. package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "testing" ) func TestGetZoneIDByZoneName(t *testing.T) { @@ -47,7 +51,7 @@ func TestGetZoneIDByZoneName(t *testing.T) { t.Run(tc.name, func(t *testing.T) { c, _ := getCloudWithMockedDescribeAvailabilityZones() - result, err := c.zoneCache.getZoneIDByZoneName(tc.zoneName) + result, err := c.zoneCache.getZoneIDByZoneName(context.TODO(), tc.zoneName) if tc.expectError { if err == nil { t.Error("Expected to see an error") @@ -107,7 +111,7 @@ func TestGetZoneDetailsByNames(t *testing.T) { t.Run(tc.name, func(t *testing.T) { c, mockedEC2API := getCloudWithMockedDescribeAvailabilityZones() - result, err := c.zoneCache.getZoneDetailsByNames(tc.zones) + result, err := c.zoneCache.getZoneDetailsByNames(context.TODO(), tc.zones) if err != nil { t.Errorf("Should not error getting zone details: %s", err) } @@ -115,7 +119,7 @@ func TestGetZoneDetailsByNames(t *testing.T) { assert.Equal(t, tc.expectedResult, result, "Should return the expected zones") // Call again to verify expected caching behavior - result, err = c.zoneCache.getZoneDetailsByNames(tc.zones) + result, err = c.zoneCache.getZoneDetailsByNames(context.TODO(), tc.zones) if err != nil { t.Errorf("Should not error getting zone details: %s", err) } @@ -130,7 +134,7 @@ func getCloudWithMockedDescribeAvailabilityZones() (*Cloud, *MockedEC2API) { c.zoneCache = zoneCache{cloud: c} mockedEC2API.On("DescribeAvailabilityZones", mock.Anything).Return(&ec2.DescribeAvailabilityZonesOutput{ - AvailabilityZones: []*ec2.AvailabilityZone{ + AvailabilityZones: []ec2types.AvailabilityZone{ { ZoneName: aws.String("az1"), ZoneId: aws.String("az1-id"), diff --git a/pkg/services/aws_ec2.go b/pkg/services/aws_ec2.go deleted file mode 100644 index a2848bac36..0000000000 --- a/pkg/services/aws_ec2.go +++ /dev/null @@ -1,77 +0,0 @@ -/* -Copyright 2024 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package services - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/aws/aws-sdk-go-v2/service/ec2/types" -) - -// EC2ClientV2 is an interface to allow it to be mocked. -type EC2ClientV2 interface { - DescribeInstanceTopology(ctx context.Context, params *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTopologyOutput, error) -} - -// Ec2SdkV2 is an implementation of the EC2 v2 interface, backed by aws-sdk-go-v2 -type Ec2SdkV2 interface { - DescribeInstanceTopology(ctx context.Context, request *ec2.DescribeInstanceTopologyInput) ([]types.InstanceTopology, error) -} - -// ec2SdkV2 is an implementation of the EC2 v2 interface, backed by aws-sdk-go-v2 -type ec2SdkV2 struct { - Ec2 EC2ClientV2 -} - -// NewEc2SdkV2 is a constructor for Ec2SdkV2 that creates a default EC2 client. -func NewEc2SdkV2(ctx context.Context, region string, assumeRoleProvider *stscreds.AssumeRoleProvider) (Ec2SdkV2, error) { - cfg, err := config.LoadDefaultConfig(ctx) - if err != nil { - return nil, err - } - - // Don't override the default creds if the assume role provider isn't set. - if assumeRoleProvider != nil { - cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) - } - - client := ec2.NewFromConfig(cfg, func(o *ec2.Options) { - o.Region = region - }) - - return &ec2SdkV2{Ec2: client}, nil -} - -// DescribeInstanceTopology paginates calls to EC2 DescribeInstanceTopology API. -func (s *ec2SdkV2) DescribeInstanceTopology(ctx context.Context, request *ec2.DescribeInstanceTopologyInput) ([]types.InstanceTopology, error) { - var topologies []types.InstanceTopology - - paginator := ec2.NewDescribeInstanceTopologyPaginator(s.Ec2, request) - for paginator.HasMorePages() { - output, err := paginator.NextPage(ctx) - if err != nil { - return nil, err - } - topologies = append(topologies, output.Instances...) - } - - return topologies, nil -} diff --git a/pkg/services/aws_ec2_mock.go b/pkg/services/aws_ec2_mock.go deleted file mode 100644 index a037561cf9..0000000000 --- a/pkg/services/aws_ec2_mock.go +++ /dev/null @@ -1,55 +0,0 @@ -/* -Copyright 2024 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package services - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/stretchr/testify/mock" -) - -// MockedEC2ClientV2 mocks EC2ClientV2. -type MockedEC2ClientV2 struct { - EC2ClientV2 - mock.Mock -} - -// DescribeInstanceTopology mocks EC2ClientV2.DescribeInstanceTopology. -func (m *MockedEC2ClientV2) DescribeInstanceTopology(ctx context.Context, params *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTopologyOutput, error) { - args := m.Called(ctx, params) - if args.Get(1) != nil { - return nil, args.Get(1).(error) - } - return args.Get(0).(*ec2.DescribeInstanceTopologyOutput), nil -} - -// MockedEc2SdkV2 is an implementation of the EC2 v2 interface, backed by aws-sdk-go-v2 -type MockedEc2SdkV2 struct { - Ec2SdkV2 - mock.Mock -} - -// DescribeInstanceTopology mocks EC2ClientV2.DescribeInstanceTopology. -func (m *MockedEc2SdkV2) DescribeInstanceTopology(ctx context.Context, request *ec2.DescribeInstanceTopologyInput) ([]types.InstanceTopology, error) { - args := m.Called(ctx, request) - if args.Get(1) != nil { - return nil, args.Get(1).(error) - } - return args.Get(0).([]types.InstanceTopology), nil -} diff --git a/pkg/services/aws_sts.go b/pkg/services/aws_sts.go index 5f96968616..f62cf01e90 100644 --- a/pkg/services/aws_sts.go +++ b/pkg/services/aws_sts.go @@ -64,8 +64,8 @@ func WithStsHeadersMiddleware(headers map[string]string) func(*sts.Options) { } } -// NewStsV2Client provides a new STS client. -func NewStsV2Client(ctx context.Context, region, roleARN, sourceARN string) (*sts.Client, error) { +// NewStsClient provides a new STS client. +func NewStsClient(ctx context.Context, region, roleARN, sourceARN string) (*sts.Client, error) { klog.Infof("Using AWS assumed role %v", roleARN) cfg, err := config.LoadDefaultConfig(ctx) if err != nil { diff --git a/version.txt b/version.txt index c26dc96979..c00278b9b9 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -1.31.7 +1.31.8