Skip to content

Commit

Permalink
Remove class constraint from Interlocked.{Compare}Exchange
Browse files Browse the repository at this point in the history
Today `Interlocked.CompareExchange<T>` and `Interlocked.Exchange<T>` support only reference type `T`s. Now that we have corresponding {Compare}Exchange methods that support types of size 1, 2, 4, and 8, we can remove the constraint and support any `T` that's either a reference type, a primitive type, or an enum type, making the generic overload more useful and avoiding consumers needing to choose less-than-ideal types just because of the need for atomicity with Interlocked.{Compare}Exchange.
  • Loading branch information
stephentoub committed Jul 8, 2024
1 parent 670d11f commit a58f945
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 139 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -102,21 +103,6 @@ public static long Exchange(ref long location1, long value)
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.InternalCall)]
private static extern object? ExchangeObject([NotNullIfNotNull(nameof(value))] ref object? location1, object? value);

// The below whole method reduces to a single call to Exchange(ref object, object) but
// the JIT thinks that it will generate more native code than it actually does.

/// <summary>Sets a variable of the specified type <typeparamref name="T"/> to a specified value and returns the original value, as an atomic operation.</summary>
/// <param name="location1">The variable to set to the specified value.</param>
/// <param name="value">The value to which the <paramref name="location1"/> parameter is set.</param>
/// <returns>The original value of <paramref name="location1"/>.</returns>
/// <exception cref="NullReferenceException">The address of location1 is a null pointer.</exception>
/// <typeparam name="T">The type to be used for <paramref name="location1"/> and <paramref name="value"/>. This type must be a reference type.</typeparam>
[Intrinsic]
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T Exchange<T>([NotNullIfNotNull(nameof(value))] ref T location1, T value) where T : class? =>
Unsafe.As<T>(Exchange(ref Unsafe.As<T, object?>(ref location1), value));
#endregion

