Skip to content

Commit

Permalink
keep sub alive when reading channel
Browse files Browse the repository at this point in the history
Signed-off-by: Caleb Lloyd <caleblloyd@gmail.com>
  • Loading branch information
caleblloyd committed Jun 4, 2024
1 parent 177112b commit ef71ec9
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 50 deletions.
60 changes: 53 additions & 7 deletions src/NATS.Client.Core/Internal/ActivityEndingMsgReader.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading.Channels;

namespace NATS.Client.Core.Internal;
Expand All @@ -7,7 +8,13 @@ internal sealed class ActivityEndingMsgReader<T> : ChannelReader<NatsMsg<T>>
{
private readonly ChannelReader<NatsMsg<T>> _inner;

public ActivityEndingMsgReader(ChannelReader<NatsMsg<T>> inner) => _inner = inner;
private readonly INatsSub<T> _sub;

public ActivityEndingMsgReader(ChannelReader<NatsMsg<T>> inner, INatsSub<T> sub)
{
_inner = inner;
_sub = sub;
}

public override bool CanCount => _inner.CanCount;

Expand All @@ -25,17 +32,56 @@ public override bool TryRead(out NatsMsg<T> item)
return false;

item.Headers?.Activity?.Dispose();

GC.KeepAlive(_sub);
return true;
}

/// <inheritdoc/>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default) => _inner.WaitToReadAsync(cancellationToken);
public override async ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
{
var handle = GCHandle.Alloc(_sub);
try
{
return await _inner.WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}
finally
{
handle.Free();
}
}

public override ValueTask<NatsMsg<T>> ReadAsync(CancellationToken cancellationToken = default) => _inner.ReadAsync(cancellationToken);
public override async ValueTask<NatsMsg<T>> ReadAsync(CancellationToken cancellationToken = default)
{
var handle = GCHandle.Alloc(_sub);
try
{
var item = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false);
item.Headers?.Activity?.Dispose();
return item;
}
finally
{
handle.Free();
}
}

public override bool TryPeek(out NatsMsg<T> item) => _inner.TryPeek(out item);

public override IAsyncEnumerable<NatsMsg<T>> ReadAllAsync(CancellationToken cancellationToken = default) => _inner.ReadAllAsync(cancellationToken);
public override async IAsyncEnumerable<NatsMsg<T>> ReadAllAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var handle = GCHandle.Alloc(_sub);
try
{
while (await _inner.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
while (TryRead(out var msg))
{
yield return msg;
}
}
}
finally
{
handle.Free();
}
}
}
3 changes: 1 addition & 2 deletions src/NATS.Client.Core/NatsConnection.Subscribe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ public async IAsyncEnumerable<NatsMsg<T>> SubscribeAsync<T>(string subject, stri
{
serializer ??= Opts.SerializerRegistry.GetDeserializer<T>();

// call to RegisterSubAnchor is no longer needed; sub is kept alive in ActivityEndingMsgReader
await using var sub = new NatsSub<T>(this, SubscriptionManager.GetManagerFor(subject), subject, queueGroup, opts, serializer, cancellationToken);
using var anchor = RegisterSubAnchor(sub);

await SubAsync(sub, cancellationToken: cancellationToken).ConfigureAwait(false);

// We don't cancel the channel reader here because we want to keep reading until the subscription
Expand Down
2 changes: 1 addition & 1 deletion src/NATS.Client.Core/NatsSub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ internal NatsSub(
connection.GetChannelOpts(connection.Opts, opts?.ChannelOpts),
msg => Connection.OnMessageDropped(this, _msgs?.Reader.Count ?? 0, msg));

Msgs = new ActivityEndingMsgReader<T>(_msgs.Reader);
Msgs = new ActivityEndingMsgReader<T>(_msgs.Reader, this);

Serializer = serializer;
}
Expand Down
131 changes: 91 additions & 40 deletions tests/NATS.Client.Core.MemoryTests/NatsSubTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Threading.Channels;
using JetBrains.dotMemoryUnit;
using NATS.Client.Core.Tests;

Expand Down Expand Up @@ -42,66 +43,116 @@ async Task Isolator()
}

