From f0a98a39917d688273488e9b71ff5c6159c7dc5b Mon Sep 17 00:00:00 2001 From: Maarten Zuidhoorn Date: Mon, 6 Mar 2023 12:37:07 +0100 Subject: [PATCH 1/3] Allow zero private key for ed25519 --- src/BIP44CoinTypeNode.test.ts | 4 +--- src/SLIP10Node.test.ts | 30 ++++++++++++++++++++++++++++++ src/SLIP10Node.ts | 18 +++++++++++------- src/utils.test.ts | 33 +++++++++++++++++++++++++++++++++ src/utils.ts | 29 +++++++++++++++++++++++++++++ 5 files changed, 104 insertions(+), 10 deletions(-) diff --git a/src/BIP44CoinTypeNode.test.ts b/src/BIP44CoinTypeNode.test.ts index cf1a2aed..66b5e302 100644 --- a/src/BIP44CoinTypeNode.test.ts +++ b/src/BIP44CoinTypeNode.test.ts @@ -113,9 +113,7 @@ describe('BIP44CoinTypeNode', () => { for (const input of inputs) { await expect( BIP44CoinTypeNode.fromJSON(input as any, arbitraryCoinType), - ).rejects.toThrow( - 'Invalid value: Must be a non-zero 32-byte byte array.', - ); + ).rejects.toThrow('Invalid value: Must be a 32-byte byte array.'); } await expect( diff --git a/src/SLIP10Node.test.ts b/src/SLIP10Node.test.ts index 64318a7a..83eee137 100644 --- a/src/SLIP10Node.test.ts +++ b/src/SLIP10Node.test.ts @@ -70,6 +70,21 @@ describe('SLIP10Node', () => { expect(node.publicKeyBytes).toHaveLength(33); }); + it('initializes a new ed25519 node from a zero private key', async () => { + const node = await SLIP10Node.fromExtendedKey({ + privateKey: new Uint8Array(32).fill(0), + chainCode: new Uint8Array(32).fill(1), + depth: 0, + parentFingerprint: 0, + index: 0, + curve: 'ed25519', + }); + + expect(node.depth).toBe(0); + expect(node.privateKeyBytes).toStrictEqual(new Uint8Array(32).fill(0)); + expect(node.publicKeyBytes).toHaveLength(33); + }); + it('initializes a new node from a public key', async () => { const { publicKeyBytes, chainCodeBytes } = await deriveChildKey({ path: fixtures.local.mnemonic, @@ -263,6 +278,21 @@ describe('SLIP10Node', () => { 'Invalid value: Expected an instance of Uint8Array or hexadecimal string.', ); }); + + it('throws if the private key is zero for secp256k1', async () => { + await expect( + SLIP10Node.fromExtendedKey({ + privateKey: new Uint8Array(32).fill(0), + chainCode: new Uint8Array(32).fill(1), + depth: 0, + parentFingerprint: 0, + index: 0, + curve: 'secp256k1', + }), + ).rejects.toThrow( + 'Invalid private key: Value is not a valid secp256k1 private key.', + ); + }); }); describe('fromDerivationPath', () => { diff --git a/src/SLIP10Node.ts b/src/SLIP10Node.ts index 689808c2..25802cee 100644 --- a/src/SLIP10Node.ts +++ b/src/SLIP10Node.ts @@ -1,4 +1,4 @@ -import { bytesToHex } from '@metamask/utils'; +import { assert, bytesToHex } from '@metamask/utils'; import { BIP44CoinTypeNode } from './BIP44CoinTypeNode'; import { BIP44Node } from './BIP44Node'; @@ -12,6 +12,7 @@ import { deriveKeyFromPath } from './derivation'; import { publicKeyToEthAddress } from './derivers/bip32'; import { getBytes, + getBytesUnsafe, getFingerprint, isValidInteger, validateBIP32Index, @@ -162,8 +163,14 @@ export class SLIP10Node implements SLIP10NodeInterface { validateBIP32Index(index); validateParentFingerprint(parentFingerprint); + const curveObject = getCurveByName(curve); + if (privateKey) { - const privateKeyBytes = getBytes(privateKey, BYTES_KEY_LENGTH); + const privateKeyBytes = getBytesUnsafe(privateKey, BYTES_KEY_LENGTH); + assert( + curveObject.isValidPrivateKey(privateKeyBytes), + `Invalid private key: Value is not a valid ${curve} private key.`, + ); return new SLIP10Node({ depth, @@ -172,16 +179,13 @@ export class SLIP10Node implements SLIP10NodeInterface { index, chainCode: chainCodeBytes, privateKey: privateKeyBytes, - publicKey: await getCurveByName(curve).getPublicKey(privateKeyBytes), + publicKey: await curveObject.getPublicKey(privateKeyBytes), curve, }); } if (publicKey) { - const publicKeyBytes = getBytes( - publicKey, - getCurveByName(curve).publicKeyLength, - ); + const publicKeyBytes = getBytes(publicKey, curveObject.publicKeyLength); return new SLIP10Node({ depth, diff --git a/src/utils.test.ts b/src/utils.test.ts index e93ec41b..29af2af1 100644 --- a/src/utils.test.ts +++ b/src/utils.test.ts @@ -22,6 +22,7 @@ import { encodeBase58check, decodeBase58check, mnemonicPhraseToBytes, + getBytesUnsafe, } from './utils'; // Inputs used for testing non-negative integers @@ -308,6 +309,38 @@ describe('getBytes', () => { 'Invalid value: Must be a non-zero 1-byte byte array.', ); }); + + it('throws if the value is zero', () => { + expect(() => getBytes('0x00', 1)).toThrow( + 'Invalid value: Must be a non-zero 1-byte byte array.', + ); + + expect(() => getBytes(new Uint8Array(1).fill(0), 1)).toThrow( + 'Invalid value: Must be a non-zero 1-byte byte array.', + ); + }); +}); + +describe('getBytesUnsafe', () => { + it('returns a Uint8Array for a hexadecimal string', () => { + expect(getBytesUnsafe('0x1234', 2)).toStrictEqual(hexStringToBytes('1234')); + expect(getBytesUnsafe('1234', 2)).toStrictEqual(hexStringToBytes('1234')); + }); + + it('returns the same Uint8Array if a Uint8Array is passed', () => { + const bytes = hexStringToBytes('1234'); + expect(getBytesUnsafe(bytes, 2)).toBe(bytes); + }); + + it('throws if the length is invalid', () => { + expect(() => getBytesUnsafe('1234', 1)).toThrow( + 'Invalid value: Must be a 1-byte byte array.', + ); + + expect(() => getBytesUnsafe(hexStringToBytes('1234'), 1)).toThrow( + 'Invalid value: Must be a 1-byte byte array.', + ); + }); }); describe('encodeBase58Check', () => { diff --git a/src/utils.ts b/src/utils.ts index 4e513e6a..70b7b11b 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -275,6 +275,35 @@ export function getBytes(value: unknown, length: number): Uint8Array { ); } +/** + * Get a `Uint8Array` from a hexadecimal string or `Uint8Array`. Validates that + * the length of the `Uint8Array` matches the specified length. + * + * This function is "unsafe," in the sense that it does not validate that the + * `Uint8Array` is not empty (i.e., all bytes are zero). + * + * @param value - The value to convert to a `Uint8Array`. + * @param length - The length to validate the `Uint8Array` against. + * @returns The `Uint8Array` corresponding to the hexadecimal string. + */ +export function getBytesUnsafe(value: unknown, length: number): Uint8Array { + if (value instanceof Uint8Array) { + assert( + value.length === length, + `Invalid value: Must be a ${length}-byte byte array.`, + ); + return value; + } + + if (typeof value === 'string') { + return getBytesUnsafe(hexToBytes(value), length); + } + + throw new Error( + `Invalid value: Expected an instance of Uint8Array or hexadecimal string.`, + ); +} + /** * Validate that the specified `Uint8Array` is not empty and has the specified * length. From 591ecfeef220a45e10c190342a301961156309fb Mon Sep 17 00:00:00 2001 From: Maarten Zuidhoorn Date: Tue, 7 Mar 2023 10:34:51 +0100 Subject: [PATCH 2/3] Add more tests --- src/utils.test.ts | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/utils.test.ts b/src/utils.test.ts index 29af2af1..2dcee991 100644 --- a/src/utils.test.ts +++ b/src/utils.test.ts @@ -319,6 +319,12 @@ describe('getBytes', () => { 'Invalid value: Must be a non-zero 1-byte byte array.', ); }); + + it('throws if the value is not a Uint8Array or a hexadecimal string', () => { + expect(() => getBytes(1, 1)).toThrow( + 'Invalid value: Expected an instance of Uint8Array or hexadecimal string.', + ); + }); }); describe('getBytesUnsafe', () => { @@ -341,6 +347,12 @@ describe('getBytesUnsafe', () => { 'Invalid value: Must be a 1-byte byte array.', ); }); + + it('throws if the value is not a Uint8Array or a hexadecimal string', () => { + expect(() => getBytesUnsafe(1, 1)).toThrow( + 'Invalid value: Expected an instance of Uint8Array or hexadecimal string.', + ); + }); }); describe('encodeBase58Check', () => { From 46906e4d91bcdb552abe3433100d2cd59b28739c Mon Sep 17 00:00:00 2001 From: Maarten Zuidhoorn Date: Tue, 7 Mar 2023 10:38:19 +0100 Subject: [PATCH 3/3] Add more tests again --- src/utils.test.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/utils.test.ts b/src/utils.test.ts index 2dcee991..3bed2831 100644 --- a/src/utils.test.ts +++ b/src/utils.test.ts @@ -331,11 +331,15 @@ describe('getBytesUnsafe', () => { it('returns a Uint8Array for a hexadecimal string', () => { expect(getBytesUnsafe('0x1234', 2)).toStrictEqual(hexStringToBytes('1234')); expect(getBytesUnsafe('1234', 2)).toStrictEqual(hexStringToBytes('1234')); + expect(getBytesUnsafe('0000', 2)).toStrictEqual(hexStringToBytes('0000')); }); it('returns the same Uint8Array if a Uint8Array is passed', () => { const bytes = hexStringToBytes('1234'); expect(getBytesUnsafe(bytes, 2)).toBe(bytes); + + const zeroBytes = hexStringToBytes('0000'); + expect(getBytesUnsafe(zeroBytes, 2)).toBe(zeroBytes); }); it('throws if the length is invalid', () => {