Skip to content

Commit

Permalink
Add support for X.509 auth for HTTP and MQTT over Websockets (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrohera authored Dec 7, 2018
1 parent edc15a2 commit 9b56f3d
Show file tree
Hide file tree
Showing 21 changed files with 1,063 additions and 368 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.WebSockets;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Extensions.Logging;
Expand All @@ -20,7 +22,10 @@ public AmqpWebSocketListener()
{
}

public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId)
public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId) =>
await ProcessWebSocketRequestAsync(webSocket, localEndPoint, remoteEndPoint, correlationId, Option.None<X509Certificate2>(), Option.None<IList<X509Certificate2>>());

public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId, Option<X509Certificate2> clientCert, Option<IList<X509Certificate2>> clientCertChain)
{
try
{
Expand Down Expand Up @@ -52,7 +57,6 @@ public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPo

protected override void OnListen()
{

}

static class Events
Expand All @@ -66,15 +70,11 @@ enum EventIds
Exception
}

public static void EstablishedConnection(string correlationId)
{
public static void EstablishedConnection(string correlationId) =>
Log.LogInformation((int)EventIds.Established, $"Connection established CorrelationId {correlationId}");
}

public static void FailedAcceptWebSocket(string correlationId, Exception ex)
{
public static void FailedAcceptWebSocket(string correlationId, Exception ex) =>
Log.LogWarning((int)EventIds.Exception, ex, $"Connection failed CorrelationId {correlationId}");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Azure.Devices.Edge.Hub.Core
{
using System.Collections.Generic;
using System.Net;
using System.Net.WebSockets;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.Azure.Devices.Edge.Util;

public interface IWebSocketListener
{
string SubProtocol { get; }

Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId);
Task ProcessWebSocketRequestAsync(WebSocket webSocket,
Option<EndPoint> localEndPoint,
EndPoint remoteEndPoint,
string correlationId);

Task ProcessWebSocketRequestAsync(WebSocket webSocket,
Option<EndPoint> localEndPoint,
EndPoint remoteEndPoint,
string correlationId,
Option<X509Certificate2> clientCert,
Option<IList<X509Certificate2>> clientCerthain);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.Azure.Devices.Edge.Hub.Http.Adapters
{
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Https;
using Microsoft.Azure.Devices.Edge.Hub.Http.Middleware;
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Extensions.Logging;

// https://github.com/aspnet/HttpAbstractions/issues/808
public class HttpsExtensionConnectionAdapter : IConnectionAdapter
{
// See http://oid-info.com/get/1.3.6.1.5.5.7.3.1
// Indicates that a certificate can be used as a SSL server certificate
const string ServerAuthenticationOid = "1.3.6.1.5.5.7.3.1";
const string AuthenticationSucceeded = "AuthenticationSucceeded";
internal const string DisableHandshakeTimeoutSwitch = "Switch.Microsoft.AspNetCore.Server.Kestrel.Https.DisableHandshakeTimeout";
static readonly TimeSpan HandshakeTimeout = TimeSpan.FromSeconds(10);
static readonly ClosedAdaptedConnection _closedAdaptedConnection = new ClosedAdaptedConnection();
readonly HttpsConnectionAdapterOptions options;
readonly X509Certificate2 serverCertificate;

public HttpsExtensionConnectionAdapter(HttpsConnectionAdapterOptions options)
{
this.options = Preconditions.CheckNotNull(options, nameof(options));
this.serverCertificate = Preconditions.CheckNotNull(options.ServerCertificate, nameof(options.ServerCertificate));
EnsureCertificateIsAllowedForServerAuth(this.serverCertificate);
}

public bool IsHttps => true;

public Task<IAdaptedConnection> OnConnectionAsync(ConnectionAdapterContext context) =>
Task.Run(() => InnerOnConnectionAsync(context));

async Task<IAdaptedConnection> InnerOnConnectionAsync(ConnectionAdapterContext context)
{
SslStream sslStream;
bool certificateRequired;

IList<X509Certificate2> chainElements = new List<X509Certificate2>();

if (this.options.ClientCertificateMode == ClientCertificateMode.NoCertificate)
{
sslStream = new SslStream(context.ConnectionStream);
certificateRequired = false;
}
else
{
sslStream = new SslStream(context.ConnectionStream,
leaveInnerStreamOpen: false,
userCertificateValidationCallback: (sender, certificate, chain, sslPolicyErrors) =>
{
if (certificate == null)
{
return this.options.ClientCertificateMode != ClientCertificateMode.RequireCertificate;
}

if (this.options.ClientCertificateValidation == null)
{
if (sslPolicyErrors != SslPolicyErrors.None)
{
return false;
}
}

var certificate2 = new X509Certificate2(certificate);
if (certificate2 == null)
{
return false;
}

if (this.options.ClientCertificateValidation != null)
{
if (!this.options.ClientCertificateValidation(certificate2, chain, sslPolicyErrors))
{
return false;
}
}

foreach (var element in chain.ChainElements)
{
chainElements.Add(element.Certificate);
}

return true;
});

certificateRequired = true;
}

try
{
if (AppContext.TryGetSwitch(DisableHandshakeTimeoutSwitch, out var handshakeDisabled) && handshakeDisabled)
{
await sslStream.AuthenticateAsServerAsync(
this.serverCertificate,
certificateRequired,
this.options.SslProtocols,
this.options.CheckCertificateRevocation);
}
else
{
try
{
var handshakeTask = sslStream.AuthenticateAsServerAsync(
this.serverCertificate,
certificateRequired,
this.options.SslProtocols,
this.options.CheckCertificateRevocation);
var handshakeTimeoutTask = Task.Delay(HandshakeTimeout);

var firstTask = await Task.WhenAny(handshakeTask, handshakeTimeoutTask);

if (firstTask == handshakeTimeoutTask)
{
Events.AuthenticationTimedOut();

// Observe any exception that might be raised from AuthenticateAsServerAsync after the timeout.
ObserveTaskException(handshakeTask);

// This will cause the request processing loop to exit immediately and close the underlying connection.
sslStream.Dispose();
return _closedAdaptedConnection;
}

// Observe potential handshake failures.
await handshakeTask;
}
catch (OperationCanceledException)
{
Events.AuthenticationTimedOut();
sslStream.Dispose();
return _closedAdaptedConnection;
}
}
}
catch (Exception)
{
Events.AuthenticationFailed();
sslStream.Dispose();
return _closedAdaptedConnection;
}

Events.AuthenticationSuccess();
// Always set the feature even though the cert might be null
var cert = (sslStream.RemoteCertificate != null) ? new X509Certificate2(sslStream.RemoteCertificate) : null;
context.Features.Set<ITlsConnectionFeature>(new TlsConnectionFeature
{
ClientCertificate = cert
});
context.Features.Set<ITlsConnectionFeatureExtended>(new TlsConnectionFeatureExtended
{
ChainElements = chainElements
});

return new HttpsAdaptedConnection(sslStream);
}

static void EnsureCertificateIsAllowedForServerAuth(X509Certificate2 certificate)
{
/* If the Extended Key Usage extension is included, then we check that the serverAuth usage is included. (http://oid-info.com/get/1.3.6.1.5.5.7.3.1)
* If the Extended Key Usage extension is not included, then we assume the certificate is allowed for all usages.
*
* See also https://blogs.msdn.microsoft.com/kaushal/2012/02/17/client-certificates-vs-server-certificates/
*
* From https://tools.ietf.org/html/rfc3280#section-4.2.1.13 "Certificate Extensions: Extended Key Usage"
*
* If the (Extended Key Usage) extension is present, then the certificate MUST only be used
* for one of the purposes indicated. If multiple purposes are
* indicated the application need not recognize all purposes indicated,
* as long as the intended purpose is present. Certificate using
* applications MAY require that a particular purpose be indicated in
* order for the certificate to be acceptable to that application.
*/

var hasEkuExtension = false;

foreach (var extension in certificate.Extensions.OfType<X509EnhancedKeyUsageExtension>())
{
hasEkuExtension = true;
foreach (var oid in extension.EnhancedKeyUsages)
{
if (oid.Value.Equals(ServerAuthenticationOid, StringComparison.Ordinal))
{
return;
}
}
}

if (hasEkuExtension)
{
throw new InvalidOperationException("InvalidServerCertificateEku");
}
}

static void ObserveTaskException(Task task)
{
_ = task.ContinueWith(t =>
{
_ = t.Exception;
}, TaskScheduler.Current);
}

class HttpsAdaptedConnection : IAdaptedConnection
{
readonly SslStream _sslStream;

public HttpsAdaptedConnection(SslStream sslStream)
{
_sslStream = sslStream;
}

public Stream ConnectionStream => _sslStream;

public void Dispose() => _sslStream.Dispose();
}

class ClosedAdaptedConnection : IAdaptedConnection
{
public Stream ConnectionStream { get; } = new ClosedStream();

[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", Justification = "Field does not need to be disposed.")]
public void Dispose()
{
}
}

internal class ClosedStream : Stream
{
static readonly Task<int> ZeroResultTask = Task.FromResult(result: 0);

public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => false;

public override long Length
{
get
{
throw new NotSupportedException();
}
}

public override long Position
{
get
{
throw new NotSupportedException();
}
set
{
throw new NotSupportedException();
}
}

public override void Flush()
{
}

public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();

public override void SetLength(long value) => throw new NotSupportedException();

public override int Read(byte[] buffer, int offset, int count) => 0;

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ZeroResultTask;

public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
}
static class Events
{
static readonly ILogger Log = Logger.Factory.CreateLogger<HttpsExtensionConnectionAdapter>();
const int IdStart = HttpEventIds.HttpsExtensionConnectionAdapter;

enum EventIds
{
AuthenticationTimedOut = IdStart,
AuthenticationFailed,
AuthenticationSuccess
}

public static void AuthenticationTimedOut() =>
Log.LogInformation((int)EventIds.AuthenticationTimedOut, "HttpExtensionConnectionAdapter authentication timeout");

public static void AuthenticationFailed() =>
Log.LogInformation((int)EventIds.AuthenticationFailed, "HttpExtensionConnectionAdapter authentication failed");

public static void AuthenticationSuccess() =>
Log.LogDebug((int)EventIds.AuthenticationSuccess, "HttpExtensionConnectionAdapter authentication succeeded");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.Azure.Devices.Edge.Hub.Http.Middleware
{
using Microsoft.AspNetCore.Http;
using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;

public static class HttpContextExtensions
{
public static IList<X509Certificate2> GetClientCertificateChain(this HttpContext context)
{
ITlsConnectionFeatureExtended feature = context.Features.Get<ITlsConnectionFeatureExtended>();
return (feature == null) ? new List<X509Certificate2>() : feature.ChainElements;
}
}
}
Loading

0 comments on commit 9b56f3d

Please sign in to comment.