diff --git a/src/Nerdbank.Streams/MultiplexingStream.Channel.cs b/src/Nerdbank.Streams/MultiplexingStream.Channel.cs
index 71098e62..6a9072f5 100644
--- a/src/Nerdbank.Streams/MultiplexingStream.Channel.cs
+++ b/src/Nerdbank.Streams/MultiplexingStream.Channel.cs
@@ -111,6 +111,7 @@ public class Channel : IDisposableObservable, IDuplexPipe
///
/// The to use to get data to be transmitted over the .
+ /// Any errors passed to this are transmitted to the remote side.
///
private PipeReader? mxStreamIOReader;
@@ -142,6 +143,11 @@ public class Channel : IDisposableObservable, IDuplexPipe
///
private bool? existingPipeGiven;
+ ///
+ /// A value indicating whether this received an error from a remote party in .
+ ///
+ private bool receivedRemoteException;
+
///
/// Initializes a new instance of the class.
///
@@ -316,6 +322,11 @@ private long RemoteWindowRemaining
}
}
+ ///
+ /// Gets the exception sent from the remote side over this channel, null otherwise.
+ ///
+ private Exception? RemoteException => this.receivedRemoteException ? this.faultingException : null;
+
///
/// Gets a value indicating whether backpressure support is enabled.
///
@@ -332,9 +343,6 @@ public void Dispose()
{
if (!this.IsDisposed)
{
- this.acceptanceSource.TrySetCanceled();
- this.optionsAppliedTaskSource?.TrySetCanceled();
-
PipeWriter? mxStreamIOWriter;
lock (this.SyncObject)
{
@@ -342,6 +350,21 @@ public void Dispose()
mxStreamIOWriter = this.mxStreamIOWriter;
}
+ // If we are disposing due to a faulting error, transition the acceptanceSource to an error state
+ lock (this.SyncObject)
+ {
+ if (this.faultingException != null)
+ {
+ this.acceptanceSource.TrySetException(this.faultingException);
+ this.optionsAppliedTaskSource?.TrySetException(this.faultingException);
+ }
+ else
+ {
+ this.acceptanceSource.TrySetCanceled();
+ this.optionsAppliedTaskSource?.TrySetCanceled();
+ }
+ }
+
// Complete writing so that the mxstream cannot write to this channel any more.
// We must also cancel a pending flush since no one is guaranteed to be reading this any more
// and we don't want to deadlock on a full buffer in a disposed channel's pipe.
@@ -361,7 +384,7 @@ public void Dispose()
mxStreamIOWriter = self.mxStreamIOWriter;
}
- mxStreamIOWriter?.Complete();
+ mxStreamIOWriter?.Complete(self.RemoteException);
self.mxStreamIOWriterCompleted.Set();
}
finally
@@ -401,7 +424,17 @@ public void Dispose()
this.remoteWindowHasCapacity.Set();
this.disposalTokenSource.Cancel();
- this.completionSource.TrySetResult(null);
+
+ // If we are disposing due to receiving or sending an exception, relay that to our client.
+ if (this.faultingException is Exception faultingException)
+ {
+ this.completionSource.TrySetException(faultingException);
+ }
+ else
+ {
+ this.completionSource.TrySetResult(null);
+ }
+
this.MultiplexingStream.OnChannelDisposed(this);
}
}
@@ -418,7 +451,7 @@ internal async Task OnChannelTerminatedAsync()
// We Complete the writer because only the writing (logical) thread should complete it
// to avoid race conditions, and Channel.Dispose can be called from any thread.
using PipeWriterRental writerRental = await this.GetReceivedMessagePipeWriterAsync().ConfigureAwait(false);
- await writerRental.Writer.CompleteAsync().ConfigureAwait(false);
+ await writerRental.Writer.CompleteAsync(this.RemoteException).ConfigureAwait(false);
}
catch (ObjectDisposedException)
{
@@ -497,32 +530,49 @@ internal async ValueTask OnContentAsync(ReadOnlySequence payload, Cancella
///
/// Called by the when it will not be writing any more data to the channel.
///
- internal void OnContentWritingCompleted()
+ /// The error in writing that originated on the remote side, if applicable.
+ internal void OnContentWritingCompleted(MultiplexingProtocolException? error = null)
{
+ // If we have already received an error from the remote side then no need to complete the channel again.
+ if (this.receivedRemoteException)
+ {
+ return;
+ }
+
+ // Set the state of the channel based on whether we are completing due to an error.
+ lock (this.SyncObject)
+ {
+ this.faultingException ??= error;
+ this.receivedRemoteException = error != null;
+ }
+
this.DisposeSelfOnFailure(Task.Run(async delegate
{
if (!this.IsDisposed)
{
try
{
+ // If the channel is not disposed, then first try to close the writer used by the channel owner
using PipeWriterRental writerRental = await this.GetReceivedMessagePipeWriterAsync().ConfigureAwait(false);
- await writerRental.Writer.CompleteAsync().ConfigureAwait(false);
+ await writerRental.Writer.CompleteAsync(this.RemoteException).ConfigureAwait(false);
}
catch (ObjectDisposedException)
{
+ // If not, try to close the underlying writer.
if (this.mxStreamIOWriter != null)
{
using AsyncSemaphore.Releaser releaser = await this.mxStreamIOWriterSemaphore.EnterAsync().ConfigureAwait(false);
- await this.mxStreamIOWriter.CompleteAsync().ConfigureAwait(false);
+ await this.mxStreamIOWriter.CompleteAsync(this.RemoteException).ConfigureAwait(false);
}
}
}
else
{
+ // If the channel has not been disposed then just close the underlying writer.
if (this.mxStreamIOWriter != null)
{
using AsyncSemaphore.Releaser releaser = await this.mxStreamIOWriterSemaphore.EnterAsync().ConfigureAwait(false);
- await this.mxStreamIOWriter.CompleteAsync().ConfigureAwait(false);
+ await this.mxStreamIOWriter.CompleteAsync(this.RemoteException).ConfigureAwait(false);
}
}
@@ -545,8 +595,15 @@ internal bool TryAcceptOffer(ChannelOptions channelOptions)
}
var acceptanceParameters = new AcceptanceParameters(this.localWindowSize.Value);
- if (this.acceptanceSource.TrySetResult(acceptanceParameters))
+
+ try
{
+ // Set up the channel options and ensure that the channel is still valid
+ // before we transition to an accepted state.
+ this.ApplyChannelOptions(channelOptions);
+ Verify.NotDisposed(this);
+
+ // If we aren't a seeded channel then send an offer accepted frame.
if (this.QualifiedId.Source != ChannelSource.Seeded)
{
ReadOnlySequence payload = this.MultiplexingStream.formatter.Serialize(acceptanceParameters);
@@ -560,16 +617,20 @@ internal bool TryAcceptOffer(ChannelOptions channelOptions)
CancellationToken.None);
}
- try
+ // Update the acceptance source to the acceptance parameters.
+ return this.acceptanceSource.TrySetResult(acceptanceParameters);
+ }
+ catch (Exception exception)
+ {
+ // Record the exception in the acceptance source.
+ this.acceptanceSource.TrySetException(exception);
+
+ // If we caught an disposal error due to the channel self faulting then swallow
+ // the exception.
+ if (exception is ObjectDisposedException && this.faultingException != null)
{
- this.ApplyChannelOptions(channelOptions);
return true;
}
- catch (ObjectDisposedException)
- {
- // A (harmless) race condition was hit.
- // Swallow it and return false below.
- }
}
return false;
@@ -744,16 +805,35 @@ private async Task ProcessOutboundTransmissionsAsync()
this.mxStreamIOReader = new UnownedPipeReader(mxStreamIOReader);
}
+ bool channelAccepted = false;
try
{
// Don't transmit data on the channel until the remote party has accepted it.
// This is not just a courtesy: it ensure we don't transmit data from the offering party before the offer frame itself.
// Likewise: it may help prevent transmitting data from the accepting party before the acceptance frame itself.
- await this.Acceptance.ConfigureAwait(false);
+ try
+ {
+ await this.Acceptance.ConfigureAwait(false);
+ channelAccepted = true;
+ }
+ catch (Exception exception)
+ {
+ // This await will only throw an exception if the channel has been disposed and thus we can swallow.
+ if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Error) ?? false)
+ {
+ this.TraceSource.TraceEvent(
+ TraceEventType.Error,
+ (int)TraceEventId.NonFatalInternalError,
+ "Channel {0} swalled acceptance exception in {1}: {2}",
+ this.QualifiedId,
+ nameof(this.ProcessOutboundTransmissionsAsync),
+ exception);
+ }
+ }
- while (!this.Completion.IsCompleted)
+ while (channelAccepted && !this.Completion.IsCompleted)
{
- if (!this.remoteWindowHasCapacity.IsSet && this.TraceSource!.Switch.ShouldTrace(TraceEventType.Verbose))
+ if (!this.remoteWindowHasCapacity.IsSet && (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Verbose) ?? false))
{
this.TraceSource.TraceEvent(TraceEventType.Verbose, 0, "Remote window is full. Waiting for remote party to process data before sending more.");
}
@@ -761,7 +841,7 @@ private async Task ProcessOutboundTransmissionsAsync()
await this.remoteWindowHasCapacity.WaitAsync().ConfigureAwait(false);
if (this.IsRemotelyTerminated)
{
- if (this.TraceSource!.Switch.ShouldTrace(TraceEventType.Verbose))
+ if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Verbose) ?? false)
{
this.TraceSource.TraceEvent(TraceEventType.Verbose, 0, "Transmission on channel {0} \"{1}\" terminated the remote party terminated the channel.", this.QualifiedId, this.Name);
}
@@ -774,7 +854,7 @@ private async Task ProcessOutboundTransmissionsAsync()
if (result.IsCanceled)
{
// We've been asked to cancel. Presumably the channel has faulted or been disposed.
- if (this.TraceSource!.Switch.ShouldTrace(TraceEventType.Verbose))
+ if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Verbose) ?? false)
{
this.TraceSource.TraceEvent(TraceEventType.Verbose, 0, "Transmission terminated because the read was canceled.");
}
@@ -793,7 +873,7 @@ private async Task ProcessOutboundTransmissionsAsync()
ReadOnlySequence bufferToRelay = result.Buffer.Slice(0, bytesToSend);
this.OnTransmittingBytes(bufferToRelay.Length);
bool isCompleted = result.IsCompleted && result.Buffer.Length == bufferToRelay.Length;
- if (this.TraceSource!.Switch.ShouldTrace(TraceEventType.Verbose))
+ if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Verbose) ?? false)
{
this.TraceSource.TraceEvent(TraceEventType.Verbose, 0, "{0} of {1} bytes will be transmitted.", bufferToRelay.Length, result.Buffer.Length);
}
@@ -819,7 +899,7 @@ private async Task ProcessOutboundTransmissionsAsync()
if (isCompleted)
{
- if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Information))
+ if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Information) ?? false)
{
this.TraceSource.TraceEvent(TraceEventType.Information, 0, "Transmission terminated because the writer completed.");
}
@@ -832,20 +912,37 @@ private async Task ProcessOutboundTransmissionsAsync()
}
catch (Exception ex)
{
+ // If the operation had been cancelled then we are expecting to receive this error so don't transmit it.
if (ex is OperationCanceledException && this.DisposalToken.IsCancellationRequested)
{
await mxStreamIOReader!.CompleteAsync().ConfigureAwait(false);
}
else
{
+ // If not record it as the error to dispose this channel with
+ lock (this.SyncObject)
+ {
+ this.faultingException ??= ex;
+ }
+
+ // Since we're not expecting to receive this error, transmit the error to the remote side.
await mxStreamIOReader!.CompleteAsync(ex).ConfigureAwait(false);
+
+ if (channelAccepted)
+ {
+ this.MultiplexingStream.OnChannelWritingError(this, ex);
+ }
}
throw;
}
finally
{
- this.MultiplexingStream.OnChannelWritingCompleted(this);
+ // Send the completion message to the remote if the channel was accepted
+ if (channelAccepted)
+ {
+ this.MultiplexingStream.OnChannelWritingCompleted(this);
+ }
// Restore the PipeReader to the field.
lock (this.SyncObject)
@@ -916,17 +1013,26 @@ private async Task AutoCloseOnPipesClosureAsync()
private void Fault(Exception exception)
{
- if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Critical) ?? false)
- {
- this.TraceSource!.TraceEvent(TraceEventType.Critical, (int)TraceEventId.FatalError, "Channel Closing self due to exception: {0}", exception);
- }
-
+ // Record the faulting exception unless it is not the original exception.
lock (this.SyncObject)
{
this.faultingException ??= exception;
}
this.mxStreamIOReader?.CancelPendingRead();
+
+ // Only dispose if not already disposed.
+ if (this.IsDisposed)
+ {
+ return;
+ }
+
+ // Record the fact that we are about to close the channel due to a fault.
+ if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Error) ?? false)
+ {
+ this.TraceSource.TraceEvent(TraceEventType.Error, (int)TraceEventId.FatalError, "Channel {0} closing self due to exception: {1}", this.QualifiedId, exception);
+ }
+
this.Dispose();
}
diff --git a/src/Nerdbank.Streams/MultiplexingStream.ChannelOptions.cs b/src/Nerdbank.Streams/MultiplexingStream.ChannelOptions.cs
index 8adddd80..e53c7b46 100644
--- a/src/Nerdbank.Streams/MultiplexingStream.ChannelOptions.cs
+++ b/src/Nerdbank.Streams/MultiplexingStream.ChannelOptions.cs
@@ -50,6 +50,10 @@ public ChannelOptions()
/// The specified in *must* be created with that *exceeds*
/// the value of and .
///
+ ///
+ /// A faulted (one where is called with an exception)
+ /// will have its exception relayed to the remote party before closing the channel.
+ ///
///
/// Thrown if set to an that returns null for either of its properties.
public IDuplexPipe? ExistingPipe
diff --git a/src/Nerdbank.Streams/MultiplexingStream.ControlCode.cs b/src/Nerdbank.Streams/MultiplexingStream.ControlCode.cs
index 91a9b364..5b01fd34 100644
--- a/src/Nerdbank.Streams/MultiplexingStream.ControlCode.cs
+++ b/src/Nerdbank.Streams/MultiplexingStream.ControlCode.cs
@@ -44,6 +44,13 @@ internal enum ControlCode : byte
/// allowing them to send more data.
///
ContentProcessed,
+
+ ///
+ /// Sent when one party experiences an exception related to a particular channel and carries details regarding the error,
+ /// when using protocol version 2 or later.
+ /// This is sent before a frame closes that channel.
+ ///
+ ContentWritingError,
}
}
}
diff --git a/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs b/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs
index c0550733..c5cf9472 100644
--- a/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs
+++ b/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs
@@ -5,6 +5,7 @@ namespace Nerdbank.Streams
{
using System;
using System.Buffers;
+ using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Threading;
@@ -495,6 +496,53 @@ internal override Channel.AcceptanceParameters DeserializeAcceptanceParameters(R
return new Channel.AcceptanceParameters(remoteWindowSize);
}
+ ///
+ /// Returns the serialized representation of a object using .
+ ///
+ /// An instance of that we want to seralize.
+ /// A which is the serialized version of the error.
+ internal ReadOnlySequence SerializeWriteError(WriteError error)
+ {
+ // Create the payload
+ Sequence errorSequence = new();
+ MessagePackWriter writer = new(errorSequence);
+
+ // Write the error message and the protocol version to the payload
+ writer.WriteArrayHeader(1);
+ writer.Write(error.Message);
+
+ // Return the payload to the caller
+ writer.Flush();
+ return errorSequence.AsReadOnlySequence;
+ }
+
+ ///
+ /// Extracts an object from the payload using .
+ ///
+ /// The payload we are trying to extract the error object from.
+ /// A object.
+ internal WriteError DeserializeWriteError(ReadOnlySequence serializedError)
+ {
+ MessagePackReader reader = new(serializedError);
+ int numElements = reader.ReadArrayHeader();
+
+ string? errorMessage = null;
+ for (int i = 0; i < numElements; i++)
+ {
+ switch (i)
+ {
+ case 0:
+ errorMessage = reader.ReadString();
+ break;
+ default:
+ reader.Skip();
+ break;
+ }
+ }
+
+ return new WriteError(errorMessage ?? string.Empty);
+ }
+
protected virtual (FrameHeader Header, ReadOnlySequence Payload) DeserializeFrame(ReadOnlySequence frameSequence)
{
var reader = new MessagePackReader(frameSequence);
diff --git a/src/Nerdbank.Streams/MultiplexingStream.WriteError.cs b/src/Nerdbank.Streams/MultiplexingStream.WriteError.cs
new file mode 100644
index 00000000..b46801d0
--- /dev/null
+++ b/src/Nerdbank.Streams/MultiplexingStream.WriteError.cs
@@ -0,0 +1,32 @@
+// Copyright (c) Andrew Arnott. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace Nerdbank.Streams
+{
+ ///
+ /// Contains the nested type.
+ ///
+ public partial class MultiplexingStream
+ {
+ ///
+ /// A class containing information about a write error and which is sent to the
+ /// remote alongside .
+ ///
+ internal class WriteError
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The error message we want to send to the receiver.
+ internal WriteError(string? message)
+ {
+ this.Message = message;
+ }
+
+ ///
+ /// Gets the error message associated with this error.
+ ///
+ internal string? Message { get; }
+ }
+ }
+}
diff --git a/src/Nerdbank.Streams/MultiplexingStream.cs b/src/Nerdbank.Streams/MultiplexingStream.cs
index 904baf24..3244674c 100644
--- a/src/Nerdbank.Streams/MultiplexingStream.cs
+++ b/src/Nerdbank.Streams/MultiplexingStream.cs
@@ -209,6 +209,16 @@ private enum TraceEventId
/// Raised when the protocol handshake is starting, to annouce the major version being used.
///
HandshakeStarted,
+
+ ///
+ /// Raised when receiving or sending a .
+ ///
+ WriteError,
+
+ ///
+ /// An error occurred that is likely not fatal.
+ ///
+ NonFatalInternalError,
}
///
@@ -220,7 +230,7 @@ private enum TraceEventId
/// Gets the logger used by this instance.
///
/// Never null.
- public TraceSource TraceSource { get; }
+ public TraceSource TraceSource { get; private set; }
///
/// Gets the default window size used for new channels that do not specify a value for .
@@ -244,6 +254,11 @@ private enum TraceEventId
///
private Func? DefaultChannelTraceSourceFactory { get; }
+ ///
+ /// Gets a value indicating whether this stream can send or receive frames of type.
+ ///
+ private bool ContentWritingErrorSupported => this.protocolMajorVersion > 1;
+
///
/// Initializes a new instance of the class
/// with set to 3.
@@ -837,6 +852,9 @@ private async Task ReadStreamAsync()
case ControlCode.ChannelTerminated:
await this.OnChannelTerminatedAsync(header.RequiredChannelId).ConfigureAwait(false);
break;
+ case ControlCode.ContentWritingError:
+ this.OnContentWritingError(header.RequiredChannelId, frame.Value.Payload);
+ break;
default:
break;
}
@@ -919,6 +937,56 @@ private async Task OnChannelTerminatedAsync(QualifiedChannelId channelId)
}
}
+ ///
+ /// Occurs when the channel receives a frame with code from the remote.
+ ///
+ /// The channel id of the sender of the frame.
+ /// The payload that the sender sent in the frame.
+ private void OnContentWritingError(QualifiedChannelId channelId, ReadOnlySequence payload)
+ {
+ // Get the channel that send this frame
+ Channel channel;
+ lock (this.syncObject)
+ {
+ channel = this.openChannels[channelId];
+ }
+
+ // Determines if the channel is in a state to receive messages
+ bool channelInValidState = channelId.Source != ChannelSource.Local || channel.IsAccepted;
+
+ // If the channel is in a valid state and we have a valid protocol version, then process the message
+ if (channelInValidState && this.ContentWritingErrorSupported)
+ {
+ // Deserialize the payload and verify that it was in an expected state
+ V2Formatter errorDeserializingFormattter = (V2Formatter)this.formatter;
+ WriteError error = errorDeserializingFormattter.DeserializeWriteError(payload);
+
+ // Get the error message and complete the channel using it
+ string errorMessage = error.Message ?? "";
+ MultiplexingProtocolException channelClosingException = new MultiplexingProtocolException($"Remote party indicated writing error: {errorMessage}");
+ channel.OnContentWritingCompleted(channelClosingException);
+ }
+ else if (channelInValidState && !this.ContentWritingErrorSupported)
+ {
+ // The channel is in a valid state but we have a protocol version that doesn't support processing errrors
+ // so don't do anything.
+ if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Warning) ?? false)
+ {
+ this.TraceSource.TraceEvent(
+ TraceEventType.Warning,
+ (int)TraceEventId.WriteError,
+ "Rejecting writing error from channel {0} as MultiplexingStream has protocol version of {1}",
+ channelId,
+ this.protocolMajorVersion);
+ }
+ }
+ else
+ {
+ // The channel is in an invalid state so throw an error indicating so
+ throw new MultiplexingProtocolException($"Remote party indicated they encountered errors writing to channel {channelId} before accepting it.");
+ }
+ }
+
private void OnContentWritingCompleted(QualifiedChannelId channelId)
{
Channel channel;
@@ -1064,12 +1132,7 @@ private bool TryAcceptChannel(Channel channel, ChannelOptions options)
Requires.NotNull(channel, nameof(channel));
Requires.NotNull(options, nameof(options));
- if (channel.TryAcceptOffer(options))
- {
- return true;
- }
-
- return false;
+ return channel.TryAcceptOffer(options);
}
private void AcceptChannelOrThrow(Channel channel, ChannelOptions options)
@@ -1079,7 +1142,12 @@ private void AcceptChannelOrThrow(Channel channel, ChannelOptions options)
if (!this.TryAcceptChannel(channel, options))
{
- if (channel.IsAccepted)
+ // If we disposed of the channel due to a user provided error then ignore the error.
+ if (channel.IsDisposed && (channel.Completion.IsFaulted || channel.Acceptance.IsFaulted))
+ {
+ return;
+ }
+ else if (channel.IsAccepted)
{
throw new InvalidOperationException("Channel is already accepted.");
}
@@ -1113,6 +1181,50 @@ private void OnChannelDisposed(Channel channel)
}
}
+ ///
+ /// Informs the remote party of a local error that prevents sending all the required data to this channel
+ /// by transmitting a frame.
+ ///
+ /// The channel whose writing was halted.
+ /// The exception that caused the writing to be haulted.
+ private void OnChannelWritingError(Channel channel, Exception exception)
+ {
+ if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Error))
+ {
+ this.TraceSource.TraceEvent(TraceEventType.Error, (int)TraceEventId.WriteError, "Local channel {0} encountered write error {1}", channel.QualifiedId, exception.Message);
+ }
+
+ // Verify that we can send a message over this channel.
+ // The race condition here is handled within SendFrameAsync which will drop the frame
+ // if the conditions we're checking for here change after we check them, so our check here
+ // is just an optimization to avoid work when we can predict its failure.
+ bool channelInValidState = true;
+ lock (this.syncObject)
+ {
+ channelInValidState = !this.channelsPendingTermination.Contains(channel.QualifiedId)
+ && this.openChannels.ContainsKey(channel.QualifiedId);
+ }
+
+ // If we can send messages over this channel and we have the correct protocol version then send the error
+ if (channelInValidState && this.ContentWritingErrorSupported)
+ {
+ // Create the payload to send to the remote side
+ V2Formatter errorSerializationFormatter = (V2Formatter)this.formatter;
+ WriteError error = new(exception.Message);
+ ReadOnlySequence serializedError = errorSerializationFormatter.SerializeWriteError(error);
+
+ // Create the frame header indicating that we encountered a content writing error
+ FrameHeader header = new FrameHeader
+ {
+ Code = ControlCode.ContentWritingError,
+ ChannelId = channel.QualifiedId,
+ };
+
+ // Send the frame alongside the payload to the remote side
+ this.SendFrame(header, serializedError, CancellationToken.None);
+ }
+ }
+
///
/// Indicates that the local end will not be writing any more data to this channel,
/// leading to the transmission of a frame being sent for this channel.
@@ -1191,7 +1303,7 @@ private async Task SendFrameAsync(FrameHeader header, ReadOnlySequence pay
// In such cases, we should just suppress transmission of the frame because the other side does not care.
// ContentWritingCompleted can be sent to SendFrame after a ChannelTerminated message such that neither have been transmitted yet
// and thus wasn't in the termination collection until later, so forgive that too.
- if (header.Code is ControlCode.ContentProcessed or ControlCode.ContentWritingCompleted)
+ if (header.Code is ControlCode.ContentProcessed or ControlCode.ContentWritingCompleted or ControlCode.ContentWritingError)
{
this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEventId.FrameSendSkipped, "Skipping {0} frame for channel {1} because we're about to terminate it.", header.Code, header.ChannelId);
return;
diff --git a/src/nerdbank-streams/src/Channel.ts b/src/nerdbank-streams/src/Channel.ts
index eae435dd..257e83c7 100644
--- a/src/nerdbank-streams/src/Channel.ts
+++ b/src/nerdbank-streams/src/Channel.ts
@@ -53,6 +53,15 @@ export abstract class Channel implements IDisposableObservable {
return this._isDisposed;
}
+ /**
+ * Closes this channel after transmitting an error to the remote party.
+ * @param error The error to transmit to the remote party.
+ */
+ public fault(error: Error): Promise {
+ // The interesting stuff is in the derived class.
+ return Promise.resolve();
+ }
+
/**
* Closes this channel.
*/
@@ -71,6 +80,7 @@ export class ChannelClass extends Channel {
private readonly _completion = new Deferred();
public localWindowSize?: number;
private remoteWindowSize?: number;
+ private remoteError?: Error;
/**
* The number of bytes transmitted from here but not yet acknowledged as processed from there,
@@ -214,7 +224,13 @@ export class ChannelClass extends Channel {
return this._acceptance.resolve();
}
- public onContent(buffer: Buffer | null) {
+ public onContent(buffer: Buffer | null, error?: Error) {
+ // If we have already received an error from the remote party, then don't process any future messages.
+ if (this.remoteError) {
+ return;
+ }
+
+ this.remoteError = error;
this._duplex.push(buffer);
// We should find a way to detect when we *actually* share the received buffer with the Channel's user
@@ -244,6 +260,23 @@ export class ChannelClass extends Channel {
}
}
+ public async fault(error: Error) {
+ // If the channel is already disposed then don't do anything
+ if (this.isDisposed) {
+ return;
+ }
+
+ // Send the error message to the remote side
+ await this._multiplexingStream.onChannelWritingError(this, error);
+
+ // Set the remote exception to the passed in error so that the channel is
+ // completed with this error
+ this.remoteError = error;
+
+ // Dispose of the channel
+ await this.dispose();
+ }
+
public async dispose() {
if (!this.isDisposed) {
super.dispose();
@@ -251,11 +284,19 @@ export class ChannelClass extends Channel {
this._acceptance.reject(new CancellationToken.CancellationError("disposed"));
// For the pipes, we Complete *our* ends, and leave the user's ends alone.
- // The completion will propagate when it's ready to.
+ // The completion will propagate when it's ready to. No need to destroy the duplex
+ // as the frame containing the error message has already been sent.
this._duplex.end();
this._duplex.push(null);
- this._completion.resolve();
+ // If we are sending an error to the remote side or received an error from the remote,
+ // relay that information to the clients.
+ if (this.remoteError) {
+ this._completion.reject(this.remoteError);
+ } else {
+ this._completion.resolve();
+ }
+
await this._multiplexingStream.onChannelDisposed(this);
}
}
diff --git a/src/nerdbank-streams/src/ControlCode.ts b/src/nerdbank-streams/src/ControlCode.ts
index 3ae71fed..5459cda6 100644
--- a/src/nerdbank-streams/src/ControlCode.ts
+++ b/src/nerdbank-streams/src/ControlCode.ts
@@ -32,4 +32,11 @@ export enum ControlCode {
* Sent when a channel has finished processing data received from the remote party, allowing them to send more data.
*/
ContentProcessed,
+
+ /**
+ * Sent when one party experiences an exception related to a particular channel and carries details regarding the error,
+ * when using protocol version 2 or later.
+ * This is sent right before a ContentWritingCompleted frame closes that channel.
+ */
+ ContentWritingError,
}
diff --git a/src/nerdbank-streams/src/MultiplexingStream.ts b/src/nerdbank-streams/src/MultiplexingStream.ts
index 1f6eafcc..4265a32b 100644
--- a/src/nerdbank-streams/src/MultiplexingStream.ts
+++ b/src/nerdbank-streams/src/MultiplexingStream.ts
@@ -25,6 +25,7 @@ import {
import { OfferParameters } from "./OfferParameters";
import { Semaphore } from 'await-semaphore';
import { QualifiedChannelId, ChannelSource } from "./QualifiedChannelId";
+import { WriteError } from "./WriteError";
export abstract class MultiplexingStream implements IDisposableObservable {
/**
@@ -106,9 +107,9 @@ export abstract class MultiplexingStream implements IDisposableObservable {
* @param options Options to customize the behavior of the stream.
* @returns The multiplexing stream.
*/
- public static Create(
+ public static Create(
stream: NodeJS.ReadWriteStream,
- options?: MultiplexingStreamOptions) : MultiplexingStream {
+ options?: MultiplexingStreamOptions): MultiplexingStream {
options ??= { protocolMajorVersion: 3 };
options.protocolMajorVersion ??= 3;
@@ -578,6 +579,26 @@ export class MultiplexingStreamClass extends MultiplexingStream {
}
}
+ public async onChannelWritingError(channel: ChannelClass, error: Error) {
+ // Make sure that we are in a protocol version in which we can write errors.
+ if (this.protocolMajorVersion === 1) {
+ return;
+ }
+
+ // Make sure we can send error messages on this channel.
+ if (!this.getOpenChannel(channel.qualifiedId)) {
+ return;
+ }
+
+ // Convert the error message into a payload into a formatter.
+ const writingError = new WriteError(error.message);
+ const errorSerializingFormatter = this.formatter as MultiplexingStreamV2Formatter;
+ const errorPayload = errorSerializingFormatter.serializeContentWritingError(writingError);
+
+ // Sent the error to the remote side.
+ await this.sendFrameAsync(new FrameHeader(ControlCode.ContentWritingError, channel.qualifiedId), errorPayload);
+ }
+
public onChannelWritingCompleted(channel: ChannelClass) {
// Only inform the remote side if this channel has not already been terminated.
if (!channel.isDisposed && this.getOpenChannel(channel.qualifiedId)) {
@@ -629,6 +650,9 @@ export class MultiplexingStreamClass extends MultiplexingStream {
case ControlCode.ContentWritingCompleted:
this.onContentWritingCompleted(frame.header.requiredChannel);
break;
+ case ControlCode.ContentWritingError:
+ this.onContentWritingError(frame.header.requiredChannel, frame.payload);
+ break;
case ControlCode.ChannelTerminated:
this.onChannelTerminated(frame.header.requiredChannel);
break;
@@ -716,6 +740,27 @@ export class MultiplexingStreamClass extends MultiplexingStream {
channel.onContentProcessed(bytesProcessed);
}
+ private onContentWritingError(channelId: QualifiedChannelId, payload: Buffer) {
+ // Make sure that the channel has the proper formatter to process the output.
+ if (this.protocolMajorVersion === 1) {
+ return;
+ }
+
+ // Ensure that we received the message on an open channel.
+ const channel = this.getOpenChannel(channelId);
+ if (!channel) {
+ throw new Error(`No channel with id ${channelId} found.`);
+ }
+
+ // Extract the error from the payload.
+ const errorDeserializingFormatter = this.formatter as MultiplexingStreamV2Formatter;
+ const writingError = errorDeserializingFormatter.deserializeContentWritingError(payload);
+
+ // Pass the error received from the remote to the channel.
+ const remoteErr = new Error(`Received error message from remote: ${writingError.message}`);
+ channel.onContent(null, remoteErr);
+ }
+
private onContentWritingCompleted(channelId: QualifiedChannelId) {
const channel = this.getOpenChannel(channelId);
if (!channel) {
diff --git a/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts b/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts
index af625669..ce4ec7e0 100644
--- a/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts
+++ b/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts
@@ -8,6 +8,7 @@ import * as msgpack from 'msgpack-lite';
import { Deferred } from "./Deferred";
import { FrameHeader } from "./FrameHeader";
import { ControlCode } from "./ControlCode";
+import { WriteError } from "./WriteError";
import { ChannelSource } from "./QualifiedChannelId";
export interface Version {
@@ -293,6 +294,19 @@ export class MultiplexingStreamV2Formatter extends MultiplexingStreamFormatter {
return msgpack.decode(payload)[0];
}
+ serializeContentWritingError(writingError: WriteError): Buffer {
+ const payload: any[] = [writingError.message];
+ return msgpack.encode(payload);
+ }
+
+ deserializeContentWritingError(payload: Buffer): WriteError {
+ const msgpackObject = msgpack.decode(payload);
+
+ // Return the error message to the caller.
+ const errorMsg: string | undefined = msgpackObject[0];
+ return new WriteError(errorMsg ?? "");
+ }
+
protected async readMessagePackAsync(cancellationToken: CancellationToken): Promise<{} | [] | null> {
const streamEnded = new Deferred();
while (true) {
diff --git a/src/nerdbank-streams/src/WriteError.ts b/src/nerdbank-streams/src/WriteError.ts
new file mode 100644
index 00000000..81970f05
--- /dev/null
+++ b/src/nerdbank-streams/src/WriteError.ts
@@ -0,0 +1,13 @@
+/**
+ * A class that is used to store information related to ContentWritingError.
+ * It is used by both the sending and receiving streams to transmit errors encountered while
+ * writing content.
+ */
+export class WriteError {
+ /**
+ * Initializes a new instance of the WriteError class.
+ * @param message The error message.
+ */
+ constructor(public readonly message: string) {
+ }
+}
diff --git a/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts b/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts
index 15341867..108aa344 100644
--- a/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts
+++ b/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts
@@ -4,6 +4,7 @@ import { Deferred } from "../Deferred";
import { FullDuplexStream } from "../FullDuplexStream";
import { MultiplexingStream } from "../MultiplexingStream";
import { ChannelOptions } from "../ChannelOptions";
+import * as assert from "assert";
[1, 2, 3].forEach(protocolMajorVersion => {
describe(`MultiplexingStream v${protocolMajorVersion} (interop) `, () => {
@@ -14,6 +15,8 @@ import { ChannelOptions } from "../ChannelOptions";
const dotnetEnvBlock: NodeJS.ProcessEnv = {
DOTNET_SKIP_FIRST_TIME_EXPERIENCE: "1", // prevent warnings in stdout that corrupt our interop stream.
};
+ let expectedDisposeError: boolean;
+
beforeAll(
async () => {
proc = spawn(
@@ -26,7 +29,8 @@ import { ChannelOptions } from "../ChannelOptions";
proc.once("exit", (code) => procExited.resolve(code));
// proc.stdout!.pipe(process.stdout);
proc.stderr!.pipe(process.stderr);
- expect(await procExited.promise).toEqual(0);
+ const buildExitVal = await procExited.promise;
+ expect(buildExitVal).toEqual(0);
} finally {
proc.kill();
proc = null;
@@ -50,12 +54,21 @@ import { ChannelOptions } from "../ChannelOptions";
proc = null;
throw e;
}
+ expectedDisposeError = false;
}, 10000); // leave time for dotnet to start.
afterEach(async () => {
if (mx) {
mx.dispose();
- await mx.completion;
+
+ // See if we encounter any errors in the multplexing stream and rethrow them if they are unexpected
+ try {
+ await mx.completion;
+ } catch (error) {
+ if (!expectedDisposeError) {
+ throw error;
+ }
+ }
}
if (proc) {
@@ -86,6 +99,31 @@ import { ChannelOptions } from "../ChannelOptions";
expect(recv).toEqual(`recv: ${bigdata}`);
});
+ it("Can send error to remote", async () => {
+ expectedDisposeError = true;
+ const errorWriteChannel = await mx.offerChannelAsync("clientErrorOffer");
+ const responseReceiveChannel = await mx.offerChannelAsync("clientResponseOffer");
+
+ const errorMessage = "couldn't send all of the data";
+ const errorToSend = new Error(errorMessage);
+
+ let caughtCompletionErr = false;
+ errorWriteChannel.completion.catch(err => {
+ caughtCompletionErr = true;
+ });
+
+ await errorWriteChannel.fault(errorToSend);
+ assert.deepStrictEqual(caughtCompletionErr, true);
+
+ let expectedMessage = `received error: Remote party indicated writing error: ${errorMessage}`;
+ if (protocolMajorVersion === 1) {
+ expectedMessage = "didn't receive any errors";
+ }
+
+ const receivedMessage = await readLineAsync(responseReceiveChannel.stream);
+ assert.deepStrictEqual(receivedMessage?.trim(), expectedMessage);
+ });
+
if (protocolMajorVersion >= 3) {
it("Can communicate over seeded channel", async () => {
const channel = mx.acceptChannel(0);
diff --git a/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts b/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts
index e78cc49a..08db0365 100644
--- a/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts
+++ b/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts
@@ -230,6 +230,46 @@ import * as assert from "assert";
await channels[1].completion;
});
+ it("channel disposes with an error", async () => {
+ const errorMessage = "couldn't send all of the data";
+ const errorToSend = new Error(errorMessage);
+
+ const channels = await Promise.all([
+ mx1.offerChannelAsync("test"),
+ mx2.acceptChannelAsync("test"),
+ ]);
+
+ await channels[0].fault(errorToSend);
+
+ // Ensure that the current channel disposes with the error.
+ let caughtSenderError = false;
+ try {
+ await channels[0].completion;
+ } catch (error) {
+ let completionErrMsg = String(error);
+ if (error instanceof Error) {
+ completionErrMsg = (error as Error).message;
+ }
+ caughtSenderError = completionErrMsg.includes(errorMessage);
+ }
+
+ assert.strictEqual(true, caughtSenderError);
+
+ // Ensure that the remote side received the error only for version >= 1
+ let caughtRemoteError = false;
+ try {
+ await channels[1].completion;
+ } catch (error) {
+ let completionErrMsg = String(error);
+ if (error instanceof Error) {
+ completionErrMsg = (error as Error).message;
+ }
+ caughtRemoteError = completionErrMsg.includes(errorMessage);
+ }
+
+ assert.deepStrictEqual(protocolMajorVersion > 1, caughtRemoteError);
+ })
+
it("channels complete when mxstream is disposed", async () => {
const channels = await Promise.all([
mx1.offerChannelAsync("test"),
diff --git a/test/Nerdbank.Streams.Interop.Tests/Program.cs b/test/Nerdbank.Streams.Interop.Tests/Program.cs
index d5b5a46e..b0c46048 100644
--- a/test/Nerdbank.Streams.Interop.Tests/Program.cs
+++ b/test/Nerdbank.Streams.Interop.Tests/Program.cs
@@ -61,6 +61,7 @@ private static (StreamReader Reader, StreamWriter Writer) CreateStreamIO(Multipl
private async Task RunAsync(int protocolMajorVersion)
{
this.ClientOfferAsync().Forget();
+ this.ClientOfferErrorAsync().Forget();
this.ServerOfferAsync().Forget();
if (protocolMajorVersion >= 3)
@@ -79,6 +80,29 @@ private async Task ClientOfferAsync()
await w.WriteLineAsync($"recv: {line}");
}
+ private async Task ClientOfferErrorAsync()
+ {
+ // Await both of the channels from the sender, one to read the error and the other to return the response.
+ MultiplexingStream.Channel? incomingChannel = await this.mx.AcceptChannelAsync("clientErrorOffer");
+ MultiplexingStream.Channel? outgoingChannel = await this.mx.AcceptChannelAsync("clientResponseOffer");
+
+ // Determine the response to send back on the whether the incoming channel completed with an exception.
+ string? responseMessage = "didn't receive any errors";
+ try
+ {
+ await incomingChannel.Completion;
+ }
+ catch (Exception error)
+ {
+ responseMessage = "received error: " + error.Message;
+ }
+
+ // Create a writer using the outgoing channel and send the response to the sender.
+ (StreamReader _, StreamWriter writer) = CreateStreamIO(outgoingChannel);
+
+ await writer.WriteLineAsync(responseMessage);
+ }
+
private async Task ServerOfferAsync()
{
MultiplexingStream.Channel? channel = await this.mx.OfferChannelAsync("serverOffer");
diff --git a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs
index 83e6717c..c772fb0c 100644
--- a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs
+++ b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs
@@ -100,6 +100,74 @@ public void DefaultMajorProtocolVersion()
Assert.Equal(1, new MultiplexingStream.Options().ProtocolMajorVersion);
}
+ [Fact]
+ public async Task ClosePipeWithError()
+ {
+ (MultiplexingStream.Channel channel1, MultiplexingStream.Channel channel2) = await this.EstablishChannelsAsync("test");
+ await channel1.Output.WriteAsync(new byte[] { 1, 2, 3 }, this.TimeoutToken);
+ ReadResult readResult = await channel2.Input.ReadAtLeastAsync(3, this.TimeoutToken);
+ channel2.Input.AdvanceTo(readResult.Buffer.End);
+
+ // Now fail one side.
+ const string expectedErrorMessage = "Inflicted error";
+ await channel1.Output.CompleteAsync(new ApplicationException(expectedErrorMessage));
+ if (this.ProtocolMajorVersion > 1)
+ {
+ MultiplexingProtocolException ex = await Assert.ThrowsAnyAsync(async () => await channel2.Input.ReadAsync(this.TimeoutToken));
+ Assert.Contains(expectedErrorMessage, ex.Message);
+ }
+ else
+ {
+ await channel2.Input.ReadAsync(this.TimeoutToken);
+ }
+ }
+
+ [Fact]
+ public async Task OfferPipeWithError()
+ {
+ string errorMessage = "Hello World";
+
+ // Prepare a readonly pipe that is already populated with data and an error.
+ var pipe = new Pipe();
+ await pipe.Writer.WriteAsync(new byte[] { 1, 2, 3 }, this.TimeoutToken);
+ pipe.Writer.Complete(new ApplicationException(errorMessage));
+
+ // Create a sending and receiving channel using the channel.
+ MultiplexingStream.Channel? localChannel = this.mx1.CreateChannel(new MultiplexingStream.ChannelOptions { ExistingPipe = new DuplexPipe(pipe.Reader) });
+ await this.WaitForEphemeralChannelOfferToPropagateAsync();
+ MultiplexingStream.Channel? remoteChannel = this.mx2.AcceptChannel(localChannel.QualifiedId.Id);
+
+ async Task ReadAllDataAsync()
+ {
+ // Read the latest input from the local channel and determine if we should continue reading.
+ ReadResult readResult = await remoteChannel.Input.ReadAsync(this.TimeoutToken);
+ remoteChannel.Input.AdvanceTo(readResult.Buffer.End);
+
+ if (readResult.IsCompleted || readResult.IsCanceled)
+ {
+ return;
+ }
+ }
+
+ if (this.ProtocolMajorVersion > 1)
+ {
+ MultiplexingProtocolException caughtException = await Assert.ThrowsAnyAsync(ReadAllDataAsync);
+ this.Logger.WriteLine(caughtException.ToString());
+ Assert.Contains(errorMessage, caughtException.Message);
+
+ MultiplexingProtocolException remoteCompletionException = await Assert.ThrowsAnyAsync(() => remoteChannel.Completion);
+ Assert.Contains(errorMessage, remoteCompletionException.Message);
+ }
+ else
+ {
+ await ReadAllDataAsync();
+ }
+
+ // Ensure that the writer of the error completes with that error, no matter what version of the protocol they are using.
+ ApplicationException localCompletionException = await Assert.ThrowsAnyAsync(() => localChannel.Completion);
+ Assert.Contains(errorMessage, localCompletionException.Message);
+ }
+
[Fact]
public async Task OfferReadOnlyDuplexPipe()
{