Skip to content

Commit

Permalink
Merge pull request #39342 from takakuni-classmethod/f-aws_bedrock_inf…
Browse files Browse the repository at this point in the history
…erence_profile-datasource

d/aws_bedrock_inference_profiles: new data source
  • Loading branch information
ewbankkit authored Oct 10, 2024
2 parents f254e92 + e4b9ef3 commit 7f260d9
Show file tree
Hide file tree
Showing 19 changed files with 632 additions and 202 deletions.
6 changes: 6 additions & 0 deletions .changelog/39342.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
```release-note:new-data-source
aws_bedrock_inference_profile
```
```release-note:new-data-source
aws_bedrock_inference_profiles
```
24 changes: 24 additions & 0 deletions internal/framework/data_source_list_of_object.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package framework

import (
"context"

"github.com/hashicorp/terraform-plugin-framework/datasource/schema"
"github.com/hashicorp/terraform-plugin-framework/types"
fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types"
)

// DataSourceComputedListOfObjectAttribute returns a new schema.ListAttribute for objects of the specified type.
// The list is Computed-only.
func DataSourceComputedListOfObjectAttribute[T any](ctx context.Context) schema.ListAttribute {
return schema.ListAttribute{
CustomType: fwtypes.NewListNestedObjectTypeOf[T](ctx),
Computed: true,
ElementType: types.ObjectType{
AttrTypes: fwtypes.AttributeTypesMust[T](ctx),
},
}
}
24 changes: 24 additions & 0 deletions internal/framework/resource_list_of_object.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package framework

import (
"context"

"github.com/hashicorp/terraform-plugin-framework/resource/schema"
"github.com/hashicorp/terraform-plugin-framework/types"
fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types"
)

