Skip to content

Commit

Permalink
Address Recovery Working
Browse files Browse the repository at this point in the history
  • Loading branch information
arietrouw committed Nov 15, 2024
1 parent 7cf6689 commit d15400c
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class AccountTest {

val testVectorPrivateKey = "7f71bc5644f8f521f7e9b73f7a391e82c05432f8a9d36c44d6b1edbf1d8db62f"
val testVectorPublicKey
= "04ed6f3b86542f45aab88ec48ab1366b462bd993fec83e234054afd8f2311fba774800fdb40c04918463b463a6044b83413a604550bfba8f8911beb65475d6528e"
= "ed6f3b86542f45aab88ec48ab1366b462bd993fec83e234054afd8f2311fba774800fdb40c04918463b463a6044b83413a604550bfba8f8911beb65475d6528e"
val testVectorAddress = "5e7a847447e7fec41011ae7d32d768f86605ba03"
val testVectorHash = "4b688df40bcedbe641ddb16ff0a1842d9c67ea1c3bf63f3e0471baa664531d1a"
val testVectorSignature
Expand All @@ -30,7 +30,7 @@ class AccountTest {
fun testKnownPrivateKeyAccount() {
val account = Account.fromPrivateKey(hexStringToByteArray(testVectorPrivateKey))
assert(account.privateKey.count() == 32)
assert(account.publicKey.count() == 65)
assert(account.publicKey.count() == 64)
Log.i("publicKey", account.publicKey.toHexString())
Log.i("publicKeyVector", testVectorPublicKey)
Log.i("address", account.address.toHexString())
Expand Down
81 changes: 69 additions & 12 deletions sdk/src/main/java/network/xyo/client/account/Account.kt
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
package network.xyo.client.account

import network.xyo.client.XyoSerializable
import android.os.Build
import androidx.annotation.RequiresApi
import network.xyo.client.account.model.AccountInstance
import network.xyo.client.account.model.AccountStatic
import network.xyo.client.account.model.PreviousHashStore
import network.xyo.client.address.XyoData.Companion.copyByteArrayWithLeadingPaddingOrTrim
import network.xyo.client.address.XyoEllipticKey.Companion.CURVE
import org.spongycastle.crypto.digests.SHA256Digest
import org.spongycastle.crypto.params.ECPrivateKeyParameters
import org.spongycastle.crypto.signers.ECDSASigner
import org.spongycastle.crypto.signers.HMacDSAKCalculator
import org.spongycastle.jcajce.provider.digest.Keccak
import tech.figure.hdwallet.ec.CurvePoint
import tech.figure.hdwallet.ec.PrivateKey
import tech.figure.hdwallet.ec.PublicKey
import tech.figure.hdwallet.ec.extensions.toBytesPadded
import tech.figure.hdwallet.ec.secp256k1Curve
import tech.figure.hdwallet.signer.ASN1Signature
import tech.figure.hdwallet.signer.BCECSigner
import tech.figure.hdwallet.signer.BTCSignature
import tech.figure.hdwallet.signer.ECDSASignature
Expand All @@ -35,17 +31,27 @@ open class Account(private val _privateKey: PrivateKey, private var _previousHas
final override val privateKey: ByteArray
get() = _privateKey.key.toBytesPadded(32)
final override val publicKey: ByteArray
get() = publicKeyFromPrivateKey(BigInteger(privateKey))
get() = _privateKey.toPublicKey().key.toBytesPadded(64)

override fun sign(hash: ByteArray): ByteArray {
val result = BCECSigner().sign(_privateKey, hash)
_previousHash = hash
return result.encodeAsBTC().toByteArray()
}

@OptIn(ExperimentalStdlibApi::class)
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
override fun verify(msg: ByteArray, signature: ByteArray): Boolean {
return BCECSigner().verify(_privateKey.toPublicKey(), msg, ECDSASignature.Companion.decode(
BTCSignature.fromByteArray(signature)))
val recoveredPublicKey = recoverPublicKey(msg, signature)
val recoveredPublicKeyHex = recoveredPublicKey?.toHexString()
val recoveredAddress = if (recoveredPublicKey == null) null else publicKeyToAddress(recoveredPublicKey)
val recoveredAddressHex = recoveredAddress?.toHexString()
val expectedAddress = address.toHexString()
val expectedPublicKeyHex = publicKey.toHexString()
val publicKey = if (recoveredPublicKey == null) null else PublicKey.fromBytes(byteArrayOf(4.toByte()) + recoveredPublicKey)
return recoveredPublicKey != null && publicKey != null && BCECSigner().verify(publicKey, msg, ECDSASignature.Companion.decode(
BTCSignature.fromByteArray(signature))) &&
recoveredAddressHex == expectedAddress
}

companion object: AccountStatic<AccountInstance> {
Expand All @@ -60,7 +66,7 @@ open class Account(private val _privateKey: PrivateKey, private var _previousHas
}

fun addressFromPublicKey(key: ByteArray): ByteArray {
val publicKeyHash = toKeccak(key.copyOfRange(1, key.size))
val publicKeyHash = toKeccak(key)
return publicKeyHash.copyOfRange(12, publicKeyHash.size)
}

Expand All @@ -85,9 +91,60 @@ open class Account(private val _privateKey: PrivateKey, private var _previousHas
return secp256k1Curve.g.mul(private).encoded(false)
}

fun recoverPublicKeyFromSignature(signature: ByteArray, msgHash: ByteArray): ByteArray? {
val signObj = ECDSASignature.Companion.decode(BTCSignature.fromByteArray(signature))
require(signature.size == 64) { "Signature must be 64 bytes (r, s format)" }

// Load secp256k1 curve parameters
val g = secp256k1Curve.g
val n = secp256k1Curve.n

// Adjust v to be 0 or 1 for public key recovery
val recId = 1

// Calculate the x-coordinate of the R point
val x = signObj.r.add(n.multiply(BigInteger.valueOf(recId.toLong())))

// Check if x is valid on the curve
// if (x >= curve.field.characteristic) return null

// Create the point R by decompression
val xBytes = x.toByteArray()
val xBytesFinal = if (xBytes.size == 33) xBytes.sliceArray(1 until xBytes.size) else xBytes
val rPoint: CurvePoint = secp256k1Curve.decodePoint(byteArrayOf((2 + recId).toByte()) + xBytesFinal)

// Calculate e = HASH(message)
val e = BigInteger(1, msgHash)

// Calculate r^-1 mod n
val rInv = signObj.r.modInverse(n)

// Calculate s * R
val sR = rPoint.mul(signObj.s)

// Calculate (-e) * G
val negE = e.negate().mod(n)
val negEG = g.mul(negE)

// Calculate Q = r^-1 * (s * R + (-e) * G)
val q = sR.add(negEG).mul(rInv).normalize()

// Convert the recovered public key to uncompressed format
val encodedPublicKey = q.encoded(false)
return encodedPublicKey.sliceArray(1 until encodedPublicKey.size) // Use `true` for compressed format if desired
}

fun bytesToBigInteger(bb: ByteArray): BigInteger {
return if (bb.isEmpty()) BigInteger.ZERO else BigInteger(1, bb)
}


fun extractRSV(signature: ByteArray): Pair<BigInteger, BigInteger> {

val decodedSignature = ECDSASignature.decode(BTCSignature.fromByteArray(signature))

return Pair(decodedSignature.r, decodedSignature.s)
}
}
}

8 changes: 7 additions & 1 deletion sdk/src/main/java/network/xyo/client/account/Wallet.kt
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package network.xyo.client.account

import android.os.Build
import androidx.annotation.RequiresApi
import network.xyo.client.account.model.PreviousHashStore
import network.xyo.client.account.model.WalletInstance
import network.xyo.client.account.model.WalletStatic
import network.xyo.client.address.XyoAccount
import tech.figure.hdwallet.bip32.ExtKey
import tech.figure.hdwallet.bip32.toRootKey
import tech.figure.hdwallet.bip39.DeterministicSeed
import tech.figure.hdwallet.bip39.MnemonicWords
import tech.figure.hdwallet.ec.extensions.toBytesPadded

open class Wallet(private val _extKey: ExtKey, previousHash: ByteArray? = null) : Account(_extKey.keyPair.privateKey, previousHash), WalletInstance {
@RequiresApi(Build.VERSION_CODES.M)
open class Wallet(private val _extKey: ExtKey, previousHash: ByteArray? = null):
Account(_extKey.keyPair.privateKey.key.toBytesPadded(32), previousHash), WalletInstance {

override fun derivePath(path: String): WalletInstance {
return Wallet(_extKey.childKey(path))
Expand Down
123 changes: 123 additions & 0 deletions sdk/src/main/java/network/xyo/client/account/recoverAddress.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package network.xyo.client.account

import android.os.Build
import androidx.annotation.RequiresApi
import org.spongycastle.jcajce.provider.digest.Keccak
import java.math.BigInteger
import kotlin.experimental.xor

// Secp256k1 curve parameters
val P = BigInteger("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F", 16)
val A = BigInteger.ZERO
val B = BigInteger.valueOf(7)
val N = BigInteger("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", 16)
val Gx = BigInteger("79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798", 16)
val Gy = BigInteger("483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8", 16)

// Elliptic curve point structure
data class Point(val x: BigInteger, val y: BigInteger) {
fun isAtInfinity() = x == BigInteger.ZERO && y == BigInteger.ZERO
}

fun ByteArray.padStart(targetLength: Int, padValue: Byte = 0): ByteArray {
if (this.size >= targetLength) return this
val padding = ByteArray(targetLength - this.size) { padValue }
return padding + this
}

// Double and Add algorithm for point multiplication
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
fun pointMultiply(k: BigInteger, point: Point): Point {
var result = Point(BigInteger.ZERO, BigInteger.ZERO) // Infinity point
var addend = point
var scalar = k

while (scalar > BigInteger.ZERO) {
if (scalar.and(BigInteger.ONE) == BigInteger.ONE) {
result = pointAdd(result, addend)
}
addend = pointDouble(addend)
scalar = scalar.shiftRight(1)
}

return result
}

// Point addition on the elliptic curve
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
fun pointAdd(p: Point, q: Point): Point {
if (p.isAtInfinity()) return q
if (q.isAtInfinity()) return p

val slope = if (p.x == q.x) {
if ((p.y + q.y).mod(P) == BigInteger.ZERO) return Point(BigInteger.ZERO, BigInteger.ZERO)
(BigInteger.valueOf(3) * p.x.pow(2) + A).mod(P) * (BigInteger.TWO * p.y).modInverse(P)
} else {
(q.y - p.y).mod(P) * (q.x - p.x).modInverse(P)
}.mod(P)

val xR = (slope.pow(2) - p.x - q.x).mod(P)
val yR = (slope * (p.x - xR) - p.y).mod(P)

return Point(xR, yR)
}

// Point doubling on the elliptic curve
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
fun pointDouble(p: Point): Point = pointAdd(p, p)

// Recover public key from signature
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
fun recoverPublicKey(messageHash: BigInteger, r: BigInteger, s: BigInteger, v: Int): Point? {
val isYEven = (v % 2 == 0)
val x = r.add(N.multiply(BigInteger.valueOf((v / 2).toLong())))

if (x >= P) return null

// Calculate y-coordinate
val alpha = (x.pow(3) + A * x + B).mod(P)
val beta = alpha.modPow((P + BigInteger.ONE).shiftRight(2), P)
val y = if (beta.testBit(0) == isYEven) beta else P - beta

val rPoint = Point(x, y)

// Calculate e and r^-1
val e = messageHash
val rInv = r.modInverse(N)

// Public key Q = r^-1 * (s * R - e * G)
val sR = pointMultiply(s, rPoint)
val eG = pointMultiply(e, Point(Gx, Gy))

return pointMultiply(rInv, pointAdd(sR, Point(eG.x, P - eG.y)))
}

private fun toKeccak(bytes: ByteArray): ByteArray {
val keccak = Keccak.Digest256()
keccak.update(bytes)
return keccak.digest()
}

// Convert public key to Ethereum address
fun publicKeyToAddress(publicKey: ByteArray): ByteArray {
val hash = toKeccak(publicKey.sliceArray(1 until publicKey.size))
return hash.sliceArray(12 until hash.size)
}

// Main function to recover address
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
fun recoverPublicKey(messageHash: ByteArray, signature: ByteArray): ByteArray? {
if (signature.size != 64) return null

val r = BigInteger(1, signature.copyOfRange(0, 32))
val s = BigInteger(1, signature.copyOfRange(32, 64))
val v = 0

val messageHashBI = BigInteger(1, messageHash)

val publicPoint = recoverPublicKey(messageHashBI, r, s, v) ?: return null
val uncompressedKey =
publicPoint.x.toByteArray().padStart(32, 0) +
publicPoint.y.toByteArray().padStart(32, 0)
return uncompressedKey
}

0 comments on commit d15400c

Please sign in to comment.