Skip to content

Commit

Permalink
STJ: Dispose enumerator on exception (dotnet#100194)
Browse files Browse the repository at this point in the history
* STJ: Dispose enumerator on exception

* Avoid code duplication

* Rework fix

* Remove useless Disposable field

* Apply fix on all collection converters

* Remove duplicate assignments

* Skip fix for no-op Dispose implementation

* Move IEnumerator disposal to WriteCore method.

---------

Co-authored-by: Eirik Tsarpalis <eirik.tsarpalis@gmail.com>
  • Loading branch information
manandre and eiriktsarpalis authored Jun 25, 2024
1 parent ddcbc8b commit a54d9e9
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ protected internal override bool OnWriteResume(
if (state.Current.CollectionEnumerator == null)
{
enumerator = value.GetEnumerator();
state.Current.CollectionEnumerator = enumerator;
if (!enumerator.MoveNext())
{
enumerator.Dispose();
Expand All @@ -47,7 +48,6 @@ protected internal override bool OnWriteResume(
{
if (ShouldFlush(ref state, writer))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

Expand All @@ -61,7 +61,6 @@ protected internal override bool OnWriteResume(
TValue element = enumerator.Current.Value;
if (!_valueConverter.TryWrite(writer, element, options, ref state))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ protected internal override bool OnWriteResume(Utf8JsonWriter writer, TDictionar
if (state.Current.CollectionEnumerator == null)
{
enumerator = value.GetEnumerator();
state.Current.CollectionEnumerator = enumerator;
if (!enumerator.MoveNext())
{
return true;
Expand All @@ -62,7 +63,6 @@ protected internal override bool OnWriteResume(Utf8JsonWriter writer, TDictionar
{
if (ShouldFlush(ref state, writer))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

Expand All @@ -87,7 +87,6 @@ protected internal override bool OnWriteResume(Utf8JsonWriter writer, TDictionar
object? element = enumerator.Value;
if (!_valueConverter.TryWrite(writer, element, options, ref state))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ protected override bool OnWriteResume(
if (state.Current.CollectionEnumerator == null)
{
enumerator = value.GetEnumerator();
state.Current.CollectionEnumerator = enumerator;
if (!enumerator.MoveNext())
{
return true;
Expand All @@ -61,14 +62,12 @@ protected override bool OnWriteResume(
{
if (ShouldFlush(ref state, writer))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

object? element = enumerator.Current;
if (!converter.TryWrite(writer, element, options, ref state))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ protected override bool OnWriteResume(Utf8JsonWriter writer, TCollection value,
if (state.Current.CollectionEnumerator == null)
{
enumerator = value.GetEnumerator();
state.Current.CollectionEnumerator = enumerator;
if (!enumerator.MoveNext())
{
enumerator.Dispose();
Expand All @@ -39,14 +40,12 @@ protected override bool OnWriteResume(Utf8JsonWriter writer, TCollection value,
{
if (ShouldFlush(ref state, writer))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

TElement element = enumerator.Current;
if (!converter.TryWrite(writer, element, options, ref state))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ protected sealed override bool OnWriteResume(Utf8JsonWriter writer, TCollection
if (state.Current.CollectionEnumerator == null)
{
enumerator = value.GetEnumerator();
state.Current.CollectionEnumerator = enumerator;
if (!enumerator.MoveNext())
{
return true;
Expand All @@ -61,14 +62,12 @@ protected sealed override bool OnWriteResume(Utf8JsonWriter writer, TCollection
{
if (ShouldFlush(ref state, writer))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

object? element = enumerator.Current;
if (!converter.TryWrite(writer, element, options, ref state))
{
state.Current.CollectionEnumerator = enumerator;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,37 @@ internal bool WriteCore(
{
return TryWrite(writer, value, options, ref state);
}
catch (InvalidOperationException ex) when (ex.Source == ThrowHelper.ExceptionSourceValueToRethrowAsJsonException)
catch (Exception ex)
{
ThrowHelper.ReThrowWithPath(ref state, ex);
throw;
}
catch (JsonException ex) when (ex.Path == null)
{
// JsonExceptions where the Path property is already set
// typically originate from nested calls to JsonSerializer;
// treat these cases as any other exception type and do not
// overwrite any exception information.
if (!state.SupportAsync)
{
// Async serializers should dispose sync and
// async disposables from the async root method.
state.DisposePendingDisposablesOnException();
}

ThrowHelper.AddJsonExceptionInformation(ref state, ex);
throw;
}
catch (NotSupportedException ex)
{
// If the message already contains Path, just re-throw. This could occur in serializer re-entry cases.
// To get proper Path semantics in re-entry cases, APIs that take 'state' need to be used.
if (ex.Message.Contains(" Path: "))
switch (ex)
{
throw;
case InvalidOperationException when ex.Source == ThrowHelper.ExceptionSourceValueToRethrowAsJsonException:
ThrowHelper.ReThrowWithPath(ref state, ex);
break;

case JsonException { Path: null } jsonException:
// JsonExceptions where the Path property is already set
// typically originate from nested calls to JsonSerializer;
// treat these cases as any other exception type and do not
// overwrite any exception information.
ThrowHelper.AddJsonExceptionInformation(ref state, jsonException);
break;

case NotSupportedException when !ex.Message.Contains(" Path: "):
// If the message already contains Path, just re-throw. This could occur in serializer re-entry cases.
// To get proper Path semantics in re-entry cases, APIs that take 'state' need to be used.
ThrowHelper.ThrowNotSupportedException(ref state, ex);
break;
}

ThrowHelper.ThrowNotSupportedException(ref state, ex);
return default;
throw;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ rootValue is not null &&
{
ThrowHelper.ThrowInvalidOperationException_PipeWriterDoesNotImplementUnflushedBytes(bufferWriter);
}

state.PipeWriter = bufferWriter;
state.FlushThreshold = (int)(bufferWriter.Capacity * JsonSerializer.FlushThreshold);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ public void AddCompletedAsyncDisposable(IAsyncDisposable asyncDisposable)
=> (CompletedAsyncDisposables ??= new List<IAsyncDisposable>()).Add(asyncDisposable);

// Asynchronously dispose of any AsyncDisposables that have been scheduled for disposal
public async ValueTask DisposeCompletedAsyncDisposables()
public readonly async ValueTask DisposeCompletedAsyncDisposables()
{
Debug.Assert(CompletedAsyncDisposables?.Count > 0);
Exception? exception = null;
Expand Down Expand Up @@ -307,7 +307,7 @@ public async ValueTask DisposeCompletedAsyncDisposables()
/// Walks the stack cleaning up any leftover IDisposables
/// in the event of an exception on serialization
/// </summary>
public void DisposePendingDisposablesOnException()
public readonly void DisposePendingDisposablesOnException()
{
Exception? exception = null;

Expand Down Expand Up @@ -346,7 +346,7 @@ static void DisposeFrame(IEnumerator? collectionEnumerator, ref Exception? excep
/// Walks the stack cleaning up any leftover I(Async)Disposables
/// in the event of an exception on async serialization
/// </summary>
public async ValueTask DisposePendingDisposablesOnExceptionAsync()
public readonly async ValueTask DisposePendingDisposablesOnExceptionAsync()
{
Exception? exception = null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,9 @@ public static void ThrowNotSupportedException(scoped ref ReadStack state, in Utf
}

[DoesNotReturn]
public static void ThrowNotSupportedException(ref WriteStack state, NotSupportedException ex)
public static void ThrowNotSupportedException(ref WriteStack state, Exception innerException)
{
string message = ex.Message;
string message = innerException.Message;

// The caller should check to ensure path is not already set.
Debug.Assert(!message.Contains(" Path: "));
Expand All @@ -608,7 +608,7 @@ public static void ThrowNotSupportedException(ref WriteStack state, NotSupported

message += $" Path: {state.PropertyPath()}.";

throw new NotSupportedException(message, ex);
throw new NotSupportedException(message, innerException);
}

[DoesNotReturn]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,34 @@ public async Task WriteISetT_DisposesEnumerators()
}
}

[Fact]
public async Task WriteIEnumerableT_ElementSerializationThrows_DisposesEnumerators()
{
var items = new RefCountedList<IEnumerable<int>>(Enumerable.Repeat(ThrowingEnumerable(), 1));
await Assert.ThrowsAsync<DivideByZeroException>(() => Serializer.SerializeWrapper(items.AsEnumerable()));
Assert.Equal(0, items.RefCount);

static IEnumerable<int> ThrowingEnumerable()
{
yield return 42;
throw new DivideByZeroException();
}
}

[Fact]
public async Task WriteIDictionaryT_ElementSerializationThrows_DisposesEnumerators()
{
var items = new RefCountedDictionary<int, IEnumerable<int>>(Enumerable.Repeat(new KeyValuePair<int, IEnumerable<int>>(42, ThrowingEnumerable()), 1));
await Assert.ThrowsAsync<DivideByZeroException>(() => Serializer.SerializeWrapper((IDictionary<int, IEnumerable<int>>)items));
Assert.Equal(0, items.RefCount);

static IEnumerable<int> ThrowingEnumerable()
{
yield return 42;
throw new DivideByZeroException();
}
}

public class SimpleClassWithKeyValuePairs
{
public KeyValuePair<string, string> KvpWStrVal { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ public async Task DeserializeAsyncEnumerable()
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(IDictionary<int, int>))]
[JsonSerializable(typeof(IDictionary<int, IEnumerable<int>>))]
[JsonSerializable(typeof(Dictionary<string, ClassWithInternalParameterlessConstructor>))]
[JsonSerializable(typeof(Dictionary<string, ClassWithPrivateParameterlessConstructor>))]
[JsonSerializable(typeof(Dictionary<string, Dictionary<string, CustomClass>>))]
Expand Down Expand Up @@ -556,6 +557,7 @@ public CollectionTests_Default()
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(IDictionary<int, int>))]
[JsonSerializable(typeof(IDictionary<int, IEnumerable<int>>))]
[JsonSerializable(typeof(Dictionary<string, ClassWithInternalParameterlessConstructor>))]
[JsonSerializable(typeof(Dictionary<string, ClassWithPrivateParameterlessConstructor>))]
[JsonSerializable(typeof(Dictionary<string, Dictionary<string, CustomClass>>))]
Expand Down

0 comments on commit a54d9e9

Please sign in to comment.