diff --git a/src/Microsoft.OData.Core/ODataMessageReader.cs b/src/Microsoft.OData.Core/ODataMessageReader.cs index 5b66f67bdb..00338fa61c 100644 --- a/src/Microsoft.OData.Core/ODataMessageReader.cs +++ b/src/Microsoft.OData.Core/ODataMessageReader.cs @@ -904,24 +904,18 @@ private ODataMessageInfo GetOrCreateMessageInfo(Stream messageStream, bool isAsy { if (this.messageInfo == null) { - if (this.serviceProvider == null) + this.messageInfo = new ODataMessageInfo { - this.messageInfo = new ODataMessageInfo(); - } - else - { - this.messageInfo = this.serviceProvider.GetRequiredService(); - } - - this.messageInfo.Encoding = this.encoding; - this.messageInfo.IsResponse = this.readingResponse; - this.messageInfo.IsAsync = isAsync; - this.messageInfo.MediaType = this.contentType; - this.messageInfo.Model = this.model; - this.messageInfo.PayloadUriConverter = this.payloadUriConverter; - this.messageInfo.ServiceProvider = this.serviceProvider; - this.messageInfo.MessageStream = messageStream; - this.messageInfo.PayloadKind = this.readerPayloadKind; + Encoding = this.encoding, + IsResponse = this.readingResponse, + IsAsync = isAsync, + MediaType = this.contentType, + Model = this.model, + PayloadUriConverter = this.payloadUriConverter, + ServiceProvider = this.serviceProvider, + MessageStream = messageStream, + PayloadKind = this.readerPayloadKind + }; } return this.messageInfo; diff --git a/src/Microsoft.OData.Core/ODataMessageWriter.cs b/src/Microsoft.OData.Core/ODataMessageWriter.cs index aa3ade2f22..fef522f3a2 100644 --- a/src/Microsoft.OData.Core/ODataMessageWriter.cs +++ b/src/Microsoft.OData.Core/ODataMessageWriter.cs @@ -19,9 +19,9 @@ namespace Microsoft.OData using Microsoft.OData.Metadata; #endregion Namespaces -/// -/// Writer class used to write all OData payloads (entries, resource sets, metadata documents, service documents, etc.). -/// + /// + /// Writer class used to write all OData payloads (entries, resource sets, metadata documents, service documents, etc.). + /// #if NETCOREAPP public sealed class ODataMessageWriter : IDisposable, IAsyncDisposable #else @@ -1111,7 +1111,7 @@ private ODataPayloadKind VerifyCanWriteValue(object value) // We cannot use ODataRawValueUtils.TryConvertPrimitiveToString for all cases since binary values are // converted into unencoded byte streams in the raw format // (as opposed to base64 encoded byte streams in the ODataRawValueUtils); see OIPI 2.2.6.4.1. - return value is byte[] ? ODataPayloadKind.BinaryValue : ODataPayloadKind.Value; + return value is byte[]? ODataPayloadKind.BinaryValue : ODataPayloadKind.Value; } /// @@ -1227,23 +1227,18 @@ private ODataMessageInfo GetOrCreateMessageInfo(Stream messageStream, bool isAsy { if (this.messageInfo == null) { - if (this.serviceProvider == null) - { - this.messageInfo = new ODataMessageInfo(); - } - else - { - this.messageInfo = this.serviceProvider.GetRequiredService(); - } - this.messageInfo.Encoding = this.encoding; - this.messageInfo.IsResponse = this.writingResponse; - this.messageInfo.IsAsync = isAsync; - this.messageInfo.MediaType = this.mediaType; - this.messageInfo.Model = this.model; - this.messageInfo.PayloadUriConverter = this.payloadUriConverter; - this.messageInfo.ServiceProvider = this.serviceProvider; - this.messageInfo.MessageStream = messageStream; + this.messageInfo = new ODataMessageInfo + { + Encoding = this.encoding, + IsResponse = this.writingResponse, + IsAsync = isAsync, + MediaType = this.mediaType, + Model = this.model, + PayloadUriConverter = this.payloadUriConverter, + ServiceProvider = this.serviceProvider, + MessageStream = messageStream + }; } return this.messageInfo; @@ -1321,7 +1316,7 @@ private async Task WriteToOutputAsync(ODataPayloadKind payload this.outputContext = await this.format.CreateOutputContextAsync( this.GetOrCreateMessageInfo(messageStream, true), this.settings).ConfigureAwait(false); - + return await writeFunc(this.outputContext).ConfigureAwait(false); } } diff --git a/test/FunctionalTests/Microsoft.OData.Core.Tests/MessageWriterConcurrencyTests.cs b/test/FunctionalTests/Microsoft.OData.Core.Tests/MessageWriterConcurrencyTests.cs new file mode 100644 index 0000000000..46483699e9 --- /dev/null +++ b/test/FunctionalTests/Microsoft.OData.Core.Tests/MessageWriterConcurrencyTests.cs @@ -0,0 +1,112 @@ +//--------------------------------------------------------------------- +// +// Copyright (C) Microsoft Corporation. All rights reserved. See License.txt in the project root for license information. +// +//--------------------------------------------------------------------- + +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using System; +using Xunit; +using Microsoft.Extensions.DependencyInjection; +using System.Linq; + +namespace Microsoft.OData.Core.Tests +{ + public class MessageWriterConcurrencyTests + { + /// + /// Verifies that concurrent message writer does not interleave execution and isolates the underlying streams. + /// + /// A task for the asyncronous test + + [Fact] + public async Task VerifyConcurrentResultsAreConsistentAsync() + { + ServiceCollection services = new(); + services.AddDefaultODataServices(); + ServiceProvider serviceProvider = services.BuildServiceProvider(); + + await Task.CompletedTask; + var content1 = string.Concat(Enumerable.Repeat('A', 1000_000)); + var content2 = string.Concat(Enumerable.Repeat('B', 1000_000)); + for (int i = 0; i < 1000; i++) + { + var values = await Task.WhenAll([WritePayload(content1, serviceProvider), WritePayload(content2, serviceProvider)]); + Assert.Equal(content1.Length, values[0].Length); + Assert.Equal(content2.Length, values[1].Length); + + Assert.Equal(content1, values[0]); + Assert.Equal(content2, values[1]); + } + } + + + /// + /// A helper function that writes to a strem using the message writer and returns the content that is present in the stream. + /// + /// String content to write. + /// A service provider with the default configurations. + /// A task that resolves to the string present in the output stream. + private async Task WritePayload(string content, IServiceProvider serviceProvider) + { + using Stream outputStream = new MemoryStream(); + + var message = new ODataMessage(outputStream, serviceProvider); + await using ODataMessageWriter writer = new ODataMessageWriter(message); + await Task.Yield(); + + await writer.WriteValueAsync(content); + + outputStream.Position = 0; + using var reader = new StreamReader(outputStream); + await Task.Yield(); + string writen = await reader.ReadToEndAsync(); + await writer.DisposeAsync(); + return writen; + } + + + class ODataMessage : IODataResponseMessage, IODataResponseMessageAsync, IServiceCollectionProvider + { + private Dictionary _headers = new(); + private Stream _outputStream; + public ODataMessage(Stream outputStream, IServiceProvider serviceProvider) + { + this.ServiceProvider = serviceProvider; + _outputStream = outputStream; + } + public IEnumerable> Headers => _headers; + + public int StatusCode { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public IServiceProvider ServiceProvider { get; private set; } + + public string GetHeader(string headerName) + { + if (_headers.TryGetValue(headerName, out var value)) + { + return value; + } + + return null; + } + + public Stream GetStream() + { + return _outputStream; + } + + public Task GetStreamAsync() + { + return Task.FromResult(_outputStream); + } + + public void SetHeader(string headerName, string headerValue) + { + _headers[headerName] = headerValue; + } + } + } +}