Skip to content

Commit

Permalink
Adds Keep Alive message to SSE (#6133)
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalSenn authored May 10, 2023
1 parent 3702e13 commit b2000ed
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,19 @@ public async IAsyncEnumerable<string> OnException(

throw new GraphQLException(ErrorBuilder.New().SetMessage("Foo").Build());
}

#pragma warning disable CS0618
[SubscribeAndResolve]
#pragma warning restore CS0618
public async IAsyncEnumerable<string> Delay(
[EnumeratorCancellation] CancellationToken cancellationToken,
int delay,
int count)
{
while (!cancellationToken.IsCancellationRequested && count-- > 0)
{
yield return "next";
await Task.Delay(delay, cancellationToken);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -451,14 +451,59 @@ public async Task DefferedQuery_NoStreamableAcceptHeader()
{""errors"":[{""message"":""The specified operation kind is not allowed.""}]}");
}

[Fact]
public async Task EventStream_Sends_KeepAlive()
{
// arrange
var server = CreateStarWarsServer();
var client = server.CreateClient();
client.Timeout = TimeSpan.FromSeconds(30);

// act
using var request = new HttpRequestMessage(HttpMethod.Post, _url)
{
Content = JsonContent.Create(
new ClientQueryRequest { Query = "subscription {delay(count: 2, delay:15000)}" }),
Headers = { { "Accept", "text/event-stream" } }
};

using var response = await client.SendAsync(request, ResponseHeadersRead);

// assert
Snapshot
.Create()
.Add(response)
.MatchInline("""
Headers:
Cache-Control: no-cache
Content-Type: text/event-stream; charset=utf-8
-------------------------->
Status Code: OK
-------------------------->
event: next
data: {"data":{"delay":"next"}}
:
event: next
data: {"data":{"delay":"next"}}
:
event: complete
""");
}

private HttpClient GetClient(HttpTransportVersion serverTransportVersion)
{
var server = CreateStarWarsServer(
configureServices: s => s.AddHttpResponseFormatter(
new HttpResponseFormatterOptions
{
HttpTransportVersion = serverTransportVersion
}));
configureServices: s => s.AddHttpResponseFormatter(
new HttpResponseFormatterOptions
{
HttpTransportVersion = serverTransportVersion
}));

return server.CreateClient();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ type Subscription {
onReview(episode: Episode!): Review!
onNext: String!
onException: String!
delay(delay: Int! count: Int!): String!
}

union SearchResult = Starship | Human | Droid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ type Subscription {
onReview(episode: Episode!): Review!
onNext: String!
onException: String!
delay(delay: Int! count: Int!): String!
}

union SearchResult = Starship | Human | Droid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ type Subscription {
onReview(episode: Episode!): Review!
onNext: String!
onException: String!
delay(delay: Int! count: Int!): String!
}

union SearchResult = Starship | Human | Droid
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.IO;
using System.Text.Encodings.Web;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -15,19 +14,14 @@ namespace HotChocolate.Execution.Serialization;
/// </summary>
public sealed class EventStreamResultFormatter : IExecutionResultFormatter
{
private static readonly byte[] _eventField
= { (byte)'e', (byte)'v', (byte)'e', (byte)'n', (byte)'t', (byte)':', (byte)' ' };
private static readonly byte[] _dataField
= { (byte)'d', (byte)'a', (byte)'t', (byte)'a', (byte)':', (byte)' ' };
private static readonly byte[] _nextEvent
= { (byte)'n', (byte)'e', (byte)'x', (byte)'t' };
private static readonly byte[] _completeEvent
=
{
(byte)'c', (byte)'o', (byte)'m', (byte)'p',
(byte)'l', (byte)'e', (byte)'t', (byte)'e'
};
private static readonly byte[] _newLine = { (byte)'\n' };
private static readonly TimeSpan _keepAliveTimeSpan = TimeSpan.FromSeconds(12);

private static readonly byte[] _eventField = "event: "u8.ToArray();
private static readonly byte[] _dataField = "data: "u8.ToArray();
private static readonly byte[] _nextEvent = "next"u8.ToArray();
private static readonly byte[] _keepAlive = ":\n\n"u8.ToArray();
private static readonly byte[] _completeEvent = "complete"u8.ToArray();
private static readonly byte[] _newLine = "\n"u8.ToArray();

private readonly JsonResultFormatter _payloadFormatter;
private readonly JsonWriterOptions _options;
Expand Down Expand Up @@ -70,50 +64,121 @@ private async ValueTask FormatInternalAsync(
{
if (result.Kind is SingleResult)
{
await WriteNextMessageAsync((IQueryResult)result, outputStream).ConfigureAwait(false);
await WriteNextMessageAsync((IQueryResult)result, outputStream, ct)
.ConfigureAwait(false);
await WriteNewLineAndFlushAsync(outputStream, ct).ConfigureAwait(false);
await WriteCompleteMessage(outputStream).ConfigureAwait(false);
await WriteCompleteMessage(outputStream, ct).ConfigureAwait(false);
await WriteNewLineAndFlushAsync(outputStream, ct).ConfigureAwait(false);
}
else if (result.Kind is DeferredResult or BatchResult or SubscriptionResult)
{
var responseStream = (IResponseStream)result;

await foreach (var queryResult in responseStream.ReadResultsAsync()
.WithCancellation(ct).ConfigureAwait(false))
{
try
{
await WriteNextMessageAsync(queryResult, outputStream)
.ConfigureAwait(false);
}
finally
{
await queryResult.DisposeAsync().ConfigureAwait(false);
}

await WriteNewLineAndFlushAsync(outputStream, ct).ConfigureAwait(false);
}

await WriteCompleteMessage(outputStream).ConfigureAwait(false);
await WriteNewLineAndFlushAsync(outputStream, ct).ConfigureAwait(false);
// synchronization of the output stream is required to ensure that the messages are not
// interleaved.
using var synchronization = new SemaphoreSlim(1, 1);

// we need to keep track if the stream is completed so that we can stop sending keep
// alive messages.
var completion = new TaskCompletionSource<bool>();

// we await all tasks so that we can catch all exceptions.
await Task.WhenAll(
ProcessResponseStreamAsync(
synchronization,
completion,
responseStream,
outputStream,
ct),
SendKeepAliveMessagesAsync(synchronization, completion, outputStream, ct));
}

else
{
throw new NotSupportedException();
}
}

private async ValueTask WriteNextMessageAsync(IQueryResult result, Stream outputStream)
private static async Task SendKeepAliveMessagesAsync(
SemaphoreSlim synchronization,
TaskCompletionSource<bool> completion,
Stream outputStream,
CancellationToken ct)
{
while (true)
{
await Task.WhenAny(Task.Delay(_keepAliveTimeSpan, ct), completion.Task);

if (!ct.IsCancellationRequested && !completion.Task.IsCompleted)
{
// we do not need try-finally here because we dispose the semaphore in the parent
// method.
await synchronization.WaitAsync(ct);

await WriteKeepAliveAndFlush(outputStream, ct);

synchronization.Release();
}
else
{
break;
}
}
}

private async Task ProcessResponseStreamAsync(
SemaphoreSlim synchronization,
TaskCompletionSource<bool> completion,
IResponseStream responseStream,
Stream outputStream,
CancellationToken ct)
{
await foreach (var queryResult in responseStream.ReadResultsAsync()
.WithCancellation(ct)
.ConfigureAwait(false))
{
// we do not need try-finally here because we dispose the semaphore in the parent
// method.

await synchronization.WaitAsync(ct);

try
{
await WriteNextMessageAsync(queryResult, outputStream, ct)
.ConfigureAwait(false);
}
finally
{
await queryResult.DisposeAsync().ConfigureAwait(false);
}

await WriteNewLineAndFlushAsync(outputStream, ct).ConfigureAwait(false);

synchronization.Release();
}

await synchronization.WaitAsync(ct);

await WriteCompleteMessage(outputStream, ct).ConfigureAwait(false);
await WriteNewLineAndFlushAsync(outputStream, ct).ConfigureAwait(false);

synchronization.Release();
completion.SetResult(true);
}

private async ValueTask WriteNextMessageAsync(
IQueryResult result,
Stream outputStream,
CancellationToken ct)
{
#if NETCOREAPP3_1_OR_GREATER
await outputStream.WriteAsync(_eventField).ConfigureAwait(false);
await outputStream.WriteAsync(_nextEvent).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine).ConfigureAwait(false);
await outputStream.WriteAsync(_eventField, ct).ConfigureAwait(false);
await outputStream.WriteAsync(_nextEvent, ct).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine, ct).ConfigureAwait(false);
#else
await outputStream.WriteAsync(_eventField, 0, _eventField.Length).ConfigureAwait(false);
await outputStream.WriteAsync(_nextEvent, 0, _nextEvent.Length).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine, 0, _newLine.Length).ConfigureAwait(false);
await outputStream.WriteAsync(_eventField, 0, _eventField.Length, ct).ConfigureAwait(false);
await outputStream.WriteAsync(_nextEvent, 0, _nextEvent.Length, ct).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine, 0, _newLine.Length, ct).ConfigureAwait(false);
#endif

using var bufferWriter = new ArrayWriter();
Expand All @@ -132,30 +197,31 @@ private async ValueTask WriteNextMessageAsync(IQueryResult result, Stream output
}

#if NETCOREAPP3_1_OR_GREATER
await outputStream.WriteAsync(_dataField).ConfigureAwait(false);
await outputStream.WriteAsync(buffer).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine).ConfigureAwait(false);
await outputStream.WriteAsync(_dataField, ct).ConfigureAwait(false);
await outputStream.WriteAsync(buffer, ct).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine, ct).ConfigureAwait(false);
#else
await outputStream.WriteAsync(_dataField, 0, _dataField.Length).ConfigureAwait(false);
await outputStream.WriteAsync(bufferWriter.GetInternalBuffer(), read, buffer.Length)
await outputStream.WriteAsync(_dataField, 0, _dataField.Length, ct)
.ConfigureAwait(false);
await outputStream.WriteAsync(_newLine, 0, _newLine.Length).ConfigureAwait(false);
await outputStream.WriteAsync(bufferWriter.GetInternalBuffer(), read, buffer.Length, ct)
.ConfigureAwait(false);
await outputStream.WriteAsync(_newLine, 0, _newLine.Length, ct).ConfigureAwait(false);
#endif

