diff --git a/src/BIP44CoinTypeNode.test.ts b/src/BIP44CoinTypeNode.test.ts index cf1a2aed..66b5e302 100644 --- a/src/BIP44CoinTypeNode.test.ts +++ b/src/BIP44CoinTypeNode.test.ts @@ -113,9 +113,7 @@ describe('BIP44CoinTypeNode', () => { for (const input of inputs) { await expect( BIP44CoinTypeNode.fromJSON(input as any, arbitraryCoinType), - ).rejects.toThrow( - 'Invalid value: Must be a non-zero 32-byte byte array.', - ); + ).rejects.toThrow('Invalid value: Must be a 32-byte byte array.'); } await expect( diff --git a/src/SLIP10Node.test.ts b/src/SLIP10Node.test.ts index 64318a7a..83eee137 100644 --- a/src/SLIP10Node.test.ts +++ b/src/SLIP10Node.test.ts @@ -70,6 +70,21 @@ describe('SLIP10Node', () => { expect(node.publicKeyBytes).toHaveLength(33); }); + it('initializes a new ed25519 node from a zero private key', async () => { + const node = await SLIP10Node.fromExtendedKey({ + privateKey: new Uint8Array(32).fill(0), + chainCode: new Uint8Array(32).fill(1), + depth: 0, + parentFingerprint: 0, + index: 0, + curve: 'ed25519', + }); + + expect(node.depth).toBe(0); + expect(node.privateKeyBytes).toStrictEqual(new Uint8Array(32).fill(0)); + expect(node.publicKeyBytes).toHaveLength(33); + }); + it('initializes a new node from a public key', async () => { const { publicKeyBytes, chainCodeBytes } = await deriveChildKey({ path: fixtures.local.mnemonic, @@ -263,6 +278,21 @@ describe('SLIP10Node', () => { 'Invalid value: Expected an instance of Uint8Array or hexadecimal string.', ); }); + + it('throws if the private key is zero for secp256k1', async () => { + await expect( + SLIP10Node.fromExtendedKey({ + privateKey: new Uint8Array(32).fill(0), + chainCode: new Uint8Array(32).fill(1), + depth: 0, + parentFingerprint: 0, + index: 0, + curve: 'secp256k1', + }), + ).rejects.toThrow( + 'Invalid private key: Value is not a valid secp256k1 private key.', + ); + }); }); describe('fromDerivationPath', () => { diff --git a/src/SLIP10Node.ts b/src/SLIP10Node.ts index 689808c2..25802cee 100644 --- a/src/SLIP10Node.ts +++ b/src/SLIP10Node.ts @@ -1,4 +1,4 @@ -import { bytesToHex } from '@metamask/utils'; +import { assert, bytesToHex } from '@metamask/utils'; import { BIP44CoinTypeNode } from './BIP44CoinTypeNode'; import { BIP44Node } from './BIP44Node'; @@ -12,6 +12,7 @@ import { deriveKeyFromPath } from './derivation'; import { publicKeyToEthAddress } from './derivers/bip32'; import { getBytes, + getBytesUnsafe, getFingerprint, isValidInteger, validateBIP32Index, @@ -162,8 +163,14 @@ export class SLIP10Node implements SLIP10NodeInterface { validateBIP32Index(index); validateParentFingerprint(parentFingerprint); + const curveObject = getCurveByName(curve); + if (privateKey) { - const privateKeyBytes = getBytes(privateKey, BYTES_KEY_LENGTH); + const privateKeyBytes = getBytesUnsafe(privateKey, BYTES_KEY_LENGTH); + assert( + curveObject.isValidPrivateKey(privateKeyBytes), + `Invalid private key: Value is not a valid ${curve} private key.`, + ); return new SLIP10Node({ depth, @@ -172,16 +179,13 @@ export class SLIP10Node implements SLIP10NodeInterface { index, chainCode: chainCodeBytes, privateKey: privateKeyBytes, - publicKey: await getCurveByName(curve).getPublicKey(privateKeyBytes), + publicKey: await curveObject.getPublicKey(privateKeyBytes), curve, }); } if (publicKey) { - const publicKeyBytes = getBytes( - publicKey, - getCurveByName(curve).publicKeyLength, - ); + const publicKeyBytes = getBytes(publicKey, curveObject.publicKeyLength); return new SLIP10Node({ depth, diff --git a/src/utils.test.ts b/src/utils.test.ts index e93ec41b..3bed2831 100644 --- a/src/utils.test.ts +++ b/src/utils.test.ts @@ -22,6 +22,7 @@ import { encodeBase58check, decodeBase58check, mnemonicPhraseToBytes, + getBytesUnsafe, } from './utils'; // Inputs used for testing non-negative integers @@ -308,6 +309,54 @@ describe('getBytes', () => { 'Invalid value: Must be a non-zero 1-byte byte array.', ); }); + + it('throws if the value is zero', () => { + expect(() => getBytes('0x00', 1)).toThrow( + 'Invalid value: Must be a non-zero 1-byte byte array.', + ); + + expect(() => getBytes(new Uint8Array(1).fill(0), 1)).toThrow( + 'Invalid value: Must be a non-zero 1-byte byte array.', + ); + }); + + it('throws if the value is not a Uint8Array or a hexadecimal string', () => { + expect(() => getBytes(1, 1)).toThrow( + 'Invalid value: Expected an instance of Uint8Array or hexadecimal string.', + ); + }); +}); + +describe('getBytesUnsafe', () => { + it('returns a Uint8Array for a hexadecimal string', () => { + expect(getBytesUnsafe('0x1234', 2)).toStrictEqual(hexStringToBytes('1234')); + expect(getBytesUnsafe('1234', 2)).toStrictEqual(hexStringToBytes('1234')); + expect(getBytesUnsafe('0000', 2)).toStrictEqual(hexStringToBytes('0000')); + }); + + it('returns the same Uint8Array if a Uint8Array is passed', () => { + const bytes = hexStringToBytes('1234'); + expect(getBytesUnsafe(bytes, 2)).toBe(bytes); + + const zeroBytes = hexStringToBytes('0000'); + expect(getBytesUnsafe(zeroBytes, 2)).toBe(zeroBytes); + }); + + it('throws if the length is invalid', () => { + expect(() => getBytesUnsafe('1234', 1)).toThrow( + 'Invalid value: Must be a 1-byte byte array.', + ); + + expect(() => getBytesUnsafe(hexStringToBytes('1234'), 1)).toThrow( + 'Invalid value: Must be a 1-byte byte array.', + ); + }); + + it('throws if the value is not a Uint8Array or a hexadecimal string', () => { + expect(() => getBytesUnsafe(1, 1)).toThrow( + 'Invalid value: Expected an instance of Uint8Array or hexadecimal string.', + ); + }); }); describe('encodeBase58Check', () => { diff --git a/src/utils.ts b/src/utils.ts index 4e513e6a..70b7b11b 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -275,6 +275,35 @@ export function getBytes(value: unknown, length: number): Uint8Array { ); } +/** + * Get a `Uint8Array` from a hexadecimal string or `Uint8Array`. Validates that + * the length of the `Uint8Array` matches the specified length. + * + * This function is "unsafe," in the sense that it does not validate that the + * `Uint8Array` is not empty (i.e., all bytes are zero). + * + * @param value - The value to convert to a `Uint8Array`. + * @param length - The length to validate the `Uint8Array` against. + * @returns The `Uint8Array` corresponding to the hexadecimal string. + */ +export function getBytesUnsafe(value: unknown, length: number): Uint8Array { + if (value instanceof Uint8Array) { + assert( + value.length === length, + `Invalid value: Must be a ${length}-byte byte array.`, + ); + return value; + } + + if (typeof value === 'string') { + return getBytesUnsafe(hexToBytes(value), length); + } + + throw new Error( + `Invalid value: Expected an instance of Uint8Array or hexadecimal string.`, + ); +} + /** * Validate that the specified `Uint8Array` is not empty and has the specified * length.