// NewResourceComputedListOfObjectSchema returns a new schema.ListAttribute for objects of the specified type.
// The list is Computed-only.
func ResourceComputedListOfObjectAttribute[T any](ctx context.Context) schema.ListAttribute {
return schema.ListAttribute{
CustomType: fwtypes.NewListNestedObjectTypeOf[T](ctx),
Computed: true,
ElementType: types.ObjectType{
AttrTypes: fwtypes.AttributeTypesMust[T](ctx),
},
}
}
17 changes: 9 additions & 8 deletions internal/service/bedrock/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ func TestAccBedrock_serial(t *testing.T) {
testCases := map[string]map[string]func(t *testing.T){
// Model customization has a non-adjustable maximum concurrency of 2
"CustomModel": {
acctest.CtBasic: testAccBedrockCustomModel_basic,
acctest.CtDisappears: testAccBedrockCustomModel_disappears,
"tags": testAccBedrockCustomModel_tags,
"kmsKey": testAccBedrockCustomModel_kmsKey,
"validationDataConfig": testAccBedrockCustomModel_validationDataConfig,
"validationDataConfigWaitForCompletion": testAccBedrockCustomModel_validationDataConfigWaitForCompletion,
"vpcConfig": testAccBedrockCustomModel_vpcConfig,
"dataSourceBasic": testAccBedrockCustomModelDataSource_basic,
acctest.CtBasic: testAccCustomModel_basic,
acctest.CtDisappears: testAccCustomModel_disappears,
"tags": testAccCustomModel_tags,
"kmsKey": testAccCustomModel_kmsKey,
"validationDataConfig": testAccCustomModel_validationDataConfig,
"validationDataConfigWaitForCompletion": testAccCustomModel_validationDataConfigWaitForCompletion,
"vpcConfig": testAccCustomModel_vpcConfig,
"singularDataSourceBasic": testAccCustomModelDataSource_basic,
"pluralDataSourceBasic": testAccCustomModelsDataSource_basic,
},
"ModelInvocationLoggingConfiguration": {
acctest.CtBasic: testAccModelInvocationLoggingConfiguration_basic,
Expand Down
102 changes: 43 additions & 59 deletions internal/service/bedrock/custom_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/hashicorp/terraform-plugin-framework-timeouts/resource/timeouts"
"github.com/hashicorp/terraform-plugin-framework-validators/listvalidator"
"github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator"
"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/resource/schema"
"github.com/hashicorp/terraform-plugin-framework/resource/schema/listplanmodifier"
Expand Down Expand Up @@ -136,30 +135,14 @@ func (r *customModelResource) Schema(ctx context.Context, request resource.Schem
stringplanmodifier.RequiresReplace(),
},
},
names.AttrTags: tftags.TagsAttribute(),
names.AttrTagsAll: tftags.TagsAttributeComputedOnly(),
"training_metrics": schema.ListAttribute{
CustomType: fwtypes.NewListNestedObjectTypeOf[customModelTrainingMetricsModel](ctx),
Computed: true,
ElementType: types.ObjectType{
AttrTypes: map[string]attr.Type{
"training_loss": types.Float64Type,
},
},
},
"validation_metrics": schema.ListAttribute{
CustomType: fwtypes.NewListNestedObjectTypeOf[customModelValidationMetricsModel](ctx),
Computed: true,
ElementType: types.ObjectType{
AttrTypes: map[string]attr.Type{
"validation_loss": types.Float64Type,
},
},
},
names.AttrTags: tftags.TagsAttribute(),
names.AttrTagsAll: tftags.TagsAttributeComputedOnly(),
"training_metrics": framework.ResourceComputedListOfObjectAttribute[trainingMetricsModel](ctx),
"validation_metrics": framework.ResourceComputedListOfObjectAttribute[validatorMetricModel](ctx),
},
Blocks: map[string]schema.Block{
"output_data_config": schema.ListNestedBlock{
CustomType: fwtypes.NewListNestedObjectTypeOf[customModelOutputDataConfigModel](ctx),
CustomType: fwtypes.NewListNestedObjectTypeOf[outputDataConfigModel](ctx),
PlanModifiers: []planmodifier.List{
listplanmodifier.RequiresReplace(),
},
Expand Down Expand Up @@ -187,7 +170,7 @@ func (r *customModelResource) Schema(ctx context.Context, request resource.Schem
Delete: true,
}),
"training_data_config": schema.ListNestedBlock{
CustomType: fwtypes.NewListNestedObjectTypeOf[customModelTrainingDataConfigModel](ctx),
CustomType: fwtypes.NewListNestedObjectTypeOf[trainingDataConfigModel](ctx),
PlanModifiers: []planmodifier.List{
listplanmodifier.RequiresReplace(),
},
Expand All @@ -211,7 +194,7 @@ func (r *customModelResource) Schema(ctx context.Context, request resource.Schem
},
},
"validation_data_config": schema.ListNestedBlock{
CustomType: fwtypes.NewListNestedObjectTypeOf[customModelValidationDataConfigModel](ctx),
CustomType: fwtypes.NewListNestedObjectTypeOf[validationDataConfigModel](ctx),
PlanModifiers: []planmodifier.List{
listplanmodifier.RequiresReplace(),
},
Expand All @@ -221,7 +204,7 @@ func (r *customModelResource) Schema(ctx context.Context, request resource.Schem
NestedObject: schema.NestedBlockObject{
Blocks: map[string]schema.Block{
"validator": schema.ListNestedBlock{
CustomType: fwtypes.NewListNestedObjectTypeOf[customModelValidatorConfigModel](ctx),
CustomType: fwtypes.NewListNestedObjectTypeOf[validatorModel](ctx),
PlanModifiers: []planmodifier.List{
listplanmodifier.RequiresReplace(),
},
Expand All @@ -238,7 +221,8 @@ func (r *customModelResource) Schema(ctx context.Context, request resource.Schem
stringplanmodifier.RequiresReplace(),
},
Validators: []validator.String{
fwvalidators.S3URI()},
fwvalidators.S3URI(),
},
},
},
},
Expand All @@ -247,7 +231,7 @@ func (r *customModelResource) Schema(ctx context.Context, request resource.Schem
},
},
names.AttrVPCConfig: schema.ListNestedBlock{
CustomType: fwtypes.NewListNestedObjectTypeOf[customModelVPCConfigModel](ctx),
CustomType: fwtypes.NewListNestedObjectTypeOf[vpcConfigModel](ctx),
PlanModifiers: []planmodifier.List{
listplanmodifier.RequiresReplace(),
},
Expand Down Expand Up @@ -323,8 +307,8 @@ func (r *customModelResource) Create(ctx context.Context, request resource.Creat
data.CustomModelARN = fwflex.StringToFramework(ctx, job.OutputModelArn)
data.JobARN = fwflex.StringToFramework(ctx, job.JobArn)
data.JobStatus = fwtypes.StringEnumValue(job.Status)
data.TrainingMetrics = fwtypes.NewListNestedObjectValueOfNull[customModelTrainingMetricsModel](ctx)
data.ValidationMetrics = fwtypes.NewListNestedObjectValueOfNull[customModelValidationMetricsModel](ctx)
data.TrainingMetrics = fwtypes.NewListNestedObjectValueOfNull[trainingMetricsModel](ctx)
data.ValidationMetrics = fwtypes.NewListNestedObjectValueOfNull[validatorMetricModel](ctx)
data.setID()

response.Diagnostics.Append(response.State.Set(ctx, &data)...)
Expand Down Expand Up @@ -617,26 +601,26 @@ func waitModelCustomizationJobStopped(ctx context.Context, conn *bedrock.Client,
}

type customModelResourceModel struct {
BaseModelIdentifier fwtypes.ARN `tfsdk:"base_model_identifier"`
CustomModelARN types.String `tfsdk:"custom_model_arn"`
CustomModelKmsKeyID fwtypes.ARN `tfsdk:"custom_model_kms_key_id"`
CustomModelName types.String `tfsdk:"custom_model_name"`
CustomizationType fwtypes.StringEnum[awstypes.CustomizationType] `tfsdk:"customization_type"`
HyperParameters fwtypes.MapValueOf[types.String] `tfsdk:"hyperparameters"`
ID types.String `tfsdk:"id"`
JobARN types.String `tfsdk:"job_arn"`
JobName types.String `tfsdk:"job_name"`
JobStatus fwtypes.StringEnum[awstypes.ModelCustomizationJobStatus] `tfsdk:"job_status"`
OutputDataConfig fwtypes.ListNestedObjectValueOf[customModelOutputDataConfigModel] `tfsdk:"output_data_config"`
RoleARN fwtypes.ARN `tfsdk:"role_arn"`
Tags tftags.Map `tfsdk:"tags"`
TagsAll tftags.Map `tfsdk:"tags_all"`
Timeouts timeouts.Value `tfsdk:"timeouts"`
TrainingDataConfig fwtypes.ListNestedObjectValueOf[customModelTrainingDataConfigModel] `tfsdk:"training_data_config"`
TrainingMetrics fwtypes.ListNestedObjectValueOf[customModelTrainingMetricsModel] `tfsdk:"training_metrics"`
ValidationDataConfig fwtypes.ListNestedObjectValueOf[customModelValidationDataConfigModel] `tfsdk:"validation_data_config"`
ValidationMetrics fwtypes.ListNestedObjectValueOf[customModelValidationMetricsModel] `tfsdk:"validation_metrics"`
VPCConfig fwtypes.ListNestedObjectValueOf[customModelVPCConfigModel] `tfsdk:"vpc_config"`
BaseModelIdentifier fwtypes.ARN `tfsdk:"base_model_identifier"`
CustomModelARN types.String `tfsdk:"custom_model_arn"`
CustomModelKmsKeyID fwtypes.ARN `tfsdk:"custom_model_kms_key_id"`
CustomModelName types.String `tfsdk:"custom_model_name"`
CustomizationType fwtypes.StringEnum[awstypes.CustomizationType] `tfsdk:"customization_type"`
HyperParameters fwtypes.MapValueOf[types.String] `tfsdk:"hyperparameters"`
ID types.String `tfsdk:"id"`
JobARN types.String `tfsdk:"job_arn"`
JobName types.String `tfsdk:"job_name"`
JobStatus fwtypes.StringEnum[awstypes.ModelCustomizationJobStatus] `tfsdk:"job_status"`
OutputDataConfig fwtypes.ListNestedObjectValueOf[outputDataConfigModel] `tfsdk:"output_data_config"`
RoleARN fwtypes.ARN `tfsdk:"role_arn"`
Tags tftags.Map `tfsdk:"tags"`
TagsAll tftags.Map `tfsdk:"tags_all"`
Timeouts timeouts.Value `tfsdk:"timeouts"`
TrainingDataConfig fwtypes.ListNestedObjectValueOf[trainingDataConfigModel] `tfsdk:"training_data_config"`
TrainingMetrics fwtypes.ListNestedObjectValueOf[trainingMetricsModel] `tfsdk:"training_metrics"`
ValidationDataConfig fwtypes.ListNestedObjectValueOf[validationDataConfigModel] `tfsdk:"validation_data_config"`
ValidationMetrics fwtypes.ListNestedObjectValueOf[validatorMetricModel] `tfsdk:"validation_metrics"`
VPCConfig fwtypes.ListNestedObjectValueOf[vpcConfigModel] `tfsdk:"vpc_config"`
}

func (data *customModelResourceModel) InitFromID() error {
Expand All @@ -649,31 +633,31 @@ func (data *customModelResourceModel) setID() {
data.ID = data.JobARN
}

type customModelOutputDataConfigModel struct {
type outputDataConfigModel struct {
S3URI types.String `tfsdk:"s3_uri"`
}

type customModelTrainingDataConfigModel struct {
type trainingDataConfigModel struct {
S3URI types.String `tfsdk:"s3_uri"`
}

type customModelTrainingMetricsModel struct {
type trainingMetricsModel struct {
TrainingLoss types.Float64 `tfsdk:"training_loss"`
}

type customModelValidationDataConfigModel struct {
Validators fwtypes.ListNestedObjectValueOf[customModelValidatorConfigModel] `tfsdk:"validator"`
type validationDataConfigModel struct {
Validators fwtypes.ListNestedObjectValueOf[validatorModel] `tfsdk:"validator"`
}

type customModelValidationMetricsModel struct {
type validatorMetricModel struct {
ValidationLoss types.Float64 `tfsdk:"validation_loss"`
}

type customModelValidatorConfigModel struct {
type validatorModel struct {
S3URI types.String `tfsdk:"s3_uri"`
}

type customModelVPCConfigModel struct {
SecurityGroupIDs fwtypes.SetValueOf[types.String] `tfsdk:"security_group_ids"`
SubnetIDs fwtypes.SetValueOf[types.String] `tfsdk:"subnet_ids"`
type vpcConfigModel struct {
SecurityGroupIDs fwtypes.SetOfString `tfsdk:"security_group_ids"`
SubnetIDs fwtypes.SetOfString `tfsdk:"subnet_ids"`
}
Loading

0 comments on commit 7f260d9

Please sign in to comment.