Skip to content

Commit

Permalink
chore: session manager tests
Browse files Browse the repository at this point in the history
  • Loading branch information
elefantel committed Jul 9, 2024
1 parent 10ac83f commit 523b5d0
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 61 deletions.
9 changes: 6 additions & 3 deletions metamask-android-sdk/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ dependencies {
implementation 'com.google.android.material:material:1.9.0'
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
implementation 'androidx.test.ext:junit-ktx:1.1.5'
testImplementation 'junit:junit:4.13.2'
implementation "androidx.lifecycle:lifecycle-livedata-ktx:2.7.0"
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'

testImplementation 'junit:junit:4.13.2'
testImplementation 'org.jetbrains.kotlinx:kotlinx-coroutines-test:1.6.0'

androidTestImplementation 'androidx.test.ext:junit:1.2.1'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.6.1'
}

ext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import kotlinx.serialization.Serializable
import org.json.JSONObject
import java.lang.ref.WeakReference

internal class CommunicationClient(context: Context, callback: EthereumEventCallback?) {
internal class CommunicationClient(context: Context, callback: EthereumEventCallback?, private val logger: Logger = DefaultLogger) {

var sessionId: String = ""
private val keyExchange: KeyExchange = KeyExchange()
Expand Down Expand Up @@ -66,23 +66,23 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
messageService = IMessegeService.Stub.asInterface(service)
messageService?.registerCallback(messageServiceCallback)
isServiceConnected = true
Logger.log("CommunicationClient:: Service connected")
logger.log("CommunicationClient:: Service connected")
initiateKeyExchange()
}

override fun onServiceDisconnected(name: ComponentName?) {
messageService = null
isServiceConnected = false
Logger.error("CommunicationClient:: Service disconnected $name")
logger.error("CommunicationClient:: Service disconnected $name")
trackEvent(Event.SDK_DISCONNECTED)
}

override fun onBindingDied(name: ComponentName?) {
Logger.error("CommunicationClient:: binding died: $name")
logger.error("CommunicationClient:: binding died: $name")
}

override fun onNullBinding(name: ComponentName?) {
Logger.error("CommunicationClient:: null binding: $name")
logger.error("CommunicationClient:: null binding: $name")
}
}

Expand Down Expand Up @@ -148,17 +148,17 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall

when (json.optString(MessageType.TYPE.value)) {
MessageType.TERMINATE.value -> {
Logger.log("CommunicationClient:: Connection terminated by MetaMask")
logger.log("CommunicationClient:: Connection terminated by MetaMask")
unbindService()
keyExchange.reset()
}
MessageType.KEYS_EXCHANGED.value -> {
Logger.log("CommunicationClient:: Keys exchanged")
logger.log("CommunicationClient:: Keys exchanged")
keyExchange.complete()
sendOriginatorInfo()
}
MessageType.READY.value -> {
Logger.log("CommunicationClient:: Connection ready")
logger.log("CommunicationClient:: Connection ready")
isMetaMaskReady = true
resumeRequestJobs()
}
Expand All @@ -178,7 +178,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
handleEvent(dataJson)
}
} else {
Logger.log("CommunicationClient:: Received error $json")
logger.log("CommunicationClient:: Received error $json")
val id = json.optString("id")
val error = json.optString(MessageType.ERROR.value)
handleError(error, id)
Expand All @@ -188,7 +188,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}

private fun resumeRequestJobs() {
Logger.log("CommunicationClient:: Resuming jobs")
logger.log("CommunicationClient:: Resuming jobs")

while (requestJobs.isNotEmpty()) {
val job = requestJobs.removeFirstOrNull()
Expand All @@ -198,7 +198,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall

private fun queueRequestJob(job: () -> Unit) {
requestJobs.add(job)
Logger.log("CommunicationClient:: Queued job")
logger.log("CommunicationClient:: Queued job")
}

private fun clearPendingRequests() {
Expand Down Expand Up @@ -294,7 +294,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
if (result.isNotEmpty()) {
completeRequest(id, Result.Success.Item(result))
} else {
Logger.error("CommunicationClient:: Unexpected response: $data")
logger.error("CommunicationClient:: Unexpected response: $data")
}
}
EthereumMethod.METAMASK_BATCH.value -> {
Expand Down Expand Up @@ -322,7 +322,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
val errorCode = errorMap["code"] as? Double ?: -1
val code = errorCode.toInt()
val message = errorMap["message"] as? String ?: ErrorType.message(code)
Logger.error("CommunicationClient:: Got error $message")
logger.error("CommunicationClient:: Got error $message")
completeRequest(requestId, Result.Error(RequestError(code, message)))
return true
}
Expand All @@ -342,7 +342,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
val accountsJson = event.optString("params")
val accounts: List<String> = Gson().fromJson(accountsJson, object : TypeToken<List<String>>() {}.type)
accounts.getOrNull(0)?.let { account ->
Logger.error("CommunicationClient:: Event Updated to account $account")
logger.error("CommunicationClient:: Event Updated to account $account")
updateAccount(account)
}
}
Expand All @@ -355,7 +355,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}
}
else -> {
Logger.error("CommunicationClient:: Unexpected event: $event")
logger.error("CommunicationClient:: Unexpected event: $event")
}
}
}
Expand Down Expand Up @@ -389,7 +389,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
put(KeyExchange.TYPE, nextStep.type)
}.toString()

