From 37653639d6b8bdb511afc88f7f69b6b66e69d46d Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Tue, 9 Mar 2021 11:18:26 +1300 Subject: [PATCH] gRPC client retries (#1187) --- .../InterceptorRegistration.cs | 2 +- .../Internal/GrpcWebProtocolHelpers.cs | 2 +- .../Configuration/ConfigObject.cs | 68 ++ .../Configuration/HedgingPolicy.cs | 89 ++ .../Configuration/MethodConfig.cs | 101 +++ .../Configuration/MethodName.cs | 82 ++ .../Configuration/RetryPolicy.cs | 105 +++ .../Configuration/RetryThrottlingPolicy.cs | 68 ++ .../Configuration/ServiceConfig.cs | 74 ++ src/Grpc.Net.Client/GrpcChannel.cs | 195 +++- src/Grpc.Net.Client/GrpcChannelOptions.cs | 68 +- .../Internal/ClientStreamWriterBase.cs | 100 +++ .../Internal/Configuration/ConfigProperty.cs | 60 ++ .../Internal/Configuration/ConvertHelpers.cs | 122 +++ .../Internal/Configuration/IConfigValue.cs | 25 + .../Internal/Configuration/Values.cs | 97 ++ .../DefaultChannelCredentialsConfigurator.cs | 62 ++ .../Internal/GrpcCall.NonGeneric.cs | 80 ++ src/Grpc.Net.Client/Internal/GrpcCall.cs | 249 +++--- .../Internal/GrpcMethodInfo.cs | 107 ++- .../Internal/GrpcProtocolConstants.cs | 10 +- .../Internal/GrpcProtocolHelpers.cs | 11 + .../Internal/Http/PushStreamContent.cs | 22 +- .../Internal/Http/PushUnaryContent.cs | 23 +- ...UnaryContent.cs => WinHttpUnaryContent.cs} | 74 +- .../Internal/HttpClientCallInvoker.cs | 54 +- .../Internal/HttpContentClientStreamReader.cs | 2 +- .../Internal/HttpContentClientStreamWriter.cs | 101 +-- src/Grpc.Net.Client/Internal/IGrpcCall.cs | 49 + .../Internal/IsExternalInit.cs | 29 + .../Internal/Retry/ChannelRetryThrottling.cs | 100 +++ .../Internal/Retry/CommitReason.cs | 34 + .../Internal/Retry/HedgingCall.cs | 422 +++++++++ .../Internal/Retry/RetryCall.cs | 337 +++++++ .../Internal/Retry/RetryCallBase.Log.cs | 128 +++ .../Internal/Retry/RetryCallBase.cs | 519 +++++++++++ .../Retry/RetryCallBaseClientStreamReader.cs | 46 + .../Retry/RetryCallBaseClientStreamWriter.cs | 88 ++ .../Internal/Retry/StatusGrpcCall.cs | 135 +++ .../Internal/StreamExtensions.cs | 5 +- src/Shared/CommonGrpcProtocolHelpers.cs | 4 +- .../Client/CancellationTests.cs | 2 +- .../Client/EventSourceTests.cs | 4 +- test/FunctionalTests/Client/HedgingTests.cs | 593 +++++++++++++ test/FunctionalTests/Client/RetryTests.cs | 596 +++++++++++++ test/FunctionalTests/Client/StreamingTests.cs | 58 +- test/FunctionalTests/FunctionalTestBase.cs | 19 +- .../Grpc.AspNetCore.FunctionalTests.csproj | 1 + .../Server/ClientStreamingMethodTests.cs | 2 +- test/FunctionalTests/Server/DeadlineTests.cs | 11 +- .../Web/Base64PipeReaderTests.cs | 2 +- .../AsyncUnaryCallTests.cs | 42 +- .../Grpc.Net.Client.Tests.csproj | 2 + .../Grpc.Net.Client.Tests/GrpcChannelTests.cs | 21 + .../HttpContentClientStreamReaderTests.cs | 5 +- .../HttpClientCallInvokerFactory.cs | 7 +- .../Infrastructure/WinHttpHandler.cs | 28 + .../Retry/ChannelRetryThrottlingTests.cs | 45 + .../Retry/HedgingCallTests.cs | 373 ++++++++ .../Retry/HedgingTests.cs | 690 ++++++++++++++ .../Grpc.Net.Client.Tests/Retry/RetryTests.cs | 840 ++++++++++++++++++ .../ServiceConfigTests.cs | 185 ++++ .../Base64ResponseStreamTests.cs | 4 +- test/Shared/ClientTestHelpers.cs | 9 +- test/Shared/ExceptionAssert.cs | 5 + test/Shared/ServiceConfigHelpers.cs | 108 +++ test/Shared/TestHelpers.cs | 8 +- .../InteropTestsWebsite/TestServiceImpl.cs | 1 + 68 files changed, 7230 insertions(+), 380 deletions(-) create mode 100644 src/Grpc.Net.Client/Configuration/ConfigObject.cs create mode 100644 src/Grpc.Net.Client/Configuration/HedgingPolicy.cs create mode 100644 src/Grpc.Net.Client/Configuration/MethodConfig.cs create mode 100644 src/Grpc.Net.Client/Configuration/MethodName.cs create mode 100644 src/Grpc.Net.Client/Configuration/RetryPolicy.cs create mode 100644 src/Grpc.Net.Client/Configuration/RetryThrottlingPolicy.cs create mode 100644 src/Grpc.Net.Client/Configuration/ServiceConfig.cs create mode 100644 src/Grpc.Net.Client/Internal/ClientStreamWriterBase.cs create mode 100644 src/Grpc.Net.Client/Internal/Configuration/ConfigProperty.cs create mode 100644 src/Grpc.Net.Client/Internal/Configuration/ConvertHelpers.cs create mode 100644 src/Grpc.Net.Client/Internal/Configuration/IConfigValue.cs create mode 100644 src/Grpc.Net.Client/Internal/Configuration/Values.cs create mode 100644 src/Grpc.Net.Client/Internal/DefaultChannelCredentialsConfigurator.cs rename src/Grpc.Net.Client/Internal/Http/{LengthUnaryContent.cs => WinHttpUnaryContent.cs} (60%) create mode 100644 src/Grpc.Net.Client/Internal/IGrpcCall.cs create mode 100644 src/Grpc.Net.Client/Internal/IsExternalInit.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/ChannelRetryThrottling.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/CommitReason.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCall.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCallBase.Log.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamReader.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamWriter.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs create mode 100644 test/FunctionalTests/Client/HedgingTests.cs create mode 100644 test/FunctionalTests/Client/RetryTests.cs create mode 100644 test/Grpc.Net.Client.Tests/Infrastructure/WinHttpHandler.cs create mode 100644 test/Grpc.Net.Client.Tests/Retry/ChannelRetryThrottlingTests.cs create mode 100644 test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs create mode 100644 test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs create mode 100644 test/Grpc.Net.Client.Tests/Retry/RetryTests.cs create mode 100644 test/Grpc.Net.Client.Tests/ServiceConfigTests.cs create mode 100644 test/Shared/ServiceConfigHelpers.cs diff --git a/src/Grpc.AspNetCore.Server/InterceptorRegistration.cs b/src/Grpc.AspNetCore.Server/InterceptorRegistration.cs index 828c032a2..51ee404d9 100644 --- a/src/Grpc.AspNetCore.Server/InterceptorRegistration.cs +++ b/src/Grpc.AspNetCore.Server/InterceptorRegistration.cs @@ -50,7 +50,7 @@ internal InterceptorRegistration( { throw new ArgumentNullException(nameof(arguments)); } - for (int i = 0; i < arguments.Length; i++) + for (var i = 0; i < arguments.Length; i++) { if (arguments[i] == null) { diff --git a/src/Grpc.AspNetCore.Web/Internal/GrpcWebProtocolHelpers.cs b/src/Grpc.AspNetCore.Web/Internal/GrpcWebProtocolHelpers.cs index 1950c5469..00d28233c 100644 --- a/src/Grpc.AspNetCore.Web/Internal/GrpcWebProtocolHelpers.cs +++ b/src/Grpc.AspNetCore.Web/Internal/GrpcWebProtocolHelpers.cs @@ -121,7 +121,7 @@ private static void WriteTrailersContent(Span buffer, IHeaderDictionary tr // gRPC-Web protocol says that names should be lower-case and grpc-web JS client // will check for 'grpc-status' and 'grpc-message' in trailers with lower-case key. // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md#protocol-differences-vs-grpc-over-http2 - for (int i = 0; i < kv.Key.Length; i++) + for (var i = 0; i < kv.Key.Length; i++) { char c = kv.Key[i]; currentBuffer[i] = (byte)((uint)(c - 'A') <= ('Z' - 'A') ? c | 0x20 : c); diff --git a/src/Grpc.Net.Client/Configuration/ConfigObject.cs b/src/Grpc.Net.Client/Configuration/ConfigObject.cs new file mode 100644 index 000000000..297494b67 --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/ConfigObject.cs @@ -0,0 +1,68 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections; +using System.Collections.Generic; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// Represents a configuration object. Implementations provide strongly typed wrappers over + /// collections of untyped values. + /// + public abstract class ConfigObject : IConfigValue + { + /// + /// Gets the underlying configuration values. + /// + public IDictionary Inner { get; } + + internal ConfigObject() : this(new Dictionary()) + { + } + + internal ConfigObject(IDictionary inner) + { + Inner = inner; + } + + object IConfigValue.Inner => Inner; + + internal T? GetValue(string key) + { + if (Inner.TryGetValue(key, out var value)) + { + return (T?)value; + } + return default; + } + + internal void SetValue(string key, T? value) + { + if (value == null) + { + Inner.Remove(key); + } + else + { + Inner[key] = value; + } + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/HedgingPolicy.cs b/src/Grpc.Net.Client/Configuration/HedgingPolicy.cs new file mode 100644 index 000000000..4af22d567 --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/HedgingPolicy.cs @@ -0,0 +1,89 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// The hedging policy for outgoing calls. Hedged calls may execute more than + /// once on the server, so only idempotent methods should specify a hedging + /// policy. + /// + /// + /// + /// Represents the HedgingPolicy message in https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto. + /// + /// + public sealed class HedgingPolicy : ConfigObject + { + internal const string MaxAttemptsPropertyName = "maxAttempts"; + internal const string HedgingDelayPropertyName = "hedgingDelay"; + internal const string NonFatalStatusCodesPropertyName = "nonFatalStatusCodes"; + + private ConfigProperty, IList> _nonFatalStatusCodes = + new(i => new Values(i ?? new List(), s => ConvertHelpers.ConvertStatusCode(s), s => ConvertHelpers.ConvertStatusCode(s.ToString()!)), NonFatalStatusCodesPropertyName); + + /// + /// Initializes a new instance of the class. + /// + public HedgingPolicy() { } + internal HedgingPolicy(IDictionary inner) : base(inner) { } + + /// + /// Gets or sets the maximum number of call attempts. This value includes the original attempt. + /// The hedging policy will send up to this number of calls. + /// + /// This property is required and must be 2 or greater. + /// This value is limited by . + /// + public int? MaxAttempts + { + get => GetValue(MaxAttemptsPropertyName); + set => SetValue(MaxAttemptsPropertyName, value); + } + + /// + /// Gets or sets the hedging delay. + /// The first call will be sent immediately, but the subsequent + /// hedged call will be sent at intervals of the specified delay. + /// Set this to 0 or null to immediately send all hedged calls. + /// + public TimeSpan? HedgingDelay + { + get => ConvertHelpers.ConvertDurationText(GetValue(HedgingDelayPropertyName)); + set => SetValue(HedgingDelayPropertyName, ConvertHelpers.ToDurationText(value)); + } + + /// + /// Gets a collection of status codes which indicate other hedged calls may still + /// succeed. If a non-fatal status code is returned by the server, hedged + /// calls will continue. Otherwise, outstanding requests will be canceled and + /// the error returned to the client application layer. + /// + /// Specifying status codes is optional. + /// + public IList NonFatalStatusCodes + { + get => _nonFatalStatusCodes.GetValue(this)!; + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/MethodConfig.cs b/src/Grpc.Net.Client/Configuration/MethodConfig.cs new file mode 100644 index 000000000..140e3b466 --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/MethodConfig.cs @@ -0,0 +1,101 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections.Generic; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// Configuration for a method. + /// The collection is used to determine which methods this configuration applies to. + /// + /// + /// + /// Represents the MethodConfig message in https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto. + /// + /// + public sealed class MethodConfig : ConfigObject + { + private const string NamePropertyName = "name"; + private const string RetryPolicyPropertyName = "retryPolicy"; + private const string HedgingPolicyPropertyName = "hedgingPolicy"; + + private ConfigProperty, IList> _names = + new(i => new Values(i ?? new List(), s => s.Inner, s => new MethodName((IDictionary)s)), NamePropertyName); + + private ConfigProperty> _retryPolicy = + new(i => i != null ? new RetryPolicy(i) : null, RetryPolicyPropertyName); + + private ConfigProperty> _hedgingPolicy = + new(i => i != null ? new HedgingPolicy(i) : null, HedgingPolicyPropertyName); + + /// + /// Initializes a new instance of the class. + /// + public MethodConfig() { } + internal MethodConfig(IDictionary inner) : base(inner) { } + + /// + /// Gets or sets the retry policy for outgoing calls. + /// A retry policy can't be combined with . + /// + public RetryPolicy? RetryPolicy + { + get => _retryPolicy.GetValue(this); + set => _retryPolicy.SetValue(this, value); + } + + /// + /// Gets or sets the hedging policy for outgoing calls. Hedged calls may execute + /// more than once on the server, so only idempotent methods should specify a hedging + /// policy. A hedging policy can't be combined with . + /// + public HedgingPolicy? HedgingPolicy + { + get => _hedgingPolicy.GetValue(this); + set => _hedgingPolicy.SetValue(this, value); + } + + /// + /// Gets a collection of names which determine the calls the method config will apply to. + /// A without names won't be used. Each name must be unique + /// across an entire . + /// + /// + /// + /// If a name's property isn't set then the method config is the default + /// for all methods for the specified service. + /// + /// + /// If a name's property isn't set then must also be unset, + /// and the method config is the default for all methods on all services. + /// represents this global default name. + /// + /// + /// When determining which method config to use for a given RPC, the most specific match wins. A method config + /// with a configured that exactly matches a call's method and service will be used + /// instead of a service or global default method config. + /// + /// + public IList Names + { + get => _names.GetValue(this)!; + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/MethodName.cs b/src/Grpc.Net.Client/Configuration/MethodName.cs new file mode 100644 index 000000000..1f8e00697 --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/MethodName.cs @@ -0,0 +1,82 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections.Generic; +using System.Collections.ObjectModel; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// The name of a method. Used to configure what calls a applies to using + /// the collection. + /// + /// + /// + /// Represents the Name message in https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto. + /// + /// + /// If a name's property isn't set then the method config is the default + /// for all methods for the specified service. + /// + /// + /// If a name's property isn't set then must also be unset, + /// and the method config is the default for all methods on all services. + /// represents this global default name. + /// + /// + /// When determining which method config to use for a given RPC, the most specific match wins. A method config + /// with a configured that exactly matches a call's method and service will be used + /// instead of a service or global default method config. + /// + /// + public sealed class MethodName + : ConfigObject + { + /// + /// A global default name. + /// + public static readonly MethodName Default = new MethodName(new ReadOnlyDictionary(new Dictionary())); + + private const string ServicePropertyName = "service"; + private const string MethodPropertyName = "method"; + + /// + /// Initializes a new instance of the class. + /// + public MethodName() { } + internal MethodName(IDictionary inner) : base(inner) { } + + /// + /// Gets or sets the service name. + /// + public string? Service + { + get => GetValue(ServicePropertyName); + set => SetValue(ServicePropertyName, value); + } + + /// + /// Gets or sets the method name. + /// + public string? Method + { + get => GetValue(MethodPropertyName); + set => SetValue(MethodPropertyName, value); + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/RetryPolicy.cs b/src/Grpc.Net.Client/Configuration/RetryPolicy.cs new file mode 100644 index 000000000..9a35775af --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/RetryPolicy.cs @@ -0,0 +1,105 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// The retry policy for outgoing calls. + /// + public sealed class RetryPolicy : ConfigObject + { + internal const string MaxAttemptsPropertyName = "maxAttempts"; + internal const string InitialBackoffPropertyName = "initialBackoff"; + internal const string MaxBackoffPropertyName = "maxBackoff"; + internal const string BackoffMultiplierPropertyName = "backoffMultiplier"; + internal const string RetryableStatusCodesPropertyName = "retryableStatusCodes"; + + private ConfigProperty, IList> _retryableStatusCodes = + new(i => new Values(i ?? new List(), s => ConvertHelpers.ConvertStatusCode(s), s => ConvertHelpers.ConvertStatusCode(s.ToString()!)), RetryableStatusCodesPropertyName); + + /// + /// Initializes a new instance of the class. + /// + public RetryPolicy() { } + internal RetryPolicy(IDictionary inner) : base(inner) { } + + /// + /// Gets or sets the maximum number of call attempts. This value includes the original attempt. + /// This property is required and must be greater than 1. + /// This value is limited by . + /// + public int? MaxAttempts + { + get => GetValue(MaxAttemptsPropertyName); + set => SetValue(MaxAttemptsPropertyName, value); + } + + /// + /// Gets or sets the initial backoff. + /// A randomized delay between 0 and the current backoff value will determine when the next + /// retry attempt is made. + /// This property is required and must be greater than zero. + /// + /// The backoff will be multiplied by after each retry + /// attempt and will increase exponentially when the multiplier is greater than 1. + /// + /// + public TimeSpan? InitialBackoff + { + get => ConvertHelpers.ConvertDurationText(GetValue(InitialBackoffPropertyName)); + set => SetValue(InitialBackoffPropertyName, ConvertHelpers.ToDurationText(value)); + } + + /// + /// Gets or sets the maximum backoff. + /// The maximum backoff places an upper limit on exponential backoff growth. + /// This property is required and must be greater than zero. + /// + public TimeSpan? MaxBackoff + { + get => ConvertHelpers.ConvertDurationText(GetValue(MaxBackoffPropertyName)); + set => SetValue(MaxBackoffPropertyName, ConvertHelpers.ToDurationText(value)); + } + + /// + /// Gets or sets the backoff multiplier. + /// The backoff will be multiplied by after each retry + /// attempt and will increase exponentially when the multiplier is greater than 1. + /// This property is required and must be greater than 0. + /// + public double? BackoffMultiplier + { + get => GetValue(BackoffMultiplierPropertyName); + set => SetValue(BackoffMultiplierPropertyName, value); + } + + /// + /// Gets a collection of status codes which may be retried. + /// At least one status code is required. + /// + public IList RetryableStatusCodes + { + get => _retryableStatusCodes.GetValue(this)!; + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/RetryThrottlingPolicy.cs b/src/Grpc.Net.Client/Configuration/RetryThrottlingPolicy.cs new file mode 100644 index 000000000..f8f6ba6ce --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/RetryThrottlingPolicy.cs @@ -0,0 +1,68 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections.Generic; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// The retry throttling policy for a server. + /// + /// For more information about configuring throttling, see https://github.com/grpc/proposal/blob/master/A6-client-retries.md#throttling-retry-attempts-and-hedged-rpcs. + /// + /// + /// + /// + /// Represents the RetryThrottlingPolicy message in https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto. + /// + /// + public sealed class RetryThrottlingPolicy : ConfigObject + { + internal const string MaxTokensPropertyName = "maxTokens"; + internal const string TokenRatioPropertyName = "tokenRatio"; + + /// + /// Initializes a new instance of the class. + /// + public RetryThrottlingPolicy() { } + internal RetryThrottlingPolicy(IDictionary inner) : base(inner) { } + + /// + /// Gets or sets the maximum number of tokens. + /// The number of tokens starts at and the token count will + /// always be between 0 and . + /// This property is required and must be greater than zero. + /// + public int? MaxTokens + { + get => GetValue(MaxTokensPropertyName); + set => SetValue(MaxTokensPropertyName, value); + } + + /// + /// Gets or sets the amount of tokens to add on each successful call. Typically this will + /// be some number between 0 and 1, e.g., 0.1. + /// This property is required and must be greater than zero. Up to 3 decimal places are supported. + /// + public double? TokenRatio + { + get => GetValue(TokenRatioPropertyName); + set => SetValue(TokenRatioPropertyName, value); + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/ServiceConfig.cs b/src/Grpc.Net.Client/Configuration/ServiceConfig.cs new file mode 100644 index 000000000..7fc67e080 --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/ServiceConfig.cs @@ -0,0 +1,74 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections.Generic; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// A represents information about a service. + /// + /// + /// + /// Represents the ServiceConfig message in https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto. + /// + /// + public sealed class ServiceConfig : ConfigObject + { + private const string MethodConfigPropertyName = "methodConfig"; + private const string RetryThrottlingPropertyName = "retryThrottling"; + + private ConfigProperty, IList> _methods = + new(i => new Values(i ?? new List(), s => s.Inner, s => new MethodConfig((IDictionary)s)), MethodConfigPropertyName); + + private ConfigProperty> _retryThrottling = + new(i => i != null ? new RetryThrottlingPolicy(i) : null, RetryThrottlingPropertyName); + + /// + /// Initializes a new instance of the class. + /// + public ServiceConfig() { } + internal ServiceConfig(IDictionary inner) : base(inner) { } + + /// + /// Gets a collection of instances. This collection is used to specify + /// configuration on a per-method basis. determines which calls + /// a method config applies to. + /// + public IList MethodConfigs + { + get => _methods.GetValue(this)!; + } + + /// + /// Gets or sets the retry throttling policy. + /// If a is provided, gRPC will automatically throttle + /// retry attempts and hedged RPCs when the client's ratio of failures to + /// successes exceeds a threshold. + /// + /// For more information about configuring throttling, see https://github.com/grpc/proposal/blob/master/A6-client-retries.md#throttling-retry-attempts-and-hedged-rpcs. + /// + /// + public RetryThrottlingPolicy? RetryThrottling + { + get => _retryThrottling.GetValue(this); + set => _retryThrottling.SetValue(this, value); + } + } +} diff --git a/src/Grpc.Net.Client/GrpcChannel.cs b/src/Grpc.Net.Client/GrpcChannel.cs index 77ddc0cf5..0d69b95d9 100644 --- a/src/Grpc.Net.Client/GrpcChannel.cs +++ b/src/Grpc.Net.Client/GrpcChannel.cs @@ -23,10 +23,15 @@ using System.Net.Http; using Grpc.Core; using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Configuration; +using GrpcServiceConfig = Grpc.Net.Client.Configuration.ServiceConfig; using Grpc.Net.Compression; using Grpc.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Grpc.Net.Client.Internal.Retry; +using System.Threading; +using System.Diagnostics; namespace Grpc.Net.Client { @@ -38,9 +43,15 @@ namespace Grpc.Net.Client public sealed class GrpcChannel : ChannelBase, IDisposable { internal const int DefaultMaxReceiveMessageSize = 1024 * 1024 * 4; // 4 MB + internal const int DefaultMaxRetryAttempts = 5; + internal const long DefaultMaxRetryBufferSize = 1024 * 1024 * 16; // 16 MB + internal const long DefaultMaxRetryBufferPerCallSize = 1024 * 1024; // 1 MB + private readonly object _lock; private readonly ConcurrentDictionary _methodInfoCache; private readonly Func _createMethodInfoFunc; + private readonly Dictionary? _serviceConfigMethods; + private readonly Random? _random; // Internal for testing internal readonly HashSet ActiveCalls; @@ -49,6 +60,9 @@ public sealed class GrpcChannel : ChannelBase, IDisposable internal bool IsWinHttp { get; } internal int? SendMaxMessageSize { get; } internal int? ReceiveMaxMessageSize { get; } + internal int? MaxRetryAttempts { get; } + internal long? MaxRetryBufferSize { get; } + internal long? MaxRetryBufferPerCallSize { get; } internal ILoggerFactory LoggerFactory { get; } internal bool ThrowOperationCanceledOnCancellation { get; } internal bool? IsSecure { get; } @@ -57,6 +71,10 @@ public sealed class GrpcChannel : ChannelBase, IDisposable internal string MessageAcceptEncoding { get; } internal bool Disposed { get; private set; } + // Stateful + internal ChannelRetryThrottling? RetryThrottling { get; } + internal long CurrentRetryBufferSize; + // Options that are set in unit tests internal ISystemClock Clock = SystemClock.Instance; internal IOperatingSystem OperatingSystem = Internal.OperatingSystem.Instance; @@ -67,6 +85,7 @@ public sealed class GrpcChannel : ChannelBase, IDisposable internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(address.Authority) { + _lock = new object(); _methodInfoCache = new ConcurrentDictionary(); // Dispose the HTTP client/handler if... @@ -80,12 +99,21 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr IsWinHttp = channelOptions.HttpHandler != null ? HttpHandlerFactory.HasHttpHandlerType(channelOptions.HttpHandler, "System.Net.Http.WinHttpHandler") : false; SendMaxMessageSize = channelOptions.MaxSendMessageSize; ReceiveMaxMessageSize = channelOptions.MaxReceiveMessageSize; + MaxRetryAttempts = channelOptions.MaxRetryAttempts; + MaxRetryBufferSize = channelOptions.MaxRetryBufferSize; + MaxRetryBufferPerCallSize = channelOptions.MaxRetryBufferPerCallSize; CompressionProviders = ResolveCompressionProviders(channelOptions.CompressionProviders); MessageAcceptEncoding = GrpcProtocolHelpers.GetMessageAcceptEncoding(CompressionProviders); LoggerFactory = channelOptions.LoggerFactory ?? NullLoggerFactory.Instance; ThrowOperationCanceledOnCancellation = channelOptions.ThrowOperationCanceledOnCancellation; _createMethodInfoFunc = CreateMethodInfo; ActiveCalls = new HashSet(); + if (channelOptions.ServiceConfig is { } serviceConfig) + { + RetryThrottling = serviceConfig.RetryThrottling != null ? CreateChannelRetryThrottling(serviceConfig.RetryThrottling) : null; + _serviceConfigMethods = CreateServiceConfigMethods(serviceConfig); + _random = new Random(); + } if (channelOptions.Credentials != null) { @@ -99,6 +127,46 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr } } + private ChannelRetryThrottling CreateChannelRetryThrottling(RetryThrottlingPolicy retryThrottling) + { + if (retryThrottling.MaxTokens == null) + { + throw CreateException(RetryThrottlingPolicy.MaxTokensPropertyName); + } + if (retryThrottling.TokenRatio == null) + { + throw CreateException(RetryThrottlingPolicy.TokenRatioPropertyName); + } + + return new ChannelRetryThrottling(retryThrottling.MaxTokens.GetValueOrDefault(), retryThrottling.TokenRatio.GetValueOrDefault(), LoggerFactory); + + static InvalidOperationException CreateException(string propertyName) + { + return new InvalidOperationException($"Retry throttling missing required property '{propertyName}'."); + } + } + + private static Dictionary CreateServiceConfigMethods(GrpcServiceConfig serviceConfig) + { + var configs = new Dictionary(); + for (var i = 0; i < serviceConfig.MethodConfigs.Count; i++) + { + var methodConfig = serviceConfig.MethodConfigs[i]; + for (var j = 0; j < methodConfig.Names.Count; j++) + { + var name = methodConfig.Names[j]; + var methodKey = new MethodKey(name.Service, name.Method); + if (configs.ContainsKey(methodKey)) + { + throw new InvalidOperationException($"Duplicate method config found. Service: '{name.Service}', method: '{name.Method}'."); + } + configs[methodKey] = methodConfig; + } + } + + return configs; + } + private static HttpMessageInvoker CreateInternalHttpInvoker(HttpMessageHandler? handler) { // HttpMessageInvoker should always dispose handler if Disposed is called on it. @@ -121,7 +189,7 @@ private static HttpMessageInvoker CreateInternalHttpInvoker(HttpMessageHandler? internal void RegisterActiveCall(IDisposable grpcCall) { - lock (ActiveCalls) + lock (_lock) { ActiveCalls.Add(grpcCall); } @@ -129,7 +197,7 @@ internal void RegisterActiveCall(IDisposable grpcCall) internal void FinishActiveCall(IDisposable grpcCall) { - lock (ActiveCalls) + lock (_lock) { ActiveCalls.Remove(grpcCall); } @@ -144,8 +212,31 @@ private GrpcMethodInfo CreateMethodInfo(IMethod method) { var uri = new Uri(method.FullName, UriKind.Relative); var scope = new GrpcCallScope(method.Type, uri); + var methodConfig = ResolveMethodConfig(method); - return new GrpcMethodInfo(scope, new Uri(Address, uri)); + return new GrpcMethodInfo(scope, new Uri(Address, uri), methodConfig); + } + + private MethodConfig? ResolveMethodConfig(IMethod method) + { + if (_serviceConfigMethods != null) + { + MethodConfig? methodConfig; + if (_serviceConfigMethods.TryGetValue(new MethodKey(method.ServiceName, method.Name), out methodConfig)) + { + return methodConfig; + } + if (_serviceConfigMethods.TryGetValue(new MethodKey(method.ServiceName, null), out methodConfig)) + { + return methodConfig; + } + if (_serviceConfigMethods.TryGetValue(new MethodKey(null, null), out methodConfig)) + { + return methodConfig; + } + } + + return null; } private static Dictionary ResolveCompressionProviders(IList? compressionProviders) @@ -156,7 +247,7 @@ private static Dictionary ResolveCompressionProvid } var resolvedCompressionProviders = new Dictionary(StringComparer.Ordinal); - for (int i = 0; i < compressionProviders.Count; i++) + for (var i = 0; i < compressionProviders.Count; i++) { var compressionProvider = compressionProviders[i]; if (!resolvedCompressionProviders.ContainsKey(compressionProvider.EncodingName)) @@ -199,47 +290,6 @@ public override CallInvoker CreateCallInvoker() return invoker; } - private class DefaultChannelCredentialsConfigurator : ChannelCredentialsConfiguratorBase - { - public bool? IsSecure { get; private set; } - public List? CallCredentials { get; private set; } - - public override void SetCompositeCredentials(object state, ChannelCredentials channelCredentials, CallCredentials callCredentials) - { - channelCredentials.InternalPopulateConfiguration(this, null); - - if (callCredentials != null) - { - if (CallCredentials == null) - { - CallCredentials = new List(); - } - - CallCredentials.Add(callCredentials); - } - } - - public override void SetInsecureCredentials(object state) - { - IsSecure = false; - } - - public override void SetSslCredentials(object state, string rootCertificates, KeyCertificatePair keyCertificatePair, VerifyPeerCallback verifyPeerCallback) - { - if (!string.IsNullOrEmpty(rootCertificates) || - keyCertificatePair != null || - verifyPeerCallback != null) - { - throw new InvalidOperationException( - $"{nameof(SslCredentials)} with non-null arguments is not supported by {nameof(GrpcChannel)}. " + - $"{nameof(GrpcChannel)} uses HttpClient to make gRPC calls and HttpClient automatically loads root certificates from the operating system certificate store. " + - $"Client certificates should be configured on HttpClient. See https://aka.ms/AA6we64 for details."); - } - - IsSecure = true; - } - } - /// /// Creates a for the specified address. /// @@ -309,7 +359,7 @@ public void Dispose() return; } - lock (ActiveCalls) + lock (_lock) { if (ActiveCalls.Count > 0) { @@ -330,5 +380,58 @@ public void Dispose() } Disposed = true; } + + internal bool TryAddToRetryBuffer(long messageSize) + { + lock (_lock) + { + if (CurrentRetryBufferSize + messageSize > MaxRetryBufferSize) + { + return false; + } + + CurrentRetryBufferSize += messageSize; + return true; + } + } + + internal void RemoveFromRetryBuffer(long messageSize) + { + lock (_lock) + { + CurrentRetryBufferSize -= messageSize; + } + } + + internal int GetRandomNumber(int minValue, int maxValue) + { + CompatibilityExtensions.Assert(_random != null); + + lock (_lock) + { + return _random.Next(minValue, maxValue); + } + } + + private struct MethodKey : IEquatable + { + public MethodKey(string? service, string? method) + { + Service = service; + Method = method; + } + + public string? Service { get; } + public string? Method { get; } + + public override bool Equals(object? obj) => obj is MethodKey n ? Equals(n) : false; + + // Service and method names are case sensitive. + public bool Equals(MethodKey other) => other.Service == Service && other.Method == Method; + + public override int GetHashCode() => + (Service != null ? StringComparer.Ordinal.GetHashCode(Service) : 0) ^ + (Method != null ? StringComparer.Ordinal.GetHashCode(Method) : 0); + } } } diff --git a/src/Grpc.Net.Client/GrpcChannelOptions.cs b/src/Grpc.Net.Client/GrpcChannelOptions.cs index 4cb89c2c5..dd766f938 100644 --- a/src/Grpc.Net.Client/GrpcChannelOptions.cs +++ b/src/Grpc.Net.Client/GrpcChannelOptions.cs @@ -20,6 +20,7 @@ using System.Collections.Generic; using System.Net.Http; using Grpc.Core; +using Grpc.Net.Client.Configuration; using Grpc.Net.Compression; using Microsoft.Extensions.Logging; @@ -65,6 +66,56 @@ public sealed class GrpcChannelOptions /// public int? MaxReceiveMessageSize { get; set; } + /// + /// Gets or sets the maximum retry attempts. This value limits any retry and hedging attempt values specified in + /// the service config. + /// + /// Setting this value alone doesn't enable retries. Retries are enabled in the service config, which can be done + /// using . + /// + /// + /// A null value removes the maximum retry attempts limit. Defaults to 5. + /// + /// + /// Note: Experimental API that can change or be removed without any prior notice. + /// + /// + public int? MaxRetryAttempts { get; set; } + + /// + /// Gets or sets the maximum buffer size in bytes that can be used to store sent messages when retrying + /// or hedging calls. If the buffer limit is exceeded then no more retry attempts are made and all + /// hedging calls but one will be canceled. This limit is applied across all calls made using the channel. + /// + /// Setting this value alone doesn't enable retries. Retries are enabled in the service config, which can be done + /// using . + /// + /// + /// A null value removes the maximum retry buffer size limit. Defaults to 16,777,216 (16 MB). + /// + /// + /// Note: Experimental API that can change or be removed without any prior notice. + /// + /// + public long? MaxRetryBufferSize { get; set; } + + /// + /// Gets or sets the maximum buffer size in bytes that can be used to store sent messages when retrying + /// or hedging calls. If the buffer limit is exceeded then no more retry attempts are made and all + /// hedging calls but one will be canceled. This limit is applied to one call. + /// + /// Setting this value alone doesn't enable retries. Retries are enabled in the service config, which can be done + /// using . + /// + /// + /// A null value removes the maximum retry buffer size limit per call. Defaults to 1,048,576 (1 MB). + /// + /// + /// Note: Experimental API that can change or be removed without any prior notice. + /// + /// + public long? MaxRetryBufferPerCallSize { get; set; } + /// /// Gets or sets a collection of compression providers. /// @@ -123,16 +174,31 @@ public sealed class GrpcChannelOptions /// Gets or sets a value indicating whether clients will throw for a call when its /// is triggered or its is exceeded. /// The default value is false. - /// Note: experimental API that can change or be removed without any prior notice. + /// + /// Note: Experimental API that can change or be removed without any prior notice. + /// /// public bool ThrowOperationCanceledOnCancellation { get; set; } + /// + /// Gets or sets the service config for a gRPC channel. A service config allows service owners to publish parameters + /// to be automatically used by all clients of their service. A service config can also be specified by a client + /// using this property. + /// + /// Note: Experimental API that can change or be removed without any prior notice. + /// + /// + public ServiceConfig? ServiceConfig { get; set; } + /// /// Initializes a new instance of the class. /// public GrpcChannelOptions() { MaxReceiveMessageSize = GrpcChannel.DefaultMaxReceiveMessageSize; + MaxRetryAttempts = GrpcChannel.DefaultMaxRetryAttempts; + MaxRetryBufferSize = GrpcChannel.DefaultMaxRetryBufferSize; + MaxRetryBufferPerCallSize = GrpcChannel.DefaultMaxRetryBufferPerCallSize; } } } diff --git a/src/Grpc.Net.Client/Internal/ClientStreamWriterBase.cs b/src/Grpc.Net.Client/Internal/ClientStreamWriterBase.cs new file mode 100644 index 000000000..dd48d2765 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/ClientStreamWriterBase.cs @@ -0,0 +1,100 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Microsoft.Extensions.Logging; + +namespace Grpc.Net.Client.Internal +{ + internal abstract class ClientStreamWriterBase : IClientStreamWriter + where TRequest : class + { + protected ILogger Logger { get; } + protected object WriteLock { get; } + protected Task? WriteTask { get; set; } + + protected ClientStreamWriterBase(ILogger logger) + { + Logger = logger; + WriteLock = new object(); + } + + public abstract WriteOptions? WriteOptions { get; set; } + + public abstract Task CompleteAsync(); + + public abstract Task WriteAsync(TRequest message); + + protected Task CreateErrorTask(string message) + { + var ex = new InvalidOperationException(message); + Log.WriteMessageError(Logger, ex); + return Task.FromException(ex); + } + + public void Dispose() + { + } + + /// + /// A value indicating whether there is an async write already in progress. + /// Should only check this property when holding the write lock. + /// + protected bool IsWriteInProgressUnsynchronized + { + get + { + Debug.Assert(Monitor.IsEntered(WriteLock)); + + var writeTask = WriteTask; + return writeTask != null && !writeTask.IsCompleted; + } + } + + protected static class Log + { + private static readonly Action _completingClientStream = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "CompletingClientStream"), "Completing client stream."); + + private static readonly Action _writeMessageError = + LoggerMessage.Define(LogLevel.Error, new EventId(2, "WriteMessageError"), "Error writing message."); + + private static readonly Action _completeClientStreamError = + LoggerMessage.Define(LogLevel.Error, new EventId(3, "CompleteClientStreamError"), "Error completing client stream."); + + public static void CompletingClientStream(ILogger logger) + { + _completingClientStream(logger, null); + } + + public static void WriteMessageError(ILogger logger, Exception ex) + { + _writeMessageError(logger, ex); + } + + public static void CompleteClientStreamError(ILogger logger, Exception ex) + { + _completeClientStreamError(logger, ex); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Configuration/ConfigProperty.cs b/src/Grpc.Net.Client/Internal/Configuration/ConfigProperty.cs new file mode 100644 index 000000000..56ad27474 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Configuration/ConfigProperty.cs @@ -0,0 +1,60 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using Grpc.Net.Client.Configuration; + +namespace Grpc.Net.Client.Internal.Configuration +{ + internal struct ConfigProperty where TValue : IConfigValue + { + private TValue? _value; + private readonly Func _valueFactory; + private readonly string _key; + + public ConfigProperty(Func valueFactory, string key) + { + _value = default; + _valueFactory = valueFactory; + _key = key; + } + + public TValue? GetValue(ConfigObject inner) + { + if (_value == null) + { + var innerValue = inner.GetValue(_key); + _value = _valueFactory(innerValue); + + if (_value != null && innerValue == null) + { + // Set newly created value + SetValue(inner, _value); + } + } + + return _value; + } + + public void SetValue(ConfigObject inner, TValue? value) + { + _value = value; + inner.SetValue(_key, _value?.Inner); + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Configuration/ConvertHelpers.cs b/src/Grpc.Net.Client/Internal/Configuration/ConvertHelpers.cs new file mode 100644 index 000000000..e403639b2 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Configuration/ConvertHelpers.cs @@ -0,0 +1,122 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Globalization; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal.Configuration +{ + internal static class ConvertHelpers + { + public static string ConvertStatusCode(StatusCode statusCode) + { + return statusCode switch + { + StatusCode.OK => "OK", + StatusCode.Cancelled => "CANCELLED", + StatusCode.Unknown => "UNKNOWN", + StatusCode.InvalidArgument => "INVALID_ARGUMENT", + StatusCode.DeadlineExceeded => "DEADLINE_EXCEEDED", + StatusCode.NotFound => "NOT_FOUND", + StatusCode.AlreadyExists => "ALREADY_EXISTS", + StatusCode.PermissionDenied => "PERMISSION_DENIED", + StatusCode.Unauthenticated => "UNAUTHENTICATED", + StatusCode.ResourceExhausted => "RESOURCE_EXHAUSTED", + StatusCode.FailedPrecondition => "FAILED_PRECONDITION", + StatusCode.Aborted => "ABORTED", + StatusCode.OutOfRange => "OUT_OF_RANGE", + StatusCode.Unimplemented => "UNIMPLEMENTED", + StatusCode.Internal => "INTERNAL", + StatusCode.Unavailable => "UNAVAILABLE", + StatusCode.DataLoss => "DATA_LOSS", + _ => throw new InvalidOperationException($"Unexpected status code: {statusCode}") + }; + } + + public static StatusCode ConvertStatusCode(string statusCode) + { + return statusCode.ToUpperInvariant() switch + { + "OK" => StatusCode.OK, + "CANCELLED" => StatusCode.Cancelled, + "UNKNOWN" => StatusCode.Unknown, + "INVALID_ARGUMENT" => StatusCode.InvalidArgument, + "DEADLINE_EXCEEDED" => StatusCode.DeadlineExceeded, + "NOT_FOUND" => StatusCode.NotFound, + "ALREADY_EXISTS" => StatusCode.AlreadyExists, + "PERMISSION_DENIED" => StatusCode.PermissionDenied, + "UNAUTHENTICATED" => StatusCode.Unauthenticated, + "RESOURCE_EXHAUSTED" => StatusCode.ResourceExhausted, + "FAILED_PRECONDITION" => StatusCode.FailedPrecondition, + "ABORTED" => StatusCode.Aborted, + "OUT_OF_RANGE" => StatusCode.OutOfRange, + "UNIMPLEMENTED" => StatusCode.Unimplemented, + "INTERNAL" => StatusCode.Internal, + "UNAVAILABLE" => StatusCode.Unavailable, + "DATA_LOSS" => StatusCode.DataLoss, + _ => int.TryParse(statusCode, out var number) + ? (StatusCode)number + : throw new InvalidOperationException($"Unexpected status code: {statusCode}") + }; + } + + public static TimeSpan? ConvertDurationText(string? text) + { + if (text == null) + { + return null; + } + + // This format is based on the Protobuf duration's JSON mapping. + // https://github.com/protocolbuffers/protobuf/blob/35bdcabdd6a05ce9ee738ad7df8c1299d9c7fc4b/src/google/protobuf/duration.proto#L92 + // + // Note that this is precise down to ticks. Fractions that are smaller than ticks will be lost. + // This shouldn't matter because timers on Windows and Linux only have millisecond precision. + if (text.Length > 0 && text[text.Length - 1] == 's' && + decimal.TryParse(text.Substring(0, text.Length - 1), NumberStyles.AllowDecimalPoint | NumberStyles.AllowLeadingSign, CultureInfo.InvariantCulture, out var seconds)) + { + try + { + var ticks = (long)(seconds * TimeSpan.TicksPerSecond); + return TimeSpan.FromTicks(ticks); + } + catch (Exception ex) + { + throw new FormatException($"'{text}' isn't a valid duration.", ex); + } + } + else + { + throw new FormatException($"'{text}' isn't a valid duration."); + } + } + + public static string? ToDurationText(TimeSpan? value) + { + if (value == null) + { + return null; + } + + // This format is based on the Protobuf duration's JSON mapping. + // https://github.com/protocolbuffers/protobuf/blob/35bdcabdd6a05ce9ee738ad7df8c1299d9c7fc4b/src/google/protobuf/duration.proto#L92 + return value.GetValueOrDefault().TotalSeconds + "s"; + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Configuration/IConfigValue.cs b/src/Grpc.Net.Client/Internal/Configuration/IConfigValue.cs new file mode 100644 index 000000000..e3a06371d --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Configuration/IConfigValue.cs @@ -0,0 +1,25 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +namespace Grpc.Net.Client.Internal.Configuration +{ + internal interface IConfigValue + { + object Inner { get; } + } +} diff --git a/src/Grpc.Net.Client/Internal/Configuration/Values.cs b/src/Grpc.Net.Client/Internal/Configuration/Values.cs new file mode 100644 index 000000000..825f96356 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Configuration/Values.cs @@ -0,0 +1,97 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Grpc.Net.Client.Internal.Configuration +{ + internal class Values : IList, IConfigValue + { + internal readonly IList Inner; + + private readonly IList _values; + internal readonly Func _convertTo; + internal readonly Func _convertFrom; + + public Values(IList inner, Func convertTo, Func convertFrom) + { + Inner = inner; + _values = new List(); + _convertTo = convertTo; + _convertFrom = convertFrom; + + foreach (var item in Inner) + { + _values.Add(_convertFrom(item)); + } + } + + public T this[int index] + { + get => _values[index]; + set + { + _values[index] = value; + Inner[index] = _convertTo(value); + } + } + + public int Count => Inner.Count; + public bool IsReadOnly => Inner.IsReadOnly; + + object IConfigValue.Inner => Inner; + + public void Add(T item) + { + _values.Add(item); + Inner.Add(_convertTo(item)); + } + + public void Clear() + { + _values.Clear(); + Inner.Clear(); + } + + public bool Contains(T item) => _values.Contains(item); + + public void CopyTo(T[] array, int arrayIndex) => _values.CopyTo(array, arrayIndex); + + public IEnumerator GetEnumerator() => _values.GetEnumerator(); + + public int IndexOf(T item) => _values.IndexOf(item); + + public void Insert(int index, T item) + { + _values.Insert(index, item); + Inner.Insert(index, _convertTo(item)); + } + + public bool Remove(T item) => _values.Remove(item) && Inner.Remove(_convertTo(item)); + + public void RemoveAt(int index) + { + _values.RemoveAt(index); + Inner.RemoveAt(index); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } +} diff --git a/src/Grpc.Net.Client/Internal/DefaultChannelCredentialsConfigurator.cs b/src/Grpc.Net.Client/Internal/DefaultChannelCredentialsConfigurator.cs new file mode 100644 index 000000000..d31fc1840 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/DefaultChannelCredentialsConfigurator.cs @@ -0,0 +1,62 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal +{ + internal class DefaultChannelCredentialsConfigurator : ChannelCredentialsConfiguratorBase + { + public bool? IsSecure { get; private set; } + public List? CallCredentials { get; private set; } + + public override void SetCompositeCredentials(object state, ChannelCredentials channelCredentials, CallCredentials callCredentials) + { + channelCredentials.InternalPopulateConfiguration(this, null); + + if (callCredentials != null) + { + if (CallCredentials == null) + { + CallCredentials = new List(); + } + + CallCredentials.Add(callCredentials); + } + } + + public override void SetInsecureCredentials(object state) => IsSecure = false; + + public override void SetSslCredentials(object state, string rootCertificates, KeyCertificatePair keyCertificatePair, VerifyPeerCallback verifyPeerCallback) + { + if (!string.IsNullOrEmpty(rootCertificates) || + keyCertificatePair != null || + verifyPeerCallback != null) + { + throw new InvalidOperationException( + $"{nameof(SslCredentials)} with non-null arguments is not supported by {nameof(GrpcChannel)}. " + + $"{nameof(GrpcChannel)} uses HttpClient to make gRPC calls and HttpClient automatically loads root certificates from the operating system certificate store. " + + $"Client certificates should be configured on HttpClient. See https://aka.ms/AA6we64 for details."); + } + + IsSecure = true; + } + } +} diff --git a/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs b/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs index 23750c374..cdf6c37c3 100644 --- a/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs +++ b/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs @@ -19,6 +19,7 @@ using System; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Net; using System.Net.Http; using Grpc.Core; using Grpc.Shared; @@ -90,5 +91,84 @@ protected bool TryGetTrailers([NotNullWhen(true)] out Metadata? trailers) trailers = Trailers; return true; } + + internal static Status? ValidateHeaders(HttpResponseMessage httpResponse, out Metadata? trailers) + { + // gRPC status can be returned in the header when there is no message (e.g. unimplemented status) + // An explicitly specified status header has priority over other failing statuses + if (GrpcProtocolHelpers.TryGetStatusCore(httpResponse.Headers, out var status)) + { + // Trailers are in the header because there is no message. + // Note that some default headers will end up in the trailers (e.g. Date, Server). + trailers = GrpcProtocolHelpers.BuildMetadata(httpResponse.Headers); + return status; + } + + trailers = null; + + // ALPN negotiation is sending HTTP/1.1 and HTTP/2. + // Check that the response wasn't downgraded to HTTP/1.1. + if (httpResponse.Version < GrpcProtocolConstants.Http2Version) + { + return new Status(StatusCode.Internal, $"Bad gRPC response. Response protocol downgraded to HTTP/{httpResponse.Version.ToString(2)}."); + } + + if (httpResponse.StatusCode != HttpStatusCode.OK) + { + var statusCode = MapHttpStatusToGrpcCode(httpResponse.StatusCode); + return new Status(statusCode, "Bad gRPC response. HTTP status code: " + (int)httpResponse.StatusCode); + } + + if (httpResponse.Content?.Headers.ContentType == null) + { + return new Status(StatusCode.Cancelled, "Bad gRPC response. Response did not have a content-type header."); + } + + var grpcEncoding = httpResponse.Content.Headers.ContentType; + if (!CommonGrpcProtocolHelpers.IsContentType(GrpcProtocolConstants.GrpcContentType, grpcEncoding?.MediaType)) + { + return new Status(StatusCode.Cancelled, "Bad gRPC response. Invalid content-type value: " + grpcEncoding); + } + + // Call is still in progress + return null; + } + + private static StatusCode MapHttpStatusToGrpcCode(HttpStatusCode httpStatusCode) + { + switch (httpStatusCode) + { + case HttpStatusCode.BadRequest: // 400 +#if !NETSTANDARD2_0 + case HttpStatusCode.RequestHeaderFieldsTooLarge: // 431 +#else + case (HttpStatusCode)431: +#endif + return StatusCode.Internal; + case HttpStatusCode.Unauthorized: // 401 + return StatusCode.Unauthenticated; + case HttpStatusCode.Forbidden: // 403 + return StatusCode.PermissionDenied; + case HttpStatusCode.NotFound: // 404 + return StatusCode.Unimplemented; +#if !NETSTANDARD2_0 + case HttpStatusCode.TooManyRequests: // 429 +#else + case (HttpStatusCode)429: +#endif + case HttpStatusCode.BadGateway: // 502 + case HttpStatusCode.ServiceUnavailable: // 503 + case HttpStatusCode.GatewayTimeout: // 504 + return StatusCode.Unavailable; + default: + if ((int)httpStatusCode >= 100 && (int)httpStatusCode < 200) + { + // 1xx. These headers should have been ignored. + return StatusCode.Internal; + } + + return StatusCode.Unknown; + } + } } } diff --git a/src/Grpc.Net.Client/Internal/GrpcCall.cs b/src/Grpc.Net.Client/Internal/GrpcCall.cs index 3bf17e799..8a59874db 100644 --- a/src/Grpc.Net.Client/Internal/GrpcCall.cs +++ b/src/Grpc.Net.Client/Internal/GrpcCall.cs @@ -18,8 +18,10 @@ using System; using System.Buffers; +using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.IO; using System.Net; using System.Net.Http; @@ -27,6 +29,7 @@ using System.Threading.Tasks; using Grpc.Core; using Grpc.Net.Client.Internal.Http; +using Grpc.Net.Client.Configuration; using Grpc.Shared; using Microsoft.Extensions.Logging; @@ -36,18 +39,19 @@ namespace Grpc.Net.Client.Internal { - internal sealed partial class GrpcCall : GrpcCall, IDisposable + internal sealed partial class GrpcCall : GrpcCall, IGrpcCall where TRequest : class where TResponse : class { - private const string ErrorStartingCallMessage = "Error starting gRPC call."; + internal const string ErrorStartingCallMessage = "Error starting gRPC call."; private readonly CancellationTokenSource _callCts; private readonly TaskCompletionSource _callTcs; private readonly DateTime _deadline; private readonly GrpcMethodInfo _grpcMethodInfo; + private readonly int _attemptCount; - private Task? _httpResponseTask; + internal Task? _httpResponseTask; private Task? _responseHeadersTask; private Timer? _deadlineTimer; private CancellationTokenRegistration? _ctsRegistration; @@ -57,10 +61,12 @@ internal sealed partial class GrpcCall : GrpcCall, IDisposa // These are set depending on the type of gRPC call private TaskCompletionSource? _responseTcs; + + public int MessagesWritten { get; private set; } public HttpContentClientStreamWriter? ClientStreamWriter { get; private set; } public HttpContentClientStreamReader? ClientStreamReader { get; private set; } - public GrpcCall(Method method, GrpcMethodInfo grpcMethodInfo, CallOptions options, GrpcChannel channel) + public GrpcCall(Method method, GrpcMethodInfo grpcMethodInfo, CallOptions options, GrpcChannel channel, int attemptCount) : base(options, channel) { // Validate deadline before creating any objects that require cleanup @@ -72,10 +78,13 @@ public GrpcCall(Method method, GrpcMethodInfo grpcMethodInf Method = method; _grpcMethodInfo = grpcMethodInfo; _deadline = options.Deadline ?? DateTime.MaxValue; + _attemptCount = attemptCount; Channel.RegisterActiveCall(this); } + public MethodConfigInfo? MethodConfig => _grpcMethodInfo.MethodConfig; + private void ValidateDeadline(DateTime? deadline) { if (deadline != null && deadline != DateTime.MaxValue && deadline != DateTime.MinValue && deadline.Value.Kind != DateTimeKind.Utc) @@ -94,40 +103,75 @@ public CancellationToken CancellationToken public override Type RequestType => typeof(TRequest); public override Type ResponseType => typeof(TResponse); - public void StartUnary(TRequest request) + IClientStreamWriter? IGrpcCall.ClientStreamWriter => ClientStreamWriter; + IAsyncStreamReader? IGrpcCall.ClientStreamReader => ClientStreamReader; + + public void StartUnary(TRequest request) => StartUnaryCore(CreatePushUnaryContent(request)); + + public void StartClientStreaming() + { + var clientStreamWriter = new HttpContentClientStreamWriter(this); + var content = new PushStreamContent(clientStreamWriter); + + StartClientStreamingCore(clientStreamWriter, content); + } + + public void StartServerStreaming(TRequest request) => StartServerStreamingCore(CreatePushUnaryContent(request)); + + private HttpContent CreatePushUnaryContent(TRequest request) + { + return !Channel.IsWinHttp + ? new PushUnaryContent(request, WriteAsync) + : new WinHttpUnaryContent(request, WriteAsync, this); + + ValueTask WriteAsync(TRequest request, Stream stream) + { + return WriteMessageAsync(stream, request, Options); + } + } + + public void StartDuplexStreaming() + { + var clientStreamWriter = new HttpContentClientStreamWriter(this); + var content = new PushStreamContent(clientStreamWriter); + + StartDuplexStreamingCore(clientStreamWriter, content); + } + + internal void StartUnaryCore(HttpContent content) { _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); - SetMessageContent(request, message); + SetMessageContent(content, message); _ = RunCall(message, timeout); } - public void StartClientStreaming() + internal void StartServerStreamingCore(HttpContent content) { - _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); - CreateWriter(message); + SetMessageContent(content, message); + ClientStreamReader = new HttpContentClientStreamReader(this); _ = RunCall(message, timeout); } - public void StartServerStreaming(TRequest request) + internal void StartClientStreamingCore(HttpContentClientStreamWriter clientStreamWriter, HttpContent content) { + _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); - SetMessageContent(request, message); - ClientStreamReader = new HttpContentClientStreamReader(this); + SetWriter(message, clientStreamWriter, content); _ = RunCall(message, timeout); } - public void StartDuplexStreaming() + public void StartDuplexStreamingCore(HttpContentClientStreamWriter clientStreamWriter, HttpContent content) { var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); - CreateWriter(message); + SetWriter(message, clientStreamWriter, content); ClientStreamReader = new HttpContentClientStreamReader(this); _ = RunCall(message, timeout); } @@ -138,14 +182,14 @@ public void Dispose() { Disposed = true; - Cleanup(new Status(StatusCode.Cancelled, "gRPC call disposed.")); + Cleanup(GrpcProtocolConstants.DisposeCanceledStatus); } } /// /// Clean up can be called by: /// 1. The user. AsyncUnaryCall.Dispose et al will call this on Dispose - /// 2. will call dispose if errors fail validation + /// 2. will call dispose if errors fail validation /// 3. will call dispose /// private void Cleanup(Status status) @@ -254,7 +298,15 @@ private async Task GetResponseHeadersCoreAsync() await CallTask.ConfigureAwait(false); } - return GrpcProtocolHelpers.BuildMetadata(httpResponse.Headers); + var metadata = GrpcProtocolHelpers.BuildMetadata(httpResponse.Headers); + + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#exposed-retry-metadata + if (_attemptCount > 1) + { + metadata.Add(GrpcProtocolConstants.RetryPreviousAttemptsHeader, (_attemptCount - 1).ToString(CultureInfo.InvariantCulture)); + } + + return metadata; } catch (Exception ex) when (ResolveException(ErrorStartingCallMessage, ex, out _, out var resolvedException)) { @@ -281,85 +333,6 @@ public Task GetResponseAsync() return _responseTcs.Task; } - private Status? ValidateHeaders(HttpResponseMessage httpResponse) - { - GrpcCallLog.ResponseHeadersReceived(Logger); - - // gRPC status can be returned in the header when there is no message (e.g. unimplemented status) - // An explicitly specified status header has priority over other failing statuses - if (GrpcProtocolHelpers.TryGetStatusCore(httpResponse.Headers, out var status)) - { - // Trailers are in the header because there is no message. - // Note that some default headers will end up in the trailers (e.g. Date, Server). - Trailers = GrpcProtocolHelpers.BuildMetadata(httpResponse.Headers); - return status; - } - - // ALPN negotiation is sending HTTP/1.1 and HTTP/2. - // Check that the response wasn't downgraded to HTTP/1.1. - if (httpResponse.Version < GrpcProtocolConstants.Http2Version) - { - return new Status(StatusCode.Internal, $"Bad gRPC response. Response protocol downgraded to HTTP/{httpResponse.Version.ToString(2)}."); - } - - if (httpResponse.StatusCode != HttpStatusCode.OK) - { - var statusCode = MapHttpStatusToGrpcCode(httpResponse.StatusCode); - return new Status(statusCode, "Bad gRPC response. HTTP status code: " + (int)httpResponse.StatusCode); - } - - if (httpResponse.Content?.Headers.ContentType == null) - { - return new Status(StatusCode.Cancelled, "Bad gRPC response. Response did not have a content-type header."); - } - - var grpcEncoding = httpResponse.Content.Headers.ContentType; - if (!CommonGrpcProtocolHelpers.IsContentType(GrpcProtocolConstants.GrpcContentType, grpcEncoding?.MediaType)) - { - return new Status(StatusCode.Cancelled, "Bad gRPC response. Invalid content-type value: " + grpcEncoding); - } - - // Call is still in progress - return null; - } - - private static StatusCode MapHttpStatusToGrpcCode(HttpStatusCode httpStatusCode) - { - switch (httpStatusCode) - { - case HttpStatusCode.BadRequest: // 400 -#if !NETSTANDARD2_0 - case HttpStatusCode.RequestHeaderFieldsTooLarge: // 431 -#else - case (HttpStatusCode)431: -#endif - return StatusCode.Internal; - case HttpStatusCode.Unauthorized: // 401 - return StatusCode.Unauthenticated; - case HttpStatusCode.Forbidden: // 403 - return StatusCode.PermissionDenied; - case HttpStatusCode.NotFound: // 404 - return StatusCode.Unimplemented; -#if !NETSTANDARD2_0 - case HttpStatusCode.TooManyRequests: // 429 -#else - case (HttpStatusCode)429: -#endif - case HttpStatusCode.BadGateway: // 502 - case HttpStatusCode.ServiceUnavailable: // 503 - case HttpStatusCode.GatewayTimeout: // 504 - return StatusCode.Unavailable; - default: - if ((int)httpStatusCode >= 100 && (int)httpStatusCode < 200) - { - // 1xx. These headers should have been ignored. - return StatusCode.Internal; - } - - return StatusCode.Unknown; - } - } - public Metadata GetTrailers() { using (StartScope()) @@ -375,32 +348,17 @@ public Metadata GetTrailers() } } - private void SetMessageContent(TRequest request, HttpRequestMessage message) + private void SetMessageContent(HttpContent content, HttpRequestMessage message) { RequestGrpcEncoding = GrpcProtocolHelpers.GetRequestEncoding(message.Headers); - - if (!Channel.IsWinHttp) - { - message.Content = new PushUnaryContent( - request, - this, - GrpcProtocolConstants.GrpcContentTypeHeaderValue); - } - else - { - // WinHttp doesn't support streaming request data so a length needs to be specified. - message.Content = new LengthUnaryContent( - request, - this, - GrpcProtocolConstants.GrpcContentTypeHeaderValue); - } + message.Content = content; } public void CancelCallFromCancellationToken() { using (StartScope()) { - CancelCall(new Status(StatusCode.Cancelled, "Call canceled by the client.")); + CancelCall(GrpcProtocolConstants.ClientCanceledStatus); } } @@ -478,10 +436,11 @@ private async Task RunCall(HttpRequestMessage request, TimeSpan? timeout) } catch (Exception ex) { - // Don't log OperationCanceledException if deadline has exceeded. + // Don't log OperationCanceledException if deadline has exceeded + // or the call has been canceled. if (ex is OperationCanceledException && _callTcs.Task.IsCompletedSuccessfully() && - _callTcs.Task.Result.StatusCode == StatusCode.DeadlineExceeded) + (_callTcs.Task.Result.StatusCode == StatusCode.DeadlineExceeded || _callTcs.Task.Result.StatusCode == StatusCode.Cancelled)) { throw; } @@ -492,7 +451,12 @@ private async Task RunCall(HttpRequestMessage request, TimeSpan? timeout) } } - status = ValidateHeaders(HttpResponse); + GrpcCallLog.ResponseHeadersReceived(Logger); + status = ValidateHeaders(HttpResponse, out var trailers); + if (trailers != null) + { + Trailers = trailers; + } // A status means either the call has failed or grpc-status was returned in the response header if (status != null) @@ -651,10 +615,7 @@ internal bool ResolveException(string summary, Exception ex, [NotNull] out Statu } else { - var exceptionMessage = CommonGrpcProtocolHelpers.ConvertToRpcExceptionMessage(ex); - var statusCode = GrpcProtocolHelpers.ResolveRpcExceptionStatusCode(ex); - - status = new Status(statusCode, summary + " " + exceptionMessage, ex); + status = GrpcProtocolHelpers.CreateStatusFromException(summary, ex); resolvedException = CreateRpcException(status.Value); return true; } @@ -816,12 +777,11 @@ private async Task ReadCredentials(HttpRequestMessage request) } } - private void CreateWriter(HttpRequestMessage message) + private void SetWriter(HttpRequestMessage message, HttpContentClientStreamWriter clientStreamWriter, HttpContent content) { RequestGrpcEncoding = GrpcProtocolHelpers.GetRequestEncoding(message.Headers); - ClientStreamWriter = new HttpContentClientStreamWriter(this); - - message.Content = new PushStreamContent(ClientStreamWriter, GrpcProtocolConstants.GrpcContentTypeHeaderValue); + ClientStreamWriter = clientStreamWriter; + message.Content = content; } private HttpRequestMessage CreateHttpRequestMessage(TimeSpan? timeout) @@ -842,6 +802,12 @@ private HttpRequestMessage CreateHttpRequestMessage(TimeSpan? timeout) headers.TryAddWithoutValidation(GrpcProtocolConstants.TEHeader, GrpcProtocolConstants.TEHeaderValue); headers.TryAddWithoutValidation(GrpcProtocolConstants.MessageAcceptEncodingHeader, Channel.MessageAcceptEncoding); + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#exposed-retry-metadata + if (_attemptCount > 1) + { + headers.TryAddWithoutValidation(GrpcProtocolConstants.RetryPreviousAttemptsHeader, (_attemptCount - 1).ToString(CultureInfo.InvariantCulture)); + } + if (Options.Headers != null && Options.Headers.Count > 0) { foreach (var entry in Options.Headers) @@ -929,38 +895,26 @@ private void DeadlineExceeded() internal ValueTask WriteMessageAsync( Stream stream, - TRequest message, - Action contextualSerializer, - CallOptions callOptions) - { - return stream.WriteMessageAsync( - this, - message, - contextualSerializer, - callOptions); - } - - internal ValueTask WriteMessageAsync( - Stream stream, - TRequest message, - Action contextualSerializer, - CallOptions callOptions) where TSerializationContext : SerializationContext, IMemoryOwner + ReadOnlyMemory message, + CancellationToken cancellationToken) { + MessagesWritten++; return stream.WriteMessageAsync( this, message, - contextualSerializer, - callOptions); + cancellationToken); } internal ValueTask WriteMessageAsync( Stream stream, - ReadOnlyMemory message, + TRequest message, CallOptions callOptions) { + MessagesWritten++; return stream.WriteMessageAsync( this, message, + Method.RequestMarshaller.ContextualSerializer, callOptions); } @@ -981,5 +935,10 @@ internal ValueTask WriteMessageAsync( singleMessage, cancellationToken); } + + public Task WriteClientStreamAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) + { + return ClientStreamWriter!.WriteAsync(writeFunc, state); + } } } diff --git a/src/Grpc.Net.Client/Internal/GrpcMethodInfo.cs b/src/Grpc.Net.Client/Internal/GrpcMethodInfo.cs index d495c4a7b..2158cae07 100644 --- a/src/Grpc.Net.Client/Internal/GrpcMethodInfo.cs +++ b/src/Grpc.Net.Client/Internal/GrpcMethodInfo.cs @@ -17,7 +17,10 @@ #endregion using System; +using System.Collections.Generic; +using System.Linq; using Grpc.Core; +using Grpc.Net.Client.Configuration; namespace Grpc.Net.Client.Internal { @@ -26,13 +29,115 @@ namespace Grpc.Net.Client.Internal /// internal class GrpcMethodInfo { - public GrpcMethodInfo(GrpcCallScope logScope, Uri callUri) + public GrpcMethodInfo(GrpcCallScope logScope, Uri callUri, MethodConfig? methodConfig) { LogScope = logScope; CallUri = callUri; + MethodConfig = CreateMethodConfig(methodConfig); + } + + private MethodConfigInfo? CreateMethodConfig(MethodConfig? methodConfig) + { + if (methodConfig == null) + { + return null; + } + if (methodConfig.RetryPolicy != null && methodConfig.HedgingPolicy != null) + { + throw new InvalidOperationException("Method config can't have a retry policy and hedging policy."); + } + + var m = new MethodConfigInfo(); + + if (methodConfig.RetryPolicy != null) + { + m.RetryPolicy = CreateRetryPolicy(methodConfig.RetryPolicy); + } + + if (methodConfig.HedgingPolicy != null) + { + m.HedgingPolicy = CreateHedgingPolicy(methodConfig.HedgingPolicy); + } + + return m; + } + + internal static RetryPolicyInfo CreateRetryPolicy(RetryPolicy r) + { + if (!(r.MaxAttempts > 1)) + { + throw new InvalidOperationException("Retry policy max attempts must be greater than 1."); + } + if (!(r.InitialBackoff > TimeSpan.Zero)) + { + throw new InvalidOperationException("Retry policy initial backoff must be greater than zero."); + } + if (!(r.MaxBackoff > TimeSpan.Zero)) + { + throw new InvalidOperationException("Retry policy maximum backoff must be greater than zero."); + } + if (!(r.BackoffMultiplier > 0)) + { + throw new InvalidOperationException("Retry policy backoff multiplier must be greater than 0."); + } + if (!(r.RetryableStatusCodes.Count > 0)) + { + throw new InvalidOperationException("Retry policy must specify at least 1 retryable status code."); + } + + return new RetryPolicyInfo + { + MaxAttempts = r.MaxAttempts.GetValueOrDefault(), + InitialBackoff = r.InitialBackoff.GetValueOrDefault(), + MaxBackoff = r.MaxBackoff.GetValueOrDefault(), + BackoffMultiplier = r.BackoffMultiplier.GetValueOrDefault(), + RetryableStatusCodes = r.RetryableStatusCodes.ToList() + }; + } + + internal static HedgingPolicyInfo CreateHedgingPolicy(HedgingPolicy h) + { + if (!(h.MaxAttempts > 1)) + { + throw new InvalidOperationException("Hedging policy max attempts must be greater than 1."); + } + if (!(h.HedgingDelay >= TimeSpan.Zero)) + { + throw new InvalidOperationException("Hedging policy delay must be equal or greater than zero."); + } + + return new HedgingPolicyInfo + { + MaxAttempts = h.MaxAttempts.GetValueOrDefault(), + HedgingDelay = h.HedgingDelay.GetValueOrDefault(), + NonFatalStatusCodes = h.NonFatalStatusCodes.ToList() + }; } public GrpcCallScope LogScope { get; } public Uri CallUri { get; } + public MethodConfigInfo? MethodConfig { get; } + } + + internal class MethodConfigInfo + { + public RetryPolicyInfo? RetryPolicy { get; set; } + public HedgingPolicyInfo? HedgingPolicy { get; set; } + } + + internal class RetryPolicyInfo + { + public int MaxAttempts { get; init; } + public TimeSpan InitialBackoff { get; init; } + public TimeSpan MaxBackoff { get; init; } + public double BackoffMultiplier { get; init; } + public List RetryableStatusCodes { get; init; } = default!; + } + + internal class HedgingPolicyInfo + { + public int MaxAttempts { get; set; } + public TimeSpan HedgingDelay { get; set; } + public List NonFatalStatusCodes { get; init; } = default!; } } diff --git a/src/Grpc.Net.Client/Internal/GrpcProtocolConstants.cs b/src/Grpc.Net.Client/Internal/GrpcProtocolConstants.cs index f96934dea..e19cb9869 100644 --- a/src/Grpc.Net.Client/Internal/GrpcProtocolConstants.cs +++ b/src/Grpc.Net.Client/Internal/GrpcProtocolConstants.cs @@ -22,6 +22,7 @@ using System.Linq; using System.Net.Http.Headers; using System.Reflection; +using Grpc.Core; using Grpc.Net.Compression; namespace Grpc.Net.Client.Internal @@ -46,9 +47,11 @@ internal static class GrpcProtocolConstants internal const string IdentityGrpcEncoding = "identity"; internal const string MessageAcceptEncodingHeader = "grpc-accept-encoding"; - internal const string CompressionRequestAlgorithmHeader = "grpc-internal-encoding-request"; + internal const string RetryPushbackHeader = "grpc-retry-pushback-ms"; + internal const string RetryPreviousAttemptsHeader = "grpc-previous-rpc-attempts"; + internal static readonly Dictionary DefaultCompressionProviders = new Dictionary(StringComparer.Ordinal) { ["gzip"] = new GzipCompressionProvider(System.IO.Compression.CompressionLevel.Fastest), @@ -65,6 +68,11 @@ internal static class GrpcProtocolConstants internal static readonly string TEHeader; internal static readonly string TEHeaderValue; + internal static readonly Status DeadlineExceededStatus = new Status(StatusCode.DeadlineExceeded, string.Empty); + internal static readonly Status ThrottledStatus = new Status(StatusCode.Cancelled, "Retries stopped because retry throttling is active."); + internal static readonly Status ClientCanceledStatus = new Status(StatusCode.Cancelled, "Call canceled by the client."); + internal static readonly Status DisposeCanceledStatus = new Status(StatusCode.Cancelled, "gRPC call disposed."); + internal static string GetMessageAcceptEncoding(Dictionary compressionProviders) { return IdentityGrpcEncoding + "," + diff --git a/src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs b/src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs index 445327574..eca30280f 100644 --- a/src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs +++ b/src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs @@ -402,6 +402,9 @@ public static StatusCode ResolveRpcExceptionStatusCode(Exception ex) } else if (current is IOException) { + // TODO(JamesNK): IOException is also returned for aborted requests. + // Need to think about what is the best status for aborted requests. + // IOException happens if there is a protocol mismatch. return StatusCode.Unavailable; } @@ -409,5 +412,13 @@ public static StatusCode ResolveRpcExceptionStatusCode(Exception ex) return StatusCode.Internal; } + + public static Status CreateStatusFromException(string summary, Exception ex) + { + var exceptionMessage = CommonGrpcProtocolHelpers.ConvertToRpcExceptionMessage(ex); + var statusCode = ResolveRpcExceptionStatusCode(ex); + + return new Status(statusCode, summary + " " + exceptionMessage, ex); + } } } diff --git a/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs b/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs index e8ed68ad0..d2aee84e3 100644 --- a/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs +++ b/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs @@ -20,9 +20,12 @@ using System.IO; using System.Net; using System.Net.Http; -using System.Net.Http.Headers; using System.Threading.Tasks; +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif + namespace Grpc.Net.Client.Internal.Http { internal class PushStreamContent : HttpContent @@ -30,19 +33,32 @@ internal class PushStreamContent : HttpContent where TResponse : class { private readonly HttpContentClientStreamWriter _streamWriter; + private readonly Func? _startCallback; - public PushStreamContent(HttpContentClientStreamWriter streamWriter, MediaTypeHeaderValue mediaType) + public PushStreamContent(HttpContentClientStreamWriter streamWriter) { - Headers.ContentType = mediaType; + Headers.ContentType = GrpcProtocolConstants.GrpcContentTypeHeaderValue; _streamWriter = streamWriter; } + public PushStreamContent( + HttpContentClientStreamWriter streamWriter, + Func? startCallback) : this(streamWriter) + { + _startCallback = startCallback; + } + protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context) { // Immediately flush request stream to send headers // https://github.com/dotnet/corefx/issues/39586#issuecomment-516210081 await stream.FlushAsync().ConfigureAwait(false); + if (_startCallback != null) + { + await _startCallback(stream).ConfigureAwait(false); + } + // Pass request stream to writer _streamWriter.WriteStreamTcs.TrySetResult(stream); diff --git a/src/Grpc.Net.Client/Internal/Http/PushUnaryContent.cs b/src/Grpc.Net.Client/Internal/Http/PushUnaryContent.cs index acedd9ce8..01a9ac761 100644 --- a/src/Grpc.Net.Client/Internal/Http/PushUnaryContent.cs +++ b/src/Grpc.Net.Client/Internal/Http/PushUnaryContent.cs @@ -16,40 +16,37 @@ #endregion +using System; using System.IO; using System.Net; using System.Net.Http; -using System.Net.Http.Headers; using System.Threading.Tasks; #if NETSTANDARD2_0 using ValueTask = System.Threading.Tasks.Task; #endif -namespace Grpc.Net.Client.Internal.Http +namespace Grpc.Net.Client.Internal { + // TODO: Still need generic args? internal class PushUnaryContent : HttpContent where TRequest : class where TResponse : class { - private readonly TRequest _content; - private readonly GrpcCall _call; + private readonly TRequest _request; + private readonly Func _startCallback; - public PushUnaryContent(TRequest content, GrpcCall call, MediaTypeHeaderValue mediaType) + public PushUnaryContent(TRequest request, Func startCallback) { - _content = content; - _call = call; - Headers.ContentType = mediaType; + _request = request; + _startCallback = startCallback; + Headers.ContentType = GrpcProtocolConstants.GrpcContentTypeHeaderValue; } protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) { #pragma warning disable CA2012 // Use ValueTasks correctly - var writeMessageTask = _call.WriteMessageAsync( - stream, - _content, - _call.Method.RequestMarshaller.ContextualSerializer, - _call.Options); + var writeMessageTask = _startCallback(_request, stream); #pragma warning restore CA2012 // Use ValueTasks correctly if (writeMessageTask.IsCompletedSuccessfully()) { diff --git a/src/Grpc.Net.Client/Internal/Http/LengthUnaryContent.cs b/src/Grpc.Net.Client/Internal/Http/WinHttpUnaryContent.cs similarity index 60% rename from src/Grpc.Net.Client/Internal/Http/LengthUnaryContent.cs rename to src/Grpc.Net.Client/Internal/Http/WinHttpUnaryContent.cs index 924dced98..945b6c2aa 100644 --- a/src/Grpc.Net.Client/Internal/Http/LengthUnaryContent.cs +++ b/src/Grpc.Net.Client/Internal/Http/WinHttpUnaryContent.cs @@ -37,64 +37,66 @@ namespace Grpc.Net.Client.Internal.Http /// The payload is then written directly to the request using specialized context /// and serializer method. /// - internal class LengthUnaryContent : HttpContent + internal class WinHttpUnaryContent : HttpContent where TRequest : class where TResponse : class { - private readonly TRequest _content; + private readonly TRequest _request; + private readonly Func _startCallback; private readonly GrpcCall _call; - private byte[]? _payload; - public LengthUnaryContent(TRequest content, GrpcCall call, MediaTypeHeaderValue mediaType) + public WinHttpUnaryContent(TRequest request, Func startCallback, GrpcCall call) { - _content = content; + _request = request; + _startCallback = startCallback; _call = call; - Headers.ContentType = mediaType; + Headers.ContentType = GrpcProtocolConstants.GrpcContentTypeHeaderValue; } - // Serialize message. Need to know size to prefix the length in the header. - private byte[] SerializePayload() + protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) { - var serializationContext = _call.SerializationContext; - serializationContext.CallOptions = _call.Options; - serializationContext.Initialize(); - - try - { - _call.Method.RequestMarshaller.ContextualSerializer(_content, serializationContext); - - return serializationContext.GetWrittenPayload().ToArray(); - } - finally +#pragma warning disable CA2012 // Use ValueTasks correctly + var writeMessageTask = _startCallback(_request, stream); +#pragma warning restore CA2012 // Use ValueTasks correctly + if (writeMessageTask.IsCompletedSuccessfully()) { - serializationContext.Reset(); + GrpcEventSource.Log.MessageSent(); + return Task.CompletedTask; } + + return WriteMessageCore(writeMessageTask); } - protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context) + private static async Task WriteMessageCore(ValueTask writeMessageTask) { - if (_payload == null) - { - _payload = SerializePayload(); - } - - await _call.WriteMessageAsync( - stream, - _payload, - _call.Options).ConfigureAwait(false); - + await writeMessageTask.ConfigureAwait(false); GrpcEventSource.Log.MessageSent(); } protected override bool TryComputeLength(out long length) { - if (_payload == null) + // This will serialize the request message again. + // Consider caching serialized content if it is a problem. + length = GetPayloadLength(); + return true; + } + + private int GetPayloadLength() + { + var serializationContext = _call.SerializationContext; + serializationContext.CallOptions = _call.Options; + serializationContext.Initialize(); + + try { - _payload = SerializePayload(); - } + _call.Method.RequestMarshaller.ContextualSerializer(_request, serializationContext); - length = _payload.Length; - return true; + return serializationContext.GetWrittenPayload().Length; + } + finally + { + serializationContext.Reset(); + } } } } diff --git a/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs b/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs index 2cec54229..275949da5 100644 --- a/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs +++ b/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs @@ -19,6 +19,8 @@ using System; using System.Threading.Tasks; using Grpc.Core; +using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Internal.Retry; namespace Grpc.Net.Client.Internal { @@ -40,7 +42,7 @@ public HttpClientCallInvoker(GrpcChannel channel) /// public override AsyncClientStreamingCall AsyncClientStreamingCall(Method method, string host, CallOptions options) { - var call = CreateGrpcCall(method, options); + var call = CreateRootGrpcCall(Channel, method, options); call.StartClientStreaming(); return new AsyncClientStreamingCall( @@ -60,7 +62,7 @@ public override AsyncClientStreamingCall AsyncClientStreami /// public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(Method method, string host, CallOptions options) { - var call = CreateGrpcCall(method, options); + var call = CreateRootGrpcCall(Channel, method, options); call.StartDuplexStreaming(); return new AsyncDuplexStreamingCall( @@ -79,7 +81,7 @@ public override AsyncDuplexStreamingCall AsyncDuplexStreami /// public override AsyncServerStreamingCall AsyncServerStreamingCall(Method method, string host, CallOptions options, TRequest request) { - var call = CreateGrpcCall(method, options); + var call = CreateRootGrpcCall(Channel, method, options); call.StartServerStreaming(request); return new AsyncServerStreamingCall( @@ -96,7 +98,7 @@ public override AsyncServerStreamingCall AsyncServerStreamingCall public override AsyncUnaryCall AsyncUnaryCall(Method method, string host, CallOptions options, TRequest request) { - var call = CreateGrpcCall(method, options); + var call = CreateRootGrpcCall(Channel, method, options); call.StartUnary(request); return new AsyncUnaryCall( @@ -117,19 +119,47 @@ public override TResponse BlockingUnaryCall(Method CreateGrpcCall( + private static IGrpcCall CreateRootGrpcCall( + GrpcChannel channel, Method method, CallOptions options) where TRequest : class where TResponse : class { - if (Channel.Disposed) + var methodInfo = channel.GetCachedGrpcMethodInfo(method); + var retryPolicy = methodInfo.MethodConfig?.RetryPolicy; + var hedgingPolicy = methodInfo.MethodConfig?.HedgingPolicy; + + if (retryPolicy != null) + { + return new RetryCall(retryPolicy, channel, method, options); + } + else if (hedgingPolicy != null) + { + return new HedgingCall(hedgingPolicy, channel, method, options); + } + else + { + // No retry/hedge policy configured. Fast path! + return CreateGrpcCall(channel, method, options, attempt: 1); + } + } + + public static GrpcCall CreateGrpcCall( + GrpcChannel channel, + Method method, + CallOptions options, + int attempt) + where TRequest : class + where TResponse : class + { + if (channel.Disposed) { throw new ObjectDisposedException(nameof(GrpcChannel)); } - var methodInfo = Channel.GetCachedGrpcMethodInfo(method); - var call = new GrpcCall(method, methodInfo, options, Channel); + var methodInfo = channel.GetCachedGrpcMethodInfo(method); + var call = new GrpcCall(method, methodInfo, options, channel, attempt); return call; } @@ -139,10 +169,10 @@ private static class Callbacks where TRequest : class where TResponse : class { - internal static readonly Func> GetResponseHeadersAsync = state => ((GrpcCall)state).GetResponseHeadersAsync(); - internal static readonly Func GetStatus = state => ((GrpcCall)state).GetStatus(); - internal static readonly Func GetTrailers = state => ((GrpcCall)state).GetTrailers(); - internal static readonly Action Dispose = state => ((GrpcCall)state).Dispose(); + internal static readonly Func> GetResponseHeadersAsync = state => ((IGrpcCall)state).GetResponseHeadersAsync(); + internal static readonly Func GetStatus = state => ((IGrpcCall)state).GetStatus(); + internal static readonly Func GetTrailers = state => ((IGrpcCall)state).GetTrailers(); + internal static readonly Action Dispose = state => ((IGrpcCall)state).Dispose(); } } } diff --git a/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs b/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs index fa040ccf0..736ac08f8 100644 --- a/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs +++ b/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs @@ -31,7 +31,7 @@ internal class HttpContentClientStreamReader : IAsyncStream where TRequest : class where TResponse : class { - // Getting logger name from generic type is slow + // Getting logger name from generic type is slow. Cached copy. private const string LoggerName = "Grpc.Net.Client.Internal.HttpContentClientStreamReader"; private static readonly Task FinishedTask = Task.FromResult(false); diff --git a/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs b/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs index 7b26faf23..879dffeb9 100644 --- a/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs +++ b/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs @@ -18,14 +18,16 @@ using System; using System.IO; -using System.Net.Http; using System.Threading.Tasks; using Grpc.Core; -using Microsoft.Extensions.Logging; + +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif namespace Grpc.Net.Client.Internal { - internal class HttpContentClientStreamWriter : IClientStreamWriter + internal class HttpContentClientStreamWriter : ClientStreamWriterBase where TRequest : class where TResponse : class { @@ -33,42 +35,38 @@ internal class HttpContentClientStreamWriter : IClientStrea private const string LoggerName = "Grpc.Net.Client.Internal.HttpContentClientStreamWriter"; private readonly GrpcCall _call; - private readonly ILogger _logger; - private readonly object _writeLock; - private Task? _writeTask; private bool _completeCalled; public TaskCompletionSource WriteStreamTcs { get; } public TaskCompletionSource CompleteTcs { get; } public HttpContentClientStreamWriter(GrpcCall call) + : base(call.Channel.LoggerFactory.CreateLogger(LoggerName)) { _call = call; - _logger = call.Channel.LoggerFactory.CreateLogger(LoggerName); WriteStreamTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); CompleteTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _writeLock = new object(); WriteOptions = _call.Options.WriteOptions; } - public WriteOptions WriteOptions { get; set; } + public override WriteOptions? WriteOptions { get; set; } - public Task CompleteAsync() + public override Task CompleteAsync() { _call.EnsureNotDisposed(); using (_call.StartScope()) { - Log.CompletingClientStream(_logger); + Log.CompletingClientStream(Logger); - lock (_writeLock) + lock (WriteLock) { // Pending writes need to be awaited first if (IsWriteInProgressUnsynchronized) { var ex = new InvalidOperationException("Can't complete the client stream writer because the previous write is in progress."); - Log.CompleteClientStreamError(_logger, ex); + Log.CompleteClientStreamError(Logger, ex); return Task.FromException(ex); } @@ -81,16 +79,26 @@ public Task CompleteAsync() return Task.CompletedTask; } - public Task WriteAsync(TRequest message) + public override Task WriteAsync(TRequest message) { if (message == null) { throw new ArgumentNullException(nameof(message)); } + return WriteAsync(WriteMessageToStream, message); + + static ValueTask WriteMessageToStream(GrpcCall call, Stream writeStream, CallOptions callOptions, TRequest message) + { + return call.WriteMessageAsync(writeStream, message, callOptions); + } + } + + public Task WriteAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) + { _call.EnsureNotDisposed(); - lock (_writeLock) + lock (WriteLock) { using (_call.StartScope()) { @@ -122,27 +130,16 @@ public Task WriteAsync(TRequest message) } // Save write task to track whether it is complete. Must be set inside lock. - _writeTask = WriteAsyncCore(message); + WriteTask = WriteAsyncCore(writeFunc, state); } } - return _writeTask; - } - - private Task CreateErrorTask(string message) - { - var ex = new InvalidOperationException(message); - Log.WriteMessageError(_logger, ex); - return Task.FromException(ex); - } - - public void Dispose() - { + return WriteTask; } public GrpcCall Call => _call; - private async Task WriteAsyncCore(TRequest message) + public async Task WriteAsyncCore(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) { try { @@ -157,11 +154,7 @@ private async Task WriteAsyncCore(TRequest message) callOptions = callOptions.WithWriteOptions(WriteOptions); } - await _call.WriteMessageAsync( - writeStream, - message, - _call.Method.RequestMarshaller.ContextualSerializer, - callOptions).ConfigureAwait(false); + await writeFunc(_call, writeStream, callOptions, state).ConfigureAwait(false); // Flush stream to ensure messages are sent immediately await writeStream.FlushAsync(callOptions.CancellationToken).ConfigureAwait(false); @@ -173,45 +166,5 @@ await _call.WriteMessageAsync( throw _call.CreateCanceledStatusException(); } } - - /// - /// A value indicating whether there is an async write already in progress. - /// Should only check this property when holding the write lock. - /// - private bool IsWriteInProgressUnsynchronized - { - get - { - var writeTask = _writeTask; - return writeTask != null && !writeTask.IsCompleted; - } - } - - private static class Log - { - private static readonly Action _completingClientStream = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, "CompletingClientStream"), "Completing client stream."); - - private static readonly Action _writeMessageError = - LoggerMessage.Define(LogLevel.Error, new EventId(2, "WriteMessageError"), "Error writing message."); - - private static readonly Action _completeClientStreamError = - LoggerMessage.Define(LogLevel.Error, new EventId(3, "CompleteClientStreamError"), "Error completing client stream."); - - public static void CompletingClientStream(ILogger logger) - { - _completingClientStream(logger, null); - } - - public static void WriteMessageError(ILogger logger, Exception ex) - { - _writeMessageError(logger, ex); - } - - public static void CompleteClientStreamError(ILogger logger, Exception ex) - { - _completeClientStreamError(logger, ex); - } - } } } diff --git a/src/Grpc.Net.Client/Internal/IGrpcCall.cs b/src/Grpc.Net.Client/Internal/IGrpcCall.cs new file mode 100644 index 000000000..d796dd034 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/IGrpcCall.cs @@ -0,0 +1,49 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.IO; +using System.Threading.Tasks; +using Grpc.Core; + +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif + +namespace Grpc.Net.Client.Internal +{ + internal interface IGrpcCall : IDisposable + where TRequest : class + where TResponse : class + { + Task GetResponseAsync(); + Task GetResponseHeadersAsync(); + Status GetStatus(); + Metadata GetTrailers(); + + IClientStreamWriter? ClientStreamWriter { get; } + IAsyncStreamReader? ClientStreamReader { get; } + + void StartUnary(TRequest request); + void StartClientStreaming(); + void StartServerStreaming(TRequest request); + void StartDuplexStreaming(); + + Task WriteClientStreamAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state); + } +} diff --git a/src/Grpc.Net.Client/Internal/IsExternalInit.cs b/src/Grpc.Net.Client/Internal/IsExternalInit.cs new file mode 100644 index 000000000..95482cdf3 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/IsExternalInit.cs @@ -0,0 +1,29 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace System.Runtime.CompilerServices +{ + // Required for init properties in netstandard2.0 + [ExcludeFromCodeCoverage, DebuggerNonUserCode] + internal static class IsExternalInit + { + } +} \ No newline at end of file diff --git a/src/Grpc.Net.Client/Internal/Retry/ChannelRetryThrottling.cs b/src/Grpc.Net.Client/Internal/Retry/ChannelRetryThrottling.cs new file mode 100644 index 000000000..356c72a9d --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/ChannelRetryThrottling.cs @@ -0,0 +1,100 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Diagnostics; +using System.Threading; +using Grpc.Net.Client.Configuration; +using Microsoft.Extensions.Logging; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal class ChannelRetryThrottling + { + private readonly object _lock = new object(); + private readonly double _tokenRatio; + private readonly int _maxTokens; + private readonly ILogger _logger; + + private double _tokenCount; + private double _tokenThreshold; + private bool _isRetryThrottlingActive; + + public ChannelRetryThrottling(int maxTokens, double tokenRatio, ILoggerFactory loggerFactory) + { + // Truncate token ratio to 3 decimal places + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#validation-of-retrythrottling + _tokenRatio = Math.Truncate(tokenRatio * 1000) / 1000; + + _maxTokens = maxTokens; + _tokenCount = maxTokens; + _tokenThreshold = _tokenCount / 2; + _logger = loggerFactory.CreateLogger(); + } + + public bool IsRetryThrottlingActive() + { + lock (_lock) + { + return _isRetryThrottlingActive; + } + } + + public void CallSuccess() + { + lock (_lock) + { + _tokenCount = Math.Min(_tokenCount + _tokenRatio, _maxTokens); + UpdateRetryThrottlingActive(); + } + } + + public void CallFailure() + { + lock (_lock) + { + _tokenCount = Math.Max(_tokenCount - 1, 0); + UpdateRetryThrottlingActive(); + } + } + + private void UpdateRetryThrottlingActive() + { + Debug.Assert(Monitor.IsEntered(_lock)); + + var newRetryThrottlingActive = _tokenCount <= _tokenThreshold; + + if (newRetryThrottlingActive != _isRetryThrottlingActive) + { + _isRetryThrottlingActive = newRetryThrottlingActive; + Log.RetryThrottlingActiveChanged(_logger, _isRetryThrottlingActive); + } + } + + private static class Log + { + private static readonly Action _retryThrottlingActiveChanged = + LoggerMessage.Define(LogLevel.Trace, new EventId(1, "RetryThrottlingActiveChanged"), "Retry throttling active state changed. New value: {RetryThrottlingActive}"); + + public static void RetryThrottlingActiveChanged(ILogger logger, bool isRetryThrottlingActive) + { + _retryThrottlingActiveChanged(logger, isRetryThrottlingActive, null); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/CommitReason.cs b/src/Grpc.Net.Client/Internal/Retry/CommitReason.cs new file mode 100644 index 000000000..166e7032d --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/CommitReason.cs @@ -0,0 +1,34 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + + +namespace Grpc.Net.Client.Internal.Retry +{ + internal enum CommitReason + { + ResponseHeadersReceived, + FatalStatusCode, + ExceededAttemptCount, + DeadlineExceeded, + Throttled, + BufferExceeded, + PushbackStop, + UnexpectedError, + Canceled + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs b/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs new file mode 100644 index 000000000..e39635ffd --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs @@ -0,0 +1,422 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal sealed partial class HedgingCall : RetryCallBase + where TRequest : class + where TResponse : class + { + // Getting logger name from generic type is slow. Cached copy. + private const string LoggerName = "Grpc.Net.Client.Internal.HedgingCall"; + + private readonly HedgingPolicyInfo _hedgingPolicy; + + private CancellationTokenSource? _hedgingDelayCts; + private TaskCompletionSource? _delayInterruptTcs; + private TimeSpan? _pushbackDelay; + + // Internal for testing + internal List> _activeCalls { get; } + internal Task? CreateHedgingCallsTask { get; set; } + + public HedgingCall(HedgingPolicyInfo hedgingPolicy, GrpcChannel channel, Method method, CallOptions options) + : base(channel, method, options, LoggerName, hedgingPolicy.MaxAttempts) + { + _hedgingPolicy = hedgingPolicy; + _activeCalls = new List>(); + + if (_hedgingPolicy.HedgingDelay > TimeSpan.Zero) + { + _delayInterruptTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _hedgingDelayCts = new CancellationTokenSource(); + } + } + + private async Task StartCall(Action> startCallFunc) + { + GrpcCall call; + lock (Lock) + { + if (CommitedCallTask.IsCompletedSuccessfully()) + { + // Call has already been commited. This could happen if written messages exceed + // buffer limits, which causes the call to immediately become commited and to clear buffers. + return; + } + + OnStartingAttempt(); + + call = HttpClientCallInvoker.CreateGrpcCall(Channel, Method, Options, AttemptCount); + _activeCalls.Add(call); + + startCallFunc(call); + + SetNewActiveCallUnsynchronized(call); + } + + Status? responseStatus; + + HttpResponseMessage? httpResponse = null; + try + { + call.CancellationToken.ThrowIfCancellationRequested(); + + CompatibilityExtensions.Assert(call._httpResponseTask != null, "Request should have been made if call is not preemptively cancelled."); + httpResponse = await call._httpResponseTask.ConfigureAwait(false); + + responseStatus = GrpcCall.ValidateHeaders(httpResponse, out _); + } + catch (Exception ex) + { + call.ResolveException(GrpcCall.ErrorStartingCallMessage, ex, out responseStatus, out _); + } + + if (CancellationTokenSource.IsCancellationRequested) + { + CommitCall(call, CommitReason.Canceled); + return; + } + + // Check to see the response returned from the server makes the call commited + // Null status code indicates the headers were valid and a "Response-Headers" response + // was received from the server. + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#when-retries-are-valid + if (responseStatus == null) + { + // Headers were returned. We're commited. + CommitCall(call, CommitReason.ResponseHeadersReceived); + + // Wait until the call has finished and then check its status code + // to update retry throttling tokens. + var status = await call.CallTask.ConfigureAwait(false); + if (status.StatusCode == StatusCode.OK) + { + RetryAttemptCallSuccess(); + } + } + else + { + var status = responseStatus.Value; + + var retryPushbackMS = GetRetryPushback(httpResponse); + + if (retryPushbackMS < 0) + { + RetryAttemptCallFailure(); + } + else if (_hedgingPolicy.NonFatalStatusCodes.Contains(status.StatusCode)) + { + // Needs to happen before interrupt. + RetryAttemptCallFailure(); + + // No need to interrupt if we started with no delay and all calls + // have already been made when hedging starting. + if (_delayInterruptTcs != null) + { + lock (Lock) + { + if (retryPushbackMS >= 0) + { + _pushbackDelay = TimeSpan.FromMilliseconds(retryPushbackMS.GetValueOrDefault()); + } + _delayInterruptTcs.TrySetResult(null); + } + } + } + else + { + CommitCall(call, CommitReason.FatalStatusCode); + } + } + + lock (Lock) + { + if (IsDeadlineExceeded()) + { + // Deadline has been exceeded so immediately commit call. + CommitCall(call, CommitReason.DeadlineExceeded); + } + else if (_activeCalls.Count == 1 && AttemptCount >= MaxRetryAttempts) + { + // This is the last active call and no more will be made. + CommitCall(call, CommitReason.ExceededAttemptCount); + } + else if (_activeCalls.Count == 1 && IsRetryThrottlingActive()) + { + // This is the last active call and throttling is active. + CommitCall(call, CommitReason.Throttled); + } + else + { + // Call isn't used and can be cancelled. + // Note that the call could have already been removed and disposed if the + // hedging call has been finalized or disposed. + if (_activeCalls.Remove(call)) + { + call.Dispose(); + } + } + } + } + + protected override void OnCommitCall(IGrpcCall call) + { + _activeCalls.Remove(call); + + CleanUpUnsynchronized(); + } + + private void CleanUpUnsynchronized() + { + Debug.Assert(Monitor.IsEntered(Lock)); + + while (_activeCalls.Count > 0) + { + _activeCalls[_activeCalls.Count - 1].Dispose(); + _activeCalls.RemoveAt(_activeCalls.Count - 1); + } + } + + protected override void StartCore(Action> startCallFunc) + { + var hedgingDelay = _hedgingPolicy.HedgingDelay; + if (hedgingDelay == TimeSpan.Zero) + { + // If there is no delay then start all call immediately + while (AttemptCount < MaxRetryAttempts) + { + _ = StartCall(startCallFunc); + + // Don't send additional calls if retry throttling is active. + if (IsRetryThrottlingActive()) + { + Log.AdditionalCallsBlockedByRetryThrottling(Logger); + break; + } + + lock (Lock) + { + // Don't send additional calls if call has been commited. + if (CommitedCallTask.IsCompletedSuccessfully()) + { + break; + } + } + } + } + else + { + CreateHedgingCallsTask = CreateHedgingCalls(startCallFunc); + } + } + + private async Task CreateHedgingCalls(Action> startCallFunc) + { + Log.StartingRetryWorker(Logger); + + try + { + var hedgingDelay = _hedgingPolicy.HedgingDelay; + + while (AttemptCount < MaxRetryAttempts) + { + _ = StartCall(startCallFunc); + + await HedgingDelayAsync(hedgingDelay).ConfigureAwait(false); + + if (IsDeadlineExceeded()) + { + CommitCall(new StatusGrpcCall(new Status(StatusCode.DeadlineExceeded, string.Empty)), CommitReason.DeadlineExceeded); + break; + } + else + { + lock (Lock) + { + if (IsRetryThrottlingActive()) + { + if (_activeCalls.Count == 0) + { + CommitCall(CreateStatusCall(GrpcProtocolConstants.ThrottledStatus), CommitReason.Throttled); + } + else + { + Log.AdditionalCallsBlockedByRetryThrottling(Logger); + } + break; + } + + // Don't send additional calls if call has been commited. + if (CommitedCallTask.IsCompletedSuccessfully()) + { + break; + } + } + } + } + } + catch (Exception ex) + { + HandleUnexpectedError(ex); + } + finally + { + Log.StoppingRetryWorker(Logger); + } + } + + private async Task HedgingDelayAsync(TimeSpan hedgingDelay) + { + CompatibilityExtensions.Assert(_hedgingDelayCts != null); + CompatibilityExtensions.Assert(_delayInterruptTcs != null); + + while (true) + { + CompatibilityExtensions.Assert(_hedgingDelayCts != null); + + var completedTask = await Task.WhenAny(Task.Delay(hedgingDelay, _hedgingDelayCts.Token), _delayInterruptTcs.Task).ConfigureAwait(false); + if (completedTask != _delayInterruptTcs.Task) + { + // Task.Delay won. Check CTS to see if it won because of cancellation. + _hedgingDelayCts.Token.ThrowIfCancellationRequested(); + return; + } + else + { + // Cancel the Task.Delay that's no longer needed. + // https://github.com/davidfowl/AspNetCoreDiagnosticScenarios/blob/519ef7d231c01116f02bc04354816a735f2a36b6/AsyncGuidance.md#using-a-timeout + _hedgingDelayCts.Cancel(); + } + + lock (Lock) + { + // If we reaching this point then the delay was interrupted. + // Need to recreate the delay TCS/CTS for the next cycle. + _delayInterruptTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _hedgingDelayCts = new CancellationTokenSource(); + + // Interrupt could come from a pushback, or a failing call with a non-fatal status. + if (_pushbackDelay != null) + { + // Use pushback value and delay again + hedgingDelay = _pushbackDelay.GetValueOrDefault(); + + _pushbackDelay = null; + } + else + { + // Immediately return for non-fatal status. + return; + } + } + } + } + + protected override void Dispose(bool disposing) + { + lock (Lock) + { + base.Dispose(disposing); + + CleanUpUnsynchronized(); + } + } + + public override Task ClientStreamCompleteAsync() + { + ClientStreamComplete = true; + + return DoClientStreamActionAsync(calls => + { + var completeTasks = new Task[calls.Count]; + for (var i = 0; i < calls.Count; i++) + { + completeTasks[i] = calls[i].ClientStreamWriter!.CompleteAsync(); + } + + return Task.WhenAll(completeTasks); + }); + } + + public override async Task ClientStreamWriteAsync(TRequest message) + { + // The retry client stream writer prevents multiple threads from reaching here. + await DoClientStreamActionAsync(calls => + { + var writeTasks = new Task[calls.Count]; + for (var i = 0; i < calls.Count; i++) + { + writeTasks[i] = calls[i].WriteClientStreamAsync(WriteNewMessage, message); + } + + return Task.WhenAll(writeTasks); + }).ConfigureAwait(false); + + lock (Lock) + { + BufferedCurrentMessage = false; + } + } + + private Task DoClientStreamActionAsync(Func>, Task> action) + { + // During a client streaming or bidirectional streaming call the app will call + // WriteAsync and CompleteAsync on the call request stream. If the call fails then + // an error will be thrown from those methods. + // + // The logic here will get the active call, apply the app action to the request stream. + // If there is an error we wait for the new active call and then run the user action on it again. + // Keep going until either the action succeeds, or there is no new active call + // because of exceeded attempts, non-retry status code or retry throttling. + // + // Because of hedging, multiple active calls can be in-progress. Apply action to all. + + lock (Lock) + { + if (_activeCalls.Count > 0) + { + return action(_activeCalls); + } + else + { + return WaitForCallUnsynchronizedAsync(action); + } + } + + async Task WaitForCallUnsynchronizedAsync(Func>, Task> action) + { + var call = await GetActiveCallUnsynchronizedAsync(previousCall: null).ConfigureAwait(false); + await action(new[] { call! }).ConfigureAwait(false); + } + } + + protected override void OnCancellation() + { + _hedgingDelayCts?.Cancel(); + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs new file mode 100644 index 000000000..b2b52768f --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs @@ -0,0 +1,337 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Diagnostics; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal sealed class RetryCall : RetryCallBase + where TRequest : class + where TResponse : class + { + // Getting logger name from generic type is slow. Cached copy. + private const string LoggerName = "Grpc.Net.Client.Internal.RetryCall"; + + private readonly RetryPolicyInfo _retryPolicy; + + private int _nextRetryDelayMilliseconds; + + private GrpcCall? _activeCall; + + public RetryCall(RetryPolicyInfo retryPolicy, GrpcChannel channel, Method method, CallOptions options) + : base(channel, method, options, LoggerName, retryPolicy.MaxAttempts) + { + _retryPolicy = retryPolicy; + + _nextRetryDelayMilliseconds = Convert.ToInt32(retryPolicy.InitialBackoff.TotalMilliseconds); + } + + private int CalculateNextRetryDelay() + { + var nextMilliseconds = _nextRetryDelayMilliseconds * _retryPolicy.BackoffMultiplier; + nextMilliseconds = Math.Min(nextMilliseconds, _retryPolicy.MaxBackoff.TotalMilliseconds); + + return Convert.ToInt32(nextMilliseconds); + } + + private CommitReason? EvaluateRetry(Status status, int? retryPushbackMilliseconds) + { + if (IsDeadlineExceeded()) + { + return CommitReason.DeadlineExceeded; + } + + if (IsRetryThrottlingActive()) + { + return CommitReason.Throttled; + } + + if (AttemptCount >= MaxRetryAttempts) + { + return CommitReason.ExceededAttemptCount; + } + + if (retryPushbackMilliseconds != null) + { + if (retryPushbackMilliseconds >= 0) + { + return null; + } + else + { + return CommitReason.PushbackStop; + } + } + + if (!_retryPolicy.RetryableStatusCodes.Contains(status.StatusCode)) + { + return CommitReason.FatalStatusCode; + } + + return null; + } + + private async Task StartRetry(Action> startCallFunc) + { + Log.StartingRetryWorker(Logger); + + try + { + // This is the main retry loop. It will: + // 1. Check the result of the active call was successful. + // 2. If it was unsuccessful then evaluate if the call can be retried. + // 3. If it can be retried then start a new active call and begin again. + while (true) + { + GrpcCall currentCall; + lock (Lock) + { + // Start new call. + OnStartingAttempt(); + + currentCall = _activeCall = HttpClientCallInvoker.CreateGrpcCall(Channel, Method, Options, AttemptCount); + startCallFunc(currentCall); + + if (CommitedCallTask.IsCompletedSuccessfully()) + { + // Call has already been commited. This could happen if written messages exceed + // buffer limits, which causes the call to immediately become commited and to clear buffers. + return; + } + + SetNewActiveCallUnsynchronized(currentCall); + } + + Status? responseStatus; + + HttpResponseMessage? httpResponse = null; + try + { + currentCall.CancellationToken.ThrowIfCancellationRequested(); + + CompatibilityExtensions.Assert(currentCall._httpResponseTask != null, "Request should have been made if call is not preemptively cancelled."); + httpResponse = await currentCall._httpResponseTask.ConfigureAwait(false); + + responseStatus = GrpcCall.ValidateHeaders(httpResponse, out _); + } + catch (Exception ex) + { + currentCall.ResolveException(GrpcCall.ErrorStartingCallMessage, ex, out responseStatus, out _); + } + + CancellationTokenSource.Token.ThrowIfCancellationRequested(); + + // Check to see the response returned from the server makes the call commited + // Null status code indicates the headers were valid and a "Response-Headers" response + // was received from the server. + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#when-retries-are-valid + if (responseStatus == null) + { + // Headers were returned. We're commited. + CommitCall(currentCall, CommitReason.ResponseHeadersReceived); + + responseStatus = await currentCall.CallTask.ConfigureAwait(false); + if (responseStatus.GetValueOrDefault().StatusCode == StatusCode.OK) + { + RetryAttemptCallSuccess(); + } + + // Commited so exit retry loop. + return; + } + + if (CommitedCallTask.IsCompletedSuccessfully()) + { + // Call has already been commited. This could happen if written messages exceed + // buffer limits, which causes the call to immediately become commited and to clear buffers. + return; + } + + Status status = responseStatus.Value; + + var retryPushbackMS = GetRetryPushback(httpResponse); + + // Failures only count towards retry throttling if they have a known, retriable status. + // This stops non-transient statuses, e.g. INVALID_ARGUMENT, from triggering throttling. + if (_retryPolicy.RetryableStatusCodes.Contains(status.StatusCode) || + retryPushbackMS < 0) + { + RetryAttemptCallFailure(); + } + + var result = EvaluateRetry(status, retryPushbackMS); + Log.RetryEvaluated(Logger, status.StatusCode, AttemptCount, result == null); + + if (result == null) + { + TimeSpan delayDuration; + if (retryPushbackMS != null) + { + delayDuration = TimeSpan.FromMilliseconds(retryPushbackMS.GetValueOrDefault()); + _nextRetryDelayMilliseconds = retryPushbackMS.GetValueOrDefault(); + } + else + { + delayDuration = TimeSpan.FromMilliseconds(Channel.GetRandomNumber(0, Convert.ToInt32(_nextRetryDelayMilliseconds))); + } + + Log.StartingRetryDelay(Logger, delayDuration); + await Task.Delay(delayDuration, CancellationTokenSource.Token).ConfigureAwait(false); + + _nextRetryDelayMilliseconds = CalculateNextRetryDelay(); + + // Check if dispose was called on call. + CancellationTokenSource.Token.ThrowIfCancellationRequested(); + + // Clean up the failed call. + currentCall.Dispose(); + } + else + { + // Handle the situation where the call failed with a non-deadline status, but retry + // didn't happen because of deadline exceeded. + IGrpcCall resolvedCall = (IsDeadlineExceeded() && !(currentCall.CallTask.IsCompletedSuccessfully() && currentCall.CallTask.Result.StatusCode == StatusCode.DeadlineExceeded)) + ? CreateStatusCall(GrpcProtocolConstants.DeadlineExceededStatus) + : currentCall; + + // Can't retry. + // Signal public API exceptions that they should finish throwing and then exit the retry loop. + CommitCall(resolvedCall, result.GetValueOrDefault()); + return; + } + } + } + catch (Exception ex) + { + HandleUnexpectedError(ex); + } + finally + { + Log.StoppingRetryWorker(Logger); + } + } + + protected override void OnCommitCall(IGrpcCall call) + { + _activeCall = null; + } + + protected override void Dispose(bool disposing) + { + lock (Lock) + { + base.Dispose(disposing); + + _activeCall?.Dispose(); + } + } + + protected override void StartCore(Action> startCallFunc) + { + _ = StartRetry(startCallFunc); + } + + public override Task ClientStreamCompleteAsync() + { + ClientStreamComplete = true; + + return DoClientStreamActionAsync(async call => + { + await call.ClientStreamWriter!.CompleteAsync().ConfigureAwait(false); + }); + } + + public override Task ClientStreamWriteAsync(TRequest message) + { + // The retry client stream writer prevents multiple threads from reaching here. + return DoClientStreamActionAsync(async call => + { + CompatibilityExtensions.Assert(call.ClientStreamWriter != null); + + if (ClientStreamWriteOptions != null) + { + call.ClientStreamWriter.WriteOptions = ClientStreamWriteOptions; + } + + await call.WriteClientStreamAsync(WriteNewMessage, message).ConfigureAwait(false); + + lock (Lock) + { + BufferedCurrentMessage = false; + } + + if (ClientStreamComplete) + { + await call.ClientStreamWriter.CompleteAsync().ConfigureAwait(false); + } + }); + } + + private async Task DoClientStreamActionAsync(Func, Task> action) + { + // During a client streaming or bidirectional streaming call the app will call + // WriteAsync and CompleteAsync on the call request stream. If the call fails then + // an error will be thrown from those methods. + // + // The logic here will get the active call, apply the app action to the request stream. + // If there is an error we wait for the new active call and then run the user action on it again. + // Keep going until either the action succeeds, or there is no new active call + // because of exceeded attempts, non-retry status code or retry throttling. + + var call = await GetActiveCallAsync(previousCall: null).ConfigureAwait(false); + while (true) + { + try + { + await action(call!).ConfigureAwait(false); + return; + } + catch + { + call = await GetActiveCallAsync(previousCall: call).ConfigureAwait(false); + if (call == null) + { + throw; + } + } + } + } + + private Task?> GetActiveCallAsync(IGrpcCall? previousCall) + { + Debug.Assert(NewActiveCallTcs != null); + + lock (Lock) + { + // Return currently active call if there is one, and its not the previous call. + if (_activeCall != null && previousCall != _activeCall) + { + return Task.FromResult?>(_activeCall); + } + + // Wait to see whether new call will be made + return GetActiveCallUnsynchronizedAsync(previousCall); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.Log.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.Log.cs new file mode 100644 index 000000000..1b5c6922e --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.Log.cs @@ -0,0 +1,128 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using Grpc.Core; +using Microsoft.Extensions.Logging; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal partial class RetryCallBase : IGrpcCall + where TRequest : class + where TResponse : class + { + protected static class Log + { + private static readonly Action _retryEvaluated = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "RetryEvaluated"), "Evaluated retry for failed gRPC call. Status code: '{StatusCode}', Attempt: {AttemptCount}, Retry: {WillRetry}"); + + private static readonly Action _retryPushbackReceived = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "RetryPushbackReceived"), "Retry pushback of '{RetryPushback}' received from the failed gRPC call."); + + private static readonly Action _startingRetryDelay = + LoggerMessage.Define(LogLevel.Trace, new EventId(3, "StartingRetryDelay"), "Starting retry delay of {DelayDuration}."); + + private static readonly Action _errorRetryingCall = + LoggerMessage.Define(LogLevel.Error, new EventId(4, "ErrorRetryingCall"), "Error retrying gRPC call."); + + private static readonly Action _sendingBufferedMessages = + LoggerMessage.Define(LogLevel.Trace, new EventId(5, "SendingBufferedMessages"), "Sending {MessageCount} buffered messages from previous failed gRPC calls."); + + private static readonly Action _messageAddedToBuffer = + LoggerMessage.Define(LogLevel.Trace, new EventId(6, "MessageAddedToBuffer"), "Message with {MessageSize} bytes added to the buffer. There are {CallBufferSize} bytes buffered for this call."); + + private static readonly Action _callCommited = + LoggerMessage.Define(LogLevel.Debug, new EventId(7, "CallCommited"), "Call commited. Reason: {CommitReason}"); + + private static readonly Action _startingRetryWorker = + LoggerMessage.Define(LogLevel.Trace, new EventId(8, "StartingRetryWorker"), "Starting retry worker."); + + private static readonly Action _stoppingRetryWorker = + LoggerMessage.Define(LogLevel.Trace, new EventId(9, "StoppingRetryWorker"), "Stopping retry worker."); + + private static readonly Action _maxAttemptsLimited = + LoggerMessage.Define(LogLevel.Debug, new EventId(10, "MaxAttemptsLimited"), "The method has {ServiceConfigMaxAttempts} attempts specified in the service config. The number of attempts has been limited by channel configuration to {ChannelMaxAttempts}."); + + private static readonly Action _additionalCallsBlockedByRetryThrottling = + LoggerMessage.Define(LogLevel.Debug, new EventId(11, "AdditionalCallsBlockedByRetryThrottling"), "Additional calls blocked by retry throttling."); + + private static readonly Action _startingAttempt = + LoggerMessage.Define(LogLevel.Debug, new EventId(12, "StartingAttempt"), "Starting attempt {AttemptCount}."); + + internal static void RetryEvaluated(ILogger logger, StatusCode statusCode, int attemptCount, bool willRetry) + { + _retryEvaluated(logger, statusCode, attemptCount, willRetry, null); + } + + internal static void RetryPushbackReceived(ILogger logger, string retryPushback) + { + _retryPushbackReceived(logger, retryPushback, null); + } + + internal static void StartingRetryDelay(ILogger logger, TimeSpan delayDuration) + { + _startingRetryDelay(logger, delayDuration, null); + } + + internal static void ErrorRetryingCall(ILogger logger, Exception ex) + { + _errorRetryingCall(logger, ex); + } + + internal static void SendingBufferedMessages(ILogger logger, int messageCount) + { + _sendingBufferedMessages(logger, messageCount, null); + } + + internal static void MessageAddedToBuffer(ILogger logger, int messageSize, long callBufferSize) + { + _messageAddedToBuffer(logger, messageSize, callBufferSize, null); + } + + internal static void CallCommited(ILogger logger, CommitReason commitReason) + { + _callCommited(logger, commitReason, null); + } + + internal static void StartingRetryWorker(ILogger logger) + { + _startingRetryWorker(logger, null); + } + + internal static void StoppingRetryWorker(ILogger logger) + { + _stoppingRetryWorker(logger, null); + } + + internal static void MaxAttemptsLimited(ILogger logger, int serviceConfigMaxAttempts, int channelMaxAttempts) + { + _maxAttemptsLimited(logger, serviceConfigMaxAttempts, channelMaxAttempts, null); + } + + internal static void AdditionalCallsBlockedByRetryThrottling(ILogger logger) + { + _additionalCallsBlockedByRetryThrottling(logger, null); + } + + internal static void StartingAttempt(ILogger logger, int attempts) + { + _startingAttempt(logger, attempts, null); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs new file mode 100644 index 000000000..048510142 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs @@ -0,0 +1,519 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Grpc.Net.Client.Internal.Http; +using Grpc.Shared; +using Microsoft.Extensions.Logging; + +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif + +namespace Grpc.Net.Client.Internal.Retry +{ + internal abstract partial class RetryCallBase : IGrpcCall + where TRequest : class + where TResponse : class + { + private readonly TaskCompletionSource> _commitedCallTcs; + private RetryCallBaseClientStreamReader? _retryBaseClientStreamReader; + private RetryCallBaseClientStreamWriter? _retryBaseClientStreamWriter; + private CancellationTokenRegistration? _ctsRegistration; + + protected object Lock { get; } = new object(); + protected ILogger Logger { get; } + protected Method Method { get; } + protected CallOptions Options { get; } + protected int MaxRetryAttempts { get; } + protected CancellationTokenSource CancellationTokenSource { get; } + protected TaskCompletionSource?>? NewActiveCallTcs { get; set; } + protected bool Disposed { get; private set; } + + public GrpcChannel Channel { get; } + public Task> CommitedCallTask => _commitedCallTcs.Task; + public IAsyncStreamReader? ClientStreamReader => _retryBaseClientStreamReader ??= new RetryCallBaseClientStreamReader(this); + public IClientStreamWriter? ClientStreamWriter => _retryBaseClientStreamWriter ??= new RetryCallBaseClientStreamWriter(this); + public WriteOptions? ClientStreamWriteOptions { get; internal set; } + public bool ClientStreamComplete { get; set; } + + protected int AttemptCount { get; private set; } + protected List> BufferedMessages { get; } + protected long CurrentCallBufferSize { get; set; } + protected bool BufferedCurrentMessage { get; set; } + + protected RetryCallBase(GrpcChannel channel, Method method, CallOptions options, string loggerName, int retryAttempts) + { + Logger = channel.LoggerFactory.CreateLogger(loggerName); + Channel = channel; + Method = method; + Options = options; + _commitedCallTcs = new TaskCompletionSource>(TaskCreationOptions.RunContinuationsAsynchronously); + BufferedMessages = new List>(); + + // Raise OnCancellation event for cancellation related clean up. + CancellationTokenSource = new CancellationTokenSource(); + CancellationTokenSource.Token.Register(state => ((RetryCallBase)state!).OnCancellation(), this); + + // If the passed in token is canceled then we want to cancel the retry cancellation token. + // Note that if the token is already canceled then callback is run inline. + if (options.CancellationToken.CanBeCanceled) + { + _ctsRegistration = options.CancellationToken.Register(state => ((RetryCallBase)state!).CancellationTokenSource.Cancel(), this); + } + + var deadline = Options.Deadline.GetValueOrDefault(DateTime.MaxValue); + if (deadline != DateTime.MaxValue) + { + var timeout = CommonGrpcProtocolHelpers.GetTimerDueTime(deadline - Channel.Clock.UtcNow, Channel.MaxTimerDueTime); + CancellationTokenSource.CancelAfter(TimeSpan.FromMilliseconds(timeout)); + } + + if (HasClientStream()) + { + // Run continuation synchronously so awaiters execute inside the lock + NewActiveCallTcs = new TaskCompletionSource?>(TaskCreationOptions.None); + } + + if (retryAttempts > Channel.MaxRetryAttempts) + { + Log.MaxAttemptsLimited(Logger, retryAttempts, Channel.MaxRetryAttempts.GetValueOrDefault()); + MaxRetryAttempts = Channel.MaxRetryAttempts.GetValueOrDefault(); + } + else + { + MaxRetryAttempts = retryAttempts; + } + } + + public async Task GetResponseAsync() + { + var call = await CommitedCallTask.ConfigureAwait(false); + return await call.GetResponseAsync().ConfigureAwait(false); + } + + public async Task GetResponseHeadersAsync() + { + var call = await CommitedCallTask.ConfigureAwait(false); + return await call.GetResponseHeadersAsync().ConfigureAwait(false); + } + + public Status GetStatus() + { + if (CommitedCallTask.IsCompletedSuccessfully()) + { + return CommitedCallTask.Result.GetStatus(); + } + + throw new InvalidOperationException("Unable to get the status because the call is not complete."); + } + + public Metadata GetTrailers() + { + if (CommitedCallTask.IsCompletedSuccessfully()) + { + return CommitedCallTask.Result.GetTrailers(); + } + + throw new InvalidOperationException("Can't get the call trailers because the call has not completed successfully."); + } + + public void Dispose() => Dispose(true); + + public void StartUnary(TRequest request) + { + StartCore(call => call.StartUnaryCore(CreatePushUnaryContent(request, call))); + } + + public void StartClientStreaming() + { + StartCore(call => + { + var clientStreamWriter = new HttpContentClientStreamWriter(call); + var content = CreatePushStreamContent(call, clientStreamWriter); + call.StartClientStreamingCore(clientStreamWriter, content); + }); + } + + public void StartServerStreaming(TRequest request) + { + StartCore(call => call.StartServerStreamingCore(CreatePushUnaryContent(request, call))); + } + + public void StartDuplexStreaming() + { + StartCore(call => + { + var clientStreamWriter = new HttpContentClientStreamWriter(call); + var content = CreatePushStreamContent(call, clientStreamWriter); + call.StartDuplexStreamingCore(clientStreamWriter, content); + }); + } + + private HttpContent CreatePushUnaryContent(TRequest request, GrpcCall call) + { + return !Channel.IsWinHttp + ? new PushUnaryContent(request, WriteAsync) + : new WinHttpUnaryContent(request, WriteAsync, call); + + ValueTask WriteAsync(TRequest request, Stream stream) + { + return WriteNewMessage(call, stream, call.Options, request); + } + } + + private PushStreamContent CreatePushStreamContent(GrpcCall call, HttpContentClientStreamWriter clientStreamWriter) + { + return new PushStreamContent(clientStreamWriter, async requestStream => + { + ValueTask writeTask; + lock (Lock) + { + Log.SendingBufferedMessages(Logger, BufferedMessages.Count); + + if (BufferedMessages.Count == 0) + { +#if NETSTANDARD2_0 + writeTask = Task.CompletedTask; +#else + writeTask = default; +#endif + } + else if (BufferedMessages.Count == 1) + { + writeTask = call.WriteMessageAsync(requestStream, BufferedMessages[0], call.CancellationToken); + } + else + { + // Copy messages to a new collection in lock for thread-safety. + var bufferedMessageCopy = BufferedMessages.ToArray(); + writeTask = WriteBufferedMessages(call, requestStream, bufferedMessageCopy); + } + } + + await writeTask.ConfigureAwait(false); + + if (ClientStreamComplete) + { + await call.ClientStreamWriter!.CompleteAsync().ConfigureAwait(false); + } + }); + + static async ValueTask WriteBufferedMessages(GrpcCall call, Stream requestStream, ReadOnlyMemory[] bufferedMessages) + { + foreach (var writtenMessage in bufferedMessages) + { + await call.WriteMessageAsync(requestStream, writtenMessage, call.CancellationToken).ConfigureAwait(false); + } + } + } + + protected abstract void StartCore(Action> startCallFunc); + + public abstract Task ClientStreamCompleteAsync(); + + public abstract Task ClientStreamWriteAsync(TRequest message); + + protected bool IsDeadlineExceeded() + { + return Options.Deadline != null && Options.Deadline <= Channel.Clock.UtcNow; + } + + protected int? GetRetryPushback(HttpResponseMessage? httpResponse) + { + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#pushback + if (httpResponse != null) + { + if (httpResponse.Headers.TryGetValues(GrpcProtocolConstants.RetryPushbackHeader, out var values)) + { + var headerValue = values.Single(); + Log.RetryPushbackReceived(Logger, headerValue); + + // A non-integer value means the server wants retries to stop. + // Resolve non-integer value to a negative integer which also means stop. + return int.TryParse(headerValue, out var value) ? value : -1; + } + } + + return null; + } + + protected byte[] SerializePayload(GrpcCall call, CallOptions callOptions, TRequest request) + { + var serializationContext = call.SerializationContext; + serializationContext.CallOptions = callOptions; + serializationContext.Initialize(); + + try + { + call.Method.RequestMarshaller.ContextualSerializer(request, serializationContext); + + // Need to take a copy because the serialization context will returned a rented buffer. + return serializationContext.GetWrittenPayload().ToArray(); + } + finally + { + serializationContext.Reset(); + } + } + + protected async ValueTask WriteNewMessage(GrpcCall call, Stream writeStream, CallOptions callOptions, TRequest message) + { + // Serialize current message and add to the buffer. + ReadOnlyMemory messageData; + + lock (Lock) + { + if (!BufferedCurrentMessage) + { + messageData = SerializePayload(call, callOptions, message); + + // Don't buffer message data if the call has been commited. + if (!CommitedCallTask.IsCompletedSuccessfully()) + { + if (!TryAddToRetryBuffer(messageData)) + { + CommitCall(call, CommitReason.BufferExceeded); + } + else + { + BufferedCurrentMessage = true; + + Log.MessageAddedToBuffer(Logger, messageData.Length, CurrentCallBufferSize); + } + } + } + else + { + // There is a race between: + // 1. A client stream starting for a new call. It will write all buffered messages, and + // 2. Writing a new message here. The message may already have been buffered when the client + // stream started so we don't want to write it again. + // + // Check the client stream write count against the buffer message count to ensure all buffered + // messages haven't already been written. + if (call.MessagesWritten == BufferedMessages.Count) + { + return; + } + + messageData = BufferedMessages[BufferedMessages.Count - 1]; + } + } + + await call.WriteMessageAsync(writeStream, messageData, callOptions.CancellationToken).ConfigureAwait(false); + } + + protected void CommitCall(IGrpcCall call, CommitReason commitReason) + { + lock (Lock) + { + if (!CommitedCallTask.IsCompletedSuccessfully()) + { + // The buffer size is verified in unit tests after calls are completed. + // Clear the buffer before commiting call. + ClearRetryBuffer(); + + OnCommitCall(call); + + // Log before committing for unit tests. + Log.CallCommited(Logger, commitReason); + + NewActiveCallTcs?.SetResult(null); + _commitedCallTcs.SetResult(call); + } + } + } + + protected abstract void OnCommitCall(IGrpcCall call); + + protected bool HasClientStream() + { + return Method.Type == MethodType.ClientStreaming || Method.Type == MethodType.DuplexStreaming; + } + + protected void SetNewActiveCallUnsynchronized(IGrpcCall call) + { + Debug.Assert(!CommitedCallTask.IsCompletedSuccessfully()); + Debug.Assert(Monitor.IsEntered(Lock)); + + if (NewActiveCallTcs != null) + { + // Run continuation synchronously so awaiters execute inside the lock + NewActiveCallTcs.SetResult(call); + NewActiveCallTcs = new TaskCompletionSource?>(TaskCreationOptions.None); + } + } + + Task IGrpcCall.WriteClientStreamAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) + { + throw new NotSupportedException(); + } + + protected async Task?> GetActiveCallUnsynchronizedAsync(IGrpcCall? previousCall) + { + CompatibilityExtensions.Assert(NewActiveCallTcs != null); + + var call = await NewActiveCallTcs.Task.ConfigureAwait(false); + + Debug.Assert(Monitor.IsEntered(Lock)); + if (call == null) + { + call = await CommitedCallTask.ConfigureAwait(false); + } + + // Avoid infinite loop. + if (call == previousCall) + { + return null; + } + + return call; + } + + protected virtual void Dispose(bool disposing) + { + if (Disposed) + { + return; + } + + Disposed = true; + + if (disposing) + { + _ctsRegistration?.Dispose(); + CancellationTokenSource.Cancel(); + + if (CommitedCallTask.IsCompletedSuccessfully()) + { + CommitedCallTask.Result.Dispose(); + } + + ClearRetryBuffer(); + } + } + + internal bool TryAddToRetryBuffer(ReadOnlyMemory message) + { + lock (Lock) + { + var messageSize = message.Length; + if (CurrentCallBufferSize + messageSize > Channel.MaxRetryBufferPerCallSize) + { + return false; + } + if (!Channel.TryAddToRetryBuffer(messageSize)) + { + return false; + } + + CurrentCallBufferSize += messageSize; + BufferedMessages.Add(message); + return true; + } + } + + internal void ClearRetryBuffer() + { + lock (Lock) + { + if (BufferedMessages.Count > 0) + { + BufferedMessages.Clear(); + Channel.RemoveFromRetryBuffer(CurrentCallBufferSize); + CurrentCallBufferSize = 0; + } + } + } + + protected StatusGrpcCall CreateStatusCall(Status status) + { + return new StatusGrpcCall(status); + } + + protected void HandleUnexpectedError(Exception ex) + { + IGrpcCall resolvedCall; + CommitReason commitReason; + + // Cancellation token triggered by dispose could throw here. + if (ex is OperationCanceledException && CancellationTokenSource.IsCancellationRequested) + { + // Cancellation could have been caused by an exceeded deadline. + if (IsDeadlineExceeded()) + { + commitReason = CommitReason.DeadlineExceeded; + // An exceeded deadline inbetween calls means there is no active call. + // Create a fake call that returns exceeded deadline status to the app. + resolvedCall = CreateStatusCall(GrpcProtocolConstants.DeadlineExceededStatus); + } + else + { + commitReason = CommitReason.Canceled; + resolvedCall = CreateStatusCall(Disposed ? GrpcProtocolConstants.DisposeCanceledStatus : GrpcProtocolConstants.ClientCanceledStatus); + } + } + else + { + commitReason = CommitReason.UnexpectedError; + resolvedCall = CreateStatusCall(GrpcProtocolHelpers.CreateStatusFromException("Unexpected error during retry.", ex)); + + // Only log unexpected errors. + Log.ErrorRetryingCall(Logger, ex); + } + + CommitCall(resolvedCall, commitReason); + } + + protected void OnStartingAttempt() + { + Debug.Assert(Monitor.IsEntered(Lock)); + + AttemptCount++; + Log.StartingAttempt(Logger, AttemptCount); + } + + protected virtual void OnCancellation() + { + } + + protected bool IsRetryThrottlingActive() + { + return Channel.RetryThrottling?.IsRetryThrottlingActive() ?? false; + } + + protected void RetryAttemptCallSuccess() + { + Channel.RetryThrottling?.CallSuccess(); + } + + protected void RetryAttemptCallFailure() + { + Channel.RetryThrottling?.CallFailure(); + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamReader.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamReader.cs new file mode 100644 index 000000000..46885f80f --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamReader.cs @@ -0,0 +1,46 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal class RetryCallBaseClientStreamReader : IAsyncStreamReader + where TRequest : class + where TResponse : class + { + private readonly RetryCallBase _retryCallBase; + + public RetryCallBaseClientStreamReader(RetryCallBase retryCallBase) + { + _retryCallBase = retryCallBase; + } + + public TResponse Current => _retryCallBase.CommitedCallTask.IsCompletedSuccessfully() + ? _retryCallBase.CommitedCallTask.Result.ClientStreamReader!.Current + : default!; + + public async Task MoveNext(CancellationToken cancellationToken) + { + var call = await _retryCallBase.CommitedCallTask.ConfigureAwait(false); + return await call.ClientStreamReader!.MoveNext(cancellationToken).ConfigureAwait(false); + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamWriter.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamWriter.cs new file mode 100644 index 000000000..1f412c58f --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamWriter.cs @@ -0,0 +1,88 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal class RetryCallBaseClientStreamWriter : ClientStreamWriterBase + where TRequest : class + where TResponse : class + { + // Getting logger name from generic type is slow + private const string LoggerName = "Grpc.Net.Client.Internal.Retry.RetryCallBaseClientStreamWriter"; + + private readonly RetryCallBase _retryCallBase; + + public RetryCallBaseClientStreamWriter(RetryCallBase retryCallBase) + : base(retryCallBase.Channel.LoggerFactory.CreateLogger(LoggerName)) + { + _retryCallBase = retryCallBase; + } + + public override WriteOptions? WriteOptions + { + get => _retryCallBase.ClientStreamWriteOptions; + set => _retryCallBase.ClientStreamWriteOptions = value; + } + + public override Task CompleteAsync() + { + lock (WriteLock) + { + // Pending writes need to be awaited first + if (IsWriteInProgressUnsynchronized) + { + var ex = new InvalidOperationException("Can't complete the client stream writer because the previous write is in progress."); + Log.CompleteClientStreamError(Logger, ex); + return Task.FromException(ex); + } + + return _retryCallBase.ClientStreamCompleteAsync(); + } + } + + public override Task WriteAsync(TRequest message) + { + lock (WriteLock) + { + // CompleteAsync has already been called + // Use explicit flag here. This error takes precedence over others. + if (_retryCallBase.ClientStreamComplete) + { + return CreateErrorTask("Request stream has already been completed."); + } + + // Pending writes need to be awaited first + if (IsWriteInProgressUnsynchronized) + { + return CreateErrorTask("Can't write the message because the previous write is in progress."); + } + + // Save write task to track whether it is complete. Must be set inside lock. + WriteTask = _retryCallBase.ClientStreamWriteAsync(message); + } + + return WriteTask; + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs b/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs new file mode 100644 index 000000000..e2728cec4 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs @@ -0,0 +1,135 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; + +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif + +namespace Grpc.Net.Client.Internal.Retry +{ + internal sealed class StatusGrpcCall : IGrpcCall + where TRequest : class + where TResponse : class + { + private readonly Status _status; + private IClientStreamWriter? _clientStreamWriter; + private IAsyncStreamReader? _clientStreamReader; + + public IClientStreamWriter? ClientStreamWriter => _clientStreamWriter ??= new StatusClientStreamWriter(_status); + public IAsyncStreamReader? ClientStreamReader => _clientStreamReader ??= new StatusStreamReader(_status); + + public StatusGrpcCall(Status status) + { + _status = status; + } + + public void Dispose() + { + } + + public Task GetResponseAsync() + { + return Task.FromException(new RpcException(_status)); + } + + public Task GetResponseHeadersAsync() + { + return Task.FromException(new RpcException(_status)); + } + + public Status GetStatus() + { + return _status; + } + + public Metadata GetTrailers() + { + throw new InvalidOperationException("Can't get the call trailers because the call has not completed successfully."); + } + + public void StartClientStreaming() + { + throw new NotSupportedException(); + } + + public void StartDuplexStreaming() + { + throw new NotSupportedException(); + } + + public void StartServerStreaming(TRequest request) + { + throw new NotSupportedException(); + } + + public void StartUnary(TRequest request) + { + throw new NotSupportedException(); + } + + public Task WriteClientStreamAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) + { + return Task.FromException(new RpcException(_status)); + } + + private sealed class StatusClientStreamWriter : IClientStreamWriter + { + private readonly Status _status; + + public WriteOptions? WriteOptions { get; set; } + + public StatusClientStreamWriter(Status status) + { + _status = status; + } + + public Task CompleteAsync() + { + return Task.FromException(new RpcException(_status)); + } + + public Task WriteAsync(TRequest message) + { + return Task.FromException(new RpcException(_status)); + } + } + + private sealed class StatusStreamReader : IAsyncStreamReader + { + private readonly Status _status; + + public TResponse Current { get; set; } = default!; + + public StatusStreamReader(Status status) + { + _status = status; + } + + public Task MoveNext(CancellationToken cancellationToken) + { + return Task.FromException(new RpcException(_status)); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/StreamExtensions.cs b/src/Grpc.Net.Client/Internal/StreamExtensions.cs index cb5279abd..8444b9291 100644 --- a/src/Grpc.Net.Client/Internal/StreamExtensions.cs +++ b/src/Grpc.Net.Client/Internal/StreamExtensions.cs @@ -319,16 +319,15 @@ public static async ValueTask WriteMessageAsync( this Stream stream, GrpcCall call, ReadOnlyMemory data, - CallOptions callOptions) + CancellationToken cancellationToken) { - // Sync relevant changes here with other WriteMessageAsync try { GrpcCallLog.SendingMessage(call.Logger); // Sending the header+content in a single WriteAsync call has significant performance benefits // https://github.com/dotnet/runtime/issues/35184#issuecomment-626304981 - await stream.WriteAsync(data, callOptions.CancellationToken).ConfigureAwait(false); + await stream.WriteAsync(data, cancellationToken).ConfigureAwait(false); GrpcCallLog.MessageSent(call.Logger); } diff --git a/src/Shared/CommonGrpcProtocolHelpers.cs b/src/Shared/CommonGrpcProtocolHelpers.cs index 787e17f1c..b5cce6059 100644 --- a/src/Shared/CommonGrpcProtocolHelpers.cs +++ b/src/Shared/CommonGrpcProtocolHelpers.cs @@ -31,7 +31,7 @@ internal static class CommonGrpcProtocolHelpers // - The timer is rescheduled to run in 0.5ms. // - The deadline callback is raised again and there is now 0.4ms until deadline. // - The timer is rescheduled to run in 0.4ms, etc. - private static readonly int TimerEpsilonMilliseconds = 4; + private static readonly int TimerEpsilonMilliseconds = 7; public static long GetTimerDueTime(TimeSpan timeout, long maxTimerDueTime) { @@ -41,7 +41,7 @@ public static long GetTimerDueTime(TimeSpan timeout, long maxTimerDueTime) // Add epislon to take into account Timer precision. // This will avoid rescheduling the timer multiple times, but means deadline - // might run for some extra milliseconds + // might run slightly longer than requested. dueTimeMilliseconds += TimerEpsilonMilliseconds; dueTimeMilliseconds = Math.Min(dueTimeMilliseconds, maxTimerDueTime); diff --git a/test/FunctionalTests/Client/CancellationTests.cs b/test/FunctionalTests/Client/CancellationTests.cs index 5f49663e5..ceb4319f8 100644 --- a/test/FunctionalTests/Client/CancellationTests.cs +++ b/test/FunctionalTests/Client/CancellationTests.cs @@ -107,7 +107,7 @@ await TestHelpers.RunParallel(tasks, async taskIndex => { try { - for (int i = 0; i < interations; i++) + for (var i = 0; i < interations; i++) { Logger.LogInformation($"Staring {taskIndex}-{i}"); diff --git a/test/FunctionalTests/Client/EventSourceTests.cs b/test/FunctionalTests/Client/EventSourceTests.cs index 6127e6e53..c40f197ab 100644 --- a/test/FunctionalTests/Client/EventSourceTests.cs +++ b/test/FunctionalTests/Client/EventSourceTests.cs @@ -172,7 +172,7 @@ async Task UnaryError(HelloRequest request, ServerCallContext contex public async Task UnaryMethod_DeadlineExceededCall_PollingCountersUpdatedCorrectly() { // Loop to ensure test is resilent across multiple runs - for (int i = 1; i < 3; i++) + for (var i = 1; i < 3; i++) { var syncPoint = new SyncPoint(); @@ -248,7 +248,7 @@ async Task UnaryDeadlineExceeded(HelloRequest request, ServerCallCon public async Task UnaryMethod_CancelCall_PollingCountersUpdatedCorrectly() { // Loop to ensure test is resilent across multiple runs - for (int i = 1; i < 3; i++) + for (var i = 1; i < 3; i++) { var syncPoint = new SyncPoint(); var cts = new CancellationTokenSource(); diff --git a/test/FunctionalTests/Client/HedgingTests.cs b/test/FunctionalTests/Client/HedgingTests.cs new file mode 100644 index 000000000..629cd28ba --- /dev/null +++ b/test/FunctionalTests/Client/HedgingTests.cs @@ -0,0 +1,593 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using Grpc.AspNetCore.FunctionalTests.Infrastructure; +using Grpc.Core; +using Grpc.Net.Client; +using Grpc.Net.Client.Configuration; +using Grpc.Net.Client.Internal; +using Grpc.Tests.Shared; +using Microsoft.Extensions.Logging; +using NUnit.Framework; +using Streaming; + +namespace Grpc.AspNetCore.FunctionalTests.Client +{ + [TestFixture] + public class HedgingTests : FunctionalTestBase + { + [TestCase(0)] + [TestCase(20)] + public async Task Unary_ExceedAttempts_Failure(int hedgingDelay) + { + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Ignore errors + SetExpectedErrorsFilter(writeContext => + { + return true; + }); + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelay))); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ExceededAttemptCount"); + } + + [Test] + public async Task Duplex_ManyParallelRequests_MessageRoundTripped() + { + const string ImportantMessage = +@" _____ _____ _____ + | __ \| __ \ / ____| + __ _| |__) | |__) | | + / _` | _ /| ___/| | + | (_| | | \ \| | | |____ + \__, |_| \_\_| \_____| + __/ | + |___/ + _ + (_) + _ ___ + | / __| + | \__ \ _ + |_|___/ | | + ___ ___ ___ | | + / __/ _ \ / _ \| | + | (_| (_) | (_) | | + \___\___/ \___/|_| + + "; + + var attempts = 100; + var allUploads = new List(); + var allCompletedTasks = new List(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + async Task MessageUpload( + IAsyncStreamReader requestStream, + IServerStreamWriter responseStream, + ServerCallContext context) + { + // Receive chunks + var chunks = new List(); + await foreach (var chunk in requestStream.ReadAllAsync()) + { + chunks.Add(chunk.Value); + } + + Task completeTask; + lock (allUploads) + { + allUploads.Add(string.Join(Environment.NewLine, chunks)); + if (allUploads.Count < attempts) + { + // Check that unused calls are canceled. + completeTask = Task.Run(async () => + { + await tcs.Task; + + var cancellationTcs = new TaskCompletionSource(); + context.CancellationToken.Register(s => ((TaskCompletionSource)s!).SetResult(true), cancellationTcs); + await cancellationTcs.Task; + }); + } + else + { + // Write response in used call. + completeTask = Task.Run(async () => + { + // Write chunks + foreach (var chunk in chunks) + { + await responseStream.WriteAsync(new StringValue + { + Value = chunk + }); + } + }); + } + } + + await completeTask; + } + + var method = Fixture.DynamicGrpc.AddDuplexStreamingMethod(MessageUpload); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 100, hedgingDelay: TimeSpan.Zero), maxRetryAttempts: 100); + + var client = TestClientFactory.Create(channel, method); + + using var call = client.DuplexStreamingCall(); + + var lines = ImportantMessage.Split(Environment.NewLine); + for (var i = 0; i < lines.Length; i++) + { + await call.RequestStream.WriteAsync(new StringValue { Value = lines[i] }).DefaultTimeout(); + await Task.Delay(TimeSpan.FromSeconds(0.01)).DefaultTimeout(); + } + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + await TestHelpers.AssertIsTrueRetryAsync(() => allUploads.Count == 100, "Wait for all calls to reach server.").DefaultTimeout(); + tcs.SetResult(null); + + var receivedLines = new List(); + await foreach (var line in call.ResponseStream.ReadAllAsync().DefaultTimeout()) + { + receivedLines.Add(line.Value); + } + + Assert.AreEqual(ImportantMessage, string.Join(Environment.NewLine, receivedLines)); + + foreach (var upload in allUploads) + { + Assert.AreEqual(ImportantMessage, upload); + } + + await Task.WhenAll(allCompletedTasks).DefaultTimeout(); + } + + [TestCase(1)] + [TestCase(2)] + public async Task Unary_DeadlineExceedAfterServerCall_Failure(int exceptedServerCallCount) + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + + if (callCount < exceptedServerCallCount) + { + return Task.FromException(new RpcException(new Status(StatusCode.DeadlineExceeded, ""))); + } + + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(nonFatalStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(200))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + } + + [Test] + public async Task Unary_DeadlineExceedDuringDelay_Failure() + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + + return Task.FromException(new RpcException(new Status(StatusCode.DeadlineExceeded, ""), new Metadata + { + new Metadata.Entry(GrpcProtocolConstants.RetryPushbackHeader, TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()) + })); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromSeconds(10), + nonFatalStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(300))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(1, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: DeadlineExceeded"); + } + + [Test] + public async Task Duplex_DeadlineExceedDuringDelay_Failure() + { + var callCount = 0; + Task DuplexDeadlineExceeded(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + { + callCount++; + + return Task.FromException(new RpcException(new Status(StatusCode.DeadlineExceeded, ""), new Metadata + { + new Metadata.Entry(GrpcProtocolConstants.RetryPushbackHeader, TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()) + })); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddDuplexStreamingMethod(DuplexDeadlineExceeded); + + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromSeconds(10), + nonFatalStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.DuplexStreamingCall(new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(300))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseStream.MoveNext(CancellationToken.None)).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + + ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new DataMessage())).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(1, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + } + + [Test] + public async Task Unary_DeadlineExceedBeforeServerCall_Failure() + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(nonFatalStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow)); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(0, callCount); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: DeadlineExceeded"); + + tcs.SetResult(new DataMessage()); + } + + [Test] + public async Task Unary_CanceledBeforeServerCall_Failure() + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(nonFatalStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(cancellationToken: cts.Token)); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode); + Assert.AreEqual(0, callCount); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: Canceled"); + + tcs.SetResult(new DataMessage()); + } + + [Test] + public async Task Unary_TriggerRetryThrottling_Failure() + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromSeconds(10), + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + })); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(3, callCount); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: Throttled"); + } + + [TestCase(0)] + [TestCase(100)] + public async Task Unary_RetryThrottlingAlreadyActive_Failure(int hedgingDelayMilliseconds) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelayMilliseconds), + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + })); + + // Manually trigger retry throttling + Debug.Assert(channel.RetryThrottling != null); + channel.RetryThrottling.CallFailure(); + channel.RetryThrottling.CallFailure(); + channel.RetryThrottling.CallFailure(); + Debug.Assert(channel.RetryThrottling.IsRetryThrottlingActive()); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(1, callCount); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: Throttled"); + } + + [Test] + public async Task Unary_RetryThrottlingBecomesActive_HasDelay_Failure() + { + var callCount = 0; + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + async Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + await syncPoint.WaitToContinue(); + return request; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromMilliseconds(100), + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + })); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + await syncPoint.WaitForSyncPoint().DefaultTimeout(); + + // Manually trigger retry throttling + Debug.Assert(channel.RetryThrottling != null); + channel.RetryThrottling.CallFailure(); + channel.RetryThrottling.CallFailure(); + channel.RetryThrottling.CallFailure(); + Debug.Assert(channel.RetryThrottling.IsRetryThrottlingActive()); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => HasLog(LogLevel.Debug, "AdditionalCallsBlockedByRetryThrottling", "Additional calls blocked by retry throttling."), "Check for expected log."); + + Assert.AreEqual(1, callCount); + syncPoint.Continue(); + + await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ResponseHeadersReceived"); + } + + [TestCase(0)] + [TestCase(20)] + public async Task Unary_AttemptsGreaterThanDefaultClientLimit_LimitedAttemptsMade(int hedgingDelay) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 10, hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelay))); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(5, callCount); + + AssertHasLog(LogLevel.Debug, "MaxAttemptsLimited", "The method has 10 attempts specified in the service config. The number of attempts has been limited by channel configuration to 5."); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ExceededAttemptCount"); + } + + [TestCase(0, false, 0)] + [TestCase(0, false, 1)] + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize - 10, false, 0)] // Final message size is bigger because of header + Protobuf field + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize - 10, false, 1)] // Final message size is bigger because of header + Protobuf field + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize + 10, true, 0)] + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize + 10, true, 1)] + public async Task Unary_LargeMessages_ExceedPerCallBufferSize(long payloadSize, bool exceedBufferLimit, int hedgingDelayMilliseconds) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Ignore errors + SetExpectedErrorsFilter(writeContext => + { + if (writeContext.EventId.Name == "ErrorSendingMessage" || + writeContext.EventId.Name == "ErrorExecutingServiceMethod") + { + return true; + } + + return false; + }); + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelayMilliseconds))); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage + { + Data = ByteString.CopyFrom(new byte[payloadSize]) + }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + if (!exceedBufferLimit) + { + Assert.AreEqual(5, callCount); + } + else + { + Assert.AreEqual(1, callCount); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: BufferExceeded"); + + // Cancelled calls could cause server errors. Delay so these error don't show up + // in the next unit test. + await Task.Delay(100); + } + + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + + } + } +} diff --git a/test/FunctionalTests/Client/RetryTests.cs b/test/FunctionalTests/Client/RetryTests.cs new file mode 100644 index 000000000..4754472ad --- /dev/null +++ b/test/FunctionalTests/Client/RetryTests.cs @@ -0,0 +1,596 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Google.Protobuf; +using Grpc.AspNetCore.FunctionalTests.Infrastructure; +using Grpc.Core; +using Grpc.Net.Client; +using Grpc.Net.Client.Configuration; +using Grpc.Tests.Shared; +using Microsoft.Extensions.Logging; +using NUnit.Framework; +using Streaming; + +namespace Grpc.AspNetCore.FunctionalTests.Client +{ + [TestFixture] + public class RetryTests : FunctionalTestBase + { + [Test] + public async Task ClientStreaming_MultipleWritesAndRetries_Failure() + { + var nextFailure = 1; + + async Task ClientStreamingWithReadFailures(IAsyncStreamReader requestStream, ServerCallContext context) + { + List bytes = new List(); + await foreach (var message in requestStream.ReadAllAsync()) + { + if (bytes.Count >= nextFailure) + { + nextFailure = nextFailure * 2; + throw new RpcException(new Status(StatusCode.Unavailable, "")); + } + + bytes.Add(message.Data[0]); + } + + return new DataMessage + { + Data = ByteString.CopyFrom(bytes.ToArray()) + }; + } + + SetExpectedErrorsFilter(writeContext => + { + return true; + }); + + // Arrange + var method = Fixture.DynamicGrpc.AddClientStreamingMethod(ClientStreamingWithReadFailures); + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 10), maxRetryAttempts: 10); + var client = TestClientFactory.Create(channel, method); + var sentData = new List(); + + // Act + var call = client.ClientStreamingCall(); + + for (var i = 0; i < 20; i++) + { + sentData.Add((byte)i); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { (byte)i }) }).DefaultTimeout(); + await Task.Delay(1); + } + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var result = await call.ResponseAsync.DefaultTimeout(); + + // Assert + Assert.IsTrue(result.Data.Span.SequenceEqual(sentData.ToArray())); + } + + [Test] + public async Task Unary_ExceedRetryAttempts_Failure() + { + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + var metadata = new Metadata(); + metadata.Add("grpc-retry-pushback-ms", "5"); + + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""), metadata)); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig()); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + AssertHasLog(LogLevel.Debug, "RetryPushbackReceived", "Retry pushback of '5' received from the failed gRPC call."); + AssertHasLog(LogLevel.Debug, "RetryEvaluated", "Evaluated retry for failed gRPC call. Status code: 'Unavailable', Attempt: 1, Retry: True"); + AssertHasLog(LogLevel.Trace, "StartingRetryDelay", "Starting retry delay of 00:00:00.0050000."); + AssertHasLog(LogLevel.Debug, "RetryEvaluated", "Evaluated retry for failed gRPC call. Status code: 'Unavailable', Attempt: 5, Retry: False"); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ExceededAttemptCount"); + } + + [Test] + public async Task Unary_TriggerRetryThrottling_Failure() + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig( + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + })); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + AssertHasLog(LogLevel.Debug, "RetryEvaluated", "Evaluated retry for failed gRPC call. Status code: 'Unavailable', Attempt: 3, Retry: False"); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: Throttled"); + } + + [TestCase(1)] + [TestCase(2)] + public async Task Unary_DeadlineExceedAfterServerCall_Failure(int exceptedServerCallCount) + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + + if (callCount < exceptedServerCallCount) + { + return Task.FromException(new RpcException(new Status(StatusCode.DeadlineExceeded, ""))); + } + + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(retryableStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(200))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(exceptedServerCallCount, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: DeadlineExceeded"); + + tcs.SetResult(new DataMessage()); + } + + [Test] + public async Task Unary_DeadlineExceedDuringBackoff_Failure() + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""), new Metadata + { + new Metadata.Entry("grpc-retry-pushback-ms", TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()) + })); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig( + initialBackoff: TimeSpan.FromSeconds(10), + maxBackoff: TimeSpan.FromSeconds(10), + retryableStatusCodes: new List { StatusCode.Unavailable }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(500))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(1, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + } + + [Test] + public async Task Duplex_DeadlineExceedDuringBackoff_Failure() + { + var callCount = 0; + Task DuplexDeadlineExceeded(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + { + callCount++; + + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""), new Metadata + { + new Metadata.Entry("grpc-retry-pushback-ms", TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()) + })); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddDuplexStreamingMethod(DuplexDeadlineExceeded); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig( + initialBackoff: TimeSpan.FromSeconds(10), + maxBackoff: TimeSpan.FromSeconds(10), + retryableStatusCodes: new List { StatusCode.Unavailable }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.DuplexStreamingCall(new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(300))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseStream.MoveNext(CancellationToken.None)).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + + ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new DataMessage())).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(1, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + } + + [Test] + public async Task Unary_DeadlineExceedBeforeServerCall_Failure() + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(retryableStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow)); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(0, callCount); + + AssertHasLog(LogLevel.Debug, "RetryEvaluated", "Evaluated retry for failed gRPC call. Status code: 'DeadlineExceeded', Attempt: 1, Retry: False"); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: DeadlineExceeded"); + + tcs.SetResult(new DataMessage()); + } + + [Test] + public async Task Unary_CanceledBeforeServerCall_Failure() + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(retryableStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(cancellationToken: cts.Token)); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode); + Assert.AreEqual(0, callCount); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: Canceled"); + + tcs.SetResult(new DataMessage()); + } + + [TestCase(1)] + [TestCase(20)] + public async Task Unary_AttemptsGreaterThanDefaultClientLimit_LimitedAttemptsMade(int hedgingDelay) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 10, initialBackoff: TimeSpan.FromMilliseconds(hedgingDelay))); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(5, callCount); + + AssertHasLog(LogLevel.Debug, "MaxAttemptsLimited", "The method has 10 attempts specified in the service config. The number of attempts has been limited by channel configuration to 5."); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ExceededAttemptCount"); + } + + [TestCase(0, false)] + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize - 10, false)] // Final message size is bigger because of header + Protobuf field + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize + 10, true)] + public async Task Unary_LargeMessages_ExceedPerCallBufferSize(long payloadSize, bool exceedBufferLimit) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig()); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage + { + Data = ByteString.CopyFrom(new byte[payloadSize]) + }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + if (!exceedBufferLimit) + { + Assert.AreEqual(5, callCount); + } + else + { + Assert.AreEqual(1, callCount); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: BufferExceeded"); + } + + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + } + + [Test] + public async Task Unary_MultipleLargeMessages_ExceedChannelMaxBufferSize() + { + // Arrange + var sp1 = new SyncPoint(runContinuationsAsynchronously: true); + var sp2 = new SyncPoint(runContinuationsAsynchronously: true); + var sp3 = new SyncPoint(runContinuationsAsynchronously: true); + var channel = CreateChannel( + serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig(), + maxRetryBufferSize: 200, + maxRetryBufferPerCallSize: 100); + + var request = new DataMessage { Data = ByteString.CopyFrom(new byte[90]) }; + + // Act + var call1Task = MakeCall(Fixture, channel, request, sp1); + await sp1.WaitForSyncPoint(); + + var call2Task = MakeCall(Fixture, channel, request, sp2); + await sp2.WaitForSyncPoint(); + + // Will exceed channel buffer limit and won't retry + var call3Task = MakeCall(Fixture, channel, request, sp3); + await sp3.WaitForSyncPoint(); + + // Assert + Assert.AreEqual(194, channel.CurrentRetryBufferSize); + + sp1.Continue(); + sp2.Continue(); + sp3.Continue(); + + var response = await call1Task.DefaultTimeout(); + Assert.AreEqual(90, response.Data.Length); + + response = await call2Task.DefaultTimeout(); + Assert.AreEqual(90, response.Data.Length); + + // Can't retry because buffer size exceeded. + var ex = await ExceptionAssert.ThrowsAsync(() => call3Task).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + + static Task MakeCall(GrpcTestFixture fixture, GrpcChannel channel, DataMessage request, SyncPoint syncPoint) + { + var callCount = 0; + async Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + if (callCount == 1) + { + await syncPoint.WaitToContinue(); + throw new RpcException(new Status(StatusCode.Unavailable, "")); + } + else + { + return request; + } + } + + // Arrange + var method = fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var client = TestClientFactory.Create(channel, method); + + var call = client.UnaryCall(request); + + return call.ResponseAsync; + } + } + + [Test] + public async Task ClientStreaming_MultipleWritesExceedPerCallLimit_Failure() + { + var nextFailure = 2; + var callCount = 0; + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + + async Task ClientStreamingWithReadFailures(IAsyncStreamReader requestStream, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + + List bytes = new List(); + await foreach (var message in requestStream.ReadAllAsync()) + { + bytes.Add(message.Data[0]); + + Logger.LogInformation($"Current count: {bytes.Count}, next failure: {nextFailure}."); + + if (bytes.Count >= nextFailure) + { + await syncPoint.WaitToContinue(); + throw new RpcException(new Status(StatusCode.Unavailable, "")); + } + } + + return new DataMessage + { + Data = ByteString.CopyFrom(bytes.ToArray()) + }; + } + + SetExpectedErrorsFilter(writeContext => + { + return true; + }); + + // Arrange + var method = Fixture.DynamicGrpc.AddClientStreamingMethod(ClientStreamingWithReadFailures); + var channel = CreateChannel( + serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 10), + maxRetryAttempts: 10, + maxRetryBufferPerCallSize: 100); + var client = TestClientFactory.Create(channel, method); + var sentData = new List(); + + // Act + var call = client.ClientStreamingCall(); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + await syncPoint.WaitForSyncPoint(); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + var s = syncPoint; + syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + nextFailure = 15; + s.Continue(); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + Assert.AreEqual(96, channel.CurrentRetryBufferSize); + + await TestHelpers.AssertIsTrueRetryAsync(() => callCount == 2, "Wait for server to have second call.").DefaultTimeout(); + + // This message exceeds the buffer size. Call is commited here. + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + await syncPoint.WaitForSyncPoint(); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + s = syncPoint; + syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + nextFailure = int.MaxValue; + s.Continue(); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(2, callCount); + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + } + } +} diff --git a/test/FunctionalTests/Client/StreamingTests.cs b/test/FunctionalTests/Client/StreamingTests.cs index ef1227bdd..0f6fb0a4c 100644 --- a/test/FunctionalTests/Client/StreamingTests.cs +++ b/test/FunctionalTests/Client/StreamingTests.cs @@ -350,11 +350,35 @@ writeContext.Exception is InvalidOperationException && return false; }); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + async Task ClientStreamedData(IAsyncStreamReader requestStream, ServerCallContext context) + { + context.CancellationToken.Register(() => tcs.SetResult(null)); + + var total = 0L; + await foreach (var message in requestStream.ReadAllAsync()) + { + total += message.Data.Length; + + if (message.ServerDelayMilliseconds > 0) + { + await Task.Delay(message.ServerDelayMilliseconds); + } + } + + return new DataComplete + { + Size = total + }; + } + // Arrange - var data = CreateTestData(1024 * 64); // 64 KB + var data = CreateTestData(1024); // 1 KB + + var method = Fixture.DynamicGrpc.AddClientStreamingMethod(ClientStreamedData, "ClientStreamedDataTimeout"); var httpClient = Fixture.CreateClient(); - httpClient.Timeout = TimeSpan.FromSeconds(0.5); + httpClient.Timeout = TimeSpan.FromSeconds(0.3); var channel = GrpcChannel.ForAddress(httpClient.BaseAddress!, new GrpcChannelOptions { @@ -362,30 +386,31 @@ writeContext.Exception is InvalidOperationException && LoggerFactory = LoggerFactory }); - var client = new StreamService.StreamServiceClient(channel); + var client = TestClientFactory.Create(channel, method); + var dataMessage = new DataMessage { Data = ByteString.CopyFrom(data) }; // Act - var call = client.ClientStreamedData(); + var call = client.ClientStreamingCall(); - var ex = await ExceptionAssert.ThrowsAsync(async () => - { - while (true) - { - await call.RequestStream.WriteAsync(dataMessage).DefaultTimeout(); + await call.RequestStream.WriteAsync(dataMessage).DefaultTimeout(); - await Task.Delay(100); - } - }).DefaultTimeout(); + await tcs.Task.DefaultTimeout(); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(dataMessage)).DefaultTimeout(); // Assert Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode); AssertHasLog(LogLevel.Information, "GrpcStatusError", "Call failed with gRPC error status. Status code: 'Cancelled', Message: ''."); + + await TestHelpers.AssertIsTrueRetryAsync( + () => HasLog(LogLevel.Error, "ErrorExecutingServiceMethod", "Error when executing service method 'ClientStreamedDataTimeout'."), + "Wait for server error so it doesn't impact other tests.").DefaultTimeout(); } [Test] @@ -539,12 +564,11 @@ public async Task ServerStreaming_WriteAfterMethodCancelled_Error(bool writeBefo }); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - + var writeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); - Task? writeTask = null; async Task ServerStreamingWithTrailers(DataMessage request, IServerStreamWriter responseStream, ServerCallContext context) { - writeTask = Task.Run(async () => + var writeTask = Task.Run(async () => { if (writeBeforeExit) { @@ -557,6 +581,7 @@ async Task ServerStreamingWithTrailers(DataMessage request, IServerStreamWriter< await responseStream.WriteAsync(new DataMessage()); }); + writeTcs.SetResult(writeTask); await tcs.Task; } @@ -581,7 +606,8 @@ async Task ServerStreamingWithTrailers(DataMessage request, IServerStreamWriter< syncPoint.Continue(); - var serverException = await ExceptionAssert.ThrowsAsync(() => writeTask!).DefaultTimeout(); + var writeTask = await writeTcs.Task.DefaultTimeout(); + var serverException = await ExceptionAssert.ThrowsAsync(() => writeTask).DefaultTimeout(); Assert.AreEqual("Can't write the message because the request is complete.", serverException.Message); // Ensure the server abort reaches the client diff --git a/test/FunctionalTests/FunctionalTestBase.cs b/test/FunctionalTests/FunctionalTestBase.cs index ebe17a2e9..88c5e189e 100644 --- a/test/FunctionalTests/FunctionalTestBase.cs +++ b/test/FunctionalTests/FunctionalTestBase.cs @@ -22,6 +22,7 @@ using Grpc.AspNetCore.FunctionalTests.Infrastructure; using Grpc.Core; using Grpc.Net.Client; +using Grpc.Net.Client.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using NUnit.Framework; @@ -41,12 +42,26 @@ public class FunctionalTestBase protected GrpcChannel Channel => _channel ??= CreateChannel(); - protected GrpcChannel CreateChannel(bool useHandler = false) + protected GrpcChannel CreateChannel(bool useHandler = false, ServiceConfig? serviceConfig = null, int? maxRetryAttempts = null, long? maxRetryBufferSize = null, long? maxRetryBufferPerCallSize = null) { var options = new GrpcChannelOptions { - LoggerFactory = LoggerFactory + LoggerFactory = LoggerFactory, + ServiceConfig = serviceConfig }; + // Don't overwrite defaults + if (maxRetryAttempts != null) + { + options.MaxRetryAttempts = maxRetryAttempts; + } + if (maxRetryBufferSize != null) + { + options.MaxRetryBufferSize = maxRetryBufferSize; + } + if (maxRetryBufferPerCallSize != null) + { + options.MaxRetryBufferPerCallSize = maxRetryBufferPerCallSize; + } if (useHandler) { options.HttpHandler = Fixture.Handler; diff --git a/test/FunctionalTests/Grpc.AspNetCore.FunctionalTests.csproj b/test/FunctionalTests/Grpc.AspNetCore.FunctionalTests.csproj index dab69c71f..78326bf26 100644 --- a/test/FunctionalTests/Grpc.AspNetCore.FunctionalTests.csproj +++ b/test/FunctionalTests/Grpc.AspNetCore.FunctionalTests.csproj @@ -7,6 +7,7 @@ + diff --git a/test/FunctionalTests/Server/ClientStreamingMethodTests.cs b/test/FunctionalTests/Server/ClientStreamingMethodTests.cs index 43ff8a5f6..3e6b2fdbc 100644 --- a/test/FunctionalTests/Server/ClientStreamingMethodTests.cs +++ b/test/FunctionalTests/Server/ClientStreamingMethodTests.cs @@ -201,7 +201,7 @@ static async Task AccumulateCount(IAsyncStreamReader { - for (int i = 0; i < 10; i++) + for (var i = 0; i < 10; i++) { await s.WriteAsync(ms.ToArray()).AsTask().DefaultTimeout(); await s.FlushAsync().DefaultTimeout(); diff --git a/test/FunctionalTests/Server/DeadlineTests.cs b/test/FunctionalTests/Server/DeadlineTests.cs index feffdfa16..4ea56b347 100644 --- a/test/FunctionalTests/Server/DeadlineTests.cs +++ b/test/FunctionalTests/Server/DeadlineTests.cs @@ -251,15 +251,18 @@ public async Task WriteMessageAfterDeadline() { static async Task WriteUntilError(HelloRequest request, IServerStreamWriter responseStream, ServerCallContext context) { - var i = 0; - while (true) + for (var i = 0; i < 5; i++) { var message = $"How are you {request.Name}? {i}"; await responseStream.WriteAsync(new HelloReply { Message = message }).DefaultTimeout(); - i++; - await Task.Delay(10); } + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + context.CancellationToken.Register(s => ((TaskCompletionSource)s!).SetResult(true), tcs); + await tcs.Task; + + await responseStream.WriteAsync(new HelloReply { Message = "Write after deadline" }).DefaultTimeout(); } // Arrange diff --git a/test/Grpc.AspNetCore.Server.Tests/Web/Base64PipeReaderTests.cs b/test/Grpc.AspNetCore.Server.Tests/Web/Base64PipeReaderTests.cs index 9674f3ac0..4cd7d7c95 100644 --- a/test/Grpc.AspNetCore.Server.Tests/Web/Base64PipeReaderTests.cs +++ b/test/Grpc.AspNetCore.Server.Tests/Web/Base64PipeReaderTests.cs @@ -130,7 +130,7 @@ public async Task ReadAsync_ByteAtATime_Success() Assert.IsFalse(resultTask.IsCompleted); - for (int i = 0; i < base64Data.Length; i++) + for (var i = 0; i < base64Data.Length; i++) { await testPipe.Writer.WriteAsync(base64Data.AsMemory(i, 1)); await Task.Delay(10); diff --git a/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs index 342cc6c09..a1b72feb7 100644 --- a/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs @@ -43,10 +43,12 @@ public async Task AsyncUnaryCall_Success_HttpRequestMessagePopulated() { // Arrange HttpRequestMessage? httpRequestMessage = null; + long? requestContentLength = null; var httpClient = ClientTestHelpers.CreateTestClient(async request => { httpRequestMessage = request; + requestContentLength = httpRequestMessage!.Content!.Headers!.ContentLength; HelloReply reply = new HelloReply { @@ -72,6 +74,7 @@ public async Task AsyncUnaryCall_Success_HttpRequestMessagePopulated() Assert.AreEqual(new MediaTypeHeaderValue("application/grpc"), httpRequestMessage.Content?.Headers?.ContentType); Assert.AreEqual(GrpcProtocolConstants.TEHeaderValue, httpRequestMessage.Headers.TE.Single().Value); Assert.AreEqual("identity,gzip", httpRequestMessage.Headers.GetValues(GrpcProtocolConstants.MessageAcceptEncodingHeader).Single()); + Assert.AreEqual(null, requestContentLength); var userAgent = httpRequestMessage.Headers.UserAgent.Single()!; Assert.AreEqual("grpc-dotnet", userAgent.Product?.Name); @@ -83,6 +86,41 @@ public async Task AsyncUnaryCall_Success_HttpRequestMessagePopulated() Assert.IsTrue(userAgent.Product!.Version!.Length <= 10); } + [Test] + public async Task AsyncUnaryCall_HasWinHttpHandler_ContentLengthOnHttpRequestMessagePopulated() + { + // Arrange + HttpRequestMessage? httpRequestMessage = null; + long? requestContentLength = null; + + var handler = TestHttpMessageHandler.Create(async request => + { + httpRequestMessage = request; + requestContentLength = httpRequestMessage!.Content!.Headers!.ContentLength; + + HelloReply reply = new HelloReply + { + Message = "Hello world" + }; + + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + // Just need to have a type called WinHttpHandler to activate new behavior. + var winHttpHandler = new WinHttpHandler(handler); + var invoker = HttpClientCallInvokerFactory.Create(winHttpHandler, "https://localhost"); + + // Act + var rs = await invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "Hello world" }); + + // Assert + Assert.AreEqual("Hello world", rs.Message); + + Assert.IsNotNull(httpRequestMessage); + Assert.AreEqual(18, requestContentLength); + } + [Test] public async Task AsyncUnaryCall_Success_RequestContentSent() { @@ -126,7 +164,7 @@ public async Task AsyncUnaryCall_Success_RequestContentSent() } [Test] - public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessResponse_ReturnHeaders() + public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessResponse_ThrowRpcError() { // Arrange var httpClient = ClientTestHelpers.CreateTestClient(request => @@ -144,7 +182,7 @@ public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessResponse_ReturnHeaders } [Test] - public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessHeaders_ThrowRpcError() + public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessHeaders_ReturnHeaders() { // Arrange var httpClient = ClientTestHelpers.CreateTestClient(request => diff --git a/test/Grpc.Net.Client.Tests/Grpc.Net.Client.Tests.csproj b/test/Grpc.Net.Client.Tests/Grpc.Net.Client.Tests.csproj index 7cc960a6e..91082a74a 100644 --- a/test/Grpc.Net.Client.Tests/Grpc.Net.Client.Tests.csproj +++ b/test/Grpc.Net.Client.Tests/Grpc.Net.Client.Tests.csproj @@ -12,12 +12,14 @@ + + diff --git a/test/Grpc.Net.Client.Tests/GrpcChannelTests.cs b/test/Grpc.Net.Client.Tests/GrpcChannelTests.cs index a952a5333..bc357d3fa 100644 --- a/test/Grpc.Net.Client.Tests/GrpcChannelTests.cs +++ b/test/Grpc.Net.Client.Tests/GrpcChannelTests.cs @@ -23,6 +23,7 @@ using Greet; using Grpc.Core; using Grpc.Net.Client.Tests.Infrastructure; +using Grpc.Net.Client.Configuration; using Grpc.Tests.Shared; using NUnit.Framework; @@ -168,6 +169,26 @@ public void Build_NoHttpProviderOnNetFx_Throw() } #endif + [Test] + public void Build_ServiceConfigDuplicateMethodConfigNames_Error() + { + // Arrange + var options = CreateGrpcChannelOptions(o => o.ServiceConfig = new ServiceConfig + { + MethodConfigs = + { + new MethodConfig { Names = { MethodName.Default } }, + new MethodConfig { Names = { MethodName.Default } } + } + }); + + // Act + var ex = Assert.Throws(() => GrpcChannel.ForAddress("https://localhost", options)); + + // Assert + Assert.AreEqual("Duplicate method config found. Service: '', method: ''.", ex.Message); + } + [Test] public void Dispose_NotCalled_NotDisposed() { diff --git a/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs b/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs index 20fe713d2..e818a599c 100644 --- a/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs +++ b/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs @@ -234,9 +234,10 @@ private static GrpcCall CreateGrpcCall(GrpcChannel cha return new GrpcCall( ClientTestHelpers.ServiceMethod, - new GrpcMethodInfo(new GrpcCallScope(ClientTestHelpers.ServiceMethod.Type, uri), uri), + new GrpcMethodInfo(new GrpcCallScope(ClientTestHelpers.ServiceMethod.Type, uri), uri, methodConfig: null), new CallOptions(), - channel); + channel, + attemptCount: 0); } private static GrpcChannel CreateChannel(HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool? throwOperationCanceledOnCancellation = null) diff --git a/test/Grpc.Net.Client.Tests/Infrastructure/HttpClientCallInvokerFactory.cs b/test/Grpc.Net.Client.Tests/Infrastructure/HttpClientCallInvokerFactory.cs index 8276fd622..4808e7bbf 100644 --- a/test/Grpc.Net.Client.Tests/Infrastructure/HttpClientCallInvokerFactory.cs +++ b/test/Grpc.Net.Client.Tests/Infrastructure/HttpClientCallInvokerFactory.cs @@ -18,6 +18,7 @@ using System; using System.Net.Http; +using Grpc.Net.Client.Configuration; using Grpc.Net.Client.Internal; using Microsoft.Extensions.Logging; @@ -32,12 +33,14 @@ public static HttpClientCallInvoker Create( Action? configure = null, bool? disableClientDeadline = null, long? maxTimerPeriod = null, - IOperatingSystem? operatingSystem = null) + IOperatingSystem? operatingSystem = null, + ServiceConfig? serviceConfig = null) { var channelOptions = new GrpcChannelOptions { LoggerFactory = loggerFactory, - HttpClient = httpClient + HttpClient = httpClient, + ServiceConfig = serviceConfig }; configure?.Invoke(channelOptions); diff --git a/test/Grpc.Net.Client.Tests/Infrastructure/WinHttpHandler.cs b/test/Grpc.Net.Client.Tests/Infrastructure/WinHttpHandler.cs new file mode 100644 index 000000000..06f6a1c66 --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Infrastructure/WinHttpHandler.cs @@ -0,0 +1,28 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +// Namespace and class name needs to resolve to System.Net.Http.WinHttpHandler. +namespace System.Net.Http +{ + public class WinHttpHandler : DelegatingHandler + { + public WinHttpHandler(HttpMessageHandler innerHandler) : base(innerHandler) + { + } + } +} diff --git a/test/Grpc.Net.Client.Tests/Retry/ChannelRetryThrottlingTests.cs b/test/Grpc.Net.Client.Tests/Retry/ChannelRetryThrottlingTests.cs new file mode 100644 index 000000000..021e541fa --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Retry/ChannelRetryThrottlingTests.cs @@ -0,0 +1,45 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Grpc.Net.Client.Internal.Retry; +using Microsoft.Extensions.Logging.Abstractions; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests.Retry +{ + [TestFixture] + public class ChannelRetryThrottlingTests + { + [Test] + public void IsRetryThrottlingActive_FailedAndSuccessCalls_ActivatedChanges() + { + var channelRetryThrottling = new ChannelRetryThrottling(maxTokens: 3, tokenRatio: 1.0, NullLoggerFactory.Instance); + + Assert.AreEqual(false, channelRetryThrottling.IsRetryThrottlingActive()); + + channelRetryThrottling.CallFailure(); + Assert.AreEqual(false, channelRetryThrottling.IsRetryThrottlingActive()); + + channelRetryThrottling.CallFailure(); + Assert.AreEqual(true, channelRetryThrottling.IsRetryThrottlingActive()); + + channelRetryThrottling.CallSuccess(); + Assert.AreEqual(false, channelRetryThrottling.IsRetryThrottlingActive()); + } + } +} diff --git a/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs b/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs new file mode 100644 index 000000000..1245cff02 --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs @@ -0,0 +1,373 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Greet; +using Grpc.Core; +using Grpc.Net.Client.Configuration; +using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Internal.Http; +using Grpc.Net.Client.Internal.Retry; +using Grpc.Net.Client.Tests.Infrastructure; +using Grpc.Tests.Shared; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests.Retry +{ + [TestFixture] + public class HedgingCallTests + { + [Test] + public async Task Dispose_ActiveCalls_CleansUpActiveCalls() + { + // Arrange + var allCallsOnServerTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var waitUntilFinishedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + // All calls are in-progress at once. + Interlocked.Increment(ref callCount); + if (callCount == 5) + { + allCallsOnServerTcs.SetResult(null); + } + await waitUntilFinishedTcs.Task; + + await request.Content!.CopyToAsync(new MemoryStream()); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: TimeSpan.FromMilliseconds(20)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest { Name = "World" }); + Assert.IsFalse(hedgingCall.CreateHedgingCallsTask!.IsCompleted); + + // Assert + Assert.AreEqual(1, hedgingCall._activeCalls.Count); + + await allCallsOnServerTcs.Task.DefaultTimeout(); + + Assert.AreEqual(5, callCount); + Assert.AreEqual(5, hedgingCall._activeCalls.Count); + + hedgingCall.Dispose(); + Assert.AreEqual(0, hedgingCall._activeCalls.Count); + await hedgingCall.CreateHedgingCallsTask!.DefaultTimeout(); + + waitUntilFinishedTcs.SetResult(null); + } + + private HedgingPolicyInfo CreateHedgingPolicy(HedgingPolicy? hedgingPolicy) => GrpcMethodInfo.CreateHedgingPolicy(hedgingPolicy!); + + [Test] + public async Task ActiveCalls_FatalStatusCode_CleansUpActiveCalls() + { + // Arrange + var allCallsOnServerSyncPoint = new SyncPoint(runContinuationsAsynchronously: true); + var waitUntilFinishedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var callLock = new object(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + await request.Content!.CopyToAsync(new MemoryStream()); + + // All calls are in-progress at once. + bool allCallsOnServer = false; + lock (callLock) + { + callCount++; + if (callCount == 5) + { + allCallsOnServer = true; + } + } + if (allCallsOnServer) + { + await allCallsOnServerSyncPoint.WaitToContinue(); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.InvalidArgument); + } + await waitUntilFinishedTcs.Task; + + throw new InvalidOperationException("Should never reach here."); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: TimeSpan.FromMilliseconds(20)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, hedgingCall._activeCalls.Count); + Assert.IsFalse(hedgingCall.CreateHedgingCallsTask!.IsCompleted); + + await allCallsOnServerSyncPoint.WaitForSyncPoint().DefaultTimeout(); + + Assert.AreEqual(5, callCount); + Assert.AreEqual(5, hedgingCall._activeCalls.Count); + + allCallsOnServerSyncPoint.Continue(); + + var ex = await ExceptionAssert.ThrowsAsync(() => hedgingCall.GetResponseAsync()).DefaultTimeout(); + Assert.AreEqual(StatusCode.InvalidArgument, ex.StatusCode); + + // Fatal status code will cancel other calls + Assert.AreEqual(0, hedgingCall._activeCalls.Count); + await hedgingCall.CreateHedgingCallsTask!.DefaultTimeout(); + + waitUntilFinishedTcs.SetResult(null); + } + + [Test] + public async Task ClientStreamWriteAsync_NoActiveCalls_WaitsForNextCall() + { + // Arrange + var allCallsOnServerSyncPoint = new SyncPoint(runContinuationsAsynchronously: true); + var callLock = new object(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + var content = (PushStreamContent)request.Content!; + _ = content.ReadAsStreamAsync(); + + // All calls are in-progress at once. + bool firstCallsOnServer = false; + lock (callLock) + { + callCount++; + if (callCount == 1) + { + firstCallsOnServer = true; + } + } + if (firstCallsOnServer) + { + await allCallsOnServerSyncPoint.WaitToContinue(); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: "200"); + } + + await content.PushComplete.DefaultTimeout(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: TimeSpan.FromMilliseconds(200)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), new CallOptions()); + + // Act + hedgingCall.StartClientStreaming(); + await hedgingCall.ClientStreamWriter!.WriteAsync(new HelloRequest { Name = "Name 1" }).DefaultTimeout(); + + // Assert + Assert.AreEqual(1, hedgingCall._activeCalls.Count); + Assert.IsFalse(hedgingCall.CreateHedgingCallsTask!.IsCompleted); + + await allCallsOnServerSyncPoint.WaitForSyncPoint().DefaultTimeout(); + allCallsOnServerSyncPoint.Continue(); + + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Call should finish and then wait until next call."); + + // This call will wait until next hedging call starts + await hedgingCall.ClientStreamWriter!.WriteAsync(new HelloRequest { Name = "Name 2" }).DefaultTimeout(); + Assert.AreEqual(1, hedgingCall._activeCalls.Count); + + await hedgingCall.ClientStreamWriter!.CompleteAsync().DefaultTimeout(); + + var responseMessage = await hedgingCall.GetResponseAsync().DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + Assert.AreEqual(0, hedgingCall._activeCalls.Count); + await hedgingCall.CreateHedgingCallsTask!.DefaultTimeout(); + } + + [Test] + public async Task ResponseAsync_PushbackStop_SuccessAfterPushbackStop() + { + // Arrange + var allCallsOnServerTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var returnSuccessTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + // All calls are in-progress at once. + Interlocked.Increment(ref callCount); + if (callCount == 2) + { + allCallsOnServerTcs.TrySetResult(null); + } + await allCallsOnServerTcs.Task; + + await request.Content!.CopyToAsync(new MemoryStream()); + + if (request.Headers.TryGetValues(GrpcProtocolConstants.RetryPreviousAttemptsHeader, out var headerValues) && + headerValues.Single() == "1") + { + await returnSuccessTcs.Task; + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + } + else + { + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, customHeaders: new Dictionary + { + [GrpcProtocolConstants.RetryPushbackHeader] = "-1" + }); + } + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 2); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest { Name = "World" }); + + // Wait for both calls to be on the server + await allCallsOnServerTcs.Task; + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 1, "Wait for pushback to be returned."); + returnSuccessTcs.SetResult(null); + + var rs = await hedgingCall.GetResponseAsync().DefaultTimeout(); + Assert.AreEqual("Hello world", rs.Message); + Assert.AreEqual(StatusCode.OK, hedgingCall.GetStatus().StatusCode); + Assert.AreEqual(2, callCount); + Assert.AreEqual(0, hedgingCall._activeCalls.Count); + } + + [Test] + public async Task RetryThrottling_BecomesActiveDuringDelay_CancelFailure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: "200"); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromMilliseconds(200), + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + }); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest()); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Wait for all calls to fail.").DefaultTimeout(); + + CompatibilityExtensions.Assert(invoker.Channel.RetryThrottling != null); + invoker.Channel.RetryThrottling.CallFailure(); + invoker.Channel.RetryThrottling.CallFailure(); + CompatibilityExtensions.Assert(invoker.Channel.RetryThrottling.IsRetryThrottlingActive()); + + var ex = await ExceptionAssert.ThrowsAsync(() => hedgingCall.GetResponseAsync()).DefaultTimeout(); + Assert.AreEqual(1, callCount); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual(StatusCode.Cancelled, hedgingCall.GetStatus().StatusCode); + Assert.AreEqual("Retries stopped because retry throttling is active.", hedgingCall.GetStatus().Detail); + } + + [Test] + public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()); + }); + var cts = new CancellationTokenSource(); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromSeconds(10)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions(cancellationToken: cts.Token)); + + // Act + hedgingCall.StartUnary(new HelloRequest()); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Wait for all calls to fail.").DefaultTimeout(); + + cts.Cancel(); + + var ex = await ExceptionAssert.ThrowsAsync(() => hedgingCall.GetResponseAsync()).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("Call canceled by the client.", ex.Status.Detail); + } + + [Test] + public async Task AsyncUnaryCall_DisposeDuringBackoff_CanceledStatus() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromSeconds(10)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest()); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Wait for all calls to fail.").DefaultTimeout(); + + hedgingCall.Dispose(); + + var ex = await ExceptionAssert.ThrowsAsync(() => hedgingCall.GetResponseAsync()).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("gRPC call disposed.", ex.Status.Detail); + } + } +} diff --git a/test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs b/test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs new file mode 100644 index 000000000..9799bbfe2 --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs @@ -0,0 +1,690 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Google.Protobuf; +using Greet; +using Grpc.Core; +using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Internal.Http; +using Grpc.Net.Client.Tests.Infrastructure; +using Grpc.Tests.Shared; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests.Retry +{ + [TestFixture] + public class HedgingTests + { + [TestCase(2)] + [TestCase(10)] + [TestCase(100)] + public async Task AsyncUnaryCall_OneAttempt_Success(int maxAttempts) + { + // Arrange + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await tcs.Task; + + await request.Content!.CopyToAsync(new MemoryStream()); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: maxAttempts); + var invoker = HttpClientCallInvokerFactory.Create( + httpClient, + serviceConfig: serviceConfig, + configure: o => o.MaxRetryAttempts = maxAttempts); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => callCount == maxAttempts, "All calls made at once."); + tcs.SetResult(null); + + var rs = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", rs.Message); + Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncClientStreamingCall_ManyParallelCalls_ReadDirectlyToRequestStream() + { + // Arrange + var requestStreams = new List(); + var attempts = 100; + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + WriterTestStream writerTestStream; + lock (requestStreams) + { + Interlocked.Increment(ref callCount); + writerTestStream = new WriterTestStream(); + requestStreams.Add(writerTestStream); + } + await request.Content!.CopyToAsync(writerTestStream); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: attempts); + var invoker = HttpClientCallInvokerFactory.Create( + httpClient, + serviceConfig: serviceConfig, + configure: o => o.MaxRetryAttempts = attempts); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions()); + var writeAsyncTask = call.RequestStream.WriteAsync(new HelloRequest { Name = "World" }); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => callCount == attempts, "All calls made at once."); + + var firstMessages = await Task.WhenAll(requestStreams.Select(s => s.WaitForDataAsync())).DefaultTimeout(); + await writeAsyncTask.DefaultTimeout(); + + foreach (var message in firstMessages) + { + Assert.IsTrue(firstMessages[0].Span.SequenceEqual(message.Span)); + } + + writeAsyncTask = call.RequestStream.WriteAsync(new HelloRequest { Name = "World 2" }); + var secondMessages = await Task.WhenAll(requestStreams.Select(s => s.WaitForDataAsync())).DefaultTimeout(); + await writeAsyncTask.DefaultTimeout(); + + foreach (var message in secondMessages) + { + Assert.IsTrue(secondMessages[0].Span.SequenceEqual(message.Span)); + } + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var rs = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", rs.Message); + Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode); + } + + private class WriterTestStream : Stream + { + public TaskCompletionSource> WriteAsyncTcs = new TaskCompletionSource>(TaskCreationOptions.RunContinuationsAsynchronously); + + public override bool CanRead { get; } + public override bool CanSeek { get; } + public override bool CanWrite { get; } + public override long Length { get; } + public override long Position { get; set; } + + private SyncPoint _syncPoint; + private Func _awaiter; + private ReadOnlyMemory _currentWriteData; + + public WriterTestStream() + { + _awaiter = SyncPoint.Create(out _syncPoint, runContinuationsAsynchronously: true); + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + +#if NET472 + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) +#else + public override async ValueTask WriteAsync(ReadOnlyMemory data, CancellationToken cancellationToken = default) +#endif + { +#if NET472 + var data = buffer.AsMemory(offset, count); +#endif + _currentWriteData = data.ToArray(); + + await _awaiter(); + // Wait until data is read by WaitForDataAsync + //await _syncPoint.WaitForSyncPoint(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public async Task> WaitForDataAsync() + { + await _syncPoint.WaitForSyncPoint(); + + ResetSyncPointAndContinuePrevious(); + + //await _awaiter(); + return _currentWriteData; + } + + private void ResetSyncPointAndContinuePrevious() + { + // We have read all data + // Signal AddDataAndWait to continue + // Reset sync point for next read + var syncPoint = _syncPoint; + + ResetSyncPoint(); + + syncPoint.Continue(); + } + + private void ResetSyncPoint() + { + _awaiter = SyncPoint.Create(out _syncPoint, runContinuationsAsynchronously: true); + } + } + + [Test] + public async Task AsyncUnaryCall_ExceedAttempts_Failure() + { + // Arrange + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestMessages = new List(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + // All calls are in-progress at once. + Interlocked.Increment(ref callCount); + if (callCount == 5) + { + tcs.TrySetResult(null); + } + await tcs.Task; + + var requestContent = await request.Content!.ReadAsStreamAsync(); + var requestMessage = await ReadRequestMessage(requestContent); + lock (requestMessages) + { + requestMessages.Add(requestMessage!); + } + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(5, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(5, requestMessages.Count); + foreach (var requestMessage in requestMessages) + { + Assert.AreEqual("World", requestMessage.Name); + } + } + + [Test] + public async Task AsyncUnaryCall_ExceedDeadlineWithActiveCalls_Failure() + { + // Arrange + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(request => + { + Interlocked.Increment(ref callCount); + return tcs.Task; + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromMilliseconds(200)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(100)), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_ManyAttemptsNoDelay_MarshallerCalledOnce() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + var marshallerCount = 0; + var requestMarshaller = Marshallers.Create( + r => + { + Interlocked.Increment(ref marshallerCount); + return r.ToByteArray(); + }, + data => HelloRequest.Parser.ParseFrom(data)); + var method = ClientTestHelpers.GetServiceMethod(requestMarshaller: requestMarshaller); + + // Act + var call = invoker.AsyncUnaryCall(method, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(5, callCount); + Assert.AreEqual(1, marshallerCount); + } + + [Test] + public async Task AsyncUnaryCall_ExceedAttempts_PusbackDelay_Failure() + { + // Arrange + var stopwatch = new Stopwatch(); + var callIntervals = new List(); + var hedgeDelay = TimeSpan.FromMilliseconds(100); + const int timerResolutionMs = 15 * 2; // Timer has a precision of about 15ms. Double it, just to be safe + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callIntervals.Add(stopwatch.ElapsedMilliseconds); + stopwatch.Restart(); + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: hedgeDelay.TotalMilliseconds.ToString()); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 2, hedgingDelay: hedgeDelay); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + stopwatch.Start(); + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(2, callCount); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + // First call should happen immediately + Assert.LessOrEqual(callIntervals[0], hedgeDelay.TotalMilliseconds); + + // Second call should happen after delay + Assert.GreaterOrEqual(callIntervals[1], hedgeDelay.TotalMilliseconds - timerResolutionMs); + } + + [Test] + public async Task AsyncUnaryCall_ExceedAttempts_NoPusbackDelay_Failure() + { + // Arrange + var stopwatch = new Stopwatch(); + var callIntervals = new List(); + var hedgeDelay = TimeSpan.FromSeconds(10); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callIntervals.Add(stopwatch.ElapsedMilliseconds); + stopwatch.Restart(); + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 2, hedgingDelay: hedgeDelay); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + stopwatch.Start(); + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(2, callCount); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + // First call should happen immediately + Assert.LessOrEqual(callIntervals[0], hedgeDelay.TotalMilliseconds); + + // Second call should happen immediately + Assert.LessOrEqual(callIntervals[1], hedgeDelay.TotalMilliseconds); + } + + [Test] + public async Task AsyncUnaryCall_PushbackDelay_PushbackDelayUpdatesNextCallDelay() + { + // Arrange + var stopwatch = new Stopwatch(); + var callIntervals = new List(); + var hedgingDelay = TimeSpan.FromSeconds(10); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callIntervals.Add(stopwatch.ElapsedMilliseconds); + stopwatch.Restart(); + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + string? hedgingPushback = hedgingDelay.TotalMilliseconds.ToString(); + if (callCount == 1) + { + hedgingPushback = "0"; + } + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: hedgingPushback); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: hedgingDelay); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + stopwatch.Start(); + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => callIntervals.Count == 2, "Only two calls should be made.").DefaultTimeout(); + + // First call should happen immediately + Assert.LessOrEqual(callIntervals[0], 100); + + // Second call should happen after delay + Assert.LessOrEqual(callIntervals[1], hedgingDelay.TotalMilliseconds); + } + + [Test] + public async Task AsyncUnaryCall_FatalStatusCode_HedgeDelay_Failure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, (callCount == 1) ? StatusCode.Unavailable : StatusCode.InvalidArgument); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromMilliseconds(50)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.InvalidArgument, ex.StatusCode); + Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode); + Assert.AreEqual(2, callCount); + } + + [Test] + public async Task AsyncServerStreamingCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + MemoryStream? requestContent = null; + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + var s = await request.Content!.ReadAsStreamAsync(); + var ms = new MemoryStream(); + await s.CopyToAsync(ms); + + if (callCount == 1) + { + await syncPoint.WaitForSyncPoint(); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + syncPoint.Continue(); + + requestContent = ms; + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 2, hedgingDelay: TimeSpan.FromMilliseconds(50)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncServerStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ServerStreaming), string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + var moveNextTask = call.ResponseStream.MoveNext(CancellationToken.None); + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + // Assert + Assert.IsTrue(await moveNextTask); + Assert.AreEqual("Hello world", call.ResponseStream.Current.Message); + + requestContent!.Seek(0, SeekOrigin.Begin); + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("World", requestMessage!.Name); + } + + [TestCase(0)] + [TestCase(1)] + [TestCase(100)] + public async Task AsyncClientStreamingCall_SuccessAfterRetry_RequestContentSent(int hedgingDelayMS) + { + // Arrange + var callLock = new object(); + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + var firstCall = false; + lock (callLock) + { + callCount++; + if (callCount == 1) + { + firstCall = true; + } + } + if (firstCall) + { + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + var content = (PushStreamContent)request.Content!; + await content.PushComplete.DefaultTimeout(); + + await request.Content!.CopyToAsync(requestContent); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 2, hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelayMS)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + Assert.IsNotNull(call); + + var responseTask = call.ResponseAsync; + Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete."); + + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout(); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var responseMessage = await responseTask.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + requestContent.Seek(0, SeekOrigin.Begin); + + var requests = new List(); + while (true) + { + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + if (requestMessage == null) + { + break; + } + + requests.Add(requestMessage); + } + + Assert.AreEqual(2, requests.Count); + Assert.AreEqual("1", requests[0].Name); + Assert.AreEqual("2", requests[1].Name); + } + + [Test] + public async Task AsyncClientStreamingCall_CompleteAndWriteAfterResult_Error() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + _ = request.Content!.ReadAsStreamAsync(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + var responseMessage = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + requestContent.Seek(0, SeekOrigin.Begin); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout(); + Assert.AreEqual("Request stream has already been completed.", ex.Message); + } + + [Test] + public async Task AsyncClientStreamingCall_WriteAfterResult_Error() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + _ = request.Content!.ReadAsStreamAsync(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + var responseMessage = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout(); + Assert.AreEqual(StatusCode.OK, ex.StatusCode); + } + + private static Task ReadRequestMessage(Stream requestContent) + { + return StreamSerializationHelper.ReadMessageAsync( + requestContent, + ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, + GrpcProtocolConstants.IdentityGrpcEncoding, + maximumMessageSize: null, + GrpcProtocolConstants.DefaultCompressionProviders, + singleMessage: false, + CancellationToken.None); + } + } +} diff --git a/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs b/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs new file mode 100644 index 000000000..082ba0da1 --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs @@ -0,0 +1,840 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Greet; +using Grpc.Core; +using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Internal.Http; +using Grpc.Net.Client.Tests.Infrastructure; +using Grpc.Tests.Shared; +using Microsoft.Extensions.Logging; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests.Retry +{ + [TestFixture] + public class RetryTests + { + [Test] + public async Task AsyncUnaryCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + HttpContent? content = null; + + bool? firstRequestPreviousAttemptsHeader = null; + string? secondRequestPreviousAttemptsHeaderValue = null; + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + + content = request.Content!; + await content.CopyToAsync(requestContent); + requestContent.Seek(0, SeekOrigin.Begin); + + if (callCount == 1) + { + firstRequestPreviousAttemptsHeader = request.Headers.TryGetValues(GrpcProtocolConstants.RetryPreviousAttemptsHeader, out _); + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + if (request.Headers.TryGetValues(GrpcProtocolConstants.RetryPreviousAttemptsHeader, out var retryAttemptCountValue)) + { + secondRequestPreviousAttemptsHeaderValue = retryAttemptCountValue.Single(); + } + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent, customTrailers: new Dictionary + { + ["custom-trailer"] = "Value!" + }); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(2, callCount); + Assert.AreEqual("Hello world", (await call.ResponseAsync.DefaultTimeout()).Message); + Assert.AreEqual("1", (await call.ResponseHeadersAsync.DefaultTimeout()).GetValue(GrpcProtocolConstants.RetryPreviousAttemptsHeader)); + + Assert.IsNotNull(content); + + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + + Assert.AreEqual("World", requestMessage!.Name); + + Assert.IsFalse(firstRequestPreviousAttemptsHeader); + Assert.AreEqual("1", secondRequestPreviousAttemptsHeaderValue); + + var trailers = call.GetTrailers(); + Assert.AreEqual("Value!", trailers.GetValue("custom-trailer")); + } + + [Test] + public async Task AsyncUnaryCall_SuccessAfterRetry_AccessResponseHeaders_SuccessfullyResponseHeadersReturned() + { + // Arrange + HttpContent? content = null; + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + content = request.Content!; + + if (callCount == 1) + { + await content.CopyToAsync(new MemoryStream()); + + await syncPoint.WaitForSyncPoint(); + + return ResponseUtils.CreateHeadersOnlyResponse( + HttpStatusCode.OK, + StatusCode.Unavailable, + customHeaders: new Dictionary { ["call-count"] = callCount.ToString() }); + } + + syncPoint.Continue(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse( + HttpStatusCode.OK, + streamContent, + customHeaders: new Dictionary { ["call-count"] = callCount.ToString() }); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + var headersTask = call.ResponseHeadersAsync; + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + // Assert + Assert.AreEqual(2, callCount); + Assert.AreEqual("Hello world", (await call.ResponseAsync.DefaultTimeout()).Message); + + var headers = await headersTask.DefaultTimeout(); + Assert.AreEqual("2", headers.GetValue("call-count")); + } + + [Test] + public async Task AsyncUnaryCall_ExceedRetryAttempts_Failure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 3); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(3, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_FailureWithLongDelay_Dispose_CallImmediatelyDisposed() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + // Very long delay + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(initialBackoff: TimeSpan.FromSeconds(30), maxBackoff: TimeSpan.FromSeconds(30)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + var resultTask = call.ResponseAsync; + + // Test will timeout if dispose doesn't kill the timer. + call.Dispose(); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => resultTask).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("gRPC call disposed.", ex.Status.Detail); + } + + [TestCase("")] + [TestCase("-1")] + [TestCase("stop")] + public async Task AsyncUnaryCall_PushbackStop_Failure(string header) + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, new StringContent(""), StatusCode.Unavailable, retryPushbackHeader: header); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_PushbackExpicitDelay_DelayForSpecifiedDuration() + { + // Arrange + Task? delayTask = null; + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + if (callCount == 1) + { + await request.Content!.CopyToAsync(new MemoryStream()); + delayTask = Task.Delay(100); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: "200"); + } + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(backoffMultiplier: 1); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Delay of 100ms will finish before second record which has a pushback delay of 200ms + var completedTask = await Task.WhenAny(call.ResponseAsync, delayTask!).DefaultTimeout(); + var rs = await call.ResponseAsync.DefaultTimeout(); + + // Assert + Assert.AreEqual(delayTask, completedTask); // Response task should finish after + Assert.AreEqual(2, callCount); + Assert.AreEqual("Hello world", rs.Message); + } + + [Test] + public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var cts = new CancellationTokenSource(); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(cancellationToken: cts.Token), new HelloRequest { Name = "World" }); + + var delayTask = Task.Delay(100); + var completedTask = await Task.WhenAny(call.ResponseAsync, delayTask); + + // Assert + Assert.AreEqual(delayTask, completedTask); // Ensure that we're waiting for retry + + cts.Cancel(); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("Call canceled by the client.", ex.Status.Detail); + } + + [Test] + public async Task AsyncUnaryCall_DisposeDuringBackoff_CanceledStatus() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var cts = new CancellationTokenSource(); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(cancellationToken: cts.Token), new HelloRequest { Name = "World" }); + + var delayTask = Task.Delay(100); + var completedTask = await Task.WhenAny(call.ResponseAsync, delayTask); + + // Assert + Assert.AreEqual(delayTask, completedTask); // Ensure that we're waiting for retry + + call.Dispose(); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("gRPC call disposed.", ex.Status.Detail); + } + + [Test] + public async Task AsyncUnaryCall_PushbackExplicitDelayExceedAttempts_Failure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: "0"); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 5); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(5, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_UnsupportedStatusCode_Failure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, new StringContent(""), StatusCode.InvalidArgument); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.InvalidArgument, ex.StatusCode); + Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_Success_RequestContentSent() + { + // Arrange + HttpContent? content = null; + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + content = request.Content; + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, callCount); + Assert.AreEqual("Hello world", (await call.ResponseAsync.DefaultTimeout()).Message); + } + + [Test] + public async Task AsyncClientStreamingCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + var currentContent = new MemoryStream(); + await request.Content!.CopyToAsync(currentContent); + + if (callCount == 1) + { + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + currentContent.Seek(0, SeekOrigin.Begin); + await currentContent.CopyToAsync(requestContent); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + Assert.IsNotNull(call); + + var responseTask = call.ResponseAsync; + Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete."); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout(); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var responseMessage = await responseTask.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + requestContent.Seek(0, SeekOrigin.Begin); + + var requests = new List(); + while (true) + { + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + if (requestMessage == null) + { + break; + } + + requests.Add(requestMessage); + } + + Assert.AreEqual(2, requests.Count); + Assert.AreEqual("1", requests[0].Name); + Assert.AreEqual("2", requests[1].Name); + + call.Dispose(); + } + + [Test] + public async Task ClientStreamWriter_WriteWhilePendingWrite_ErrorThrown() + { + // Arrange + var httpClient = ClientTestHelpers.CreateTestClient(request => + { + var streamContent = new StreamContent(new SyncPointMemoryStream()); + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent)); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + var writeTask1 = call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }); + Assert.IsFalse(writeTask1.IsCompleted); + + var writeTask2 = call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }); + var ex = await ExceptionAssert.ThrowsAsync(() => writeTask2).DefaultTimeout(); + + Assert.AreEqual("Can't write the message because the previous write is in progress.", ex.Message); + } + + [Test] + public async Task ClientStreamWriter_WriteWhileComplete_ErrorThrown() + { + // Arrange + var streamContent = new SyncPointMemoryStream(); + var httpClient = ClientTestHelpers.CreateTestClient(request => + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent))); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + await call.RequestStream.CompleteAsync().DefaultTimeout(); + var resultTask = call.ResponseAsync; + + // Assert + var writeException1 = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout(); + Assert.AreEqual("Request stream has already been completed.", writeException1.Message); + + await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(new HelloReply + { + Message = "Hello world 1" + }).DefaultTimeout()).DefaultTimeout(); + await streamContent.AddDataAndWait(new byte[0]); + + var result = await resultTask.DefaultTimeout(); + Assert.AreEqual("Hello world 1", result.Message); + + var writeException2 = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "2" })).DefaultTimeout(); + Assert.AreEqual("Request stream has already been completed.", writeException2.Message); + } + + [Test] + public async Task AsyncClientStreamingCall_CompleteAndWriteAfterResult_Error() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + _ = request.Content!.ReadAsStreamAsync(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + var responseMessage = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + requestContent.Seek(0, SeekOrigin.Begin); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout(); + Assert.AreEqual("Request stream has already been completed.", ex.Message); + } + + [Test] + public async Task AsyncClientStreamingCall_WriteAfterResult_Error() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + _ = request.Content!.ReadAsStreamAsync(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + var responseMessage = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout(); + Assert.AreEqual(StatusCode.OK, ex.StatusCode); + } + + [Test] + public async Task AsyncClientStreamingCall_OneMessageSentThenRetryThenAnotherMessage_RequestContentSent() + { + // Arrange + var requestContent = new MemoryStream(); + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + var content = (PushStreamContent)request.Content!; + + if (callCount == 1) + { + _ = content.CopyToAsync(new MemoryStream()); + + await syncPoint.WaitForSyncPoint(); + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + syncPoint.Continue(); + + await content.PushComplete.DefaultTimeout(); + await content.CopyToAsync(requestContent); + requestContent.Seek(0, SeekOrigin.Begin); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + Assert.IsNotNull(call); + + var responseTask = call.ResponseAsync; + Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete."); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout(); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var responseMessage = await responseTask.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("1", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("2", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.IsNull(requestMessage); + } + + [Test] + public async Task AsyncServerStreamingCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + + var content = request.Content!; + await content.CopyToAsync(requestContent); + requestContent.Seek(0, SeekOrigin.Begin); + + if (callCount == 1) + { + await syncPoint.WaitForSyncPoint(); + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + syncPoint.Continue(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncServerStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ServerStreaming), string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + var moveNextTask = call.ResponseStream.MoveNext(CancellationToken.None); + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + // Assert + Assert.IsTrue(await moveNextTask); + Assert.AreEqual("Hello world", call.ResponseStream.Current.Message); + + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("World", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.IsNull(requestMessage); + } + + [Test] + public async Task AsyncServerStreamingCall_FailureAfterReadingResponseMessage_Failure() + { + // Arrange + var streamContent = new SyncPointMemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(request => + { + callCount++; + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent))); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncServerStreamingCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest()); + + var responseStream = call.ResponseStream; + + // Assert + Assert.IsNull(responseStream.Current); + + var moveNextTask1 = responseStream.MoveNext(CancellationToken.None); + Assert.IsFalse(moveNextTask1.IsCompleted); + + await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(new HelloReply + { + Message = "Hello world 1" + }).DefaultTimeout()).DefaultTimeout(); + + Assert.IsTrue(await moveNextTask1.DefaultTimeout()); + Assert.IsNotNull(responseStream.Current); + Assert.AreEqual("Hello world 1", responseStream.Current.Message); + + var moveNextTask2 = responseStream.MoveNext(CancellationToken.None); + Assert.IsFalse(moveNextTask2.IsCompleted); + + await streamContent.AddExceptionAndWait(new Exception("Exception!")).DefaultTimeout(); + + var ex = await ExceptionAssert.ThrowsAsync(() => moveNextTask2).DefaultTimeout(); + Assert.AreEqual(StatusCode.Internal, ex.StatusCode); + Assert.AreEqual(StatusCode.Internal, call.GetStatus().StatusCode); + Assert.AreEqual("Error reading next message. Exception: Exception!", call.GetStatus().Detail); + } + + [Test] + public async Task AsyncDuplexStreamingCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + var requestContent = new MemoryStream(); + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + var content = (PushStreamContent)request.Content!; + + if (callCount == 1) + { + _ = content.CopyToAsync(new MemoryStream()); + + await syncPoint.WaitForSyncPoint(); + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + syncPoint.Continue(); + + await content.PushComplete.DefaultTimeout(); + await content.CopyToAsync(requestContent); + requestContent.Seek(0, SeekOrigin.Begin); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncDuplexStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.DuplexStreaming), string.Empty, new CallOptions()); + var moveNextTask = call.ResponseStream.MoveNext(CancellationToken.None); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout(); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + // Assert + Assert.IsTrue(await moveNextTask.DefaultTimeout()); + Assert.AreEqual("Hello world", call.ResponseStream.Current.Message); + + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("1", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("2", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.IsNull(requestMessage); + } + + private static Task ReadRequestMessage(Stream requestContent) + { + return StreamSerializationHelper.ReadMessageAsync( + requestContent, + ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, + GrpcProtocolConstants.IdentityGrpcEncoding, + maximumMessageSize: null, + GrpcProtocolConstants.DefaultCompressionProviders, + singleMessage: false, + CancellationToken.None); + } + } +} diff --git a/test/Grpc.Net.Client.Tests/ServiceConfigTests.cs b/test/Grpc.Net.Client.Tests/ServiceConfigTests.cs new file mode 100644 index 000000000..d83ab310a --- /dev/null +++ b/test/Grpc.Net.Client.Tests/ServiceConfigTests.cs @@ -0,0 +1,185 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; +using Grpc.Net.Client.Configuration; +using Grpc.Net.Client.Internal.Configuration; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests +{ + [TestFixture] + public class ServiceConfigTests + { + [Test] + public void ServiceConfig_CreateUnderlyingConfig() + { + // Arrange & Act + var serviceConfig = new ServiceConfig + { + MethodConfigs = + { + new MethodConfig + { + Names = { new MethodName() }, + RetryPolicy = new RetryPolicy + { + MaxAttempts = 5, + InitialBackoff = TimeSpan.FromSeconds(1), + RetryableStatusCodes = { StatusCode.Unavailable, StatusCode.Aborted } + } + } + } + }; + + // Assert + Assert.AreEqual(1, serviceConfig.MethodConfigs.Count); + Assert.AreEqual(1, serviceConfig.MethodConfigs[0].Names.Count); + Assert.AreEqual(5, serviceConfig.MethodConfigs[0].RetryPolicy!.MaxAttempts); + Assert.AreEqual(TimeSpan.FromSeconds(1), serviceConfig.MethodConfigs[0].RetryPolicy!.InitialBackoff); + Assert.AreEqual(StatusCode.Unavailable, serviceConfig.MethodConfigs[0].RetryPolicy!.RetryableStatusCodes[0]); + Assert.AreEqual(StatusCode.Aborted, serviceConfig.MethodConfigs[0].RetryPolicy!.RetryableStatusCodes[1]); + + var inner = serviceConfig.Inner; + var methodConfigs = (IList)inner["methodConfig"]; + var allServices = (IDictionary)methodConfigs[0]; + + Assert.AreEqual(5, (int)((IDictionary)allServices["retryPolicy"])["maxAttempts"]); + Assert.AreEqual("1s", (string)((IDictionary)allServices["retryPolicy"])["initialBackoff"]); + Assert.AreEqual("UNAVAILABLE", (string)((IList)((IDictionary)allServices["retryPolicy"])["retryableStatusCodes"])[0]); + Assert.AreEqual("ABORTED", (string)((IList)((IDictionary)allServices["retryPolicy"])["retryableStatusCodes"])[1]); + } + + [Test] + public void ServiceConfig_ReadUnderlyingConfig() + { + // Arrange + var inner = new Dictionary + { + ["methodConfig"] = new List + { + new Dictionary + { + ["name"] = new List { new Dictionary() }, + ["retryPolicy"] = new Dictionary + { + ["maxAttempts"] = 5, + ["initialBackoff"] = "1s", + ["retryableStatusCodes"] = new List { "UNAVAILABLE", "ABORTED" } + } + } + } + }; + + // Act + var serviceConfig = new ServiceConfig(inner); + + // Assert + Assert.AreEqual(1, serviceConfig.MethodConfigs.Count); + Assert.AreEqual(1, serviceConfig.MethodConfigs[0].Names.Count); + Assert.AreEqual(5, serviceConfig.MethodConfigs[0].RetryPolicy!.MaxAttempts); + Assert.AreEqual(TimeSpan.FromSeconds(1), serviceConfig.MethodConfigs[0].RetryPolicy!.InitialBackoff); + Assert.AreEqual(StatusCode.Unavailable, serviceConfig.MethodConfigs[0].RetryPolicy!.RetryableStatusCodes[0]); + Assert.AreEqual(StatusCode.Aborted, serviceConfig.MethodConfigs[0].RetryPolicy!.RetryableStatusCodes[1]); + } + + [Test] + public void RetryThrottlingPolicy_ReadUnderlyingConfig_Success() + { + // Arrange + var inner = new Dictionary + { + ["initialBackoff"] = "1.1s", + ["retryableStatusCodes"] = new List { "UNAVAILABLE", "Aborted", 1 } + }; + + // Act + var retryPolicy = new RetryPolicy(inner); + + // Assert + Assert.AreEqual(TimeSpan.FromSeconds(1.1), retryPolicy.InitialBackoff); + Assert.AreEqual(StatusCode.Unavailable, retryPolicy.RetryableStatusCodes[0]); + Assert.AreEqual(StatusCode.Aborted, retryPolicy.RetryableStatusCodes[1]); + Assert.AreEqual(StatusCode.Cancelled, retryPolicy.RetryableStatusCodes[2]); + } + + [TestCase("0s", 0)] + [TestCase("0.0s", 0)] + [TestCase("-0s", 0)] + [TestCase("1s", 1 * TimeSpan.TicksPerSecond)] + [TestCase("1.0s", 1 * TimeSpan.TicksPerSecond)] + [TestCase("1.1s", (long)(1.1 * TimeSpan.TicksPerSecond))] + [TestCase("-1s", -1 * TimeSpan.TicksPerSecond)] + [TestCase("3.0000001s", (long)(3.0000001 * TimeSpan.TicksPerSecond))] + [TestCase("315576000000s", (315576000000 * TimeSpan.TicksPerSecond))] + [TestCase("-315576000000s", (-315576000000 * TimeSpan.TicksPerSecond))] + public void ConvertDurationText_Success(string text, long ticks) + { + // Arrange & Act + var timespan = ConvertHelpers.ConvertDurationText(text); + + // Assert + Assert.AreEqual(ticks, timespan!.Value.Ticks); + } + + [TestCase("0s", null)] + [TestCase("0.0s", "0s")] + [TestCase("-0s", "0s")] + [TestCase("1s", null)] + [TestCase("1.0s", "1s")] + [TestCase("1.1s", null)] + [TestCase("-1s", null)] + [TestCase("3.0000001s", null)] + [TestCase("315576000000s", null)] + [TestCase("-315576000000s", null)] + public void Duration_Roundtrip(string text, string explicitResult) + { + // Arrange & Act + var timespan = ConvertHelpers.ConvertDurationText(text); + + // Assert + Assert.AreEqual(explicitResult ?? text, ConvertHelpers.ToDurationText(timespan)); + } + + [TestCase("")] + [TestCase("s")] + [TestCase("0")] + [TestCase("-")] + [TestCase("1xs")] + [TestCase("1,1s")] + [TestCase("1.2345678e7")] + [TestCase("1.2345678e7s")] + public void ConvertDurationText_Failure(string text) + { + // Arrange & Act + var ex = Assert.Throws(() => ConvertHelpers.ConvertDurationText(text)); + + // Assert + Assert.AreEqual($"'{text}' isn't a valid duration.", ex.Message); + } + + [Test] + public void MethodName_Default_ErrorOnChange() + { + // Arrange & Act & Assert + Assert.Throws(() => MethodName.Default.Method = "This will break"); + } + } +} diff --git a/test/Grpc.Net.Client.Web.Tests/Base64ResponseStreamTests.cs b/test/Grpc.Net.Client.Web.Tests/Base64ResponseStreamTests.cs index e8a56e658..0e87b91f1 100644 --- a/test/Grpc.Net.Client.Web.Tests/Base64ResponseStreamTests.cs +++ b/test/Grpc.Net.Client.Web.Tests/Base64ResponseStreamTests.cs @@ -48,7 +48,7 @@ public async Task ReadAsync_ReadLargeData_Success() var messageCount = 3; var streamContent = new List(); - for (int i = 0; i < messageCount; i++) + for (var i = 0; i < messageCount; i++) { streamContent.AddRange(messageContent); } @@ -56,7 +56,7 @@ public async Task ReadAsync_ReadLargeData_Success() var ms = new LimitedReadMemoryStream(streamContent.ToArray(), 3); var base64Stream = new Base64ResponseStream(ms); - for (int i = 0; i < messageCount; i++) + for (var i = 0; i < messageCount; i++) { // Assert 1 var resolvedHeaderData = await ReadContent(base64Stream, 5, CancellationToken.None); diff --git a/test/Shared/ClientTestHelpers.cs b/test/Shared/ClientTestHelpers.cs index 89f5fb182..81c904d18 100644 --- a/test/Shared/ClientTestHelpers.cs +++ b/test/Shared/ClientTestHelpers.cs @@ -17,6 +17,7 @@ #endregion using System; +using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Http; @@ -26,6 +27,7 @@ using Google.Protobuf; using Greet; using Grpc.Core; +using Grpc.Net.Client.Configuration; using Grpc.Net.Compression; namespace Grpc.Tests.Shared @@ -35,7 +37,12 @@ internal static class ClientTestHelpers public static readonly Marshaller HelloRequestMarshaller = Marshallers.Create(r => r.ToByteArray(), data => HelloRequest.Parser.ParseFrom(data)); public static readonly Marshaller HelloReplyMarshaller = Marshallers.Create(r => r.ToByteArray(), data => HelloReply.Parser.ParseFrom(data)); - public static readonly Method ServiceMethod = new Method(MethodType.Unary, "ServiceName", "MethodName", HelloRequestMarshaller, HelloReplyMarshaller); + public static readonly Method ServiceMethod = GetServiceMethod(MethodType.Unary); + + public static Method GetServiceMethod(MethodType? methodType = null, Marshaller? requestMarshaller = null) + { + return new Method(methodType ?? MethodType.Unary, "ServiceName", "MethodName", requestMarshaller ?? HelloRequestMarshaller, HelloReplyMarshaller); + } public static TestHttpMessageHandler CreateTestMessageHandler(HelloReply reply) { diff --git a/test/Shared/ExceptionAssert.cs b/test/Shared/ExceptionAssert.cs index 0317c2908..0a0684e12 100644 --- a/test/Shared/ExceptionAssert.cs +++ b/test/Shared/ExceptionAssert.cs @@ -27,6 +27,11 @@ public static class ExceptionAssert public static async Task ThrowsAsync(Func action, params string[] possibleMessages) where TException : Exception { + if (action == null) + { + throw new ArgumentNullException(nameof(action)); + } + try { await action(); diff --git a/test/Shared/ServiceConfigHelpers.cs b/test/Shared/ServiceConfigHelpers.cs new file mode 100644 index 000000000..0dd67aae1 --- /dev/null +++ b/test/Shared/ServiceConfigHelpers.cs @@ -0,0 +1,108 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; +using Grpc.Net.Client.Configuration; + +namespace Grpc.Tests.Shared +{ + internal static class ServiceConfigHelpers + { + public static ServiceConfig CreateRetryServiceConfig( + int? maxAttempts = null, + TimeSpan? initialBackoff = null, + TimeSpan? maxBackoff = null, + double? backoffMultiplier = null, + IList? retryableStatusCodes = null, + RetryThrottlingPolicy? retryThrottling = null) + { + var retryPolicy = new RetryPolicy + { + MaxAttempts = maxAttempts ?? 5, + InitialBackoff = initialBackoff ?? TimeSpan.FromMilliseconds(1), + MaxBackoff = maxBackoff ?? TimeSpan.FromMilliseconds(1), + BackoffMultiplier = backoffMultiplier ?? 1 + }; + + if (retryableStatusCodes != null) + { + foreach (var statusCode in retryableStatusCodes) + { + retryPolicy.RetryableStatusCodes.Add(statusCode); + } + } + else + { + retryPolicy.RetryableStatusCodes.Add(StatusCode.Unavailable); + } + + return new ServiceConfig + { + MethodConfigs = + { + new MethodConfig + { + Names = { MethodName.Default }, + RetryPolicy = retryPolicy + } + }, + RetryThrottling = retryThrottling + }; + } + + public static ServiceConfig CreateHedgingServiceConfig( + int? maxAttempts = null, + TimeSpan? hedgingDelay = null, + IList? nonFatalStatusCodes = null, + RetryThrottlingPolicy? retryThrottling = null) + { + var hedgingPolicy = new HedgingPolicy + { + MaxAttempts = maxAttempts ?? 5, + HedgingDelay = hedgingDelay ?? TimeSpan.Zero + }; + + if (nonFatalStatusCodes != null) + { + foreach (var statusCode in nonFatalStatusCodes) + { + hedgingPolicy.NonFatalStatusCodes.Add(statusCode); + } + } + else + { + hedgingPolicy.NonFatalStatusCodes.Add(StatusCode.Unavailable); + } + + return new ServiceConfig + { + MethodConfigs = + { + new MethodConfig + { + Names = { MethodName.Default }, + HedgingPolicy = hedgingPolicy + } + }, + RetryThrottling = retryThrottling + }; + } + } +} diff --git a/test/Shared/TestHelpers.cs b/test/Shared/TestHelpers.cs index a80f6bcbb..bab8b0803 100644 --- a/test/Shared/TestHelpers.cs +++ b/test/Shared/TestHelpers.cs @@ -33,9 +33,9 @@ public static string ResolvePath(string relativePath) public static async Task AssertIsTrueRetryAsync(Func assert, string message) { - const int Retrys = 10; + const int Retries = 10; - for (int i = 0; i < Retrys; i++) + for (var i = 0; i < Retries; i++) { if (i > 0) { @@ -48,13 +48,13 @@ public static async Task AssertIsTrueRetryAsync(Func assert, string messag } } - throw new Exception($"Assert failed after {Retrys} retries: {message}"); + throw new Exception($"Assert failed after {Retries} retries: {message}"); } public static async Task RunParallel(int count, Func action) { var actionTasks = new Task[count]; - for (int i = 0; i < actionTasks.Length; i++) + for (var i = 0; i < actionTasks.Length; i++) { actionTasks[i] = action(i); } diff --git a/testassets/InteropTestsWebsite/TestServiceImpl.cs b/testassets/InteropTestsWebsite/TestServiceImpl.cs index 86123d4b2..dbe7bddf0 100644 --- a/testassets/InteropTestsWebsite/TestServiceImpl.cs +++ b/testassets/InteropTestsWebsite/TestServiceImpl.cs @@ -71,6 +71,7 @@ await requestStream.ForEachAsync(request => sum += request.Payload.Body.Length; return Task.CompletedTask; }); + return new StreamingInputCallResponse { AggregatedPayloadSize = sum }; }