diff --git a/src/asn1_type.rs b/src/asn1_type.rs index b79f208..eb43a74 100644 --- a/src/asn1_type.rs +++ b/src/asn1_type.rs @@ -21,14 +21,14 @@ fn clone_asn1_schema_obj<'py>(asn1_schema_obj: &'py PyAny, args: &PyTuple, kwarg } pub trait Decoder<'py> { - fn verify_raw(self: &Self) -> Option { - None + fn verify_raw(self: &Self) -> PyResult<()> { + Ok(()) } fn decode(self: &Self) -> PyResult<&'py PyAny>; - fn verify_decoded(self: &Self, _asn1_value: &PyAny) -> Option { - None + fn verify_decoded(self: &Self, _asn1_value: &PyAny) -> PyResult<()> { + Ok(()) } } @@ -44,21 +44,21 @@ impl<'py> BooleanDecoder<'py> { } impl<'py> Decoder<'py> for BooleanDecoder<'py> { - fn verify_raw(&self) -> Option { + fn verify_raw(&self) -> PyResult<()> { if self.step.tag().format() != tag::FORMAT_SIMPLE { - return Some(self.step.create_error("Invalid BOOLEAN value format")); + return Err(self.step.create_error("Invalid BOOLEAN value format")); } match self.step.value_substrate_len() { 1 => { if self.step.value_substrate()[0] != 0 && self.step.value_substrate()[0] != 0xFF { - Some(self.step.create_error("Non-canonical BOOLEAN encoding")) + Err(self.step.create_error("Non-canonical BOOLEAN encoding")) } else { - None + Ok(()) } } - l => Some(self.step.create_error(&format!("Invalid BOOLEAN value length of {} octets", l))) + l => Err(self.step.create_error(&format!("Invalid BOOLEAN value length of {} octets", l))) } } @@ -89,23 +89,23 @@ impl<'py> IntegerDecoder<'py> { } impl<'py> Decoder<'py> for IntegerDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { if self.step.tag().format() != tag::FORMAT_SIMPLE { - return Some(self.step.create_error(&format!("Invalid {} value format", self.type_name))); + return Err(self.step.create_error(&format!("Invalid {} value format", self.type_name))); } let value_substrate = self.step.value_substrate(); if value_substrate.len() == 0 { - return Some(self.step.create_error(&format!("Substrate under-run in {} value", self.type_name))) + return Err(self.step.create_error(&format!("Substrate under-run in {} value", self.type_name))) } else if value_substrate.len() >= 2 { if (value_substrate[0] == 0 && value_substrate[1] & 0x80 == 0) || (value_substrate[0] == 0xFF && value_substrate[1] & 0x80 != 0) { - return Some(self.step.create_error(&format!("Non-minimal {} encoding", self.type_name))) + return Err(self.step.create_error(&format!("Non-minimal {} encoding", self.type_name))) } } - None + Ok(()) } fn decode(self: &Self) -> PyResult<&'py PyAny> { @@ -126,22 +126,22 @@ impl<'py> BitStringDecoder<'py> { Self { step } } - fn check_named_bit_string(self: &Self, trailer_bit_count: u8, last_octet: u8) -> Option { + fn check_named_bit_string(self: &Self, trailer_bit_count: u8, last_octet: u8) -> PyResult<()> { let last_value_bit_mask = 1 << trailer_bit_count; if last_value_bit_mask & last_octet == 0 { - Some(self.step.create_error("Trailing zero bit in named BIT STRING")) + Err(self.step.create_error("Trailing zero bit in named BIT STRING")) } else { - None + Ok(()) } } } impl<'py> Decoder<'py> for BitStringDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { if self.step.tag().format() != tag::FORMAT_SIMPLE { - return Some(self.step.create_error("Invalid BIT STRING value format")); + return Err(self.step.create_error("Invalid BIT STRING value format")); } let value_substrate = self.step.value_substrate(); @@ -149,20 +149,20 @@ impl<'py> Decoder<'py> for BitStringDecoder<'py> { let value_substrate_len = value_substrate.len(); if value_substrate_len == 0 { - return Some(self.step.create_error("Substrate under-run in BIT STRING")); + return Err(self.step.create_error("Substrate under-run in BIT STRING")); } let trailer_bit_count = value_substrate[0]; if trailer_bit_count > 7 || (value_substrate_len == 1 && trailer_bit_count != 0) { - return Some(self.step.create_error(&format!("Invalid trailer length of {} bits in BIT STRING", trailer_bit_count))); + return Err(self.step.create_error(&format!("Invalid trailer length of {} bits in BIT STRING", trailer_bit_count))); } if value_substrate_len >= 2 { let trailer_bits = value_substrate[value_substrate_len - 1] & ((1 << trailer_bit_count) - 1); if trailer_bits != 0 { - return Some(self.step.create_error("Non-zero trailer value in BIT STRING")); + return Err(self.step.create_error("Non-zero trailer value in BIT STRING")); } if self.step.asn1_spec().getattr(intern![self.step.asn1_spec().py(), "namedValues"]).unwrap().is_true().unwrap() { @@ -172,7 +172,7 @@ impl<'py> Decoder<'py> for BitStringDecoder<'py> { } } - None + Ok(()) } fn decode(self: &Self) -> PyResult<&'py PyAny> { @@ -204,10 +204,10 @@ impl<'py> OctetStringDecoder<'py> { impl<'py> Decoder<'py> for OctetStringDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { match self.step.tag().format() { - tag::FORMAT_SIMPLE => None, - _ => Some(self.step.create_error("Invalid OCTET STRING value format")) + tag::FORMAT_SIMPLE => Ok(()), + _ => Err(self.step.create_error("Invalid OCTET STRING value format")) } } @@ -230,14 +230,14 @@ impl<'py> NullDecoder<'py> { } impl<'py> Decoder<'py> for NullDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { if self.step.tag().format() != tag::FORMAT_SIMPLE { - return Some(self.step.create_error("Invalid NULL value format")) + return Err(self.step.create_error("Invalid NULL value format")) } match self.step.value_substrate_len() { - 0 => None, - _ => Some(self.step.create_error("Invalid NULL value length")) + 0 => Ok(()), + _ => Err(self.step.create_error("Invalid NULL value length")) } } @@ -260,10 +260,10 @@ impl<'py> ObjectIdentifierDecoder<'py> { } impl<'py> Decoder<'py> for ObjectIdentifierDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { match self.step.tag().format() { - tag::FORMAT_SIMPLE => None, - _ => Some(self.step.create_error("Invalid OBJECT IDENTIFIER value format")) + tag::FORMAT_SIMPLE => Ok(()), + _ => Err(self.step.create_error("Invalid OBJECT IDENTIFIER value format")) } } @@ -295,10 +295,10 @@ impl<'py> CharacterStringDecoder<'py> { } impl<'py> Decoder<'py> for CharacterStringDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { match self.step.tag().format() { - tag::FORMAT_SIMPLE => None, - _ => Some(self.step.create_error(&format!("Invalid {} value format", self.type_name))) + tag::FORMAT_SIMPLE => Ok(()), + _ => Err(self.step.create_error(&format!("Invalid {} value format", self.type_name))) } } @@ -320,15 +320,15 @@ impl<'py> PrintableStringDecoder<'py> { } impl<'py> Decoder<'py> for PrintableStringDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { match self.step.tag().format() { tag::FORMAT_SIMPLE => (), - _ => return Some(self.step.create_error("Invalid PRINTABLESTRING value format")) + _ => return Err(self.step.create_error("Invalid PRINTABLESTRING value format")) }; match PrintableStringRef::new(self.step.value_substrate()) { - Ok(_) => None, - Err(e) => Some(self.step.create_error(&format!("Error decoding PRINTABLESTRING: {}", e.to_string()))) + Ok(_) => Ok(()), + Err(e) => Err(self.step.create_error(&format!("Error decoding PRINTABLESTRING: {}", e.to_string()))) } } @@ -340,8 +340,33 @@ impl<'py> Decoder<'py> for PrintableStringDecoder<'py> { } +fn check_consistency(step: DecodeStep, asn1_value: &PyAny) -> PyResult<()> { + let py = asn1_value.py(); + + match asn1_value.getattr(intern![py, "isInconsistent"]) { + Ok(o) => { + if o.is_true().unwrap() { + Err(step.create_error(&o.to_string())) + } + else { + Ok(()) + } + }, + Err(e) => Err(e) + } + +} + + +fn get_constructed_set_component_kwargs(m: NativeHelperModule) -> &PyDict { + m.module.getattr(intern![m.module.py(), CONSTRUCTED_SET_COMPONENT_KWARGS]).unwrap().downcast_exact().unwrap() +} +fn get_choice_set_component_kwargs(m: NativeHelperModule) -> &PyDict { + m.module.getattr(intern![m.module.py(), CHOICE_SET_COMPONENT_KWARGS]).unwrap().downcast_exact().unwrap() +} + pub struct SequenceDecoder<'py> { step: DecodeStep<'py> @@ -384,21 +409,11 @@ impl<'py> SequenceDecoder<'py> { } -fn get_constructed_set_component_kwargs(m: NativeHelperModule) -> &PyDict { - m.module.getattr(intern![m.module.py(), CONSTRUCTED_SET_COMPONENT_KWARGS]).unwrap().downcast_exact().unwrap() -} - - -fn get_choice_set_component_kwargs(m: NativeHelperModule) -> &PyDict { - m.module.getattr(intern![m.module.py(), CHOICE_SET_COMPONENT_KWARGS]).unwrap().downcast_exact().unwrap() -} - - impl<'py> Decoder<'py> for SequenceDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { match self.step.tag().format() { - tag::FORMAT_CONSTRUCTED => None, - _ => return Some(self.step.create_error("Invalid SEQUENCE value format")) + tag::FORMAT_CONSTRUCTED => Ok(()), + _ => return Err(self.step.create_error("Invalid SEQUENCE value format")) } } @@ -465,13 +480,8 @@ impl<'py> Decoder<'py> for SequenceDecoder<'py> { } } - fn verify_decoded(self: &Self, asn1_value: &PyAny) -> Option { - let py = asn1_value.py(); - - match asn1_value.getattr(intern![py, "isInconsistent"]) { - Ok(_) => None, - Err(e) => Some(e) - } + fn verify_decoded(self: &Self, asn1_value: &PyAny) -> PyResult<()> { + check_consistency(self.step, asn1_value) } } @@ -486,10 +496,10 @@ impl<'py> SequenceOfDecoder<'py> { } impl<'py> Decoder<'py> for SequenceOfDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { match self.step.tag().format() { - tag::FORMAT_CONSTRUCTED => None, - _ => return Some(self.step.create_error("Invalid SEQUENCE value format")) + tag::FORMAT_CONSTRUCTED => Ok(()), + _ => return Err(self.step.create_error("Invalid SEQUENCE value format")) } } @@ -526,13 +536,8 @@ impl<'py> Decoder<'py> for SequenceOfDecoder<'py> { Ok(asn1_object) } - fn verify_decoded(self: &Self, asn1_value: &PyAny) -> Option { - let py = asn1_value.py(); - - match asn1_value.getattr(intern![py, "isInconsistent"]) { - Ok(_) => None, - Err(e) => Some(e) - } + fn verify_decoded(self: &Self, asn1_value: &PyAny) -> PyResult<()> { + check_consistency(self.step, asn1_value) } } @@ -549,10 +554,10 @@ impl<'py> SetOfDecoder<'py> { impl<'py> Decoder<'py> for SetOfDecoder<'py> { - fn verify_raw(self: &Self) -> Option { + fn verify_raw(self: &Self) -> PyResult<()> { match self.step.tag().format() { - tag::FORMAT_CONSTRUCTED => None, - _ => return Some(self.step.create_error("Invalid SET value format")) + tag::FORMAT_CONSTRUCTED => Ok(()), + _ => return Err(self.step.create_error("Invalid SET value format")) } } @@ -601,13 +606,8 @@ impl<'py> Decoder<'py> for SetOfDecoder<'py> { Ok(asn1_object) } - fn verify_decoded(self: &Self, asn1_value: &PyAny) -> Option { - let py = asn1_value.py(); - - match asn1_value.getattr(intern![py, "isInconsistent"]) { - Ok(_) => None, - Err(e) => Some(e) - } + fn verify_decoded(self: &Self, asn1_value: &PyAny) -> PyResult<()> { + check_consistency(self.step, asn1_value) } } diff --git a/src/decoder.rs b/src/decoder.rs index 5dbda6e..9a33943 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -158,8 +158,8 @@ pub fn decode_asn1_spec_value(step: DecodeStep) -> PyResult<&PyAny> { }; match decoder.verify_raw() { - Some(e) => return Err(e), - None => () + Err(e) => return Err(e), + Ok(()) => () }; let decoded_result = decoder.decode(); @@ -168,8 +168,8 @@ pub fn decode_asn1_spec_value(step: DecodeStep) -> PyResult<&PyAny> { Err(e) => Err(e), Ok(decoded) => { match decoder.verify_decoded(decoded) { - None => Ok(decoded), - Some(e) => Err(e) + Err(e) => Err(e), + Ok(()) => Ok(decoded) } } } diff --git a/test.py b/test.py index 53e0ccf..4e12c10 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,6 @@ import binascii from pyasn1.type.univ import ObjectIdentifier, BitString, Integer, OctetString -from pyasn1.type import univ, char, useful +from pyasn1.type import univ, char, useful, constraint from pyasn1.type.char import PrintableString, UTF8String from pyasn1_fasder import decode_der from pyasn1_alt_modules import rfc5280 @@ -52,6 +52,17 @@ -----END CERTIFICATE----- ''' +class SequenceOfTest(univ.SequenceOf): + pass + + +SequenceOfTest.componentType = char.PrintableString() +SequenceOfTest.sizeSpec = constraint.ValueSizeConstraint(2, float('inf')) +print(repr(SequenceOfTest())) +d, _ = decode_der(binascii.unhexlify(b'3003130161'), asn1Spec=SequenceOfTest()) +print(repr(d)) +print(d.isInconsistent) + der = x509.load_pem_x509_certificate(trustwave.encode()).public_bytes(serialization.Encoding.DER) c = rfc5280.Certificate() diff --git a/tests/test_universal_constructed.py b/tests/test_universal_constructed.py new file mode 100644 index 0000000..4227422 --- /dev/null +++ b/tests/test_universal_constructed.py @@ -0,0 +1,51 @@ +import binascii + +import pytest +from pyasn1 import error +from pyasn1.error import PyAsn1Error +from pyasn1.type import univ, namedtype, char, useful, constraint + +from pyasn1_fasder import decode_der + + +MAX = float('inf') + + +def _wrapper(substrate_hex, asn1Spec): + return decode_der(binascii.unhexlify(substrate_hex), asn1Spec=asn1Spec) + + +class SequenceTest(univ.Sequence): + pass + + +SequenceTest.componentType = namedtype.NamedTypes( + namedtype.NamedType('first', char.PrintableString()), + namedtype.OptionalNamedType('optional', char.UTF8String()), + namedtype.DefaultedNamedType('default', useful.UTCTime()), + namedtype.NamedType('last', char.PrintableString()), +) + + +class SequenceOfTest(univ.SequenceOf): + pass + + +SequenceOfTest.componentType = char.PrintableString() +SequenceOfTest.sizeSpec = constraint.ValueSizeConstraint(2, MAX) + + +def test_sequenceof_one_element(): + decoded, _ = _wrapper(b'3003130141', SequenceOfTest()) + + assert str(decoded[0]) == 'A' + + +def test_sequenceof_empty(): + with pytest.raises(error.PyAsn1Error): + _wrapper(b'3000', SequenceOfTest()) + + + + +