Skip to content
This repository has been archived by the owner on Jun 7, 2020. It is now read-only.

[FEATURE] Client Certificate authentication #2007

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package chat.rocket.android.authentication.server.presentation

import android.security.KeyChain
import android.security.KeyChainAliasCallback;
import chat.rocket.android.authentication.domain.model.LoginDeepLinkInfo
import chat.rocket.android.authentication.presentation.AuthenticationNavigator
import chat.rocket.android.core.behaviours.showMessage
import chat.rocket.android.core.lifecycle.CancelStrategy
import chat.rocket.android.authentication.ui.AuthenticationActivity
import chat.rocket.android.helper.ClientCertHelper
import chat.rocket.android.server.domain.SaveClientCertInteractor
import chat.rocket.android.server.domain.GetAccountsInteractor
import chat.rocket.android.server.domain.GetSettingsInteractor
import chat.rocket.android.server.domain.RefreshSettingsInteractor
Expand All @@ -23,9 +28,12 @@ class ServerPresenter @Inject constructor(
private val serverInteractor: SaveConnectingServerInteractor,
private val refreshSettingsInteractor: RefreshSettingsInteractor,
private val getAccountsInteractor: GetAccountsInteractor,
private val saveClientCertInteractor: SaveClientCertInteractor,
private val clientCertHelper: ClientCertHelper,
val settingsInteractor: GetSettingsInteractor,
val factory: RocketChatClientFactory
) : CheckServerPresenter(
val factory: RocketChatClientFactory,
internal val activity: AuthenticationActivity
) : KeyChainAliasCallback, CheckServerPresenter(
strategy = strategy,
factory = factory,
settingsInteractor = settingsInteractor,
Expand Down Expand Up @@ -117,4 +125,19 @@ class ServerPresenter @Inject constructor(
}
}
}

override fun alias(alias: String) {
saveClientCertInteractor.save(alias)
clientCertHelper.getClient()
}

