diff --git a/src/Nerdbank.Streams/MultiplexingStream.Channel.cs b/src/Nerdbank.Streams/MultiplexingStream.Channel.cs index 71098e62..6a9072f5 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.Channel.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.Channel.cs @@ -111,6 +111,7 @@ public class Channel : IDisposableObservable, IDuplexPipe /// /// The to use to get data to be transmitted over the . + /// Any errors passed to this are transmitted to the remote side. /// private PipeReader? mxStreamIOReader; @@ -142,6 +143,11 @@ public class Channel : IDisposableObservable, IDuplexPipe /// private bool? existingPipeGiven; + /// + /// A value indicating whether this received an error from a remote party in . + /// + private bool receivedRemoteException; + /// /// Initializes a new instance of the class. /// @@ -316,6 +322,11 @@ private long RemoteWindowRemaining } } + /// + /// Gets the exception sent from the remote side over this channel, null otherwise. + /// + private Exception? RemoteException => this.receivedRemoteException ? this.faultingException : null; + /// /// Gets a value indicating whether backpressure support is enabled. /// @@ -332,9 +343,6 @@ public void Dispose() { if (!this.IsDisposed) { - this.acceptanceSource.TrySetCanceled(); - this.optionsAppliedTaskSource?.TrySetCanceled(); - PipeWriter? mxStreamIOWriter; lock (this.SyncObject) { @@ -342,6 +350,21 @@ public void Dispose() mxStreamIOWriter = this.mxStreamIOWriter; } + // If we are disposing due to a faulting error, transition the acceptanceSource to an error state + lock (this.SyncObject) + { + if (this.faultingException != null) + { + this.acceptanceSource.TrySetException(this.faultingException); + this.optionsAppliedTaskSource?.TrySetException(this.faultingException); + } + else + { + this.acceptanceSource.TrySetCanceled(); + this.optionsAppliedTaskSource?.TrySetCanceled(); + } + } + // Complete writing so that the mxstream cannot write to this channel any more. // We must also cancel a pending flush since no one is guaranteed to be reading this any more // and we don't want to deadlock on a full buffer in a disposed channel's pipe. @@ -361,7 +384,7 @@ public void Dispose() mxStreamIOWriter = self.mxStreamIOWriter; } - mxStreamIOWriter?.Complete(); + mxStreamIOWriter?.Complete(self.RemoteException); self.mxStreamIOWriterCompleted.Set(); } finally @@ -401,7 +424,17 @@ public void Dispose() this.remoteWindowHasCapacity.Set(); this.disposalTokenSource.Cancel(); - this.completionSource.TrySetResult(null); + + // If we are disposing due to receiving or sending an exception, relay that to our client. + if (this.faultingException is Exception faultingException) + { + this.completionSource.TrySetException(faultingException); + } + else + { + this.completionSource.TrySetResult(null); + } + this.MultiplexingStream.OnChannelDisposed(this); } } @@ -418,7 +451,7 @@ internal async Task OnChannelTerminatedAsync() // We Complete the writer because only the writing (logical) thread should complete it // to avoid race conditions, and Channel.Dispose can be called from any thread. using PipeWriterRental writerRental = await this.GetReceivedMessagePipeWriterAsync().ConfigureAwait(false); - await writerRental.Writer.CompleteAsync().ConfigureAwait(false); + await writerRental.Writer.CompleteAsync(this.RemoteException).ConfigureAwait(false); } catch (ObjectDisposedException) { @@ -497,32 +530,49 @@ internal async ValueTask OnContentAsync(ReadOnlySequence payload, Cancella /// /// Called by the when it will not be writing any more data to the channel. /// - internal void OnContentWritingCompleted() + /// The error in writing that originated on the remote side, if applicable. + internal void OnContentWritingCompleted(MultiplexingProtocolException? error = null) { + // If we have already received an error from the remote side then no need to complete the channel again. + if (this.receivedRemoteException) + { + return; + } + + // Set the state of the channel based on whether we are completing due to an error. + lock (this.SyncObject) + { + this.faultingException ??= error; + this.receivedRemoteException = error != null; + } + this.DisposeSelfOnFailure(Task.Run(async delegate { if (!this.IsDisposed) { try { + // If the channel is not disposed, then first try to close the writer used by the channel owner using PipeWriterRental writerRental = await this.GetReceivedMessagePipeWriterAsync().ConfigureAwait(false); - await writerRental.Writer.CompleteAsync().ConfigureAwait(false); + await writerRental.Writer.CompleteAsync(this.RemoteException).ConfigureAwait(false); } catch (ObjectDisposedException) { + // If not, try to close the underlying writer. if (this.mxStreamIOWriter != null) { using AsyncSemaphore.Releaser releaser = await this.mxStreamIOWriterSemaphore.EnterAsync().ConfigureAwait(false); - await this.mxStreamIOWriter.CompleteAsync().ConfigureAwait(false); + await this.mxStreamIOWriter.CompleteAsync(this.RemoteException).ConfigureAwait(false); } } } else { + // If the channel has not been disposed then just close the underlying writer. if (this.mxStreamIOWriter != null) { using AsyncSemaphore.Releaser releaser = await this.mxStreamIOWriterSemaphore.EnterAsync().ConfigureAwait(false); - await this.mxStreamIOWriter.CompleteAsync().ConfigureAwait(false); + await this.mxStreamIOWriter.CompleteAsync(this.RemoteException).ConfigureAwait(false); } } @@ -545,8 +595,15 @@ internal bool TryAcceptOffer(ChannelOptions channelOptions) } var acceptanceParameters = new AcceptanceParameters(this.localWindowSize.Value); - if (this.acceptanceSource.TrySetResult(acceptanceParameters)) + + try { + // Set up the channel options and ensure that the channel is still valid + // before we transition to an accepted state. + this.ApplyChannelOptions(channelOptions); + Verify.NotDisposed(this); + + // If we aren't a seeded channel then send an offer accepted frame. if (this.QualifiedId.Source != ChannelSource.Seeded) { ReadOnlySequence payload = this.MultiplexingStream.formatter.Serialize(acceptanceParameters); @@ -560,16 +617,20 @@ internal bool TryAcceptOffer(ChannelOptions channelOptions) CancellationToken.None); } - try + // Update the acceptance source to the acceptance parameters. + return this.acceptanceSource.TrySetResult(acceptanceParameters); + } + catch (Exception exception) + { + // Record the exception in the acceptance source. + this.acceptanceSource.TrySetException(exception); + + // If we caught an disposal error due to the channel self faulting then swallow + // the exception. + if (exception is ObjectDisposedException && this.faultingException != null) { - this.ApplyChannelOptions(channelOptions); return true; } - catch (ObjectDisposedException) - { - // A (harmless) race condition was hit. - // Swallow it and return false below. - } } return false; @@ -744,16 +805,35 @@ private async Task ProcessOutboundTransmissionsAsync() this.mxStreamIOReader = new UnownedPipeReader(mxStreamIOReader); } + bool channelAccepted = false; try { // Don't transmit data on the channel until the remote party has accepted it. // This is not just a courtesy: it ensure we don't transmit data from the offering party before the offer frame itself. // Likewise: it may help prevent transmitting data from the accepting party before the acceptance frame itself. - await this.Acceptance.ConfigureAwait(false); + try + { + await this.Acceptance.ConfigureAwait(false); + channelAccepted = true; + } + catch (Exception exception) + { + // This await will only throw an exception if the channel has been disposed and thus we can swallow. + if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Error) ?? false) + { + this.TraceSource.TraceEvent( + TraceEventType.Error, + (int)TraceEventId.NonFatalInternalError, + "Channel {0} swalled acceptance exception in {1}: {2}", + this.QualifiedId, + nameof(this.ProcessOutboundTransmissionsAsync), + exception); + } + } - while (!this.Completion.IsCompleted) + while (channelAccepted && !this.Completion.IsCompleted) { - if (!this.remoteWindowHasCapacity.IsSet && this.TraceSource!.Switch.ShouldTrace(TraceEventType.Verbose)) + if (!this.remoteWindowHasCapacity.IsSet && (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Verbose) ?? false)) { this.TraceSource.TraceEvent(TraceEventType.Verbose, 0, "Remote window is full. Waiting for remote party to process data before sending more."); } @@ -761,7 +841,7 @@ private async Task ProcessOutboundTransmissionsAsync() await this.remoteWindowHasCapacity.WaitAsync().ConfigureAwait(false); if (this.IsRemotelyTerminated) { - if (this.TraceSource!.Switch.ShouldTrace(TraceEventType.Verbose)) + if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Verbose) ?? false) { this.TraceSource.TraceEvent(TraceEventType.Verbose, 0, "Transmission on channel {0} \"{1}\" terminated the remote party terminated the channel.", this.QualifiedId, this.Name); } @@ -774,7 +854,7 @@ private async Task ProcessOutboundTransmissionsAsync() if (result.IsCanceled) { // We've been asked to cancel. Presumably the channel has faulted or been disposed. - if (this.TraceSource!.Switch.ShouldTrace(TraceEventType.Verbose)) + if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Verbose) ?? false) { this.TraceSource.TraceEvent(TraceEventType.Verbose, 0, "Transmission terminated because the read was canceled."); } @@ -793,7 +873,7 @@ private async Task ProcessOutboundTransmissionsAsync() ReadOnlySequence bufferToRelay = result.Buffer.Slice(0, bytesToSend); this.OnTransmittingBytes(bufferToRelay.Length); bool isCompleted = result.IsCompleted && result.Buffer.Length == bufferToRelay.Length; - if (this.TraceSource!.Switch.ShouldTrace(TraceEventType.Verbose)) + if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Verbose) ?? false) { this.TraceSource.TraceEvent(TraceEventType.Verbose, 0, "{0} of {1} bytes will be transmitted.", bufferToRelay.Length, result.Buffer.Length); } @@ -819,7 +899,7 @@ private async Task ProcessOutboundTransmissionsAsync() if (isCompleted) { - if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Information)) + if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Information) ?? false) { this.TraceSource.TraceEvent(TraceEventType.Information, 0, "Transmission terminated because the writer completed."); } @@ -832,20 +912,37 @@ private async Task ProcessOutboundTransmissionsAsync() } catch (Exception ex) { + // If the operation had been cancelled then we are expecting to receive this error so don't transmit it. if (ex is OperationCanceledException && this.DisposalToken.IsCancellationRequested) { await mxStreamIOReader!.CompleteAsync().ConfigureAwait(false); } else { + // If not record it as the error to dispose this channel with + lock (this.SyncObject) + { + this.faultingException ??= ex; + } + + // Since we're not expecting to receive this error, transmit the error to the remote side. await mxStreamIOReader!.CompleteAsync(ex).ConfigureAwait(false); + + if (channelAccepted) + { + this.MultiplexingStream.OnChannelWritingError(this, ex); + } } throw; } finally { - this.MultiplexingStream.OnChannelWritingCompleted(this); + // Send the completion message to the remote if the channel was accepted + if (channelAccepted) + { + this.MultiplexingStream.OnChannelWritingCompleted(this); + } // Restore the PipeReader to the field. lock (this.SyncObject) @@ -916,17 +1013,26 @@ private async Task AutoCloseOnPipesClosureAsync() private void Fault(Exception exception) { - if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Critical) ?? false) - { - this.TraceSource!.TraceEvent(TraceEventType.Critical, (int)TraceEventId.FatalError, "Channel Closing self due to exception: {0}", exception); - } - + // Record the faulting exception unless it is not the original exception. lock (this.SyncObject) { this.faultingException ??= exception; } this.mxStreamIOReader?.CancelPendingRead(); + + // Only dispose if not already disposed. + if (this.IsDisposed) + { + return; + } + + // Record the fact that we are about to close the channel due to a fault. + if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Error) ?? false) + { + this.TraceSource.TraceEvent(TraceEventType.Error, (int)TraceEventId.FatalError, "Channel {0} closing self due to exception: {1}", this.QualifiedId, exception); + } + this.Dispose(); } diff --git a/src/Nerdbank.Streams/MultiplexingStream.ChannelOptions.cs b/src/Nerdbank.Streams/MultiplexingStream.ChannelOptions.cs index 8adddd80..e53c7b46 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.ChannelOptions.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.ChannelOptions.cs @@ -50,6 +50,10 @@ public ChannelOptions() /// The specified in *must* be created with that *exceeds* /// the value of and . /// + /// + /// A faulted (one where is called with an exception) + /// will have its exception relayed to the remote party before closing the channel. + /// /// /// Thrown if set to an that returns null for either of its properties. public IDuplexPipe? ExistingPipe diff --git a/src/Nerdbank.Streams/MultiplexingStream.ControlCode.cs b/src/Nerdbank.Streams/MultiplexingStream.ControlCode.cs index 91a9b364..5b01fd34 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.ControlCode.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.ControlCode.cs @@ -44,6 +44,13 @@ internal enum ControlCode : byte /// allowing them to send more data. /// ContentProcessed, + + /// + /// Sent when one party experiences an exception related to a particular channel and carries details regarding the error, + /// when using protocol version 2 or later. + /// This is sent before a frame closes that channel. + /// + ContentWritingError, } } } diff --git a/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs b/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs index c0550733..c5cf9472 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs @@ -5,6 +5,7 @@ namespace Nerdbank.Streams { using System; using System.Buffers; + using System.Diagnostics; using System.IO; using System.IO.Pipelines; using System.Threading; @@ -495,6 +496,53 @@ internal override Channel.AcceptanceParameters DeserializeAcceptanceParameters(R return new Channel.AcceptanceParameters(remoteWindowSize); } + /// + /// Returns the serialized representation of a object using . + /// + /// An instance of that we want to seralize. + /// A which is the serialized version of the error. + internal ReadOnlySequence SerializeWriteError(WriteError error) + { + // Create the payload + Sequence errorSequence = new(); + MessagePackWriter writer = new(errorSequence); + + // Write the error message and the protocol version to the payload + writer.WriteArrayHeader(1); + writer.Write(error.Message); + + // Return the payload to the caller + writer.Flush(); + return errorSequence.AsReadOnlySequence; + } + + /// + /// Extracts an object from the payload using . + /// + /// The payload we are trying to extract the error object from. + /// A object. + internal WriteError DeserializeWriteError(ReadOnlySequence serializedError) + { + MessagePackReader reader = new(serializedError); + int numElements = reader.ReadArrayHeader(); + + string? errorMessage = null; + for (int i = 0; i < numElements; i++) + { + switch (i) + { + case 0: + errorMessage = reader.ReadString(); + break; + default: + reader.Skip(); + break; + } + } + + return new WriteError(errorMessage ?? string.Empty); + } + protected virtual (FrameHeader Header, ReadOnlySequence Payload) DeserializeFrame(ReadOnlySequence frameSequence) { var reader = new MessagePackReader(frameSequence); diff --git a/src/Nerdbank.Streams/MultiplexingStream.WriteError.cs b/src/Nerdbank.Streams/MultiplexingStream.WriteError.cs new file mode 100644 index 00000000..b46801d0 --- /dev/null +++ b/src/Nerdbank.Streams/MultiplexingStream.WriteError.cs @@ -0,0 +1,32 @@ +// Copyright (c) Andrew Arnott. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Nerdbank.Streams +{ + /// + /// Contains the nested type. + /// + public partial class MultiplexingStream + { + /// + /// A class containing information about a write error and which is sent to the + /// remote alongside . + /// + internal class WriteError + { + /// + /// Initializes a new instance of the class. + /// + /// The error message we want to send to the receiver. + internal WriteError(string? message) + { + this.Message = message; + } + + /// + /// Gets the error message associated with this error. + /// + internal string? Message { get; } + } + } +} diff --git a/src/Nerdbank.Streams/MultiplexingStream.cs b/src/Nerdbank.Streams/MultiplexingStream.cs index 904baf24..3244674c 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.cs @@ -209,6 +209,16 @@ private enum TraceEventId /// Raised when the protocol handshake is starting, to annouce the major version being used. /// HandshakeStarted, + + /// + /// Raised when receiving or sending a . + /// + WriteError, + + /// + /// An error occurred that is likely not fatal. + /// + NonFatalInternalError, } /// @@ -220,7 +230,7 @@ private enum TraceEventId /// Gets the logger used by this instance. /// /// Never null. - public TraceSource TraceSource { get; } + public TraceSource TraceSource { get; private set; } /// /// Gets the default window size used for new channels that do not specify a value for . @@ -244,6 +254,11 @@ private enum TraceEventId /// private Func? DefaultChannelTraceSourceFactory { get; } + /// + /// Gets a value indicating whether this stream can send or receive frames of type. + /// + private bool ContentWritingErrorSupported => this.protocolMajorVersion > 1; + /// /// Initializes a new instance of the class /// with set to 3. @@ -837,6 +852,9 @@ private async Task ReadStreamAsync() case ControlCode.ChannelTerminated: await this.OnChannelTerminatedAsync(header.RequiredChannelId).ConfigureAwait(false); break; + case ControlCode.ContentWritingError: + this.OnContentWritingError(header.RequiredChannelId, frame.Value.Payload); + break; default: break; } @@ -919,6 +937,56 @@ private async Task OnChannelTerminatedAsync(QualifiedChannelId channelId) } } + /// + /// Occurs when the channel receives a frame with code from the remote. + /// + /// The channel id of the sender of the frame. + /// The payload that the sender sent in the frame. + private void OnContentWritingError(QualifiedChannelId channelId, ReadOnlySequence payload) + { + // Get the channel that send this frame + Channel channel; + lock (this.syncObject) + { + channel = this.openChannels[channelId]; + } + + // Determines if the channel is in a state to receive messages + bool channelInValidState = channelId.Source != ChannelSource.Local || channel.IsAccepted; + + // If the channel is in a valid state and we have a valid protocol version, then process the message + if (channelInValidState && this.ContentWritingErrorSupported) + { + // Deserialize the payload and verify that it was in an expected state + V2Formatter errorDeserializingFormattter = (V2Formatter)this.formatter; + WriteError error = errorDeserializingFormattter.DeserializeWriteError(payload); + + // Get the error message and complete the channel using it + string errorMessage = error.Message ?? ""; + MultiplexingProtocolException channelClosingException = new MultiplexingProtocolException($"Remote party indicated writing error: {errorMessage}"); + channel.OnContentWritingCompleted(channelClosingException); + } + else if (channelInValidState && !this.ContentWritingErrorSupported) + { + // The channel is in a valid state but we have a protocol version that doesn't support processing errrors + // so don't do anything. + if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Warning) ?? false) + { + this.TraceSource.TraceEvent( + TraceEventType.Warning, + (int)TraceEventId.WriteError, + "Rejecting writing error from channel {0} as MultiplexingStream has protocol version of {1}", + channelId, + this.protocolMajorVersion); + } + } + else + { + // The channel is in an invalid state so throw an error indicating so + throw new MultiplexingProtocolException($"Remote party indicated they encountered errors writing to channel {channelId} before accepting it."); + } + } + private void OnContentWritingCompleted(QualifiedChannelId channelId) { Channel channel; @@ -1064,12 +1132,7 @@ private bool TryAcceptChannel(Channel channel, ChannelOptions options) Requires.NotNull(channel, nameof(channel)); Requires.NotNull(options, nameof(options)); - if (channel.TryAcceptOffer(options)) - { - return true; - } - - return false; + return channel.TryAcceptOffer(options); } private void AcceptChannelOrThrow(Channel channel, ChannelOptions options) @@ -1079,7 +1142,12 @@ private void AcceptChannelOrThrow(Channel channel, ChannelOptions options) if (!this.TryAcceptChannel(channel, options)) { - if (channel.IsAccepted) + // If we disposed of the channel due to a user provided error then ignore the error. + if (channel.IsDisposed && (channel.Completion.IsFaulted || channel.Acceptance.IsFaulted)) + { + return; + } + else if (channel.IsAccepted) { throw new InvalidOperationException("Channel is already accepted."); } @@ -1113,6 +1181,50 @@ private void OnChannelDisposed(Channel channel) } } + /// + /// Informs the remote party of a local error that prevents sending all the required data to this channel + /// by transmitting a frame. + /// + /// The channel whose writing was halted. + /// The exception that caused the writing to be haulted. + private void OnChannelWritingError(Channel channel, Exception exception) + { + if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Error)) + { + this.TraceSource.TraceEvent(TraceEventType.Error, (int)TraceEventId.WriteError, "Local channel {0} encountered write error {1}", channel.QualifiedId, exception.Message); + } + + // Verify that we can send a message over this channel. + // The race condition here is handled within SendFrameAsync which will drop the frame + // if the conditions we're checking for here change after we check them, so our check here + // is just an optimization to avoid work when we can predict its failure. + bool channelInValidState = true; + lock (this.syncObject) + { + channelInValidState = !this.channelsPendingTermination.Contains(channel.QualifiedId) + && this.openChannels.ContainsKey(channel.QualifiedId); + } + + // If we can send messages over this channel and we have the correct protocol version then send the error + if (channelInValidState && this.ContentWritingErrorSupported) + { + // Create the payload to send to the remote side + V2Formatter errorSerializationFormatter = (V2Formatter)this.formatter; + WriteError error = new(exception.Message); + ReadOnlySequence serializedError = errorSerializationFormatter.SerializeWriteError(error); + + // Create the frame header indicating that we encountered a content writing error + FrameHeader header = new FrameHeader + { + Code = ControlCode.ContentWritingError, + ChannelId = channel.QualifiedId, + }; + + // Send the frame alongside the payload to the remote side + this.SendFrame(header, serializedError, CancellationToken.None); + } + } + /// /// Indicates that the local end will not be writing any more data to this channel, /// leading to the transmission of a frame being sent for this channel. @@ -1191,7 +1303,7 @@ private async Task SendFrameAsync(FrameHeader header, ReadOnlySequence pay // In such cases, we should just suppress transmission of the frame because the other side does not care. // ContentWritingCompleted can be sent to SendFrame after a ChannelTerminated message such that neither have been transmitted yet // and thus wasn't in the termination collection until later, so forgive that too. - if (header.Code is ControlCode.ContentProcessed or ControlCode.ContentWritingCompleted) + if (header.Code is ControlCode.ContentProcessed or ControlCode.ContentWritingCompleted or ControlCode.ContentWritingError) { this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEventId.FrameSendSkipped, "Skipping {0} frame for channel {1} because we're about to terminate it.", header.Code, header.ChannelId); return; diff --git a/src/nerdbank-streams/src/Channel.ts b/src/nerdbank-streams/src/Channel.ts index eae435dd..257e83c7 100644 --- a/src/nerdbank-streams/src/Channel.ts +++ b/src/nerdbank-streams/src/Channel.ts @@ -53,6 +53,15 @@ export abstract class Channel implements IDisposableObservable { return this._isDisposed; } + /** + * Closes this channel after transmitting an error to the remote party. + * @param error The error to transmit to the remote party. + */ + public fault(error: Error): Promise { + // The interesting stuff is in the derived class. + return Promise.resolve(); + } + /** * Closes this channel. */ @@ -71,6 +80,7 @@ export class ChannelClass extends Channel { private readonly _completion = new Deferred(); public localWindowSize?: number; private remoteWindowSize?: number; + private remoteError?: Error; /** * The number of bytes transmitted from here but not yet acknowledged as processed from there, @@ -214,7 +224,13 @@ export class ChannelClass extends Channel { return this._acceptance.resolve(); } - public onContent(buffer: Buffer | null) { + public onContent(buffer: Buffer | null, error?: Error) { + // If we have already received an error from the remote party, then don't process any future messages. + if (this.remoteError) { + return; + } + + this.remoteError = error; this._duplex.push(buffer); // We should find a way to detect when we *actually* share the received buffer with the Channel's user @@ -244,6 +260,23 @@ export class ChannelClass extends Channel { } } + public async fault(error: Error) { + // If the channel is already disposed then don't do anything + if (this.isDisposed) { + return; + } + + // Send the error message to the remote side + await this._multiplexingStream.onChannelWritingError(this, error); + + // Set the remote exception to the passed in error so that the channel is + // completed with this error + this.remoteError = error; + + // Dispose of the channel + await this.dispose(); + } + public async dispose() { if (!this.isDisposed) { super.dispose(); @@ -251,11 +284,19 @@ export class ChannelClass extends Channel { this._acceptance.reject(new CancellationToken.CancellationError("disposed")); // For the pipes, we Complete *our* ends, and leave the user's ends alone. - // The completion will propagate when it's ready to. + // The completion will propagate when it's ready to. No need to destroy the duplex + // as the frame containing the error message has already been sent. this._duplex.end(); this._duplex.push(null); - this._completion.resolve(); + // If we are sending an error to the remote side or received an error from the remote, + // relay that information to the clients. + if (this.remoteError) { + this._completion.reject(this.remoteError); + } else { + this._completion.resolve(); + } + await this._multiplexingStream.onChannelDisposed(this); } } diff --git a/src/nerdbank-streams/src/ControlCode.ts b/src/nerdbank-streams/src/ControlCode.ts index 3ae71fed..5459cda6 100644 --- a/src/nerdbank-streams/src/ControlCode.ts +++ b/src/nerdbank-streams/src/ControlCode.ts @@ -32,4 +32,11 @@ export enum ControlCode { * Sent when a channel has finished processing data received from the remote party, allowing them to send more data. */ ContentProcessed, + + /** + * Sent when one party experiences an exception related to a particular channel and carries details regarding the error, + * when using protocol version 2 or later. + * This is sent right before a ContentWritingCompleted frame closes that channel. + */ + ContentWritingError, } diff --git a/src/nerdbank-streams/src/MultiplexingStream.ts b/src/nerdbank-streams/src/MultiplexingStream.ts index 1f6eafcc..4265a32b 100644 --- a/src/nerdbank-streams/src/MultiplexingStream.ts +++ b/src/nerdbank-streams/src/MultiplexingStream.ts @@ -25,6 +25,7 @@ import { import { OfferParameters } from "./OfferParameters"; import { Semaphore } from 'await-semaphore'; import { QualifiedChannelId, ChannelSource } from "./QualifiedChannelId"; +import { WriteError } from "./WriteError"; export abstract class MultiplexingStream implements IDisposableObservable { /** @@ -106,9 +107,9 @@ export abstract class MultiplexingStream implements IDisposableObservable { * @param options Options to customize the behavior of the stream. * @returns The multiplexing stream. */ - public static Create( + public static Create( stream: NodeJS.ReadWriteStream, - options?: MultiplexingStreamOptions) : MultiplexingStream { + options?: MultiplexingStreamOptions): MultiplexingStream { options ??= { protocolMajorVersion: 3 }; options.protocolMajorVersion ??= 3; @@ -578,6 +579,26 @@ export class MultiplexingStreamClass extends MultiplexingStream { } } + public async onChannelWritingError(channel: ChannelClass, error: Error) { + // Make sure that we are in a protocol version in which we can write errors. + if (this.protocolMajorVersion === 1) { + return; + } + + // Make sure we can send error messages on this channel. + if (!this.getOpenChannel(channel.qualifiedId)) { + return; + } + + // Convert the error message into a payload into a formatter. + const writingError = new WriteError(error.message); + const errorSerializingFormatter = this.formatter as MultiplexingStreamV2Formatter; + const errorPayload = errorSerializingFormatter.serializeContentWritingError(writingError); + + // Sent the error to the remote side. + await this.sendFrameAsync(new FrameHeader(ControlCode.ContentWritingError, channel.qualifiedId), errorPayload); + } + public onChannelWritingCompleted(channel: ChannelClass) { // Only inform the remote side if this channel has not already been terminated. if (!channel.isDisposed && this.getOpenChannel(channel.qualifiedId)) { @@ -629,6 +650,9 @@ export class MultiplexingStreamClass extends MultiplexingStream { case ControlCode.ContentWritingCompleted: this.onContentWritingCompleted(frame.header.requiredChannel); break; + case ControlCode.ContentWritingError: + this.onContentWritingError(frame.header.requiredChannel, frame.payload); + break; case ControlCode.ChannelTerminated: this.onChannelTerminated(frame.header.requiredChannel); break; @@ -716,6 +740,27 @@ export class MultiplexingStreamClass extends MultiplexingStream { channel.onContentProcessed(bytesProcessed); } + private onContentWritingError(channelId: QualifiedChannelId, payload: Buffer) { + // Make sure that the channel has the proper formatter to process the output. + if (this.protocolMajorVersion === 1) { + return; + } + + // Ensure that we received the message on an open channel. + const channel = this.getOpenChannel(channelId); + if (!channel) { + throw new Error(`No channel with id ${channelId} found.`); + } + + // Extract the error from the payload. + const errorDeserializingFormatter = this.formatter as MultiplexingStreamV2Formatter; + const writingError = errorDeserializingFormatter.deserializeContentWritingError(payload); + + // Pass the error received from the remote to the channel. + const remoteErr = new Error(`Received error message from remote: ${writingError.message}`); + channel.onContent(null, remoteErr); + } + private onContentWritingCompleted(channelId: QualifiedChannelId) { const channel = this.getOpenChannel(channelId); if (!channel) { diff --git a/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts b/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts index af625669..ce4ec7e0 100644 --- a/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts +++ b/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts @@ -8,6 +8,7 @@ import * as msgpack from 'msgpack-lite'; import { Deferred } from "./Deferred"; import { FrameHeader } from "./FrameHeader"; import { ControlCode } from "./ControlCode"; +import { WriteError } from "./WriteError"; import { ChannelSource } from "./QualifiedChannelId"; export interface Version { @@ -293,6 +294,19 @@ export class MultiplexingStreamV2Formatter extends MultiplexingStreamFormatter { return msgpack.decode(payload)[0]; } + serializeContentWritingError(writingError: WriteError): Buffer { + const payload: any[] = [writingError.message]; + return msgpack.encode(payload); + } + + deserializeContentWritingError(payload: Buffer): WriteError { + const msgpackObject = msgpack.decode(payload); + + // Return the error message to the caller. + const errorMsg: string | undefined = msgpackObject[0]; + return new WriteError(errorMsg ?? ""); + } + protected async readMessagePackAsync(cancellationToken: CancellationToken): Promise<{} | [] | null> { const streamEnded = new Deferred(); while (true) { diff --git a/src/nerdbank-streams/src/WriteError.ts b/src/nerdbank-streams/src/WriteError.ts new file mode 100644 index 00000000..81970f05 --- /dev/null +++ b/src/nerdbank-streams/src/WriteError.ts @@ -0,0 +1,13 @@ +/** + * A class that is used to store information related to ContentWritingError. + * It is used by both the sending and receiving streams to transmit errors encountered while + * writing content. + */ +export class WriteError { + /** + * Initializes a new instance of the WriteError class. + * @param message The error message. + */ + constructor(public readonly message: string) { + } +} diff --git a/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts b/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts index 15341867..108aa344 100644 --- a/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts +++ b/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts @@ -4,6 +4,7 @@ import { Deferred } from "../Deferred"; import { FullDuplexStream } from "../FullDuplexStream"; import { MultiplexingStream } from "../MultiplexingStream"; import { ChannelOptions } from "../ChannelOptions"; +import * as assert from "assert"; [1, 2, 3].forEach(protocolMajorVersion => { describe(`MultiplexingStream v${protocolMajorVersion} (interop) `, () => { @@ -14,6 +15,8 @@ import { ChannelOptions } from "../ChannelOptions"; const dotnetEnvBlock: NodeJS.ProcessEnv = { DOTNET_SKIP_FIRST_TIME_EXPERIENCE: "1", // prevent warnings in stdout that corrupt our interop stream. }; + let expectedDisposeError: boolean; + beforeAll( async () => { proc = spawn( @@ -26,7 +29,8 @@ import { ChannelOptions } from "../ChannelOptions"; proc.once("exit", (code) => procExited.resolve(code)); // proc.stdout!.pipe(process.stdout); proc.stderr!.pipe(process.stderr); - expect(await procExited.promise).toEqual(0); + const buildExitVal = await procExited.promise; + expect(buildExitVal).toEqual(0); } finally { proc.kill(); proc = null; @@ -50,12 +54,21 @@ import { ChannelOptions } from "../ChannelOptions"; proc = null; throw e; } + expectedDisposeError = false; }, 10000); // leave time for dotnet to start. afterEach(async () => { if (mx) { mx.dispose(); - await mx.completion; + + // See if we encounter any errors in the multplexing stream and rethrow them if they are unexpected + try { + await mx.completion; + } catch (error) { + if (!expectedDisposeError) { + throw error; + } + } } if (proc) { @@ -86,6 +99,31 @@ import { ChannelOptions } from "../ChannelOptions"; expect(recv).toEqual(`recv: ${bigdata}`); }); + it("Can send error to remote", async () => { + expectedDisposeError = true; + const errorWriteChannel = await mx.offerChannelAsync("clientErrorOffer"); + const responseReceiveChannel = await mx.offerChannelAsync("clientResponseOffer"); + + const errorMessage = "couldn't send all of the data"; + const errorToSend = new Error(errorMessage); + + let caughtCompletionErr = false; + errorWriteChannel.completion.catch(err => { + caughtCompletionErr = true; + }); + + await errorWriteChannel.fault(errorToSend); + assert.deepStrictEqual(caughtCompletionErr, true); + + let expectedMessage = `received error: Remote party indicated writing error: ${errorMessage}`; + if (protocolMajorVersion === 1) { + expectedMessage = "didn't receive any errors"; + } + + const receivedMessage = await readLineAsync(responseReceiveChannel.stream); + assert.deepStrictEqual(receivedMessage?.trim(), expectedMessage); + }); + if (protocolMajorVersion >= 3) { it("Can communicate over seeded channel", async () => { const channel = mx.acceptChannel(0); diff --git a/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts b/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts index e78cc49a..08db0365 100644 --- a/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts +++ b/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts @@ -230,6 +230,46 @@ import * as assert from "assert"; await channels[1].completion; }); + it("channel disposes with an error", async () => { + const errorMessage = "couldn't send all of the data"; + const errorToSend = new Error(errorMessage); + + const channels = await Promise.all([ + mx1.offerChannelAsync("test"), + mx2.acceptChannelAsync("test"), + ]); + + await channels[0].fault(errorToSend); + + // Ensure that the current channel disposes with the error. + let caughtSenderError = false; + try { + await channels[0].completion; + } catch (error) { + let completionErrMsg = String(error); + if (error instanceof Error) { + completionErrMsg = (error as Error).message; + } + caughtSenderError = completionErrMsg.includes(errorMessage); + } + + assert.strictEqual(true, caughtSenderError); + + // Ensure that the remote side received the error only for version >= 1 + let caughtRemoteError = false; + try { + await channels[1].completion; + } catch (error) { + let completionErrMsg = String(error); + if (error instanceof Error) { + completionErrMsg = (error as Error).message; + } + caughtRemoteError = completionErrMsg.includes(errorMessage); + } + + assert.deepStrictEqual(protocolMajorVersion > 1, caughtRemoteError); + }) + it("channels complete when mxstream is disposed", async () => { const channels = await Promise.all([ mx1.offerChannelAsync("test"), diff --git a/test/Nerdbank.Streams.Interop.Tests/Program.cs b/test/Nerdbank.Streams.Interop.Tests/Program.cs index d5b5a46e..b0c46048 100644 --- a/test/Nerdbank.Streams.Interop.Tests/Program.cs +++ b/test/Nerdbank.Streams.Interop.Tests/Program.cs @@ -61,6 +61,7 @@ private static (StreamReader Reader, StreamWriter Writer) CreateStreamIO(Multipl private async Task RunAsync(int protocolMajorVersion) { this.ClientOfferAsync().Forget(); + this.ClientOfferErrorAsync().Forget(); this.ServerOfferAsync().Forget(); if (protocolMajorVersion >= 3) @@ -79,6 +80,29 @@ private async Task ClientOfferAsync() await w.WriteLineAsync($"recv: {line}"); } + private async Task ClientOfferErrorAsync() + { + // Await both of the channels from the sender, one to read the error and the other to return the response. + MultiplexingStream.Channel? incomingChannel = await this.mx.AcceptChannelAsync("clientErrorOffer"); + MultiplexingStream.Channel? outgoingChannel = await this.mx.AcceptChannelAsync("clientResponseOffer"); + + // Determine the response to send back on the whether the incoming channel completed with an exception. + string? responseMessage = "didn't receive any errors"; + try + { + await incomingChannel.Completion; + } + catch (Exception error) + { + responseMessage = "received error: " + error.Message; + } + + // Create a writer using the outgoing channel and send the response to the sender. + (StreamReader _, StreamWriter writer) = CreateStreamIO(outgoingChannel); + + await writer.WriteLineAsync(responseMessage); + } + private async Task ServerOfferAsync() { MultiplexingStream.Channel? channel = await this.mx.OfferChannelAsync("serverOffer"); diff --git a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs index 83e6717c..c772fb0c 100644 --- a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs +++ b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs @@ -100,6 +100,74 @@ public void DefaultMajorProtocolVersion() Assert.Equal(1, new MultiplexingStream.Options().ProtocolMajorVersion); } + [Fact] + public async Task ClosePipeWithError() + { + (MultiplexingStream.Channel channel1, MultiplexingStream.Channel channel2) = await this.EstablishChannelsAsync("test"); + await channel1.Output.WriteAsync(new byte[] { 1, 2, 3 }, this.TimeoutToken); + ReadResult readResult = await channel2.Input.ReadAtLeastAsync(3, this.TimeoutToken); + channel2.Input.AdvanceTo(readResult.Buffer.End); + + // Now fail one side. + const string expectedErrorMessage = "Inflicted error"; + await channel1.Output.CompleteAsync(new ApplicationException(expectedErrorMessage)); + if (this.ProtocolMajorVersion > 1) + { + MultiplexingProtocolException ex = await Assert.ThrowsAnyAsync(async () => await channel2.Input.ReadAsync(this.TimeoutToken)); + Assert.Contains(expectedErrorMessage, ex.Message); + } + else + { + await channel2.Input.ReadAsync(this.TimeoutToken); + } + } + + [Fact] + public async Task OfferPipeWithError() + { + string errorMessage = "Hello World"; + + // Prepare a readonly pipe that is already populated with data and an error. + var pipe = new Pipe(); + await pipe.Writer.WriteAsync(new byte[] { 1, 2, 3 }, this.TimeoutToken); + pipe.Writer.Complete(new ApplicationException(errorMessage)); + + // Create a sending and receiving channel using the channel. + MultiplexingStream.Channel? localChannel = this.mx1.CreateChannel(new MultiplexingStream.ChannelOptions { ExistingPipe = new DuplexPipe(pipe.Reader) }); + await this.WaitForEphemeralChannelOfferToPropagateAsync(); + MultiplexingStream.Channel? remoteChannel = this.mx2.AcceptChannel(localChannel.QualifiedId.Id); + + async Task ReadAllDataAsync() + { + // Read the latest input from the local channel and determine if we should continue reading. + ReadResult readResult = await remoteChannel.Input.ReadAsync(this.TimeoutToken); + remoteChannel.Input.AdvanceTo(readResult.Buffer.End); + + if (readResult.IsCompleted || readResult.IsCanceled) + { + return; + } + } + + if (this.ProtocolMajorVersion > 1) + { + MultiplexingProtocolException caughtException = await Assert.ThrowsAnyAsync(ReadAllDataAsync); + this.Logger.WriteLine(caughtException.ToString()); + Assert.Contains(errorMessage, caughtException.Message); + + MultiplexingProtocolException remoteCompletionException = await Assert.ThrowsAnyAsync(() => remoteChannel.Completion); + Assert.Contains(errorMessage, remoteCompletionException.Message); + } + else + { + await ReadAllDataAsync(); + } + + // Ensure that the writer of the error completes with that error, no matter what version of the protocol they are using. + ApplicationException localCompletionException = await Assert.ThrowsAnyAsync(() => localChannel.Completion); + Assert.Contains(errorMessage, localCompletionException.Message); + } + [Fact] public async Task OfferReadOnlyDuplexPipe() {