Skip to content

Commit

Permalink
Converts source from List to Set and adds test for handling multipl…
Browse files Browse the repository at this point in the history
…e sources
  • Loading branch information
gdavison committed May 2, 2024
1 parent 559461f commit c334544
Show file tree
Hide file tree
Showing 3 changed files with 395 additions and 62 deletions.
14 changes: 8 additions & 6 deletions internal/service/securitylake/securitylake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ func TestAccSecurityLake_serial(t *testing.T) {
"replication": testAccDataLake_replication,
},
"Subscriber": {
"accessType": testAccSubscriber_accessType,
"basic": testAccSubscriber_basic,
"customLogs": testAccSubscriber_customLogSource,
"disappears": testAccSubscriber_disappears,
"tags": testAccSubscriber_tags,
"updated": testAccSubscriber_update,
"accessType": testAccSubscriber_accessType,
"basic": testAccSubscriber_basic,
"customLogs": testAccSubscriber_customLogSource,
"disappears": testAccSubscriber_disappears,
"multipleSources": testAccSubscriber_multipleSources,
"tags": testAccSubscriber_tags,
"updated": testAccSubscriber_update,
"migrateSource": testAccSubscriber_migrate_source,
},
}

Expand Down
101 changes: 66 additions & 35 deletions internal/service/securitylake/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/hashicorp/aws-sdk-go-base/v2/tfawserr"
"github.com/hashicorp/terraform-plugin-framework-timeouts/resource/timeouts"
"github.com/hashicorp/terraform-plugin-framework-validators/listvalidator"
"github.com/hashicorp/terraform-plugin-framework-validators/setvalidator"
"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/path"
Expand All @@ -32,6 +33,7 @@ import (
"github.com/hashicorp/terraform-provider-aws/internal/framework"
fwflex "github.com/hashicorp/terraform-provider-aws/internal/framework/flex"
fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types"
"github.com/hashicorp/terraform-provider-aws/internal/slices"
tftags "github.com/hashicorp/terraform-provider-aws/internal/tags"
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
"github.com/hashicorp/terraform-provider-aws/names"
Expand Down Expand Up @@ -100,10 +102,10 @@ func (r *subscriberResource) Schema(ctx context.Context, request resource.Schema
names.AttrTagsAll: tftags.TagsAttributeComputedOnly(),
},
Blocks: map[string]schema.Block{
"source": schema.ListNestedBlock{
Validators: []validator.List{
listvalidator.IsRequired(),
listvalidator.SizeAtLeast(1),
"source": schema.SetNestedBlock{
Validators: []validator.Set{
setvalidator.IsRequired(),
setvalidator.SizeAtLeast(1),
},
NestedObject: schema.NestedBlockObject{
Blocks: map[string]schema.Block{
Expand Down Expand Up @@ -242,7 +244,7 @@ func (r *subscriberResource) Create(ctx context.Context, request resource.Create

subscriber, err := waitSubscriberCreated(ctx, conn, data.ID.ValueString(), r.CreateTimeout(ctx, data.Timeouts))
if err != nil {
response.Diagnostics.AddError(fmt.Sprintf("waiting for Security Lake Subscriber (%s) create", data.ID.ValueString()), err.Error())
response.Diagnostics.AddError(fmt.Sprintf("waiting for Security Lake Subscriber (%s) to be created", data.ID.ValueString()), err.Error())

return
}
Expand Down Expand Up @@ -347,16 +349,20 @@ func (r *subscriberResource) Update(ctx context.Context, request resource.Update
return
}

_, err = waitSubscriberUpdated(ctx, conn, new.ID.ValueString(), r.CreateTimeout(ctx, new.Timeouts))
subscriber, err := waitSubscriberUpdated(ctx, conn, new.ID.ValueString(), r.CreateTimeout(ctx, new.Timeouts))
if err != nil {
response.Diagnostics.AddError(fmt.Sprintf("waiting for Security Lake Subscriber (%s) create", new.ID.ValueString()), err.Error())
response.Diagnostics.AddError(fmt.Sprintf("waiting for Security Lake Subscriber (%s) to be updated", new.ID.ValueString()), err.Error())

return
}

var subscriberIdentity subscriberIdentityModel
response.Diagnostics.Append(fwflex.Flatten(ctx, subscriber.SubscriberIdentity, &subscriberIdentity)...)
if response.Diagnostics.HasError() {
return
}

new.ResourceShareArn = old.ResourceShareArn
new.ResourceShareName = old.ResourceShareName
new.SubscriberEndpoint = old.SubscriberEndpoint
response.Diagnostics.Append(new.refreshFromOutput(ctx, subscriberIdentity, subscriber)...)
}

response.Diagnostics.Append(response.State.Set(ctx, &new)...)
Expand Down Expand Up @@ -510,7 +516,7 @@ func expandSubscriptionValueSources(ctx context.Context, subscriberSourcesModels
if !item.AwsLogSourceResource.IsNull() && (len(item.AwsLogSourceResource.Elements()) > 0) {
var awsLogSources []awsLogSubscriberSourceModel
diags.Append(item.AwsLogSourceResource.ElementsAs(ctx, &awsLogSources, false)...)
subscriberLogSource := expandSubscriberLogSourceSource(ctx, awsLogSources)
subscriberLogSource := expandSubscriberAwsLogSourceSource(ctx, awsLogSources)
sources = append(sources, subscriberLogSource)
}
if (!item.CustomLogSourceResource.IsNull()) && (len(item.CustomLogSourceResource.Elements()) > 0) {
Expand All @@ -524,7 +530,7 @@ func expandSubscriptionValueSources(ctx context.Context, subscriberSourcesModels
return sources, diags
}

func expandSubscriberLogSourceSource(ctx context.Context, awsLogSources []awsLogSubscriberSourceModel) *awstypes.LogSourceResourceMemberAwsLogSource {
func expandSubscriberAwsLogSourceSource(ctx context.Context, awsLogSources []awsLogSubscriberSourceModel) *awstypes.LogSourceResourceMemberAwsLogSource {
if len(awsLogSources) == 0 {
return nil
}
Expand Down Expand Up @@ -553,38 +559,63 @@ func expandSubscriberCustomLogSourceSource(ctx context.Context, customLogSources
return customLogSourceResource
}

func flattenSubscriberSourcesModel(ctx context.Context, apiObject []awstypes.LogSourceResource) (types.List, diag.Diagnostics) {
func flattenSubscriberSources(ctx context.Context, apiObject []awstypes.LogSourceResource) (types.Set, diag.Diagnostics) {
var diags diag.Diagnostics
elemType := types.ObjectType{AttrTypes: subscriberSourcesModelAttrTypes}
result := types.SetNull(elemType)

obj := map[string]attr.Value{}
var elems []types.Object

for _, item := range apiObject {
switch v := item.(type) {
case *awstypes.LogSourceResourceMemberAwsLogSource:
subscriberLogSource, d := flattenSubscriberLogSourceResourceModel(ctx, &v.Value, nil, "aws")
diags.Append(d...)
obj = map[string]attr.Value{
"aws_log_source_resource": subscriberLogSource,
"custom_log_source_resource": types.ListNull(customLogSubscriberSourceModelAttrTypes),
}
case *awstypes.LogSourceResourceMemberCustomLogSource:
subscriberLogSource, d := flattenSubscriberLogSourceResourceModel(ctx, nil, &v.Value, "custom")
diags.Append(d...)
obj = map[string]attr.Value{
"aws_log_source_resource": types.ListNull(logSubscriberSourcesModelAttrTypes),
"custom_log_source_resource": subscriberLogSource,
}
elem, d := flattenSubscriberSourcesModel(ctx, item)
diags.Append(d...)
if d.HasError() {
return result, diags
}
elems = append(elems, elem)
}

objVal, d := types.ObjectValue(subscriberSourcesModelAttrTypes, obj)
setVal, d := types.SetValue(elemType, slices.ApplyToAll(elems, func(o types.Object) attr.Value {
return o
}))
diags.Append(d...)

listVal, d := types.ListValue(elemType, []attr.Value{objVal})
return setVal, diags
}

func flattenSubscriberSourcesModel(ctx context.Context, apiObject awstypes.LogSourceResource) (types.Object, diag.Diagnostics) {
var diags diag.Diagnostics
result := types.ObjectUnknown(subscriberSourcesModelAttrTypes)

obj := map[string]attr.Value{}

switch v := apiObject.(type) {
case *awstypes.LogSourceResourceMemberAwsLogSource:
subscriberLogSource, d := flattenSubscriberLogSourceResourceModel(ctx, &v.Value, nil, "aws")
diags.Append(d...)
if d.HasError() {
return result, diags
}
obj = map[string]attr.Value{
"aws_log_source_resource": subscriberLogSource,
"custom_log_source_resource": types.ListNull(customLogSubscriberSourceModelAttrTypes),
}
case *awstypes.LogSourceResourceMemberCustomLogSource:
subscriberLogSource, d := flattenSubscriberLogSourceResourceModel(ctx, nil, &v.Value, "custom")
diags.Append(d...)
if d.HasError() {
return result, diags
}
obj = map[string]attr.Value{
"aws_log_source_resource": types.ListNull(logSubscriberSourcesModelAttrTypes),
"custom_log_source_resource": subscriberLogSource,
}
}

result, d := types.ObjectValue(subscriberSourcesModelAttrTypes, obj)
diags.Append(d...)

return listVal, diags
return result, diags
}

func flattenSubscriberLogSourceResourceModel(ctx context.Context, awsLogApiObject *awstypes.AwsLogSourceResource, customLogApiObject *awstypes.CustomLogSourceResource, logSourceType string) (types.List, diag.Diagnostics) {
Expand Down Expand Up @@ -706,7 +737,7 @@ type subscriberResourceModel struct {
AccessTypes types.String `tfsdk:"access_type"`
SubscriberArn types.String `tfsdk:"arn"`
ID types.String `tfsdk:"id"`
Sources types.List `tfsdk:"source"`
Sources types.Set `tfsdk:"source"`
SubscriberDescription types.String `tfsdk:"subscriber_description"`
SubscriberIdentity fwtypes.ListNestedObjectValueOf[subscriberIdentityModel] `tfsdk:"subscriber_identity"`
SubscriberName types.String `tfsdk:"subscriber_name"`
Expand Down Expand Up @@ -764,14 +795,14 @@ func (rd *subscriberResourceModel) refreshFromOutput(ctx context.Context, subscr

rd.AccessTypes = fwflex.StringValueToFramework(ctx, subscriber.AccessTypes[0])
rd.SubscriberIdentity = fwtypes.NewListNestedObjectValueOfPtrMust(ctx, &subscriberIdentity)
sourcesOutput, d := flattenSubscriberSourcesModel(ctx, subscriber.Sources)
diags.Append(d...)
rd.ResourceShareArn = fwflex.StringToFrameworkLegacy(ctx, subscriber.ResourceShareArn)
rd.ResourceShareName = fwflex.StringToFramework(ctx, subscriber.ResourceShareName)
rd.S3BucketArn = fwflex.StringToFramework(ctx, subscriber.S3BucketArn)
rd.SubscriberEndpoint = fwflex.StringToFramework(ctx, subscriber.SubscriberEndpoint)
rd.SubscriberStatus = fwflex.StringValueToFramework(ctx, subscriber.SubscriberStatus)
rd.RoleArn = fwflex.StringToFramework(ctx, subscriber.RoleArn)
sourcesOutput, d := flattenSubscriberSources(ctx, subscriber.Sources)
diags.Append(d...)
rd.Sources = sourcesOutput
rd.SubscriberName = fwflex.StringToFramework(ctx, subscriber.SubscriberName)
rd.SubscriberDescription = fwflex.StringToFramework(ctx, subscriber.SubscriberDescription)
Expand Down
Loading

0 comments on commit c334544

Please sign in to comment.