Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in HttpHeaders parsing #103263

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -347,17 +347,14 @@ private IEnumerator<KeyValuePair<string, IEnumerable<string>>> GetEnumeratorCore
// during enumeration so that we can parse the raw value in order to a) return
// the correct set of parsed values, and b) update the instance for subsequent enumerations
// to reflect that parsing.
info = new HeaderStoreItemInfo() { RawValue = entry.Value };

if (EntriesAreLiveView)
{
entries[i].Value = info;
}
else
{
Debug.Assert(Contains(entry.Key));
((Dictionary<HeaderDescriptor, object>)_headerStore!)[entry.Key] = info;
}
#nullable disable // https://github.com/dotnet/roslyn/issues/73928
ref object storeValueRef = ref EntriesAreLiveView
? ref entries[i].Value
: ref CollectionsMarshal.GetValueRefOrNullRef((Dictionary<HeaderDescriptor, object>)_headerStore, entry.Key);

info = ReplaceWithHeaderStoreItemInfo(ref storeValueRef, entry.Value);
#nullable restore
}

// Make sure we parse all raw values before returning the result. Note that this has to be
Expand Down Expand Up @@ -729,15 +726,10 @@ private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)]
if (!Unsafe.IsNullRef(ref storeValueRef))
{
object value = storeValueRef;
if (value is HeaderStoreItemInfo hsi)
{
info = hsi;
}
else
{
Debug.Assert(value is string);
storeValueRef = info = new HeaderStoreItemInfo() { RawValue = value };
}

info = value is HeaderStoreItemInfo hsi
? hsi
: ReplaceWithHeaderStoreItemInfo(ref storeValueRef, value);

ParseRawHeaderValues(key, info);
return true;
Expand All @@ -747,6 +739,31 @@ private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)]
return false;
}

/// <summary>
/// Replaces <paramref name="storeValueRef"/> with a new <see cref="HeaderStoreItemInfo"/>,
/// or returns the existing <see cref="HeaderStoreItemInfo"/> if a different thread beat us to it.
/// </summary>
/// <remarks>
/// This helper should be used any time we're upgrading a storage slot from an unparsed string to a HeaderStoreItemInfo *while reading*.
/// Concurrent writes to the header collection are UB, so we don't need to worry about race conditions when doing the replacement there.
/// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static HeaderStoreItemInfo ReplaceWithHeaderStoreItemInfo(ref object storeValueRef, object value)
{
Debug.Assert(value is string);

var info = new HeaderStoreItemInfo() { RawValue = value };
object previousValue = Interlocked.CompareExchange(ref storeValueRef, info, value);

if (ReferenceEquals(previousValue, value))
{
return info;
}

// Rare race condition: Another thread replaced the value with a HeaderStoreItemInfo.
return (HeaderStoreItemInfo)previousValue;
}

private static void ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemInfo info)
{
// Unlike TryGetHeaderInfo() this method tries to parse all non-validated header values (if any)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Linq;
using System.Net.Http.Headers;
Expand Down Expand Up @@ -2502,6 +2503,51 @@ static HttpRequestHeaders CreateHeaders()
}
}

[Theory]
[InlineData(true, true)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(false, false)]
public async Task ConcurrentReads_ReturnTheSameParsedValues(bool useDictionary, bool useTypedProperty)
{
HttpContentHeaders dummyValues = new ByteArrayContent([]).Headers;
if (useDictionary)
{
for (int i = 0; i < HttpHeaders.ArrayThreshold; i++)
{
Assert.True(dummyValues.TryAddWithoutValidation($"foo-{i}", "Foo"));
}
}

Stopwatch s = Stopwatch.StartNew();

while (s.ElapsedMilliseconds < 100)
{
HttpContentHeaders headers = new ByteArrayContent([]).Headers;

headers.AddHeaders(dummyValues);

Assert.True(headers.TryAddWithoutValidation("Content-Type", "application/json; charset=utf-8"));

if (useTypedProperty)
{
Task<MediaTypeHeaderValue> task = Task.Run(() => headers.ContentType);
MediaTypeHeaderValue contentType1 = headers.ContentType;
MediaTypeHeaderValue contentType2 = await task;

Assert.Same(contentType1, contentType2);
}
else
{
Task task = Task.Run(() => headers.Count()); // Force enumeration
MediaTypeHeaderValue contentType1 = headers.ContentType;
await task;

Assert.Same(contentType1, headers.ContentType);
}
}
}

[Fact]
public void TryAddInvalidHeader_ShouldThrowFormatException()
{
Expand Down
Loading