Skip to content

Commit

Permalink
Update message writer and reader to ignore Message info from DI as it…
Browse files Browse the repository at this point in the history
… is overwritten.
  • Loading branch information
marabooy committed Sep 9, 2024
1 parent 8f63289 commit 78fd2d3
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 38 deletions.
28 changes: 11 additions & 17 deletions src/Microsoft.OData.Core/ODataMessageReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -904,24 +904,18 @@ private ODataMessageInfo GetOrCreateMessageInfo(Stream messageStream, bool isAsy
{
if (this.messageInfo == null)
{
if (this.serviceProvider == null)
this.messageInfo = new ODataMessageInfo
{
this.messageInfo = new ODataMessageInfo();
}
else
{
this.messageInfo = this.serviceProvider.GetRequiredService<ODataMessageInfo>();
}

this.messageInfo.Encoding = this.encoding;
this.messageInfo.IsResponse = this.readingResponse;
this.messageInfo.IsAsync = isAsync;
this.messageInfo.MediaType = this.contentType;
this.messageInfo.Model = this.model;
this.messageInfo.PayloadUriConverter = this.payloadUriConverter;
this.messageInfo.ServiceProvider = this.serviceProvider;
this.messageInfo.MessageStream = messageStream;
this.messageInfo.PayloadKind = this.readerPayloadKind;
Encoding = this.encoding,
IsResponse = this.readingResponse,
IsAsync = isAsync,
MediaType = this.contentType,
Model = this.model,
PayloadUriConverter = this.payloadUriConverter,
ServiceProvider = this.serviceProvider,
MessageStream = messageStream,
PayloadKind = this.readerPayloadKind
};
}

return this.messageInfo;
Expand Down
37 changes: 16 additions & 21 deletions src/Microsoft.OData.Core/ODataMessageWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ namespace Microsoft.OData
using Microsoft.OData.Metadata;
#endregion Namespaces

/// <summary>
/// Writer class used to write all OData payloads (entries, resource sets, metadata documents, service documents, etc.).
/// </summary>
/// <summary>
/// Writer class used to write all OData payloads (entries, resource sets, metadata documents, service documents, etc.).
/// </summary>
#if NETCOREAPP
public sealed class ODataMessageWriter : IDisposable, IAsyncDisposable
#else
Expand Down Expand Up @@ -1111,7 +1111,7 @@ private ODataPayloadKind VerifyCanWriteValue(object value)
// We cannot use ODataRawValueUtils.TryConvertPrimitiveToString for all cases since binary values are
// converted into unencoded byte streams in the raw format
// (as opposed to base64 encoded byte streams in the ODataRawValueUtils); see OIPI 2.2.6.4.1.
return value is byte[] ? ODataPayloadKind.BinaryValue : ODataPayloadKind.Value;
return value is byte[]? ODataPayloadKind.BinaryValue : ODataPayloadKind.Value;
}