Logger.log("Sending key exchange ${nextStep.type}")
logger.log("Sending key exchange ${nextStep.type}")
sendKeyExchangeMesage(exchangeMessage)
}
}
Expand All @@ -402,7 +402,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
if (keyExchange.keysExchanged()) {
messageService?.sendMessage(bundle)
} else {
Logger.log("CommunicationClient::sendMessage keys not exchanged, queueing job")
logger.log("CommunicationClient::sendMessage keys not exchanged, queueing job")
queueRequestJob { messageService?.sendMessage(bundle) }
}
}
Expand All @@ -416,29 +416,29 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
queuedRequests[request.id] = SubmittedRequest(request, callback)
queueRequestJob { processRequest(request, callback) }
if (!requestedBindService) {
Logger.log("CommunicationClient:: sendRequest - not yet connected to metamask, binding service first")
logger.log("CommunicationClient:: sendRequest - not yet connected to metamask, binding service first")
bindService()
} else {
Logger.log("CommunicationClient:: sendRequest - not yet connected to metamask, waiting for service to bind")
logger.log("CommunicationClient:: sendRequest - not yet connected to metamask, waiting for service to bind")
}
} else if (!keyExchange.keysExchanged()) {
Logger.log("CommunicationClient:: sendRequest - keys not yet exchanged")
logger.log("CommunicationClient:: sendRequest - keys not yet exchanged")
queuedRequests[request.id] = SubmittedRequest(request, callback)
queueRequestJob { processRequest(request, callback) }
initiateKeyExchange()
} else {
if (isMetaMaskReady) {
processRequest(request, callback)
} else {
Logger.log("CommunicationClient::sendRequest - wallet is not ready, queueing request")
logger.log("CommunicationClient::sendRequest - wallet is not ready, queueing request")
queueRequestJob { processRequest(request, callback) }
sendOriginatorInfo()
}
}
}

private fun processRequest(request: RpcRequest, callback: (Result) -> Unit) {
Logger.log("CommunicationClient:: sending request $request")
logger.log("CommunicationClient:: sending request $request")
if (queuedRequests[request.id] != null) {
queuedRequests.remove(request.id)
}
Expand Down Expand Up @@ -467,7 +467,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
val requestInfo = RequestInfo("originator_info", originatorInfo)
val requestInfoJson = Gson().toJson(requestInfo)

Logger.log("CommunicationClient:: Sending originator info: $requestInfoJson")
logger.log("CommunicationClient:: Sending originator info: $requestInfoJson")

val payload = keyExchange.encrypt(requestInfoJson)

Expand All @@ -493,7 +493,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}

private fun bindService() {
Logger.log("CommunicationClient:: Binding service")
logger.log("CommunicationClient:: Binding service")
requestedBindService = true

val serviceIntent = Intent()
Expand All @@ -510,29 +510,29 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
serviceConnection,
Context.BIND_AUTO_CREATE)
} else {
Logger.error("App context null")
logger.error("App context null")
}
}

fun unbindService() {
requestedBindService = false

if (isServiceConnected) {
Logger.log("CommunicationClient:: unbindService")
logger.log("CommunicationClient:: unbindService")
appContextRef.get()?.unbindService(serviceConnection)
isServiceConnected = false
}
}

