Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(all): Enable TLSv1.2 support for older devices #3258

Merged
merged 15 commits into from
May 9, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package com.amazonaws.mobileconnectors.cognitoauth.util;

import android.content.Context;

import com.amazonaws.http.TLS12SocketFactory;
import com.amazonaws.mobileconnectors.cognitoauth.exceptions.AuthClientException;
import com.amazonaws.mobileconnectors.cognitoauth.exceptions.AuthServiceException;

Expand Down Expand Up @@ -46,6 +45,7 @@ public String httpPost(final URL uri, final Map<String, String> headerParams, fi
}

final HttpsURLConnection httpsURLConnection = (HttpsURLConnection) uri.openConnection();
TLS12SocketFactory.fixTLSPre21(httpsURLConnection);
DataOutputStream httpOutputStream = null;
BufferedReader br = null;
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
*
* * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* *
* * Licensed under the Apache License, Version 2.0 (the "License").
* * You may not use this file except in compliance with the License.
* * A copy of the License is located at
* *
* * http://aws.amazon.com/apache2.0
* *
* * or in the "license" file accompanying this file. This file is distributed
* * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* * express or implied. See the License for the specific language governing
* * permissions and limitations under the License.
*
*/

package com.amazonaws.http;

import com.amazonaws.logging.LogFactory;

import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLSession;

public class LoggingHandshakeCompletedListener implements HandshakeCompletedListener {

private static final com.amazonaws.logging.Log log =
LogFactory.getLog(LoggingHandshakeCompletedListener.class);
@Override
public void handshakeCompleted(HandshakeCompletedEvent event) {
try {
SSLSession session = event.getSession();
String protocol = session.getProtocol();
String cipherSuite = session.getCipherSuite();

log.debug("Protocol: " + protocol + ", CipherSuite: " + cipherSuite);
} catch (Exception exception) {
log.debug("Failed to log connection protocol/cipher suite", exception);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amazonaws.http;

import android.os.Build;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

import java.io.IOException;
import java.net.InetAddress;
import java.net.Socket;
import java.net.UnknownHostException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;

import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

/**
* Although this has public access, it is intended for internal use and should not be used directly by host
* applications. The behavior of this may change without warning.
*/
public class TLS12SocketFactory extends SSLSocketFactory {

private static final Object contextLock = new Object();
private static final String[] SUPPORTED_PROTOCOLS =
new String[] { "TLSv1", "TLSv1.1", "TLSv1.2" };
private static SSLContext sslContext = null;
private final SSLSocketFactory delegate;
private LoggingHandshakeCompletedListener handshakeCompletedListener;

@Nullable
public static TLS12SocketFactory createTLS12SocketFactory() {
return createTLS12SocketFactory(null);
}

@Nullable
public static TLS12SocketFactory createTLS12SocketFactory(
@Nullable SSLContext sslContext
) {
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.LOLLIPOP) {
try {
return new TLS12SocketFactory(sslContext);
} catch (Exception e) {
//
}
}
return null;
}

public static void fixTLSPre21(@NonNull HttpsURLConnection connection) {
fixTLSPre21(connection, createTLS12SocketFactory());
}

public static void fixTLSPre21(
@NonNull HttpsURLConnection connection,
@Nullable TLS12SocketFactory tls12SocketFactory
) {
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.LOLLIPOP &&
tls12SocketFactory != null) {
try {
connection.setSSLSocketFactory(tls12SocketFactory);
} catch (Exception e) {
// Failed to enabled TLS1.2 on < Android 21 device
}
}
}

private TLS12SocketFactory(@Nullable SSLContext customSSLContext)
throws KeyManagementException, NoSuchAlgorithmException {

if (customSSLContext != null) {
delegate = customSSLContext.getSocketFactory();
} else {
// Cache SSLContext due to weight and hold static
synchronized (contextLock) {
if (sslContext == null) {
sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, null, null);
}
}
delegate = sslContext.getSocketFactory();
}
this.handshakeCompletedListener = new LoggingHandshakeCompletedListener();
}

@Override
public String[] getDefaultCipherSuites() {
return delegate.getDefaultCipherSuites();
}

@Override
public String[] getSupportedCipherSuites() {
return delegate.getSupportedCipherSuites();
}

@Override
public Socket createSocket() throws IOException {
SSLSocket socket = (SSLSocket) delegate.createSocket();
socket.addHandshakeCompletedListener(handshakeCompletedListener);
return updateTLSProtocols(socket);
}

@Override
public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException {
SSLSocket socket = (SSLSocket) delegate.createSocket(s, host, port, autoClose);
socket.addHandshakeCompletedListener(handshakeCompletedListener);
return updateTLSProtocols(socket);
}

@Override
public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
SSLSocket socket = (SSLSocket) delegate.createSocket(host, port);
socket.addHandshakeCompletedListener(handshakeCompletedListener);
return updateTLSProtocols(socket);
}

