Skip to content

Commit

Permalink
[wip] add cancellation token support (#1242)
Browse files Browse the repository at this point in the history
* add cancellation token support

* suggestion from JM

* update xml comments
  • Loading branch information
jennyf19 authored Jun 7, 2021
1 parent 15d3b6c commit b348c48
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 16 deletions.
79 changes: 78 additions & 1 deletion src/Microsoft.Identity.Web/Microsoft.Identity.Web.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions src/Microsoft.Identity.Web/TokenAcquisition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Net.Http;
using System.Security.Claims;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication.JwtBearer;
using Microsoft.AspNetCore.Authentication.OAuth;
Expand Down Expand Up @@ -362,7 +363,7 @@ public Task<AuthenticationResult> GetAuthenticationResultForAppAsync(

try
{
return builder.ExecuteAsync();
return builder.ExecuteAsync(tokenAcquisitionOptions != null ? tokenAcquisitionOptions.CancellationToken : CancellationToken.None);
}
catch (MsalServiceException exMsal) when (IsInvalidClientCertificateError(exMsal))
{
Expand Down Expand Up @@ -746,7 +747,7 @@ private IConfidentialClientApplication BuildConfidentialClientApplication(Merged
}
}

return await builder.ExecuteAsync()
return await builder.ExecuteAsync(tokenAcquisitionOptions != null ? tokenAcquisitionOptions.CancellationToken : CancellationToken.None)
.ConfigureAwait(false);
}

Expand Down Expand Up @@ -868,7 +869,7 @@ private Task<AuthenticationResult> GetAuthenticationResultForWebAppWithAccountFr
builder.WithAuthority(authority);
}

return builder.ExecuteAsync();
return builder.ExecuteAsync(tokenAcquisitionOptions != null ? tokenAcquisitionOptions.CancellationToken : CancellationToken.None);
}

private static bool AcceptedTokenVersionMismatch(MsalUiRequiredException msalServiceException)
Expand Down
8 changes: 8 additions & 0 deletions src/Microsoft.Identity.Web/TokenAcquisitionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Threading;
using Microsoft.Identity.Client.AppConfig;

namespace Microsoft.Identity.Web
Expand Down Expand Up @@ -46,6 +47,11 @@ public class TokenAcquisitionOptions
/// </summary>
public PoPAuthenticationConfiguration? PoPConfiguration { get; set; }

/// <summary>
/// Cancellation token to be used when calling the token acquisition methods.
/// </summary>
public CancellationToken CancellationToken { get; set; } = CancellationToken.None;

/// <summary>
/// Clone the options (to be able to override them).
/// </summary>
Expand All @@ -58,6 +64,8 @@ public TokenAcquisitionOptions Clone()
ExtraQueryParameters = ExtraQueryParameters,
ForceRefresh = ForceRefresh,
Claims = Claims,
PoPConfiguration = PoPConfiguration,
CancellationToken = CancellationToken,
};
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Threading;

