diff --git a/Cargo.toml b/Cargo.toml index 2d531d7..a132cbc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,8 @@ default-target = "x86_64-pc-windows-msvc" windows-sys = { version = "0.59", features = [ "Win32_Foundation", "Win32_Security_Cryptography", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", - "Win32_System_Memory"] } + "Win32_System_LibraryLoader", "Win32_System_Memory" +] } [dev-dependencies] windows-sys = { version = "0.59", features = ["Win32_System_SystemInformation", "Win32_System_Time"] } diff --git a/src/schannel_cred.rs b/src/schannel_cred.rs index c4b21d1..ca19224 100644 --- a/src/schannel_cred.rs +++ b/src/schannel_cred.rs @@ -177,6 +177,31 @@ impl Protocol { } } +fn verify_min_os_build(major: u32, build: u32) -> Option<()> { + use windows_sys::Win32::System::SystemInformation::OSVERSIONINFOW; + + let handle = std::ptr::NonNull::new(unsafe { + windows_sys::Win32::System::LibraryLoader::GetModuleHandleW(windows_sys::w!("ntdll.dll")) + })?; + let rtl_get_ver = unsafe { + windows_sys::Win32::System::LibraryLoader::GetProcAddress(handle.as_ptr(), windows_sys::s!("RtlGetVersion")) + }?; + + type RtlGetVersionFunc = unsafe extern "system" fn(*mut OSVERSIONINFOW) -> i32; + let proc: RtlGetVersionFunc = unsafe { mem::transmute(rtl_get_ver) }; + + let mut info: OSVERSIONINFOW = unsafe { mem::zeroed() }; + info.dwOSVersionInfoSize = mem::size_of::() as u32; + + unsafe { proc(&mut info) }; + + if info.dwMajorVersion > major || (info.dwMajorVersion == major && info.dwBuildNumber >= build) { + Some(()) + } else { + None + } +} + /// A builder type for `SchannelCred`s. #[derive(Default, Debug)] pub struct Builder { @@ -220,36 +245,58 @@ impl Builder { /// Creates a new `SchannelCred`. pub fn acquire(&self, direction: Direction) -> io::Result { unsafe { - let mut handle: Credentials::SecHandle = mem::zeroed(); - let mut cred_data: Identity::SCHANNEL_CRED = mem::zeroed(); - cred_data.dwVersion = Identity::SCHANNEL_CRED_VERSION; - cred_data.dwFlags = - Identity::SCH_USE_STRONG_CRYPTO | Identity::SCH_CRED_NO_DEFAULT_CREDS; - if let Some(ref supported_algorithms) = self.supported_algorithms { - cred_data.cSupportedAlgs = supported_algorithms.len() as u32; - cred_data.palgSupportedAlgs = supported_algorithms.as_ptr() as *mut _; - } - if let Some(ref enabled_protocols) = self.enabled_protocols { - cred_data.grbitEnabledProtocols = enabled_protocols + let mut enabled_protocols: u32 = 0; + if let Some(ref enable_list) = self.enabled_protocols { + enabled_protocols = enable_list .iter() .map(|p| p.dword(direction)) .fold(0, |acc, p| acc | p); } + + let mut cred_data: Identity::SCHANNEL_CRED = mem::zeroed(); + cred_data.dwVersion = Identity::SCHANNEL_CRED_VERSION; + cred_data.dwFlags = Identity::SCH_USE_STRONG_CRYPTO | Identity::SCH_CRED_NO_DEFAULT_CREDS; + cred_data.grbitEnabledProtocols = enabled_protocols; let mut certs = self.certs.iter().map(|c| c.as_inner()).collect::>(); cred_data.cCreds = certs.len() as u32; cred_data.paCred = certs.as_mut_ptr() as _; + + let mut tls_param: Identity::TLS_PARAMETERS = mem::zeroed(); + let mut cred_data2: Identity::SCH_CREDENTIALS = mem::zeroed(); + let mut pauthdata: *const core::ffi::c_void = ptr::null(); + if let Some(ref supported_algorithms) = self.supported_algorithms { + cred_data.cSupportedAlgs = supported_algorithms.len() as u32; + cred_data.palgSupportedAlgs = supported_algorithms.as_ptr() as *mut _; + } else if verify_min_os_build(10, 17763).is_some() { + // If no algorithms specified and should be supported, use new SCH_CREDENTIALS interface which supports TLS1.3. + // Although we check for win10 build 17763 above, I have only seen this work on win 11. + tls_param.grbitDisabledProtocols = !enabled_protocols; + // TODO: support something to select tls13-ciphers + cred_data2.dwVersion = Identity::SCH_CREDENTIALS_VERSION; + cred_data2.dwFlags = Identity::SCH_USE_STRONG_CRYPTO | Identity::SCH_CRED_NO_DEFAULT_CREDS; + cred_data2.cCreds = certs.len() as u32; + cred_data2.paCred = certs.as_mut_ptr() as _; + cred_data2.cTlsParameters = 1; + cred_data2.pTlsParameters = &mut tls_param; + pauthdata = &mut cred_data2 as *const _ as *const _; + } + + if pauthdata.is_null() { + pauthdata = &mut cred_data as *const _ as *const _; + } let direction = match direction { Direction::Inbound => Identity::SECPKG_CRED_INBOUND, Direction::Outbound => Identity::SECPKG_CRED_OUTBOUND, }; + let mut handle: Credentials::SecHandle = mem::zeroed(); match Identity::AcquireCredentialsHandleA( ptr::null(), Identity::UNISP_NAME_A, direction, ptr::null_mut(), - &mut cred_data as *const _ as *const _, + pauthdata, None, ptr::null_mut(), &mut handle, diff --git a/src/test.rs b/src/test.rs index 5ea1899..7ab1e75 100644 --- a/src/test.rs +++ b/src/test.rs @@ -271,6 +271,27 @@ fn verify_callback_success() { assert!(out.ends_with(b"\n")); } +#[test] +fn tls_13() { + let creds = SchannelCred::builder() + .enabled_protocols(&[Protocol::Tls12, Protocol::Tls13]) + .acquire(Direction::Outbound) + .unwrap(); + let stream = TcpStream::connect("tls13.akamai.io:443").unwrap(); + let mut stream = tls_stream::Builder::new() + .domain("tls13.akamai.io") + .connect(creds, stream) + .unwrap(); + stream + .write_all(b"GET / HTTP/1.0\r\nHost: tls13.akamai.io\r\n\r\n") + .unwrap(); + let mut out = vec![]; + stream.read_to_end(&mut out).unwrap(); + + let pattern = b"Your client negotiated TLS 1.3"; + assert!(out.windows(pattern.len()).any(|x| x == pattern)); +} + #[test] fn verify_callback_error() { let creds = SchannelCred::builder()