Skip to content

Commit

Permalink
Notify observer to complete async stream on network error. (#5060)
Browse files Browse the repository at this point in the history
  • Loading branch information
CoreDX9 authored Nov 7, 2022
1 parent c00ed5f commit 10263bb
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ namespace StrawberryShake.Transport.WebSockets;
/// </summary>
public interface ISocketClient : IAsyncDisposable
{
/// <summary>
/// An event that is called when the message receiving cycle stoped
/// </summary>
event EventHandler ReceiveFinished;

/// <summary>
/// The URI where the socket should connect to
/// </summary>
Expand Down Expand Up @@ -70,4 +75,4 @@ Task CloseAsync(
string message,
SocketCloseStatus closeStatus,
CancellationToken cancellationToken = default);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using StrawberryShake.Transport.WebSockets.Messages;

namespace StrawberryShake.Transport.WebSockets;
Expand All @@ -18,4 +20,13 @@ public interface ISocketOperation : IAsyncDisposable
/// CReate an operation message stream.
/// </summary>
IAsyncEnumerable<OperationMessage> ReadAsync();

/// <summary>
/// Complete the operation
/// </summary>
/// <param name="cancellationToken">
/// A <see cref="CancellationToken"/> to cancel the completion
/// </param>
/// <returns>A task that is completed once the operation is completed</returns>
ValueTask CompleteAsync(CancellationToken cancellationToken);
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,16 @@ Task StopOperationAsync(
/// </summary>
/// <param name="listener"></param>
void Unsubscribe(OnReceiveAsync listener);
}

/// <summary>
/// Notify the protocol to complete
/// </summary>
/// <param name="operationId">The id of the operation to stop</param>
/// <param name="cancellationToken">
/// A <see cref="CancellationToken"/> to cancel the notification
/// </param>
/// <returns>A task that is completed once the notification is completed</returns>
ValueTask NotifyCompletion(
string operationId,
CancellationToken cancellationToken);
}
18 changes: 18 additions & 0 deletions src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public Session(ISocketClient socketClient)
{
_socketClient = socketClient ??
throw new ArgumentNullException(nameof(socketClient));

_socketClient.ReceiveFinished += ReceiveFinishHandler;
}

/// <inheritdoc />
Expand Down Expand Up @@ -93,6 +95,21 @@ await socketProtocol
}
}

/// <inheritdoc />
private async ValueTask CompleteOperation(CancellationToken cancellationToken)
{
foreach (var operation in _operations)
{
await operation.Value.CompleteAsync(cancellationToken);
}
}

/// <inheritdoc />
private void ReceiveFinishHandler(object? sender, EventArgs args)
{
_ = CompleteOperation(default);
}

/// <summary>
/// Opens a session over the socket
/// </summary>
Expand Down Expand Up @@ -166,6 +183,7 @@ public async ValueTask DisposeAsync()
_operations.Clear();
}