[Test]
public void Subscription_should_not_be_collected_when_in_async_enumerator()
{
var server = NatsServer.Start();
try
public void Subscription_should_not_be_collected_subscribe_async() =>
RunSubTest(async (nats, channelWriter, iterations) =>
{
var nats = server.CreateClientConnection(new NatsOpts { RequestTimeout = TimeSpan.FromSeconds(10) });
var i = 0;
#pragma warning disable SA1312
await foreach (var _ in nats.SubscribeAsync<string>("foo.*"))
#pragma warning restore SA1312
{
await channelWriter.WriteAsync(new object());
if (++i >= iterations)
break;
}
});

[Test]
public void Subscription_should_not_be_collected_subscribe_core_async_read_all_async() =>
RunSubTest(async (nats, channelWriter, iterations) =>
{
var i = 0;
await using var sub = await nats.SubscribeCoreAsync<string>("foo.*");
#pragma warning disable SA1312
await foreach (var _ in sub.Msgs.ReadAllAsync())
#pragma warning restore SA1312
{
await channelWriter.WriteAsync(new object());
if (++i >= iterations)
break;
}
});

var sync = 0;
[Test]
public void Subscription_should_not_be_collected_subscribe_core_async_read_async() =>
RunSubTest(async (nats, channelWriter, iterations) =>
{
var i = 0;
await using var sub = await nats.SubscribeCoreAsync<string>("foo.*");
while (true)
{
await sub.Msgs.ReadAsync();
await channelWriter.WriteAsync(new object());
if (++i >= iterations)
break;
}
});

var sub = Task.Run(async () =>
[Test]
public void Subscription_should_not_be_collected_subscribe_core_async_wait_to_read_async() =>
RunSubTest(async (nats, channelWriter, iterations) =>
{
var i = 0;
await using var sub = await nats.SubscribeCoreAsync<string>("foo.*");
while (await sub.Msgs.WaitToReadAsync())
{
var count = 0;
await foreach (var msg in nats.SubscribeAsync<string>("foo.*"))
while (sub.Msgs.TryRead(out _))
{
if (msg.Subject == "foo.sync")
{
Interlocked.Increment(ref sync);
continue;
}
await channelWriter.WriteAsync(new object());
i++;
}

if (++count == 10)
break;
if (i >= iterations)
{
break;
}
});
}
});

var pub = Task.Run(async () =>
private void RunSubTest(Func<NatsConnection, ChannelWriter<object>, int, Task> subTask)
{
var server = NatsServer.Start();
try
{
const int iterations = 10;
var nats = server.CreateClientConnection(new NatsOpts { RequestTimeout = TimeSpan.FromSeconds(10) });
var received = Channel.CreateUnbounded<object>();
var task = subTask(nats, received.Writer, iterations);

var i = 0;
var fail = 0;
while (true)
{
while (Volatile.Read(ref sync) == 0)
nats.PublishAsync("foo.data", "data").AsTask().GetAwaiter().GetResult();
try
{
await nats.PublishAsync("foo.sync", "sync");
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(100));
received.Reader.ReadAsync(cts.Token).AsTask().GetAwaiter().GetResult();
}

for (var i = 0; i < 10; i++)
catch (OperationCanceledException)
{
GC.Collect();

dotMemory.Check(memory =>
if (++fail <= 10)
{
var count = memory.GetObjects(where => where.Type.Is<NatsSub<string>>()).ObjectsCount;
Assert.That(count, Is.EqualTo(1), "Alive");
});
continue;
}

await nats.PublishAsync("foo.data", "data");
Assert.Fail($"failed to receive a reply 10 times");
}
});

var waitPub = Task.WaitAll(new[] { pub }, TimeSpan.FromSeconds(10));
if (!waitPub)
{
Assert.Fail("Timed out waiting for pub task to complete");
}
if (++i >= iterations)
break;

var waitSub = Task.WaitAll(new[] { sub }, TimeSpan.FromSeconds(10));
if (!waitSub)
{
Assert.Fail("Timed out waiting for sub task to complete");
GC.Collect();
dotMemory.Check(memory =>
{
var count = memory.GetObjects(where => where.Type.Is<NatsSub<string>>()).ObjectsCount;
Assert.That(count, Is.EqualTo(1), $"Alive - received {i}");
});
}

GC.Collect();
task.GetAwaiter().GetResult();

GC.Collect();
dotMemory.Check(memory =>
{
var count = memory.GetObjects(where => where.Type.Is<NatsSub<string>>()).ObjectsCount;
Expand Down

0 comments on commit ef71ec9

Please sign in to comment.