read += buffer.Length + 1;
}
}

private static async ValueTask WriteCompleteMessage(Stream outputStream)
private static async ValueTask WriteCompleteMessage(Stream outputStream, CancellationToken ct)
{
#if NETCOREAPP3_1_OR_GREATER
await outputStream.WriteAsync(_eventField).ConfigureAwait(false);
await outputStream.WriteAsync(_completeEvent).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine).ConfigureAwait(false);
await outputStream.WriteAsync(_eventField, ct).ConfigureAwait(false);
await outputStream.WriteAsync(_completeEvent, ct).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine, ct).ConfigureAwait(false);
#else
await outputStream.WriteAsync(_eventField, 0, _eventField.Length).ConfigureAwait(false);
await outputStream.WriteAsync(_completeEvent, 0, _completeEvent.Length).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine, 0, _newLine.Length).ConfigureAwait(false);
await outputStream.WriteAsync(_eventField, 0, _eventField.Length, ct).ConfigureAwait(false);
await outputStream.WriteAsync(_completeEvent, 0, _completeEvent.Length, ct).ConfigureAwait(false);
await outputStream.WriteAsync(_newLine, 0, _newLine.Length, ct).ConfigureAwait(false);
#endif
}

Expand All @@ -167,6 +233,18 @@ private static async ValueTask WriteNewLineAndFlushAsync(
await outputStream.WriteAsync(_newLine, ct).ConfigureAwait(false);
#else
await outputStream.WriteAsync(_newLine, 0, _newLine.Length, ct).ConfigureAwait(false);
#endif
await outputStream.FlushAsync(ct).ConfigureAwait(false);
}

private static async ValueTask WriteKeepAliveAndFlush(
Stream outputStream,
CancellationToken ct)
{
#if NETCOREAPP3_1_OR_GREATER
await outputStream.WriteAsync(_keepAlive, ct).ConfigureAwait(false);
#else
await outputStream.WriteAsync(_keepAlive, 0, _keepAlive.Length, ct).ConfigureAwait(false);
#endif
await outputStream.FlushAsync(ct).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Subscription {
onReview(episode: Episode!): Review!
onNext: String!
onException: String!
delay(delay: Int! count: Int!): String!
}

type Human implements Character {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Subscription {
onReview(episode: Episode!): Review!
onNext: String!
onException: String!
delay(delay: Int! count: Int!): String!
}

type Human implements Character {
Expand Down

0 comments on commit b2000ed

Please sign in to comment.