Skip to content

Commit

Permalink
[Rust] update how "use declarations" are generated, now deriving Defa…
Browse files Browse the repository at this point in the history
…ults for enums (#952)

* Fixed several issues:
#1: LibRsDef explicitly us crate:: to disambiguate from built in mods like bool
#2: Added acting_version field to Composite structs to fix compilation errors when using Composite structs.  This is an incomplete implementation because the parent doesn't pass the acting_version to the composite because you need to change the signature of wrap(parent, offset) to include the acting_version, so this version just ensures that if the acting_version isn't set on the composite, it disregards the version check.
#3: fixed primitiveArrayDecoder to return an empty array of the right size if less than version required.

* [Rust] removed 'acting_version' field from SubGroup decoder

* [Rust] updated RustGenerator::generateEnum() to support derive "Default" instead of generating impl

* fixed formatting issue

* [Rust] updated how 'use declarations' are generated to prevent "ambiguous glob re-exports" warnings from rust compiler

---------

Co-authored-by: Adam Krieg <adam@talostrading.com>
Co-authored-by: Michael Ward <mward@drw.com>
  • Loading branch information
3 people authored Nov 1, 2023
1 parent 060031c commit c0dae9e
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 42 deletions.
2 changes: 1 addition & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ baseline_bigendian = { path = "../generated/rust/baseline-bigendian" }
nested_composite_name = { path = "../generated/rust/nested-composite-name" }

[dev-dependencies]
criterion = "0.3"
criterion = "0.5"

[[bench]]
name = "car_benchmark"
Expand Down
1 change: 1 addition & 0 deletions rust/benches/car_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use examples_uk_co_real_logic_sbe_benchmarks::*;
use car_codec::encoder::*;

const MANUFACTURER: &[u8] = b"MANUFACTURER";
const MODEL: &[u8] = b"MODEL";
Expand Down
1 change: 1 addition & 0 deletions rust/benches/md_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use examples_uk_co_real_logic_sbe_benchmarks_fix::*;
use market_data_incremental_refresh_trades_codec::encoder::*;

struct State {
buffer: Vec<u8>,
Expand Down
1 change: 1 addition & 0 deletions rust/tests/baseline_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fs::File;
use std::io::prelude::*;

use examples_baseline::*;
use car_codec::encoder::*;

fn read_sbe_file_generated_from_java_example() -> std::io::Result<Vec<u8>> {
// Generated by the generateCarExampleDataFile gradle task.
Expand Down
1 change: 1 addition & 0 deletions rust/tests/big_endian_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use baseline_bigendian::*;
use car_codec::encoder::*;

#[test]
fn big_endian_baseline_example() -> SbeResult<()> {
Expand Down
1 change: 1 addition & 0 deletions rust/tests/extension_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fs::File;
use std::io::prelude::*;

use examples_extension::*;
use car_codec::encoder::*;

fn read_sbe_file_generated_from_java_example() -> std::io::Result<Vec<u8>> {
// Generated by the generateCarExampleDataFile gradle task.
Expand Down
14 changes: 8 additions & 6 deletions rust/tests/issue_895_test.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use ::issue_895::*;
use issue_895::{
issue_895_codec::{decoder::Issue895Decoder, encoder::Issue895Encoder},
MessageHeaderDecoder, ReadBuf, SbeResult, WriteBuf, ENCODED_LENGTH, SBE_BLOCK_LENGTH,
SBE_SCHEMA_ID, SBE_SCHEMA_VERSION, SBE_TEMPLATE_ID,
};

fn create_encoder(buffer: &mut Vec<u8>) -> Issue895Encoder {
let issue_895 = Issue895Encoder::default().wrap(
WriteBuf::new(buffer.as_mut_slice()),
ENCODED_LENGTH,
);
let issue_895 =
Issue895Encoder::default().wrap(WriteBuf::new(buffer.as_mut_slice()), ENCODED_LENGTH);
let mut header = issue_895.header(0);
header.parent().unwrap()
}
Expand Down Expand Up @@ -72,4 +74,4 @@ fn issue_895_double_none() -> SbeResult<()> {
assert_eq!(None, decoder.optional_double());

Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void generate() throws IOException
// add re-export of modules
for (final String module : modules)
{
indent(libRs, 0, "pub use %s::*;\n", toLowerSnakeCase(module));
indent(libRs, 0, "pub use crate::%s::*;\n", toLowerSnakeCase(module));
}
indent(libRs, 0, "\n");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package uk.co.real_logic.sbe.generation.rust;

import org.agrona.Strings;
import org.agrona.Verify;
import org.agrona.generation.OutputManager;
import uk.co.real_logic.sbe.PrimitiveType;
Expand Down Expand Up @@ -119,7 +120,7 @@ public void generate() throws IOException
indent(writer, 0, "version = \"0.1.0\"\n");
indent(writer, 0, "authors = [\"sbetool\"]\n");
indent(writer, 0, "description = \"%s\"\n", ir.description());
indent(writer, 0, "edition = \"2018\"\n\n");
indent(writer, 0, "edition = \"2021\"\n\n");
indent(writer, 0, "[lib]\n");
indent(writer, 0, "name = \"%s\"\n", namespace);
indent(writer, 0, "path = \"src/lib.rs\"\n");
Expand Down Expand Up @@ -151,8 +152,8 @@ public void generate() throws IOException
try (Writer out = outputManager.createOutput(codecModName))
{
indent(out, 0, "use crate::*;\n\n");
indent(out, 0, "pub use encoder::*;\n");
indent(out, 0, "pub use decoder::*;\n\n");
indent(out, 0, "pub use encoder::%sEncoder;\n", formatStructName(msgToken.name()));
indent(out, 0, "pub use decoder::%sDecoder;\n\n", formatStructName(msgToken.name()));
final String blockLengthType = blockLengthType();
final String templateIdType = rustTypeName(ir.headerStructure().templateIdType());
final String schemaIdType = rustTypeName(ir.headerStructure().schemaIdType());
Expand Down Expand Up @@ -274,7 +275,18 @@ static void generateEncoderGroups(
final Token numInGroupToken = Generators.findFirst("numInGroup", tokens, index);
final PrimitiveType numInGroupPrimitiveType = numInGroupToken.encoding().primitiveType();

indent(sb, level, "/// GROUP ENCODER\n");
final String description = groupToken.description();
if (!Strings.isEmpty(description))
{
indent(sb, level, "/// GROUP ENCODER (id=%s, description='%s')\n",
groupToken.id(), description);
}
else
{
indent(sb, level, "/// GROUP ENCODER (id=%s)\n",
groupToken.id());
}

assert 4 == groupHeaderTokenCount;
indent(sb, level, "#[inline]\n");
indent(sb, level, "pub fn %s(self, count: %s, %1$s: %3$s<Self>) -> %3$s<Self> {\n",
Expand Down Expand Up @@ -553,7 +565,7 @@ private static void generateCompositeDecoder(
decoderName,
decoderTypeName);

indent(sb, level + 1, "if self.acting_version < %d {\n", fieldToken.version());
indent(sb, level + 1, "if self.acting_version > 0 && self.acting_version < %d {\n", fieldToken.version());
indent(sb, level + 2, "return Either::Left(self);\n");
indent(sb, level + 1, "}\n\n");

Expand Down Expand Up @@ -587,7 +599,7 @@ private static void generateBitSetDecoder(

if (bitsetToken.version() > 0)
{
indent(sb, level + 1, "if self.acting_version < %d {\n", bitsetToken.version());
indent(sb, level + 1, "if self.acting_version > 0 && self.acting_version < %d {\n", bitsetToken.version());
indent(sb, level + 2, "return %s::default();\n", structTypeName);
indent(sb, level + 1, "}\n\n");
}
Expand Down Expand Up @@ -647,7 +659,7 @@ private static void generatePrimitiveArrayDecoder(

if (fieldToken.version() > 0)
{
indent(sb, level + 1, "if self.acting_version < %d {\n", fieldToken.version());
indent(sb, level + 1, "if self.acting_version > 0 && self.acting_version < %d {\n", fieldToken.version());
indent(sb, level + 2, "return [%s; %d];\n", encoding.applicableNullValue(), arrayLength);
indent(sb, level + 1, "}\n\n");
}
Expand Down Expand Up @@ -756,7 +768,7 @@ private static void generatePrimitiveOptionalDecoder(

if (fieldToken.version() > 0)
{
indent(sb, level + 1, "if self.acting_version < %d {\n", fieldToken.version());
indent(sb, level + 1, "if self.acting_version > 0 && self.acting_version < %d {\n", fieldToken.version());
indent(sb, level + 2, "return None;\n");
indent(sb, level + 1, "}\n\n");
}
Expand Down Expand Up @@ -809,7 +821,7 @@ private static void generatePrimitiveRequiredDecoder(

if (fieldToken.version() > 0)
{
indent(sb, level + 1, "if self.acting_version < %d {\n", fieldToken.version());
indent(sb, level + 1, "if self.acting_version > 0 && self.acting_version < %d {\n", fieldToken.version());
indent(sb, level + 2, "return %s;\n",
generateRustLiteral(encoding.primitiveType(), encoding.applicableNullValue().toString()));
indent(sb, level + 1, "}\n\n");
Expand Down Expand Up @@ -897,9 +909,19 @@ static void generateDecoderGroups(
i = collectVarData(tokens, i, varData);

final String groupName = decoderName(formatStructName(groupToken.name()));
indent(sb, level, "/// GROUP DECODER\n");
assert 4 == groupHeaderTokenCount;
final String description = groupToken.description();
if (!Strings.isEmpty(description))
{
indent(sb, level, "/// GROUP DECODER (id=%s, description='%s')\n",
groupToken.id(), description);
}
else
{
indent(sb, level, "/// GROUP DECODER (id=%s)\n",
groupToken.id());
}

assert 4 == groupHeaderTokenCount;
indent(sb, level, "#[inline]\n");
if (groupToken.version() > 0)
{
Expand All @@ -910,18 +932,14 @@ static void generateDecoderGroups(
indent(sb, level + 2, "return None;\n");
indent(sb, level + 1, "}\n\n");

indent(sb, level + 1, "let acting_version = self.acting_version;\n");
indent(sb, level + 1, "Some(%s::default().wrap(self, acting_version as usize))\n",
groupName);
indent(sb, level + 1, "Some(%s::default().wrap(self))\n", groupName);
}
else
{
indent(sb, level, "pub fn %s(self) -> %2$s<Self> {\n",
formatFunctionName(groupName), groupName);

indent(sb, level + 1, "let acting_version = self.acting_version;\n");
indent(sb, level + 1, "%s::default().wrap(self, acting_version as usize)\n",
groupName);
indent(sb, level + 1, "%s::default().wrap(self)\n", groupName);
}
indent(sb, level, "}\n\n");

Expand Down Expand Up @@ -958,7 +976,8 @@ static void generateDecoderVarData(
{
if (varDataToken.version() > 0)
{
indent(sb, level + 1, "if self.acting_version < %d {\n", varDataToken.version());
indent(sb, level + 1, "if self.acting_version > 0 && self.acting_version < %d {\n",
varDataToken.version());
indent(sb, level + 2, "return (self.parent.as_ref().unwrap().get_limit(), 0);\n");
indent(sb, level + 1, "}\n\n");
}
Expand All @@ -973,7 +992,8 @@ static void generateDecoderVarData(
{
if (varDataToken.version() > 0)
{
indent(sb, level + 1, "if self.acting_version < %d {\n", varDataToken.version());
indent(sb, level + 1, "if self.acting_version > 0 && self.acting_version < %d {\n",
varDataToken.version());
indent(sb, level + 2, "return (self.get_limit(), 0);\n");
indent(sb, level + 1, "}\n\n");
}
Expand All @@ -992,7 +1012,8 @@ static void generateDecoderVarData(

if (varDataToken.version() > 0)
{
indent(sb, level + 1, "if self.acting_version < %d {\n", varDataToken.version());
indent(sb, level + 1, "if self.acting_version > 0 && self.acting_version < %d {\n",
varDataToken.version());
indent(sb, level + 2, "return &[] as &[u8];\n");
indent(sb, level + 1, "}\n\n");
}
Expand Down Expand Up @@ -1199,7 +1220,7 @@ private static void generateEnum(
throw new IllegalArgumentException("No valid values provided for enum " + originalEnumName);
}

indent(writer, 0, "#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]\n");
indent(writer, 0, "#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]\n");
final String primitiveType = rustTypeName(messageBody.get(0).encoding().primitiveType());
indent(writer, 0, "#[repr(%s)]\n", primitiveType);
indent(writer, 0, "pub enum %s {\n", enumRustName);
Expand All @@ -1216,16 +1237,11 @@ private static void generateEnum(
final Encoding encoding = messageBody.get(0).encoding();
final CharSequence nullVal = generateRustLiteral(encoding.primitiveType(),
encoding.applicableNullValue().toString());
indent(writer, 1, "#[default]\n");
indent(writer, 1, "NullVal = %s, \n", nullVal);
}
indent(writer, 0, "}\n");

// Default implementation to support Default in other structs
indent(writer, 0, "impl Default for %s {\n", enumRustName);
indent(writer, 1, "#[inline]\n");
indent(writer, 1, "fn default() -> Self { %s::%s }\n", enumRustName, "NullVal");
indent(writer, 0, "}\n");

// From impl
indent(writer, 0, "impl From<%s> for %s {\n", primitiveType, enumRustName);
indent(writer, 1, "#[inline]\n");
Expand Down Expand Up @@ -1271,8 +1287,8 @@ private static void generateComposite(
{
indent(out, 0, "use crate::*;\n\n");

indent(out, 0, "pub use encoder::*;\n");
indent(out, 0, "pub use decoder::*;\n\n");
indent(out, 0, "pub use encoder::%sEncoder;\n", formatStructName(compositeName));
indent(out, 0, "pub use decoder::%sDecoder;\n\n", formatStructName(compositeName));

final int encodedLength = tokens.get(0).encodedLength();
if (encodedLength > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ private static String sanitizeMethodOrProperty(final String name)
{
if (shadowsKeyword(name))
{
return name + "_";
return "r#" + name;
}
else
{
Expand Down Expand Up @@ -310,7 +310,7 @@ enum ReservedKeyword
{
for (final ReservedKeyword value : ReservedKeyword.values())
{
LOWER_CASE_NAMES.add(value.name());
LOWER_CASE_NAMES.add(value.name().toLowerCase());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ void generateDecoder(
indent(sb, level - 1, "pub struct %s<P> {\n", name);
indent(sb, level, "parent: Option<P>,\n");
indent(sb, level, "block_length: usize,\n");
indent(sb, level, "acting_version: usize,\n");
indent(sb, level, "count: %s,\n", rustTypeName(numInGroupPrimitiveType));
indent(sb, level, "index: usize,\n");
indent(sb, level, "offset: usize,\n");
Expand All @@ -174,7 +173,6 @@ void generateDecoder(
indent(sb, level, "pub fn wrap(\n");
indent(sb, level + 1, "mut self,\n");
indent(sb, level + 1, "mut parent: P,\n");
indent(sb, level + 1, "acting_version: usize,\n");
indent(sb, level, ") -> Self {\n");
indent(sb, level + 1, "let initial_offset = parent.get_limit();\n");
indent(sb, level + 1, "let block_length = parent.get_buf().get_%s_at(initial_offset) as usize;\n",
Expand All @@ -186,7 +184,6 @@ void generateDecoder(

indent(sb, level + 1, "self.parent = Some(parent);\n");
indent(sb, level + 1, "self.block_length = block_length;\n");
indent(sb, level + 1, "self.acting_version = acting_version;\n");
indent(sb, level + 1, "self.count = count;\n");
indent(sb, level + 1, "self.index = usize::MAX;\n");
indent(sb, level + 1, "self.offset = 0;\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ void functionNameCasing()
assertEquals("pricenull_9", formatFunctionName("PRICENULL9"));
assertEquals("price_9_book", formatFunctionName("PRICE9Book"));
assertEquals("issue_435", formatFunctionName("issue435"));
assertEquals("r#type", formatFunctionName("type"));

assertEquals("upper_case", formatFunctionName("UPPERCase"));
assertEquals("no_md_entries", formatFunctionName("NoMDEntries"));
assertEquals("md_entry_type_book", formatFunctionName("MD_EntryTYPEBook"));
Expand Down

0 comments on commit c0dae9e

Please sign in to comment.