Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use NatsMemoryOwner for Base64Url Encoder #565

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/NATS.Client.ObjectStore/Internal/Encoder.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Buffers;
using System.Security.Cryptography;
using NATS.Client.Core;

namespace NATS.Client.ObjectStore.Internal;

Expand Down Expand Up @@ -68,18 +69,33 @@ public static string Encode(string arg)
/// <exception cref="ArgumentNullException">'inArray' is null.</exception>
/// <exception cref="ArgumentOutOfRangeException">offset or length is negative OR offset plus length is greater than the length of inArray.</exception>
public static string Encode(Span<byte> inArray, bool raw = false)
{
using (var owner = EncodeToMemoryOwner(inArray, raw))
{
var segment = owner.DangerousGetArray();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use here isn't dangerous since we are owning it in scope,
This was just easiest way to call new string without #if type stuff for different FWs.

As it stands, this should still be a significant improvement as far as array allocations due to the pooling.

LMK if this deserves a comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good. as you said it's in scope so it's fine. a short comment might be nice for our future selves. btw if you're concerned maybe use span.tostring()? - i believe that was you suggestion to something else before 😅 having said that I'm guessing string.ctor(char[]) must be super fast!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

span.tostring()?

I had the same realization lol... span.ToString is going to be the best way to go on these.

if (segment.Array == null || segment.Array.Length == 0)
{
return string.Empty;
}

return new string(segment.Array, segment.Offset, segment.Count);
}
}

public static NatsMemoryOwner<char> EncodeToMemoryOwner(Span<byte> inArray, bool raw = false)
{
var offset = 0;
var length = inArray.Length;

if (length == 0)
return string.Empty;
return NatsMemoryOwner<char>.Empty;

var lengthMod3 = length % 3;
var limit = length - lengthMod3;
var output = new char[(length + 2) / 3 * 4];
var owner = NatsMemoryOwner<char>.Allocate((length + 2) / 3 * 4);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be try-catch-dispose-throw around this, case something derps in the encode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, Should we be adding a flag to the method here to specify 'clearing' and pass it in here? I guess it depends on the security concern of the SHA lingering.

CC @mtmk @caleblloyd

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be try-catch-dispose-throw around this, case something derps in the encode?

Should we encode to span instead? in case of sha we can also stackalloc maybe?

Also, Should we be adding a flag to the method here to specify 'clearing' and pass it in here?

Do you mean clearing the input array or output?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we encode to span instead? in case of sha we can also stackalloc maybe?

Hmm I can give it a shot, it might be cleaner...

Do you mean clearing the input array or output?

Clearing the buffered array. IIRC NatsMemoryOwner has a flag to say whether the array is cleared on return.

My 'gut' says yes but if we can get away with stackallocing the span, point is moot. Will check.

var table = SBase64Table;
int i, j = 0;
var output = owner.Span;

// takes 3 bytes from inArray and insert 4 bytes into output
for (i = offset; i < limit; i += 3)
Expand Down Expand Up @@ -128,14 +144,14 @@ public static string Encode(Span<byte> inArray, bool raw = false)
}

if (raw)
return new string(output, 0, j);
return owner.Slice(0, j);

for (var k = j; k < output.Length; k++)
{
output[k] = Base64PadCharacter;
}

return new string(output);
return owner;
}