_socketClient.ReceiveFinished -= ReceiveFinishHandler;
_socketProtocol?.Unsubscribe(ReceiveMessage);
await _socketClient.DisposeAsync();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ public SocketOperation(
public IAsyncEnumerable<OperationMessage> ReadAsync()
=> new MessageStream(this, _channel);

/// <inheritdoc />
public async ValueTask CompleteAsync(CancellationToken cancellationToken)
{
if (!_disposed)
{
try
{
await _channel.Writer.WriteAsync(CompleteOperationMessage.Default, cancellationToken).ConfigureAwait(false);
}
catch (ChannelClosedException)
{
// if the channel is closed we will move on.
}
}
}

private sealed class MessageStream : IAsyncEnumerable<OperationMessage>
{
private readonly SocketOperation _operation;
Expand Down Expand Up @@ -113,4 +129,4 @@ public async ValueTask DisposeAsync()
_disposed = true;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ protected async ValueTask Notify(
}
}

/// <inheritdoc />
public async ValueTask NotifyCompletion(
string operationId,
CancellationToken cancellationToken)
{
await Notify(operationId, CompleteOperationMessage.Default, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public virtual ValueTask DisposeAsync()
{
Expand All @@ -78,4 +86,4 @@ public virtual ValueTask DisposeAsync()
_disposed = true;
return default;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ public sealed class WebSocketClient : IWebSocketClient
private readonly IReadOnlyList<ISocketProtocolFactory> _protocolFactories;
private readonly ClientWebSocket _socket;
private ISocketProtocol? _activeProtocol;
private bool _receiveFinishEventTriggered = false;
private bool _disposed;

/// <inheritdoc />
public event EventHandler ReceiveFinished = default!;

/// <summary>
/// Creates a new instance of <see cref="WebSocketClient"/>
/// </summary>
Expand Down Expand Up @@ -52,10 +56,22 @@ public WebSocketClient(
public string Name { get; }

/// <inheritdoc />
public bool IsClosed =>
_disposed
|| _socket.CloseStatus.HasValue
|| _socket.State == WebSocketState.Aborted;
public bool IsClosed
{
get
{
var closed = _disposed
|| _socket.CloseStatus.HasValue
|| _socket.State == WebSocketState.Aborted;

if (closed && !_receiveFinishEventTriggered)
{
_receiveFinishEventTriggered = true;
ReceiveFinished?.Invoke(this, EventArgs.Empty);
}
return closed;
}
}

/// <inheritdoc />
public WebSocket Socket => _socket;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public sealed class SocketClientStub : ISocketClient
new(TaskCreationOptions.None);
private bool _isClosed = true;

public event EventHandler ReceiveFinished = default!;

public SemaphoreSlim Blocker { get; } = new(0);

public Uri? Uri { get; set; }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
using System;
using System.Net.WebSockets;
using System.Reflection;
using System.Threading.Tasks;
using HotChocolate.AspNetCore.Tests.Utilities;
using HotChocolate.StarWars.Models;
using HotChocolate.Subscriptions;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using StrawberryShake.Transport.WebSockets;
using StrawberryShake.Transport.WebSockets.Protocols;
using Xunit;

namespace StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsOnReviewSubCompletion
Expand Down Expand Up @@ -63,11 +66,128 @@ public async Task Watch_StarWarsOnReviewSubCompletion_Test()
{
await Task.Delay(1_000);
}

// assert
Assert.True(commentary is not null && completionTriggered);

session.Dispose();
}

[Fact]
public async Task Watch_StarWarsOnReviewSubCompletionPassively_Test()
{
// arrange
using IWebHost host = TestServerHelper.CreateServer(
_ => { },
out var port);
var topicEventSender = host.Services.GetRequiredService<ITopicEventSender>();

var serviceCollection = new ServiceCollection();
serviceCollection.AddStarWarsOnReviewSubCompletionClient(
profile: StarWarsOnReviewSubCompletionClientProfileKind.Default)
.ConfigureHttpClient(
c => c.BaseAddress = new Uri("http://localhost:" + port + "/graphql"))
.ConfigureWebSocketClient(
c => c.Uri = new Uri("ws://localhost:" + port + "/graphql"));

serviceCollection.AddSingleton<SubscriptionSocketStateMonitor>();

// act
IServiceProvider services = serviceCollection.BuildServiceProvider();
IStarWarsOnReviewSubCompletionClient client = services.GetRequiredService<IStarWarsOnReviewSubCompletionClient>();

string? commentary = null;
bool completionTriggered = false;

var sub = client.OnReviewSub.Watch();
var session = sub.Subscribe(
result => commentary = result.Data?.OnReview?.Commentary,
() => completionTriggered = true);

var topic = Episode.NewHope;

// try to send message 10 times
// make sure the subscription connection is successful
for (int times = 0; commentary is null && times < 10; times++)
{
await topicEventSender.SendAsync(topic, new Review { Stars = 1, Commentary = "Commentary" });
await Task.Delay(1_000);
}

// simulate network error
var monitor = services.GetRequiredService<SubscriptionSocketStateMonitor>();
monitor.AbortSocket();

//await host.StopAsync();

// waiting for completion message sent
for (int times = 0; !completionTriggered && times < 10; times++)
{
await Task.Delay(1_000);
}

// assert
Assert.True(commentary is not null && completionTriggered);

session.Dispose();
}
}

public class SubscriptionSocketStateMonitor
{
private const BindingFlags _bindingFlags = BindingFlags.NonPublic | BindingFlags.Instance;

private readonly ISessionPool _sessionPool;
private readonly Type _sessionPoolType;
private readonly FieldInfo _sessionsField;

private readonly FieldInfo _socketOperationsDictionaryField = typeof(Session).GetField("_operations", _bindingFlags)!;
private readonly FieldInfo _socketOperationManagerField = typeof(SocketOperation).GetField("_manager", _bindingFlags)!;
private readonly FieldInfo _socketProtocolField = typeof(Session)!.GetField("_socketProtocol", _bindingFlags)!;
private readonly FieldInfo _protocolReceiverField = typeof(GraphQLWebSocketProtocol).GetField("_receiver", _bindingFlags)!;

private Type? _sessionInfoType;
private PropertyInfo? _sessionProperty;
private Type? _receiverType;
private FieldInfo? _receiverClientField;

public SubscriptionSocketStateMonitor(ISessionPool sessionPool)
{
_sessionPool = sessionPool;
_sessionPoolType = _sessionPool.GetType();
_sessionsField = _sessionPoolType.GetField("_sessions", _bindingFlags)!;
}

public void AbortSocket()
{
var sessionInfos = (_sessionsField!.GetValue(_sessionPool) as System.Collections.IDictionary)!.Values;

foreach (var sessionInfo in sessionInfos)
{
_sessionInfoType ??= sessionInfo.GetType();
_sessionProperty ??= _sessionInfoType.GetProperty("Session")!;
var session = _sessionProperty.GetValue(sessionInfo) as Session;
var socketOperations = _socketOperationsDictionaryField
.GetValue(session) as System.Collections.Concurrent.ConcurrentDictionary<string, SocketOperation>;

foreach (var operation in socketOperations!)
{
var operationsession = _socketOperationManagerField.GetValue(operation.Value) as Session;
var protocol = _socketProtocolField.GetValue(operationsession) as GraphQLWebSocketProtocol;

var receiver = _protocolReceiverField.GetValue(protocol)!;

_receiverType ??= receiver.GetType();
_receiverClientField ??= _receiverType.GetField("_client", _bindingFlags)!;
var client = _receiverClientField.GetValue(receiver) as ISocketClient;

if (client!.IsClosed is false && client is WebSocketClient webSocketClient)
{
var socket = typeof(WebSocketClient).GetField("_socket", _bindingFlags)!.GetValue(webSocketClient) as ClientWebSocket;
socket!.Abort();
}
}
}
}
}
}

0 comments on commit 10263bb

Please sign in to comment.