diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index f1beb277ac..bcf8219127 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,7 @@ ### Bug Fixes + * Support updating all attributes for `databricks_model_serving` ([#4575](https://github.com/databricks/terraform-provider-databricks/pull/4575)). * Fix automatic cluster creation for `databricks_sql_permissions` ([#4141](https://github.com/databricks/terraform-provider-databricks/pull/4141)) ### Documentation diff --git a/serving/model_serving_test.go b/serving/model_serving_test.go index 2d7e851a9a..4209dfe2fc 100644 --- a/serving/model_serving_test.go +++ b/serving/model_serving_test.go @@ -46,6 +46,18 @@ func TestAccModelServing(t *testing.T) { } } } + tags { + key = "key1" + value = "value-should-not-change" + } + tags { + key = "key2" + value = "value-should-change" + } + tags { + key = "key3" + value = "should-be-deleted" + } } data "databricks_serving_endpoints" "all" {} @@ -79,6 +91,18 @@ func TestAccModelServing(t *testing.T) { } } } + tags { + key = "key1" + value = "value-should-not-change" + } + tags { + key = "key2" + value = "value-should-change-to-something-new" + } + tags { + key = "key4" + value = "should-be-added" + } } data "databricks_serving_endpoints" "all" {} `, name), diff --git a/serving/resource_model_serving.go b/serving/resource_model_serving.go index 1e10aa7a70..5c59f7cfb6 100644 --- a/serving/resource_model_serving.go +++ b/serving/resource_model_serving.go @@ -5,6 +5,7 @@ import ( "log" "time" + "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/serving" "github.com/databricks/terraform-provider-databricks/common" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" @@ -13,6 +14,85 @@ import ( const DefaultProvisionTimeout = 45 * time.Minute const deleteCallTimeout = 10 * time.Second +// updateConfig updates the configuration of the provided serving endpoint to the provided config. +func updateConfig(ctx context.Context, w *databricks.WorkspaceClient, name string, e *serving.EndpointCoreConfigInput, d *schema.ResourceData) error { + e.Name = name + waiter, err := w.ServingEndpoints.UpdateConfig(ctx, *e) + if err != nil { + return err + } + _, err = waiter.GetWithTimeout(d.Timeout(schema.TimeoutUpdate)) + if err != nil { + return err + } + return nil +} + +// updateTags updates the tags of the provided serving endpoint to the given tags. Any tags not present on the existing +// endpoint will be removed, any tags absent on the endpoint will be added, existing tags will be updated, and unchanged +// tags will remain as-is. +func updateTags(ctx context.Context, w *databricks.WorkspaceClient, name string, newTags []serving.EndpointTag, d *schema.ResourceData) error { + currentEndpoint, err := w.ServingEndpoints.Get(ctx, serving.GetServingEndpointRequest{ + Name: name, + }) + oldTags := currentEndpoint.Tags + if err != nil { + return err + } + req := serving.PatchServingEndpointTags{ + Name: name, + } + for _, newTag := range newTags { + found := false + for _, oldTag := range oldTags { + if oldTag.Key == newTag.Key && oldTag.Value == newTag.Value { + found = true + break + } + } + if !found { + req.AddTags = append(req.AddTags, newTag) + } + } + for _, oldTag := range oldTags { + found := false + for _, newTag := range newTags { + if oldTag.Key == newTag.Key { + found = true + break + } + } + if !found { + req.DeleteTags = append(req.DeleteTags, oldTag.Key) + } + } + if _, err := w.ServingEndpoints.Patch(ctx, req); err != nil { + return err + } + return nil +} + +// Update the rate limit configuration for a model serving endpoint. +func updateRateLimits(ctx context.Context, w *databricks.WorkspaceClient, name string, newRateLimits []serving.RateLimit, d *schema.ResourceData) error { + _, err := w.ServingEndpoints.Put(ctx, serving.PutRequest{ + Name: name, + RateLimits: newRateLimits, + }) + return err +} + +// Update the AI Gateway configuration for a model serving endpoint. +func updateAiGateway(ctx context.Context, w *databricks.WorkspaceClient, name string, newAiGateway serving.AiGatewayConfig, d *schema.ResourceData) error { + _, err := w.ServingEndpoints.PutAiGateway(ctx, serving.PutAiGatewayRequest{ + Name: name, + Guardrails: newAiGateway.Guardrails, + InferenceTableConfig: newAiGateway.InferenceTableConfig, + RateLimits: newAiGateway.RateLimits, + UsageTrackingConfig: newAiGateway.UsageTrackingConfig, + }) + return err +} + func ResourceModelServing() common.Resource { s := common.StructToSchema( serving.CreateServingEndpoint{}, @@ -43,6 +123,9 @@ func ResourceModelServing() common.Resource { common.MustSchemaPath(m, "config", "served_entities", "workload_size").Computed = true common.MustSchemaPath(m, "config", "served_entities", "workload_type").Computed = true + // route_optimized cannot be updated. + common.MustSchemaPath(m, "route_optimized").ForceNew = true + m["serving_endpoint_id"] = &schema.Schema{ Computed: true, Type: schema.TypeString, @@ -113,13 +196,22 @@ func ResourceModelServing() common.Resource { var e serving.CreateServingEndpoint common.DataToStructPointer(d, s, &e) if d.HasChange("config") { - e.Config.Name = e.Name - waiter, err := w.ServingEndpoints.UpdateConfig(ctx, *e.Config) - if err != nil { + if err := updateConfig(ctx, w, e.Name, e.Config, d); err != nil { + return err + } + } + if d.HasChange("tags") { + if err := updateTags(ctx, w, e.Name, e.Tags, d); err != nil { + return err + } + } + if d.HasChange("rate_limits") { + if err := updateRateLimits(ctx, w, e.Name, e.RateLimits, d); err != nil { return err } - _, err = waiter.GetWithTimeout(d.Timeout(schema.TimeoutUpdate)) - if err != nil { + } + if d.HasChange("ai_gateway") { + if err := updateAiGateway(ctx, w, e.Name, *e.AiGateway, d); err != nil { return err } }