Skip to content

Commit

Permalink
Fix IAsyncEnumerable controller methods to allow setting headers (#57924
Browse files Browse the repository at this point in the history
)

* Fix IAsyncEnumerable controller methods to allow setting headers

* name

* httpjson extensions too

* revert
  • Loading branch information
BrennanConroy authored Sep 20, 2024
1 parent 8371724 commit 92376ce
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 92 deletions.
102 changes: 22 additions & 80 deletions src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,12 @@ public static Task WriteAsJsonAsync<TValue>(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response.BodyWriter, value, options,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response.BodyWriter, value, options, response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, options, cancellationToken);
}

Expand All @@ -131,33 +121,22 @@ public static Task WriteAsJsonAsync<TValue>(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response, value, jsonTypeInfo,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response, value, jsonTypeInfo, response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken);

static async Task WriteAsJsonAsyncSlow(Task startTask, HttpResponse response, TValue value, JsonTypeInfo<TValue> jsonTypeInfo,
bool ignoreOCE, CancellationToken cancellationToken)
static async Task WriteAsJsonAsyncSlow(HttpResponse response, TValue value, JsonTypeInfo<TValue> jsonTypeInfo,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}
}

Expand All @@ -184,52 +163,38 @@ public static Task WriteAsJsonAsync(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response, value, jsonTypeInfo,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response, value, jsonTypeInfo, response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken);

static async Task WriteAsJsonAsyncSlow(Task startTask, HttpResponse response, object? value, JsonTypeInfo jsonTypeInfo,
bool ignoreOCE, CancellationToken cancellationToken)
static async Task WriteAsJsonAsyncSlow(HttpResponse response, object? value, JsonTypeInfo jsonTypeInfo,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}
}

[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
[RequiresDynamicCode(RequiresDynamicCodeMessage)]
private static async Task WriteAsJsonAsyncSlow<TValue>(
Task startTask,
PipeWriter body,
TValue value,
JsonSerializerOptions? options,
bool ignoreOCE,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(body, value, options, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}

/// <summary>
Expand Down Expand Up @@ -304,42 +269,30 @@ public static Task WriteAsJsonAsync(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response.BodyWriter, value, type, options,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response.BodyWriter, value, type, options,
response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, type, options, cancellationToken);
}

[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
[RequiresDynamicCode(RequiresDynamicCodeMessage)]
private static async Task WriteAsJsonAsyncSlow(
Task startTask,
PipeWriter body,
object? value,
Type type,
JsonSerializerOptions? options,
bool ignoreOCE,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(body, value, type, options, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}

/// <summary>
Expand Down Expand Up @@ -367,33 +320,22 @@ public static Task WriteAsJsonAsync(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response.BodyWriter, value, type, context,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response.BodyWriter, value, type, context, response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, type, context, cancellationToken);

static async Task WriteAsJsonAsyncSlow(Task startTask, PipeWriter body, object? value, Type type, JsonSerializerContext context,
bool ignoreOCE, CancellationToken cancellationToken)
static async Task WriteAsJsonAsyncSlow(PipeWriter body, object? value, Type type, JsonSerializerContext context,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(body, value, type, context, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}
}

Expand Down
70 changes: 69 additions & 1 deletion src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.TestHost;

#nullable enable

Expand Down Expand Up @@ -481,6 +484,71 @@ public async Task WriteAsJsonAsync_NullValue_WithJsonTypeInfo_JsonResponse()
Assert.Equal("null", data);
}

// Regression test: https://github.com/dotnet/aspnetcore/issues/57895
[Fact]
public async Task AsyncEnumerableCanSetHeader()
{
var builder = WebApplication.CreateBuilder();
builder.WebHost.UseTestServer();

await using var app = builder.Build();

app.MapGet("/", IAsyncEnumerable<int> (HttpContext httpContext) =>
{
return AsyncEnum();

async IAsyncEnumerable<int> AsyncEnum()
{
await Task.Yield();
httpContext.Response.Headers["Test"] = "t";
yield return 1;
}
});

await app.StartAsync();

var client = app.GetTestClient();

var result = await client.GetAsync("/");
result.EnsureSuccessStatusCode();
var headerValue = Assert.Single(result.Headers.GetValues("Test"));
Assert.Equal("t", headerValue);

await app.StopAsync();
}

// Regression test: https://github.com/dotnet/aspnetcore/issues/57895
[Fact]
public async Task EnumerableCanSetHeader()
{
var builder = WebApplication.CreateBuilder();
builder.WebHost.UseTestServer();

await using var app = builder.Build();

app.MapGet("/", IEnumerable<int> (HttpContext httpContext) =>
{
return Enum();

IEnumerable<int> Enum()
{
httpContext.Response.Headers["Test"] = "t";
yield return 1;
}
});

await app.StartAsync();

var client = app.GetTestClient();

var result = await client.GetAsync("/");
result.EnsureSuccessStatusCode();
var headerValue = Assert.Single(result.Headers.GetValues("Test"));
Assert.Equal("t", headerValue);

await app.StopAsync();
}

public class TestObject
{
public string? StringProperty { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
<Reference Include="Microsoft.AspNetCore.Http.Results" />
<Reference Include="Microsoft.AspNetCore.Http.Extensions" />
<Reference Include="Microsoft.AspNetCore.Mvc.Core" />
<Reference Include="Microsoft.AspNetCore.TestHost" />
<Reference Include="Microsoft.Extensions.DependencyInjection" />
<Reference Include="Microsoft.Extensions.DependencyModel" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ public sealed override async Task WriteResponseBodyAsync(OutputFormatterWriteCon
try
{
var responseWriter = httpContext.Response.BodyWriter;
if (!httpContext.Response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await httpContext.Response.StartAsync();
}

if (jsonTypeInfo is not null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,6 @@ public async Task ExecuteAsync(ActionContext context, JsonResult result)
try
{
var responseWriter = response.BodyWriter;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await response.StartAsync();
}

await JsonSerializer.SerializeAsync(responseWriter, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted);
}
catch (OperationCanceledException) when (context.HttpContext.RequestAborted.IsCancellationRequested) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,21 @@ public async Task Formatting_PolymorphicModel_WithJsonPolymorphism()
await response.AssertStatusCodeAsync(HttpStatusCode.OK);
Assert.Equal(expected, await response.Content.ReadAsStringAsync());
}

// Regression test: https://github.com/dotnet/aspnetcore/issues/57895
[Fact]
public async Task CanSetHeaderWithAsyncEnumerable()
{
// Arrange
var expected = "[1]";

// Act
var response = await Client.GetAsync($"/SystemTextJsonOutputFormatter/{nameof(SystemTextJsonOutputFormatterController.AsyncEnumerable)}");

// Assert
await response.AssertStatusCodeAsync(HttpStatusCode.OK);
Assert.Equal(expected, await response.Content.ReadAsStringAsync());
var headerValue = Assert.Single(response.Headers.GetValues("Test"));
Assert.Equal("t", headerValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ public class SystemTextJsonOutputFormatterController : ControllerBase
Address = "Some address",
};

[HttpGet]
public async IAsyncEnumerable<int> AsyncEnumerable()
{
await Task.Yield();
HttpContext.Response.Headers["Test"] = "t";
yield return 1;
}

[JsonPolymorphic]
[JsonDerivedType(typeof(DerivedModel), nameof(DerivedModel))]
public class SimpleModel
Expand Down

0 comments on commit 92376ce

Please sign in to comment.