Skip to content

Commit

Permalink
improve arrays support:
Browse files Browse the repository at this point in the history
- ensure that reading arrays of primitive types is as fast as possible when there is enough data in the stream, throw when there is not, fall back to slow path when we don't know
- make GetArray reuse the previously created instance
- and provide tests for all of that
  • Loading branch information
adamsitnik committed Jun 7, 2024
1 parent e3d1a7b commit fdbd424
Show file tree
Hide file tree
Showing 8 changed files with 348 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ internal ArrayOfClassesRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo

/// <inheritdoc/>
public override ClassRecord?[] GetArray(bool allowNulls = true)
=> (ClassRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));

private ClassRecord?[] ToArray(bool allowNulls)
{
ClassRecord?[] result = new ClassRecord?[Length];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace System.Runtime.Serialization.BinaryFormat;
#endif
abstract class ArrayRecord : SerializationRecord
{
private protected Array? _arrayNullsAllowed;
private protected Array? _arrayNullsNotAllowed;

private protected ArrayRecord(ArrayInfo arrayInfo)
{
ArrayInfo = arrayInfo;
Expand Down Expand Up @@ -80,7 +83,9 @@ public Array GetArray(Type expectedArrayType, bool allowNulls = true)
throw new InvalidOperationException(SR.Format(SR.Serialization_TypeMismatch, expectedArrayType.AssemblyQualifiedName, ElementTypeName.AssemblyQualifiedName));
}

return Deserialize(expectedArrayType, allowNulls);
return allowNulls
? _arrayNullsAllowed ??= Deserialize(expectedArrayType, true)
: _arrayNullsNotAllowed ??= Deserialize(expectedArrayType, false);
}

[RequiresDynamicCode("May call Array.CreateInstance() and Type.MakeArrayType().")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ public override TypeName ElementTypeName

/// <inheritdoc/>
public override object?[] GetArray(bool allowNulls = true)
=> (object?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));

private object?[] ToArray(bool allowNulls)
{
object?[] values = new object?[Length];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public override TypeName ElementTypeName

/// <inheritdoc/>
public override T[] GetArray(bool allowNulls = true)
=> Values is T[] array ? array : Values.ToArray();
=> (T[])(_arrayNullsNotAllowed ??= (Values is T[] array ? array : Values.ToArray()));

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
{
Expand All @@ -61,32 +61,136 @@ private protected override void AddValue(object value)

internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int count)
{
if (typeof(T) == typeof(byte) && reader.IsDataAvailable(count))
// For decimals, the input is provided as strings, so we can't compute the required size up-front.
if (typeof(T) == typeof(decimal))
{
return (List<T>)(object)DecodeDecimals(reader, count);
}

long requiredBytes = count;
if (typeof(T) != typeof(char)) // the input is UTF8
{
requiredBytes *= Unsafe.SizeOf<T>();
}

bool? isDataAvailable = reader.IsDataAvailable(requiredBytes);
if (!isDataAvailable.HasValue)
{
return DecodeFromNonSeekableStream(reader, count);
}

if (!isDataAvailable.Value)
{
// We are sure there is not enough data.
ThrowHelper.ThrowEndOfStreamException();
}

if (typeof(T) == typeof(byte))
{
return (T[])(object)reader.ReadBytes(count);
}
// the input is UTF8, so the check does not include sizeof(char)
else if (typeof(T) == typeof(char) && reader.IsDataAvailable(count))
else if (typeof(T) == typeof(char))
{
return (T[])(object)reader.ReadChars(count);
}

// For decimals, the input is provided as strings, so we can't compute the required size up-front.
bool canPreAllocate = typeof(T) != typeof(decimal) && reader.IsDataAvailable(count * Unsafe.SizeOf<T>());
// Most of the tests use MemoryStream or FileStream and they both allow for executing the fast path.
// To ensure the slow path is tested as well, the fast path is executed only for optimized builds.
#if NET && RELEASE
if (canPreAllocate)
// It's safe to pre-allocate, as we have ensured there is enough bytes in the stream.
T[] result = new T[count];
Span<byte> resultAsBytes = MemoryMarshal.AsBytes<T>(result);
#if NET
reader.BaseStream.ReadExactly(resultAsBytes);
#else
byte[] bytes = ArrayPool<byte>.Shared.Rent(Math.Min(count * Unsafe.SizeOf<T>(), 256_000));

while (!resultAsBytes.IsEmpty)
{
int bytesRead = reader.Read(bytes, 0, Math.Min(resultAsBytes.Length, bytes.Length));
if (bytesRead <= 0)
{
ArrayPool<byte>.Shared.Return(bytes);
ThrowHelper.ThrowEndOfStreamException();
}

bytes.AsSpan(0, bytesRead).CopyTo(resultAsBytes);
resultAsBytes = resultAsBytes.Slice(bytesRead);
}

ArrayPool<byte>.Shared.Return(bytes);
#endif

if (!BitConverter.IsLittleEndian)
{
if (typeof(T) == typeof(short) || typeof(T) == typeof(ushort))
{
Span<short> span = MemoryMarshal.Cast<T, short>(result);
#if NET
BinaryPrimitives.ReverseEndianness(span, span);
#else
for (int i = 0; i < span.Length; i++)
{
span[i] = BinaryPrimitives.ReverseEndianness(span[i]);
}
#endif
}
else if (typeof(T) == typeof(int) || typeof(T) == typeof(uint) || typeof(T) == typeof(float))
{
Span<int> span = MemoryMarshal.Cast<T, int>(result);
#if NET
BinaryPrimitives.ReverseEndianness(span, span);
#else
for (int i = 0; i < span.Length; i++)
{
span[i] = BinaryPrimitives.ReverseEndianness(span[i]);
}
#endif
}
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double)
|| typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
{
Span<long> span = MemoryMarshal.Cast<T, long>(result);
#if NET
BinaryPrimitives.ReverseEndianness(span, span);
#else
for (int i = 0; i < span.Length; i++)
{
span[i] = BinaryPrimitives.ReverseEndianness(span[i]);
}
#endif
}
}

return result;
}

private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
{
List<decimal> values = new();
#if NET
Span<byte> buffer = stackalloc byte[256];
for (int i = 0; i < count; i++)
{
int stringLength = reader.Read7BitEncodedInt();
if (!(stringLength > 0 && stringLength <= buffer.Length))
{
ThrowHelper.ThrowInvalidValue(stringLength);
}

reader.BaseStream.ReadExactly(buffer.Slice(0, stringLength));

values.Add(decimal.Parse(buffer.Slice(0, stringLength), CultureInfo.InvariantCulture));
}
#else
for (int i = 0; i < count; i++)
{
return DecodePrimitiveTypesToArray(reader, count);
values.Add(decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture));
}
#endif
return DecodePrimitiveTypesToList(reader, count, canPreAllocate);
return values;
}

