Skip to content

Commit

Permalink
Improved cancellation (#186)
Browse files Browse the repository at this point in the history
- Allow cancellation during connection attempt
- Don't display `OperationCancelledException` as error
- Some cleanup
  • Loading branch information
ShortDevelopment authored Jan 1, 2025
1 parent e071cd5 commit c29187b
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 74 deletions.
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project>
<PropertyGroup Label="Language">
<LangVersion>latest</LangVersion>
<LangVersion>preview</LangVersion>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ public sealed class NearShareSender(ConnectedDevicesPlatform platform)

async Task<SenderStateMachine> PrepareTransferInternalAsync(EndpointInfo endpoint, CancellationToken cancellationToken)
{
var session = await Platform.ConnectAsync(endpoint, options: new() { TransportUpgraded = TransportUpgraded });
var session = await Platform.ConnectAsync(endpoint, options: new() { TransportUpgraded = TransportUpgraded }, cancellationToken);

Guid operationId = Guid.NewGuid();

HandshakeHandler handshake = new(Platform);
using var handShakeChannel = await session.StartClientChannelAsync(NearShareHandshakeApp.Id, NearShareHandshakeApp.Name, handshake, cancellationToken);
using var handShakeChannel = await session.StartClientChannelAsync(handshake, cancellationToken);
var handshakeResultMsg = await handshake.Execute(operationId);

// ToDo: CorrelationVector
Expand All @@ -49,8 +49,11 @@ public async Task SendFilesAsync(CdpDevice device, IReadOnlyList<CdpFileProvider
await senderStateMachine.SendFilesAsync(files, progress, cancellationToken);
}

sealed class HandshakeHandler(ConnectedDevicesPlatform cdp) : CdpAppBase(cdp)
sealed class HandshakeHandler(ConnectedDevicesPlatform cdp) : CdpAppBase(cdp), ICdpAppId
{
public static string Id { get; } = NearShareHandshakeApp.Id;
public static string Name { get; } = NearShareHandshakeApp.Name;

readonly TaskCompletionSource<CdpMessage> _promise = new();

public Task<CdpMessage> Execute(Guid operationId)
Expand Down
9 changes: 6 additions & 3 deletions lib/ShortDev.Microsoft.ConnectedDevices/CdpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ internal static CdpSession GetOrCreate(ConnectedDevicesPlatform platform, Endpoi
), out _);
}

internal static async Task<CdpSession> ConnectClientAsync(ConnectedDevicesPlatform platform, CdpSocket socket, ConnectOptions? options = null)
internal static async Task<CdpSession> ConnectClientAsync(ConnectedDevicesPlatform platform, CdpSocket socket, ConnectOptions? options = null, CancellationToken cancellationToken = default)
{
var session = _sessionRegistry.Create(localSessionId => new(
platform,
Expand All @@ -85,7 +85,7 @@ internal static async Task<CdpSession> ConnectClientAsync(ConnectedDevicesPlatfo
if (options is not null)
connectHandler.UpgradeHandler.Upgraded += options.TransportUpgraded;

await connectHandler.ConnectAsync(socket);
await connectHandler.ConnectAsync(socket, cancellationToken: cancellationToken);

return session;
}
Expand Down Expand Up @@ -178,12 +178,15 @@ void HandleSession(CommonHeader header, ref EndianReader reader)
}
#endregion

public Task<CdpChannel> StartClientChannelAsync<TApp>(TApp handler, CancellationToken cancellationToken = default) where TApp : CdpAppBase, ICdpAppId
=> StartClientChannelAsync(TApp.Id, TApp.Name, handler, cancellationToken);

public async Task<CdpChannel> StartClientChannelAsync(string appId, string appName, CdpAppBase handler, CancellationToken cancellationToken = default)
{
if (_channelHandler is not ClientChannelHandler clientChannelHandler)
throw new InvalidOperationException("Session is not a client");

var socket = await Platform.CreateSocketAsync(_connectHandler.UpgradeHandler.RemoteEndpoint);
var socket = await Platform.CreateSocketAsync(_connectHandler.UpgradeHandler.RemoteEndpoint, cancellationToken);
return await clientChannelHandler.CreateChannelAsync(appId, appName, handler, socket, cancellationToken);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ await Task.WhenAll(_transportMap.Values
}
}

public async Task<CdpSession> ConnectAsync([NotNull] EndpointInfo endpoint, ConnectOptions? options = null)
public async Task<CdpSession> ConnectAsync([NotNull] EndpointInfo endpoint, ConnectOptions? options = null, CancellationToken cancellationToken = default)
{
var socket = await CreateSocketAsync(endpoint).ConfigureAwait(false);
return await CdpSession.ConnectClientAsync(this, socket, options).ConfigureAwait(false);
var socket = await CreateSocketAsync(endpoint, cancellationToken).ConfigureAwait(false);
return await CdpSession.ConnectClientAsync(this, socket, options, cancellationToken).ConfigureAwait(false);
}

internal async Task<CdpSocket> CreateSocketAsync(EndpointInfo endpoint)
internal async Task<CdpSocket> CreateSocketAsync(EndpointInfo endpoint, CancellationToken cancellationToken = default)
{
if (TryGetKnownSocket(endpoint, out var knownSocket))
return knownSocket;

var transport = TryGetTransport(endpoint.TransportType) ?? throw new InvalidOperationException($"No single transport found for type {endpoint.TransportType}");
var socket = await transport.ConnectAsync(endpoint).ConfigureAwait(false);
var socket = await transport.ConnectAsync(endpoint, cancellationToken).ConfigureAwait(false);
ReceiveLoop(socket);
return socket;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@
using ShortDev.Microsoft.ConnectedDevices.Messages.Connection.DeviceInfo;
using ShortDev.Microsoft.ConnectedDevices.Session.Upgrade;
using ShortDev.Microsoft.ConnectedDevices.Transports;
using System.Runtime.CompilerServices;

namespace ShortDev.Microsoft.ConnectedDevices.Session.Connection;
internal sealed class ClientConnectHandler(CdpSession session, ClientUpgradeHandler upgradeHandler) : ConnectHandler(session, upgradeHandler)
{
readonly ClientUpgradeHandler _clientUpgradeHandler = upgradeHandler;
readonly ILogger _logger = session.Platform.CreateLogger<ClientConnectHandler>();

TaskCompletionSource? _promise;
public Task ConnectAsync(CdpSocket socket, bool upgradeSupported = true)
ConnectionTask? _promise;
public async Task ConnectAsync(CdpSocket socket, bool upgradeSupported = true, CancellationToken cancellationToken = default)
{
if (_promise != null)
throw new InvalidOperationException("Already connecting");
cancellationToken.ThrowIfCancellationRequested();

_promise = new();
if (Interlocked.CompareExchange(ref _promise, new(cancellationToken), null) is not null)
throw new InvalidOperationException("Already connecting");

CommonHeader header = new()
{
Expand Down Expand Up @@ -52,11 +53,14 @@ public Task ConnectAsync(CdpSocket socket, bool upgradeSupported = true)

_session.SendMessage(socket, header, writer);

return _promise.Task;
await _promise;
}

protected override void HandleMessageInternal(CdpSocket socket, CommonHeader header, ConnectionHeader connectionHeader, ref EndianReader reader)
{
if (_promise?.CancellationToken.IsCancellationRequested == true)
return;

if (connectionHeader.MessageType == ConnectionType.ConnectResponse)
{
if (_session.Cryptor != null)
Expand Down Expand Up @@ -135,7 +139,7 @@ async void PrepareSession(CdpSocket socket)
try
{
var oldSocket = socket;
socket = await _clientUpgradeHandler.RequestUpgradeAsync(oldSocket);
socket = await _clientUpgradeHandler.UpgradeAsync(oldSocket);
oldSocket.Dispose();
}
catch (Exception ex)
Expand All @@ -146,26 +150,21 @@ async void PrepareSession(CdpSocket socket)

try
{
SendAuthDone(socket);
EndianWriter writer = new(Endianness.BigEndian);
new ConnectionHeader()
{
ConnectionMode = ConnectionMode.Proximal,
MessageType = ConnectionType.AuthDoneRequest
}.Write(writer);

header.Flags = 0;
_session.SendMessage(socket, header, writer);
}
catch (Exception ex)
{
_promise?.TrySetException(ex);
}
}

void SendAuthDone(CdpSocket socket)
{
EndianWriter writer = new(Endianness.BigEndian);
new ConnectionHeader()
{
ConnectionMode = ConnectionMode.Proximal,
MessageType = ConnectionType.AuthDoneRequest
}.Write(writer);

header.Flags = 0;
_session.SendMessage(socket, header, writer);
}
}

void HandleDeviceAuthResponse(CdpSocket socket, CommonHeader header)
Expand Down Expand Up @@ -212,4 +211,26 @@ void HandleAuthDoneResponse(CdpSocket socket, ref EndianReader reader)

_promise?.TrySetResult();
}

sealed class ConnectionTask
{
readonly TaskCompletionSource _promise = new();

public CancellationToken CancellationToken { get; }
public ConnectionTask(CancellationToken cancellationToken)
{
CancellationToken = cancellationToken;

cancellationToken.Register(() => _promise.TrySetCanceled());
}

public void TrySetResult()
=> _promise.TrySetResult();

public void TrySetException(Exception ex)
=> _promise.TrySetException(ex);

public TaskAwaiter GetAwaiter()
=> _promise.Task.GetAwaiter();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,15 @@ protected override bool TryHandleConnectInternal(CdpSocket socket, ConnectionHea
static readonly IReadOnlyList<EndpointMetadata> UpgradeEndpoints = [EndpointMetadata.Tcp];

UpgradeInstance? _currentUpgrade;
public async ValueTask<CdpSocket> RequestUpgradeAsync(CdpSocket oldSocket)
public async ValueTask<CdpSocket> UpgradeAsync(CdpSocket oldSocket)
{
if (_currentUpgrade != null)
if (Interlocked.CompareExchange(ref _currentUpgrade, new(), null) is not null)
throw new InvalidOperationException("Only a single upgrade may occur at the same time");

_currentUpgrade = new();
try
{
_logger.SendingUpgradeRequest(_currentUpgrade.Id, UpgradeEndpoints);
SendUpgradeRequest(oldSocket, _currentUpgrade.Id, UpgradeEndpoints);
return await _currentUpgrade.Promise.Task;
}
finally
{
_currentUpgrade = null;
}

void SendUpgradeRequest(CdpSocket socket, Guid upgradeId, IReadOnlyList<EndpointMetadata> endpoints)
{
CommonHeader header = new()
{
Type = MessageType.Connect
Expand All @@ -73,11 +63,17 @@ void SendUpgradeRequest(CdpSocket socket, Guid upgradeId, IReadOnlyList<Endpoint

new UpgradeRequest()
{
UpgradeId = upgradeId,
Endpoints = endpoints
UpgradeId = _currentUpgrade.Id,
Endpoints = UpgradeEndpoints
}.Write(writer);

_session.SendMessage(socket, header, writer);
_session.SendMessage(oldSocket, header, writer);

return await _currentUpgrade;
}
finally
{
_currentUpgrade = null;
}
}

Expand Down Expand Up @@ -107,19 +103,15 @@ async void FindNewEndpoint()
if (_currentUpgrade == null)
return;

_currentUpgrade.NewSocket = tasks.FirstOrDefault(x => x != null);
if (_currentUpgrade.NewSocket == null)
{
_currentUpgrade.Promise.TrySetCanceled();
if (!_currentUpgrade.TryChooseSocket(tasks.FirstOrDefault(x => x != null)))
return;
}

SendUpgradFinalization(oldSocket);

// Cancel after timeout if upgrade has not finished yet
await Task.Delay(UpgradeInstance.Timeout);

_currentUpgrade?.Promise.TrySetCanceled();
_currentUpgrade?.TrySetCanceled();
}
}

Expand Down Expand Up @@ -176,7 +168,7 @@ void HandleUpgradeFailure(ref EndianReader reader)
{
var msg = HResultPayload.Parse(ref reader);

_currentUpgrade?.Promise.TrySetException(
_currentUpgrade?.TrySetException(
new Exception($"Transport upgrade failed with HResult {msg.HResult} (hresult: {HResultPayload.HResultToString(msg.HResult)}, errorCode: {HResultPayload.ErrorCodeToString(msg.HResult)})")
);
}
Expand All @@ -195,19 +187,41 @@ void HandleTransportConfirmation(CdpSocket socket, ref EndianReader reader)
RemoteEndpoint = socket.Endpoint;

// Complete promise
_currentUpgrade.Promise.TrySetResult(socket);
_currentUpgrade.TrySetResult(socket);
}

sealed class UpgradeInstance
{
public static readonly TimeSpan Timeout = TimeSpan.FromSeconds(2);

public Guid Id { get; } = Guid.NewGuid();
public TaskCompletionSource<CdpSocket> Promise { get; } = new();

readonly TaskCompletionSource<CdpSocket> _promise = new();
public bool TrySetCanceled()
=> _promise.TrySetCanceled();

public bool TrySetResult(CdpSocket socket)
=> _promise.TrySetResult(socket);

public bool TrySetException(Exception ex)
=> _promise.TrySetException(ex);

public TaskAwaiter<CdpSocket> GetAwaiter()
=> Promise.Task.GetAwaiter();
=> _promise.Task.GetAwaiter();

CdpSocket? _newSocket;
public bool TryChooseSocket(CdpSocket? newSocket)
{
if (newSocket is null)
{
_promise.TrySetCanceled();
return false;
}

return Interlocked.CompareExchange(ref _newSocket, newSocket, null) is null;
}

public CdpSocket? NewSocket { get; set; }
public CdpSocket? NewSocket
=> _newSocket;
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
namespace ShortDev.Microsoft.ConnectedDevices.Transports.Bluetooth;
public sealed class BluetoothTransport(IBluetoothHandler handler) : ICdpTransport, ICdpDiscoverableTransport
{
public CdpTransportType TransportType { get; } = CdpTransportType.Rfcomm;
readonly IBluetoothHandler _handler = handler;

public IBluetoothHandler Handler { get; } = handler;
public CdpTransportType TransportType { get; } = CdpTransportType.Rfcomm;
public bool IsEnabled => _handler.IsEnabled;

public event DeviceConnectedEventHandler? DeviceConnected;
public async Task Listen(CancellationToken cancellationToken)
{
await Handler.ListenRfcommAsync(
await _handler.ListenRfcommAsync(
new RfcommOptions()
{
ServiceId = Constants.RfcommServiceId,
Expand All @@ -19,21 +20,21 @@ await Handler.ListenRfcommAsync(
);
}

public async Task<CdpSocket> ConnectAsync(EndpointInfo endpoint)
=> await Handler.ConnectRfcommAsync(endpoint, new RfcommOptions()
public async Task<CdpSocket> ConnectAsync(EndpointInfo endpoint, CancellationToken cancellationToken = default)
=> await _handler.ConnectRfcommAsync(endpoint, new RfcommOptions()
{
ServiceId = Constants.RfcommServiceId,
ServiceName = Constants.RfcommServiceName,
SocketConnected = (socket) => DeviceConnected?.Invoke(this, socket)
});
}, cancellationToken);

public async Task Advertise(LocalDeviceInfo deviceInfo, CancellationToken cancellationToken)
{
await Handler.AdvertiseBLeBeaconAsync(
await _handler.AdvertiseBLeBeaconAsync(
new AdvertiseOptions()
{
ManufacturerId = Constants.BLeBeaconManufacturerId,
BeaconData = new BLeBeacon(deviceInfo.Type, Handler.MacAddress, deviceInfo.Name)
BeaconData = new BLeBeacon(deviceInfo.Type, _handler.MacAddress, deviceInfo.Name)
},
cancellationToken
);
Expand All @@ -42,7 +43,7 @@ await Handler.AdvertiseBLeBeaconAsync(
public event DeviceDiscoveredEventHandler? DeviceDiscovered;
public async Task Discover(CancellationToken cancellationToken)
{
await Handler.ScanBLeAsync(new()
await _handler.ScanBLeAsync(new()
{
OnDeviceDiscovered = (advertisement, rssi) =>
{
Expand All @@ -65,5 +66,5 @@ public void Dispose()
}

public EndpointInfo GetEndpoint()
=> new(TransportType, Handler.MacAddress.ToStringFormatted(), Constants.RfcommServiceId);
=> new(TransportType, _handler.MacAddress.ToStringFormatted(), Constants.RfcommServiceId);
}
Loading

0 comments on commit c29187b

Please sign in to comment.