/// <summary>
Expand Down
251 changes: 132 additions & 119 deletions src/NATS.Client.ObjectStore/NatsObjStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,65 +94,74 @@ public async ValueTask<ObjectMetadata> GetAsync(string key, Stream stream, bool

pushConsumer.Init();

string digest;
var chunks = 0;
var size = 0;
using (var sha256 = SHA256.Create())
var digest = NatsMemoryOwner<char>.Empty;
try
{
var chunks = 0;
var size = 0;
using (var sha256 = SHA256.Create())
{
#if NETSTANDARD2_0
using (var hashedStream = new CryptoStream(stream, sha256, CryptoStreamMode.Write))
using (var hashedStream = new CryptoStream(stream, sha256, CryptoStreamMode.Write))
#else
await using (var hashedStream = new CryptoStream(stream, sha256, CryptoStreamMode.Write, leaveOpen))
await using (var hashedStream = new CryptoStream(stream, sha256, CryptoStreamMode.Write, leaveOpen))
#endif
{
await foreach (var msg in pushConsumer.Msgs.ReadAllAsync(cancellationToken))
{
// We have to make sure to carry on consuming the channel to avoid any blocking:
// e.g. if the channel is full, we would be blocking the reads off the socket (this was intentionally
// done ot avoid bloating the memory with a large backlog of messages or dropping messages at this level
// and signal the server that we are a slow consumer); then when we make an request-reply API call to
// delete the consumer, the socket would be blocked trying to send the response back to us; so we need to
// keep consuming the channel to avoid this.
if (pushConsumer.IsDone)
continue;

if (msg.Data.Length > 0)
await foreach (var msg in pushConsumer.Msgs.ReadAllAsync(cancellationToken))
{
using var memoryOwner = msg.Data;
chunks++;
size += memoryOwner.Memory.Length;
// We have to make sure to carry on consuming the channel to avoid any blocking:
// e.g. if the channel is full, we would be blocking the reads off the socket (this was intentionally
// done ot avoid bloating the memory with a large backlog of messages or dropping messages at this level
// and signal the server that we are a slow consumer); then when we make an request-reply API call to
// delete the consumer, the socket would be blocked trying to send the response back to us; so we need to
// keep consuming the channel to avoid this.
if (pushConsumer.IsDone)
continue;

if (msg.Data.Length > 0)
{
using var memoryOwner = msg.Data;
chunks++;
size += memoryOwner.Memory.Length;
#if NETSTANDARD2_0
var segment = memoryOwner.DangerousGetArray();
await hashedStream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken);
var segment = memoryOwner.DangerousGetArray();
await hashedStream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken);
#else
await hashedStream.WriteAsync(memoryOwner.Memory, cancellationToken);
await hashedStream.WriteAsync(memoryOwner.Memory, cancellationToken);
#endif
}
}

var p = msg.Metadata?.NumPending;
if (p is 0)
{
pushConsumer.Done();
var p = msg.Metadata?.NumPending;
if (p is 0)
{
pushConsumer.Done();
}
}
}

digest = Base64UrlEncoder.EncodeToMemoryOwner(sha256.Hash);
}

digest = Base64UrlEncoder.Encode(sha256.Hash);
}
if (info.Digest == null
|| info.Digest.StartsWith("SHA-256=") == false
|| info.Digest.AsSpan().Slice("SHA-256=".Length).SequenceEqual(digest.Span) == false)
Comment on lines +145 to +147
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is likely a 'better way' to handle this case, but this seemed clear and should be performant enough.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good. maybe ordinal StartsWith to avoid cultures?

{
throw new NatsObjException("SHA-256 digest mismatch");
}

if ($"SHA-256={digest}" != info.Digest)
{
throw new NatsObjException("SHA-256 digest mismatch");
}
if (chunks != info.Chunks)
{
throw new NatsObjException("Chunks mismatch");
}

if (chunks != info.Chunks)
{
throw new NatsObjException("Chunks mismatch");
if (size != info.Size)
{
throw new NatsObjException("Size mismatch");
}
}

if (size != info.Size)
finally
{
throw new NatsObjException("Size mismatch");
digest.Dispose();
}