#region CompareExchange
Expand Down Expand Up @@ -183,29 +169,6 @@ public static long CompareExchange(ref long location1, long value, long comparan
[MethodImpl(MethodImplOptions.InternalCall)]
[return: NotNullIfNotNull(nameof(location1))]
private static extern object? CompareExchangeObject(ref object? location1, object? value, object? comparand);

// Note that getILIntrinsicImplementationForInterlocked() in vm\jitinterface.cpp replaces
// the body of the following method with the following IL:
// ldarg.0
// ldarg.1
// ldarg.2
// call System.Threading.Interlocked::CompareExchange(ref Object, Object, Object)
// ret
// The workaround is no longer strictly necessary now that we have Unsafe.As but it does
// have the advantage of being less sensitive to JIT's inliner decisions.

/// <summary>Compares two instances of the specified reference type <typeparamref name="T"/> for reference equality and, if they are equal, replaces the first one.</summary>
/// <param name="location1">The destination, whose value is compared by reference with <paramref name="comparand"/> and possibly replaced.</param>
/// <param name="value">The value that replaces the destination value if the comparison by reference results in equality.</param>
/// <param name="comparand">The object that is compared by reference to the value at <paramref name="location1"/>.</param>
/// <returns>The original value in <paramref name="location1"/>.</returns>
/// <exception cref="NullReferenceException">The address of <paramref name="location1"/> is a null pointer.</exception>
/// <typeparam name="T">The type to be used for <paramref name="location1"/>, <paramref name="value"/>, and <paramref name="comparand"/>. This type must be a reference type.</typeparam>
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
public static T CompareExchange<T>(ref T location1, T value, T comparand) where T : class? =>
Unsafe.As<T>(CompareExchange(ref Unsafe.As<T, object?>(ref location1), value, comparand));
#endregion

#region Add
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ public static long CompareExchange(ref long location1, long value, long comparan
#endif
}

[Intrinsic]
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T CompareExchange<T>(ref T location1, T value, T comparand) where T : class?
{
return Unsafe.As<T>(CompareExchange(ref Unsafe.As<T, object?>(ref location1), value, comparand));
}

[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
Expand Down Expand Up @@ -92,16 +84,6 @@ public static long Exchange(ref long location1, long value)
#endif
}

[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
public static T Exchange<T>([NotNullIfNotNull(nameof(value))] ref T location1, T value) where T : class?
{
if (Unsafe.IsNullRef(ref location1))
ThrowHelper.ThrowNullReferenceException();
return Unsafe.As<T>(RuntimeImports.InterlockedExchange(ref Unsafe.As<T, object?>(ref location1), value));
}

[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
Expand Down
44 changes: 0 additions & 44 deletions src/coreclr/vm/jitinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7239,46 +7239,6 @@ bool getILIntrinsicImplementationForUnsafe(MethodDesc * ftn,
return false;
}

bool getILIntrinsicImplementationForInterlocked(MethodDesc * ftn,
CORINFO_METHOD_INFO * methInfo)
{
STANDARD_VM_CONTRACT;

_ASSERTE(CoreLibBinder::IsClass(ftn->GetMethodTable(), CLASS__INTERLOCKED));

// We are only interested if ftn's token and CompareExchange<T> token match
if (ftn->GetMemberDef() != CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_T)->GetMemberDef())
return false;

// Get MethodDesc for non-generic System.Threading.Interlocked.CompareExchange()
MethodDesc* cmpxchgObject = CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_OBJECT);

// Setup up the body of the method
static BYTE il[] = {
CEE_LDARG_0,
CEE_LDARG_1,
CEE_LDARG_2,
CEE_CALL,0,0,0,0,
CEE_RET
};

// Get the token for non-generic System.Threading.Interlocked.CompareExchange(), and patch [target]
mdMethodDef cmpxchgObjectToken = cmpxchgObject->GetMemberDef();
il[4] = (BYTE)((int)cmpxchgObjectToken >> 0);
il[5] = (BYTE)((int)cmpxchgObjectToken >> 8);
il[6] = (BYTE)((int)cmpxchgObjectToken >> 16);
il[7] = (BYTE)((int)cmpxchgObjectToken >> 24);

// Initialize methInfo
methInfo->ILCode = const_cast<BYTE*>(il);
methInfo->ILCodeSize = sizeof(il);
methInfo->maxStack = 3;
methInfo->EHcount = 0;
methInfo->options = (CorInfoOptions)0;

return true;
}

bool IsBitwiseEquatable(TypeHandle typeHandle, MethodTable * methodTable)
{
if (!methodTable->IsValueType() ||
Expand Down Expand Up @@ -7628,10 +7588,6 @@ static void getMethodInfoHelper(
{
fILIntrinsic = getILIntrinsicImplementationForUnsafe(ftn, methInfo);
}
else if (CoreLibBinder::IsClass(pMT, CLASS__INTERLOCKED))
{
fILIntrinsic = getILIntrinsicImplementationForInterlocked(ftn, methInfo);
}
else if (CoreLibBinder::IsClass(pMT, CLASS__RUNTIME_HELPERS))
{
fILIntrinsic = getILIntrinsicImplementationForRuntimeHelpers(ftn, methInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4322,6 +4322,9 @@
<data name="NotSupported_EmitDebugInfo" xml:space="preserve">
<value>Emitting debug info is not supported for this member.</value>
</data>
<data name="NotSupported_ReferenceEnumOrPrimitiveTypeRequired" xml:space="preserve">
<value>The specified type must be a reference type, an enum type, or a primitive type.</value>
</data>
<data name="Argument_BadFieldForInitializeArray" xml:space="preserve">
<value>The field is invalid for initializing array or span.</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Threading
{
Expand Down Expand Up @@ -222,6 +224,65 @@ public static UIntPtr Exchange(ref UIntPtr location1, UIntPtr value)
return (UIntPtr)Exchange(ref Unsafe.As<UIntPtr, int>(ref location1), (int)value);
#endif
}

/// <summary>Sets a variable of the specified type <typeparamref name="T"/> to a specified value and returns the original value, as an atomic operation.</summary>
/// <param name="location1">The variable to set to the specified value.</param>
/// <param name="value">The value to which the <paramref name="location1"/> parameter is set.</param>
/// <returns>The original value of <paramref name="location1"/>.</returns>
/// <exception cref="NullReferenceException">The address of location1 is a null pointer.</exception>
/// <exception cref="NotSupportedException">An unsupported <typeparamref name="T"/> is specified.</exception>
/// <typeparam name="T">
/// The type to be used for <paramref name="location1"/> and <paramref name="value"/>.
/// This type must be a reference type, an enum type (i.e. typeof(T).IsEnum is true), or a primitive type (i.e. typeof(T).IsPrimitive is true).
/// </typeparam>
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T Exchange<T>([NotNullIfNotNull(nameof(value))] ref T location1, T value)
{
// Handle all reference types with CompareExchange(ref object, ...).
if (!typeof(T).IsValueType)
{
object? result = Exchange(ref Unsafe.As<T, object?>(ref location1), value);
return Unsafe.As<object?, T>(ref result);
}

// Handle everything else with a CompareExchange overload for the unsigned integral type of the corresponding size.
// Only primitive types and enum types (which are backed by primitive types) are supported.
if (!typeof(T).IsPrimitive && !typeof(T).IsEnum)
{
throw new NotSupportedException(SR.NotSupported_ReferenceEnumOrPrimitiveTypeRequired);
}

if (Unsafe.SizeOf<T>() == 1)
{
return Unsafe.BitCast<byte, T>(
Exchange(
ref Unsafe.As<T, byte>(ref location1),
Unsafe.BitCast<T, byte>(value)));
}

if (Unsafe.SizeOf<T>() == 2)
{
return Unsafe.BitCast<ushort, T>(
Exchange(
ref Unsafe.As<T, ushort>(ref location1),
Unsafe.BitCast<T, ushort>(value)));
}

if (Unsafe.SizeOf<T>() == 4)
{
return Unsafe.BitCast<uint, T>(
Exchange(
ref Unsafe.As<T, uint>(ref location1),
Unsafe.BitCast<T, uint>(value)));
}

Debug.Assert(Unsafe.SizeOf<T>() == 8);
return Unsafe.BitCast<ulong, T>(
Exchange(
ref Unsafe.As<T, ulong>(ref location1),
Unsafe.BitCast<T, ulong>(value)));
}
#endregion

#region CompareExchange
Expand Down Expand Up @@ -413,6 +474,70 @@ public static UIntPtr CompareExchange(ref UIntPtr location1, UIntPtr value, UInt
return (UIntPtr)CompareExchange(ref Unsafe.As<UIntPtr, int>(ref location1), (int)value, (int)comparand);
#endif
}

/// <summary>Compares two instances of the specified type <typeparamref name="T"/> for equality and, if they are equal, replaces the first one.</summary>
/// <param name="location1">The destination, whose value is compared with <paramref name="comparand"/> and possibly replaced.</param>
/// <param name="value">The value that replaces the destination value if the comparison results in equality.</param>
/// <param name="comparand">The object that is compared to the value at <paramref name="location1"/>.</param>
/// <returns>The original value in <paramref name="location1"/>.</returns>
/// <exception cref="NullReferenceException">The address of <paramref name="location1"/> is a null pointer.</exception>
/// <exception cref="NotSupportedException">An unsupported <typeparamref name="T"/> is specified.</exception>
/// <typeparam name="T">
/// The type to be used for <paramref name="location1"/>, <paramref name="value"/>, and <paramref name="comparand"/>.
/// This type must be a reference type, an enum type (i.e. typeof(T).IsEnum is true), or a primitive type (i.e. typeof(T).IsPrimitive is true).
/// </typeparam>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
public static T CompareExchange<T>(ref T location1, T value, T comparand)
{
// Handle all reference types with CompareExchange(ref object, ...).
if (!typeof(T).IsValueType)
{
object? result = CompareExchange(ref Unsafe.As<T, object?>(ref location1), value, comparand);
return Unsafe.As<object?, T>(ref result);
}

// Handle everything else with a CompareExchange overload for the unsigned integral type of the corresponding size.
// Only primitive types and enum types (which are backed by primitive types) are supported.
if (!typeof(T).IsPrimitive && !typeof(T).IsEnum)
{
throw new NotSupportedException(SR.NotSupported_ReferenceEnumOrPrimitiveTypeRequired);
}

if (Unsafe.SizeOf<T>() == 1)
{
return Unsafe.BitCast<byte, T>(
CompareExchange(
ref Unsafe.As<T, byte>(ref location1),
Unsafe.BitCast<T, byte>(value),
Unsafe.BitCast<T, byte>(comparand)));
}

if (Unsafe.SizeOf<T>() == 2)
{
return Unsafe.BitCast<ushort, T>(
CompareExchange(
ref Unsafe.As<T, ushort>(ref location1),
Unsafe.BitCast<T, ushort>(value),
Unsafe.BitCast<T, ushort>(comparand)));
}

if (Unsafe.SizeOf<T>() == 4)
{
return Unsafe.BitCast<uint, T>(
CompareExchange(
ref Unsafe.As<T, uint>(ref location1),
Unsafe.BitCast<T, uint>(value),
Unsafe.BitCast<T, uint>(comparand)));
}

Debug.Assert(Unsafe.SizeOf<T>() == 8);
return Unsafe.BitCast<ulong, T>(
CompareExchange(
ref Unsafe.As<T, ulong>(ref location1),
Unsafe.BitCast<T, ulong>(value),
Unsafe.BitCast<T, ulong>(comparand)));
}
#endregion

#region Add
Expand Down
4 changes: 2 additions & 2 deletions src/libraries/System.Threading/ref/System.Threading.cs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ public static partial class Interlocked
[System.CLSCompliantAttribute(false)]
public static ulong CompareExchange(ref ulong location1, ulong value, ulong comparand) { throw null; }
[return: System.Diagnostics.CodeAnalysis.NotNullIfNotNullAttribute("location1")]
public static T CompareExchange<T>(ref T location1, T value, T comparand) where T : class? { throw null; }
public static T CompareExchange<T>(ref T location1, T value, T comparand) { throw null; }
public static int Decrement(ref int location) { throw null; }
public static long Decrement(ref long location) { throw null; }
[System.CLSCompliantAttribute(false)]
Expand All @@ -296,7 +296,7 @@ public static partial class Interlocked
[System.CLSCompliantAttribute(false)]
public static ulong Exchange(ref ulong location1, ulong value) { throw null; }
[return: System.Diagnostics.CodeAnalysis.NotNullIfNotNullAttribute("location1")]
public static T Exchange<T>([System.Diagnostics.CodeAnalysis.NotNullIfNotNullAttribute("value")] ref T location1, T value) where T : class? { throw null; }
public static T Exchange<T>([System.Diagnostics.CodeAnalysis.NotNullIfNotNullAttribute("value")] ref T location1, T value) { throw null; }
public static int Increment(ref int location) { throw null; }
public static long Increment(ref long location) { throw null; }
[System.CLSCompliantAttribute(false)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,47 +72,10 @@ public static partial class Interlocked
[MethodImplAttribute(MethodImplOptions.InternalCall)]
public static extern long CompareExchange(ref long location1, long value, long comparand);

[return: NotNullIfNotNull(nameof(location1))]
[Intrinsic]
public static T CompareExchange<T>(ref T location1, T value, T comparand) where T : class?
{
if (Unsafe.IsNullRef(ref location1))
throw new NullReferenceException();
// Besides avoiding coop handles for efficiency,
// and correctness, this also appears needed to
// avoid an assertion failure in the runtime, related to
// coop handles over generics.
//
// See CompareExchange(object) for comments.
//
// This is not entirely convincing due to lack of volatile.
//
T? result = null;
// T : class so call the object overload.
CompareExchange(ref Unsafe.As<T, object?>(ref location1), ref Unsafe.As<T, object?>(ref value), ref Unsafe.As<T, object?>(ref comparand), ref Unsafe.As<T, object?>(ref result!));
return result;
}

[Intrinsic]
[MethodImplAttribute(MethodImplOptions.InternalCall)]
public static extern long Exchange(ref long location1, long value);

[return: NotNullIfNotNull(nameof(location1))]
[Intrinsic]
public static T Exchange<T>([NotNullIfNotNull(nameof(value))] ref T location1, T value) where T : class?
{
if (Unsafe.IsNullRef(ref location1))
throw new NullReferenceException();
// See CompareExchange(T) for comments.
//
// This is not entirely convincing due to lack of volatile.
//
T? result = null;
// T : class so call the object overload.
Exchange(ref Unsafe.As<T, object?>(ref location1), ref Unsafe.As<T, object?>(ref value), ref Unsafe.As<T, object?>(ref result!));
return result;
}

[Intrinsic]
[MethodImplAttribute(MethodImplOptions.InternalCall)]
public static extern long Read(ref long location);
Expand Down

0 comments on commit a58f945

Please sign in to comment.