Skip to content

Commit

Permalink
Fix: multi-account bug, same account used across multiple authorizati…
Browse files Browse the repository at this point in the history
…ons (#1012)

* an inefficient fix for same account accross diff auths

* fake app changes

* add unit test

* optimize
  • Loading branch information
Funkatronics authored Dec 3, 2024
1 parent 91a34c3 commit 5a6c4bc
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,9 @@ class MainViewModel(application: Application) : AndroidViewModel(application) {
signInPayload: SignInWithSolana.Payload? = null
): MobileWalletAdapterClient.AuthorizationResult {
val result = try {
client.authorize(identity, chain, signInPayload, uiState.value.sessionProtocolVersion!!)
client.authorize(identity, chain, signInPayload,
uiState.value.accounts?.map { it.publicKey },
uiState.value.sessionProtocolVersion!!)
} catch (e: MobileWalletAdapterUseCase.MobileWalletAdapterOperationFailedException) {
_uiState.update {
it.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ object MobileWalletAdapterUseCase {
identity: DappIdentity,
chain: String?,
signInPayload: SignInWithSolana.Payload?,
publicKeys: List<ByteArray>?,
protocolVersion: ProtocolVersion = ProtocolVersion.V1
): MobileWalletAdapterClient.AuthorizationResult = coroutineScope {
try {
runInterruptible(Dispatchers.IO) {
if (protocolVersion == ProtocolVersion.V1) {
client.authorize(
identity.uri, identity.iconRelativeUri, identity.name, chain,
null, null, null, signInPayload
null, null, publicKeys?.toTypedArray(), signInPayload
).get()!!
} else {
val cluster = when (chain) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,74 @@ class MainActivityTest {
assertTrue(signer.verifySignature(authResult.signInResult!!.signature))
}

@Test
fun authorizationFlow_SuccessfulReauthorizeSingleAccountMultiAuth() {
// given
val uiDevice = UiDevice.getInstance(InstrumentationRegistry.getInstrumentation())

val identity1Uri = Uri.parse("https://test1.com")
val identity1Name = "Test 1"
val identity2Uri = Uri.parse("https://test2.com")
val identity2Name = "Test 2"
val iconUri = Uri.parse("favicon.ico")
val chain = ProtocolContract.CHAIN_SOLANA_TESTNET

// simulate client side scenarios
val localAssociation1 = LocalAssociationScenario(Scenario.DEFAULT_CLIENT_TIMEOUT_MS)
val associationIntent1 = LocalAssociationIntentCreator.createAssociationIntent(
null,
localAssociation1.port,
localAssociation1.session
)

val localAssociation2 = LocalAssociationScenario(Scenario.DEFAULT_CLIENT_TIMEOUT_MS)
val associationIntent2 = LocalAssociationIntentCreator.createAssociationIntent(
null,
localAssociation2.port,
localAssociation2.session
)

// when
ActivityScenario.launch<MainActivity>(associationIntent1)

// First, simulate client 1 authorizing
// trigger authorization from client 1
var mwaClient = localAssociation1.start().get()
val authorization1 = mwaClient.authorize(identity1Uri, iconUri, identity1Name, chain,
null, null, null, null)

uiDevice.wait(Until.hasObject(By.res(FAKEWALLET_PACKAGE, "authorize")), WINDOW_CHANGE_TIMEOUT)

onView(withId(R.id.btn_authorize))
.check(matches(isDisplayed())).perform(click())

val accounts = authorization1.get().accounts.map { it.publicKey }
localAssociation1.close().get()

// Now, authorize client 2 for the same account (publickey) that was used with client 1
ActivityScenario.launch<MainActivity>(associationIntent2)

// trigger authorization from client 2
mwaClient = localAssociation2.start().get()
val authorization2 = mwaClient.authorize(identity2Uri, iconUri, identity2Name, chain,
null, null, accounts.toTypedArray(), null)

uiDevice.wait(Until.hasObject(By.res(FAKEWALLET_PACKAGE, "authorize")), WINDOW_CHANGE_TIMEOUT)

onView(withId(R.id.btn_authorize))
.check(matches(isDisplayed())).perform(click())

val authResult = authorization2.get()

// reauthorize - this is needed to trigger the auth lookup
val reauthResult = mwaClient.authorize(identity2Uri, iconUri, identity2Name, chain,
authResult.authToken, null, accounts.toTypedArray(), null).get()

// then
assertNotNull(reauthResult)
assertTrue(reauthResult.authToken == authResult.authToken)
}

@Test
fun signingFlow_SuccessfulSignMessagesMultiAccount() {
// given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import android.util.Log
import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope
import com.funkatronics.encoders.Base58
import com.funkatronics.encoders.Base64
import com.solana.mobilewalletadapter.common.ProtocolContract
import com.solana.mobilewalletadapter.common.protocol.SessionProperties
import com.solana.mobilewalletadapter.common.signin.SignInWithSolana
Expand Down Expand Up @@ -118,10 +119,19 @@ class MobileWalletAdapterViewModel(application: Application) : AndroidViewModel(
viewModelScope.launch {
if (authorized) {
val accounts = (0 until numAccounts).map {
val keypair = getApplication<FakeWalletApplication>().keyRepository.generateKeypair()
val publicKey = keypair.public as Ed25519PublicKeyParameters
Log.d(TAG, "Generated a new keypair (pub=${publicKey.encoded.contentToString()}) for authorize request")
buildAccount(publicKey.encoded, "fakewallet account $it")
val publicKeyBytes = request.request.addresses?.get(it)?.let { address ->
val keypair = getApplication<FakeWalletApplication>().keyRepository
.getKeypair(Base64.decode(address)) ?: return@let null
val publicKey = keypair.public as Ed25519PublicKeyParameters
Log.d(TAG, "Reusing known keypair (pub=${publicKey.encoded.contentToString()}) for authorize request")
publicKey.encoded
} ?: run {
val keypair = getApplication<FakeWalletApplication>().keyRepository.generateKeypair()
val publicKey = keypair.public as Ed25519PublicKeyParameters
Log.d(TAG, "Generated a new keypair (pub=${publicKey.encoded.contentToString()}) for authorize request")
publicKey.encoded
}
buildAccount(publicKeyBytes, "fakewallet account $it")
}
request.request.completeWithAuthorize(accounts.toTypedArray(), null,
request.sourceVerificationState.authorizationScope.encodeToByteArray(), null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,26 @@ public AccountRecord query(@NonNull byte[] publicKey) {
}
}

@Nullable
@Override
public AccountRecord query(long parentId, @NonNull byte[] publicKey) {
final SQLiteDatabase.CursorFactory accountCursorFactory = (db1, masterQuery, editTable, query) -> {
query.bindBlob(1, publicKey);
return new SQLiteCursor(masterQuery, editTable, query);
};
try (final Cursor cursor = super.queryWithFactory(accountCursorFactory,
TABLE_ACCOUNTS,
ACCOUNTS_COLUMNS,
COLUMN_ACCOUNTS_PUBLIC_KEY_RAW + "=? AND " +
COLUMN_ACCOUNTS_PARENT_ID + "=?" + parentId,
null)) {
if (!cursor.moveToNext()) {
return null;
}
return cursorToEntity(cursor);
}
}

@Override
public void deleteUnreferencedAccounts() {
final SQLiteStatement deleteUnreferencedAccounts = super.compileStatement(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ long insert(@NonNull long parentId, @NonNull byte[] publicKey,
@IntRange(from = 0)
long updateParentId(long oldParentId, long newParentId);

// @Nullable
// List<AccountRecord> query(int parentId);

@Nullable
AccountRecord query(@NonNull byte[] publicKey);

@Nullable
AccountRecord query(long parentId, @NonNull byte[] publicKey);

void deleteUnreferencedAccounts();
}
Original file line number Diff line number Diff line change
Expand Up @@ -402,20 +402,11 @@ public AuthRecord issue(@NonNull String name,
// Finally, try and look up the accounts
final List<AccountRecord> accountRecords = new ArrayList<>();
for (AuthorizedAccount account: accounts) {

final AccountRecord accountRecordQueried = mAccountsDao.query(account.publicKey);

final int accountId;
final AccountRecord accountRecord;
// If no matching account exists, create one
if (accountRecordQueried == null) {
accountId = (int) mAccountsDao.insert(authRecordId, account.publicKey,
// create an account record for each account in this auth record
final int accountId = (int) mAccountsDao.insert(authRecordId, account.publicKey,
account.accountLabel, account.accountIcon, account.chains, account.features);
accountRecord = new AccountRecord(accountId, authRecordId, account.publicKey,
final AccountRecord accountRecord = new AccountRecord(accountId, authRecordId, account.publicKey,
account.accountLabel, account.accountIcon, account.chains, account.features);
} else {
accountRecord = accountRecordQueried;
}
accountRecords.add(accountRecord);
}

Expand Down

0 comments on commit 5a6c4bc

Please sign in to comment.