Skip to content

Commit 9fe2bfc

Browse files
authored
Merge pull request #42143 from DrFaust92/sagemaker-image-version-args
r/sagemaker_image_version - add args
2 parents 44c820c + 7192870 commit 9fe2bfc

File tree

4 files changed

+220
-6
lines changed

4 files changed

+220
-6
lines changed

.changelog/42143.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
```release-note:enhancement
2+
resource/aws_sagemaker_image_version: Add `horovod`, `job_type`, `ml_framework`, `processor`, `programming_lang`, `release_notes`, and `vendor_guidance` arguments
3+
```

internal/service/sagemaker/image_version.go

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@ import (
77
"context"
88
"log"
99

10+
"github.com/YakDriver/regexache"
1011
"github.com/aws/aws-sdk-go-v2/aws"
1112
"github.com/aws/aws-sdk-go-v2/service/sagemaker"
1213
awstypes "github.com/aws/aws-sdk-go-v2/service/sagemaker/types"
1314
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
15+
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/id"
1416
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry"
1517
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
18+
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
1619
"github.com/hashicorp/terraform-provider-aws/internal/conns"
20+
"github.com/hashicorp/terraform-provider-aws/internal/enum"
1721
"github.com/hashicorp/terraform-provider-aws/internal/errs"
1822
"github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag"
1923
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
@@ -25,6 +29,7 @@ func resourceImageVersion() *schema.Resource {
2529
return &schema.Resource{
2630
CreateWithoutTimeout: resourceImageVersionCreate,
2731
ReadWithoutTimeout: resourceImageVersionRead,
32+
UpdateWithoutTimeout: resourceImageVersionUpdate,
2833
DeleteWithoutTimeout: resourceImageVersionDelete,
2934
Importer: &schema.ResourceImporter{
3035
StateContext: schema.ImportStatePassthroughContext,
@@ -44,6 +49,10 @@ func resourceImageVersion() *schema.Resource {
4449
Type: schema.TypeString,
4550
Computed: true,
4651
},
52+
"horovod": {
53+
Type: schema.TypeBool,
54+
Optional: true,
55+
},
4756
"image_arn": {
4857
Type: schema.TypeString,
4958
Computed: true,
@@ -53,6 +62,36 @@ func resourceImageVersion() *schema.Resource {
5362
Required: true,
5463
ForceNew: true,
5564
},
65+
"job_type": {
66+
Type: schema.TypeString,
67+
Optional: true,
68+
ValidateDiagFunc: enum.Validate[awstypes.JobType](),
69+
},
70+
"ml_framework": {
71+
Type: schema.TypeString,
72+
Optional: true,
73+
ValidateFunc: validation.StringMatch(regexache.MustCompile(`^[a-zA-Z]+ ?\d+\.\d+(\.\d+)?$`), ""),
74+
},
75+
"processor": {
76+
Type: schema.TypeString,
77+
Optional: true,
78+
ValidateDiagFunc: enum.Validate[awstypes.Processor](),
79+
},
80+
"programming_lang": {
81+
Type: schema.TypeString,
82+
Optional: true,
83+
ValidateFunc: validation.StringMatch(regexache.MustCompile(`^[a-zA-Z]+ ?\d+\.\d+(\.\d+)?$`), ""),
84+
},
85+
"release_notes": {
86+
Type: schema.TypeString,
87+
Optional: true,
88+
ValidateFunc: validation.StringLenBetween(0, 255),
89+
},
90+
"vendor_guidance": {
91+
Type: schema.TypeString,
92+
Optional: true,
93+
ValidateDiagFunc: enum.Validate[awstypes.VendorGuidance](),
94+
},
5695
names.AttrVersion: {
5796
Type: schema.TypeInt,
5897
Computed: true,
@@ -67,8 +106,37 @@ func resourceImageVersionCreate(ctx context.Context, d *schema.ResourceData, met
67106

68107
name := d.Get("image_name").(string)
69108
input := &sagemaker.CreateImageVersionInput{
70-
ImageName: aws.String(name),
71-
BaseImage: aws.String(d.Get("base_image").(string)),
109+
ImageName: aws.String(name),
110+
BaseImage: aws.String(d.Get("base_image").(string)),
111+
ClientToken: aws.String(id.UniqueId()),
112+
}
113+
114+
if v, ok := d.GetOk("job_type"); ok {
115+
input.JobType = awstypes.JobType(v.(string))
116+
}
117+
118+
if v, ok := d.GetOk("processor"); ok {
119+
input.Processor = awstypes.Processor(v.(string))
120+
}
121+
122+
if v, ok := d.GetOk("release_notes"); ok {
123+
input.ReleaseNotes = aws.String(v.(string))
124+
}
125+
126+
if v, ok := d.GetOk("vendor_guidance"); ok {
127+
input.VendorGuidance = awstypes.VendorGuidance(v.(string))
128+
}
129+
130+
if v, ok := d.GetOk("horovod"); ok {
131+
input.Horovod = aws.Bool(v.(bool))
132+
}
133+
134+
if v, ok := d.GetOk("ml_framework"); ok {
135+
input.MLFramework = aws.String(v.(string))
136+
}
137+
138+
if v, ok := d.GetOk("programming_lang"); ok {
139+
input.ProgrammingLang = aws.String(v.(string))
72140
}
73141

74142
_, err := conn.CreateImageVersion(ctx, input)
@@ -107,10 +175,61 @@ func resourceImageVersionRead(ctx context.Context, d *schema.ResourceData, meta
107175
d.Set("container_image", image.ContainerImage)
108176
d.Set(names.AttrVersion, image.Version)
109177
d.Set("image_name", d.Id())
178+
d.Set("horovod", image.Horovod)
179+
d.Set("job_type", image.JobType)
180+
d.Set("processor", image.Processor)
181+
d.Set("release_notes", image.ReleaseNotes)
182+
d.Set("vendor_guidance", image.VendorGuidance)
183+
d.Set("ml_framework", image.MLFramework)
184+
d.Set("programming_lang", image.ProgrammingLang)
110185

111186
return diags
112187
}
113188

189+
func resourceImageVersionUpdate(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
190+
var diags diag.Diagnostics
191+
conn := meta.(*conns.AWSClient).SageMakerClient(ctx)
192+
193+
input := &sagemaker.UpdateImageVersionInput{
194+
ImageName: aws.String(d.Id()),
195+
Version: aws.Int32(int32(d.Get(names.AttrVersion).(int))),
196+
}
197+
198+
if d.HasChange("horovod") {
199+
input.Horovod = aws.Bool(d.Get("horovod").(bool))
200+
}
201+
202+
if d.HasChange("job_type") {
203+
input.JobType = awstypes.JobType(d.Get("job_type").(string))
204+
}
205+
206+
if d.HasChange("processor") {
207+
input.Processor = awstypes.Processor(d.Get("processor").(string))
208+
}
209+
210+
if d.HasChange("release_notes") {
211+
input.ReleaseNotes = aws.String(d.Get("release_notes").(string))
212+
}
213+
214+
if d.HasChange("vendor_guidance") {
215+
input.VendorGuidance = awstypes.VendorGuidance(d.Get("vendor_guidance").(string))
216+
}
217+
218+
if d.HasChange("ml_framework") {
219+
input.MLFramework = aws.String(d.Get("ml_framework").(string))
220+
}
221+
222+
if d.HasChange("programming_lang") {
223+
input.ProgrammingLang = aws.String(d.Get("programming_lang").(string))
224+
}
225+
226+
if _, err := conn.UpdateImageVersion(ctx, input); err != nil {
227+
return sdkdiag.AppendErrorf(diags, "updating SageMaker AI Image Version (%s): %s", d.Id(), err)
228+
}
229+
230+
return append(diags, resourceImageVersionRead(ctx, d, meta)...)
231+
}
232+
114233
func resourceImageVersionDelete(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
115234
var diags diag.Diagnostics
116235
conn := meta.(*conns.AWSClient).SageMakerClient(ctx)

internal/service/sagemaker/image_version_test.go

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ func TestAccSageMakerImageVersion_basic(t *testing.T) {
4747
acctest.CheckResourceAttrRegionalARN(ctx, resourceName, "image_arn", "sagemaker", fmt.Sprintf("image/%s", rName)),
4848
acctest.CheckResourceAttrRegionalARN(ctx, resourceName, names.AttrARN, "sagemaker", fmt.Sprintf("image-version/%s/1", rName)),
4949
resource.TestCheckResourceAttrSet(resourceName, "container_image"),
50+
resource.TestCheckResourceAttr(resourceName, "horovod", acctest.CtFalse),
5051
),
5152
},
5253
{
@@ -58,6 +59,71 @@ func TestAccSageMakerImageVersion_basic(t *testing.T) {
5859
})
5960
}
6061

62+
func TestAccSageMakerImageVersion_full(t *testing.T) {
63+
ctx := acctest.Context(t)
64+
if os.Getenv("SAGEMAKER_IMAGE_VERSION_BASE_IMAGE") == "" {
65+
t.Skip("Environment variable SAGEMAKER_IMAGE_VERSION_BASE_IMAGE is not set")
66+
}
67+
68+
var image sagemaker.DescribeImageVersionOutput
69+
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
70+
rNameUpdate := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
71+
resourceName := "aws_sagemaker_image_version.test"
72+
baseImage := os.Getenv("SAGEMAKER_IMAGE_VERSION_BASE_IMAGE")
73+
74+
resource.ParallelTest(t, resource.TestCase{
75+
PreCheck: func() { acctest.PreCheck(ctx, t) },
76+
ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID),
77+
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
78+
CheckDestroy: testAccCheckImageVersionDestroy(ctx),
79+
Steps: []resource.TestStep{
80+
{
81+
Config: testAccImageVersionConfig_full(rName, baseImage, rName),
82+
Check: resource.ComposeTestCheckFunc(
83+
testAccCheckImageVersionExists(ctx, resourceName, &image),
84+
resource.TestCheckResourceAttr(resourceName, "image_name", rName),
85+
resource.TestCheckResourceAttr(resourceName, "base_image", baseImage),
86+
resource.TestCheckResourceAttr(resourceName, names.AttrVersion, "1"),
87+
acctest.CheckResourceAttrRegionalARN(ctx, resourceName, "image_arn", "sagemaker", fmt.Sprintf("image/%s", rName)),
88+
acctest.CheckResourceAttrRegionalARN(ctx, resourceName, names.AttrARN, "sagemaker", fmt.Sprintf("image-version/%s/1", rName)),
89+
resource.TestCheckResourceAttrSet(resourceName, "container_image"),
90+
resource.TestCheckResourceAttr(resourceName, "horovod", acctest.CtFalse),
91+
resource.TestCheckResourceAttr(resourceName, "processor", "CPU"),
92+
resource.TestCheckResourceAttr(resourceName, "vendor_guidance", "STABLE"),
93+
resource.TestCheckResourceAttr(resourceName, "release_notes", rName),
94+
resource.TestCheckResourceAttr(resourceName, "job_type", "TRAINING"),
95+
resource.TestCheckResourceAttr(resourceName, "ml_framework", "TensorFlow 1.1"),
96+
resource.TestCheckResourceAttr(resourceName, "programming_lang", "Python 3.8"),
97+
),
98+
},
99+
{
100+
ResourceName: resourceName,
101+
ImportState: true,
102+
ImportStateVerify: true,
103+
},
104+
{
105+
Config: testAccImageVersionConfig_full(rName, baseImage, rNameUpdate),
106+
Check: resource.ComposeTestCheckFunc(
107+
testAccCheckImageVersionExists(ctx, resourceName, &image),
108+
resource.TestCheckResourceAttr(resourceName, "image_name", rName),
109+
resource.TestCheckResourceAttr(resourceName, "base_image", baseImage),
110+
resource.TestCheckResourceAttr(resourceName, names.AttrVersion, "1"),
111+
acctest.CheckResourceAttrRegionalARN(ctx, resourceName, "image_arn", "sagemaker", fmt.Sprintf("image/%s", rName)),
112+
acctest.CheckResourceAttrRegionalARN(ctx, resourceName, names.AttrARN, "sagemaker", fmt.Sprintf("image-version/%s/1", rName)),
113+
resource.TestCheckResourceAttrSet(resourceName, "container_image"),
114+
resource.TestCheckResourceAttr(resourceName, "horovod", acctest.CtFalse),
115+
resource.TestCheckResourceAttr(resourceName, "processor", "CPU"),
116+
resource.TestCheckResourceAttr(resourceName, "vendor_guidance", "STABLE"),
117+
resource.TestCheckResourceAttr(resourceName, "release_notes", rNameUpdate),
118+
resource.TestCheckResourceAttr(resourceName, "job_type", "TRAINING"),
119+
resource.TestCheckResourceAttr(resourceName, "ml_framework", "TensorFlow 1.1"),
120+
resource.TestCheckResourceAttr(resourceName, "programming_lang", "Python 3.8"),
121+
),
122+
},
123+
},
124+
})
125+
}
126+
61127
func TestAccSageMakerImageVersion_disappears(t *testing.T) {
62128
ctx := acctest.Context(t)
63129
if os.Getenv("SAGEMAKER_IMAGE_VERSION_BASE_IMAGE") == "" {
@@ -165,7 +231,7 @@ func testAccCheckImageVersionExists(ctx context.Context, n string, image *sagema
165231
}
166232
}
167233

168-
func testAccImageVersionConfig_basic(rName, baseImage string) string {
234+
func testAccImageVersionConfigBase(rName string) string {
169235
return fmt.Sprintf(`
170236
data "aws_partition" "current" {}
171237
@@ -196,10 +262,29 @@ resource "aws_sagemaker_image" "test" {
196262
197263
depends_on = [aws_iam_role_policy_attachment.test]
198264
}
265+
`, rName)
266+
}
199267

268+
func testAccImageVersionConfig_basic(rName, baseImage string) string {
269+
return testAccImageVersionConfigBase(rName) + fmt.Sprintf(`
200270
resource "aws_sagemaker_image_version" "test" {
201271
image_name = aws_sagemaker_image.test.id
202-
base_image = %[2]q
272+
base_image = %[1]q
273+
}
274+
`, baseImage)
275+
}
276+
277+
func testAccImageVersionConfig_full(rName, baseImage, notes string) string {
278+
return testAccImageVersionConfigBase(rName) + fmt.Sprintf(`
279+
resource "aws_sagemaker_image_version" "test" {
280+
image_name = aws_sagemaker_image.test.id
281+
base_image = %[1]q
282+
job_type = "TRAINING"
283+
processor = "CPU"
284+
release_notes = %[2]q
285+
vendor_guidance = "STABLE"
286+
ml_framework = "TensorFlow 1.1"
287+
programming_lang = "Python 3.8"
203288
}
204-
`, rName, baseImage)
289+
`, baseImage, notes)
205290
}

website/docs/r/sagemaker_image_version.html.markdown

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,21 @@ This resource supports the following arguments:
2727

2828
* `image_name` - (Required) The name of the image. Must be unique to your account.
2929
* `base_image` - (Required) The registry path of the container image on which this image version is based.
30+
* `horovod` - (Optional) Indicates Horovod compatibility.
31+
* `job_type` - (Optional) Indicates SageMaker AI job type compatibility. Valid values are: `TRAINING`, `INFERENCE`, and `NOTEBOOK_KERNEL`.
32+
* `ml_framework` - (Optional) The machine learning framework vended in the image version.
33+
* `processor` - (Optional) Indicates CPU or GPU compatibility. Valid values are: `CPU` and `GPU`.
34+
* `programming_lang` - (Optional) The supported programming language and its version.
35+
* `release_notes` - (Optional) The maintainer description of the image version.
36+
* `vendor_guidance` - (Optional) The stability of the image version, specified by the maintainer. Valid values are: `NOT_PROVIDED`, `STABLE`, `TO_BE_ARCHIVED`, and `ARCHIVED`.
3037

3138
## Attribute Reference
3239

3340
This resource exports the following attributes in addition to the arguments above:
3441

3542
* `id` - The name of the Image.
3643
* `arn` - The Amazon Resource Name (ARN) assigned by AWS to this Image Version.
37-
* `image_arn`- The Amazon Resource Name (ARN) of the image the version is based on.
44+
* `version`- The version of the image. If not specified, the latest version is described.
3845
* `container_image` - The registry path of the container image that contains this image version.
3946

4047
## Import

0 commit comments

Comments
 (0)