Skip to content

Commit

Permalink
Add tenantId and scopes to TokenCacheNotificationArgs (#3401)
Browse files Browse the repository at this point in the history
* Add tenantId and scopes to TokenCacheNotificationArgs

* Update src/client/Microsoft.Identity.Client/TokenCacheNotificationArgs.cs

Co-authored-by: Gladwin Johnson <90415114+gladjohn@users.noreply.github.com>

Co-authored-by: Gladwin Johnson <90415114+gladjohn@users.noreply.github.com>
  • Loading branch information
bgavrilMS and gladjohn authored Jun 18, 2022
1 parent 552f186 commit 2d9afe4
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,16 @@ private async Task RefreshCacheForReadOperationsAsync()
var args = new TokenCacheNotificationArgs(
TokenCacheInternal,
_requestParams.AppConfig.ClientId,
_requestParams.Account,
_requestParams.Account,
hasStateChanged: false,
isApplicationCache: TokenCacheInternal.IsApplicationCache,
suggestedCacheKey: key,
hasTokens: TokenCacheInternal.HasTokensNoLocks(),
cancellationToken: _requestParams.RequestContext.UserCancellationToken,
suggestedCacheExpiry: null,
correlationId: _requestParams.RequestContext.CorrelationId);
correlationId: _requestParams.RequestContext.CorrelationId,
requestScopes: _requestParams.Scope,
requestTenantId: _requestParams.AuthorityManager.OriginalAuthority.TenantId);

stopwatch.Start();
await TokenCacheInternal.OnBeforeAccessAsync(args).ConfigureAwait(false);
Expand All @@ -147,7 +149,9 @@ private async Task RefreshCacheForReadOperationsAsync()
hasTokens: TokenCacheInternal.HasTokensNoLocks(),
cancellationToken: _requestParams.RequestContext.UserCancellationToken,
suggestedCacheExpiry: null,
correlationId: _requestParams.RequestContext.CorrelationId);
correlationId: _requestParams.RequestContext.CorrelationId,
requestScopes: _requestParams.Scope,
requestTenantId: _requestParams.AuthorityManager.OriginalAuthority.TenantId);

await TokenCacheInternal.OnAfterAccessAsync(args).ConfigureAwait(false);
RequestContext.ApiEvent.DurationInCacheInMs += stopwatch.ElapsedMilliseconds;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ private void ShowPickerWithSplashScreenImpl()
splash.DialogResult = System.Windows.Forms.DialogResult.OK;
splash.TopMost = true;

#pragma warning disable VSTHRD101 // Avoid unsupported async delegates - Windows API mandates this
splash.Shown += async (s, e) =>
{
var windowHandle = splash.Handle;
await ShowPickerForWin32WindowAsync(windowHandle).ConfigureAwait(true);
splash.Close();
};
#pragma warning restore VSTHRD101 // Avoid unsupported async delegates

try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ async Task<Tuple<MsalAccessTokenCacheItem, MsalIdTokenCacheItem, Account>> IToke
hasTokens: tokenCacheInternal.HasTokensNoLocks(),
suggestedCacheExpiry: null,
cancellationToken: requestParams.RequestContext.UserCancellationToken,
correlationId: requestParams.RequestContext.CorrelationId);
correlationId: requestParams.RequestContext.CorrelationId,
requestScopes: requestParams.Scope,
requestTenantId: requestParams.AuthorityManager.OriginalAuthority.TenantId);

Stopwatch sw = Stopwatch.StartNew();

Expand Down Expand Up @@ -233,7 +235,10 @@ async Task<Tuple<MsalAccessTokenCacheItem, MsalIdTokenCacheItem, Account>> IToke
hasTokens: tokenCacheInternal.HasTokensNoLocks(),
suggestedCacheExpiry: cacheExpiry,
cancellationToken: requestParams.RequestContext.UserCancellationToken,
correlationId: requestParams.RequestContext.CorrelationId);
correlationId: requestParams.RequestContext.CorrelationId,
requestScopes: requestParams.Scope,
requestTenantId: requestParams.AuthorityManager.OriginalAuthority.TenantId);


Stopwatch sw = Stopwatch.StartNew();
await tokenCacheInternal.OnAfterAccessAsync(args).ConfigureAwait(false);
Expand Down Expand Up @@ -721,7 +726,9 @@ internal async Task ExpireAllAccessTokensForTestAsync()
hasTokens: tokenCacheInternal.HasTokensNoLocks(),
suggestedCacheExpiry: null,
cancellationToken: default,
correlationId: default);
correlationId: default,
requestScopes: null,
requestTenantId: null);

