diff --git a/src/java.base/share/classes/sun/security/util/BitArray.java b/src/java.base/share/classes/sun/security/util/BitArray.java index cb788bec7ca..78aa26317df 100644 --- a/src/java.base/share/classes/sun/security/util/BitArray.java +++ b/src/java.base/share/classes/sun/security/util/BitArray.java @@ -63,22 +63,32 @@ public BitArray(int length) throws IllegalArgumentException { repn = new byte[(length + BITS_PER_UNIT - 1)/BITS_PER_UNIT]; } - /** * Creates a BitArray of the specified size, initialized from the - * specified byte array. The most significant bit of {@code a[0]} gets - * index zero in the BitArray. The array a must be large enough - * to specify a value for every bit in the BitArray. In other words, - * {@code 8*a.length <= length}. + * specified byte array. The most significant bit of {@code a[0]} gets + * index zero in the BitArray. The array must be large enough to specify + * a value for every bit of the BitArray. i.e. {@code 8*a.length <= length}. */ public BitArray(int length, byte[] a) throws IllegalArgumentException { + this(length, a, 0); + } + + /** + * Creates a BitArray of the specified size, initialized from the + * specified byte array starting at the specified offset. The most + * significant bit of {@code a[ofs]} gets index zero in the BitArray. + * The array must be large enough to specify a value for every bit of + * the BitArray, i.e. {@code 8*(a.length - ofs) <= length}. + */ + public BitArray(int length, byte[] a, int ofs) + throws IllegalArgumentException { if (length < 0) { throw new IllegalArgumentException("Negative length for BitArray"); } - if (a.length * BITS_PER_UNIT < length) { - throw new IllegalArgumentException("Byte array too short to represent " + - "bit array of given length"); + if ((a.length - ofs) * BITS_PER_UNIT < length) { + throw new IllegalArgumentException + ("Byte array too short to represent " + length + "-bit array"); } this.length = length; @@ -93,7 +103,7 @@ public BitArray(int length, byte[] a) throws IllegalArgumentException { 2. zero out extra bits in the last byte */ repn = new byte[repLength]; - System.arraycopy(a, 0, repn, 0, repLength); + System.arraycopy(a, ofs, repn, 0, repLength); if (repLength > 0) { repn[repLength - 1] &= bitMask; } @@ -270,7 +280,7 @@ public String toString() { public BitArray truncate() { for (int i=length-1; i>=0; i--) { if (get(i)) { - return new BitArray(i+1, Arrays.copyOf(repn, (i + BITS_PER_UNIT)/BITS_PER_UNIT)); + return new BitArray(i+1, repn, 0); } } return new BitArray(1); diff --git a/src/java.base/share/classes/sun/security/util/DerValue.java b/src/java.base/share/classes/sun/security/util/DerValue.java index 2739515d039..14be5ad44cd 100644 --- a/src/java.base/share/classes/sun/security/util/DerValue.java +++ b/src/java.base/share/classes/sun/security/util/DerValue.java @@ -689,6 +689,28 @@ public String getAsString() throws IOException { }; } + // check the number of pad bits, validate the pad bits in the bytes + // if enforcing DER (i.e. allowBER == false), and return the number of + // bits of the resulting BitString + private static int checkPaddedBits(int numOfPadBits, byte[] data, + int start, int end, boolean allowBER) throws IOException { + // number of pad bits should be from 0(min) to 7(max). + if ((numOfPadBits < 0) || (numOfPadBits > 7)) { + throw new IOException("Invalid number of padding bits"); + } + int lenInBits = ((end - start) << 3) - numOfPadBits; + if (lenInBits < 0) { + throw new IOException("Not enough bytes in BitString"); + } + + // padding bits should be all zeros for DER + if (!allowBER && numOfPadBits != 0 && + (data[end - 1] & (0xff >>> (8 - numOfPadBits))) != 0) { + throw new IOException("Invalid value of padding bits"); + } + return lenInBits; + } + /** * Returns an ASN.1 BIT STRING value, with the tag assumed implicit * based on the parameter. The bit string must be byte-aligned. @@ -705,18 +727,17 @@ public byte[] getBitString(boolean tagImplicit) throws IOException { } if (end == start) { throw new IOException("Invalid encoding: zero length bit string"); + } + data.pos = data.end; // Compatibility. Reach end. + int numOfPadBits = buffer[start]; - if ((numOfPadBits < 0) || (numOfPadBits > 7)) { - throw new IOException("Invalid number of padding bits"); - } - // minus the first byte which indicates the number of padding bits + checkPaddedBits(numOfPadBits, buffer, start + 1, end, allowBER); byte[] retval = Arrays.copyOfRange(buffer, start + 1, end); - if (numOfPadBits != 0) { - // get rid of the padding bits - retval[end - start - 2] &= (0xff << numOfPadBits); + if (allowBER && numOfPadBits != 0) { + // fix the potential non-zero padding bits + retval[retval.length - 1] &= (0xff << numOfPadBits); } - data.pos = data.end; // Compatibility. Reach end. return retval; } @@ -739,16 +760,11 @@ public BitArray getUnalignedBitString(boolean tagImplicit) throw new IOException("Invalid encoding: zero length bit string"); } data.pos = data.end; // Compatibility. Reach end. + int numOfPadBits = buffer[start]; - if ((numOfPadBits < 0) || (numOfPadBits > 7)) { - throw new IOException("Invalid number of padding bits"); - } - if (end == start + 1) { - return new BitArray(0); - } else { - return new BitArray(((end - start - 1) << 3) - numOfPadBits, - Arrays.copyOfRange(buffer, start + 1, end)); - } + int len = checkPaddedBits(numOfPadBits, buffer, start + 1, end, + allowBER); + return new BitArray(len, buffer, start + 1); } /** diff --git a/test/jdk/sun/security/util/DerInputBuffer/PaddedBitString.java b/test/jdk/sun/security/util/DerInputBuffer/PaddedBitString.java index 2e31159c25e..2e1f45e13c9 100644 --- a/test/jdk/sun/security/util/DerInputBuffer/PaddedBitString.java +++ b/test/jdk/sun/security/util/DerInputBuffer/PaddedBitString.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2002, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2002, 2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -26,52 +26,82 @@ * @bug 4511556 * @summary Verify BitString value containing padding bits is accepted. * @modules java.base/sun.security.util + * @library /test/lib */ - import java.io.*; -import java.util.Arrays; import java.math.BigInteger; +import java.util.Arrays; +import java.util.HexFormat; +import jdk.test.lib.Asserts; +import jdk.test.lib.Utils; +import sun.security.util.BitArray; import sun.security.util.DerInputStream; public class PaddedBitString { // Relaxed the BitString parsing routine to accept bit strings - // with padding bits, ex. treat DER_BITSTRING_PAD6 as the same - // bit string as DER_BITSTRING_NOPAD. + // with padding bits, ex. treat DER_BITSTRING_PAD6_b as the same + // bit string as DER_BITSTRING_PAD6_0/DER_BITSTRING_NOPAD. // Note: // 1. the number of padding bits has to be in [0...7] // 2. value of the padding bits is ignored - // bit string (01011101 11000000) - // With 6 padding bits (01011101 11001011) - private final static byte[] DER_BITSTRING_PAD6 = { 3, 3, 6, - (byte)0x5d, (byte)0xcb }; - // With no padding bits private final static byte[] DER_BITSTRING_NOPAD = { 3, 3, 0, (byte)0x5d, (byte)0xc0 }; + // With 6 zero padding bits (01011101 11000000) + private final static byte[] DER_BITSTRING_PAD6_0 = { 3, 3, 6, + (byte)0x5d, (byte)0xc0 }; - public static void main(String args[]) throws Exception { - byte[] ba0, ba1; - try { - DerInputStream derin = new DerInputStream(DER_BITSTRING_PAD6); - ba1 = derin.getBitString(); - } catch( IOException e ) { - e.printStackTrace(); - throw new Exception("Unable to parse BitString with 6 padding bits"); - } + // With 6 nonzero padding bits (01011101 11001011) + private final static byte[] DER_BITSTRING_PAD6_b = { 3, 3, 6, + (byte)0x5d, (byte)0xcb }; - try { - DerInputStream derin = new DerInputStream(DER_BITSTRING_NOPAD); - ba0 = derin.getBitString(); - } catch( IOException e ) { - e.printStackTrace(); - throw new Exception("Unable to parse BitString with no padding"); - } + // With 8 padding bits + private final static byte[] DER_BITSTRING_PAD8_0 = { 3, 3, 8, + (byte)0x5d, (byte)0xc0 }; + + private final static byte[] BITS = { (byte)0x5d, (byte)0xc0 }; + + static enum Type { + BIT_STRING, + UNALIGNED_BIT_STRING; + } + + public static void main(String args[]) throws Exception { + test(DER_BITSTRING_NOPAD, new BitArray(16, BITS)); + test(DER_BITSTRING_PAD6_0, new BitArray(10, BITS)); + test(DER_BITSTRING_PAD6_b, new BitArray(10, BITS)); + test(DER_BITSTRING_PAD8_0, null); + System.out.println("Tests Passed"); + } - if (Arrays.equals(ba1, ba0) == false ) { - throw new Exception("BitString comparison check failed"); + private static void test(byte[] in, BitArray ans) throws IOException { + System.out.println("Testing " + + HexFormat.of().withUpperCase().formatHex(in)); + for (Type t : Type.values()) { + DerInputStream derin = new DerInputStream(in); + boolean shouldPass = (ans != null); + switch (t) { + case BIT_STRING: + if (shouldPass) { + Asserts.assertTrue(Arrays.equals(ans.toByteArray(), + derin.getBitString())); + } else { + Utils.runAndCheckException(() -> derin.getBitString(), + IOException.class); + } + break; + case UNALIGNED_BIT_STRING: + if (shouldPass) { + Asserts.assertEQ(ans, derin.getUnalignedBitString()); + } else { + Utils.runAndCheckException(() -> + derin.getUnalignedBitString(), IOException.class); + } + break; + } } } }