From e04faf6626cec1939793b6d4aa9d5f0c658056fc Mon Sep 17 00:00:00 2001 From: Serkant Karaca Date: Thu, 14 Dec 2017 09:44:04 -0800 Subject: [PATCH] Aad token provider (#225) * Adding AAD token support. * improve lock * Adding MSI provider. * Some small changes * Adddresing comments and couple small fixes. * Remove string resource * Remove test secrets * Renaming paramater endpoint to endpointAddress * Renaming paramater endpoint to endpointAddress * Remove unused AsyncLock * Remove redundant config name * Remove using * Remove usings * Merge fix * API name change * Xmldoc type fix. * Remove SecurityToken internal constructor * Build error fix * Cleanup code * Pass temporary audience instead of empty string to SecurityToken. * Make token type required * Refactoring SharedAccessSignatureToken for cleanup * Function renaming to carry UTC notion. * Remove procted accesors from SecurityToken class properties. --- ...Microsoft.Azure.EventHubs.Processor.csproj | 1 - .../Amqp/AmqpEventHubClient.cs | 31 +++- .../Amqp/AmqpServiceClient.cs | 4 +- .../EventHubClient.cs | 153 +++++++++++++++- .../Microsoft.Azure.EventHubs.csproj | 14 +- .../AzureActiveDirectoryTokenProvider.cs | 96 ++++++++++ .../Primitives/ClientConstants.cs | 9 + .../EventHubsConnectionStringBuilder.cs | 58 +++--- .../Primitives/ITokenProvider.cs | 25 +++ .../Primitives/JsonSecurityToken.cs | 32 ++++ .../ManagedServiceIdentityTokenProvider.cs | 29 +++ .../Primitives/SecurityToken.cs | 140 ++++----------- .../Primitives/SharedAccessSignatureToken.cs | 150 ++++++++++++++++ .../SharedAccessSignatureTokenProvider.cs | 132 +++----------- .../Primitives/TokenProvider.cs | 141 +++++++++------ .../Client/ClientTestBase.cs | 2 +- .../Client/ConnectionStringBuilderTests.cs | 2 +- .../Client/MiscTests.cs | 41 ----- .../Client/TokenProviderTests.cs | 168 ++++++++++++++++++ 19 files changed, 874 insertions(+), 354 deletions(-) create mode 100644 src/Microsoft.Azure.EventHubs/Primitives/AzureActiveDirectoryTokenProvider.cs create mode 100644 src/Microsoft.Azure.EventHubs/Primitives/ITokenProvider.cs create mode 100644 src/Microsoft.Azure.EventHubs/Primitives/JsonSecurityToken.cs create mode 100644 src/Microsoft.Azure.EventHubs/Primitives/ManagedServiceIdentityTokenProvider.cs create mode 100644 src/Microsoft.Azure.EventHubs/Primitives/SharedAccessSignatureToken.cs create mode 100644 test/Microsoft.Azure.EventHubs.Tests/Client/TokenProviderTests.cs diff --git a/src/Microsoft.Azure.EventHubs.Processor/Microsoft.Azure.EventHubs.Processor.csproj b/src/Microsoft.Azure.EventHubs.Processor/Microsoft.Azure.EventHubs.Processor.csproj index 850bb1d..07cc13a 100644 --- a/src/Microsoft.Azure.EventHubs.Processor/Microsoft.Azure.EventHubs.Processor.csproj +++ b/src/Microsoft.Azure.EventHubs.Processor/Microsoft.Azure.EventHubs.Processor.csproj @@ -29,7 +29,6 @@ - $(DefineConstants);UAP10_0 UAP,Version=v10.0 UAP 10.0.14393.0 diff --git a/src/Microsoft.Azure.EventHubs/Amqp/AmqpEventHubClient.cs b/src/Microsoft.Azure.EventHubs/Amqp/AmqpEventHubClient.cs index 6a3f739..624354e 100644 --- a/src/Microsoft.Azure.EventHubs/Amqp/AmqpEventHubClient.cs +++ b/src/Microsoft.Azure.EventHubs/Amqp/AmqpEventHubClient.cs @@ -5,7 +5,6 @@ namespace Microsoft.Azure.EventHubs.Amqp { using System; using System.Linq; - using System.Net; using System.Threading.Tasks; using Microsoft.Azure.Amqp.Sasl; using Microsoft.Azure.Amqp; @@ -26,17 +25,33 @@ public AmqpEventHubClient(EventHubsConnectionStringBuilder csb) if (!string.IsNullOrWhiteSpace(csb.SharedAccessSignature)) { - this.TokenProvider = TokenProvider.CreateSharedAccessSignatureTokenProvider(csb.SharedAccessSignature); + this.InternalTokenProvider = TokenProvider.CreateSharedAccessSignatureTokenProvider(csb.SharedAccessSignature); } else { - this.TokenProvider = TokenProvider.CreateSharedAccessSignatureTokenProvider(csb.SasKeyName, csb.SasKey); + this.InternalTokenProvider = TokenProvider.CreateSharedAccessSignatureTokenProvider(csb.SasKeyName, csb.SasKey); } this.CbsTokenProvider = new TokenProviderAdapter(this); this.ConnectionManager = new FaultTolerantAmqpObject(this.CreateConnectionAsync, this.CloseConnection); } + public AmqpEventHubClient( + Uri endpointAddress, + string entityPath, + ITokenProvider tokenProvider, + TimeSpan operationTimeout, + TransportType transportType) + : base(new EventHubsConnectionStringBuilder(endpointAddress, entityPath, operationTimeout, transportType)) + { + this.ContainerId = Guid.NewGuid().ToString("N"); + this.AmqpVersion = new Version(1, 0, 0, 0); + this.MaxFrameSize = AmqpConstants.DefaultMaxFrameSize; + this.InternalTokenProvider = tokenProvider; + this.CbsTokenProvider = new TokenProviderAdapter(this); + this.ConnectionManager = new FaultTolerantAmqpObject(this.CreateConnectionAsync, this.CloseConnection); + } + internal ICbsTokenProvider CbsTokenProvider { get; } internal FaultTolerantAmqpObject ConnectionManager { get; } @@ -47,7 +62,7 @@ public AmqpEventHubClient(EventHubsConnectionStringBuilder csb) uint MaxFrameSize { get; } - internal TokenProvider TokenProvider { get; } + internal ITokenProvider InternalTokenProvider { get; } internal override EventDataSender OnCreateEventSender(string partitionId) { @@ -110,7 +125,7 @@ internal static AmqpSettings CreateAmqpSettings( string sslHostName = null, bool useWebSockets = false, bool sslStreamUpgrade = false, - NetworkCredential networkCredential = null, + System.Net.NetworkCredential networkCredential = null, bool forceTokenProvider = true) { var settings = new AmqpSettings(); @@ -266,11 +281,9 @@ public TokenProviderAdapter(AmqpEventHubClient eventHubClient) public async Task GetTokenAsync(Uri namespaceAddress, string appliesTo, string[] requiredClaims) { - string claim = requiredClaims?.FirstOrDefault(); - var tokenProvider = this.eventHubClient.TokenProvider; var timeout = this.eventHubClient.ConnectionStringBuilder.OperationTimeout; - var token = await tokenProvider.GetTokenAsync(appliesTo, claim, timeout).ConfigureAwait(false); - return new CbsToken(token.TokenValue, CbsConstants.ServiceBusSasTokenType, token.ExpiresAtUtc); + var token = await this.eventHubClient.InternalTokenProvider.GetTokenAsync(appliesTo, timeout).ConfigureAwait(false); + return new CbsToken(token.TokenValue, token.TokenType, token.ExpiresAtUtc); } } } diff --git a/src/Microsoft.Azure.EventHubs/Amqp/AmqpServiceClient.cs b/src/Microsoft.Azure.EventHubs/Amqp/AmqpServiceClient.cs index f78640a..06ec5bd 100644 --- a/src/Microsoft.Azure.EventHubs/Amqp/AmqpServiceClient.cs +++ b/src/Microsoft.Azure.EventHubs/Amqp/AmqpServiceClient.cs @@ -150,9 +150,9 @@ async Task GetTokenString() // when checking for token expiry. if (this.token == null || DateTime.UtcNow > this.token.ExpiresAtUtc.Subtract(TimeSpan.FromMinutes(5))) { - this.token = await this.eventHubClient.TokenProvider.GetTokenAsync( + this.token = await this.eventHubClient.InternalTokenProvider.GetTokenAsync( this.eventHubClient.ConnectionStringBuilder.Endpoint.AbsoluteUri, - ClaimConstants.Listen, this.eventHubClient.ConnectionStringBuilder.OperationTimeout).ConfigureAwait(false); + this.eventHubClient.ConnectionStringBuilder.OperationTimeout).ConfigureAwait(false); } return this.token.TokenValue.ToString(); diff --git a/src/Microsoft.Azure.EventHubs/EventHubClient.cs b/src/Microsoft.Azure.EventHubs/EventHubClient.cs index e13133f..602cd33 100644 --- a/src/Microsoft.Azure.EventHubs/EventHubClient.cs +++ b/src/Microsoft.Azure.EventHubs/EventHubClient.cs @@ -3,11 +3,12 @@ namespace Microsoft.Azure.EventHubs { - using Amqp; using System; using System.Collections.Generic; using System.Diagnostics; using System.Threading.Tasks; + using Microsoft.Azure.EventHubs.Amqp; + using Microsoft.IdentityModel.Clients.ActiveDirectory; /// /// Anchor class - all EventHub client operations start here. @@ -70,6 +71,156 @@ public static EventHubClient CreateFromConnectionString(string connectionString) return Create(csb); } + /// + /// Creates a new instance of the Event Hubs client using the specified endpoint, entity path, and token provider. + /// + /// Fully qualified domain name for Event Hubs. Most likely, {yournamespace}.servicebus.windows.net + /// Event Hub path + /// Token provider which will generate security tokens for authorization. + /// Operation timeout for Event Hubs operations. + /// Transport type on connection. + /// + public static EventHubClient Create( + Uri endpointAddress, + string entityPath, + ITokenProvider tokenProvider, + TimeSpan? operationTimeout = null, + TransportType transportType = TransportType.Amqp) + { + if (endpointAddress == null) + { + throw Fx.Exception.ArgumentNull(nameof(endpointAddress)); + } + + if (string.IsNullOrWhiteSpace(entityPath)) + { + throw Fx.Exception.ArgumentNullOrWhiteSpace(nameof(entityPath)); + } + + if (tokenProvider == null) + { + throw Fx.Exception.ArgumentNull(nameof(tokenProvider)); + } + + EventHubsEventSource.Log.EventHubClientCreateStart(endpointAddress.Host, entityPath); + EventHubClient eventHubClient = new AmqpEventHubClient( + endpointAddress, + entityPath, + tokenProvider, + operationTimeout?? ClientConstants.DefaultOperationTimeout, + transportType); + EventHubsEventSource.Log.EventHubClientCreateStop(eventHubClient.ClientId); + return eventHubClient; + } + + /// + /// Creates a new instance of the Event Hubs client using the specified endpoint, entity path, AAD authentication context. + /// + /// Fully qualified domain name for Event Hubs. Most likely, {yournamespace}.servicebus.windows.net + /// Event Hub path + /// AuthenticationContext for AAD. + /// The app credential. + /// Operation timeout for Event Hubs operations. + /// Transport type on connection. + /// + public static EventHubClient Create( + Uri endpointAddress, + string entityPath, + AuthenticationContext authContext, + ClientCredential clientCredential, + TimeSpan? operationTimeout = null, + TransportType transportType = TransportType.Amqp) + { + return Create( + endpointAddress, + entityPath, + TokenProvider.CreateAadTokenProvider(authContext, clientCredential), + operationTimeout, + transportType); + } + + /// + /// Creates a new instance of the Event Hubs client using the specified endpoint, entity path, AAD authentication context. + /// + /// Fully qualified domain name for Event Hubs. Most likely, {yournamespace}.servicebus.windows.net + /// Event Hub path + /// AuthenticationContext for AAD. + /// ClientId for AAD. + /// The redirectUri on Client App. + /// Platform parameters + /// User Identifier + /// Operation timeout for Event Hubs operations. + /// Transport type on connection. + /// + public static EventHubClient Create( + Uri endpointAddress, + string entityPath, + AuthenticationContext authContext, + string clientId, + Uri redirectUri, + IPlatformParameters platformParameters, + UserIdentifier userIdentifier = null, + TimeSpan? operationTimeout = null, + TransportType transportType = TransportType.Amqp) + { + return Create( + endpointAddress, + entityPath, + TokenProvider.CreateAadTokenProvider(authContext, clientId, redirectUri, platformParameters, userIdentifier), + operationTimeout, + transportType); + } + +#if !UAP10_0 + /// + /// Creates a new instance of the Event Hubs client using the specified endpoint, entity path, AAD authentication context. + /// + /// Fully qualified domain name for Event Hubs. Most likely, {yournamespace}.servicebus.windows.net + /// Event Hub path + /// AuthenticationContext for AAD. + /// The client assertion certificate credential. + /// Operation timeout for Event Hubs operations. + /// Transport type on connection. + /// + public static EventHubClient Create( + Uri endpointAddress, + string entityPath, + AuthenticationContext authContext, + ClientAssertionCertificate clientAssertionCertificate, + TimeSpan? operationTimeout = null, + TransportType transportType = TransportType.Amqp) + { + return Create( + endpointAddress, + entityPath, + TokenProvider.CreateAadTokenProvider(authContext, clientAssertionCertificate), + operationTimeout, + transportType); + } +#endif + + /// + /// Creates a new instance of the Event Hubs client using the specified endpoint, entity path on Azure Managed Service Identity authentication. + /// + /// Fully qualified domain name for Event Hubs. Most likely, {yournamespace}.servicebus.windows.net + /// Event Hub path + /// Operation timeout for Event Hubs operations. + /// Transport type on connection. + /// + public static EventHubClient CreateWithManagedServiceIdentity( + Uri endpointAddress, + string entityPath, + TimeSpan? operationTimeout = null, + TransportType transportType = TransportType.Amqp) + { + return Create( + endpointAddress, + entityPath, + TokenProvider.CreateManagedServiceIdentityTokenProvider(), + operationTimeout, + transportType); + } + static EventHubClient Create(EventHubsConnectionStringBuilder csb) { if (string.IsNullOrWhiteSpace(csb.EntityPath)) diff --git a/src/Microsoft.Azure.EventHubs/Microsoft.Azure.EventHubs.csproj b/src/Microsoft.Azure.EventHubs/Microsoft.Azure.EventHubs.csproj index 80ba2c2..92709c8 100644 --- a/src/Microsoft.Azure.EventHubs/Microsoft.Azure.EventHubs.csproj +++ b/src/Microsoft.Azure.EventHubs/Microsoft.Azure.EventHubs.csproj @@ -29,6 +29,10 @@ 2.0.0.0 + + $(DefineConstants);NET461 + + $(DefineConstants);UAP10_0 UAP,Version=v10.0 @@ -39,6 +43,10 @@ v5.0 + + $(DefineConstants);NETSTANDARD2_0 + + @@ -50,12 +58,16 @@ - + + + + + diff --git a/src/Microsoft.Azure.EventHubs/Primitives/AzureActiveDirectoryTokenProvider.cs b/src/Microsoft.Azure.EventHubs/Primitives/AzureActiveDirectoryTokenProvider.cs new file mode 100644 index 0000000..6877cc6 --- /dev/null +++ b/src/Microsoft.Azure.EventHubs/Primitives/AzureActiveDirectoryTokenProvider.cs @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.EventHubs +{ + using System; + using System.Threading.Tasks; + using Microsoft.IdentityModel.Clients.ActiveDirectory; + + /// + /// Represents the Azure Active Directory token provider for the Event Hubs. + /// + public class AzureActiveDirectoryTokenProvider : TokenProvider + { + readonly AuthenticationContext authContext; + readonly ClientCredential clientCredential; +#if !UAP10_0 + readonly ClientAssertionCertificate clientAssertionCertificate; +#endif + readonly string clientId; + readonly Uri redirectUri; + readonly IPlatformParameters platformParameters; + readonly UserIdentifier userIdentifier; + + enum AuthType + { + ClientCredential, + UserPasswordCredential, + ClientAssertionCertificate, + InteractiveUserLogin + } + + readonly AuthType authType; + + internal AzureActiveDirectoryTokenProvider(AuthenticationContext authContext, ClientCredential credential) + { + this.clientCredential = credential; + this.authContext = authContext; + this.authType = AuthType.ClientCredential; + this.clientId = clientCredential.ClientId; + } + +#if !UAP10_0 + internal AzureActiveDirectoryTokenProvider(AuthenticationContext authContext, ClientAssertionCertificate clientAssertionCertificate) + { + this.clientAssertionCertificate = clientAssertionCertificate; + this.authContext = authContext; + this.authType = AuthType.ClientAssertionCertificate; + this.clientId = clientAssertionCertificate.ClientId; + } +#endif + + internal AzureActiveDirectoryTokenProvider(AuthenticationContext authContext, string clientId, Uri redirectUri, IPlatformParameters platformParameters, UserIdentifier userIdentifier) + { + this.authContext = authContext; + this.clientId = clientId; + this.redirectUri = redirectUri; + this.platformParameters = platformParameters; + this.userIdentifier = userIdentifier; + this.authType = AuthType.InteractiveUserLogin; + } + + /// + /// Gets a for the given audience and duration. + /// + /// The URI which the access token applies to + /// The time span that specifies the timeout value for the message that gets the security token + /// + public override async Task GetTokenAsync(string appliesTo, TimeSpan timeout) + { + AuthenticationResult authResult; + + switch (this.authType) + { + case AuthType.ClientCredential: + authResult = await this.authContext.AcquireTokenAsync(ClientConstants.AadEventHubsAudience, this.clientCredential); + break; + +#if !UAP10_0 + case AuthType.ClientAssertionCertificate: + authResult = await this.authContext.AcquireTokenAsync(ClientConstants.AadEventHubsAudience, this.clientAssertionCertificate); + break; +#endif + + case AuthType.InteractiveUserLogin: + authResult = await this.authContext.AcquireTokenAsync(ClientConstants.AadEventHubsAudience, this.clientId, this.redirectUri, this.platformParameters, this.userIdentifier); + break; + + default: + throw new NotSupportedException(); + } + + return new JsonSecurityToken(authResult.AccessToken, appliesTo); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Azure.EventHubs/Primitives/ClientConstants.cs b/src/Microsoft.Azure.EventHubs/Primitives/ClientConstants.cs index a968675..a2d6bc4 100644 --- a/src/Microsoft.Azure.EventHubs/Primitives/ClientConstants.cs +++ b/src/Microsoft.Azure.EventHubs/Primitives/ClientConstants.cs @@ -3,9 +3,18 @@ namespace Microsoft.Azure.EventHubs { + using System; + static class ClientConstants { public const int TimerToleranceInSeconds = 5; public const int ServerBusyBaseSleepTimeInSecs = 4; + + public const string SasTokenType = "servicebus.windows.net:sastoken"; + public const string JsonWebTokenType = "jwt"; + public const string AadEventHubsAudience = "https://eventhubs.azure.net/"; + + public static TimeSpan DefaultOperationTimeout = TimeSpan.FromMinutes(1); + public static TransportType DefaultTransportType = TransportType.Amqp; } } diff --git a/src/Microsoft.Azure.EventHubs/Primitives/EventHubsConnectionStringBuilder.cs b/src/Microsoft.Azure.EventHubs/Primitives/EventHubsConnectionStringBuilder.cs index e440427..449734c 100644 --- a/src/Microsoft.Azure.EventHubs/Primitives/EventHubsConnectionStringBuilder.cs +++ b/src/Microsoft.Azure.EventHubs/Primitives/EventHubsConnectionStringBuilder.cs @@ -49,8 +49,6 @@ public class EventHubsConnectionStringBuilder const char KeyValueSeparator = '='; const char KeyValuePairDelimiter = ';'; - static readonly TimeSpan DefaultOperationTimeout = TimeSpan.FromMinutes(1); - static readonly TransportType DefaultTransportType = TransportType.Amqp; static readonly string EndpointScheme = "amqps"; static readonly string EndpointConfigName = "Endpoint"; static readonly string SharedAccessKeyNameConfigName = "SharedAccessKeyName"; @@ -58,7 +56,6 @@ public class EventHubsConnectionStringBuilder static readonly string EntityPathConfigName = "EntityPath"; static readonly string OperationTimeoutConfigName = "OperationTimeout"; static readonly string TransportTypeConfigName = "TransportType"; - static readonly string OperationTimeoutName = "OperationTimeout"; static readonly string SharedAccessSignatureConfigName = "SharedAccessSignature"; /// @@ -73,7 +70,7 @@ public EventHubsConnectionStringBuilder( string entityPath, string sharedAccessKeyName, string sharedAccessKey) - : this (endpointAddress, entityPath, sharedAccessKeyName, sharedAccessKey, DefaultOperationTimeout) + : this (endpointAddress, entityPath, sharedAccessKeyName, sharedAccessKey, ClientConstants.DefaultOperationTimeout) { } @@ -124,10 +121,31 @@ public EventHubsConnectionStringBuilder( this.SharedAccessSignature = sharedAccessSignature; } - EventHubsConnectionStringBuilder( + /// + /// ConnectionString format: + /// Endpoint=sb://namespace_DNS_Name;EntityPath=EVENT_HUB_NAME;SharedAccessKeyName=SHARED_ACCESS_KEY_NAME;SharedAccessKey=SHARED_ACCESS_KEY + /// + /// Event Hubs ConnectionString + public EventHubsConnectionStringBuilder(string connectionString) + { + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw Fx.Exception.ArgumentNullOrWhiteSpace(nameof(connectionString)); + } + + // Assign default values. + this.OperationTimeout = ClientConstants.DefaultOperationTimeout; + this.TransportType = TransportType.Amqp; + + // Parse the connection string now and override default values if any provided. + this.ParseConnectionString(connectionString); + } + + internal EventHubsConnectionStringBuilder( Uri endpointAddress, string entityPath, - TimeSpan operationTimeout) + TimeSpan operationTimeout, + TransportType transportType = TransportType.Amqp) { if (endpointAddress == null) { @@ -147,27 +165,7 @@ public EventHubsConnectionStringBuilder( this.EntityPath = entityPath; this.OperationTimeout = operationTimeout; - this.TransportType = DefaultTransportType; - } - - /// - /// ConnectionString format: - /// Endpoint=sb://namespace_DNS_Name;EntityPath=EVENT_HUB_NAME;SharedAccessKeyName=SHARED_ACCESS_KEY_NAME;SharedAccessKey=SHARED_ACCESS_KEY - /// - /// Event Hubs ConnectionString - public EventHubsConnectionStringBuilder(string connectionString) - { - if (string.IsNullOrWhiteSpace(connectionString)) - { - throw Fx.Exception.ArgumentNullOrWhiteSpace(nameof(connectionString)); - } - - // Assign default values. - this.OperationTimeout = DefaultOperationTimeout; - this.TransportType = TransportType.Amqp; - - // Parse the connection string now and override default values if any provided. - this.ParseConnectionString(connectionString); + this.TransportType = transportType; } /// @@ -253,12 +251,12 @@ public override string ToString() connectionStringBuilder.Append($"{SharedAccessSignatureConfigName}{KeyValueSeparator}{this.SharedAccessSignature}{KeyValuePairDelimiter}"); } - if (this.OperationTimeout != DefaultOperationTimeout) + if (this.OperationTimeout != ClientConstants.DefaultOperationTimeout) { connectionStringBuilder.Append($"{OperationTimeoutConfigName}{KeyValueSeparator}{this.OperationTimeout}{KeyValuePairDelimiter}"); } - if (this.TransportType != DefaultTransportType) + if (this.TransportType != ClientConstants.DefaultTransportType) { connectionStringBuilder.Append($"{TransportTypeConfigName}{KeyValueSeparator}{TransportType}{KeyValuePairDelimiter}"); } @@ -339,7 +337,7 @@ void ParseConnectionString(string connectionString) { this.SharedAccessSignature = value; } - else if (key.Equals(OperationTimeoutName, StringComparison.OrdinalIgnoreCase)) + else if (key.Equals(OperationTimeoutConfigName, StringComparison.OrdinalIgnoreCase)) { this.OperationTimeout = TimeSpan.Parse(value); } diff --git a/src/Microsoft.Azure.EventHubs/Primitives/ITokenProvider.cs b/src/Microsoft.Azure.EventHubs/Primitives/ITokenProvider.cs new file mode 100644 index 0000000..f1c4fd9 --- /dev/null +++ b/src/Microsoft.Azure.EventHubs/Primitives/ITokenProvider.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.EventHubs +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + using System.Threading.Tasks; + + /// + /// Provides interface definition of a token provider. + /// + public interface ITokenProvider + { + /// + /// Gets a . + /// + /// The URI which the access token applies to + /// The time span that specifies the timeout value for the message that gets the security token + /// + Task GetTokenAsync(string appliesTo, TimeSpan timeout); + } +} diff --git a/src/Microsoft.Azure.EventHubs/Primitives/JsonSecurityToken.cs b/src/Microsoft.Azure.EventHubs/Primitives/JsonSecurityToken.cs new file mode 100644 index 0000000..ae2f789 --- /dev/null +++ b/src/Microsoft.Azure.EventHubs/Primitives/JsonSecurityToken.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.EventHubs +{ + using System; + using System.Collections.ObjectModel; + using System.IdentityModel.Tokens; + using System.IdentityModel.Tokens.Jwt; + + /// + /// Extends SecurityToken for JWT specific properties + /// + public class JsonSecurityToken : SecurityToken + { + /// + /// Creates a new instance of the class. + /// + /// Raw JSON Web Token string + /// The audience + public JsonSecurityToken(string rawToken, string audience) + : base(rawToken, GetExpirationDateTimeUtcFromToken(rawToken), audience, ClientConstants.JsonWebTokenType) + { + } + + static DateTime GetExpirationDateTimeUtcFromToken(string token) + { + var jwtSecurityToken = new JwtSecurityToken(token); + return jwtSecurityToken.ValidTo; + } + } +} diff --git a/src/Microsoft.Azure.EventHubs/Primitives/ManagedServiceIdentityTokenProvider.cs b/src/Microsoft.Azure.EventHubs/Primitives/ManagedServiceIdentityTokenProvider.cs new file mode 100644 index 0000000..ff3618b --- /dev/null +++ b/src/Microsoft.Azure.EventHubs/Primitives/ManagedServiceIdentityTokenProvider.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.EventHubs +{ + using System; + using System.Threading.Tasks; + using Azure.Services.AppAuthentication; + + /// + /// Represents the Azure Active Directory token provider for Azure Managed Service Identity integration. + /// + public class ManagedServiceIdentityTokenProvider : TokenProvider + { + static AzureServiceTokenProvider azureServiceTokenProvider = new AzureServiceTokenProvider(); + + /// + /// Gets a for the given audience and duration. + /// + /// The URI which the access token applies to + /// The time span that specifies the timeout value for the message that gets the security token + /// + public async override Task GetTokenAsync(string appliesTo, TimeSpan timeout) + { + string accessToken = await azureServiceTokenProvider.GetAccessTokenAsync(ClientConstants.AadEventHubsAudience); + return new JsonSecurityToken(accessToken, appliesTo); + } + } +} diff --git a/src/Microsoft.Azure.EventHubs/Primitives/SecurityToken.cs b/src/Microsoft.Azure.EventHubs/Primitives/SecurityToken.cs index 9223606..fbc537a 100644 --- a/src/Microsoft.Azure.EventHubs/Primitives/SecurityToken.cs +++ b/src/Microsoft.Azure.EventHubs/Primitives/SecurityToken.cs @@ -4,80 +4,55 @@ namespace Microsoft.Azure.EventHubs { using System; - using System.Collections.Generic; - using System.Diagnostics.CodeAnalysis; - using System.Globalization; - using System.Net; /// /// Provides information about a security token such as audience, expiry time, and the string token value. /// public class SecurityToken { - // per Simple Web Token draft specification - private const string TokenAudience = "Audience"; - private const string TokenExpiresOn = "ExpiresOn"; - private const string TokenIssuer = "Issuer"; - private const string TokenDigest256 = "HMACSHA256"; - - const string InternalExpiresOnFieldName = "ExpiresOn"; - const string InternalAudienceFieldName = TokenAudience; - const string InternalKeyValueSeparator = "="; - const string InternalPairSeparator = "&"; - static readonly Func Decoder = WebUtility.UrlDecode; - static readonly DateTime EpochTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, DateTimeKind.Utc); - readonly string token; - readonly DateTime expiresAtUtc; - readonly string audience; + /// + /// Token literal + /// + string token; /// - /// Creates a new instance of the class. + /// Expiry date-time /// - /// The token - /// The expiration time - /// The audience - public SecurityToken(string tokenString, DateTime expiresAtUtc, string audience) - { - if (tokenString == null || audience == null) - { - throw Fx.Exception.ArgumentNull(tokenString == null ? nameof(tokenString) : nameof(audience)); - } + DateTime expiresAtUtc; + + /// + /// Token audience + /// + string audience; - this.token = tokenString; - this.expiresAtUtc = expiresAtUtc; - this.audience = audience; - } + /// + /// Token type + /// + string tokenType; /// /// Creates a new instance of the class. /// /// The token /// The expiration time - public SecurityToken(string tokenString, DateTime expiresAtUtc) + /// The audience + /// The type of the token + public SecurityToken(string tokenString, DateTime expiresAtUtc, string audience, string tokenType) { - if (tokenString == null) + if (string.IsNullOrEmpty(tokenString)) { - throw Fx.Exception.ArgumentNull(nameof(tokenString)); + throw Fx.Exception.ArgumentNullOrWhiteSpace(nameof(tokenString)); } - this.token = tokenString; - this.expiresAtUtc = expiresAtUtc; - this.audience = GetAudienceFromToken(tokenString); - } - - /// - /// Creates a new instance of the class. - /// - /// The token - public SecurityToken(string tokenString) - { - if (tokenString == null) + if (string.IsNullOrEmpty(audience)) { - throw Fx.Exception.ArgumentNull(nameof(tokenString)); + throw Fx.Exception.ArgumentNullOrWhiteSpace(nameof(audience)); } this.token = tokenString; - GetExpirationDateAndAudienceFromToken(tokenString, out this.expiresAtUtc, out this.audience); + this.expiresAtUtc = expiresAtUtc; + this.audience = audience; + this.tokenType = tokenType; } /// @@ -90,69 +65,14 @@ public SecurityToken(string tokenString) /// public DateTime ExpiresAtUtc => this.expiresAtUtc; - /// - protected virtual string ExpiresOnFieldName => InternalExpiresOnFieldName; - - /// - protected virtual string AudienceFieldName => InternalAudienceFieldName; - - /// - protected virtual string KeyValueSeparator => InternalKeyValueSeparator; - - /// - protected virtual string PairSeparator => InternalPairSeparator; - /// /// Gets the actual token. /// - public object TokenValue => this.token; - - string GetAudienceFromToken(string token) - { - string audience; - IDictionary decodedToken = Decode(token, Decoder, Decoder, this.KeyValueSeparator, this.PairSeparator); - if (!decodedToken.TryGetValue(AudienceFieldName, out audience)) - { - throw new FormatException(Resources.TokenMissingAudience); - } - - return audience; - } - - void GetExpirationDateAndAudienceFromToken(string token, out DateTime expiresOn, out string audience) - { - string expiresIn; - IDictionary decodedToken = Decode(token, Decoder, Decoder, this.KeyValueSeparator, this.PairSeparator); - if (!decodedToken.TryGetValue(ExpiresOnFieldName, out expiresIn)) - { - throw new FormatException(Resources.TokenMissingExpiresOn); - } + public virtual string TokenValue => this.token; - if (!decodedToken.TryGetValue(AudienceFieldName, out audience)) - { - throw new FormatException(Resources.TokenMissingAudience); - } - - expiresOn = (EpochTime + TimeSpan.FromSeconds(double.Parse(expiresIn, CultureInfo.InvariantCulture))); - } - - static IDictionary Decode(string encodedString, Func keyDecoder, Func valueDecoder, string keyValueSeparator, string pairSeparator) - { - IDictionary dictionary = new Dictionary(); - IEnumerable valueEncodedPairs = encodedString.Split(new[] { pairSeparator }, StringSplitOptions.None); - foreach (string valueEncodedPair in valueEncodedPairs) - { - string[] pair = valueEncodedPair.Split(new[] { keyValueSeparator }, StringSplitOptions.None); - if (pair.Length != 2) - { - throw new FormatException(Resources.InvalidEncoding); - } - - dictionary.Add(keyDecoder(pair[0]), valueDecoder(pair[1])); - } - - return dictionary; - } + /// + /// Gets the token type. + /// + public virtual string TokenType => this.tokenType; } - } diff --git a/src/Microsoft.Azure.EventHubs/Primitives/SharedAccessSignatureToken.cs b/src/Microsoft.Azure.EventHubs/Primitives/SharedAccessSignatureToken.cs new file mode 100644 index 0000000..dec4701 --- /dev/null +++ b/src/Microsoft.Azure.EventHubs/Primitives/SharedAccessSignatureToken.cs @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.EventHubs +{ + using System; + using System.Collections.Generic; + using System.Globalization; + using System.Net; + + /// + /// A WCF SecurityToken that wraps a Shared Access Signature + /// + class SharedAccessSignatureToken : SecurityToken + { + internal const string SharedAccessSignature = "SharedAccessSignature"; + internal const string SignedResource = "sr"; + internal const string Signature = "sig"; + internal const string SignedKeyName = "skn"; + internal const string SignedExpiry = "se"; + internal const int MaxKeyNameLength = 256; + internal const int MaxKeyLength = 256; + + const string SignedResourceFullFieldName = SharedAccessSignature + " " + SignedResource; + const string SasPairSeparator = "&"; + const string SasKeyValueSeparator = "="; + + static readonly Func Decoder = WebUtility.UrlDecode; + + /// + /// Creates a new instance of the class. + /// + /// The token + public SharedAccessSignatureToken(string tokenString) + : base(tokenString, GetExpirationDateTimeUtcFromToken(tokenString), GetAudienceFromToken(tokenString), ClientConstants.SasTokenType) + { + } + + internal static void Validate(string sharedAccessSignature) + { + if (string.IsNullOrEmpty(sharedAccessSignature)) + { + throw new ArgumentNullException(nameof(sharedAccessSignature)); + } + + IDictionary parsedFields = ExtractFieldValues(sharedAccessSignature); + + string signature; + if (!parsedFields.TryGetValue(Signature, out signature)) + { + throw new ArgumentNullException(Signature); + } + + string expiry; + if (!parsedFields.TryGetValue(SignedExpiry, out expiry)) + { + throw new ArgumentNullException(SignedExpiry); + } + + string keyName; + if (!parsedFields.TryGetValue(SignedKeyName, out keyName)) + { + throw new ArgumentNullException(SignedKeyName); + } + + string encodedAudience; + if (!parsedFields.TryGetValue(SignedResource, out encodedAudience)) + { + throw new ArgumentNullException(SignedResource); + } + } + + static IDictionary ExtractFieldValues(string sharedAccessSignature) + { + string[] tokenLines = sharedAccessSignature.Split(); + + if (!string.Equals(tokenLines[0].Trim(), SharedAccessSignature, StringComparison.OrdinalIgnoreCase) || tokenLines.Length != 2) + { + throw new ArgumentNullException(nameof(sharedAccessSignature)); + } + + IDictionary parsedFields = new Dictionary(StringComparer.OrdinalIgnoreCase); + string[] tokenFields = tokenLines[1].Trim().Split(new[] { SasPairSeparator }, StringSplitOptions.None); + + foreach (string tokenField in tokenFields) + { + if (tokenField != string.Empty) + { + string[] fieldParts = tokenField.Split(new[] { SasKeyValueSeparator }, StringSplitOptions.None); + if (string.Equals(fieldParts[0], SignedResource, StringComparison.OrdinalIgnoreCase)) + { + // We need to preserve the casing of the escape characters in the audience, + // so defer decoding the URL until later. + parsedFields.Add(fieldParts[0], fieldParts[1]); + } + else + { + parsedFields.Add(fieldParts[0], WebUtility.UrlDecode(fieldParts[1])); + } + } + } + + return parsedFields; + } + + static string GetAudienceFromToken(string token) + { + string audience; + IDictionary decodedToken = Decode(token, Decoder, Decoder, SasKeyValueSeparator, SasPairSeparator); + if (!decodedToken.TryGetValue(SignedResourceFullFieldName, out audience)) + { + throw new FormatException(Resources.TokenMissingAudience); + } + + return audience; + } + + static DateTime GetExpirationDateTimeUtcFromToken(string token) + { + string expiresIn; + IDictionary decodedToken = Decode(token, Decoder, Decoder, SasKeyValueSeparator, SasPairSeparator); + if (!decodedToken.TryGetValue(SignedExpiry, out expiresIn)) + { + throw new FormatException(Resources.TokenMissingExpiresOn); + } + + var expiresOn = (SharedAccessSignatureTokenProvider.EpochTime + TimeSpan.FromSeconds(double.Parse(expiresIn, CultureInfo.InvariantCulture))); + + return expiresOn; + } + + static IDictionary Decode(string encodedString, Func keyDecoder, Func valueDecoder, string keyValueSeparator, string pairSeparator) + { + IDictionary dictionary = new Dictionary(); + IEnumerable valueEncodedPairs = encodedString.Split(new[] { pairSeparator }, StringSplitOptions.None); + foreach (string valueEncodedPair in valueEncodedPairs) + { + string[] pair = valueEncodedPair.Split(new[] { keyValueSeparator }, StringSplitOptions.None); + if (pair.Length != 2) + { + throw new FormatException(Resources.InvalidEncoding); + } + + dictionary.Add(keyDecoder(pair[0]), valueDecoder(pair[1])); + } + + return dictionary; + } + } +} diff --git a/src/Microsoft.Azure.EventHubs/Primitives/SharedAccessSignatureTokenProvider.cs b/src/Microsoft.Azure.EventHubs/Primitives/SharedAccessSignatureTokenProvider.cs index 187df6f..58730a9 100644 --- a/src/Microsoft.Azure.EventHubs/Primitives/SharedAccessSignatureTokenProvider.cs +++ b/src/Microsoft.Azure.EventHubs/Primitives/SharedAccessSignatureTokenProvider.cs @@ -17,22 +17,33 @@ namespace Microsoft.Azure.EventHubs /// public class SharedAccessSignatureTokenProvider : TokenProvider { + const TokenScope DefaultTokenScope = TokenScope.Entity; + + internal static readonly TimeSpan DefaultTokenTimeout = TimeSpan.FromMinutes(60); + /// /// Represents 00:00:00 UTC Thursday 1, January 1970. /// public static readonly DateTime EpochTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, DateTimeKind.Utc); + readonly byte[] encodedSharedAccessKey; readonly string keyName; readonly TimeSpan tokenTimeToLive; + readonly TokenScope tokenScope; readonly string sharedAccessSignature; + internal static readonly Func MessagingTokenProviderKeyEncoder = Encoding.UTF8.GetBytes; internal SharedAccessSignatureTokenProvider(string sharedAccessSignature) - : base(TokenScope.Entity) { SharedAccessSignatureToken.Validate(sharedAccessSignature); this.sharedAccessSignature = sharedAccessSignature; } + internal SharedAccessSignatureTokenProvider(string keyName, string sharedAccessKey, TokenScope tokenScope = TokenScope.Entity) + : this(keyName, sharedAccessKey, MessagingTokenProviderKeyEncoder, DefaultTokenTimeout, tokenScope) + { + } + internal SharedAccessSignatureTokenProvider(string keyName, string sharedAccessKey, TimeSpan tokenTimeToLive, TokenScope tokenScope = TokenScope.Entity) : this(keyName, sharedAccessKey, MessagingTokenProviderKeyEncoder, tokenTimeToLive, tokenScope) { @@ -45,7 +56,6 @@ internal SharedAccessSignatureTokenProvider(string keyName, string sharedAccessK /// /// protected SharedAccessSignatureTokenProvider(string keyName, string sharedAccessKey, Func customKeyEncoder, TimeSpan tokenTimeToLive, TokenScope tokenScope) - : base(tokenScope) { if (string.IsNullOrEmpty(keyName)) { @@ -76,15 +86,19 @@ protected SharedAccessSignatureTokenProvider(string keyName, string sharedAccess this.encodedSharedAccessKey = customKeyEncoder != null ? customKeyEncoder(sharedAccessKey) : MessagingTokenProviderKeyEncoder(sharedAccessKey); + this.tokenScope = tokenScope; } - /// - /// - /// - /// - /// - protected override Task OnGetTokenAsync(string appliesTo, string action, TimeSpan timeout) + /// + /// Gets a for the given audience and duration. + /// + /// The URI which the access token applies to + /// The time span that specifies the timeout value for the message that gets the security token + /// + public override Task GetTokenAsync(string appliesTo, TimeSpan timeout) { + TimeoutHelper.ThrowIfNegativeArgument(timeout); + appliesTo = NormalizeAppliesTo(appliesTo); string tokenString = this.BuildSignature(appliesTo); var securityToken = new SharedAccessSignatureToken(tokenString); return Task.FromResult(securityToken); @@ -104,6 +118,11 @@ protected virtual string BuildSignature(string targetUri) : this.sharedAccessSignature; } + string NormalizeAppliesTo(string appliesTo) + { + return EventHubsUriHelper.NormalizeUri(appliesTo, "http", true, stripPath: this.tokenScope == TokenScope.Namespace, ensureTrailingSlash: true); + } + static class SharedAccessSignatureBuilder { [SuppressMessage("Microsoft.Globalization", "CA1308:NormalizeStringsToUppercase", Justification = "Uris are normalized to lowercase")] @@ -152,102 +171,5 @@ static string Sign(string requestString, byte[] encodedSharedAccessKey) } } } - - /// - /// A WCF SecurityToken that wraps a Shared Access Signature - /// - class SharedAccessSignatureToken : SecurityToken - { - public const int MaxKeyNameLength = 256; - public const int MaxKeyLength = 256; - public const string SharedAccessSignature = "SharedAccessSignature"; - public const string SignedResource = "sr"; - public const string Signature = "sig"; - public const string SignedKeyName = "skn"; - public const string SignedExpiry = "se"; - public const string SignedResourceFullFieldName = SharedAccessSignature + " " + SignedResource; - public const string SasKeyValueSeparator = "="; - public const string SasPairSeparator = "&"; - - public SharedAccessSignatureToken(string tokenString) - : base(tokenString) - { - } - - protected override string AudienceFieldName => SignedResourceFullFieldName; - - protected override string ExpiresOnFieldName => SignedExpiry; - - protected override string KeyValueSeparator => SasKeyValueSeparator; - - protected override string PairSeparator => SasPairSeparator; - - internal static void Validate(string sharedAccessSignature) - { - if (string.IsNullOrEmpty(sharedAccessSignature)) - { - throw new ArgumentNullException(nameof(sharedAccessSignature)); - } - - IDictionary parsedFields = ExtractFieldValues(sharedAccessSignature); - - string signature; - if (!parsedFields.TryGetValue(Signature, out signature)) - { - throw new ArgumentNullException(Signature); - } - - string expiry; - if (!parsedFields.TryGetValue(SignedExpiry, out expiry)) - { - throw new ArgumentNullException(SignedExpiry); - } - - string keyName; - if (!parsedFields.TryGetValue(SignedKeyName, out keyName)) - { - throw new ArgumentNullException(SignedKeyName); - } - - string encodedAudience; - if (!parsedFields.TryGetValue(SignedResource, out encodedAudience)) - { - throw new ArgumentNullException(SignedResource); - } - } - - static IDictionary ExtractFieldValues(string sharedAccessSignature) - { - string[] tokenLines = sharedAccessSignature.Split(); - - if (!string.Equals(tokenLines[0].Trim(), SharedAccessSignature, StringComparison.OrdinalIgnoreCase) || tokenLines.Length != 2) - { - throw new ArgumentNullException(nameof(sharedAccessSignature)); - } - - IDictionary parsedFields = new Dictionary(StringComparer.OrdinalIgnoreCase); - string[] tokenFields = tokenLines[1].Trim().Split(new[] { SasPairSeparator }, StringSplitOptions.None); - - foreach (string tokenField in tokenFields) - { - if (tokenField != string.Empty) - { - string[] fieldParts = tokenField.Split(new[] { SasKeyValueSeparator }, StringSplitOptions.None); - if (string.Equals(fieldParts[0], SignedResource, StringComparison.OrdinalIgnoreCase)) - { - // We need to preserve the casing of the escape characters in the audience, - // so defer decoding the URL until later. - parsedFields.Add(fieldParts[0], fieldParts[1]); - } - else - { - parsedFields.Add(fieldParts[0], WebUtility.UrlDecode(fieldParts[1])); - } - } - } - - return parsedFields; - } - } } } diff --git a/src/Microsoft.Azure.EventHubs/Primitives/TokenProvider.cs b/src/Microsoft.Azure.EventHubs/Primitives/TokenProvider.cs index c29f505..145049e 100644 --- a/src/Microsoft.Azure.EventHubs/Primitives/TokenProvider.cs +++ b/src/Microsoft.Azure.EventHubs/Primitives/TokenProvider.cs @@ -4,40 +4,14 @@ namespace Microsoft.Azure.EventHubs { using System; - using System.Text; using System.Threading.Tasks; + using Microsoft.IdentityModel.Clients.ActiveDirectory; /// /// This abstract base class can be extended to implement additional token providers. /// - public abstract class TokenProvider + public abstract class TokenProvider : ITokenProvider { - internal static readonly TimeSpan DefaultTokenTimeout = TimeSpan.FromMinutes(60); - internal static readonly Func MessagingTokenProviderKeyEncoder = Encoding.UTF8.GetBytes; - const TokenScope DefaultTokenScope = TokenScope.Entity; - - /// - protected TokenProvider() - : this(TokenProvider.DefaultTokenScope) - { - } - - /// - /// - protected TokenProvider(TokenScope tokenScope) - { - this.TokenScope = tokenScope; - this.ThisLock = new object(); - } - - /// - /// Gets the scope or permissions associated with the token. - /// - public TokenScope TokenScope { get; } - - /// - protected object ThisLock { get; } - /// /// Construct a TokenProvider based on a sharedAccessSignature. /// @@ -56,7 +30,7 @@ public static TokenProvider CreateSharedAccessSignatureTokenProvider(string shar /// A TokenProvider initialized with the provided RuleId and Password public static TokenProvider CreateSharedAccessSignatureTokenProvider(string keyName, string sharedAccessKey) { - return new SharedAccessSignatureTokenProvider(keyName, sharedAccessKey, DefaultTokenTimeout); + return new SharedAccessSignatureTokenProvider(keyName, sharedAccessKey); } //internal static TokenProvider CreateIoTTokenProvider(string keyName, string sharedAccessKey) @@ -85,7 +59,7 @@ public static TokenProvider CreateSharedAccessSignatureTokenProvider(string keyN /// A TokenProvider initialized with the provided RuleId and Password public static TokenProvider CreateSharedAccessSignatureTokenProvider(string keyName, string sharedAccessKey, TokenScope tokenScope) { - return new SharedAccessSignatureTokenProvider(keyName, sharedAccessKey, DefaultTokenTimeout, tokenScope); + return new SharedAccessSignatureTokenProvider(keyName, sharedAccessKey, tokenScope); } /// @@ -101,33 +75,96 @@ public static TokenProvider CreateSharedAccessSignatureTokenProvider(string keyN return new SharedAccessSignatureTokenProvider(keyName, sharedAccessKey, tokenTimeToLive, tokenScope); } - /// - /// Gets a for the given audience and duration. - /// - /// The URI which the access token applies to - /// The request action - /// The time span that specifies the timeout value for the message that gets the security token - /// - public Task GetTokenAsync(string appliesTo, string action, TimeSpan timeout) + /// Creates an Azure Active Directory token provider. + /// AuthenticationContext for AAD. + /// The app credential. + /// The for returning Json web token. + public static TokenProvider CreateAadTokenProvider(AuthenticationContext authContext, ClientCredential clientCredential) { - TimeoutHelper.ThrowIfNegativeArgument(timeout); - appliesTo = NormalizeAppliesTo(appliesTo); - return this.OnGetTokenAsync(appliesTo, action, timeout); + if (authContext == null) + { + throw new ArgumentNullException(nameof(authContext)); + } + + if (clientCredential == null) + { + throw new ArgumentNullException(nameof(clientCredential)); + } + + return new AzureActiveDirectoryTokenProvider(authContext, clientCredential); } - /// - /// - /// - /// - /// - protected abstract Task OnGetTokenAsync(string appliesTo, string action, TimeSpan timeout); + /// Creates an Azure Active Directory token provider. + /// AuthenticationContext for AAD. + /// ClientId for AAD. + /// The redirectUri on Client App. + /// Platform parameters + /// User Identifier + /// The for returning Json web token. + public static TokenProvider CreateAadTokenProvider( + AuthenticationContext authContext, + string clientId, + Uri redirectUri, + IPlatformParameters platformParameters, + UserIdentifier userIdentifier = null) + { + if (authContext == null) + { + throw new ArgumentNullException(nameof(authContext)); + } - /// - /// - /// - protected virtual string NormalizeAppliesTo(string appliesTo) + if (string.IsNullOrEmpty(clientId)) + { + throw new ArgumentNullException(nameof(clientId)); + } + + if (redirectUri == null) + { + throw new ArgumentNullException(nameof(redirectUri)); + } + + if (platformParameters == null) + { + throw new ArgumentNullException(nameof(platformParameters)); + } + + return new AzureActiveDirectoryTokenProvider(authContext, clientId, redirectUri, platformParameters, userIdentifier); + } + +#if !UAP10_0 + /// Creates an Azure Active Directory token provider. + /// AuthenticationContext for AAD. + /// The client assertion certificate credential. + /// The for returning Json web token. + public static TokenProvider CreateAadTokenProvider(AuthenticationContext authContext, ClientAssertionCertificate clientAssertionCertificate) { - return EventHubsUriHelper.NormalizeUri(appliesTo, "http", true, stripPath: this.TokenScope == TokenScope.Namespace, ensureTrailingSlash: true); + if (authContext == null) + { + throw new ArgumentNullException(nameof(authContext)); + } + + if (clientAssertionCertificate == null) + { + throw new ArgumentNullException(nameof(clientAssertionCertificate)); + } + + return new AzureActiveDirectoryTokenProvider(authContext, clientAssertionCertificate); } +#endif + + /// Creates Azure Managed Service Identity token provider. + /// The for returning Json web token. + public static TokenProvider CreateManagedServiceIdentityTokenProvider() + { + return new ManagedServiceIdentityTokenProvider(); + } + + /// + /// Gets a for the given audience and duration. + /// + /// The URI which the access token applies to + /// The time span that specifies the timeout value for the message that gets the security token + /// + public abstract Task GetTokenAsync(string appliesTo, TimeSpan timeout); } } diff --git a/test/Microsoft.Azure.EventHubs.Tests/Client/ClientTestBase.cs b/test/Microsoft.Azure.EventHubs.Tests/Client/ClientTestBase.cs index 44e90ad..0ed71c3 100644 --- a/test/Microsoft.Azure.EventHubs.Tests/Client/ClientTestBase.cs +++ b/test/Microsoft.Azure.EventHubs.Tests/Client/ClientTestBase.cs @@ -18,7 +18,7 @@ public ClientTestBase() { // Create default EH client. this.EventHubClient = EventHubClient.CreateFromConnectionString(TestUtility.EventHubsConnectionString); - + // Discover partition ids. var eventHubInfo = this.EventHubClient.GetRuntimeInformationAsync().Result; this.PartitionIds = eventHubInfo.PartitionIds; diff --git a/test/Microsoft.Azure.EventHubs.Tests/Client/ConnectionStringBuilderTests.cs b/test/Microsoft.Azure.EventHubs.Tests/Client/ConnectionStringBuilderTests.cs index 33eb6b9..5628a54 100644 --- a/test/Microsoft.Azure.EventHubs.Tests/Client/ConnectionStringBuilderTests.cs +++ b/test/Microsoft.Azure.EventHubs.Tests/Client/ConnectionStringBuilderTests.cs @@ -121,7 +121,7 @@ async Task UseSharedAccessSignatureApi() // Generate shared access token. var csb = new EventHubsConnectionStringBuilder(TestUtility.EventHubsConnectionString); var tokenProvider = TokenProvider.CreateSharedAccessSignatureTokenProvider(csb.SasKeyName, csb.SasKey); - var token = await tokenProvider.GetTokenAsync(csb.Endpoint.ToString(), "Send,Receive", TimeSpan.FromSeconds(120)); + var token = await tokenProvider.GetTokenAsync(csb.Endpoint.ToString(), TimeSpan.FromSeconds(120)); var sharedAccessSignature = token.TokenValue.ToString(); // Create connection string builder by SharedAccessSignature overload. diff --git a/test/Microsoft.Azure.EventHubs.Tests/Client/MiscTests.cs b/test/Microsoft.Azure.EventHubs.Tests/Client/MiscTests.cs index 574db00..796b534 100644 --- a/test/Microsoft.Azure.EventHubs.Tests/Client/MiscTests.cs +++ b/test/Microsoft.Azure.EventHubs.Tests/Client/MiscTests.cs @@ -75,46 +75,5 @@ async Task PartitionKeyValidation() Assert.True(totalReceived == NumberOfMessagesToSend, $"Didn't receive the same number of messages that we sent. Sent: {NumberOfMessagesToSend}, Received: {totalReceived}"); } - - [Fact] - [DisplayTestMethodName] - async Task UseSharedAccessSignature() - { - // Generate shared access token. - var csb = new EventHubsConnectionStringBuilder(TestUtility.EventHubsConnectionString); - var tokenProvider = TokenProvider.CreateSharedAccessSignatureTokenProvider(csb.SasKeyName, csb.SasKey); - var token = await tokenProvider.GetTokenAsync(csb.Endpoint.ToString(), "Send,Receive", TimeSpan.FromSeconds(120)); - var sas = token.TokenValue.ToString(); - - // Update connection string builder to use shared access signature instead. - csb.SasKey = ""; - csb.SasKeyName = ""; - csb.SharedAccessSignature = sas; - - // Create new client with updated connection string. - var ehClient = EventHubClient.CreateFromConnectionString(csb.ToString()); - - // Send one event - TestUtility.Log($"Sending one message."); - var ehSender = ehClient.CreatePartitionSender("0"); - var eventData = new EventData(Encoding.UTF8.GetBytes("Hello EventHub by partitionKey!")); - await ehSender.SendAsync(eventData); - - // Receive event. - TestUtility.Log($"Receiving one message."); - var ehReceiver = ehClient.CreateReceiver(PartitionReceiver.DefaultConsumerGroupName, "0", PartitionReceiver.StartOfStream); - var msg = await ehReceiver.ReceiveAsync(1); - Assert.True(msg != null, "Failed to receive message."); - - // Get EH runtime information. - TestUtility.Log($"Getting Event Hub runtime information."); - var ehInfo = await ehClient.GetRuntimeInformationAsync(); - Assert.True(ehInfo != null, "Failed to get runtime information."); - - // Get EH partition runtime information. - TestUtility.Log($"Getting Event Hub partition '0' runtime information."); - var partitionInfo = await ehClient.GetPartitionRuntimeInformationAsync("0"); - Assert.True(ehInfo != null, "Failed to get runtime partition information."); - } } } diff --git a/test/Microsoft.Azure.EventHubs.Tests/Client/TokenProviderTests.cs b/test/Microsoft.Azure.EventHubs.Tests/Client/TokenProviderTests.cs new file mode 100644 index 0000000..7e8be95 --- /dev/null +++ b/test/Microsoft.Azure.EventHubs.Tests/Client/TokenProviderTests.cs @@ -0,0 +1,168 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.EventHubs.Tests.Client +{ + using System; + using System.Collections.Generic; + using System.Text; + using System.Threading.Tasks; + using Microsoft.IdentityModel.Clients.ActiveDirectory; + using Xunit; + + public class TokenProviderTests : ClientTestBase + { + [Fact] + [DisplayTestMethodName] + async Task UseSharedAccessSignature() + { + // Generate shared access token. + var csb = new EventHubsConnectionStringBuilder(TestUtility.EventHubsConnectionString); + var tokenProvider = TokenProvider.CreateSharedAccessSignatureTokenProvider(csb.SasKeyName, csb.SasKey); + var token = await tokenProvider.GetTokenAsync(csb.Endpoint.ToString(), TimeSpan.FromSeconds(120)); + var sas = token.TokenValue.ToString(); + + // Update connection string builder to use shared access signature instead. + csb.SasKey = ""; + csb.SasKeyName = ""; + csb.SharedAccessSignature = sas; + + // Create new client with updated connection string. + var ehClient = EventHubClient.CreateFromConnectionString(csb.ToString()); + + // Send one event + TestUtility.Log($"Sending one message."); + var ehSender = ehClient.CreatePartitionSender("0"); + var eventData = new EventData(Encoding.UTF8.GetBytes("Hello EventHub!")); + await ehSender.SendAsync(eventData); + + // Receive event. + TestUtility.Log($"Receiving one message."); + var ehReceiver = ehClient.CreateReceiver(PartitionReceiver.DefaultConsumerGroupName, "0", PartitionReceiver.StartOfStream); + var msg = await ehReceiver.ReceiveAsync(1); + Assert.True(msg != null, "Failed to receive message."); + + // Get EH runtime information. + TestUtility.Log($"Getting Event Hub runtime information."); + var ehInfo = await ehClient.GetRuntimeInformationAsync(); + Assert.True(ehInfo != null, "Failed to get runtime information."); + + // Get EH partition runtime information. + TestUtility.Log($"Getting Event Hub partition '0' runtime information."); + var partitionInfo = await ehClient.GetPartitionRuntimeInformationAsync("0"); + Assert.True(ehInfo != null, "Failed to get runtime partition information."); + } + + [Fact] + [DisplayTestMethodName] + async Task UseITokenProviderWithSas() + { + // Generate SAS token provider. + var csb = new EventHubsConnectionStringBuilder(TestUtility.EventHubsConnectionString); + var tokenProvider = TokenProvider.CreateSharedAccessSignatureTokenProvider(csb.SasKeyName, csb.SasKey); + + // Create new client with updated connection string. + var ehClient = EventHubClient.Create(csb.Endpoint, csb.EntityPath, tokenProvider); + + // Send one event + TestUtility.Log($"Sending one message."); + var ehSender = ehClient.CreatePartitionSender("0"); + var eventData = new EventData(Encoding.UTF8.GetBytes("Hello EventHub!")); + await ehSender.SendAsync(eventData); + + // Receive event. + TestUtility.Log($"Receiving one message."); + var ehReceiver = ehClient.CreateReceiver(PartitionReceiver.DefaultConsumerGroupName, "0", PartitionReceiver.StartOfStream); + var msg = await ehReceiver.ReceiveAsync(1); + Assert.True(msg != null, "Failed to receive message."); + + // Get EH runtime information. + TestUtility.Log($"Getting Event Hub runtime information."); + var ehInfo = await ehClient.GetRuntimeInformationAsync(); + Assert.True(ehInfo != null, "Failed to get runtime information."); + + // Get EH partition runtime information. + TestUtility.Log($"Getting Event Hub partition '0' runtime information."); + var partitionInfo = await ehClient.GetPartitionRuntimeInformationAsync("0"); + Assert.True(ehInfo != null, "Failed to get runtime partition information."); + } + + /// + /// This test is for manual only purpose. Fill in the tenant-id, app-id and app-secret before running. + /// + /// + [Fact] + [DisplayTestMethodName] + async Task UseITokenProviderWithAad() + { + var tenantId = ""; + var aadAppId = ""; + var aadAppSecret = ""; + + if (string.IsNullOrEmpty(tenantId)) + { + TestUtility.Log($"Skipping test during scheduled runs."); + return; + } + + var authContext = new AuthenticationContext($"https://login.windows.net/{tenantId}"); + var cc = new ClientCredential(aadAppId, aadAppSecret); + var tokenProvider = TokenProvider.CreateAadTokenProvider(authContext, cc); + + // Create new client with updated connection string. + var csb = new EventHubsConnectionStringBuilder(TestUtility.EventHubsConnectionString); + var ehClient = EventHubClient.Create(csb.Endpoint, csb.EntityPath, tokenProvider); + + // Send one event + TestUtility.Log($"Sending one message."); + var ehSender = ehClient.CreatePartitionSender("0"); + var eventData = new EventData(Encoding.UTF8.GetBytes("Hello EventHub!")); + await ehSender.SendAsync(eventData); + + // Receive event. + TestUtility.Log($"Receiving one message."); + var ehReceiver = ehClient.CreateReceiver(PartitionReceiver.DefaultConsumerGroupName, "0", PartitionReceiver.StartOfStream); + var msg = await ehReceiver.ReceiveAsync(1); + Assert.True(msg != null, "Failed to receive message."); + } + + + /// + /// This test is for manual only purpose. Fill in the tenant-id, app-id and app-secret before running. + /// + /// + [Fact] + [DisplayTestMethodName] + async Task UseCreateApiWithAad() + { + var tenantId = ""; + var aadAppId = ""; + var aadAppSecret = ""; + + if (string.IsNullOrEmpty(tenantId)) + { + TestUtility.Log($"Skipping test during scheduled runs."); + return; + } + + var authContext = new AuthenticationContext($"https://login.windows.net/{tenantId}"); + var cc = new ClientCredential(aadAppId, aadAppSecret); + + // Create new client with updated connection string. + var csb = new EventHubsConnectionStringBuilder(TestUtility.EventHubsConnectionString); + var ehClient = EventHubClient.Create(csb.Endpoint, csb.EntityPath, authContext, cc); + + // Send one event + TestUtility.Log($"Sending one message."); + var ehSender = ehClient.CreatePartitionSender("0"); + var eventData = new EventData(Encoding.UTF8.GetBytes("Hello EventHub!")); + await ehSender.SendAsync(eventData); + + // Receive event. + TestUtility.Log($"Receiving one message."); + var ehReceiver = ehClient.CreateReceiver(PartitionReceiver.DefaultConsumerGroupName, "0", PartitionReceiver.StartOfStream); + var msg = await ehReceiver.ReceiveAsync(1); + Assert.True(msg != null, "Failed to receive message."); + } + } +}