Skip to content

Commit

Permalink
feat: check address strictly (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjchen7 authored May 24, 2022
1 parent 6023440 commit 922ffd1
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 11 deletions.
29 changes: 27 additions & 2 deletions address/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func ConvertPublicToAddress(mode Mode, publicKey string) (string, error) {
}

func Parse(address string) (*ParsedAddress, error) {
hrp, decoded, err := bech32.Decode(address)
encoding, hrp, decoded, err := bech32.Decode(address)
if err != nil {
return nil, err
}
Expand All @@ -181,14 +181,24 @@ func Parse(address string) (*ParsedAddress, error) {
var addressType Type
var script types.Script
if strings.HasPrefix(payload, "01") {
if encoding != bech32.BECH32 {
return nil, errors.New("payload header 0x01 should have encoding BECH32")
}
addressType = Short
if CodeHashIndexSingleSig == payload[2:4] {
if len(payload) != 44 {
return nil, errors.New("payload bytes length of secp256k1-sighash-all " +
"short address should be 22")
}
script = types.Script{
CodeHash: types.HexToHash(transaction.SECP256K1_BLAKE160_SIGHASH_ALL_TYPE_HASH),
HashType: types.HashTypeType,
Args: common.Hex2Bytes(payload[4:]),
}
} else if CodeHashIndexAnyoneCanPay == payload[2:4] {
if len(payload) < 44 || len(payload) > 48 {
return nil, errors.New("payload bytes length of acp short address should between 22-24")
}
script = types.Script{
HashType: types.HashTypeType,
Args: common.Hex2Bytes(payload[4:]),
Expand All @@ -198,28 +208,43 @@ func Parse(address string) (*ParsedAddress, error) {
} else {
script.CodeHash = types.HexToHash(utils.AnyoneCanPayCodeHashOnLina)
}
} else {
} else if CodeHashIndexMultisigSig == payload[2:4] {
if len(payload) != 44 {
return nil, errors.New("payload bytes length of secp256k1-multisig-all " +
"short address should be 22")
}
script = types.Script{
CodeHash: types.HexToHash(transaction.SECP256K1_BLAKE160_MULTISIG_ALL_TYPE_HASH),
HashType: types.HashTypeType,
Args: common.Hex2Bytes(payload[4:]),
}
} else {
return nil, errors.New("unknown code hash index " + payload[2:4])
}
} else if strings.HasPrefix(payload, "02") {
if encoding != bech32.BECH32 {
return nil, errors.New("payload header 0x02 should have encoding BECH32");
}
addressType = FullBech32
script = types.Script{
CodeHash: types.HexToHash(payload[2:66]),
HashType: types.HashTypeData,
Args: common.Hex2Bytes(payload[66:]),
}
} else if strings.HasPrefix(payload, "04") {
if encoding != bech32.BECH32 {
return nil, errors.New("payload header 0x04 should have encoding BECH32");
}
addressType = FullBech32
script = types.Script{
CodeHash: types.HexToHash(payload[2:66]),
HashType: types.HashTypeType,
Args: common.Hex2Bytes(payload[66:]),
}
} else if strings.HasPrefix(payload, "00") {
if encoding != bech32.BECH32M {
return nil, errors.New("payload header 0x00 should have encoding BECH32");
}
addressType = FullBech32m
script = types.Script{
CodeHash: types.HexToHash(payload[2:66]),
Expand Down
23 changes: 23 additions & 0 deletions address/address_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,26 @@ func TestConvertPublickeyToBech32mFullAddress(t *testing.T) {
}
assert.Equal(t, address, "ckb1qzda0cr08m85hc8jlnfp3zer7xulejywt49kt2rr0vthywaa50xwsqdnnw7qkdnnclfkg59uzn8umtfd2kwxceqxwquc4")
}

func TestInvalidAddressParse(t *testing.T) {
// These invalid address come form https://github.com/nervosnetwork/ckb-sdk-rust/pull/7/files
// INVALID bech32 encoding
_, err := Parse("ckb1qyqylv479ewscx3ms620sv34pgeuz6zagaaqh0knz7")
assert.NotNil(t, err)
// INVALID data length
_, err = Parse("ckb1qyqylv479ewscx3ms620sv34pgeuz6zagaarxdzvx03")
assert.NotNil(t, err)
// INVALID code hash index
_, err = Parse("ckb1qyg5lv479ewscx3ms620sv34pgeuz6zagaaqajch0c")
assert.NotNil(t, err)
// INVALID bech32m encoding
_, err = Parse("ckb1q2da0cr08m85hc8jlnfp3zer7xulejywt49kt2rr0vthywaa50xwsnajhch96rq68wrqn2tmhm")
assert.NotNil(t, err)
// Invalid ckb2021 format full address
_, err = Parse("ckb1qzda0cr08m85hc8jlnfp3zer7xulejywt49kt2rr0vthywaa50xwsq20k2lzuhgvrgacv4tmr88")
assert.NotNil(t, err)
_, err = Parse("ckb1qzda0cr08m85hc8jlnfp3zer7xulejywt49kt2rr0vthywaa50xwsqz0k2lzuhgvrgacvhcym08")
assert.NotNil(t, err)
_, err = Parse("ckb1qzda0cr08m85hc8jlnfp3zer7xulejywt49kt2rr0vthywaa50xwsqj0k2lzuhgvrgacvnhnzl8")
assert.NotNil(t, err)
}
24 changes: 16 additions & 8 deletions crypto/bech32/bech32.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,38 @@ const charset = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
var gen = []int{0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3}

const BECH32M_CONST = 0x2bc830a3
type Encoding uint
const (
BECH32 Encoding = iota
BECH32M
)

func Decode(bech string) (string, []byte, error) {
func Decode(bech string) (Encoding, string, []byte, error) {
for i := 0; i < len(bech); i++ {
if bech[i] < 33 || bech[i] > 126 {
return "", nil, fmt.Errorf("invalid character: '%c'", bech[i])
return BECH32, "", nil, fmt.Errorf("invalid character: '%c'", bech[i])
}
}

lower := strings.ToLower(bech)
upper := strings.ToUpper(bech)
if bech != lower && bech != upper {
return "", nil, errors.New("string not all lowercase or all uppercase")
return BECH32, "", nil, errors.New("string not all lowercase or all uppercase")
}

bech = lower

one := strings.LastIndexByte(bech, '1')
if one < 1 || one+7 > len(bech) {
return "", nil, fmt.Errorf("invalid index of 1")
return BECH32, "", nil, fmt.Errorf("invalid index of 1")
}

hrp := bech[:one]
data := bech[one+1:]

decoded, err := toBytes(data)
if err != nil {
return "", nil, errors.New(fmt.Sprintf("failed converting data to bytes: %v", err))
return BECH32, "", nil, errors.New(fmt.Sprintf("failed converting data to bytes: %v", err))
}

ints := make([]int, len(decoded))
Expand All @@ -48,7 +53,9 @@ func Decode(bech string) (string, []byte, error) {
polymod := append(bech32HrpExpand(hrp), ints...)
i := bech32Polymod(polymod)

var encoding Encoding
if i == 1 {
encoding = BECH32
if !bech32VerifyChecksum(hrp, decoded) {
moreInfo := ""
checksum := bech[len(bech)-6:]
Expand All @@ -57,9 +64,10 @@ func Decode(bech string) (string, []byte, error) {
if err == nil {
moreInfo = fmt.Sprintf("Expected %v, got %v.", expected, checksum)
}
return "", nil, errors.New("checksum failed. " + moreInfo)
return BECH32, "", nil, errors.New("checksum failed. " + moreInfo)
}
} else {
encoding = BECH32M
if !bech32VerifyChecksumWithBech32m(hrp, decoded) {
moreInfo := ""
checksum := bech[len(bech)-6:]
Expand All @@ -68,11 +76,11 @@ func Decode(bech string) (string, []byte, error) {
if err == nil {
moreInfo = fmt.Sprintf("Expected %v, got %v.", expected, checksum)
}
return "", nil, errors.New("checksum failed. " + moreInfo)
return BECH32M, "", nil, errors.New("checksum failed. " + moreInfo)
}
}

return hrp, decoded[:len(decoded)-6], nil
return encoding, hrp, decoded[:len(decoded)-6], nil
}

func Encode(hrp string, data []byte) (string, error) {
Expand Down
3 changes: 2 additions & 1 deletion crypto/bech32/bech32_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ func TestEncode(t *testing.T) {
}

func TestDecode(t *testing.T) {
hrp, decoded, err := Decode("ckb1qyqt705jmfy3r7jlvg88k87j0sksmhgduazqrr2qt2")
encoding, hrp, decoded, err := Decode("ckb1qyqt705jmfy3r7jlvg88k87j0sksmhgduazqrr2qt2")
if err != nil {
assert.Error(t, err)
}
assert.Equal(t, BECH32, encoding)
assert.Equal(t, "ckb", hrp)
assert.Equal(t, "0004000b1e0f14121b090411031e121f0c08070716071e120f1016101b17080d1c1d0200", hex.EncodeToString(decoded))
}

0 comments on commit 922ffd1

Please sign in to comment.