@Override
public Socket createSocket(String host, int port, InetAddress localHost, int localPort) throws IOException, UnknownHostException {
SSLSocket socket = (SSLSocket) delegate.createSocket(host, port, localHost, localPort);
socket.addHandshakeCompletedListener(handshakeCompletedListener);
return updateTLSProtocols(socket);
}

@Override
public Socket createSocket(InetAddress host, int port) throws IOException {
SSLSocket socket = (SSLSocket) delegate.createSocket(host, port);
socket.addHandshakeCompletedListener(handshakeCompletedListener);
return updateTLSProtocols(socket);
}

@Override
public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) throws IOException {
SSLSocket socket = (SSLSocket) delegate.createSocket(address, port, localAddress, localPort);
socket.addHandshakeCompletedListener(handshakeCompletedListener);
return updateTLSProtocols(socket);
}

private Socket updateTLSProtocols(Socket socket) {
if(socket instanceof SSLSocket) {
try {
((SSLSocket) socket).setEnabledProtocols(SUPPORTED_PROTOCOLS);
} catch (Exception e) {
// TLS 1.2 may not be supported on device
}
}
return socket;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,24 @@ public class UrlHttpClient implements HttpClient {
private static final int BUFFER_SIZE_MULTIPLIER = 8;
private final ClientConfiguration config;

// SocketFactory for Pre SDK 21 devices to enforce TLS 1.2
private final TLS12SocketFactory tls12SocketFactory;

// Cached SSLContext for connections using custom TrustManagers.
private SSLContext customTrustSSLContext = null;

// SocketFactory for Pre SDK 21 devices to enforce TLS 1.2 that also holds custom TrustManagers.
private TLS12SocketFactory customTrustTls12SocketFactory;

/**
* Constructor.
* @param config the client config.
*/
public UrlHttpClient(ClientConfiguration config) {
this.config = config;

// will return null if SDK >= 21
tls12SocketFactory = TLS12SocketFactory.createTLS12SocketFactory();
}

@Override
Expand Down Expand Up @@ -279,26 +291,35 @@ void configureConnection(HttpRequest request, HttpURLConnection connection) {

if (config.getTrustManager() != null) {
enableCustomTrustManager(https);
} else if (tls12SocketFactory != null) {
TLS12SocketFactory.fixTLSPre21(https, tls12SocketFactory);
}
}
}

private SSLContext sc = null;

private void enableCustomTrustManager(HttpsURLConnection connection) {
if (sc == null) {
if (customTrustSSLContext == null) {
final TrustManager[] customTrustManagers = new TrustManager[] {
config.getTrustManager()
};
try {
sc = SSLContext.getInstance("TLS");
sc.init(null, customTrustManagers, null);
customTrustSSLContext = SSLContext.getInstance("TLS");
customTrustSSLContext.init(null, customTrustManagers, null);

if (customTrustTls12SocketFactory == null) {
customTrustTls12SocketFactory = TLS12SocketFactory
.createTLS12SocketFactory(customTrustSSLContext);
}
} catch (final GeneralSecurityException e) {
throw new RuntimeException(e);
}
}

connection.setSSLSocketFactory(sc.getSocketFactory());
if (customTrustTls12SocketFactory != null) {
connection.setSSLSocketFactory(customTrustTls12SocketFactory);
} else {
connection.setSSLSocketFactory(customTrustSSLContext.getSocketFactory());
}
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.amazonaws.ClientConfiguration;
import com.amazonaws.Request;
import com.amazonaws.http.HttpMethodName;
import com.amazonaws.http.TLS12SocketFactory;

import java.io.IOException;
import java.io.InputStream;
Expand All @@ -31,6 +32,8 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.net.ssl.HttpsURLConnection;

/**
* HTTP utils class.
*/
Expand Down Expand Up @@ -289,6 +292,9 @@ public static InputStream fetchFile(
final URL url = uri.toURL();
// TODO: support proxy?
final HttpURLConnection connection = (HttpURLConnection) url.openConnection();
if (connection instanceof HttpsURLConnection) {
TLS12SocketFactory.fixTLSPre21((HttpsURLConnection) connection);
}
connection.setConnectTimeout(getConnectionTimeout(config));
connection.setReadTimeout(getSocketTimeout(config));
connection.addRequestProperty("User-Agent", getUserAgent(config));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.amazonaws.AmazonClientException;
import com.amazonaws.SDKGlobalConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.http.TLS12SocketFactory;
import com.amazonaws.regions.Region;
import com.amazonaws.util.StringUtils;
import com.amazonaws.util.VersionInfoUtils;
Expand Down Expand Up @@ -229,10 +230,17 @@ public boolean isMetricsEnabled() {
return metricsIsEnabled;
}
/**
* Holds client socket factory. Set upon initial connect then reused on
* Holds client socket factory for keystore connect. Set upon initial connect then reused on
* reconnect.
*/
private SocketFactory clientSocketFactory;


/**
* Holds cached SocketFactory for non-keystore connect calls on Android versions < 21
* Set upon initial connect then reused on reconnect.
*/
private TLS12SocketFactory tls12SocketFactory;
/**
* Holds client provided AWS credentials provider.
* Set upon initial connect.
Expand Down Expand Up @@ -1138,6 +1146,7 @@ private void customAuthConnect(final MqttConnectOptions options) {
private void mqttConnect(MqttConnectOptions options) {
LOGGER.debug("ready to do mqtt connect");

fixTLSPre21(options);
options.setCleanSession(cleanSession);
options.setKeepAliveInterval(userKeepAlive);

Expand Down Expand Up @@ -1324,6 +1333,8 @@ void reconnectToSession() {
handleConnectionFailure(new IllegalStateException("Unexpected value: " + authMode));
}

fixTLSPre21(options);

setupCallbackForMqttClient();
try {
++autoReconnectsAttempted;
Expand Down Expand Up @@ -2055,4 +2066,18 @@ enum AuthenticationMode {
public boolean getSessionPresent() {
return sessionPresent;
}

/**
* Injects a SocketFactory that supports TLSv1.2 on pre 21 devices.
* If a SocketFactory is already specified (ex keystore connect uses its own), call is ignored.
* @param options for connect call
*/
private void fixTLSPre21(MqttConnectOptions options) {
if (options.getSocketFactory() == null &&
Build.VERSION.SDK_INT < Build.VERSION_CODES.LOLLIPOP
) {
this.tls12SocketFactory = TLS12SocketFactory.createTLS12SocketFactory();
options.setSocketFactory(tls12SocketFactory);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import androidx.browser.customtabs.CustomTabsSession;
import android.util.Log;

import com.amazonaws.http.TLS12SocketFactory;
import com.amazonaws.internal.keyvaluestore.AWSKeyValueStore;
import com.amazonaws.mobile.client.AWSMobileClient;
import com.amazonaws.mobile.client.Callback;
Expand Down Expand Up @@ -495,6 +496,7 @@ public static String httpPost(final URL uri, final Map<String, String> headerPar
}

final HttpsURLConnection httpsURLConnection = (HttpsURLConnection) uri.openConnection();
TLS12SocketFactory.fixTLSPre21(httpsURLConnection);
DataOutputStream httpOutputStream = null;
BufferedReader br = null;
try {
Expand Down
Loading