Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notify observer to complete async stream on network error. #5060

Merged
merged 10 commits into from
Nov 7, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ namespace StrawberryShake.Transport.WebSockets;
/// </summary>
public interface ISocketClient : IAsyncDisposable
{
/// <summary>
/// A even that is called when the message receiving cycle stoped
/// </summary>
event EventHandler ReceiveFinished;

/// <summary>
/// The URI where the socket should connect to
/// </summary>
Expand Down Expand Up @@ -70,4 +75,4 @@ Task CloseAsync(
string message,
SocketCloseStatus closeStatus,
CancellationToken cancellationToken = default);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using StrawberryShake.Transport.WebSockets.Messages;

namespace StrawberryShake.Transport.WebSockets;
Expand All @@ -18,4 +20,13 @@ public interface ISocketOperation : IAsyncDisposable
/// CReate an operation message stream.
/// </summary>
IAsyncEnumerable<OperationMessage> ReadAsync();

/// <summary>
/// Complete the operation
/// </summary>
/// <param name="cancellationToken">
/// A <see cref="CancellationToken"/> to cancel the completion
/// </param>
/// <returns>A task that is completed once the operation is completed</returns>
ValueTask CompleteAsync(CancellationToken cancellationToken);
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,16 @@ Task StopOperationAsync(
/// </summary>
/// <param name="listener"></param>
void Unsubscribe(OnReceiveAsync listener);
}

/// <summary>
/// Notify the protocol to complete
/// </summary>
/// <param name="operationId">The id of the operation to stop</param>
/// <param name="cancellationToken">
/// A <see cref="CancellationToken"/> to cancel the notification
/// </param>
/// <returns>A task that is completed once the notification is completed</returns>
ValueTask NotifyCompletion(
string operationId,
CancellationToken cancellationToken);
}
18 changes: 18 additions & 0 deletions src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public Session(ISocketClient socketClient)
{
_socketClient = socketClient ??
throw new ArgumentNullException(nameof(socketClient));

_socketClient.ReceiveFinished += ReceiveFinishHandler;
}

/// <inheritdoc />
Expand Down Expand Up @@ -93,6 +95,21 @@ await socketProtocol
}
}

/// <inheritdoc />
private async ValueTask CompleteOperation(CancellationToken cancellationToken)
{
foreach (var operation in _operations)
{
await operation.Value.CompleteAsync(cancellationToken);
}
}

/// <inheritdoc />
private void ReceiveFinishHandler(object? sender, EventArgs args)
{
_ = CompleteOperation(default);
}

/// <summary>
/// Opens a session over the socket
/// </summary>
Expand Down Expand Up @@ -166,6 +183,7 @@ public async ValueTask DisposeAsync()
_operations.Clear();
}

