Skip to content

Commit

Permalink
C++ can read !!binary into vector<byte>
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanCurtis-TRI committed Dec 17, 2024
1 parent 12a08ee commit 2af36ef
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 1 deletion.
14 changes: 14 additions & 0 deletions common/yaml/test/example_structs.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ bool operator==(const StringStruct& a, const StringStruct& b) {
return a.value == b.value;
}

struct BytesStruct {
template <typename Archive>
void Serialize(Archive* a) {
a->Visit(DRAKE_NVP(value));
}

std::vector<std::byte> value{std::byte(0), std::byte(1), std::byte(2)};
};

// This is used only for EXPECT_EQ, not by any YAML operations.
bool operator==(const BytesStruct& a, const BytesStruct& b) {
return a.value == b.value;
}

struct PathStruct {
template <typename Archive>
void Serialize(Archive* a) {
Expand Down
86 changes: 86 additions & 0 deletions common/yaml/test/yaml_read_archive_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,92 @@ TEST_P(YamlReadArchiveTest, DoubleMissing) {
EXPECT_EQ(x.value, kNominalDouble);
}

TEST_P(YamlReadArchiveTest, Bytes) {
const auto test = [](const std::string& value, const std::string& expected) {
const std::byte* data = reinterpret_cast<const std::byte*>(expected.data());
std::vector<std::byte> expected_bytes(data, data + expected.size());
const auto& x = AcceptNoThrow<BytesStruct>(LoadSingleValue(value));
EXPECT_EQ(x.value, expected_bytes)
<< "Expected string: '" << expected << "'";
};

// Using !!binary on a schema whose type is bytes.
test("!!binary A3Rlc3Rfc3RyAw==", "\x03test_str\x03");
// Note: The number of spaces is critical to producing proper formatted yaml.
test("!!binary |\n A3Rlc3Rfc3RyAw==", "\x03test_str\x03");
test("!!binary |\n A3Rlc3R\n fc3RyAw==", "\x03test_str\x03");
test("!!binary ", "");

// Malformed base64 value.
{
// Proper encoding of "\x03t_str\x03" is 'A3Rfc3RyAw=='.

// Missing character.
DRAKE_EXPECT_THROWS_MESSAGE(
AcceptIntoDummy<BytesStruct>(LoadSingleValue("!!binary A3Rfc3RyAw=")),
"Expected a base64-encoded value.*error decoding.*");

// Invalid character.
DRAKE_EXPECT_THROWS_MESSAGE(
AcceptIntoDummy<BytesStruct>(LoadSingleValue("!!binary A3Rfc*RyAw==")),
"Expected a base64-encoded value.*error decoding.*");
}

// Assigning any other type to bytes is rejected.
// Note: these various value strings should be converted to various primitive
// types (string, int, etc.) before we process the scalar value. However,
// this doesn't currently happen so the error message can't complain that the
// wrong *type* has been passed to value. When we aggressively convert and
// check for mismatch, these error matches will shift to match what is done
// in yaml.py.
const auto reject = [](const std::string& value, std::string_view tag = {}) {
std::string tagged_str =
tag.empty() ? value : fmt::format("{} {}", tag, value);
DRAKE_EXPECT_THROWS_MESSAGE(
AcceptIntoDummy<BytesStruct>(LoadSingleValue(tagged_str)),
fmt::format(".*yaml value must be base64.*{}.*", value));
};
// String.
reject("test string");
reject("1234", "!!str");
// Int.
reject("12", "!!int");
reject("0x3");
reject("0o3");
reject("00:03");
// Float.
reject("1234.5");
reject("1234.5", "!!float");
reject(".inf");
reject("00:03.3");
// Bool
reject("true");

// Null is a special case; it is a non-scalar so doesn't get treated as the
// other scalar types.
const auto reject_non_scalar = [](const std::string& value) {
DRAKE_EXPECT_THROWS_MESSAGE(
AcceptIntoDummy<BytesStruct>(LoadSingleValue(value)),
".*has non-Scalar.*");
};
// Null
reject_non_scalar("null");
reject_non_scalar("");

// Using !!binary for non-binary types is bad.
const auto reject_bad_target = []<class T>(const std::string& value,
const T&) {
DRAKE_EXPECT_THROWS_MESSAGE(
AcceptIntoDummy<T>(LoadSingleValue(value)),
".*!!binary tag can only be applied to.*byte.*");
}; // NOLINT -- templated lambda confuses cpplint about the semicolon.
// These all use valid base64 encoding of what would otherwise be valid string
// values for the serializable type.
reject_bad_target("!!binary LmluZg==", DoubleStruct{}); // .inf
reject_bad_target("!!binary MTIzNC41", DoubleStruct{}); // 1234.5
reject_bad_target("!!binary dGVzdC9wYXRo", PathStruct{}); // test/path
}

TEST_P(YamlReadArchiveTest, Path) {
const auto test_valid = [](const std::string& value,
const std::string& expected) {
Expand Down
3 changes: 3 additions & 0 deletions common/yaml/yaml_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ class Node final {
// https://yaml.org/spec/1.2.2/#generic-string
static constexpr std::string_view kTagStr{"tag:yaml.org,2002:str"};

// https://yaml.org/spec/1.2.2/#generic-string
static constexpr std::string_view kTagBinary{"tag:yaml.org,2002:binary"};

/* Sets the filename where this Node was read from. A nullopt indicates that
the filename is not known. */
void SetFilename(std::optional<std::string> filename);
Expand Down
50 changes: 50 additions & 0 deletions common/yaml/yaml_read_archive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,49 @@ void YamlReadArchive::ParseScalar(const std::string& value,
*result = value;
}

std::string YamlReadArchive::DecodeBase64OrThrow(const std::string& encoded,
std::string_view target) {
std::vector<unsigned char> chars = YAML::DecodeBase64(encoded);
std::string decoded(chars.begin(), chars.end());
if (decoded.empty() != encoded.empty()) {
// Decoded should only be empty if the input is empty. This is a good but
// imperfect test. If `encoded` were nothing but whitespace, it would
// *functionally* be an empty base64 string and *should* produce an empty
// result. If necessary, we can attempt stripping whitespace from encoded.

const int length = ssize(encoded);
std::string head = encoded.substr(0, 25);
if (ssize(head) < length) {
head += "...";
}
throw std::runtime_error(fmt::format(
"Expected a base64-encoded value for '{}'; error decoding: '{}'",
target, head));
}

return decoded;
}

void YamlReadArchive::ThrowIfBadBinaryScalar(
bool output_is_binary, const internal::Node& node,
std::string_view node_name, std::function<std::string()> get_output_type) {
const bool input_is_binary = node.GetTag() == internal::Node::kTagBinary;
const bool binary_match = output_is_binary == input_is_binary;
if (!binary_match) {
if (input_is_binary) {
throw std::runtime_error(fmt::format(
"The !!binary tag can only be applied to values written to "
"std::vector<std::byte>. The value for '{}' has type {}.",
node_name, get_output_type()));
} else {
throw std::runtime_error(fmt::format(
"The C++ type for '{}' is std::vector<std::byte>. Its yaml value "
"must be base64 encoded with the !!binary tag. Given '{}'.",
node_name, node.GetScalar()));
}
}
}

void YamlReadArchive::ParseScalar(const std::string& value,
std::filesystem::path* result) {
DRAKE_DEMAND(result != nullptr);
Expand All @@ -273,6 +316,13 @@ void YamlReadArchive::ParseScalar(const std::string& value,
*result = std::filesystem::path(value).lexically_normal();
}

void YamlReadArchive::ParseScalar(const std::string& value,
std::vector<std::byte>* result) {
DRAKE_DEMAND(result != nullptr);
const std::byte* data = reinterpret_cast<const std::byte*>(value.data());
*result = std::vector<std::byte>(data, data + value.size());
}

const internal::Node* YamlReadArchive::MaybeGetSubNode(const char* name) const {
DRAKE_DEMAND(name != nullptr);
if (mapish_item_key_ != nullptr) {
Expand Down
47 changes: 46 additions & 1 deletion common/yaml/yaml_read_archive.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <array>
#include <cstdint>
#include <filesystem>
#include <functional>
#include <map>
#include <optional>
#include <ostream>
Expand Down Expand Up @@ -170,6 +171,12 @@ class YamlReadArchive final {
this->VisitVector(nvp);
}

// For std::vector<std::byte>.
template <typename NVP>
void DoVisit(const NVP& nvp, const std::vector<std::byte>&, int32_t) {
this->VisitScalar(nvp);
}

// For std::array.
template <typename NVP, typename T, std::size_t N>
void DoVisit(const NVP& nvp, const std::array<T, N>&, int32_t) {
Expand Down Expand Up @@ -239,13 +246,50 @@ class YamlReadArchive final {
sub_archive.Accept(&value);
}

// Attempts to interpret `encoded` as a base64-encoded string, returning the
// decoded version.
//
// @param encoded The ostensibly base64-encoded input.
// @param target The name of the node the decoding was being done for. Only
// used in the exception message.
// @throws if there's an error decoding.
static std::string DecodeBase64OrThrow(const std::string& encoded,
std::string_view target);

// Throws an error message if the reported binary configuration of input and
// output don't match.
//
// @param output_is_binary Assertion that the output is binary (if true).
// @param node The node being processed.
// @param node_name The name of the node being processed; only used in
// the exception message.
// @param get_output_type A callable that reports the name of the output
// type. Only called if an exception is thrown.
// @throws if there is a binary mismatch.
static void ThrowIfBadBinaryScalar(
bool output_is_binary, const internal::Node& node,
std::string_view node_name, std::function<std::string()> get_output_type);

template <typename NVP>
void VisitScalar(const NVP& nvp) {
const internal::Node* sub_node = GetSubNodeScalar(nvp.name());
if (sub_node == nullptr) {
return;
}
ParseScalar(sub_node->GetScalar(), nvp.value());

using OutValueType = typename NVP::value_type;
constexpr bool output_is_binary =
std::is_same_v<OutValueType, std::vector<std::byte>>;
ThrowIfBadBinaryScalar(
output_is_binary, *sub_node, nvp.name(),
static_cast<std::string (*)()>(NiceTypeName::Get<OutValueType>));

if constexpr (output_is_binary) {
ParseScalar(DecodeBase64OrThrow(sub_node->GetScalar(), nvp.name()),
nvp.value());
} else {
ParseScalar(sub_node->GetScalar(), nvp.value());
}
}

template <typename NVP>
Expand Down Expand Up @@ -518,6 +562,7 @@ class YamlReadArchive final {
void ParseScalar(const std::string& value, uint64_t* result);
void ParseScalar(const std::string& value, std::string* result);
void ParseScalar(const std::string& value, std::filesystem::path* result);
void ParseScalar(const std::string& value, std::vector<std::byte>* result);

template <typename T>
void ParseScalarImpl(const std::string& value, T* result);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
DecodeBase64 will return an empty result in one of three circumstances:
- The input is an empty string.
- The input is nothing but whitespace (a functionally empty string).
- The input has an invalid character.

However, it doesn't have the proper number of encoding characters (a multiple of
4), instead of reporting an error, it simply returns a truncated result with no
hint of any error.

This changes the function to detect the lack of sufficient data and signals by
returning an empty string. This gives the caller enough information to infer
a problem (and even the cause).


--- src/binary.cpp
+++ src/binary.cpp
@@ -74,7 +74,8 @@ std::vector<unsigned char> DecodeBase64(const std::string &input) {
unsigned char *out = &ret[0];

unsigned value = 0;
- for (std::size_t i = 0, cnt = 0; i < input.size(); i++) {
+ std::size_t cnt = 0;
+ for (std::size_t i = 0; i < input.size(); i++) {
if (std::isspace(static_cast<unsigned char>(input[i]))) {
// skip newlines
continue;
@@ -90,9 +91,14 @@ std::vector<unsigned char> DecodeBase64(const std::string &input) {
*out++ = value >> 8;
if (input[i] != '=')
*out++ = value;
+ cnt = 0;
+ } else {
+ ++cnt;
}
- ++cnt;
}
+ // An invalid number of characters were encountered.
+ if (cnt != 0)
+ return ret_type();

ret.resize(out - &ret[0]);
return ret;
2 changes: 2 additions & 0 deletions tools/workspace/yaml_cpp_internal/repository.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ def yaml_cpp_internal_repository(
mirrors = None):
github_archive(
name = name,
# local_repository_override = "/home/seancurtis/code/yaml-cpp",
repository = "jbeder/yaml-cpp",
commit = "0.8.0",
sha256 = "fbe74bbdcee21d656715688706da3c8becfd946d92cd44705cc6098bb23b3a16", # noqa
build_file = ":package.BUILD.bazel",
patches = [
":patches/upstream/b64_decode_failure_is_empty.patch",
":patches/emit-local-tag.patch",
],
mirrors = mirrors,
Expand Down

0 comments on commit 2af36ef

Please sign in to comment.