await tokenCacheInternal.OnAfterAccessAsync(args).ConfigureAwait(false);
}
Expand Down Expand Up @@ -1138,7 +1145,10 @@ async Task ITokenCacheInternal.RemoveAccountAsync(IAccount account, Authenticati
hasTokens: tokenCacheInternal.HasTokensNoLocks(),
suggestedCacheExpiry: null,
cancellationToken: requestParameters.RequestContext.UserCancellationToken,
correlationId: requestParameters.RequestContext.CorrelationId);
correlationId: requestParameters.RequestContext.CorrelationId,
requestScopes: requestParameters.Scope,
requestTenantId: requestParameters.AuthorityManager.OriginalAuthority.TenantId);


await tokenCacheInternal.OnBeforeAccessAsync(args).ConfigureAwait(false);
await tokenCacheInternal.OnBeforeWriteAsync(args).ConfigureAwait(false);
Expand Down Expand Up @@ -1169,7 +1179,10 @@ async Task ITokenCacheInternal.RemoveAccountAsync(IAccount account, Authenticati
hasTokens: tokenCacheInternal.HasTokensNoLocks(),
suggestedCacheExpiry: null,
cancellationToken: requestParameters.RequestContext.UserCancellationToken,
correlationId: requestParameters.RequestContext.CorrelationId);
correlationId: requestParameters.RequestContext.CorrelationId,
requestScopes: requestParameters.Scope,
requestTenantId: requestParameters.AuthorityManager.OriginalAuthority.TenantId);


await tokenCacheInternal.OnAfterAccessAsync(args).ConfigureAwait(false);
}
Expand Down
58 changes: 56 additions & 2 deletions src/client/Microsoft.Identity.Client/TokenCacheNotificationArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT License.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading;
using Microsoft.IdentityModel.Abstractions;

Expand Down Expand Up @@ -37,6 +39,8 @@ public TokenCacheNotificationArgs(
hasTokens,
suggestedCacheExpiry,
cancellationToken,
default,
default,
default)
{
}
Expand All @@ -54,7 +58,39 @@ public TokenCacheNotificationArgs(
bool hasTokens,
DateTimeOffset? suggestedCacheExpiry,
CancellationToken cancellationToken,
Guid correlationId)
Guid correlationId)
: this(tokenCache,
clientId,
account,
hasStateChanged,
isApplicationCache,
suggestedCacheKey,
hasTokens,
suggestedCacheExpiry,
cancellationToken,
correlationId,
default,
default)
{
}

/// <summary>
/// This constructor is for test purposes only. It allows apps to unit test their MSAL token cache implementation code.
/// </summary>
public TokenCacheNotificationArgs( // only use this constructor in product code
ITokenCacheSerializer tokenCache,
string clientId,
IAccount account,
bool hasStateChanged,
bool isApplicationCache,
string suggestedCacheKey,
bool hasTokens,
DateTimeOffset? suggestedCacheExpiry,
CancellationToken cancellationToken,
Guid correlationId,
IEnumerable<string> requestScopes,
string requestTenantId)

