Skip to content

Commit d4d38fc

Browse files
committed
init
1 parent 1b54244 commit d4d38fc

File tree

2 files changed

+70
-70
lines changed

2 files changed

+70
-70
lines changed

internal/provider/aws_credentials.go

Lines changed: 67 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,39 @@ const (
2626
)
2727

2828
func getAWSCredentials(c *config.AWSVars) (*config.Credentials, error) {
29-
ctx := context.TODO()
29+
ctx := context.TODO()
3030

31-
// Base config with static credentials
32-
cfg, err := awsconfig.LoadDefaultConfig(ctx,
33-
awsconfig.WithRegion(c.Region),
34-
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(c.AccessKeyID, c.SecretAccessKey, c.SessionToken)),
35-
)
31+
// Base config with static credentials
32+
cfg, err := awsconfig.LoadDefaultConfig(ctx,
33+
awsconfig.WithRegion(c.Region),
34+
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(c.AccessKeyID, c.SecretAccessKey, c.SessionToken)),
35+
)
36+
if err != nil {
37+
return nil, err
38+
}
39+
ep, signingRegion := ResolveSTSEndpoint(c.Endpoint, c.Region)
40+
41+
// STS client with custom endpoint and signing region when needed
42+
stsClient := sts.NewFromConfig(cfg, func(o *sts.Options) {
43+
// Always set region to derived signing region
44+
o.Region = signingRegion
45+
if ep != "" {
46+
o.EndpointResolver = sts.EndpointResolverFromURL(ep)
47+
}
48+
})
49+
50+
// Assume role provider using STS client
51+
assumeRoleProvider := stscreds.NewAssumeRoleProvider(stsClient, c.AssumeRoleARN)
52+
53+
// Secrets Manager client using the assumed role credentials
54+
smCfg := cfg
55+
smCfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider)
56+
smClient := secretsmanager.NewFromConfig(smCfg)
57+
58+
secretString, err := secretsManagerGetSecretValue(ctx, smClient, c.SecretName)
3659
if err != nil {
3760
return nil, err
3861
}
39-
ep, signingRegion := ResolveSTSEndpoint(c.Endpoint, c.Region)
40-
41-
// STS client with custom endpoint and signing region when needed
42-
stsClient := sts.NewFromConfig(cfg, func(o *sts.Options) {
43-
// Always set region to derived signing region
44-
o.Region = signingRegion
45-
if ep != "" {
46-
o.EndpointResolver = sts.EndpointResolverFromURL(ep)
47-
}
48-
})
49-
50-
// Assume role provider using STS client
51-
assumeRoleProvider := stscreds.NewAssumeRoleProvider(stsClient, c.AssumeRoleARN)
52-
53-
// Secrets Manager client using the assumed role credentials
54-
smCfg := cfg
55-
smCfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider)
56-
smClient := secretsmanager.NewFromConfig(smCfg)
57-
58-
secretString, err := secretsManagerGetSecretValue(ctx, smClient, c.SecretName)
59-
if err != nil {
60-
return nil, err
61-
}
6262
var secret config.Credentials
6363
err = json.Unmarshal([]byte(secretString), &secret)
6464
if err != nil {
@@ -89,45 +89,45 @@ func DeriveSTSRegionFromEndpoint(ep string) string {
8989
}
9090

9191
func ResolveSTSEndpoint(stsEndpoint, secretsRegion string) (string, string) {
92-
ep := stsEndpoint
93-
if ep == "" {
94-
r := secretsRegion
95-
if r == "" {
96-
r = DefaultRegionSTS
97-
}
98-
ep = fmt.Sprintf("https://sts.%s.amazonaws.com/", r)
99-
}
100-
signingRegion := DeriveSTSRegionFromEndpoint(ep)
101-
return ep, signingRegion
92+
ep := stsEndpoint
93+
if ep == "" {
94+
r := secretsRegion
95+
if r == "" {
96+
r = DefaultRegionSTS
97+
}
98+
ep = fmt.Sprintf("https://sts.%s.amazonaws.com/", r)
99+
}
100+
signingRegion := DeriveSTSRegionFromEndpoint(ep)
101+
return ep, signingRegion
102102
}
103103

104104
func secretsManagerGetSecretValue(ctx context.Context, client *secretsmanager.Client, secret string) (string, error) {
105-
input := &secretsmanager.GetSecretValueInput{
106-
SecretId: aws.String(secret),
107-
VersionStage: aws.String("AWSCURRENT"),
108-
}
109-
110-
result, err := client.GetSecretValue(ctx, input)
111-
if err != nil {
112-
switch e := err.(type) {
113-
case *smtypes.ResourceNotFoundException:
114-
log.Println("ResourceNotFoundException", e.Error())
115-
case *smtypes.InvalidParameterException:
116-
log.Println("InvalidParameterException", e.Error())
117-
case *smtypes.InvalidRequestException:
118-
log.Println("InvalidRequestException", e.Error())
119-
case *smtypes.DecryptionFailure:
120-
log.Println("DecryptionFailure", e.Error())
121-
case *smtypes.InternalServiceError:
122-
log.Println("InternalServiceError", e.Error())
123-
default:
124-
log.Println(err.Error())
125-
}
126-
return "", err
127-
}
128-
129-
if result.SecretString == nil {
130-
return "", fmt.Errorf("secret string is nil for secret %s", secret)
131-
}
132-
return *result.SecretString, nil
105+
input := &secretsmanager.GetSecretValueInput{
106+
SecretId: aws.String(secret),
107+
VersionStage: aws.String("AWSCURRENT"),
108+
}
109+
110+
result, err := client.GetSecretValue(ctx, input)
111+
if err != nil {
112+
switch e := err.(type) {
113+
case *smtypes.ResourceNotFoundException:
114+
log.Println("ResourceNotFoundException", e.Error())
115+
case *smtypes.InvalidParameterException:
116+
log.Println("InvalidParameterException", e.Error())
117+
case *smtypes.InvalidRequestException:
118+
log.Println("InvalidRequestException", e.Error())
119+
case *smtypes.DecryptionFailure:
120+
log.Println("DecryptionFailure", e.Error())
121+
case *smtypes.InternalServiceError:
122+
log.Println("InternalServiceError", e.Error())
123+
default:
124+
log.Println(err.Error())
125+
}
126+
return "", err
127+
}
128+
129+
if result.SecretString == nil {
130+
return "", fmt.Errorf("secret string is nil for secret %s", secret)
131+
}
132+
return *result.SecretString, nil
133133
}

internal/provider/aws_credentials_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ func Test_resolveSTSEndpoint(t *testing.T) {
8585

8686
for testName, tc := range testCases {
8787
t.Run(testName, func(t *testing.T) {
88-
ep, sign := provider.ResolveSTSEndpoint(tc.stsEndpoint, tc.secretsRegion)
89-
assert.Equal(t, tc.expectedURL, ep)
90-
assert.Equal(t, tc.expectedSign, sign)
88+
ep, sign := provider.ResolveSTSEndpoint(tc.stsEndpoint, tc.secretsRegion)
89+
assert.Equal(t, tc.expectedURL, ep)
90+
assert.Equal(t, tc.expectedSign, sign)
9191
})
9292
}
9393
}

0 commit comments

Comments
 (0)