_socketClient.ReceiveFinished -= ReceiveFinishHandler;
_socketProtocol?.Unsubscribe(ReceiveMessage);
await _socketClient.DisposeAsync();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ public SocketOperation(
public IAsyncEnumerable<OperationMessage> ReadAsync()
=> new MessageStream(this, _channel);

/// <inheritdoc />
public async ValueTask CompleteAsync(CancellationToken cancellationToken)
{
if (!_disposed)
{
try
{
await _channel.Writer.WriteAsync(CompleteOperationMessage.Default, cancellationToken).ConfigureAwait(false);
}
catch (ChannelClosedException)
{
// if the channel is closed we will move on.
}
}
}

private sealed class MessageStream : IAsyncEnumerable<OperationMessage>
{
private readonly SocketOperation _operation;
Expand Down Expand Up @@ -113,4 +129,4 @@ public async ValueTask DisposeAsync()
_disposed = true;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ protected async ValueTask Notify(
}
}

/// <inheritdoc />
public async ValueTask NotifyCompletion(
string operationId,
CancellationToken cancellationToken)
{
await Notify(operationId, CompleteOperationMessage.Default, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public virtual ValueTask DisposeAsync()
{
Expand All @@ -78,4 +86,4 @@ public virtual ValueTask DisposeAsync()
_disposed = true;
return default;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ public sealed class WebSocketClient : IWebSocketClient
private readonly IReadOnlyList<ISocketProtocolFactory> _protocolFactories;
private readonly ClientWebSocket _socket;
private ISocketProtocol? _activeProtocol;
private bool _receiveFinishEventTriggered = false;
private bool _disposed;

/// <inheritdoc />
public event EventHandler ReceiveFinished = default!;

/// <summary>
/// Creates a new instance of <see cref="WebSocketClient"/>
/// </summary>
Expand Down Expand Up @@ -52,10 +56,22 @@ public WebSocketClient(
public string Name { get; }

/// <inheritdoc />
public bool IsClosed =>
_disposed
|| _socket.CloseStatus.HasValue
|| _socket.State == WebSocketState.Aborted;
public bool IsClosed
{
get
{
var closed = _disposed
|| _socket.CloseStatus.HasValue
|| _socket.State == WebSocketState.Aborted;

if (closed && !_receiveFinishEventTriggered)
{
_receiveFinishEventTriggered = true;
ReceiveFinished?.Invoke(this, EventArgs.Empty);
}
return closed;
}
}

/// <inheritdoc />
public WebSocket Socket => _socket;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public sealed class SocketClientStub : ISocketClient
new(TaskCreationOptions.None);
private bool _isClosed = true;

public event EventHandler ReceiveFinished = default!;

public SemaphoreSlim Blocker { get; } = new(0);

public Uri? Uri { get; set; }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
using System;
using System.Net.WebSockets;
using System.Reflection;
using System.Threading.Tasks;
using HotChocolate.AspNetCore.Tests.Utilities;
using HotChocolate.StarWars.Models;
using HotChocolate.Subscriptions;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using StrawberryShake.Transport.WebSockets;
using StrawberryShake.Transport.WebSockets.Protocols;
using Xunit;

namespace StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsOnReviewSubCompletion
Expand Down Expand Up @@ -63,11 +66,128 @@ public async Task Watch_StarWarsOnReviewSubCompletion_Test()
{
await Task.Delay(1_000);
}

// assert
Assert.True(commentary is not null && completionTriggered);

session.Dispose();
}

[Fact]
public async Task Watch_StarWarsOnReviewSubCompletionPassively_Test()
{
// arrange
using IWebHost host = TestServerHelper.CreateServer(
_ => { },
out var port);
var topicEventSender = host.Services.GetRequiredService<ITopicEventSender>();

var serviceCollection = new ServiceCollection();
serviceCollection.AddStarWarsOnReviewSubCompletionClient(
profile: StarWarsOnReviewSubCompletionClientProfileKind.Default)
.ConfigureHttpClient(
c => c.BaseAddress = new Uri("http://localhost:" + port + "/graphql"))
.ConfigureWebSocketClient(
c => c.Uri = new Uri("ws://localhost:" + port + "/graphql"));

serviceCollection.AddSingleton<SubscriptionSocketStateMonitor>();

// act
IServiceProvider services = serviceCollection.BuildServiceProvider();
IStarWarsOnReviewSubCompletionClient client = services.GetRequiredService<IStarWarsOnReviewSubCompletionClient>();

string? commentary = null;
bool completionTriggered = false;

var sub = client.OnReviewSub.Watch();
var session = sub.Subscribe(
result => commentary = result.Data?.OnReview?.Commentary,
() => completionTriggered = true);

var topic = Episode.NewHope;

// try to send message 10 times
// make sure the subscription connection is successful
for (int times = 0; commentary is null && times < 10; times++)
{
await topicEventSender.SendAsync(topic, new Review { Stars = 1, Commentary = "Commentary" });
await Task.Delay(1_000);
}

// simulate network error
var monitor = services.GetRequiredService<SubscriptionSocketStateMonitor>();
monitor.AbortSocket();

//await host.StopAsync();

// waiting for completion message sent
for (int times = 0; !completionTriggered && times < 10; times++)
{
await Task.Delay(1_000);
}

// assert
Assert.True(commentary is not null && completionTriggered);

session.Dispose();
}
}

public class SubscriptionSocketStateMonitor
{
private const BindingFlags _bindingFlags = BindingFlags.NonPublic | BindingFlags.Instance;

private readonly ISessionPool _sessionPool;
private readonly Type _sessionPoolType;
private readonly FieldInfo _sessionsField;

private readonly FieldInfo _socketOperationsDictionaryField = typeof(Session).GetField("_operations", _bindingFlags)!;
private readonly FieldInfo _socketOperationManagerField = typeof(SocketOperation).GetField("_manager", _bindingFlags)!;
private readonly FieldInfo _socketProtocolField = typeof(Session)!.GetField("_socketProtocol", _bindingFlags)!;
private readonly FieldInfo _protocolReceiverField = typeof(GraphQLWebSocketProtocol).GetField("_receiver", _bindingFlags)!;

private Type? _sessionInfoType;
private PropertyInfo? _sessionProperty;
private Type? _receiverType;
private FieldInfo? _receiverClientField;

public SubscriptionSocketStateMonitor(ISessionPool sessionPool)
{
_sessionPool = sessionPool;
_sessionPoolType = _sessionPool.GetType();
_sessionsField = _sessionPoolType.GetField("_sessions", _bindingFlags)!;
}

public void AbortSocket()
{
var sessionInfos = (_sessionsField!.GetValue(_sessionPool) as System.Collections.IDictionary)!.Values;

foreach (var sessionInfo in sessionInfos)
{
_sessionInfoType ??= sessionInfo.GetType();
_sessionProperty ??= _sessionInfoType.GetProperty("Session")!;
var session = _sessionProperty.GetValue(sessionInfo) as Session;
var socketOperations = _socketOperationsDictionaryField
.GetValue(session) as System.Collections.Concurrent.ConcurrentDictionary<string, SocketOperation>;

foreach (var operation in socketOperations!)
{
var operationsession = _socketOperationManagerField.GetValue(operation.Value) as Session;
var protocol = _socketProtocolField.GetValue(operationsession) as GraphQLWebSocketProtocol;

var receiver = _protocolReceiverField.GetValue(protocol)!;

_receiverType ??= receiver.GetType();
_receiverClientField ??= _receiverType.GetField("_client", _bindingFlags)!;
var client = _receiverClientField.GetValue(receiver) as ISocketClient;

if (client!.IsClosed is false && client is WebSocketClient webSocketClient)
{
var socket = typeof(WebSocketClient).GetField("_socket", _bindingFlags)!.GetValue(webSocketClient) as ClientWebSocket;
socket!.Abort();
}
}
}
}
}
}