{
TokenCache = tokenCache;
ClientId = clientId;
Expand All @@ -65,6 +101,8 @@ public TokenCacheNotificationArgs(
HasTokens = hasTokens;
CancellationToken = cancellationToken;
CorrelationId = correlationId;
RequestScopes = requestScopes;
RequestTenantId = requestTenantId;
SuggestedCacheExpiry = suggestedCacheExpiry;
}

Expand Down Expand Up @@ -135,13 +173,29 @@ public TokenCacheNotificationArgs(
/// </summary>
public Guid CorrelationId { get; }

/// <summary>
/// Scopes specified in the AcquireToken* method.
/// </summary>
/// <remarks>
/// Note that Azure Active Directory may return more scopes than requested, however this property will only contain the scopes requested.
/// </remarks>
public IEnumerable<string> RequestScopes { get; }

/// <summary>
/// Tenant Id specified in the AcquireToken* method, if any.
/// </summary>
/// <remarks>
/// Note that if "common" or "organizations" is specified, Azure Active Directory discovers the host tenant for the user, and the tokens
/// are associated with it. This property is not impacted.</remarks>
public string RequestTenantId { get; }

/// <summary>
/// Suggested value of the expiry, to help determining the cache eviction time.
/// This value is <b>only</b> set on the <code>OnAfterAccess</code> delegate, on a cache write
/// operation (that is when <code>args.HasStateChanged</code> is <code>true</code>) and when the cache write
/// is triggered from the <code>AcquireTokenForClient</code> method. In all other cases it's <code>null</code>, as there is a refresh token, and therefore the
/// access tokens are refreshable.
/// </summary>
public DateTimeOffset? SuggestedCacheExpiry { get; private set; }
public DateTimeOffset? SuggestedCacheExpiry { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class TokenCacheAccessRecorder
public TokenCacheNotificationArgs LastBeforeWriteNotificationArgs { get; private set; }
public TokenCacheNotificationArgs LastAfterAccessNotificationArgs { get; private set; }

public TokenCacheAccessRecorder(TokenCache tokenCache)
public TokenCacheAccessRecorder(TokenCache tokenCache, Action<TokenCacheNotificationArgs> assertLogic = null)
{
_tokenCache = tokenCache;

Expand All @@ -38,6 +38,7 @@ public TokenCacheAccessRecorder(TokenCache tokenCache)
var existingBeforeAccessCallback = _tokenCache.BeforeAccess;
_tokenCache.BeforeAccess = (args) =>
{
assertLogic?.Invoke(args);
BeforeAccessCount++;
LastBeforeAccessNotificationArgs = args;
existingBeforeAccessCallback?.Invoke(args);
Expand All @@ -46,6 +47,7 @@ public TokenCacheAccessRecorder(TokenCache tokenCache)
var existingBeforeWriteCallback = _tokenCache.BeforeWrite;
_tokenCache.BeforeWrite = (args) =>
{
assertLogic?.Invoke(args);
BeforeWriteCount++;
LastBeforeWriteNotificationArgs = args;
Expand All @@ -55,6 +57,7 @@ public TokenCacheAccessRecorder(TokenCache tokenCache)
var existingAfterAccessCallback = _tokenCache.AfterAccess;
_tokenCache.AfterAccess = (args) =>
{
assertLogic?.Invoke(args);
AfterAccessTotalCount++;
LastAfterAccessNotificationArgs = args;
Expand All @@ -65,7 +68,6 @@ public TokenCacheAccessRecorder(TokenCache tokenCache)
existingAfterAccessCallback?.Invoke(args);
};

}

public void AssertAccessCounts(int expectedReads, int expectedWrites)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Cache;

namespace Microsoft.Identity.Test.Common.Core.Helpers
{
internal static class TokenCacheExtensions
{
public static TokenCacheAccessRecorder RecordAccess(this ITokenCache tokenCache)
public static TokenCacheAccessRecorder RecordAccess(this ITokenCache tokenCache, Action<TokenCacheNotificationArgs> assertLogic = null)
{
return new TokenCacheAccessRecorder(tokenCache as TokenCache);
return new TokenCacheAccessRecorder(tokenCache as TokenCache, assertLogic);
}

public static void ClearAccessTokens(this ITokenCacheAccessor accessor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ private async Task RunClientCredsAsync(Cloud cloud, CredentialType credentialTyp
Assert.IsTrue(appCacheRecorder.LastAfterAccessNotificationArgs.HasTokens);
Assert.AreEqual(correlationId, appCacheRecorder.LastAfterAccessNotificationArgs.CorrelationId);
Assert.AreEqual(correlationId, appCacheRecorder.LastBeforeAccessNotificationArgs.CorrelationId);
CollectionAssert.AreEquivalent(settings.AppScopes.ToArray(), appCacheRecorder.LastBeforeAccessNotificationArgs.RequestScopes.ToArray());
CollectionAssert.AreEquivalent(settings.AppScopes.ToArray(), appCacheRecorder.LastAfterAccessNotificationArgs.RequestScopes.ToArray());
Assert.AreEqual(settings.TenantId, appCacheRecorder.LastBeforeAccessNotificationArgs.RequestTenantId ?? "");
Assert.AreEqual(settings.TenantId, appCacheRecorder.LastAfterAccessNotificationArgs.RequestTenantId ?? "");
Assert.IsTrue(authResult.AuthenticationResultMetadata.DurationTotalInMs > 0);
Assert.IsTrue(authResult.AuthenticationResultMetadata.DurationInHttpInMs > 0);
Assert.AreEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ public async Task UnknownNodesTestAsync()
cache = notificationArgs.TokenCache.SerializeMsalV3();
});

var notification = new TokenCacheNotificationArgs(tokenCache, null, null, false, false, null, false, null, default);
var notification = new TokenCacheNotificationArgs(tokenCache, null, null, false, false, null, false, null, default, default, default, default);

await (tokenCache as ITokenCacheInternal).OnBeforeAccessAsync(notification).ConfigureAwait(false);
await (tokenCache as ITokenCacheInternal).OnAfterAccessAsync(notification).ConfigureAwait(false);
(tokenCache as ITokenCacheInternal).Accessor.AssertItemCount(5, 4, 3, 3, 3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,5 +452,86 @@ public void IsSerializedTest()
Assert.IsTrue((cca.UserTokenCache as ITokenCacheInternal).IsAppSubscribedToSerializationEvents());

}

[TestMethod]
public async Task TokenCacheSerializationArgs_AppCache_TenantIdScopes_Async()
{
using (var harness = CreateTestHarness())
{

// Arrange
var cca = ConfidentialClientApplicationBuilder
.Create(TestConstants.ClientId)
.WithClientSecret(TestConstants.ClientSecret)
.WithHttpManager(harness.HttpManager)
.BuildConcrete();
CancellationTokenSource cts = new CancellationTokenSource();
var cancellationToken = cts.Token;

var appTokenCacheRecoder = cca.AppTokenCache.RecordAccess((args) =>
{
Assert.AreEqual(TestConstants.TenantId2, args.RequestTenantId);
Assert.AreEqual(TestConstants.ClientId, args.ClientId);
Assert.IsNull(args.Account);
Assert.IsTrue(args.IsApplicationCache);
Assert.AreEqual(cancellationToken, args.CancellationToken);
CollectionAssert.AreEquivalent(TestConstants.s_scope.ToArray(), args.RequestScopes.ToArray());
});

harness.HttpManager.AddAllMocks(TokenResponseType.Valid_ClientCredentials);

// Act - Client Credentials with authority override
await cca.AcquireTokenForClient(TestConstants.s_scope)
.WithTenantId(TestConstants.TenantId2)
.ExecuteAsync(cancellationToken)
.ConfigureAwait(false);

appTokenCacheRecoder.AssertAccessCounts(1, 1);

}
}

[TestMethod]
public async Task TokenCacheSerializationArgs_UserCache_TenantIdScopes_Async()
{
string[] inputScope = new[] { "input_scope_different_than_aad_scope" };
using (var harness = CreateTestHarness())
{

// Arrange
var cca = ConfidentialClientApplicationBuilder
.Create(TestConstants.ClientId)
.WithClientSecret(TestConstants.ClientSecret)
.WithHttpManager(harness.HttpManager)
.BuildConcrete();
CancellationTokenSource cts = new CancellationTokenSource();
var cancellationToken = cts.Token;

var userCacheRecorder = cca.UserTokenCache.RecordAccess((args) =>
{
Assert.AreEqual(TestConstants.TenantId2, args.RequestTenantId);
Assert.AreEqual(TestConstants.ClientId, args.ClientId);
Assert.IsNotNull(args.Account);
Assert.IsFalse(args.IsApplicationCache);
Assert.AreEqual(cancellationToken, args.CancellationToken);
CollectionAssert.AreEquivalent(inputScope, args.RequestScopes.ToArray());
});

harness.HttpManager.AddAllMocks(TokenResponseType.Valid_UserFlows);

// Act - Client Credentials with authority override
var result = await cca.AcquireTokenByAuthorizationCode(inputScope, "code")
.WithTenantId(TestConstants.TenantId2)
.ExecuteAsync(cancellationToken)
.ConfigureAwait(false);

userCacheRecorder.AssertAccessCounts(0, 1);

CollectionAssert.AreEquivalent(TestConstants.s_scope.ToArray(), result.Scopes.ToArray());

}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ public void MultiThreadSuccessfulResponseFromLocalImds_HasOnlyOneImdsCall()
AddMockedResponse(MockHelpers.CreateSuccessResponseMessage(TestConstants.Region));
SemaphoreSlim semaphore = new SemaphoreSlim(0);
int threadCount = MaxThreadCount;
#pragma warning disable VSTHRD101 // Avoid unsupported async delegates - acceptable risk (crash the test proj)
var result = Parallel.For(0, MaxThreadCount, async (i) =>
{
try
Expand All @@ -115,8 +116,9 @@ public void MultiThreadSuccessfulResponseFromLocalImds_HasOnlyOneImdsCall()
{
Interlocked.Decrement(ref threadCount);
}
});

});
#pragma warning restore VSTHRD101 // Avoid unsupported async delegates

while (threadCount != 0)
{
Thread.Sleep(100);
Expand Down
Loading

0 comments on commit 2d9afe4

Please sign in to comment.