From 92376ce36468ed605a66f09cfe02d4c436ea3a2b Mon Sep 17 00:00:00 2001 From: Brennan Date: Fri, 20 Sep 2024 16:42:17 -0700 Subject: [PATCH] Fix IAsyncEnumerable controller methods to allow setting headers (#57924) * Fix IAsyncEnumerable controller methods to allow setting headers * name * httpjson extensions too * revert --- .../src/HttpResponseJsonExtensions.cs | 102 ++++-------------- .../test/HttpResponseJsonExtensionsTests.cs | 70 +++++++++++- ...ft.AspNetCore.Http.Extensions.Tests.csproj | 1 + .../SystemTextJsonOutputFormatter.cs | 5 - .../SystemTextJsonResultExecutor.cs | 6 -- .../SystemTextJsonOutputFormatterTest.cs | 17 +++ ...SystemTextJsonOutputFormatterController.cs | 8 ++ 7 files changed, 117 insertions(+), 92 deletions(-) diff --git a/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs b/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs index c7d003e6bb0b..84e09c1a3581 100644 --- a/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs +++ b/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs @@ -90,22 +90,12 @@ public static Task WriteAsJsonAsync( response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset; - var startTask = Task.CompletedTask; - if (!response.HasStarted) - { - // Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush. - startTask = response.StartAsync(cancellationToken); - } - // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException - if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled) + if (!cancellationToken.CanBeCanceled) { - return WriteAsJsonAsyncSlow(startTask, response.BodyWriter, value, options, - ignoreOCE: !cancellationToken.CanBeCanceled, - cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted); + return WriteAsJsonAsyncSlow(response.BodyWriter, value, options, response.HttpContext.RequestAborted); } - startTask.GetAwaiter().GetResult(); return JsonSerializer.SerializeAsync(response.BodyWriter, value, options, cancellationToken); } @@ -131,33 +121,22 @@ public static Task WriteAsJsonAsync( response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset; - var startTask = Task.CompletedTask; - if (!response.HasStarted) - { - // Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush. - startTask = response.StartAsync(cancellationToken); - } - // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException - if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled) + if (!cancellationToken.CanBeCanceled) { - return WriteAsJsonAsyncSlow(startTask, response, value, jsonTypeInfo, - ignoreOCE: !cancellationToken.CanBeCanceled, - cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted); + return WriteAsJsonAsyncSlow(response, value, jsonTypeInfo, response.HttpContext.RequestAborted); } - startTask.GetAwaiter().GetResult(); return JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken); - static async Task WriteAsJsonAsyncSlow(Task startTask, HttpResponse response, TValue value, JsonTypeInfo jsonTypeInfo, - bool ignoreOCE, CancellationToken cancellationToken) + static async Task WriteAsJsonAsyncSlow(HttpResponse response, TValue value, JsonTypeInfo jsonTypeInfo, + CancellationToken cancellationToken) { try { - await startTask; await JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken); } - catch (OperationCanceledException) when (ignoreOCE) { } + catch (OperationCanceledException) { } } } @@ -184,52 +163,38 @@ public static Task WriteAsJsonAsync( response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset; - var startTask = Task.CompletedTask; - if (!response.HasStarted) - { - // Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush. - startTask = response.StartAsync(cancellationToken); - } - // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException - if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled) + if (!cancellationToken.CanBeCanceled) { - return WriteAsJsonAsyncSlow(startTask, response, value, jsonTypeInfo, - ignoreOCE: !cancellationToken.CanBeCanceled, - cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted); + return WriteAsJsonAsyncSlow(response, value, jsonTypeInfo, response.HttpContext.RequestAborted); } - startTask.GetAwaiter().GetResult(); return JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken); - static async Task WriteAsJsonAsyncSlow(Task startTask, HttpResponse response, object? value, JsonTypeInfo jsonTypeInfo, - bool ignoreOCE, CancellationToken cancellationToken) + static async Task WriteAsJsonAsyncSlow(HttpResponse response, object? value, JsonTypeInfo jsonTypeInfo, + CancellationToken cancellationToken) { try { - await startTask; await JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken); } - catch (OperationCanceledException) when (ignoreOCE) { } + catch (OperationCanceledException) { } } } [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] [RequiresDynamicCode(RequiresDynamicCodeMessage)] private static async Task WriteAsJsonAsyncSlow( - Task startTask, PipeWriter body, TValue value, JsonSerializerOptions? options, - bool ignoreOCE, CancellationToken cancellationToken) { try { - await startTask; await JsonSerializer.SerializeAsync(body, value, options, cancellationToken); } - catch (OperationCanceledException) when (ignoreOCE) { } + catch (OperationCanceledException) { } } /// @@ -304,42 +269,30 @@ public static Task WriteAsJsonAsync( response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset; - var startTask = Task.CompletedTask; - if (!response.HasStarted) - { - // Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush. - startTask = response.StartAsync(cancellationToken); - } - // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException - if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled) + if (!cancellationToken.CanBeCanceled) { - return WriteAsJsonAsyncSlow(startTask, response.BodyWriter, value, type, options, - ignoreOCE: !cancellationToken.CanBeCanceled, - cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted); + return WriteAsJsonAsyncSlow(response.BodyWriter, value, type, options, + response.HttpContext.RequestAborted); } - startTask.GetAwaiter().GetResult(); return JsonSerializer.SerializeAsync(response.BodyWriter, value, type, options, cancellationToken); } [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] [RequiresDynamicCode(RequiresDynamicCodeMessage)] private static async Task WriteAsJsonAsyncSlow( - Task startTask, PipeWriter body, object? value, Type type, JsonSerializerOptions? options, - bool ignoreOCE, CancellationToken cancellationToken) { try { - await startTask; await JsonSerializer.SerializeAsync(body, value, type, options, cancellationToken); } - catch (OperationCanceledException) when (ignoreOCE) { } + catch (OperationCanceledException) { } } /// @@ -367,33 +320,22 @@ public static Task WriteAsJsonAsync( response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset; - var startTask = Task.CompletedTask; - if (!response.HasStarted) - { - // Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush. - startTask = response.StartAsync(cancellationToken); - } - // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException - if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled) + if (!cancellationToken.CanBeCanceled) { - return WriteAsJsonAsyncSlow(startTask, response.BodyWriter, value, type, context, - ignoreOCE: !cancellationToken.CanBeCanceled, - cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted); + return WriteAsJsonAsyncSlow(response.BodyWriter, value, type, context, response.HttpContext.RequestAborted); } - startTask.GetAwaiter().GetResult(); return JsonSerializer.SerializeAsync(response.BodyWriter, value, type, context, cancellationToken); - static async Task WriteAsJsonAsyncSlow(Task startTask, PipeWriter body, object? value, Type type, JsonSerializerContext context, - bool ignoreOCE, CancellationToken cancellationToken) + static async Task WriteAsJsonAsyncSlow(PipeWriter body, object? value, Type type, JsonSerializerContext context, + CancellationToken cancellationToken) { try { - await startTask; await JsonSerializer.SerializeAsync(body, value, type, context, cancellationToken); } - catch (OperationCanceledException) when (ignoreOCE) { } + catch (OperationCanceledException) { } } } diff --git a/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs b/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs index 35cfa265d7f1..3d5007e73d26 100644 --- a/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs +++ b/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs @@ -1,12 +1,15 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.IO.Pipelines; using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using System.Text.Json.Serialization.Metadata; -using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.TestHost; #nullable enable @@ -481,6 +484,71 @@ public async Task WriteAsJsonAsync_NullValue_WithJsonTypeInfo_JsonResponse() Assert.Equal("null", data); } + // Regression test: https://github.com/dotnet/aspnetcore/issues/57895 + [Fact] + public async Task AsyncEnumerableCanSetHeader() + { + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + + await using var app = builder.Build(); + + app.MapGet("/", IAsyncEnumerable (HttpContext httpContext) => + { + return AsyncEnum(); + + async IAsyncEnumerable AsyncEnum() + { + await Task.Yield(); + httpContext.Response.Headers["Test"] = "t"; + yield return 1; + } + }); + + await app.StartAsync(); + + var client = app.GetTestClient(); + + var result = await client.GetAsync("/"); + result.EnsureSuccessStatusCode(); + var headerValue = Assert.Single(result.Headers.GetValues("Test")); + Assert.Equal("t", headerValue); + + await app.StopAsync(); + } + + // Regression test: https://github.com/dotnet/aspnetcore/issues/57895 + [Fact] + public async Task EnumerableCanSetHeader() + { + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + + await using var app = builder.Build(); + + app.MapGet("/", IEnumerable (HttpContext httpContext) => + { + return Enum(); + + IEnumerable Enum() + { + httpContext.Response.Headers["Test"] = "t"; + yield return 1; + } + }); + + await app.StartAsync(); + + var client = app.GetTestClient(); + + var result = await client.GetAsync("/"); + result.EnsureSuccessStatusCode(); + var headerValue = Assert.Single(result.Headers.GetValues("Test")); + Assert.Equal("t", headerValue); + + await app.StopAsync(); + } + public class TestObject { public string? StringProperty { get; set; } diff --git a/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj b/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj index 686ab34dd28a..4a35778afa55 100644 --- a/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj +++ b/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj @@ -18,6 +18,7 @@ + diff --git a/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs b/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs index c51ca745d8e7..f4e82f6857f7 100644 --- a/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs +++ b/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs @@ -88,11 +88,6 @@ public sealed override async Task WriteResponseBodyAsync(OutputFormatterWriteCon try { var responseWriter = httpContext.Response.BodyWriter; - if (!httpContext.Response.HasStarted) - { - // Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush. - await httpContext.Response.StartAsync(); - } if (jsonTypeInfo is not null) { diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs index cfce28c8dc64..167d4f71bec0 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs @@ -66,12 +66,6 @@ public async Task ExecuteAsync(ActionContext context, JsonResult result) try { var responseWriter = response.BodyWriter; - if (!response.HasStarted) - { - // Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush. - await response.StartAsync(); - } - await JsonSerializer.SerializeAsync(responseWriter, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted); } catch (OperationCanceledException) when (context.HttpContext.RequestAborted.IsCancellationRequested) { } diff --git a/src/Mvc/test/Mvc.FunctionalTests/SystemTextJsonOutputFormatterTest.cs b/src/Mvc/test/Mvc.FunctionalTests/SystemTextJsonOutputFormatterTest.cs index df54ab0d8cd9..d4906ab320f6 100644 --- a/src/Mvc/test/Mvc.FunctionalTests/SystemTextJsonOutputFormatterTest.cs +++ b/src/Mvc/test/Mvc.FunctionalTests/SystemTextJsonOutputFormatterTest.cs @@ -65,4 +65,21 @@ public async Task Formatting_PolymorphicModel_WithJsonPolymorphism() await response.AssertStatusCodeAsync(HttpStatusCode.OK); Assert.Equal(expected, await response.Content.ReadAsStringAsync()); } + + // Regression test: https://github.com/dotnet/aspnetcore/issues/57895 + [Fact] + public async Task CanSetHeaderWithAsyncEnumerable() + { + // Arrange + var expected = "[1]"; + + // Act + var response = await Client.GetAsync($"/SystemTextJsonOutputFormatter/{nameof(SystemTextJsonOutputFormatterController.AsyncEnumerable)}"); + + // Assert + await response.AssertStatusCodeAsync(HttpStatusCode.OK); + Assert.Equal(expected, await response.Content.ReadAsStringAsync()); + var headerValue = Assert.Single(response.Headers.GetValues("Test")); + Assert.Equal("t", headerValue); + } } diff --git a/src/Mvc/test/WebSites/FormatterWebSite/Controllers/SystemTextJsonOutputFormatterController.cs b/src/Mvc/test/WebSites/FormatterWebSite/Controllers/SystemTextJsonOutputFormatterController.cs index 287ffa90fd91..dcbd10cb1171 100644 --- a/src/Mvc/test/WebSites/FormatterWebSite/Controllers/SystemTextJsonOutputFormatterController.cs +++ b/src/Mvc/test/WebSites/FormatterWebSite/Controllers/SystemTextJsonOutputFormatterController.cs @@ -19,6 +19,14 @@ public class SystemTextJsonOutputFormatterController : ControllerBase Address = "Some address", }; + [HttpGet] + public async IAsyncEnumerable AsyncEnumerable() + { + await Task.Yield(); + HttpContext.Response.Headers["Test"] = "t"; + yield return 1; + } + [JsonPolymorphic] [JsonDerivedType(typeof(DerivedModel), nameof(DerivedModel))] public class SimpleModel