diff --git a/pw_protobuf/public/pw_protobuf/stream_decoder.h b/pw_protobuf/public/pw_protobuf/stream_decoder.h index 12ab41dfff..e3efb9ab4a 100644 --- a/pw_protobuf/public/pw_protobuf/stream_decoder.h +++ b/pw_protobuf/public/pw_protobuf/stream_decoder.h @@ -578,6 +578,12 @@ class StreamDecoder { Status Advance(size_t end_position); + size_t RemainingBytes() { + return stream_bounds_.high < std::numeric_limits::max() + ? stream_bounds_.high - position_ + : std::numeric_limits::max(); + } + void CloseBytesReader(BytesReader& reader); void CloseNestedDecoder(StreamDecoder& nested); diff --git a/pw_protobuf/stream_decoder.cc b/pw_protobuf/stream_decoder.cc index 1d18bcc5bc..4f163f46cd 100644 --- a/pw_protobuf/stream_decoder.cc +++ b/pw_protobuf/stream_decoder.cc @@ -208,7 +208,8 @@ Status StreamDecoder::ReadFieldKey() { PW_DCHECK(field_consumed_); uint64_t varint = 0; - PW_TRY_ASSIGN(size_t bytes_read, varint::Read(reader_, &varint)); + PW_TRY_ASSIGN(size_t bytes_read, + varint::Read(reader_, &varint, RemainingBytes())); position_ += bytes_read; if (!FieldKey::IsValidKey(varint)) { @@ -220,7 +221,7 @@ Status StreamDecoder::ReadFieldKey() { if (current_field_.wire_type() == WireType::kDelimited) { // Read the length varint of length-delimited fields immediately to simplify // later processing of the field. - StatusWithSize sws = varint::Read(reader_, &varint); + StatusWithSize sws = varint::Read(reader_, &varint, RemainingBytes()); if (sws.IsOutOfRange()) { // Out of range indicates the end of the stream. As a value is expected // here, report it as a data loss and terminate the decode operation. @@ -260,7 +261,8 @@ Status StreamDecoder::SkipField() { switch (current_field_.wire_type()) { case WireType::kVarint: { // Consume the varint field; nothing more to skip afterward. - PW_TRY_ASSIGN(size_t bytes_read, varint::Read(reader_, &value)); + PW_TRY_ASSIGN(size_t bytes_read, + varint::Read(reader_, &value, RemainingBytes())); position_ += bytes_read; break; } @@ -286,6 +288,11 @@ Status StreamDecoder::SkipField() { return status_; } + if (RemainingBytes() < bytes_to_skip) { + status_ = Status::DataLoss(); + return status_; + } + PW_TRY(Advance(position_ + bytes_to_skip)); } @@ -310,7 +317,7 @@ Status StreamDecoder::ReadVarintField(std::span out, StatusWithSize StreamDecoder::ReadOneVarint(std::span out, VarintType decode_type) { uint64_t value; - StatusWithSize sws = varint::Read(reader_, &value); + StatusWithSize sws = varint::Read(reader_, &value, RemainingBytes()); if (sws.IsOutOfRange()) { // Out of range indicates the end of the stream. As a value is expected // here, report it as a data loss and terminate the decode operation. @@ -367,6 +374,11 @@ Status StreamDecoder::ReadFixedField(std::span out) { return status_; } + if (RemainingBytes() < out.size()) { + status_ = Status::DataLoss(); + return status_; + } + PW_TRY(reader_.Read(out)); position_ += out.size(); field_consumed_ = true; diff --git a/pw_protobuf/stream_decoder_test.cc b/pw_protobuf/stream_decoder_test.cc index 5206a2e03d..c3fd949885 100644 --- a/pw_protobuf/stream_decoder_test.cc +++ b/pw_protobuf/stream_decoder_test.cc @@ -558,6 +558,233 @@ TEST(StreamDecoder, Decode_Nested_InvalidField) { EXPECT_EQ(decoder.Next(), Status::DataLoss()); } +TEST(StreamDecoder, Decode_Nested_InvalidFieldKey) { + // clang-format off + constexpr uint8_t encoded_proto[] = { + // Submessage key=1, length=2 + 0x0a, 0x02, + // type=invalid... + 0xff, 0xff, + // End submessage + + // type=sint32, k=2, v=-13 + 0x10, 0x19, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + StreamDecoder decoder(reader); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(*decoder.FieldNumber(), 1u); + + { + StreamDecoder nested = decoder.GetNestedDecoder(); + EXPECT_EQ(nested.Next(), Status::DataLoss()); + + // Make sure that the nested decoder didn't run off the end of the + // submessage. + ASSERT_EQ(reader.Tell(), 4u); + } +} + +TEST(StreamDecoder, Decode_Nested_MissingDelimitedLength) { + // clang-format off + constexpr uint8_t encoded_proto[] = { + // Submessage key=1, length=1 + 0x0a, 0x01, + // Delimited field (bytes) key=1, length=missing... + 0x0a, + // End submessage + + // type=sint32, k=2, v=-13 + 0x10, 0x19, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + StreamDecoder decoder(reader); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(*decoder.FieldNumber(), 1u); + + { + StreamDecoder nested = decoder.GetNestedDecoder(); + EXPECT_EQ(nested.Next(), Status::DataLoss()); + + // Make sure that the nested decoder didn't run off the end of the + // submessage. + ASSERT_EQ(reader.Tell(), 3u); + } +} + +TEST(StreamDecoder, Decode_Nested_InvalidDelimitedLength) { + // clang-format off + constexpr uint8_t encoded_proto[] = { + // Submessage key=1, length=2 + 0x0a, 0x02, + // Delimited field (bytes) key=1, length=invalid... + 0x0a, 0xff, + // End submessage + + // type=sint32, k=2, v=-13 + 0x10, 0x19, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + StreamDecoder decoder(reader); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(*decoder.FieldNumber(), 1u); + + { + StreamDecoder nested = decoder.GetNestedDecoder(); + EXPECT_EQ(nested.Next(), Status::DataLoss()); + + // Make sure that the nested decoder didn't run off the end of the + // submessage. + ASSERT_EQ(reader.Tell(), 4u); + } +} + +TEST(StreamDecoder, Decode_Nested_InvalidVarint) { + // clang-format off + constexpr uint8_t encoded_proto[] = { + // Submessage key=1, length=2 + 0x0a, 0x02, + // type=uint32 key=1, value=invalid... + 0x08, 0xff, + // End submessage + + // type=sint32, k=2, v=-13 + 0x10, 0x19, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + StreamDecoder decoder(reader); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(*decoder.FieldNumber(), 1u); + + { + StreamDecoder nested = decoder.GetNestedDecoder(); + EXPECT_EQ(nested.Next(), OkStatus()); + ASSERT_EQ(*nested.FieldNumber(), 1u); + + Result uint32 = nested.ReadUint32(); + EXPECT_EQ(uint32.status(), Status::DataLoss()); + + // Make sure that the nested decoder didn't run off the end of the + // submessage. + ASSERT_EQ(reader.Tell(), 4u); + } +} + +TEST(StreamDecoder, Decode_Nested_SkipInvalidVarint) { + // clang-format off + constexpr uint8_t encoded_proto[] = { + // Submessage key=1, length=2 + 0x0a, 0x02, + // type=uint32 key=1, value=invalid... + 0x08, 0xff, + // End submessage + + // type=sint32, k=2, v=-13 + 0x10, 0x19, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + StreamDecoder decoder(reader); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(*decoder.FieldNumber(), 1u); + + { + StreamDecoder nested = decoder.GetNestedDecoder(); + EXPECT_EQ(nested.Next(), OkStatus()); + ASSERT_EQ(*nested.FieldNumber(), 1u); + + // Skip without reading. + EXPECT_EQ(nested.Next(), Status::DataLoss()); + + // Make sure that the nested decoder didn't run off the end of the + // submessage. + ASSERT_EQ(reader.Tell(), 4u); + } +} + +TEST(StreamDecoder, Decode_Nested_TruncatedFixed) { + // clang-format off + constexpr uint8_t encoded_proto[] = { + // Submessage key=1, length=2 + 0x0a, 0x03, + // type=fixed32 key=1, value=truncated... + 0x0d, 0x42, 0x00, + // End submessage + + // type=sint32, k=2, v=-13 + 0x10, 0x19, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + StreamDecoder decoder(reader); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(*decoder.FieldNumber(), 1u); + + { + StreamDecoder nested = decoder.GetNestedDecoder(); + EXPECT_EQ(nested.Next(), OkStatus()); + ASSERT_EQ(*nested.FieldNumber(), 1u); + + Result uint32 = nested.ReadFixed32(); + EXPECT_EQ(uint32.status(), Status::DataLoss()); + + // Make sure that the nested decoder didn't run off the end of the + // submessage. Note that this will not read the data at all in this case. + ASSERT_EQ(reader.Tell(), 3u); + } +} + +TEST(StreamDecoder, Decode_Nested_SkipTruncatedFixed) { + // clang-format off + constexpr uint8_t encoded_proto[] = { + // Submessage key=1, length=2 + 0x0a, 0x03, + // type=fixed32 key=1, value=truncated... + 0x0d, 0x42, 0x00, + // End submessage + + // type=sint32, k=2, v=-13 + 0x10, 0x19, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + StreamDecoder decoder(reader); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(*decoder.FieldNumber(), 1u); + + { + StreamDecoder nested = decoder.GetNestedDecoder(); + EXPECT_EQ(nested.Next(), OkStatus()); + ASSERT_EQ(*nested.FieldNumber(), 1u); + + // Skip without reading. + EXPECT_EQ(nested.Next(), Status::DataLoss()); + + // Make sure that the nested decoder didn't run off the end of the + // submessage. Note that this will be unable to skip the field without + // exceeding the range of the nested decoder, so it won't move the cursor. + ASSERT_EQ(reader.Tell(), 3u); + } +} + TEST(StreamDecoder, Decode_BytesReader) { // clang-format off constexpr uint8_t encoded_proto[] = { diff --git a/pw_varint/public/pw_varint/stream.h b/pw_varint/public/pw_varint/stream.h index 7ae3564131..c055bc52b0 100644 --- a/pw_varint/public/pw_varint/stream.h +++ b/pw_varint/public/pw_varint/stream.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include "pw_status/status_with_size.h" #include "pw_stream/stream.h" @@ -26,9 +27,14 @@ namespace varint { // // Returns the number of bytes read from the stream if successful, OutOfRange // if the varint does not fit in a int64_t / uint64_t or if the input is -// exhausted before the number terminates. Reads a maximum of 10 bytes. -StatusWithSize Read(stream::Reader& reader, int64_t* output); -StatusWithSize Read(stream::Reader& reader, uint64_t* output); +// exhausted before the number terminates. Reads a maximum of 10 bytes or +// max_size, whichever is smaller. +StatusWithSize Read(stream::Reader& reader, + int64_t* output, + size_t max_size = std::numeric_limits::max()); +StatusWithSize Read(stream::Reader& reader, + uint64_t* output, + size_t max_size = std::numeric_limits::max()); } // namespace varint } // namespace pw diff --git a/pw_varint/stream.cc b/pw_varint/stream.cc index 536cbbb2b1..a973a3714a 100644 --- a/pw_varint/stream.cc +++ b/pw_varint/stream.cc @@ -25,9 +25,9 @@ namespace pw { namespace varint { -StatusWithSize Read(stream::Reader& reader, int64_t* output) { +StatusWithSize Read(stream::Reader& reader, int64_t* output, size_t max_size) { uint64_t value = 0; - StatusWithSize count = Read(reader, &value); + StatusWithSize count = Read(reader, &value, max_size); if (!count.ok()) { return count; } @@ -36,7 +36,7 @@ StatusWithSize Read(stream::Reader& reader, int64_t* output) { return count; } -StatusWithSize Read(stream::Reader& reader, uint64_t* output) { +StatusWithSize Read(stream::Reader& reader, uint64_t* output, size_t max_size) { uint64_t value = 0; size_t count = 0; @@ -47,6 +47,14 @@ StatusWithSize Read(stream::Reader& reader, uint64_t* output) { return StatusWithSize::DataLoss(); } + if (count >= max_size) { + // Varint didn't fit within the range given; return OutOfRange() if + // max_size was 0, but DataLoss if we were reading something we thought + // was going to be a varint. + return count > 0 ? StatusWithSize::DataLoss() + : StatusWithSize::OutOfRange(); + } + std::byte b; if (auto result = reader.Read(std::span(&b, 1)); !result.ok()) { if (count > 0 && result.status().IsOutOfRange()) { diff --git a/pw_varint/stream_test.cc b/pw_varint/stream_test.cc index 7eed52c6a1..bd70694003 100644 --- a/pw_varint/stream_test.cc +++ b/pw_varint/stream_test.cc @@ -274,4 +274,30 @@ TEST(VarintRead, Errors) { } } +TEST(VarintRead, SizeLimit) { + uint64_t value = -1234; + + { + // buffer contains a valid varint, but we limit the read length to ensure + // that the final byte is not read, turning it into an error. + const auto buffer = MakeBuffer("\xff\xff\xff\xff\x0f"); + stream::MemoryReader reader(buffer); + const auto sws = Read(reader, &value, 4); + EXPECT_FALSE(sws.ok()); + EXPECT_EQ(sws.status(), Status::DataLoss()); + EXPECT_EQ(reader.Tell(), 4u); + } + + { + // If we tell varint::Read() to read zero bytes, it should always return + // OutOfRange() without moving the reader. + const auto buffer = MakeBuffer("\xff\xff\xff\xff\x0f"); + stream::MemoryReader reader(buffer); + const auto sws = Read(reader, &value, 0); + EXPECT_FALSE(sws.ok()); + EXPECT_EQ(sws.status(), Status::OutOfRange()); + EXPECT_EQ(reader.Tell(), 0u); + } +} + } // namespace pw::varint