diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Internal/ResilienceKeys.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Internal/ResilienceKeys.cs index 6ca28d8d7f5..4112666f376 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Resilience/Internal/ResilienceKeys.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Internal/ResilienceKeys.cs @@ -10,7 +10,7 @@ namespace Microsoft.Extensions.Http.Resilience.Internal; internal static class ResilienceKeys { - public static readonly ResiliencePropertyKey RequestMessage = new("Resilience.Http.RequestMessage"); + public static readonly ResiliencePropertyKey RequestMessage = new("Resilience.Http.RequestMessage"); public static readonly ResiliencePropertyKey RoutingStrategy = new("Resilience.Http.RequestRoutingStrategy"); diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/HttpResilienceContextExtensions.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/HttpResilienceContextExtensions.cs new file mode 100644 index 00000000000..7218f7da7d8 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Resilience/HttpResilienceContextExtensions.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Net.Http; +using Microsoft.Extensions.Http.Resilience.Internal; +using Microsoft.Shared.Diagnostics; +using Polly; + +namespace Polly; + +/// +/// Provides utility methods for working with . +/// +public static class HttpResilienceContextExtensions +{ + /// + /// Gets the request message from the . + /// + /// The resilience context. + /// + /// The request message. + /// If the request message is not present in the the method returns . + /// + /// is . + public static HttpRequestMessage? GetRequestMessage(this ResilienceContext context) + { + _ = Throw.IfNull(context); + return context.Properties.GetValue(ResilienceKeys.RequestMessage, default); + } + + /// + /// Sets the request message on the . + /// + /// The resilience context. + /// The request message. + /// is . + public static void SetRequestMessage(this ResilienceContext context, HttpRequestMessage? requestMessage) + { + _ = Throw.IfNull(context); + context.Properties.Set(ResilienceKeys.RequestMessage, requestMessage); + } +} diff --git a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs new file mode 100644 index 00000000000..3ca02116e16 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs @@ -0,0 +1,86 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Net.Http; +using Microsoft.Extensions.Http.Resilience.Internal; +using Polly; +using Xunit; + +namespace Microsoft.Extensions.Http.Resilience.Test.Resilience; + +public class HttpResilienceContextExtensionsTests +{ + [Fact] + public void GetRequestMessage_ResilienceContextIsNull_Throws() + { + ResilienceContext context = null!; + Assert.Throws(context.GetRequestMessage); + } + + [Fact] + public void GetRequestMessage_RequestMessageIsMissing_ReturnsNull() + { + var context = ResilienceContextPool.Shared.Get(); + + Assert.Null(context.GetRequestMessage()); + + ResilienceContextPool.Shared.Return(context); + } + + [Fact] + public void GetRequestMessage_RequestMessageIsNull_ReturnsNull() + { + var context = ResilienceContextPool.Shared.Get(); + context.Properties.Set(ResilienceKeys.RequestMessage, null); + + Assert.Null(context.GetRequestMessage()); + + ResilienceContextPool.Shared.Return(context); + } + + [Fact] + public void GetRequestMessage_RequestMessageIsPresent_ReturnsRequestMessage() + { + var context = ResilienceContextPool.Shared.Get(); + using var request = new HttpRequestMessage(); + context.Properties.Set(ResilienceKeys.RequestMessage, request); + + Assert.Same(request, context.GetRequestMessage()); + + ResilienceContextPool.Shared.Return(context); + } + + [Fact] + public void SetRequestMessage_ResilienceContextIsNull_Throws() + { + ResilienceContext context = null!; + using var request = new HttpRequestMessage(); + Assert.Throws(() => context.SetRequestMessage(request)); + } + + [Fact] + public void SetRequestMessage_RequestMessageIsNull_SetsNullRequestMessage() + { + var context = ResilienceContextPool.Shared.Get(); + context.SetRequestMessage(null); + + Assert.True(context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out HttpRequestMessage? request)); + Assert.Null(request); + + ResilienceContextPool.Shared.Return(context); + } + + [Fact] + public void SetRequestMessage_RequestMessageIsNotNull_SetsRequestMessage() + { + var context = ResilienceContextPool.Shared.Get(); + using var request = new HttpRequestMessage(); + context.SetRequestMessage(request); + + Assert.True(context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out HttpRequestMessage? actualRequest)); + Assert.Same(request, actualRequest); + + ResilienceContextPool.Shared.Return(context); + } +}