diff --git a/app/src/main/java/chat/rocket/android/authentication/server/presentation/ServerPresenter.kt b/app/src/main/java/chat/rocket/android/authentication/server/presentation/ServerPresenter.kt index 415ace7783..be0ea59c46 100644 --- a/app/src/main/java/chat/rocket/android/authentication/server/presentation/ServerPresenter.kt +++ b/app/src/main/java/chat/rocket/android/authentication/server/presentation/ServerPresenter.kt @@ -1,9 +1,14 @@ package chat.rocket.android.authentication.server.presentation +import android.security.KeyChain +import android.security.KeyChainAliasCallback; import chat.rocket.android.authentication.domain.model.LoginDeepLinkInfo import chat.rocket.android.authentication.presentation.AuthenticationNavigator import chat.rocket.android.core.behaviours.showMessage import chat.rocket.android.core.lifecycle.CancelStrategy +import chat.rocket.android.authentication.ui.AuthenticationActivity +import chat.rocket.android.helper.ClientCertHelper +import chat.rocket.android.server.domain.SaveClientCertInteractor import chat.rocket.android.server.domain.GetAccountsInteractor import chat.rocket.android.server.domain.GetSettingsInteractor import chat.rocket.android.server.domain.RefreshSettingsInteractor @@ -23,9 +28,12 @@ class ServerPresenter @Inject constructor( private val serverInteractor: SaveConnectingServerInteractor, private val refreshSettingsInteractor: RefreshSettingsInteractor, private val getAccountsInteractor: GetAccountsInteractor, + private val saveClientCertInteractor: SaveClientCertInteractor, + private val clientCertHelper: ClientCertHelper, val settingsInteractor: GetSettingsInteractor, - val factory: RocketChatClientFactory -) : CheckServerPresenter( + val factory: RocketChatClientFactory, + internal val activity: AuthenticationActivity +) : KeyChainAliasCallback, CheckServerPresenter( strategy = strategy, factory = factory, settingsInteractor = settingsInteractor, @@ -117,4 +125,19 @@ class ServerPresenter @Inject constructor( } } } + + override fun alias(alias: String) { + saveClientCertInteractor.save(alias) + clientCertHelper.getClient() + } + + fun requestClientCert() { + KeyChain.choosePrivateKeyAlias(activity, + this, // Callback + null, // Any key types. + null, // Any issuers. + null, // Any host + -1, // Any port + "RocketChat") + } } \ No newline at end of file diff --git a/app/src/main/java/chat/rocket/android/authentication/server/ui/ServerFragment.kt b/app/src/main/java/chat/rocket/android/authentication/server/ui/ServerFragment.kt index adeebc8d20..c627e4fc78 100644 --- a/app/src/main/java/chat/rocket/android/authentication/server/ui/ServerFragment.kt +++ b/app/src/main/java/chat/rocket/android/authentication/server/ui/ServerFragment.kt @@ -139,6 +139,9 @@ class ServerFragment : Fragment(), ServerView { button_connect.setOnClickListener { presenter.checkServer("$protocol${text_server_url.textContent.sanitize()}") } + button_tls_cert.setOnClickListener { + presenter.requestClientCert() + } } override fun showInvalidServerUrlMessage() = diff --git a/app/src/main/java/chat/rocket/android/dagger/module/AppModule.kt b/app/src/main/java/chat/rocket/android/dagger/module/AppModule.kt index c06b4c4cf9..ac66d45d61 100644 --- a/app/src/main/java/chat/rocket/android/dagger/module/AppModule.kt +++ b/app/src/main/java/chat/rocket/android/dagger/module/AppModule.kt @@ -19,6 +19,7 @@ import chat.rocket.android.dagger.qualifier.ForAuthentication import chat.rocket.android.dagger.qualifier.ForMessages import chat.rocket.android.db.DatabaseManager import chat.rocket.android.db.DatabaseManagerFactory +import chat.rocket.android.helper.ClientCertHelper import chat.rocket.android.helper.MessageParser import chat.rocket.android.infrastructure.LocalRepository import chat.rocket.android.infrastructure.SharedPreferencesLocalRepository @@ -28,6 +29,7 @@ import chat.rocket.android.server.domain.AccountsRepository import chat.rocket.android.server.domain.AnalyticsTrackingInteractor import chat.rocket.android.server.domain.AnalyticsTrackingRepository import chat.rocket.android.server.domain.ChatRoomsRepository +import chat.rocket.android.server.domain.ClientCertRepository import chat.rocket.android.server.domain.CurrentServerRepository import chat.rocket.android.server.domain.GetAccountInteractor import chat.rocket.android.server.domain.GetAccountsInteractor @@ -54,6 +56,7 @@ import chat.rocket.android.server.infraestructure.SharedPreferencesPermissionsRe import chat.rocket.android.server.infraestructure.SharedPreferencesSettingsRepository import chat.rocket.android.server.infraestructure.SharedPrefsAnalyticsTrackingRepository import chat.rocket.android.server.infraestructure.SharedPrefsConnectingServerRepository +import chat.rocket.android.server.infraestructure.SharedPrefsClientCertRepository import chat.rocket.android.server.infraestructure.SharedPrefsCurrentServerRepository import chat.rocket.android.util.AppJsonAdapterFactory import chat.rocket.android.util.HttpLoggingInterceptor @@ -78,7 +81,6 @@ import okhttp3.OkHttpClient import ru.noties.markwon.SpannableConfiguration import ru.noties.markwon.spans.SpannableTheme import timber.log.Timber -import java.util.concurrent.TimeUnit import javax.inject.Named import javax.inject.Singleton @@ -123,25 +125,21 @@ class AppModule { @Provides @Singleton - fun provideOkHttpClient(logger: HttpLoggingInterceptor, basicAuthenticator: BasicAuthenticatorInterceptor): OkHttpClient { - return OkHttpClient.Builder() - .addInterceptor(logger) - .addInterceptor(basicAuthenticator) - .connectTimeout(15, TimeUnit.SECONDS) - .readTimeout(20, TimeUnit.SECONDS) - .writeTimeout(15, TimeUnit.SECONDS) - .build() + fun provideOkHttpClient( + clientCertHelper: ClientCertHelper + ): OkHttpClient { + return clientCertHelper.getClient() } @Provides @Singleton fun provideImagePipelineConfig( context: Context, - okHttpClient: OkHttpClient + clientCertHelper: ClientCertHelper ): ImagePipelineConfig { val listeners = setOf(RequestLoggingListener()) - return OkHttpImagePipelineConfigFactory.newBuilder(context, okHttpClient) + return OkHttpImagePipelineConfigFactory.newBuilder(context, clientCertHelper.getClient()) .setRequestListeners(listeners) .setDownsampleEnabled(true) .experiment().setPartialImageCachingEnabled(true).build() @@ -307,6 +305,12 @@ class AppModule { ): AccountsRepository = SharedPreferencesAccountsRepository(preferences, moshi) + @Provides + @Singleton + fun provideClientCertRepository(prefs: SharedPreferences): ClientCertRepository { + return SharedPrefsClientCertRepository(prefs) + } + @Provides fun provideNotificationManager(context: Application) = context.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager diff --git a/app/src/main/java/chat/rocket/android/helper/ClientCertHelper.kt b/app/src/main/java/chat/rocket/android/helper/ClientCertHelper.kt new file mode 100644 index 0000000000..c35e21e075 --- /dev/null +++ b/app/src/main/java/chat/rocket/android/helper/ClientCertHelper.kt @@ -0,0 +1,161 @@ +package chat.rocket.android.helper + +import android.app.Application +import android.os.AsyncTask +import android.security.KeyChain +import java.net.Socket +import java.security.KeyStore +import java.security.Principal +import java.security.PrivateKey +import java.security.cert.CertificateException +import java.security.cert.X509Certificate +import javax.net.ssl.SSLContext +import javax.net.ssl.TrustManagerFactory +import javax.net.ssl.X509TrustManager +import javax.net.ssl.X509ExtendedKeyManager +import javax.net.ssl.KeyManager +import chat.rocket.android.server.domain.GetClientCertInteractor +import chat.rocket.android.util.HttpLoggingInterceptor +import chat.rocket.android.util.BasicAuthenticatorInterceptor +import okhttp3.OkHttpClient +import javax.inject.Inject +import javax.inject.Singleton +import java.util.concurrent.TimeUnit +import chat.rocket.android.util.ClientCertInterceptor + + +data class SslStuff(val privKey: PrivateKey, val certChain: Array) + +class SslTask @Inject constructor( + private val context: Application, + private val getClientCertInteractor: GetClientCertInteractor, + private val logger: HttpLoggingInterceptor, + private val basicAuthenticator: BasicAuthenticatorInterceptor, + private val clientCertHelper: ClientCertHelper +): AsyncTask() { + + override fun doInBackground(vararg params: Void?): SslStuff { + var alias = getClientCertInteractor.get() + alias = alias.toString() + val privKey = KeyChain.getPrivateKey(context.applicationContext, alias) + val certChain = KeyChain.getCertificateChain(context.applicationContext, alias) + + return SslStuff(privKey, certChain) + } + + override fun onPostExecute(result: SslStuff?) { + var alias = getClientCertInteractor.get() + if (result != null && !clientCertHelper.getSetSslSocket()) { + alias = alias.toString() + val (privateKey, certificates) = result + val trustStore = KeyStore.getInstance(KeyStore.getDefaultType()) + val keyManager = object : X509ExtendedKeyManager() { + override fun chooseClientAlias(strings: Array, principals: Array?, socket: Socket): String { + return alias + } + + override fun chooseServerAlias(s: String, principals: Array, socket: Socket): String { + return alias + } + + override fun getCertificateChain(s: String): Array? { + return certificates + } + + override fun getClientAliases(s: String, principals: Array): Array { + return arrayOf(alias) + } + + override fun getServerAliases(s: String, principals: Array): Array { + return arrayOf(alias) + } + + override fun getPrivateKey(s: String): PrivateKey? { + return privateKey + } + } + + val trustFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustFactory.init(trustStore) + + val tm = arrayOf(object : X509TrustManager { + @Throws(CertificateException::class) + override fun checkClientTrusted(chain: Array, authType: String) { + } + + @Throws(CertificateException::class) + override fun checkServerTrusted(chain: Array, authType: String) { + } + + override fun getAcceptedIssuers(): Array { + return certificates + } + + fun isClientTrusted(arg0: Array): Boolean { + return true + } + + fun isServerTrusted(arg0: Array): Boolean { + return true + } + + }) + + val sslContext = SSLContext.getInstance("TLS") + sslContext.init(arrayOf(keyManager), tm, null) + SSLContext.setDefault(sslContext) + + if (!clientCertHelper.getSetSslSocket()) { + clientCertHelper.setOkHttpClient(OkHttpClient.Builder() + .addInterceptor(logger) + .addInterceptor(basicAuthenticator) + .sslSocketFactory(sslContext.socketFactory) + .connectTimeout(15, TimeUnit.SECONDS) + .readTimeout(20, TimeUnit.SECONDS) + .writeTimeout(15, TimeUnit.SECONDS) + .build()) + } + } + } +} + + +@Singleton +class ClientCertHelper @Inject constructor( + private val context: Application, + private val getClientCertInteractor: GetClientCertInteractor, + private val logger: HttpLoggingInterceptor, + private val basicAuthenticator: BasicAuthenticatorInterceptor +) { + private val clientCert: ClientCertInterceptor = ClientCertInterceptor(this) + private var setSslSocket: Boolean = false + private var okHttpClient: OkHttpClient = OkHttpClient.Builder() + .addInterceptor(clientCert) + .addInterceptor(logger) + .addInterceptor(basicAuthenticator) + .connectTimeout(15, TimeUnit.SECONDS) + .readTimeout(20, TimeUnit.SECONDS) + .writeTimeout(15, TimeUnit.SECONDS) + .build() + + fun getEnabled(): Boolean { + var alias = getClientCertInteractor.get() + return !alias.isNullOrEmpty() + } + + fun getSetSslSocket(): Boolean { + return this.setSslSocket + } + + fun setOkHttpClient(client: OkHttpClient) { + this.okHttpClient = client + this.setSslSocket = true + } + + fun getClient(): OkHttpClient { + if (this.getEnabled() && !this.setSslSocket) { + SslTask(context, getClientCertInteractor, logger, basicAuthenticator, this).execute() + } + return this.okHttpClient + } +} diff --git a/app/src/main/java/chat/rocket/android/server/domain/ClientCertRepository.kt b/app/src/main/java/chat/rocket/android/server/domain/ClientCertRepository.kt new file mode 100644 index 0000000000..9b0f0661ef --- /dev/null +++ b/app/src/main/java/chat/rocket/android/server/domain/ClientCertRepository.kt @@ -0,0 +1,7 @@ +package chat.rocket.android.server.domain + +interface ClientCertRepository { + fun save(alias: String) + fun get(): String? + fun clear() +} diff --git a/app/src/main/java/chat/rocket/android/server/domain/GetClientInteractor.kt b/app/src/main/java/chat/rocket/android/server/domain/GetClientInteractor.kt new file mode 100644 index 0000000000..49ce8d10b9 --- /dev/null +++ b/app/src/main/java/chat/rocket/android/server/domain/GetClientInteractor.kt @@ -0,0 +1,11 @@ +package chat.rocket.android.server.domain + +import javax.inject.Inject + +class GetClientCertInteractor @Inject constructor(private val repository: ClientCertRepository) { + fun get(): String? = repository.get() + + fun clear() { + repository.clear() + } +} diff --git a/app/src/main/java/chat/rocket/android/server/domain/SaveClientInteractor.kt b/app/src/main/java/chat/rocket/android/server/domain/SaveClientInteractor.kt new file mode 100644 index 0000000000..33d87c68b5 --- /dev/null +++ b/app/src/main/java/chat/rocket/android/server/domain/SaveClientInteractor.kt @@ -0,0 +1,7 @@ +package chat.rocket.android.server.domain + +import javax.inject.Inject + +class SaveClientCertInteractor @Inject constructor(private val repository: ClientCertRepository) { + fun save(alias: String) = repository.save(alias) +} diff --git a/app/src/main/java/chat/rocket/android/server/infraestructure/ConnectionManagerFactory.kt b/app/src/main/java/chat/rocket/android/server/infraestructure/ConnectionManagerFactory.kt index d1f04094ee..d9359485b7 100644 --- a/app/src/main/java/chat/rocket/android/server/infraestructure/ConnectionManagerFactory.kt +++ b/app/src/main/java/chat/rocket/android/server/infraestructure/ConnectionManagerFactory.kt @@ -1,13 +1,14 @@ package chat.rocket.android.server.infraestructure import chat.rocket.android.db.DatabaseManagerFactory -import chat.rocket.android.infrastructure.LocalRepository +import chat.rocket.android.helper.ClientCertHelper import timber.log.Timber import javax.inject.Inject import javax.inject.Singleton @Singleton class ConnectionManagerFactory @Inject constructor( + private val clientCertHelper: ClientCertHelper, private val factory: RocketChatClientFactory, private val dbFactory: DatabaseManagerFactory ) { @@ -21,7 +22,13 @@ class ConnectionManagerFactory @Inject constructor( Timber.d("Returning FRESH Manager for: $url") val manager = ConnectionManager(factory.create(url), dbFactory.create(url)) - cache[url] = manager + if (clientCertHelper.getEnabled()) { + if (clientCertHelper.getSetSslSocket()) { + cache[url] = manager + } + } else { + cache[url] = manager + } return manager } diff --git a/app/src/main/java/chat/rocket/android/server/infraestructure/RocketChatClientFactory.kt b/app/src/main/java/chat/rocket/android/server/infraestructure/RocketChatClientFactory.kt index 3f0bb6ce77..7a71351cf2 100644 --- a/app/src/main/java/chat/rocket/android/server/infraestructure/RocketChatClientFactory.kt +++ b/app/src/main/java/chat/rocket/android/server/infraestructure/RocketChatClientFactory.kt @@ -2,17 +2,17 @@ package chat.rocket.android.server.infraestructure import android.os.Build import chat.rocket.android.BuildConfig +import chat.rocket.android.helper.ClientCertHelper import chat.rocket.android.server.domain.TokenRepository import chat.rocket.common.util.PlatformLogger import chat.rocket.core.RocketChatClient -import okhttp3.OkHttpClient import timber.log.Timber import javax.inject.Inject import javax.inject.Singleton @Singleton class RocketChatClientFactory @Inject constructor( - private val okHttpClient: OkHttpClient, + private val clientCertHelper: ClientCertHelper, private val repository: TokenRepository, private val logger: PlatformLogger ) { @@ -25,7 +25,7 @@ class RocketChatClientFactory @Inject constructor( } val client = RocketChatClient.create { - httpClient = okHttpClient + httpClient = clientCertHelper.getClient() restUrl = url userAgent = "RC Mobile; Android ${Build.VERSION.RELEASE}; v${BuildConfig.VERSION_NAME} (${BuildConfig.VERSION_CODE})" tokenRepository = repository @@ -34,7 +34,14 @@ class RocketChatClientFactory @Inject constructor( } Timber.d("Returning NEW client for: $url") - cache[url] = client + if (clientCertHelper.getEnabled()) { + if (clientCertHelper.getSetSslSocket()) { + cache[url] = client + } + } else { + cache[url] = client + } + return client } } \ No newline at end of file diff --git a/app/src/main/java/chat/rocket/android/server/infraestructure/SharedPrefsClientCertRepository.kt b/app/src/main/java/chat/rocket/android/server/infraestructure/SharedPrefsClientCertRepository.kt new file mode 100644 index 0000000000..c99b311b99 --- /dev/null +++ b/app/src/main/java/chat/rocket/android/server/infraestructure/SharedPrefsClientCertRepository.kt @@ -0,0 +1,23 @@ +package chat.rocket.android.server.infraestructure + +import android.content.SharedPreferences +import chat.rocket.android.server.domain.ClientCertRepository + +class SharedPrefsClientCertRepository(private val preferences: SharedPreferences) : ClientCertRepository { + + override fun save(alias: String) { + preferences.edit().putString(CLIENT_KEY, alias).apply() + } + + override fun get(): String? { + return preferences.getString(CLIENT_KEY, null) + } + + companion object { + private const val CLIENT_KEY = "" + } + + override fun clear() { + preferences.edit().remove(CLIENT_KEY).apply() + } +} diff --git a/app/src/main/java/chat/rocket/android/util/ClientCertInterceptor.kt b/app/src/main/java/chat/rocket/android/util/ClientCertInterceptor.kt new file mode 100644 index 0000000000..d4a595ff7e --- /dev/null +++ b/app/src/main/java/chat/rocket/android/util/ClientCertInterceptor.kt @@ -0,0 +1,31 @@ +package chat.rocket.android.util + +import chat.rocket.android.helper.ClientCertHelper +import okhttp3.Interceptor +import okhttp3.OkHttpClient +import okhttp3.Response +import java.io.IOException + +/** + * An OkHttp interceptor which waits for clientCert to be done before overriding the existing + * okHttpClient, if enabled. + * [application interceptor][OkHttpClient.interceptors] + * or as a [ ][OkHttpClient.networkInterceptors]. + */ +class ClientCertInterceptor( + private val clientCertHelper: ClientCertHelper +) : Interceptor { + @Volatile + internal var clientOverride: OkHttpClient? = null + + @Throws(IOException::class) + override fun intercept(chain: Interceptor.Chain): Response { + if (clientCertHelper.getSetSslSocket()) { + clientOverride = clientCertHelper.getClient() + } + val override = clientOverride + return if (override != null) { + override.newCall(chain.request()).execute() + } else chain.proceed(chain.request()) + } +} diff --git a/app/src/main/res/layout/fragment_authentication_server.xml b/app/src/main/res/layout/fragment_authentication_server.xml index ba32e82b50..34c7ae5f4e 100644 --- a/app/src/main/res/layout/fragment_authentication_server.xml +++ b/app/src/main/res/layout/fragment_authentication_server.xml @@ -74,6 +74,16 @@ app:layout_constraintStart_toStartOf="parent" app:layout_constraintTop_toBottomOf="@+id/server_url_container" /> + + Connect + Select Client Certificate Use this username Terms of Service Privacy Policy