fun requestClientCert() {
KeyChain.choosePrivateKeyAlias(activity,
this, // Callback
null, // Any key types.
null, // Any issuers.
null, // Any host
-1, // Any port
"RocketChat")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ class ServerFragment : Fragment(), ServerView {
button_connect.setOnClickListener {
presenter.checkServer("$protocol${text_server_url.textContent.sanitize()}")
}
button_tls_cert.setOnClickListener {
presenter.requestClientCert()
}
}

override fun showInvalidServerUrlMessage() =
Expand Down
26 changes: 15 additions & 11 deletions app/src/main/java/chat/rocket/android/dagger/module/AppModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import chat.rocket.android.dagger.qualifier.ForAuthentication
import chat.rocket.android.dagger.qualifier.ForMessages
import chat.rocket.android.db.DatabaseManager
import chat.rocket.android.db.DatabaseManagerFactory
import chat.rocket.android.helper.ClientCertHelper
import chat.rocket.android.helper.MessageParser
import chat.rocket.android.infrastructure.LocalRepository
import chat.rocket.android.infrastructure.SharedPreferencesLocalRepository
Expand All @@ -28,6 +29,7 @@ import chat.rocket.android.server.domain.AccountsRepository
import chat.rocket.android.server.domain.AnalyticsTrackingInteractor
import chat.rocket.android.server.domain.AnalyticsTrackingRepository
import chat.rocket.android.server.domain.ChatRoomsRepository
import chat.rocket.android.server.domain.ClientCertRepository
import chat.rocket.android.server.domain.CurrentServerRepository
import chat.rocket.android.server.domain.GetAccountInteractor
import chat.rocket.android.server.domain.GetAccountsInteractor
Expand All @@ -54,6 +56,7 @@ import chat.rocket.android.server.infraestructure.SharedPreferencesPermissionsRe
import chat.rocket.android.server.infraestructure.SharedPreferencesSettingsRepository
import chat.rocket.android.server.infraestructure.SharedPrefsAnalyticsTrackingRepository
import chat.rocket.android.server.infraestructure.SharedPrefsConnectingServerRepository
import chat.rocket.android.server.infraestructure.SharedPrefsClientCertRepository
import chat.rocket.android.server.infraestructure.SharedPrefsCurrentServerRepository
import chat.rocket.android.util.AppJsonAdapterFactory
import chat.rocket.android.util.HttpLoggingInterceptor
Expand All @@ -78,7 +81,6 @@ import okhttp3.OkHttpClient
import ru.noties.markwon.SpannableConfiguration
import ru.noties.markwon.spans.SpannableTheme
import timber.log.Timber
import java.util.concurrent.TimeUnit
import javax.inject.Named
import javax.inject.Singleton

Expand Down Expand Up @@ -123,25 +125,21 @@ class AppModule {

@Provides
@Singleton
fun provideOkHttpClient(logger: HttpLoggingInterceptor, basicAuthenticator: BasicAuthenticatorInterceptor): OkHttpClient {
return OkHttpClient.Builder()
.addInterceptor(logger)
.addInterceptor(basicAuthenticator)
.connectTimeout(15, TimeUnit.SECONDS)
.readTimeout(20, TimeUnit.SECONDS)
.writeTimeout(15, TimeUnit.SECONDS)
.build()
fun provideOkHttpClient(
clientCertHelper: ClientCertHelper
): OkHttpClient {
return clientCertHelper.getClient()
}

@Provides
@Singleton
fun provideImagePipelineConfig(
context: Context,
okHttpClient: OkHttpClient
clientCertHelper: ClientCertHelper
): ImagePipelineConfig {
val listeners = setOf(RequestLoggingListener())

return OkHttpImagePipelineConfigFactory.newBuilder(context, okHttpClient)
return OkHttpImagePipelineConfigFactory.newBuilder(context, clientCertHelper.getClient())
.setRequestListeners(listeners)
.setDownsampleEnabled(true)
.experiment().setPartialImageCachingEnabled(true).build()
Expand Down Expand Up @@ -307,6 +305,12 @@ class AppModule {
): AccountsRepository =
SharedPreferencesAccountsRepository(preferences, moshi)

@Provides
@Singleton
fun provideClientCertRepository(prefs: SharedPreferences): ClientCertRepository {
return SharedPrefsClientCertRepository(prefs)
}

@Provides
fun provideNotificationManager(context: Application) =
context.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager
Expand Down
161 changes: 161 additions & 0 deletions app/src/main/java/chat/rocket/android/helper/ClientCertHelper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package chat.rocket.android.helper

import android.app.Application
import android.os.AsyncTask
import android.security.KeyChain
import java.net.Socket
import java.security.KeyStore
import java.security.Principal
import java.security.PrivateKey
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import javax.net.ssl.X509ExtendedKeyManager
import javax.net.ssl.KeyManager
import chat.rocket.android.server.domain.GetClientCertInteractor
import chat.rocket.android.util.HttpLoggingInterceptor
import chat.rocket.android.util.BasicAuthenticatorInterceptor
import okhttp3.OkHttpClient
import javax.inject.Inject
import javax.inject.Singleton
import java.util.concurrent.TimeUnit
import chat.rocket.android.util.ClientCertInterceptor


data class SslStuff(val privKey: PrivateKey, val certChain: Array<X509Certificate>)

class SslTask @Inject constructor(
private val context: Application,
private val getClientCertInteractor: GetClientCertInteractor,
private val logger: HttpLoggingInterceptor,
private val basicAuthenticator: BasicAuthenticatorInterceptor,
private val clientCertHelper: ClientCertHelper
): AsyncTask<Void, Void, SslStuff>() {

override fun doInBackground(vararg params: Void?): SslStuff {
var alias = getClientCertInteractor.get()
alias = alias.toString()
val privKey = KeyChain.getPrivateKey(context.applicationContext, alias)
val certChain = KeyChain.getCertificateChain(context.applicationContext, alias)

return SslStuff(privKey, certChain)
}

override fun onPostExecute(result: SslStuff?) {
var alias = getClientCertInteractor.get()
if (result != null && !clientCertHelper.getSetSslSocket()) {
alias = alias.toString()
val (privateKey, certificates) = result
val trustStore = KeyStore.getInstance(KeyStore.getDefaultType())
val keyManager = object : X509ExtendedKeyManager() {
override fun chooseClientAlias(strings: Array<String>, principals: Array<Principal>?, socket: Socket): String {
return alias
}

override fun chooseServerAlias(s: String, principals: Array<Principal>, socket: Socket): String {
return alias
}

override fun getCertificateChain(s: String): Array<X509Certificate>? {
return certificates
}

override fun getClientAliases(s: String, principals: Array<Principal>): Array<String> {
return arrayOf(alias)
}

override fun getServerAliases(s: String, principals: Array<Principal>): Array<String> {
return arrayOf(alias)
}

override fun getPrivateKey(s: String): PrivateKey? {
return privateKey
}
}

val trustFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustFactory.init(trustStore)

val tm = arrayOf<X509TrustManager>(object : X509TrustManager {
@Throws(CertificateException::class)
override fun checkClientTrusted(chain: Array<X509Certificate>, authType: String) {
}

@Throws(CertificateException::class)
override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {
}

override fun getAcceptedIssuers(): Array<X509Certificate> {
return certificates
}

fun isClientTrusted(arg0: Array<X509Certificate>): Boolean {
return true
}

fun isServerTrusted(arg0: Array<X509Certificate>): Boolean {
return true
}

})

val sslContext = SSLContext.getInstance("TLS")
sslContext.init(arrayOf<KeyManager>(keyManager), tm, null)
SSLContext.setDefault(sslContext)

if (!clientCertHelper.getSetSslSocket()) {
clientCertHelper.setOkHttpClient(OkHttpClient.Builder()
.addInterceptor(logger)
.addInterceptor(basicAuthenticator)
.sslSocketFactory(sslContext.socketFactory)
.connectTimeout(15, TimeUnit.SECONDS)
.readTimeout(20, TimeUnit.SECONDS)
.writeTimeout(15, TimeUnit.SECONDS)
.build())
}
}
}
}


@Singleton
class ClientCertHelper @Inject constructor(
private val context: Application,
private val getClientCertInteractor: GetClientCertInteractor,
private val logger: HttpLoggingInterceptor,
private val basicAuthenticator: BasicAuthenticatorInterceptor
) {
private val clientCert: ClientCertInterceptor = ClientCertInterceptor(this)
private var setSslSocket: Boolean = false
private var okHttpClient: OkHttpClient = OkHttpClient.Builder()
.addInterceptor(clientCert)
.addInterceptor(logger)
.addInterceptor(basicAuthenticator)
.connectTimeout(15, TimeUnit.SECONDS)
.readTimeout(20, TimeUnit.SECONDS)
.writeTimeout(15, TimeUnit.SECONDS)
.build()

fun getEnabled(): Boolean {
var alias = getClientCertInteractor.get()
return !alias.isNullOrEmpty()
}

fun getSetSslSocket(): Boolean {
return this.setSslSocket
}

fun setOkHttpClient(client: OkHttpClient) {
this.okHttpClient = client
this.setSslSocket = true
}

fun getClient(): OkHttpClient {
if (this.getEnabled() && !this.setSslSocket) {
SslTask(context, getClientCertInteractor, logger, basicAuthenticator, this).execute()
}
return this.okHttpClient
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package chat.rocket.android.server.domain

interface ClientCertRepository {
fun save(alias: String)
fun get(): String?
fun clear()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package chat.rocket.android.server.domain

import javax.inject.Inject

class GetClientCertInteractor @Inject constructor(private val repository: ClientCertRepository) {
fun get(): String? = repository.get()

fun clear() {
repository.clear()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package chat.rocket.android.server.domain

import javax.inject.Inject

class SaveClientCertInteractor @Inject constructor(private val repository: ClientCertRepository) {
fun save(alias: String) = repository.save(alias)
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package chat.rocket.android.server.infraestructure

import chat.rocket.android.db.DatabaseManagerFactory
import chat.rocket.android.infrastructure.LocalRepository
import chat.rocket.android.helper.ClientCertHelper
import timber.log.Timber
import javax.inject.Inject
import javax.inject.Singleton

@Singleton
class ConnectionManagerFactory @Inject constructor(
private val clientCertHelper: ClientCertHelper,
private val factory: RocketChatClientFactory,
private val dbFactory: DatabaseManagerFactory
) {
Expand All @@ -21,7 +22,13 @@ class ConnectionManagerFactory @Inject constructor(

Timber.d("Returning FRESH Manager for: $url")
val manager = ConnectionManager(factory.create(url), dbFactory.create(url))
cache[url] = manager
if (clientCertHelper.getEnabled()) {
if (clientCertHelper.getSetSslSocket()) {
cache[url] = manager
}
} else {
cache[url] = manager
}
return manager
}

Expand Down
Loading