fun initiateKeyExchange() {
Logger.log("CommunicationClient:: Initiating key exchange")
logger.log("CommunicationClient:: Initiating key exchange")

val keyExchange = JSONObject().apply {
put(KeyExchange.PUBLIC_KEY, keyExchange.publicKey)
put(KeyExchange.TYPE, KeyExchangeMessageType.KEY_HANDSHAKE_SYN.name)
}

Logger.log("Sending key exchange ${KeyExchangeMessageType.KEY_HANDSHAKE_SYN}")
logger.log("Sending key exchange ${KeyExchangeMessageType.KEY_HANDSHAKE_SYN}")
sendKeyExchangeMesage(keyExchange.toString())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ private const val DEFAULT_SESSION_DURATION: Long = 30 * 24 * 3600 // 30 days def
class Ethereum (
private val context: Context,
private val dappMetadata: DappMetadata,
sdkOptions: SDKOptions? = null): EthereumEventCallback {
sdkOptions: SDKOptions? = null,
private val logger: Logger = DefaultLogger
): EthereumEventCallback {
private var connectRequestSent = false
private val communicationClient: CommunicationClient? by lazy {
CommunicationClient(context, null)
Expand Down Expand Up @@ -58,7 +60,7 @@ class Ethereum (
private var sessionDuration: Long = DEFAULT_SESSION_DURATION

override fun updateAccount(account: String) {
Logger.log("Ethereum:: Selected account changed: $account")
logger.log("Ethereum:: Selected account changed: $account")
_ethereumState.postValue(
currentEthereumState.copy(
selectedAddress = account,
Expand All @@ -68,7 +70,7 @@ class Ethereum (
}

override fun updateChainId(newChainId: String) {
Logger.log("Ethereum:: ChainId changed: $newChainId")
logger.log("Ethereum:: ChainId changed: $newChainId")
_ethereumState.postValue(
currentEthereumState.copy(
chainId = newChainId,
Expand Down Expand Up @@ -96,7 +98,7 @@ class Ethereum (
return
}

Logger.log("Ethereum:: connecting...")
logger.log("Ethereum:: connecting...")
communicationClient?.dappMetadata = dappMetadata
communicationClient?.ethereumEventCallbackRef = WeakReference(this)
communicationClient?.updateSessionDuration(sessionDuration)
Expand All @@ -118,7 +120,7 @@ class Ethereum (
*/

fun connectWith(request: EthereumRequest, callback: ((Result) -> Unit)? = null) {
Logger.log("Ethereum:: connecting with ${request.method}...")
logger.log("Ethereum:: connecting with ${request.method}...")
connectRequestSent = true
communicationClient?.dappMetadata = dappMetadata
communicationClient?.ethereumEventCallbackRef = WeakReference(this)
Expand Down Expand Up @@ -261,7 +263,7 @@ class Ethereum (
}

fun disconnect(clearSession: Boolean = false) {
Logger.log("Ethereum:: disconnecting...")
logger.log("Ethereum:: disconnecting...")
connectRequestSent = false
communicationClient?.resetState()
communicationClient?.unbindService()
Expand Down Expand Up @@ -293,7 +295,7 @@ class Ethereum (
}

private fun requestAccounts(callback: ((Result) -> Unit)? = null) {
Logger.log("Ethereum:: Requesting ethereum accounts")
logger.log("Ethereum:: Requesting ethereum accounts")
connectRequestSent = true

val accountsRequest = EthereumRequest(
Expand All @@ -304,7 +306,7 @@ class Ethereum (
}

fun sendRequest(request: RpcRequest, callback: ((Result) -> Unit)? = null) {
Logger.log("Ethereum:: Sending request $request")
logger.log("Ethereum:: Sending request $request")

if (!connectRequestSent) {
requestAccounts {
Expand All @@ -314,7 +316,7 @@ class Ethereum (
}

if (EthereumMethod.isReadOnly(request.method) && infuraProvider?.supportsChain(chainId) == true) {
Logger.log("Ethereum:: Using Infura API for method ${request.method} on chain $chainId")
logger.log("Ethereum:: Using Infura API for method ${request.method} on chain $chainId")
infuraProvider.makeRequest(request, chainId, dappMetadata, callback)
} else {
communicationClient?.sendRequest(request) { response ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import okhttp3.RequestBody.Companion.toRequestBody
import org.json.JSONObject
import java.io.IOException

internal class HttpClient {
internal class HttpClient(private val logger: Logger = DefaultLogger) {
private val client = OkHttpClient()
private var additionalHeaders: Headers = Headers.headersOf("Accept", "application/json", "Content-Type", "application/json")

Expand All @@ -32,7 +32,7 @@ internal class HttpClient {

client.newCall(request).enqueue(object: Callback {
override fun onFailure(call: Call, e: IOException) {
Log.e(TAG,"HttpClient: error ${e.message}")
logger.error("HttpClient: error ${e.message}")
if (callback != null) {
callback(null, e)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.metamask.androidsdk

import org.json.JSONObject

class InfuraProvider(private val infuraAPIKey: String) {
class InfuraProvider(private val infuraAPIKey: String, private val logger: Logger = DefaultLogger) {
val rpcUrls: Map<String, String> = mapOf(
// ###### Ethereum ######
// Mainnet
Expand Down Expand Up @@ -87,12 +87,12 @@ class InfuraProvider(private val infuraAPIKey: String) {

httpClient.newCall("${rpcUrls[chainId]}", parameters = params) { response, ioException ->
if (response != null) {
Logger.log("InfuraProvider:: response $response")
logger.log("InfuraProvider:: response $response")
try {
val result = JSONObject(response).optString("result") ?: ""
callback?.invoke(Result.Success.Item(result))
} catch (e: Exception) {
Logger.error("InfuraProvider:: error: ${e.message}")
logger.error("InfuraProvider:: error: ${e.message}")
callback?.invoke(Result.Error(RequestError(-1, response)))
}
} else if (ioException != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ data class KeyExchangeMessage(
val publicKey: String?
)

class KeyExchange(private val crypto: Encryption = Crypto()) {
class KeyExchange(private val crypto: Encryption = Crypto(), private val logger: Logger = DefaultLogger) {
companion object {
const val TYPE = "type"
const val PUBLIC_KEY = "public_key"
Expand Down Expand Up @@ -56,7 +56,7 @@ class KeyExchange(private val crypto: Encryption = Crypto()) {
}

fun complete() {
Logger.log("KeyExchange:: Key exchange complete")
logger.log("KeyExchange:: Key exchange complete")
setIsKeysExchanged(true)
}

Expand Down
Loading

0 comments on commit 523b5d0

Please sign in to comment.