private static List<T> DecodePrimitiveTypesToList(BinaryReader reader, int count, bool canPreAllocate)
private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int count)
{
List<T> values = new List<T>(canPreAllocate ? count : Math.Min(count, 4));
List<T> values = new List<T>(Math.Min(count, 4));
for (int i = 0; i < count; i++)
{
if (typeof(T) == typeof(byte))
Expand Down Expand Up @@ -137,10 +241,6 @@ private static List<T> DecodePrimitiveTypesToList(BinaryReader reader, int count
{
values.Add((T)(object)reader.ReadDouble());
}
else if (typeof(T) == typeof(decimal))
{
values.Add((T)(object)decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture));
}
else if (typeof(T) == typeof(DateTime))
{
values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadInt64()));
Expand All @@ -155,51 +255,4 @@ private static List<T> DecodePrimitiveTypesToList(BinaryReader reader, int count

return values;
}

#if NET
private static T[] DecodePrimitiveTypesToArray(BinaryReader reader, int count)
{
T[] result = new T[count];
Span<byte> bytes = MemoryMarshal.AsBytes<T>(result);
reader.BaseStream.ReadExactly(bytes);

if (!BitConverter.IsLittleEndian)
{
if (typeof(T) == typeof(short) || typeof(T) == typeof(ushort))
{
Span<short> span = MemoryMarshal.Cast<T, short>(result);
BinaryPrimitives.ReverseEndianness(span, span);
}
else if (typeof(T) == typeof(int) || typeof(T) == typeof(uint) || typeof(T) == typeof(float))
{
Span<int> span = MemoryMarshal.Cast<T, int>(result);
BinaryPrimitives.ReverseEndianness(span, span);
}
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double)
|| typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
{
Span<long> span = MemoryMarshal.Cast<T, long>(result);
BinaryPrimitives.ReverseEndianness(span, span);
}
}

if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
{
Span<long> longs = MemoryMarshal.Cast<T, long>(result);
for (int i = 0; i < longs.Length; i++)
{
if (typeof(T) == typeof(DateTime))
{
result[i] = (T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(longs[i]);
}
else
{
result[i] = (T)(object)new TimeSpan(longs[i]);
}
}
}

return result;
}
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetA

/// <inheritdoc/>
public override string?[] GetArray(bool allowNulls = true)
=> (string?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));

