From e82488dbbad94428f28e41ffa592d559774e8731 Mon Sep 17 00:00:00 2001 From: Tomas Weinfurt Date: Wed, 30 Sep 2020 15:37:23 -0700 Subject: [PATCH] do not call RemoteCertificateValidationCallback on the same certificate again (#42836) * do not call RemoteCertificateValidationCallback on the same certificate agin * feedback from review * feedback from review * make sure we do not use Linq --- .../HttpClientHandlerTest.SslProtocols.cs | 6 +++- .../src/System/Net/Security/SecureChannel.cs | 12 ++++++- .../SslStreamNetworkStreamTest.cs | 7 +++- .../tests/FunctionalTests/TestHelper.cs | 33 +++++++++++++++++++ 4 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.SslProtocols.cs b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.SslProtocols.cs index 17ab7a4b1c27d..b0b121d73693a 100644 --- a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.SslProtocols.cs +++ b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.SslProtocols.cs @@ -100,10 +100,12 @@ public static IEnumerable GetAsync_AllowedSSLVersion_Succeeds_MemberDa [MemberData(nameof(GetAsync_AllowedSSLVersion_Succeeds_MemberData))] public async Task GetAsync_AllowedSSLVersion_Succeeds(SslProtocols acceptedProtocol, bool requestOnlyThisProtocol) { + int count = 0; using (HttpClientHandler handler = CreateHttpClientHandler()) using (HttpClient client = CreateHttpClient(handler)) { - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; + handler.ServerCertificateCustomValidationCallback = + (request, cert, chain, errors) => { count++; return true; }; if (requestOnlyThisProtocol) { @@ -131,6 +133,8 @@ await TestHelper.WhenAllCompletedOrAnyFailed( server.AcceptConnectionSendResponseAndCloseAsync(), client.GetAsync(url)); }, options); + + Assert.Equal(1, count); } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs index bcc6b16475af0..2261580a31512 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs @@ -954,7 +954,17 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot try { - _remoteCertificate = CertificateValidationPal.GetRemoteCertificate(_securityContext, out remoteCertificateStore); + X509Certificate2? certificate = CertificateValidationPal.GetRemoteCertificate(_securityContext, out remoteCertificateStore); + + if (_remoteCertificate != null && certificate != null && + certificate.RawData.AsSpan().SequenceEqual(_remoteCertificate.RawData)) + { + // This is renegotiation or TLS 1.3 and the certificate did not change. + // There is no reason to process callback again as we already established trust. + return true; + } + + _remoteCertificate = certificate; if (_remoteCertificate == null) { diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs index 258b8a1ee2824..7882d2ac4e6b6 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs @@ -230,6 +230,7 @@ public async Task SslStream_NestedAuth_Throws() public async Task SslStream_TargetHostName_Succeeds(bool useEmptyName) { string targetName = useEmptyName ? string.Empty : Guid.NewGuid().ToString("N"); + int count = 0; (Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams(); using (clientStream) @@ -248,7 +249,7 @@ public async Task SslStream_TargetHostName_Succeeds(bool useEmptyName) { SslStream stream = (SslStream)sender; Assert.Equal(targetName, stream.TargetHostName); - + count++; return true; }; @@ -265,8 +266,12 @@ public async Task SslStream_TargetHostName_Succeeds(bool useEmptyName) await TestConfiguration.WhenAllOrAnyFailedWithTimeout( client.AuthenticateAsClientAsync(clientOptions), server.AuthenticateAsServerAsync(serverOptions)); + + await TestHelper.PingPong(client, server); + Assert.Equal(targetName, client.TargetHostName); Assert.Equal(targetName, server.TargetHostName); + Assert.Equal(1, count); } } diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs index 97e3a4a5efb9d..ad6104dd7c027 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs @@ -9,6 +9,8 @@ using System.Security.Cryptography.X509Certificates.Tests.Common; using System.Runtime.CompilerServices; using System.Text; +using System.Threading.Tasks; +using Xunit; namespace System.Net.Security.Tests { @@ -30,6 +32,9 @@ public static class TestHelper private static readonly X509BasicConstraintsExtension s_eeConstraints = new X509BasicConstraintsExtension(false, false, 0, false); + private static readonly byte[] s_ping = Encoding.UTF8.GetBytes("PING"); + private static readonly byte[] s_pong = Encoding.UTF8.GetBytes("PONG"); + public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams() { (Stream clientStream, Stream serverStream) = GetConnectedStreams(); @@ -151,5 +156,33 @@ internal static (X509Certificate2 certificate, X509Certificate2Collection) Gener return (endEntity, chain); } + + internal static async Task PingPong(SslStream client, SslStream server) + { + byte[] buffer = new byte[s_ping.Length]; + ValueTask t = client.WriteAsync(s_ping); + + int remains = s_ping.Length; + while (remains > 0) + { + int readLength = await server.ReadAsync(buffer, buffer.Length - remains, remains); + Assert.True(readLength > 0); + remains -= readLength; + } + Assert.Equal(s_ping, buffer); + await t; + + t = server.WriteAsync(s_pong); + remains = s_pong.Length; + while (remains > 0) + { + int readLength = await client.ReadAsync(buffer, buffer.Length - remains, remains); + Assert.True(readLength > 0); + remains -= readLength; + } + + Assert.Equal(s_pong, buffer); + await t; + } } }