Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
jckuester committed Jan 21, 2019
1 parent c3bedf0 commit 694c1c2
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 167 deletions.
118 changes: 66 additions & 52 deletions aws/resource_aws_sagemaker_endpoint_configuration.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package aws

import (
"bytes"
"fmt"
"log"
"time"

"github.com/hashicorp/terraform/helper/validation"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/terraform/helper/hashcode"
"github.com/hashicorp/terraform/helper/resource"
"github.com/hashicorp/terraform/helper/schema"
"log"
"time"
)

func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
Expand Down Expand Up @@ -38,7 +39,7 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
},

"production_variants": {
Type: schema.TypeSet,
Type: schema.TypeList,
Required: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
Expand All @@ -56,9 +57,10 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
},

"initial_instance_count": {
Type: schema.TypeInt,
Required: true,
ForceNew: true,
Type: schema.TypeInt,
Required: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},

"instance_type": {
Expand All @@ -68,13 +70,19 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
},

"initial_variant_weight": {
Type: schema.TypeFloat,
Required: true,
Type: schema.TypeFloat,
Optional: true,
ForceNew: true,
ValidateFunc: FloatAtLeast(0),
},

"accelerator_type": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
},
},
},
Set: resourceAwsSagmakerEndpointConfigEntryHash,
},

"kms_key_id": {
Expand All @@ -83,11 +91,6 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
ForceNew: true,
},

"creation_time": {
Type: schema.TypeString,
Computed: true,
},

"tags": tagsSchema(),
},
}
Expand All @@ -107,7 +110,7 @@ func resourceAwsSagemakerEndpointConfigurationCreate(d *schema.ResourceData, met
EndpointConfigName: aws.String(name),
}

prodVariants, err := expandProductionVariants(d.Get("production_variants").(*schema.Set).List())
prodVariants, err := expandProductionVariants(d.Get("production_variants").([]interface{}))
if err != nil {
return err
}
Expand All @@ -117,18 +120,18 @@ func resourceAwsSagemakerEndpointConfigurationCreate(d *schema.ResourceData, met
createOpts.KmsKeyId = aws.String(v.(string))
}

if v, ok := d.GetOk("tags"); ok {
createOpts.Tags = tagsFromMapSagemaker(v.(map[string]interface{}))
}

log.Printf("[DEBUG] Sagemaker endpoint configuration create config: %#v", *createOpts)
resp, err := conn.CreateEndpointConfig(createOpts)
_, err = conn.CreateEndpointConfig(createOpts)
if err != nil {
return fmt.Errorf("Error creating Sagemaker endpoint configuration: %s", err)
return fmt.Errorf("error creating Sagemaker endpoint configuration: %s", err)
}

d.SetId(name)
if err := d.Set("arn", resp.EndpointConfigArn); err != nil {
return err
}

return resourceAwsSagemakerEndpointConfigurationUpdate(d, meta)
return resourceAwsSagemakerEndpointConfigurationRead(d, meta)
}

func resourceAwsSagemakerEndpointConfigurationRead(d *schema.ResourceData, meta interface{}) error {
Expand All @@ -140,11 +143,12 @@ func resourceAwsSagemakerEndpointConfigurationRead(d *schema.ResourceData, meta

endpointConfig, err := conn.DescribeEndpointConfig(request)
if err != nil {
if sagemakerErr, ok := err.(awserr.Error); ok && sagemakerErr.Code() == "ResourceNotFound" {
if sagemakerErr, ok := err.(awserr.Error); ok && sagemakerErr.Code() == "ValidationException" {
log.Printf("[INFO] unable to find the sagemaker endpoint configuration resource and therefore it is removed from the state: %s", d.Id())
d.SetId("")
return nil
}
return fmt.Errorf("Error reading Sagemaker endpoint configuration %s: %s", d.Id(), err)
return fmt.Errorf("error reading Sagemaker endpoint configuration %s: %s", d.Id(), err)
}

if err := d.Set("arn", endpointConfig.EndpointConfigArn); err != nil {
Expand All @@ -156,14 +160,19 @@ func resourceAwsSagemakerEndpointConfigurationRead(d *schema.ResourceData, meta
if err := d.Set("production_variants", flattenProductionVariants(endpointConfig.ProductionVariants)); err != nil {
return err
}

if err := d.Set("kms_key_id", endpointConfig.KmsKeyId); err != nil {
return err
}
if err := d.Set("creation_time", endpointConfig.CreationTime.Format(time.RFC3339)); err != nil {

tagsOutput, err := conn.ListTags(&sagemaker.ListTagsInput{
ResourceArn: endpointConfig.EndpointConfigArn,
})
if err != nil {
return fmt.Errorf("error listing tags of Sagemaker endpoint configuration %s: %s", d.Id(), err)
}
if err := d.Set("tags", tagsToMapSagemaker(tagsOutput.Tags)); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -216,20 +225,28 @@ func expandProductionVariants(configured []interface{}) ([]*sagemaker.Production
for _, lRaw := range configured {
data := lRaw.(map[string]interface{})

var name string
if v, ok := data["variant_name"]; ok {
name = v.(string)
} else {
name = resource.UniqueId()
}

l := &sagemaker.ProductionVariant{
VariantName: aws.String(name),
InstanceType: aws.String(data["instance_type"].(string)),
ModelName: aws.String(data["model_name"].(string)),
InitialVariantWeight: aws.Float64(float64(data["initial_variant_weight"].(float64))),
InitialInstanceCount: aws.Int64(int64(data["initial_instance_count"].(int))),
}

if v, ok := data["variant_name"]; ok {
l.VariantName = aws.String(v.(string))
} else {
l.VariantName = aws.String(resource.UniqueId())
}

if v, ok := data["initial_variant_weight"]; ok {
l.InitialVariantWeight = aws.Float64(v.(float64))
} else {
l.InitialVariantWeight = aws.Float64(1)
}

if v, ok := data["accelerator_type"]; ok && v.(string) != "" {
l.AcceleratorType = aws.String(data["accelerator_type"].(string))
}

containers = append(containers, l)
}

Expand All @@ -238,27 +255,24 @@ func expandProductionVariants(configured []interface{}) ([]*sagemaker.Production

func flattenProductionVariants(list []*sagemaker.ProductionVariant) []map[string]interface{} {
result := make([]map[string]interface{}, 0, len(list))

for _, i := range list {
l := map[string]interface{}{
"variant_name": *i.VariantName,
"instance_type": *i.InstanceType,
"model_name": *i.ModelName,
"initial_variant_weight": *i.InitialVariantWeight,
"initial_instance_count": *i.InitialInstanceCount,
}
if i.VariantName != nil {
l["variant_name"] = *i.VariantName
}
if i.InitialVariantWeight != nil {
l["initial_variant_weight"] = *i.InitialVariantWeight
}
if i.AcceleratorType != nil {
l["accelerator_type"] = *i.AcceleratorType
}

result = append(result, l)
}
return result
}

func resourceAwsSagmakerEndpointConfigEntryHash(v interface{}) int {
var buf bytes.Buffer
m := v.(map[string]interface{})
buf.WriteString(fmt.Sprintf("%s-", m["variant_name"].(string)))
buf.WriteString(fmt.Sprintf("%s-", m["model_name"].(string)))
buf.WriteString(fmt.Sprintf("%s-", m["instance_type"].(string)))
buf.WriteString(fmt.Sprintf("%f-", m["initial_variant_weight"].(float64)))
buf.WriteString(fmt.Sprintf("%d-", m["initial_instance_count"].(int)))

return hashcode.String(buf.String())
}
Loading

0 comments on commit 694c1c2

Please sign in to comment.