Skip to content

Commit

Permalink
azuread_conditional_access_policy: improve handling of the `session_c…
Browse files Browse the repository at this point in the history
…ontrols` block

- Make `sign_in_frequency_authentication_type` and
  `sign_in_frequency_internal` both Optional + Computed and remove their
  default values.
- Handle the setting of these default values in the
  `expandConditionalAccessSessionControls()` function.
- Expand test coverage to all reasonable permutatons of the properties
  in this block to ensure no trailing diff or incorrect setting of
  values in the request.
  • Loading branch information
manicminer committed May 16, 2024
1 parent 0598828 commit dc82167
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ func conditionalAccessPolicyResource() *pluginsdk.Resource {
"sign_in_frequency_authentication_type": {
Type: pluginsdk.TypeString,
Optional: true,
Default: msgraph.ConditionalAccessAuthenticationTypePrimaryAndSecondaryAuthentication,
Computed: true,
ValidateFunc: validation.StringInSlice([]string{
msgraph.ConditionalAccessAuthenticationTypePrimaryAndSecondaryAuthentication,
msgraph.ConditionalAccessAuthenticationTypeSecondaryAuthentication,
Expand All @@ -592,7 +592,7 @@ func conditionalAccessPolicyResource() *pluginsdk.Resource {
"sign_in_frequency_interval": {
Type: pluginsdk.TypeString,
Optional: true,
Default: msgraph.ConditionalAccessFrequencyIntervalTimeBased,
Computed: true,
ValidateFunc: validation.StringInSlice([]string{
msgraph.ConditionalAccessFrequencyIntervalTimeBased,
msgraph.ConditionalAccessFrequencyIntervalEveryTime,
Expand Down Expand Up @@ -637,12 +637,14 @@ func conditionalAccessPolicyCustomizeDiff(_ context.Context, diff *pluginsdk.Res
func conditionalAccessPolicyDiffSuppress(k, old, new string, d *pluginsdk.ResourceData) bool {
suppress := false

// When ineffectual `session_controls` are specified, you must send `sessionControls: null`, and when policy has ineffectual
// `sessionControls`, the API condenses it to `sessionControls: null` in the response.
if k == "session_controls.#" && old == "0" && new == "1" {
// When an ineffectual `session_controls` block is configured, the API just ignores it and returns
// sessionControls: null
sessionControlsRaw := d.Get("session_controls").([]interface{})
if len(sessionControlsRaw) == 1 && sessionControlsRaw[0] != nil {
sessionControls := sessionControlsRaw[0].(map[string]interface{})

// Suppress by default, but only if all the block properties have a non-default value
suppress = true
if v, ok := sessionControls["application_enforced_restrictions_enabled"]; ok && v.(bool) {
suppress = false
Expand All @@ -659,10 +661,10 @@ func conditionalAccessPolicyDiffSuppress(k, old, new string, d *pluginsdk.Resour
if v, ok := sessionControls["sign_in_frequency"]; ok && v.(int) > 0 {
suppress = false
}
if v, ok := sessionControls["sign_in_frequency_authentication_type"]; ok && v.(string) != msgraph.ConditionalAccessAuthenticationTypePrimaryAndSecondaryAuthentication {
if v, ok := sessionControls["sign_in_frequency_authentication_type"]; ok && v.(string) != "" {
suppress = false
}
if v, ok := sessionControls["sign_in_frequency_interval"]; ok && v.(string) != msgraph.ConditionalAccessFrequencyIntervalTimeBased {
if v, ok := sessionControls["sign_in_frequency_interval"]; ok && v.(string) != "" {
suppress = false
}
if v, ok := sessionControls["sign_in_frequency_period"]; ok && v.(string) != "" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,6 @@ func TestAccConditionalAccessPolicy_sessionControls(t *testing.T) {
),
},
data.ImportStep(),
})
}

func TestAccConditionalAccessPolicy_sessionControlsDisabled(t *testing.T) {
// This is testing the DiffSuppressFunc for the `session_controls` block

data := acceptance.BuildTestData(t, "azuread_conditional_access_policy", "test")
r := ConditionalAccessPolicyResource{}

data.ResourceTest(t, r, []acceptance.TestStep{
{
Config: r.sessionControlsDisabled(data),
Check: acceptance.ComposeTestCheckFunc(
Expand Down Expand Up @@ -199,6 +189,46 @@ func TestAccConditionalAccessPolicy_sessionControlsDisabled(t *testing.T) {
),
},
data.ImportStep(),
{
Config: r.sessionControlsApplicationEnforcedRestrictions(data),
Check: acceptance.ComposeTestCheckFunc(
check.That(data.ResourceName).ExistsInAzure(r),
check.That(data.ResourceName).Key("id").Exists(),
check.That(data.ResourceName).Key("display_name").HasValue(fmt.Sprintf("acctest-CONPOLICY-%d", data.RandomInteger)),
check.That(data.ResourceName).Key("state").HasValue("disabled"),
),
},
data.ImportStep(),
{
Config: r.sessionControlsCloudAppSecurityPolicy(data),
Check: acceptance.ComposeTestCheckFunc(
check.That(data.ResourceName).ExistsInAzure(r),
check.That(data.ResourceName).Key("id").Exists(),
check.That(data.ResourceName).Key("display_name").HasValue(fmt.Sprintf("acctest-CONPOLICY-%d", data.RandomInteger)),
check.That(data.ResourceName).Key("state").HasValue("disabled"),
),
},
data.ImportStep(),
{
Config: r.sessionControlsPersistentBrowserMode(data),
Check: acceptance.ComposeTestCheckFunc(
check.That(data.ResourceName).ExistsInAzure(r),
check.That(data.ResourceName).Key("id").Exists(),
check.That(data.ResourceName).Key("display_name").HasValue(fmt.Sprintf("acctest-CONPOLICY-%d", data.RandomInteger)),
check.That(data.ResourceName).Key("state").HasValue("disabled"),
),
},
data.ImportStep(),
{
Config: r.sessionControlsDisabled(data),
Check: acceptance.ComposeTestCheckFunc(
check.That(data.ResourceName).ExistsInAzure(r),
check.That(data.ResourceName).Key("id").Exists(),
check.That(data.ResourceName).Key("display_name").HasValue(fmt.Sprintf("acctest-CONPOLICY-%d", data.RandomInteger)),
check.That(data.ResourceName).Key("state").HasValue("disabled"),
),
},
data.ImportStep(),
})
}

Expand Down Expand Up @@ -302,6 +332,11 @@ func TestAccConditionalAccessPolicy_guestsOrExternalUsers(t *testing.T) {
}

func (r ConditionalAccessPolicyResource) Exists(ctx context.Context, clients *clients.Client, state *pluginsdk.InstanceState) (*bool, error) {
clients.ConditionalAccess.PoliciesClient.BaseClient.DisableRetries = true
defer func() {
clients.ConditionalAccess.PoliciesClient.BaseClient.DisableRetries = false
}()

var id *string

app, status, err := clients.ConditionalAccess.PoliciesClient.Get(ctx, state.ID, odata.Query{})
Expand Down Expand Up @@ -523,6 +558,129 @@ resource "azuread_conditional_access_policy" "test" {
`, data.RandomInteger)
}

func (ConditionalAccessPolicyResource) sessionControlsApplicationEnforcedRestrictions(data acceptance.TestData) string {
return fmt.Sprintf(`
provider "azuread" {}
resource "azuread_conditional_access_policy" "test" {
display_name = "acctest-CONPOLICY-%[1]d"
state = "disabled"
conditions {
client_app_types = ["browser"]
applications {
included_applications = ["All"]
}
locations {
included_locations = ["All"]
}
platforms {
included_platforms = ["all"]
}
users {
included_users = ["All"]
excluded_users = ["GuestsOrExternalUsers"]
}
}
grant_controls {
operator = "OR"
built_in_controls = ["block"]
}
session_controls {
application_enforced_restrictions_enabled = true
}
}
`, data.RandomInteger)
}

func (ConditionalAccessPolicyResource) sessionControlsCloudAppSecurityPolicy(data acceptance.TestData) string {
return fmt.Sprintf(`
provider "azuread" {}
resource "azuread_conditional_access_policy" "test" {
display_name = "acctest-CONPOLICY-%[1]d"
state = "disabled"
conditions {
client_app_types = ["browser"]
applications {
included_applications = ["All"]
}
locations {
included_locations = ["All"]
}
platforms {
included_platforms = ["all"]
}
users {
included_users = ["All"]
excluded_users = ["GuestsOrExternalUsers"]
}
}
grant_controls {
operator = "OR"
built_in_controls = ["block"]
}
session_controls {
cloud_app_security_policy = "monitorOnly"
}
}
`, data.RandomInteger)
}

func (ConditionalAccessPolicyResource) sessionControlsPersistentBrowserMode(data acceptance.TestData) string {
return fmt.Sprintf(`
provider "azuread" {}
resource "azuread_conditional_access_policy" "test" {
display_name = "acctest-CONPOLICY-%[1]d"
state = "disabled"
conditions {
client_app_types = ["browser"]
applications {
included_applications = ["All"]
}
locations {
included_locations = ["All"]
}
platforms {
included_platforms = ["all"]
}
users {
included_users = ["All"]
excluded_users = ["GuestsOrExternalUsers"]
}
}
grant_controls {
operator = "OR"
built_in_controls = ["block"]
}
session_controls {
persistent_browser_mode = "always"
}
}
`, data.RandomInteger)
}

func (ConditionalAccessPolicyResource) clientApplicationsIncluded(data acceptance.TestData) string {
return fmt.Sprintf(`
provider "azuread" {}
Expand Down
13 changes: 9 additions & 4 deletions internal/services/conditionalaccess/conditionalaccess.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,24 @@ func expandConditionalAccessSessionControls(in []interface{}) *msgraph.Condition
signInFrequency.IsEnabled = pointer.To(true)
signInFrequency.Type = pointer.To(config["sign_in_frequency_period"].(string))
signInFrequency.Value = pointer.To(int32(frequencyValue))

// AuthenticationType and FrequencyInterval must be set to default values here
signInFrequency.AuthenticationType = pointer.To(msgraph.ConditionalAccessAuthenticationTypePrimaryAndSecondaryAuthentication)
signInFrequency.FrequencyInterval = pointer.To(msgraph.ConditionalAccessFrequencyIntervalTimeBased)
}

if authenticationType, ok := config["sign_in_frequency_authentication_type"]; ok {
if authenticationType, ok := config["sign_in_frequency_authentication_type"]; ok && authenticationType.(string) != "" {
signInFrequency.AuthenticationType = pointer.To(authenticationType.(string))
}

if interval, ok := config["sign_in_frequency_interval"]; ok {
if interval, ok := config["sign_in_frequency_interval"]; ok && interval.(string) != "" {
signInFrequency.FrequencyInterval = pointer.To(interval.(string))
}

// API returns 400 error if signInFrequency is set with all default/zero values
if pointer.From(signInFrequency.IsEnabled) || pointer.From(signInFrequency.FrequencyInterval) != msgraph.ConditionalAccessFrequencyIntervalTimeBased ||
pointer.From(signInFrequency.AuthenticationType) != msgraph.ConditionalAccessAuthenticationTypePrimaryAndSecondaryAuthentication {
if (signInFrequency.IsEnabled != nil && *signInFrequency.IsEnabled) ||
(signInFrequency.FrequencyInterval != nil && *signInFrequency.FrequencyInterval != msgraph.ConditionalAccessFrequencyIntervalTimeBased) ||
(signInFrequency.AuthenticationType != nil && *signInFrequency.AuthenticationType != msgraph.ConditionalAccessAuthenticationTypePrimaryAndSecondaryAuthentication) {
result.SignInFrequency = &signInFrequency
}

Expand Down

0 comments on commit dc82167

Please sign in to comment.