return info;
Expand Down Expand Up @@ -223,117 +232,121 @@ public async ValueTask<ObjectMetadata> PutAsync(ObjectMetadata meta, Stream stre
var chunks = 0;
var chunkSize = meta.Options.MaxChunkSize.Value;

string digest;
using (var sha256 = SHA256.Create())
var digest = NatsMemoryOwner<char>.Empty;
try
{
using (var sha256 = SHA256.Create())
{
#if NETSTANDARD2_0
using (var hashedStream = new CryptoStream(stream, sha256, CryptoStreamMode.Read))
using (var hashedStream = new CryptoStream(stream, sha256, CryptoStreamMode.Read))
#else
await using (var hashedStream = new CryptoStream(stream, sha256, CryptoStreamMode.Read, leaveOpen))
await using (var hashedStream = new CryptoStream(stream, sha256, CryptoStreamMode.Read, leaveOpen))
#endif
{
while (true)
{
var memoryOwner = NatsMemoryOwner<byte>.Allocate(chunkSize);

var memory = memoryOwner.Memory;
var currentChunkSize = 0;
var eof = false;

// Fill a chunk
while (true)
{
#if NETSTANDARD2_0
int read;
if (MemoryMarshal.TryGetArray((ReadOnlyMemory<byte>)memory, out var segment) == false)
{
read = await hashedStream.ReadAsync(segment.Array!, segment.Offset, segment.Count, cancellationToken);
}
else
var memoryOwner = NatsMemoryOwner<byte>.Allocate(chunkSize);

var memory = memoryOwner.Memory;
var currentChunkSize = 0;
var eof = false;

// Fill a chunk
while (true)
{
var bytes = ArrayPool<byte>.Shared.Rent(memory.Length);
try
#if NETSTANDARD2_0
int read;
if (MemoryMarshal.TryGetArray((ReadOnlyMemory<byte>)memory, out var segment) == false)
{
segment = new ArraySegment<byte>(bytes, 0, memory.Length);
read = await hashedStream.ReadAsync(segment.Array!, segment.Offset, segment.Count, cancellationToken);
segment.Array.AsMemory(0, read).CopyTo(memory);
}
finally
else
{
ArrayPool<byte>.Shared.Return(bytes);
var bytes = ArrayPool<byte>.Shared.Rent(memory.Length);
try
{
segment = new ArraySegment<byte>(bytes, 0, memory.Length);
read = await hashedStream.ReadAsync(segment.Array!, segment.Offset, segment.Count, cancellationToken);
segment.Array.AsMemory(0, read).CopyTo(memory);
}
finally
{
ArrayPool<byte>.Shared.Return(bytes);
}
}
}

#else
var read = await hashedStream.ReadAsync(memory, cancellationToken);
var read = await hashedStream.ReadAsync(memory, cancellationToken);
#endif

// End of stream
if (read == 0)
{
eof = true;
break;
}
// End of stream
if (read == 0)
{
eof = true;
break;
}

memory = memory.Slice(read);
currentChunkSize += read;
memory = memory.Slice(read);
currentChunkSize += read;

// Chunk filled
if (memory.IsEmpty)
{
break;
// Chunk filled
if (memory.IsEmpty)
{
break;
}
}
}

if (currentChunkSize > 0)
{
size += currentChunkSize;
chunks++;
}
if (currentChunkSize > 0)
{
size += currentChunkSize;
chunks++;
}

var buffer = memoryOwner.Slice(0, currentChunkSize);
var buffer = memoryOwner.Slice(0, currentChunkSize);

// Chunks
var ack = await _context.PublishAsync(GetChunkSubject(nuid), buffer, serializer: NatsRawSerializer<NatsMemoryOwner<byte>>.Default, cancellationToken: cancellationToken);
ack.EnsureSuccess();
// Chunks
var ack = await _context.PublishAsync(GetChunkSubject(nuid), buffer, serializer: NatsRawSerializer<NatsMemoryOwner<byte>>.Default, cancellationToken: cancellationToken);
ack.EnsureSuccess();

if (eof)
break;
if (eof)
break;
}
}
}

if (sha256.Hash == null)
throw new NatsObjException("Can't compute SHA256 hash");
if (sha256.Hash == null)
throw new NatsObjException("Can't compute SHA256 hash");

digest = Base64UrlEncoder.Encode(sha256.Hash);
}
digest = Base64UrlEncoder.EncodeToMemoryOwner(sha256.Hash);
}

meta.Chunks = chunks;
meta.Size = size;
meta.Digest = $"SHA-256={digest}";
meta.Chunks = chunks;
meta.Size = size;
meta.Digest = $"SHA-256={digest}";

// Metadata
await PublishMeta(meta, cancellationToken);
// Metadata
await PublishMeta(meta, cancellationToken);

// Delete the old object
if (info?.Nuid != null && info.Nuid != nuid)
{
try
{
await _context.JSRequestResponseAsync<StreamPurgeRequest, StreamPurgeResponse>(
subject: $"{_context.Opts.Prefix}.STREAM.PURGE.OBJ_{Bucket}",
request: new StreamPurgeRequest
{
Filter = GetChunkSubject(info.Nuid),
},
cancellationToken);
}
catch (NatsJSApiException e)
// Delete the old object
if (info?.Nuid != null && info.Nuid != nuid)
{
if (e.Error.Code != 404)
throw;
try
{
await _context.JSRequestResponseAsync<StreamPurgeRequest, StreamPurgeResponse>(
subject: $"{_context.Opts.Prefix}.STREAM.PURGE.OBJ_{Bucket}",
request: new StreamPurgeRequest { Filter = GetChunkSubject(info.Nuid), },
cancellationToken);
}
catch (NatsJSApiException e)
{
if (e.Error.Code != 404)
throw;
}
}
}
finally
{
digest.Dispose();
}

return meta;
}
Expand Down
Loading