diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index 7b0579e73c727..37998df2e9603 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -227,6 +227,7 @@ private static uint HandleEventConnected(State state, ref ConnectionEvent connec state.Connected = true; state.ConnectTcs!.SetResult(MsQuicStatusCodes.Success); + state.ConnectTcs = null; } return MsQuicStatusCodes.Success; @@ -576,7 +577,8 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'."); } - _state.ConnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + // We store TCS to local variable to avoid NRE if callbacks finish fast and set _state.ConnectTcs to null. + var tcs = _state.ConnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); try { @@ -600,7 +602,7 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d throw; } - return new ValueTask(_state.ConnectTcs.Task); + return new ValueTask(tcs.Task); } private ValueTask ShutdownAsync( @@ -683,9 +685,10 @@ private static uint NativeCallbackHandler( if (state.ConnectTcs != null) { - state.ConnectTcs.SetException(ex); - state.ConnectTcs = null; + // This is opportunistic if we get exception and have ability to propagate it to caller. + state.ConnectTcs.TrySetException(ex); state.Connection = null; + state.ConnectTcs = null; } else { diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 0d1ec65f00e79..0adb63860c4f3 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -152,7 +152,6 @@ public async Task CertificateCallbackThrowPropagates() } [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/56263")] public async Task ConnectWithCertificateCallback() { X509Certificate2 c1 = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate(); @@ -235,7 +234,6 @@ public async Task ConnectWithCertificateCallback() } [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/56454")] public async Task ConnectWithCertificateForDifferentName_Throws() { (X509Certificate2 certificate, _) = System.Net.Security.Tests.TestHelper.GenerateCertificates("localhost");