diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index 2e5d3fd815..f8f2facb59 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -138,196 +138,154 @@ internal class SNICommon internal const int LocalDBBadRuntime = 57; /// - /// We only validate Server name in Certificate to match with "targetServerName". + /// We either validate that the provided 'validationCert' matches the 'serverCert', or we validate that the server name in the 'serverCert' matches 'targetServerName'. /// Certificate validation and chain trust validations are done by SSLStream class [System.Net.Security.SecureChannel.VerifyRemoteCertificate method] /// This method is called as a result of callback for SSL Stream Certificate validation. /// + /// Connection ID/GUID for tracing /// Server that client is expecting to connect to - /// X.509 certificate + /// Optional hostname to use for server certificate validation + /// X.509 certificate from the server + /// Path to an X.509 certificate file from the application to compare with the serverCert /// Policy errors /// True if certificate is valid - internal static bool ValidateSslServerCertificate(string targetServerName, X509Certificate cert, SslPolicyErrors policyErrors) + internal static bool ValidateSslServerCertificate(Guid connectionId, string targetServerName, string hostNameInCertificate, X509Certificate serverCert, string validationCertFileName, SslPolicyErrors policyErrors) { using (TrySNIEventScope.Create("SNICommon.ValidateSslServerCertificate | SNI | SCOPE | INFO | Entering Scope {0} ")) { if (policyErrors == SslPolicyErrors.None) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "targetServerName {0}, SSL Server certificate not validated as PolicyErrors set to None.", args0: targetServerName); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Connection Id {0}, targetServerName {1}, SSL Server certificate not validated as PolicyErrors set to None.", args0: connectionId, args1: targetServerName); return true; } - // If we get to this point then there is a ssl policy flag. - StringBuilder messageBuilder = new(); - if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors)) + string serverNameToValidate; + X509Certificate validationCertificate = null; + if (!string.IsNullOrEmpty(hostNameInCertificate)) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "targetServerName {0}, SslPolicyError {1}, SSL Policy certificate chain has errors.", args0: targetServerName, args1: policyErrors); + serverNameToValidate = hostNameInCertificate; + } + else + { + serverNameToValidate = targetServerName; + } - // get the chain status from the certificate - X509Certificate2 cert2 = cert as X509Certificate2; - X509Chain chain = new(); - chain.ChainPolicy.RevocationMode = X509RevocationMode.Offline; - StringBuilder chainStatusInformation = new(); - bool chainIsValid = chain.Build(cert2); - Debug.Assert(!chainIsValid, "RemoteCertificateChainError flag is detected, but certificate chain is valid."); - if (!chainIsValid) + if (!string.IsNullOrEmpty(validationCertFileName)) + { + try { - foreach (X509ChainStatus chainStatus in chain.ChainStatus) - { - chainStatusInformation.Append($"{chainStatus.StatusInformation}, [Status: {chainStatus.Status}]"); - chainStatusInformation.AppendLine(); - } + validationCertificate = new X509Certificate(validationCertFileName); + } + catch (Exception e) + { + // if this fails, then fall back to the HostNameInCertificate or TargetServer validation. + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Exception occurred loading specified ServerCertificate: {1}, treating it as if ServerCertificate has not been specified.", args0: connectionId, args1: e.Message); } - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "targetServerName {0}, SslPolicyError {1}, SSL Policy certificate chain has errors. ChainStatus {2}", args0: targetServerName, args1: policyErrors, args2: chainStatusInformation); - messageBuilder.AppendFormat(Strings.SQL_RemoteCertificateChainErrors, chainStatusInformation); - messageBuilder.AppendLine(); } - if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable)) + if (validationCertificate != null) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "targetServerName {0}, SSL Policy invalidated certificate.", args0: targetServerName); - messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNotAvailable); + if (serverCert.GetRawCertData().AsSpan().SequenceEqual(validationCertificate.GetRawCertData().AsSpan())) + { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Connection Id {0}, ServerCertificate matches the certificate provided by the server. Certificate validation passed.", args0: connectionId); + return true; + } + else + { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Connection Id {0}, ServerCertificate doesn't match the certificate provided by the server. Certificate validation failed.", args0: connectionId); + throw ADP.SSLCertificateAuthenticationException(Strings.SQL_RemoteCertificateDoesNotMatchServerCertificate); + } } - - if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch)) + else { -#if NET8_0_OR_GREATER - X509Certificate2 cert2 = cert as X509Certificate2; - if (!cert2.MatchesHostname(targetServerName)) + // If we get to this point then there is a ssl policy flag. + StringBuilder messageBuilder = new(); + if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable)) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "targetServerName {0}, Target Server name or HNIC does not match the Subject/SAN in Certificate.", args0: targetServerName); - messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "Connection Id {0}, targetServerName {1}, SSL Server certificate not validated as PolicyErrors set to RemoteCertificateNotAvailable.", args0: connectionId, args1: targetServerName); + messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNotAvailable); } -#else - // To Do: include certificate SAN (Subject Alternative Name) check. - string certServerName = cert.Subject.Substring(cert.Subject.IndexOf('=') + 1); - // Verify that target server name matches subject in the certificate - if (targetServerName.Length > certServerName.Length) - { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "targetServerName {0}, Target Server name is of greater length than Subject in Certificate.", args0: targetServerName); - messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); - } - else if (targetServerName.Length == certServerName.Length) + if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors)) { - // Both strings have the same length, so targetServerName must be a FQDN - if (!targetServerName.Equals(certServerName, StringComparison.OrdinalIgnoreCase)) + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "Connection Id {0}, targetServerName {0}, SslPolicyError {1}, SSL Policy certificate chain has errors.", args0: connectionId, args1: targetServerName, args2: policyErrors); + + // get the chain status from the certificate + X509Certificate2 cert2 = serverCert as X509Certificate2; + X509Chain chain = new(); + chain.ChainPolicy.RevocationMode = X509RevocationMode.Offline; + StringBuilder chainStatusInformation = new(); + bool chainIsValid = chain.Build(cert2); + Debug.Assert(!chainIsValid, "RemoteCertificateChainError flag is detected, but certificate chain is valid."); + if (!chainIsValid) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "targetServerName {0}, Target Server name does not match Subject in Certificate.", args0: targetServerName); - messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); + foreach (X509ChainStatus chainStatus in chain.ChainStatus) + { + chainStatusInformation.Append($"{chainStatus.StatusInformation}, [Status: {chainStatus.Status}]"); + chainStatusInformation.AppendLine(); + } } + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "Connection Id {0}, targetServerName {1}, SslPolicyError {2}, SSL Policy certificate chain has errors. ChainStatus {3}", args0: connectionId, args1: targetServerName, args2: policyErrors, args3: chainStatusInformation); + messageBuilder.AppendFormat(Strings.SQL_RemoteCertificateChainErrors, chainStatusInformation); + messageBuilder.AppendLine(); } - else + + if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch)) { - if (string.Compare(targetServerName, 0, certServerName, 0, targetServerName.Length, StringComparison.OrdinalIgnoreCase) != 0) +#if NET8_0_OR_GREATER + X509Certificate2 cert2 = serverCert as X509Certificate2; + if (!cert2.MatchesHostname(serverNameToValidate)) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "targetServerName {0}, Target Server name does not match Subject in Certificate.", args0: targetServerName); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "Connection Id {0}, serverNameToValidate {1}, Target Server name or HNIC does not match the Subject/SAN in Certificate.", args0: connectionId, args1: serverNameToValidate); messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); } +#else + // To Do: include certificate SAN (Subject Alternative Name) check. + string certServerName = serverCert.Subject.Substring(serverCert.Subject.IndexOf('=') + 1); - // Server name matches cert name for its whole length, so ensure that the - // character following the server name is a '.'. This will avoid - // having server name "ab" match "abc.corp.company.com" - // (Names have different lengths, so the target server can't be a FQDN.) - if (certServerName[targetServerName.Length] != '.') + // Verify that target server name matches subject in the certificate + if (serverNameToValidate.Length > certServerName.Length) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "targetServerName {0}, Target Server name does not match Subject in Certificate.", args0: targetServerName); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "Connection Id {0}, serverNameToValidate {1}, Target Server name is of greater length than Subject in Certificate.", args0: connectionId, args1: serverNameToValidate); messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); } - } -#endif - } - - if (messageBuilder.Length > 0) - { - throw ADP.SSLCertificateAuthenticationException(messageBuilder.ToString()); - } - - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, " Remote certificate with subject: {0}, validated successfully.", args0: cert.Subject); - return true; - } - } - - /// - /// We validate the provided certificate provided by the client with the one from the server to see if it matches. - /// Certificate validation and chain trust validations are done by SSLStream class [System.Net.Security.SecureChannel.VerifyRemoteCertificate method] - /// This method is called as a result of callback for SSL Stream Certificate validation. - /// - /// X.509 certificate provided by the client - /// X.509 certificate provided by the server - /// Policy errors - /// True if certificate is valid - internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X509Certificate serverCert, SslPolicyErrors policyErrors) - { - using (TrySNIEventScope.Create("SNICommon.ValidateSslServerCertificate | SNI | SCOPE | INFO | Entering Scope {0} ")) - { - if (policyErrors == SslPolicyErrors.None) - { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "serverCert {0}, SSL Server certificate not validated as PolicyErrors set to None.", args0: clientCert.Subject); - return true; - } - - StringBuilder messageBuilder = new(); - if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable)) - { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "serverCert {0}, SSL Server certificate not validated as PolicyErrors set to RemoteCertificateNotAvailable.", args0: clientCert.Subject); - messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNotAvailable); - } - - if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors)) - { - // get the chain status from the server certificate - X509Certificate2 cert2 = serverCert as X509Certificate2; - X509Chain chain = new(); - chain.ChainPolicy.RevocationMode = X509RevocationMode.Offline; - StringBuilder chainStatusInformation = new(); - bool chainIsValid = chain.Build(cert2); - Debug.Assert(!chainIsValid, "RemoteCertificateChainError flag is detected, but certificate chain is valid."); - if (!chainIsValid) - { - foreach (X509ChainStatus chainStatus in chain.ChainStatus) + else if (serverNameToValidate.Length == certServerName.Length) { - chainStatusInformation.Append($"{chainStatus.StatusInformation}, [Status: {chainStatus.Status}]"); - chainStatusInformation.AppendLine(); + // Both strings have the same length, so serverNameToValidate must be a FQDN + if (!serverNameToValidate.Equals(certServerName, StringComparison.OrdinalIgnoreCase)) + { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "Connection Id {0}, serverNameToValidate {1}, Target Server name does not match Subject in Certificate.", args0: connectionId, args1: serverNameToValidate); + messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); + } } - } - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "certificate subject from server is {0}, and does not match with the certificate provided client.", args0: cert2.SubjectName.Name); - messageBuilder.AppendFormat(Strings.SQL_RemoteCertificateChainErrors, chainStatusInformation); - messageBuilder.AppendLine(); - } - - if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch)) - { -#if NET8_0_OR_GREATER - X509Certificate2 s_cert = serverCert as X509Certificate2; - X509Certificate2 c_cert = clientCert as X509Certificate2; - - if (!s_cert.MatchesHostname(c_cert.SubjectName.Name)) - { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "certificate from server does not match with the certificate provided client.", args0: s_cert.Subject); - messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); - } -#else - // Verify that subject name matches - if (serverCert.Subject != clientCert.Subject) - { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "certificate subject from server is {0}, and does not match with the certificate provided client.", args0: serverCert.Subject); - messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); + else + { + if (string.Compare(serverNameToValidate, 0, certServerName, 0, serverNameToValidate.Length, StringComparison.OrdinalIgnoreCase) != 0) + { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "Connection Id {0}, serverNameToValidate {1}, Target Server name does not match Subject in Certificate.", args0: connectionId, args1: serverNameToValidate); + messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); + } + + // Server name matches cert name for its whole length, so ensure that the + // character following the server name is a '.'. This will avoid + // having server name "ab" match "abc.corp.company.com" + // (Names have different lengths, so the target server can't be a FQDN.) + if (certServerName[serverNameToValidate.Length] != '.') + { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "Connection Id {0}, serverNameToValidate {1}, Target Server name does not match Subject in Certificate.", args0: connectionId, args1: serverNameToValidate); + messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); + } + } +#endif } - if (!serverCert.Equals(clientCert)) + if (messageBuilder.Length > 0) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.ERR, "certificate from server does not match with the certificate provided client.", args0: serverCert.Subject); - messageBuilder.AppendLine(Strings.SQL_RemoteCertificateNameMismatch); + throw ADP.SSLCertificateAuthenticationException(messageBuilder.ToString()); } -#endif } - if (messageBuilder.Length > 0) - { - throw ADP.SSLCertificateAuthenticationException(messageBuilder.ToString()); - } - - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "certificate subject {0}, Client certificate validated successfully.", args0: clientCert.Subject); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Connection Id {0}, certificate with subject: {1}, validated successfully.", args0: connectionId, args1: serverCert.Subject); return true; } } @@ -342,11 +300,11 @@ internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer ti args0: serverName, args1: remainingTimeout); using CancellationTokenSource cts = new CancellationTokenSource(remainingTimeout); - - return Dns.GetHostAddressesAsync(serverName, cts.Token) - .ConfigureAwait(false) - .GetAwaiter() - .GetResult(); + // using this overload to support netstandard + Task task = Dns.GetHostAddressesAsync(serverName); + task.ConfigureAwait(false); + task.Wait(cts.Token); + return task.Result; } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs index 2889ce6bb4..8f8af57f58 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs @@ -24,6 +24,8 @@ internal sealed class SNINpHandle : SNIPhysicalHandle private readonly string _targetServer; private readonly object _sendSync; + private readonly string _hostNameInCertificate; + private readonly string _serverCertificateFilename; private readonly bool _tlsFirst; private Stream _stream; private NamedPipeClientStream _pipeStream; @@ -38,7 +40,7 @@ internal sealed class SNINpHandle : SNIPhysicalHandle private int _bufferSize = TdsEnums.DEFAULT_LOGIN_PACKET_SIZE; private readonly Guid _connectionId = Guid.NewGuid(); - public SNINpHandle(string serverName, string pipeName, TimeoutTimer timeout, bool tlsFirst) + public SNINpHandle(string serverName, string pipeName, TimeoutTimer timeout, bool tlsFirst, string hostNameInCertificate, string serverCertificateFilename) { using (TrySNIEventScope.Create(nameof(SNINpHandle))) { @@ -47,6 +49,8 @@ public SNINpHandle(string serverName, string pipeName, TimeoutTimer timeout, boo _sendSync = new object(); _targetServer = serverName; _tlsFirst = tlsFirst; + _hostNameInCertificate = hostNameInCertificate; + _serverCertificateFilename = serverCertificateFilename; try { _pipeStream = new NamedPipeClientStream( @@ -369,14 +373,14 @@ public override void DisableSsl() /// Validate server certificate /// /// Sender object - /// X.509 certificate + /// X.509 certificate /// X.509 chain /// Policy errors /// true if valid - private bool ValidateServerCertificate(object sender, X509Certificate cert, X509Chain chain, SslPolicyErrors policyErrors) + private bool ValidateServerCertificate(object sender, X509Certificate serverCertificate, X509Chain chain, SslPolicyErrors policyErrors) { using (TrySNIEventScope.Create(nameof(SNINpHandle))) - { + { if (!_validateCert) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, "Connection Id {0}, Certificate validation not requested.", args0: ConnectionId); @@ -384,8 +388,8 @@ private bool ValidateServerCertificate(object sender, X509Certificate cert, X509 } SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, "Connection Id {0}, Proceeding to SSL certificate validation.", args0: ConnectionId); - return SNICommon.ValidateSslServerCertificate(_targetServer, cert, policyErrors); - } + return SNICommon.ValidateSslServerCertificate(_connectionId, _targetServer, _hostNameInCertificate, serverCertificate, _serverCertificateFilename, policyErrors); + } } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 77b44a1cf5..0ee31e5f46 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -206,7 +206,7 @@ internal static SNIHandle CreateConnectionHandle( tlsFirst, hostNameInCertificate, serverCertificateFilename); break; case DataSource.Protocol.NP: - sniHandle = CreateNpHandle(details, timeout, parallel, tlsFirst); + sniHandle = CreateNpHandle(details, timeout, parallel, tlsFirst, hostNameInCertificate, serverCertificateFilename); break; default: Debug.Fail($"Unexpected connection protocol: {details._connectionProtocol}"); @@ -362,8 +362,10 @@ private static SNITCPHandle CreateTcpHandle( /// Timer expiration /// Should MultiSubnetFailover be used. Only returns an error for named pipes. /// + /// Host name in certificate + /// Used for the path to the Server Certificate /// SNINpHandle - private static SNINpHandle CreateNpHandle(DataSource details, TimeoutTimer timeout, bool parallel, bool tlsFirst) + private static SNINpHandle CreateNpHandle(DataSource details, TimeoutTimer timeout, bool parallel, bool tlsFirst, string hostNameInCertificate, string serverCertificateFilename) { if (parallel) { @@ -371,7 +373,7 @@ private static SNINpHandle CreateNpHandle(DataSource details, TimeoutTimer timeo SNICommon.ReportSNIError(SNIProviders.NP_PROV, 0, SNICommon.MultiSubnetFailoverWithNonTcpProtocol, Strings.SNI_ERROR_49); return null; } - return new SNINpHandle(details.PipeHostName, details.PipeName, timeout, tlsFirst); + return new SNINpHandle(details.PipeHostName, details.PipeName, timeout, tlsFirst, hostNameInCertificate, serverCertificateFilename); } /// @@ -539,8 +541,10 @@ private void PopulateProtocol() internal static string GetLocalDBInstance(string dataSource, out bool error) { string instanceName = null; + // ReadOnlySpan is not supported in netstandard 2.0, but installing System.Memory solves the issue ReadOnlySpan input = dataSource.AsSpan().TrimStart(); error = false; + // NetStandard 2.0 does not support passing a string to ReadOnlySpan int index = input.IndexOf(LocalDbHost.AsSpan().Trim(), StringComparison.InvariantCultureIgnoreCase); if (input.StartsWith(LocalDbHost_NP.AsSpan().Trim(), StringComparison.InvariantCultureIgnoreCase)) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 9d415bcfc8..2791de17a4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -644,7 +644,6 @@ public override uint EnableSsl(uint options) } else { - // TODO: Resolve whether to send _serverNameIndication or _targetServer. _serverNameIndication currently results in error. Why? _sslStream.AuthenticateAsClient(_targetServer, null, s_supportedProtocols, false); } if (_sslOverTdsStream is not null) @@ -698,33 +697,8 @@ private bool ValidateServerCertificate(object sender, X509Certificate serverCert return true; } - string serverNameToValidate; - if (!string.IsNullOrEmpty(_hostNameInCertificate)) - { - serverNameToValidate = _hostNameInCertificate; - } - else - { - serverNameToValidate = _targetServer; - } - - if (!string.IsNullOrEmpty(_serverCertificateFilename)) - { - X509Certificate clientCertificate = null; - try - { - clientCertificate = new X509Certificate(_serverCertificateFilename); - return SNICommon.ValidateSslServerCertificate(clientCertificate, serverCertificate, policyErrors); - } - catch (Exception e) - { - // if this fails, then fall back to the HostNameInCertificate or TargetServer validation. - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, IOException occurred: {1}", args0: _connectionId, args1: e.Message); - } - } - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Certificate will be validated for Target Server name", args0: _connectionId); - return SNICommon.ValidateSslServerCertificate(serverNameToValidate, serverCertificate, policyErrors); + return SNICommon.ValidateSslServerCertificate(_connectionId, _targetServer, _hostNameInCertificate, serverCertificate, _serverCertificateFilename, policyErrors); } /// diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs index 533ffa86ad..a685787e9e 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs @@ -10215,6 +10215,15 @@ internal static string SQL_RemoteCertificateChainErrors { } } + /// + /// Looks up a localized string similar to The certificate provided by the server does not match the certificate provided by the ServerCertificate option.. + /// + internal static string SQL_RemoteCertificateDoesNotMatchServerCertificate { + get { + return ResourceManager.GetString("SQL_RemoteCertificateDoesNotMatchServerCertificate", resourceCulture); + } + } + /// /// Looks up a localized string similar to Certificate name mismatch. The provided 'DataSource' or 'HostNameInCertificate' does not match the name in the certificate.. /// diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx index 3fb2e0a74a..c2dd68b867 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx @@ -4737,4 +4737,7 @@ Certificate not available while validating the certificate. + + The certificate provided by the server does not match the certificate provided by the ServerCertificate option. + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs index 8f16c09aa3..a164149a60 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs @@ -12,6 +12,7 @@ using System.Security; using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlServer.TDS.PreLogin; using Microsoft.SqlServer.TDS.Servers; using Xunit; diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs index 1ead74f58d..ef45bdbc7a 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs @@ -65,6 +65,5 @@ public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool ena public void Dispose() => _endpoint?.Stop(); public string ConnectionString { get; private set; } - } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/ConnectionTestParameters.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/ConnectionTestParameters.cs new file mode 100644 index 0000000000..4b6e7b087b --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/ConnectionTestParameters.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.SqlServer.TDS.PreLogin; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.DataCommon +{ + public class ConnectionTestParameters + { + private SqlConnectionEncryptOption _encryptionOption; + private TDSPreLoginTokenEncryptionType _encryptionType; + private string _hnic; + private string _cert; + private bool _result; + private bool _trustServerCert; + + public SqlConnectionEncryptOption Encrypt => _encryptionOption; + public bool TrustServerCertificate => _trustServerCert; + public string Certificate => _cert; + public string HostNameInCertificate => _hnic; + public bool TestResult => _result; + public TDSPreLoginTokenEncryptionType TdsEncryptionType => _encryptionType; + + public ConnectionTestParameters(TDSPreLoginTokenEncryptionType tdsEncryptionType, SqlConnectionEncryptOption encryptOption, bool trustServerCert, string cert, string hnic, bool result) + { + _encryptionOption = encryptOption; + _trustServerCert = trustServerCert; + _cert = cert; + _hnic = hnic; + _result = result; + _encryptionType = tdsEncryptionType; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/ConnectionTestParametersData.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/ConnectionTestParametersData.cs new file mode 100644 index 0000000000..5a2e01a77c --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/ConnectionTestParametersData.cs @@ -0,0 +1,85 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.IO; +using Microsoft.SqlServer.TDS.PreLogin; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.DataCommon +{ + public class ConnectionTestParametersData + { + private const int CASES = 30; + private string _empty = string.Empty; + // It was advised to store the client certificate in its own folder. + private static readonly string s_fullPathToCer = Path.Combine(Directory.GetCurrentDirectory(), "clientcert", "localhostcert.cer"); + private static readonly string s_mismatchedcert = Path.Combine(Directory.GetCurrentDirectory(), "clientcert", "mismatchedcert.cer"); + + private static readonly string s_hostName = System.Net.Dns.GetHostName(); + public static ConnectionTestParametersData Data { get; } = new ConnectionTestParametersData(); + public List ConnectionTestParametersList { get; set; } + + public static IEnumerable GetConnectionTestParameters() + { + for (int i = 0; i < CASES; i++) + { + yield return new object[] { Data.ConnectionTestParametersList[i] }; + } + } + + public ConnectionTestParametersData() + { + // Test cases possible field values for connection parameters: + // These combinations are based on the possible values of Encrypt, TrustServerCertificate, Certificate, HostNameInCertificate + /* + * TDSEncryption | Encrypt | TrustServerCertificate | Certificate | HNIC | TestResults + * ---------------------------------------------------------------------------------------------- + * Off | Optional | true | valid | valid name | true + * On | Mandatory | false | mismatched | empty | false + * Required | | x | ChainError? | wrong name? | + */ + ConnectionTestParametersList = new List + { + // TDSPreLoginTokenEncryptionType.Off + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Optional, false, _empty, _empty, true), + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Mandatory, false, _empty, _empty, false), + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Optional, true, _empty, _empty, true), + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Mandatory, true, _empty, _empty, true), + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Mandatory, false, s_fullPathToCer, _empty, true), + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Mandatory, true, s_fullPathToCer, _empty, true), + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Mandatory, false, _empty, s_hostName, false), + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Mandatory, true, _empty, s_hostName, true), + + // TDSPreLoginTokenEncryptionType.On + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Optional, false, _empty, _empty, false), + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Mandatory, false, _empty, _empty, false), + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Optional, true, _empty, _empty, true), + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Mandatory, true, _empty, _empty, true), + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Mandatory, false, s_fullPathToCer, _empty, true), + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Mandatory, true, s_fullPathToCer, _empty, true), + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Mandatory, false, _empty, s_hostName, false), + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Mandatory, true, _empty, s_hostName, true), + + // TDSPreLoginTokenEncryptionType.Required + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Optional, false, _empty, _empty, false), + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Mandatory, false, _empty, _empty, false), + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Optional, true, _empty, _empty, true), + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Mandatory, true, _empty, _empty, true), + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Mandatory, false, s_fullPathToCer, _empty, true), + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Mandatory, true, s_fullPathToCer, _empty, true), + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Mandatory, false, _empty, s_hostName, false), + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Mandatory, true, _empty, s_hostName, true), + + // Mismatched certificate test + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Mandatory, false, s_mismatchedcert, _empty, false), + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Mandatory, true, s_mismatchedcert, _empty, false), + new(TDSPreLoginTokenEncryptionType.Off, SqlConnectionEncryptOption.Mandatory, true, s_mismatchedcert, _empty, true), + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Mandatory, false, s_mismatchedcert, _empty, false), + new(TDSPreLoginTokenEncryptionType.On, SqlConnectionEncryptOption.Mandatory, true, s_mismatchedcert, _empty, true), + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Mandatory, false, s_mismatchedcert, _empty, false), + new(TDSPreLoginTokenEncryptionType.Required, SqlConnectionEncryptOption.Mandatory, true, s_mismatchedcert, _empty, true), + }; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index f75aef8edb..8962d5ab15 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -270,6 +270,8 @@ + + @@ -287,6 +289,7 @@ + @@ -354,6 +357,15 @@ + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + Always diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTest.cs index fc358acb05..d8a402236e 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTest.cs @@ -32,7 +32,7 @@ public class CertificateTest : IDisposable // InstanceName will get replaced with an instance name in the connection string private static string InstanceName = "MSSQLSERVER"; - // InstanceNamePrefix will get replaced with MSSQL$ is there is an instance name in connection string + // s_instanceNamePrefix will get replaced with MSSQL$ is there is an instance name in connection string private static string InstanceNamePrefix = ""; // SlashInstance is used to override IPV4 and IPV6 defined about so it includes an instance name diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs new file mode 100644 index 0000000000..9cfc1c71bb --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs @@ -0,0 +1,269 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using System.Diagnostics; +using System.IO; +using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates; +using System.Security.Principal; +using System.ServiceProcess; +using System.Text; +using Microsoft.Data.SqlClient.ManualTesting.Tests.DataCommon; +using Microsoft.Win32; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + public class CertificateTestWithTdsServer : IDisposable + { + private static readonly string s_fullPathToPowershellScript = Path.Combine(Directory.GetCurrentDirectory(), "makepfxcert.ps1"); + private static readonly string s_fullPathToCleanupPowershellScript = Path.Combine(Directory.GetCurrentDirectory(), "removecert.ps1"); + private static readonly string s_fullPathToPfx = Path.Combine(Directory.GetCurrentDirectory(), "localhostcert.pfx"); + private static readonly string s_fullPathTothumbprint = Path.Combine(Directory.GetCurrentDirectory(), "thumbprint.txt"); + private static readonly string s_fullPathToClientCert = Path.Combine(Directory.GetCurrentDirectory(), "clientcert"); + private static bool s_windowsAdmin = true; + private static string s_instanceName = "MSSQLSERVER"; + // s_instanceNamePrefix will get replaced with MSSQL$ is there is an instance name in the connection string + private static string s_instanceNamePrefix = ""; + private const string LocalHost = "localhost"; + + public CertificateTestWithTdsServer() + { + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); + Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName)); + + if (!string.IsNullOrEmpty(instanceName)) + { + s_instanceName = instanceName; + s_instanceNamePrefix = "MSSQL$"; + } + + // Confirm that user has elevated access on Windows + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + using WindowsIdentity identity = WindowsIdentity.GetCurrent(); + WindowsPrincipal principal = new(identity); + if (principal.IsInRole(WindowsBuiltInRole.Administrator)) + s_windowsAdmin = true; + else + s_windowsAdmin = false; + } + + if (!Directory.Exists(s_fullPathToClientCert)) + { + Directory.CreateDirectory(s_fullPathToClientCert); + } + + RunPowershellScript(s_fullPathToPowershellScript); + } + + private static bool IsLocalHost() + { + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); + Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out _)); + return LocalHost.Equals(hostname, StringComparison.OrdinalIgnoreCase); + } + + private static bool AreConnStringsSetup() => DataTestUtility.AreConnStringsSetup(); + private static bool IsNotAzureServer() => DataTestUtility.IsNotAzureServer(); + private static bool UseManagedSNIOnWindows() => DataTestUtility.UseManagedSNIOnWindows; + + private static string ForceEncryptionRegistryPath + { + get + { + if (DataTestUtility.IsSQL2022()) + { + return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL16.{s_instanceName}\MSSQLSERVER\SuperSocketNetLib"; + } + if (DataTestUtility.IsSQL2019()) + { + return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL15.{s_instanceName}\MSSQLSERVER\SuperSocketNetLib"; + } + if (DataTestUtility.IsSQL2016()) + { + return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL14.{s_instanceName}\MSSQLSERVER\SuperSocketNetLib"; + } + return string.Empty; + } + } + + [ConditionalTheory(nameof(AreConnStringsSetup), nameof(IsNotAzureServer), nameof(IsLocalHost))] + [MemberData(nameof(ConnectionTestParametersData.GetConnectionTestParameters), MemberType = typeof(ConnectionTestParametersData))] + [PlatformSpecific(TestPlatforms.Windows)] + public void BeginWindowsConnectionTest(ConnectionTestParameters connectionTestParameters) + { + if (!s_windowsAdmin) + { + Assert.Fail("User needs to have elevated access for these set of tests"); + } + + ConnectionTest(connectionTestParameters); + } + + [ConditionalTheory(nameof(AreConnStringsSetup), nameof(IsNotAzureServer), nameof(IsLocalHost))] + [MemberData(nameof(ConnectionTestParametersData.GetConnectionTestParameters), MemberType = typeof(ConnectionTestParametersData))] + [PlatformSpecific(TestPlatforms.Linux)] + public void BeginLinuxConnectionTest(ConnectionTestParameters connectionTestParameters) + { + ConnectionTest(connectionTestParameters); + } + + private void ConnectionTest(ConnectionTestParameters connectionTestParameters) + { + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); + + // The TestTdsServer does not validate the user name and password, so we can use any value if they are not defined. + string userId = string.IsNullOrWhiteSpace(builder.UserID) ? "user" : builder.UserID; + string password = string.IsNullOrWhiteSpace(builder.Password) ? "password" : builder.Password; + + using TestTdsServer server = TestTdsServer.StartTestServer(enableFedAuth: false, enableLog: false, connectionTimeout: 15, + methodName: "", new X509Certificate2(s_fullPathToPfx, "nopassword", X509KeyStorageFlags.UserKeySet), + encryptionType: connectionTestParameters.TdsEncryptionType); + + builder = new(server.ConnectionString) + { + UserID = userId, + Password = password, + TrustServerCertificate = connectionTestParameters.TrustServerCertificate, + Encrypt = connectionTestParameters.Encrypt, + }; + + if (!string.IsNullOrEmpty(connectionTestParameters.Certificate)) + { + builder.ServerCertificate = connectionTestParameters.Certificate; + } + + if (!string.IsNullOrEmpty(connectionTestParameters.HostNameInCertificate)) + { + builder.HostNameInCertificate = connectionTestParameters.HostNameInCertificate; + } + + using SqlConnection connection = new(builder.ConnectionString); + try + { + connection.Open(); + Assert.Equal(connectionTestParameters.TestResult, (connection.State == ConnectionState.Open)); + } + catch (Exception) + { + Assert.False(connectionTestParameters.TestResult); + } + } + + private static void RunPowershellScript(string script) + { + string currentDirectory = Directory.GetCurrentDirectory(); + string powerShellCommand = "powershell.exe"; + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + powerShellCommand = "pwsh"; + } + + if (File.Exists(script)) + { + StringBuilder output = new(); + Process proc = new() + { + StartInfo = + { + FileName = powerShellCommand, + RedirectStandardError = true, + RedirectStandardOutput = true, + UseShellExecute = false, + Arguments = $"{script} -OutDir {currentDirectory} > result.txt", + CreateNoWindow = false, + Verb = "runas" + } + }; + + proc.EnableRaisingEvents = true; + + proc.OutputDataReceived += new DataReceivedEventHandler((sender, e) => + { + if (e.Data != null) + { + output.AppendLine(e.Data); + } + }); + + proc.ErrorDataReceived += new DataReceivedEventHandler((sender, e) => + { + if (e.Data != null) + { + output.AppendLine(e.Data); + } + }); + + proc.Start(); + + proc.BeginOutputReadLine(); + proc.BeginErrorReadLine(); + + if (!proc.WaitForExit(60000)) + { + proc.Kill(); + proc.WaitForExit(2000); + throw new Exception($"Could not generate certificate. Error output: {output}"); + } + } + else + { + throw new Exception($"Could not find makepfxcert.ps1"); + } + } + + private void RemoveCertificate() + { + string thumbprint = File.ReadAllText(s_fullPathTothumbprint); + using X509Store certStore = new(StoreName.Root, StoreLocation.LocalMachine); + certStore.Open(OpenFlags.ReadWrite); + X509Certificate2Collection certCollection = certStore.Certificates.Find(X509FindType.FindByThumbprint, thumbprint, false); + if (certCollection.Count > 0) + { + certStore.Remove(certCollection[0]); + } + certStore.Close(); + + File.Delete(s_fullPathTothumbprint); + Directory.Delete(s_fullPathToClientCert, true); + } + + private static void RemoveForceEncryptionFromRegistryPath(string registryPath) + { + RegistryKey key = Registry.LocalMachine.OpenSubKey(registryPath, true); + key?.SetValue("ForceEncryption", 0, RegistryValueKind.DWord); + key?.SetValue("Certificate", "", RegistryValueKind.String); + ServiceController sc = new($"{s_instanceNamePrefix}{s_instanceName}"); + sc.Stop(); + sc.WaitForStatus(ServiceControllerStatus.Stopped); + sc.Start(); + sc.WaitForStatus(ServiceControllerStatus.Running); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + if (disposing && !string.IsNullOrEmpty(s_fullPathTothumbprint)) + { + RemoveCertificate(); + RemoveForceEncryptionFromRegistryPath(ForceEncryptionRegistryPath); + } + } + else + { + RunPowershellScript(s_fullPathToCleanupPowershellScript); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs index 3552204886..a4557d72b6 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs @@ -3,27 +3,35 @@ // See the LICENSE file in the project root for more information. using System; +using System.Linq; using System.Net; +using System.Net.Sockets; using System.Runtime.CompilerServices; +using System.Security.Cryptography.X509Certificates; using Microsoft.SqlServer.TDS.EndPoint; +using Microsoft.SqlServer.TDS.PreLogin; using Microsoft.SqlServer.TDS.Servers; namespace Microsoft.Data.SqlClient.ManualTesting.Tests { internal class TestTdsServer : GenericTDSServer, IDisposable { + private const int DefaultConnectionTimeout = 5; + private TDSServerEndPoint _endpoint = null; - private SqlConnectionStringBuilder connectionStringBuilder; + private SqlConnectionStringBuilder _connectionStringBuilder; public TestTdsServer(TDSServerArguments args) : base(args) { } public TestTdsServer(QueryEngine engine, TDSServerArguments args) : base(args) { - this.Engine = engine; + Engine = engine; } - public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, [CallerMemberName] string methodName = "") + public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, + int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "", + X509Certificate2 encryptionCertificate = null, TDSPreLoginTokenEncryptionType encryptionType = TDSPreLoginTokenEncryptionType.NotSupported) { TDSServerArguments args = new TDSServerArguments() { @@ -32,10 +40,18 @@ public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool if (enableFedAuth) { - args.FedAuthRequiredPreLoginOption = Microsoft.SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired; + args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired; + } + + if (encryptionCertificate != null) + { + args.EncryptionCertificate = encryptionCertificate; } + args.Encryption = encryptionType; + TestTdsServer server = engine == null ? new TestTdsServer(args) : new TestTdsServer(engine, args); + server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; server._endpoint.EndpointName = methodName; // The server EventLog should be enabled as it logs the exceptions. @@ -43,19 +59,37 @@ public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool server._endpoint.Start(); int port = server._endpoint.ServerEndPoint.Port; - server.connectionStringBuilder = new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = 5, Encrypt = SqlConnectionEncryptOption.Optional }; - server.ConnectionString = server.connectionStringBuilder.ConnectionString; + + server._connectionStringBuilder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + port, + ConnectTimeout = connectionTimeout, + }; + + if (encryptionType == TDSPreLoginTokenEncryptionType.Off || + encryptionType == TDSPreLoginTokenEncryptionType.None || + encryptionType == TDSPreLoginTokenEncryptionType.NotSupported) + { + server._connectionStringBuilder.Encrypt = SqlConnectionEncryptOption.Optional; + } + else + { + server._connectionStringBuilder.Encrypt = SqlConnectionEncryptOption.Mandatory; + } + + server.ConnectionString = server._connectionStringBuilder.ConnectionString; return server; } - public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, [CallerMemberName] string methodName = "") + public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, + int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "", + X509Certificate2 encryptionCertificate = null, TDSPreLoginTokenEncryptionType encryptionType = TDSPreLoginTokenEncryptionType.NotSupported) { - return StartServerWithQueryEngine(null, false, false, methodName); + return StartServerWithQueryEngine(null, enableFedAuth, enableLog, connectionTimeout, methodName, encryptionCertificate, encryptionType); } public void Dispose() => _endpoint?.Stop(); public string ConnectionString { get; private set; } - } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/makepfxcert.ps1 b/src/Microsoft.Data.SqlClient/tests/ManualTests/makepfxcert.ps1 new file mode 100644 index 0000000000..02d558d77b --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/makepfxcert.ps1 @@ -0,0 +1,158 @@ +# Licensed to the .NET Foundation under one or more agreements. +# The .NET Foundation licenses this file to you under the MIT license. +# See the LICENSE file in the project root for more information. +# Script: Invoke-SqlServerCertificateCommand# +# Author: SqlClient Team +# Date: March 20, 2024 +# Comments: This scripts creates SSL Self-Signed Certificate for TestTdsServer in pfx format. +# This script is not intended to be used in any production environments. + +param ($OutDir) + +function Invoke-SqlServerCertificateCommand { + [CmdletBinding()] + param( + [Parameter(Mandatory = $false)] + [string] $certificateName = "localhostcert.cer", + [string] $myCertStoreLocation = "Cert:\LocalMachine\My", + [string] $rootCertStoreLocation = "Cert:\LocalMachine\Root", + [string] $sqlAliasName = "SQLAliasName", + [string] $localhost = "localhost", + [string] $LoopBackIPV4 = "127.0.0.1", + [string] $LoopBackIPV6 = "::1" + ) + Write-Output "Certificate generation started..." + + # Change directory to where the tests are + Write-Output "Change directory to $OutDir ..." + cd $OutDir + pwd + + try { + # Get FQDN of the machine + Write-Output "Get FQDN of the machine..." + $fqdn = [System.Net.Dns]::GetHostByName(($env:computerName)).HostName + Write-Output "FQDN = $fqdn" + + $OS = [System.Environment]::OSVersion.Platform + Write-Output "Operating System is $OS" + + # Create a self-signed certificate + if ($OS -eq "Unix") { + chmod 777 $OutDir + # Install OpenSSL module + Install-Module -Name OpenSSL -Repository PSGallery -Force + # Show version of OpenSSL just to make sure it is installed + openssl version + + # Create self signed certificate using openssl + Write-Output "Creating certificate for linux..." + if ($fqdn.length -gt 64) { + $machineId = $fqdn.Substring(0,15) + openssl req -x509 -newkey rsa:4096 -sha256 -days 1095 -nodes -keyout $OutDir/localhostcert.key -out $OutDir/localhostcert.cer -subj "/CN=$machineId" -addext "subjectAltName=DNS:$fqdn,DNS:localhost,IP:127.0.0.1,IP:::1" + } + else { + openssl req -x509 -newkey rsa:4096 -sha256 -days 1095 -nodes -keyout $OutDir/localhostcert.key -out $OutDir/localhostcert.cer -subj "/CN=$fqdn" -addext "subjectAltName=DNS:$fqdn,DNS:localhost,IP:127.0.0.1,IP:::1" + } + chmod 777 $OutDir/localhostcert.key $OutDir/localhostcert.cer + # Copy the certificate to the clientcert folder + cp $OutDir/localhostcert.cer $OutDir/clientcert/ + # Export the certificate to pfx + Write-Output "Exporting certificate to pfx..." + openssl pkcs12 -export -in $OutDir/localhostcert.cer -inkey $OutDir/localhostcert.key -out $OutDir/localhostcert.pfx -password pass:nopassword + chmod 777 $OutDir/localhostcert.pfx + + Write-Output "Converting certificate to pem..." + # Create pem from cer + cp $OutDir/localhostcert.cer $OutDir/localhostcert.pem + chmod 777 $OutDir/localhostcert.pem + + # Add trust to the pem certificate + Write-Output "Adding trust to pem certificate..." + openssl x509 -trustout -addtrust "serverAuth" -in $OutDir/localhostcert.pem + + # Import the certificate to the Root store ------------------------------------------------------------------------------ + # NOTE: The process must have root privileges to add the certificate to the Root store. If not, then use + # "chmod 777 /usr/local/share/ca-certificates" to give read, write and execute privileges to anyone on that folder + # Copy the certificate to /usr/local/share/ca-certificates folder while changing the extension to "crt". + # Only certificates with extension "crt" gets added for some reason. + Write-Output "Copy the pem certificate to /usr/local/share/ca-certificates folder..." + cp $OutDir/localhostcert.pem /usr/local/share/ca-certificates/localhostcert.crt + + # Add trust to the mismatched certificate as well + $ openssl x509 -in $OutDir/mismatchedcert.cer -inform der -out $OutDir/mismatchedcert.pem + # Copy the mismatched certificate to the clientcert folder + cp $OutDir/mismatchedcert.cer $OutDir/clientcert/ + openssl x509 -trustout -addtrust "serverAuth" -in $OutDir/mismatchedcert.pem + cp $OutDir/mismatchedcert.pem /usr/local/share/ca-certificates/mismatchedcert.crt + + # enable certificate as CA certificate + dpkg-reconfigure ca-certificates -f noninteractive -p critical + + # Update the certificates store + Write-Output "Updating the certificates store..." + update-ca-certificates -v + } else { + Write-Output "Creating a self-signed certificate..." + $params = @{ + Type = "SSLServerAuthentication" + Subject = "CN=$fqdn" + KeyAlgorithm = "RSA" + KeyLength = 4096 + HashAlgorithm = "SHA256" + TextExtension = "2.5.29.37={text}1.3.6.1.5.5.7.3.1", "2.5.29.17={text}DNS=$fqdn&DNS=$localhost&IPAddress=$LoopBackIPV4&DNS=$sqlAliasName&IPAddress=$LoopBackIPV6" + NotAfter = (Get-Date).AddMonths(36) + KeySpec = "KeyExchange" + Provider = "Microsoft RSA SChannel Cryptographic Provider" + CertStoreLocation = $myCertStoreLocation + FriendlyName = "TestTDSServerCertificate" + } + + $certificate = New-SelfSignedCertificate @params + Write-Output "Certificate created successfully" + Write-Output "Certificate Thumbprint: $($certificate.Thumbprint)" + + # Export the certificate to a file + Write-Output "Exporting the certificate to a file..." + Export-Certificate -Cert $certificate -FilePath "$OutDir/$certificateName" -Type CERT + + # Copy the certificate to the clientcert folder + copy $OutDir/$certificateName $OutDir/clientcert/ + copy $OutDir/mismatchedcert.cer $OutDir/clientcert/ + + # Import the certificate to the Root store + Write-Output "Importing the certificate to the Root store..." + $params = @{ + FilePath = "$OutDir/$certificateName" + CertStoreLocation = $rootCertStoreLocation + } + Import-Certificate @params + + Write-Output "Converting certificate to pfx..." + Write-Output "Cert:\LocalMachine\my\$($certificate.Thumbprint)" + + $pwd = ConvertTo-SecureString -String 'nopassword' -Force -AsPlainText + # Export the certificate to a pfx format + Export-PfxCertificate -Password $pwd -FilePath "$OutDir\localhostcert.pfx" -Cert "Cert:\LocalMachine\my\$($certificate.Thumbprint)" + + # Write the certificate thumbprint to a file + echo $certificate.Thumbprint | Out-File -FilePath "$OutDir\thumbprint.txt" -Encoding ascii + } + + Write-Output "Done creating pfx certificate..." + } + catch { + $e = $_.Exception + $msg = $e.Message + while ($e.InnerException) { + $e = $e.InnerException + $msg += "`n" + $e.Message + } + + Write-Output "Certificate generation was not successfull. $msg" + } + + Write-Output "Certificate generation task completed." +} + +Invoke-SqlServerCertificateCommand diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/mismatchedcert.cer b/src/Microsoft.Data.SqlClient/tests/ManualTests/mismatchedcert.cer new file mode 100644 index 0000000000..6b35e97a55 Binary files /dev/null and b/src/Microsoft.Data.SqlClient/tests/ManualTests/mismatchedcert.cer differ diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/removecert.ps1 b/src/Microsoft.Data.SqlClient/tests/ManualTests/removecert.ps1 new file mode 100644 index 0000000000..14b944de80 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/removecert.ps1 @@ -0,0 +1,21 @@ +# Licensed to the .NET Foundation under one or more agreements. +# The .NET Foundation licenses this file to you under the MIT license. +# See the LICENSE file in the project root for more information. +# Script: removecert.ps1 +# Author: SqlClient Team +# Date: May 24, 2024 +# Comments: This script deletes the SSL Self-Signed Certificate from Linux certificate store. +# This script is not intended to be used in any production environments. + +param ($OutDir) + +# Delete all certificates +rm $OutDir/clientcer/*.cer +rm $OutDir/localhostcert.pem +rm $OutDir/mismatchedcert.pem +rm /usr/local/share/ca-certificates/localhostcert.crt +rm /usr/local/share/ca-certificates/mismatchedcert.crt + +# Update the certificates store +Write-Output "Updating the certificates store..." +update-ca-certificates -v diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDSStream.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDSStream.cs index 94abbf5818..80d8633501 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDSStream.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDSStream.cs @@ -264,6 +264,10 @@ public override int Read(byte[] buffer, int offset, int count) // Calculate how much data can be read until the end of the packet is reached long packetDataAvailable = IncomingPacketHeader.Length - IncomingPacketPosition; + // Set count to actual size of data to be read from the buffer so this loop can exit + if (packetDataAvailable < count) + count = (int)packetDataAvailable; + // Check how much data we should give back in the current iteration int packetDataToRead = Math.Min((int)packetDataAvailable, count - bufferReadPosition);