/// <summary>
Expand Down Expand Up @@ -1227,23 +1227,18 @@ private ODataMessageInfo GetOrCreateMessageInfo(Stream messageStream, bool isAsy
{
if (this.messageInfo == null)
{
if (this.serviceProvider == null)
{
this.messageInfo = new ODataMessageInfo();
}
else
{
this.messageInfo = this.serviceProvider.GetRequiredService<ODataMessageInfo>();
}

this.messageInfo.Encoding = this.encoding;
this.messageInfo.IsResponse = this.writingResponse;
this.messageInfo.IsAsync = isAsync;
this.messageInfo.MediaType = this.mediaType;
this.messageInfo.Model = this.model;
this.messageInfo.PayloadUriConverter = this.payloadUriConverter;
this.messageInfo.ServiceProvider = this.serviceProvider;
this.messageInfo.MessageStream = messageStream;
this.messageInfo = new ODataMessageInfo
{
Encoding = this.encoding,
IsResponse = this.writingResponse,
IsAsync = isAsync,
MediaType = this.mediaType,
Model = this.model,
PayloadUriConverter = this.payloadUriConverter,
ServiceProvider = this.serviceProvider,
MessageStream = messageStream
};
}

return this.messageInfo;
Expand Down Expand Up @@ -1321,7 +1316,7 @@ private async Task<TResult> WriteToOutputAsync<TResult>(ODataPayloadKind payload
this.outputContext = await this.format.CreateOutputContextAsync(
this.GetOrCreateMessageInfo(messageStream, true),
this.settings).ConfigureAwait(false);

return await writeFunc(this.outputContext).ConfigureAwait(false);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//---------------------------------------------------------------------
// <copyright file="MessageWriterConcurrencyTests.cs" company="Microsoft">
// Copyright (C) Microsoft Corporation. All rights reserved. See License.txt in the project root for license information.
// </copyright>
//---------------------------------------------------------------------

using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;
using System;
using Xunit;
using Microsoft.Extensions.DependencyInjection;
using System.Linq;

namespace Microsoft.OData.Core.Tests
{
public class MessageWriterConcurrencyTests
{
/// <summary>
/// Verifies that concurrent message writer does not interleave execution and isolates the underlying streams.
/// </summary>
/// <returns>A task for the asyncronous test</returns>

[Fact]
public async Task VerifyConcurrentResultsAreConsistentAsync()
{
ServiceCollection services = new();
services.AddDefaultODataServices();
ServiceProvider serviceProvider = services.BuildServiceProvider();

await Task.CompletedTask;
var content1 = string.Concat(Enumerable.Repeat('A', 1000_000));
var content2 = string.Concat(Enumerable.Repeat('B', 1000_000));
for (int i = 0; i < 1000; i++)
{
var values = await Task.WhenAll([WritePayload(content1, serviceProvider), WritePayload(content2, serviceProvider)]);
Assert.Equal(content1.Length, values[0].Length);
Assert.Equal(content2.Length, values[1].Length);

Assert.Equal(content1, values[0]);
Assert.Equal(content2, values[1]);
}
}


/// <summary>
/// A helper function that writes to a strem using the message writer and returns the content that is present in the stream.
/// </summary>
/// <param name="content">String content to write.</param>
/// <param name="serviceProvider">A service provider with the default configurations.</param>
/// <returns>A task that resolves to the string present in the output stream.</returns>
private async Task<string> WritePayload(string content, IServiceProvider serviceProvider)
{
using Stream outputStream = new MemoryStream();

var message = new ODataMessage(outputStream, serviceProvider);
await using ODataMessageWriter writer = new ODataMessageWriter(message);
await Task.Yield();

await writer.WriteValueAsync(content);

outputStream.Position = 0;
using var reader = new StreamReader(outputStream);
await Task.Yield();
string writen = await reader.ReadToEndAsync();
await writer.DisposeAsync();
return writen;
}


class ODataMessage : IODataResponseMessage, IODataResponseMessageAsync, IServiceCollectionProvider
{
private Dictionary<string, string> _headers = new();
private Stream _outputStream;
public ODataMessage(Stream outputStream, IServiceProvider serviceProvider)
{
this.ServiceProvider = serviceProvider;
_outputStream = outputStream;
}
public IEnumerable<KeyValuePair<string, string>> Headers => _headers;

public int StatusCode { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public IServiceProvider ServiceProvider { get; private set; }

public string GetHeader(string headerName)
{
if (_headers.TryGetValue(headerName, out var value))
{
return value;
}

return null;
}

public Stream GetStream()
{
return _outputStream;
}

public Task<Stream> GetStreamAsync()
{
return Task.FromResult(_outputStream);
}

public void SetHeader(string headerName, string headerValue)
{
_headers[headerName] = headerValue;
}
}
}
}

0 comments on commit 78fd2d3

Please sign in to comment.