private string?[] ToArray(bool allowNulls)
{
string?[] values = new string?[Length];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ internal static RectangularOrCustomOffsetArrayRecord Create(BinaryReader reader,
PrimitiveType.Boolean => sizeof(bool),
PrimitiveType.Byte => sizeof(byte),
PrimitiveType.SByte => sizeof(sbyte),
PrimitiveType.Char => sizeof(char),
PrimitiveType.Char => sizeof(byte), // it's UTF8
PrimitiveType.Int16 => sizeof(short),
PrimitiveType.UInt16 => sizeof(ushort),
PrimitiveType.Int32 => sizeof(int),
Expand All @@ -186,7 +186,20 @@ internal static RectangularOrCustomOffsetArrayRecord Create(BinaryReader reader,
_ => -1
};

canPreAllocate = sizeOfSingleValue != -1 && reader.IsDataAvailable(requiredBytes: arrayInfo.TotalElementsCount * sizeOfSingleValue);
if (sizeOfSingleValue > 0)
{
long size = arrayInfo.TotalElementsCount * sizeOfSingleValue;
bool? isDataAvailable = reader.IsDataAvailable(size);
if (isDataAvailable.HasValue)
{
if (!isDataAvailable.Value)
{
ThrowHelper.ThrowEndOfStreamException();
}

canPreAllocate = true;
}
}
}

return new RectangularOrCustomOffsetArrayRecord(elementType, arrayInfo, memberTypeInfo, lengths, offsets, canPreAllocate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,23 @@ internal static DateTime CreateDateTimeFromData(long data)
return Unsafe.As<long, DateTime>(ref data);
}

internal static bool IsDataAvailable(this BinaryReader reader, long requiredBytes)
internal static bool? IsDataAvailable(this BinaryReader reader, long requiredBytes)
{
if (!reader.BaseStream.CanSeek)
{
return false;
return null;
}

long availableBytes = 0;
try
{
availableBytes = reader.BaseStream.Length - reader.BaseStream.Position;
// If the values are equal, it's still not enough, as every NRBF payload
// needs to end with EndMessageByte and requiredBytes does not take it into account.
return (reader.BaseStream.Length - reader.BaseStream.Position) > requiredBytes;
}
catch
{
// seekable Stream can still throw when accessing Length and Position
return null;
}

return availableBytes > requiredBytes;
}
}
Loading

0 comments on commit fdbd424

Please sign in to comment.