namespace Microsoft.Identity.Web.TokenCacheProviders
{
/// <summary>
/// Set of properties that the token cache serialization implementations might use to optimize the cache.
/// </summary>
public class CacheSerializerHints
{
/// <summary>
/// CancellationToken enabling cooperative cancellation between threads, thread pool, or Task objects.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Memory;
Expand Down Expand Up @@ -71,6 +72,18 @@ public MsalDistributedTokenCacheAdapter(
/// <param name="cacheKey">Key of the cache to remove.</param>
/// <returns>A <see cref="Task"/> that completes when key removal has completed.</returns>
protected override async Task RemoveKeyAsync(string cacheKey)
{
await RemoveKeyAsync(cacheKey, new CacheSerializerHints()).ConfigureAwait(false);
}

/// <summary>
/// Removes a specific token cache, described by its cache key
/// from the distributed cache.
/// </summary>
/// <param name="cacheKey">Key of the cache to remove.</param>
/// <param name="cacheSerializerHints">Hints for the cache serialization implementation optimization.</param>
/// <returns>A <see cref="Task"/> that completes when key removal has completed.</returns>
protected override async Task RemoveKeyAsync(string cacheKey, CacheSerializerHints cacheSerializerHints)
{
string remove = "Remove";
_memoryCache.Remove(cacheKey);
Expand All @@ -79,7 +92,7 @@ protected override async Task RemoveKeyAsync(string cacheKey)

await L2OperationWithRetryOnFailureAsync(
remove,
(cacheKey) => _distributedCache.RemoveAsync(cacheKey),
(cacheKey) => _distributedCache.RemoveAsync(cacheKey, cacheSerializerHints.CancellationToken),
cacheKey).ConfigureAwait(false);
}

Expand All @@ -91,6 +104,19 @@ await L2OperationWithRetryOnFailureAsync(
/// <returns>Read blob representing a token cache for the cache key
/// (account or app).</returns>
protected override async Task<byte[]> ReadCacheBytesAsync(string cacheKey)
{
return await ReadCacheBytesAsync(cacheKey, new CacheSerializerHints()).ConfigureAwait(false);
}

/// <summary>
/// Read a specific token cache, described by its cache key, from the
/// distributed cache.
/// </summary>
/// <param name="cacheKey">Key of the cache item to retrieve.</param>
/// <param name="cacheSerializerHints">Hints for the cache serialization implementation optimization.</param>
/// <returns>Read blob representing a token cache for the cache key
/// (account or app).</returns>
protected override async Task<byte[]> ReadCacheBytesAsync(string cacheKey, CacheSerializerHints cacheSerializerHints)
{
string read = "Read";
// check memory cache first
Expand All @@ -99,14 +125,17 @@ protected override async Task<byte[]> ReadCacheBytesAsync(string cacheKey)

if (result == null)
{
var measure = await Task.Run(async () =>
var measure = await Task.Run(
async () =>
{
// not found in memory, check distributed cache
result = await L2OperationWithRetryOnFailureAsync(
read,
(cacheKey) => _distributedCache.GetAsync(cacheKey),
(cacheKey) => _distributedCache.GetAsync(cacheKey, cacheSerializerHints.CancellationToken),
cacheKey).ConfigureAwait(false);
}).Measure().ConfigureAwait(false);
#pragma warning disable CA1062 // Validate arguments of public methods
}, cacheSerializerHints.CancellationToken).Measure().ConfigureAwait(false);
#pragma warning restore CA1062 // Validate arguments of public methods

Logger.DistributedCacheReadTime(_logger, _distributedCacheType, read, measure.MilliSeconds, null);

Expand All @@ -128,7 +157,7 @@ protected override async Task<byte[]> ReadCacheBytesAsync(string cacheKey)
{
await L2OperationWithRetryOnFailureAsync(
"Refresh",
(cacheKey) => _distributedCache.RefreshAsync(cacheKey),
(cacheKey) => _distributedCache.RefreshAsync(cacheKey, cacheSerializerHints.CancellationToken),
cacheKey,
result!).ConfigureAwait(false);
}
Expand All @@ -145,6 +174,21 @@ await L2OperationWithRetryOnFailureAsync(
/// <param name="bytes">blob to write.</param>
/// <returns>A <see cref="Task"/> that completes when a write operation has completed.</returns>
protected override async Task WriteCacheBytesAsync(string cacheKey, byte[] bytes)
{
await WriteCacheBytesAsync(cacheKey, bytes, new CacheSerializerHints()).ConfigureAwait(false);
}

/// <summary>
/// Writes a token cache blob to the serialization cache (by key).
/// </summary>
/// <param name="cacheKey">Cache key.</param>
/// <param name="bytes">blob to write.</param>
/// <param name="cacheSerializerHints">Hints for the cache serialization implementation optimization.</param>
/// <returns>A <see cref="Task"/> that completes when a write operation has completed.</returns>
protected override async Task WriteCacheBytesAsync(
string cacheKey,
byte[] bytes,
CacheSerializerHints cacheSerializerHints)
{
string write = "Write";
MemoryCacheEntryOptions memoryCacheEntryOptions = new MemoryCacheEntryOptions()
Expand All @@ -160,7 +204,7 @@ protected override async Task WriteCacheBytesAsync(string cacheKey, byte[] bytes

await L2OperationWithRetryOnFailureAsync(
write,
(cacheKey) => _distributedCache.SetAsync(cacheKey, bytes, _distributedCacheOptions),
(cacheKey) => _distributedCache.SetAsync(cacheKey, bytes, _distributedCacheOptions, cacheSerializerHints.CancellationToken),
cacheKey).Measure().ConfigureAwait(false);
}

Expand Down
Loading

0 comments on commit b348c48

Please sign in to comment.