Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose VectorMask<T> to support generic masking for Vector<T> #74613

Closed
Tracked by #79005
anthonycanino opened this issue Aug 25, 2022 · 8 comments
Closed
Tracked by #79005

Expose VectorMask<T> to support generic masking for Vector<T> #74613

anthonycanino opened this issue Aug 25, 2022 · 8 comments
Labels
area-System.Runtime.Intrinsics avx512 Related to the AVX-512 architecture
Milestone

Comments

@anthonycanino
Copy link
Contributor

anthonycanino commented Aug 25, 2022

Summary

For each Vector API, we introduce a corresponding VectorMask, which abstracts away low-level bit-masking and instead allows to express conditional SIMD processing as boolean logic over Vector APIs. In particular, VectorMask<T> allows to perform masking operations and conditional SIMD processing on the variable length Vector<T> API, which allows for Vector<T> to be used more performantly and closer to SIMD processing done with Vector64/Vector128/Vector256.

Please see dotnet/designs#268 and https://github.com/anthonycanino/designs/blob/main/accepted/2022/enable-512-vectors.md#vectormask-usage for detailed discussion behind the rationale for VectorMask, though the APIs that are posted here reflect the most recent discussion on the proposal at dotnet/designs#268.

The API Proposal focuses on Vector128 and Vector with associated VectorMask128 and VectorMask APIs, but we propose a correponding VectorMaskX for each VectorX API, e.g., Vector64, Vector256 etc.

API Proposal

namespace System.Runtime.Intrinsics
{
    public static partial class Vector64
    {
        public VectorMask64<T> ExtractMask<T>(this Vector64<T> vector);
    }

    public static partial class Vector128
    {
        public VectorMask128<T> ExtractMask<T>(this Vector128<T> vector);
    }

    public static partial class Vector256
    {
        public VectorMask128<T> ExtractMask<T>(this Vector256<T> vector);
    }

    public static partial class Vector512
    {
        public VectorMask512<T> ExtractMask<T>(this Vector512<T> vector);
    }

    public static class VectorMask64
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask64<T> Create(ushort mask);

        public static VectorMask64<TTo> As<TFrom, TTo>(this VectorMask64<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask64<byte>   AsByte  <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<double> AsDouble<T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<short>  AsInt16 <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<int>    AsInt32 <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<long>   AsInt64 <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<nint>   AsNInt  <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<nuint>  AsNUInt <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<sbyte>  AsSByte <T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<float>  AsSingle<T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<ushort> AsUInt16<T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<uint>   AsUInt32<T>(this VectorMask64<T> vector) where T : struct;
        public static VectorMask64<ulong>  AsUInt64<T>(this VectorMask64<T> vector) where T : struct;

        public static VectorMask64<T> BitwiseAnd<T>(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> BitwiseOr<T>(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> AndNot<T>(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> OnesComplement<T>(VectorMask64<T> value);
        public static VectorMask64<T> Xor<T>(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> Xnor<T>(VectorMask64<T> left, VectorMask64<T> right);

        public static VectorMask64<T> ShiftLeft<T>(VectorMask64<T> value, int count);
        public static VectorMask64<T> ShiftRight<T>(VectorMask64<T> value, int count);

        public static bool Equals<T>(VectorMask64<T> left, VectorMask64<T> right);

        public static int LeadingZeroCount(VectorMask64<T> mask);
        public static int TrailingZeroCount(VectorMask64<T> mask);
        public static int PopCount(VectorMask64<T> mask);
        
        public static bool GetElement(this Vector64<T> vector, int index) where T : struct;
        public static Vector64Mask<T> WithElement(this Vector64<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask64<T> where T : struct 
    {
        private readonly byte _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask64<T> AllBitsSet { get; }
        public static VectorMask64<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask64<T> operator &(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> operator |(VectorMask64<T> left, VectorMask64<T> right);
        public static VectorMask64<T> operator ~(VectorMask64<T> value);
        public static VectorMask64<T> operator ^(VectorMask64<T> left, VectorMask64<T> right);

        public static VectorMask64<T> operator <<(VectorMask64<T> value, int count);
        public static VectorMask64<T> operator >>(VectorMask64<T> value, int count);

        public static bool operator ==(VectorMask64<T> left, VectorMask64<T> right);
        public static bool operator !=(VectorMask64<T> left, VectorMask64<T> right);
    }

    public static class VectorMask128
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask128<T> Create(ushort mask);

        public static VectorMask128<TTo> As<TFrom, TTo>(this VectorMask128<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask128<byte>   AsByte  <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<double> AsDouble<T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<short>  AsInt16 <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<int>    AsInt32 <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<long>   AsInt64 <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<nint>   AsNInt  <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<nuint>  AsNUInt <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<sbyte>  AsSByte <T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<float>  AsSingle<T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<ushort> AsUInt16<T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<uint>   AsUInt32<T>(this VectorMask128<T> vector) where T : struct;
        public static VectorMask128<ulong>  AsUInt64<T>(this VectorMask128<T> vector) where T : struct;

        public static VectorMask128<T> BitwiseAnd<T>(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> BitwiseOr<T>(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> AndNot<T>(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> OnesComplement<T>(VectorMask128<T> value);
        public static VectorMask128<T> Xor<T>(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> Xnor<T>(VectorMask128<T> left, VectorMask128<T> right);

        public static VectorMask128<T> ShiftLeft<T>(VectorMask128<T> value, int count);
        public static VectorMask128<T> ShiftRight<T>(VectorMask128<T> value, int count);

        public static bool Equals<T>(VectorMask128<T> left, VectorMask128<T> right);

        public static int LeadingZeroCount(VectorMask128<T> mask);
        public static int TrailingZeroCount(VectorMask128<T> mask);
        public static int PopCount(VectorMask128<T> mask);
        
        public static bool GetElement(this Vector128<T> vector, int index) where T : struct;
        public static Vector128Mask<T> WithElement(this Vector128<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask128<T> where T : struct 
    {
        private readonly ushort _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask128<T> AllBitsSet { get; }
        public static VectorMask128<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask128<T> operator &(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> operator |(VectorMask128<T> left, VectorMask128<T> right);
        public static VectorMask128<T> operator ~(VectorMask128<T> value);
        public static VectorMask128<T> operator ^(VectorMask128<T> left, VectorMask128<T> right);

        public static VectorMask128<T> operator <<(VectorMask128<T> value, int count);
        public static VectorMask128<T> operator >>(VectorMask128<T> value, int count);

        public static bool operator ==(VectorMask128<T> left, VectorMask128<T> right);
        public static bool operator !=(VectorMask128<T> left, VectorMask128<T> right);
    }

    public static class VectorMask256
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask256<T> Create(ushort mask);

        public static VectorMask256<TTo> As<TFrom, TTo>(this VectorMask256<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask256<byte>   AsByte  <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<double> AsDouble<T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<short>  AsInt16 <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<int>    AsInt32 <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<long>   AsInt64 <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<nint>   AsNInt  <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<nuint>  AsNUInt <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<sbyte>  AsSByte <T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<float>  AsSingle<T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<ushort> AsUInt16<T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<uint>   AsUInt32<T>(this VectorMask256<T> vector) where T : struct;
        public static VectorMask256<ulong>  AsUInt64<T>(this VectorMask256<T> vector) where T : struct;

        public static VectorMask256<T> BitwiseAnd<T>(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> BitwiseOr<T>(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> AndNot<T>(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> OnesComplement<T>(VectorMask256<T> value);
        public static VectorMask256<T> Xor<T>(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> Xnor<T>(VectorMask256<T> left, VectorMask256<T> right);

        public static VectorMask256<T> ShiftLeft<T>(VectorMask256<T> value, int count);
        public static VectorMask256<T> ShiftRight<T>(VectorMask256<T> value, int count);

        public static bool Equals<T>(VectorMask256<T> left, VectorMask256<T> right);

        public static int LeadingZeroCount(VectorMask256<T> mask);
        public static int TrailingZeroCount(VectorMask256<T> mask);
        public static int PopCount(VectorMask256<T> mask);
        
        public static bool GetElement(this Vector256<T> vector, int index) where T : struct;
        public static Vector256Mask<T> WithElement(this Vector256<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask256<T> where T : struct 
    {
        private readonly uint _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask256<T> AllBitsSet { get; }
        public static VectorMask256<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask256<T> operator &(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> operator |(VectorMask256<T> left, VectorMask256<T> right);
        public static VectorMask256<T> operator ~(VectorMask256<T> value);
        public static VectorMask256<T> operator ^(VectorMask256<T> left, VectorMask256<T> right);

        public static VectorMask256<T> operator <<(VectorMask256<T> value, int count);
        public static VectorMask256<T> operator >>(VectorMask256<T> value, int count);

        public static bool operator ==(VectorMask256<T> left, VectorMask256<T> right);
        public static bool operator !=(VectorMask256<T> left, VectorMask256<T> right);
    }

    public static class VectorMask512
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask512<T> Create(ushort mask);

        public static VectorMask512<TTo> As<TFrom, TTo>(this VectorMask512<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask512<byte>   AsByte  <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<double> AsDouble<T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<short>  AsInt16 <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<int>    AsInt32 <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<long>   AsInt64 <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<nint>   AsNInt  <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<nuint>  AsNUInt <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<sbyte>  AsSByte <T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<float>  AsSingle<T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<ushort> AsUInt16<T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<uint>   AsUInt32<T>(this VectorMask512<T> vector) where T : struct;
        public static VectorMask512<ulong>  AsUInt64<T>(this VectorMask512<T> vector) where T : struct;

        public static VectorMask512<T> BitwiseAnd<T>(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> BitwiseOr<T>(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> AndNot<T>(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> OnesComplement<T>(VectorMask512<T> value);
        public static VectorMask512<T> Xor<T>(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> Xnor<T>(VectorMask512<T> left, VectorMask512<T> right);

        public static VectorMask512<T> ShiftLeft<T>(VectorMask512<T> value, int count);
        public static VectorMask512<T> ShiftRight<T>(VectorMask512<T> value, int count);

        public static bool Equals<T>(VectorMask512<T> left, VectorMask512<T> right);

        public static int LeadingZeroCount(VectorMask512<T> mask);
        public static int TrailingZeroCount(VectorMask512<T> mask);
        public static int PopCount(VectorMask512<T> mask);
        
        public static bool GetElement(this Vector512<T> vector, int index) where T : struct;
        public static Vector512Mask<T> WithElement(this Vector512<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask512<T> where T : struct 
    {
        private readonly ulong _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask512<T> AllBitsSet { get; }
        public static VectorMask512<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask512<T> operator &(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> operator |(VectorMask512<T> left, VectorMask512<T> right);
        public static VectorMask512<T> operator ~(VectorMask512<T> value);
        public static VectorMask512<T> operator ^(VectorMask512<T> left, VectorMask512<T> right);

        public static VectorMask512<T> operator <<(VectorMask512<T> value, int count);
        public static VectorMask512<T> operator >>(VectorMask512<T> value, int count);

        public static bool operator ==(VectorMask512<T> left, VectorMask512<T> right);
        public static bool operator !=(VectorMask512<T> left, VectorMask512<T> right);
    }
}

namespace System.Numerics
{
    public static partial class Vector
    {
        public VectorMask<T> ExtractMask<T>(this Vector<T> vector);
    }

    public static class VectorMask
    {
        public bool IsHardwareAccelerated { get; }

        public static VectorMask<T> Create(byte[] value);
        public static VectorMask<T> Create(byte[] value, int index);
        public static VectorMask<T> Create(ReadOnlySpan<byte> value);

        public static VectorMask<TTo> As<TFrom, TTo>(this VectorMask<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask<byte>   AsByte  <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<double> AsDouble<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<short>  AsInt16 <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<int>    AsInt32 <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<long>   AsInt64 <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<nint>   AsNInt  <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<nuint>  AsNUInt <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<sbyte>  AsSByte <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<float>  AsSingle<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<ushort> AsUInt16<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<uint>   AsUInt32<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<ulong>  AsUInt64<T>(this VectorMask<T> vector) where T : struct;

        public static VectorMask<T> BitwiseAnd<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> BitwiseOr<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> AndNot<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> OnesComplement<T>(VectorMask<T> value);
        public static VectorMask<T> Xor<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> Xnor<T>(VectorMask<T> left, VectorMask<T> right);

        public static VectorMask<T> ShiftLeft<T>(VectorMask<T> value, int count);
        public static VectorMask<T> ShiftRight<T>(VectorMask<T> value, int count);

        public static bool Equals<T>(VectorMask<T> left, VectorMask<T> right);

        public static int LeadingZeroCount(VectorMask<T> mask);
        public static int TrailingZeroCount(VectorMask<T> mask);
        public static int PopCount(VectorMask<T> mask);
        
        public static bool GetElement(this Vector<T> vector, int index) where T : struct;
        public static VectorMask<T> WithElement(this Vector<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask<T> where T : struct 
    {
        private readonly ulong _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask<T> AllBitsSet { get; }
        public static VectorMask<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask<T> operator &(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> operator |(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> operator ~(VectorMask<T> value);
        public static VectorMask<T> operator ^(VectorMask<T> left, VectorMask<T> right);

        public static VectorMask<T> operator <<(VectorMask<T> value, int count);
        public static VectorMask<T> operator >>(VectorMask<T> value, int count);

        public static bool operator ==(VectorMask<T> left, VectorMask<T> right);
        public static bool operator !=(VectorMask<T> left, VectorMask<T> right);
    }
}

API Usage

A few points require further discussion:

VectorMask<T> does not have a Create method because like Vector, it's size is unknown. So while it would be nice to have VectorMask<byte>.Create(0xFF00) or VectorMask<byte>.Create(0xFFFF0000), (the first might cover if VectorMask == VectorMask128, the second if VectorMask == VectorMask256) since technically VectorMask and Vector are variable length, it breaks the abstraction a bit. My proposed alternative is to have CreateUnsafe where a boolean array allows to set each bit, and is an error if the length of the boolean array != VectorMask.Count`.

It would be good to also have this done via a byte array to compress the user effort a bit, e.g., instead of VectorMask<int>.CreateUnsafe([true, false, false, true]) we have VectorMask<int>.CreateUnsafe([0x09]). We might want to relax the constraint a bit then, and instead say "(if the length of the byte array) * 8 < VectorMask.Count, zero extend, if greater, truncate etc.).

@anthonycanino anthonycanino added the api-suggestion Early API idea and discussion, it is NOT ready for implementation label Aug 25, 2022
@ghost ghost added the untriaged New issue has not been triaged by the area owner label Aug 25, 2022
@ghost
Copy link

ghost commented Aug 25, 2022

Tagging subscribers to this area: @dotnet/area-system-runtime-intrinsics
See info in area-owners.md if you want to be subscribed.

Issue Details

Background and motivation

For each Vector API, we introduce a corresponding VectorMask, which abstracts away low-level bit-masking and instead allows to express conditional SIMD processing as boolean logic over Vector APIs. In particular, VectorMask<T> allows to perform masking operations and conditional SIMD processing on the variable length Vector<T> API, which allows for Vector<T> to be used more performantly and closer to SIMD processing done with Vector64/Vector128/Vector256.

Please see dotnet/designs#268 and https://github.com/anthonycanino/designs/blob/main/accepted/2022/enable-512-vectors.md#vectormask-usage for detailed discussion behind the rationale for VectorMask, though the APIs that are posted here reflect the most recent discussion on the proposal at dotnet/designs#268.

The API Proposal focuses on Vector128 and Vector with associated VectorMask128 and VectorMask APIs, but we propose a correponding VectorMaskX for each VectorX API, e.g., Vector64, Vector256 etc.

API Proposal

public readonly struct Vector128<T> where T : struct
{
  public VectorMask128<T> AsMask();
}

public static class VectorMask128
{
  public static bool IsHardwareAccelerated;  

  // Create functions, have to align the type with the constant that can be used
  public static unsafe VectorMask128<T> Create(ushort mask);
  
  // The following we definitely want, standard boolean logic operators that are accelerated by AVX512

  public static VectorMask128<T> Add<T>(VectorMask128<T> left, VectorMask128<T> right);
  public static VectorMask128<T> And<T>(VectorMask128<T> left, VectorMask128<T> right);

  public static VectorMask128<T> AndNot<T>(VectorMask128<T> left, VectorMask128<T> right);
  public static VectorMask128<T> Not<T>(VectorMask128<T> other);
  public static VectorMask128<T> Or<T>(VectorMask128<T> left, VectorMask128<T> right);
  public static VectorMask128<T> Xor<T>(VectorMask128<T> left, VectorMask128<T> right);
  public static VectorMask128<T> Xnor<T>(VectorMask128<T> left, VectorMask128<T> right);

  public static VectorMask128<T> ShiftLeft<T>(VectorMask128<T> left, int count);
  public static VectorMask128<T> ShiftRight<T>(VectorMask128<T> left, int count);

  // To be a bit consistent with Vector128 etc, this can return either all bits set or zero

  public static VectorMask128<T> Equals<T>(VectorMask128<T> left, VectorMask128<T> right);

  public static int LeadingZeroCount(VectorMask128<T> mask);
  public static int TrailingZeroCount(VectorMask128<T> mask);
  public static int PopCount(VectorMask128<T> mask);

  public static bool GetCondition(int index);
  public static bool SetCondition(int index, bool cond);

  public static VectorMask128<TTo> As<TFrom, TTo>(this VectorMask128<TFrom> vector)
    where TFrom : struct
    where TTo : struct;
}


public readonly struct VectorMask128<T> where T : struct 
{
  private readonly ushort _01;  // 16 bits is the most needed to mask 16 bytes (vector128 of byte)

  public static VectorMask128<T> Zero;
  public static VectorMask128<T> AllBitsSet;

  // Count  = Sizeof(ushort) / Sizeof(T)
  public static int Count;

  public static unsafe VectorMask128<T> operator +(VectorMask128<T> left, VectorMask128<T> right);
  public static unsafe VectorMask128<T> operator &(VectorMask128<T> left, VectorMask128<T> right);
  public static unsafe VectorMask128<T> operator ~(VectorMask128<T> other);
  public static unsafe VectorMask128<T> operator |(VectorMask128<T> left, VectorMask128<T> right);
  public static unsafe VectorMask128<T> operator ^(VectorMask128<T> left, VectorMask128<T> right);

  // Users must use `Count` to determine how much count available to shift by
  public static unsafe VectorMask128<T> operator <<(VectorMask128<T> other, int count);
  public static unsafe VectorMask128<T> operator >>(VectorMask128<T> other, int count);

  // Total equality is consistent with Vector operator ==
  public static unsafe bool operator ==(VectorMask128<T> left, VectorMask128<T> right);
  public static unsafe bool operator !=(VectorMask128<T> left, VectorMask128<T> right);

  public static bool this[int index] => this.GetCondition(index);

}

public readonly struct Vector<T> where T : struct
{
  public VectorMask<T> AsMask();
}

public static class VectorMask
{
  public static bool IsHardwareAccelerated;  

  // The following we definitely want, standard boolean logic operators that are accelerated by AVX512

  public static VectorMask<T> Add<T>(VectorMask<T> left, VectorMask<T> right);
  public static VectorMask<T> And<T>(VectorMask<T> left, VectorMask<T> right);

  public static VectorMask<T> AndNot<T>(VectorMask<T> left, VectorMask<T> right);
  public static VectorMask<T> Not<T>(VectorMask<T> other);
  public static VectorMask<T> Or<T>(VectorMask<T> left, VectorMask<T> right);
  public static VectorMask<T> Xor<T>(VectorMask<T> left, VectorMask<T> right);

  public static VectorMask<T> Xnor<T>(VectorMask<T> left, VectorMask<T> right);

  // The following we want for AVX512 but have to consider how other arch fit

  public static VectorMask<T> ShiftLeft<T>(VectorMask<T> left, int count);
  public static VectorMask<T> ShiftRight<T>(VectorMask<T> left, int count);

  // To be a bit consistent with Vector128 etc, this can return either all bits set or zero

  public static VectorMask<T> Equals<T>(VectorMask<T> left, VectorMask<T> right);

  public static int LeadingZeroCount(VectorMask<T> mask);
  public static int TrailingZeroCount(VectorMask<T> mask);
  public static int PopCount(VectorMask<T> mask);

  public static bool GetCondition(int index);
  public static bool SetCondition(int index, bool cond);

  public static VectorMask<TTo> As<TFrom, TTo>(this VectorMask<TFrom> vector)
    where TFrom : struct
    where TTo : struct;
}

public readonly struct VectorMask<T> where T : struct 
{
  public static VectorMask<T> Zero;
  public static VectorMask<T> AllBitsSet;

  // Count  = Sizeof(VectorMask) / Sizeof(T)
  public static int Count;

  // If elementBits.Length != Count, error
  public static unsafe VectorMask<T> CreateUnsafe(bool[] elementBits);
  // If elementBits.Length != Count, error
  public static unsafe VectorMask<T> CreateUnsafe(byte[] elementBits);

  public static unsafe VectorMask<T> operator +(VectorMask<T> left, VectorMask<T> right);
  public static unsafe VectorMask<T> operator &(VectorMask<T> left, VectorMask<T> right);
  public static unsafe VectorMask<T> operator ~(VectorMask<T> other);
  public static unsafe VectorMask<T> operator |(VectorMask<T> left, VectorMask<T> right);
  public static unsafe VectorMask<T> operator ^(VectorMask<T> left, VectorMask<T> right);

  // Users must use `Count` to determine how much count available to shift by
  public static unsafe VectorMask<T> operator <<(VectorMask<T> other, int count);

  public static unsafe VectorMask<T> operator >>(VectorMask<T> other, int count);

  // Total equality is consistent with Vector operator ==
  public static unsafe bool operator ==(VectorMask<T> left, VectorMask<T> right);

  public static unsafe bool operator !=(VectorMask<T> left, VectorMask<T> right);

  public static bool this[int index] => this.GetCondition(index);
}

API Usage

A few points require further discussion:

VectorMask<T> does not have a Create method because like Vector, it's size is unknown. So while it would be nice to have VectorMask<byte>.Create(0xFF00) or VectorMask<byte>.Create(0xFFFF0000), (the first might cover if VectorMask == VectorMask128, the second if VectorMask == VectorMask256) since technically VectorMask and Vector are variable length, it breaks the abstraction a bit. My proposed alternative is to have CreateUnsafe where a boolean array allows to set each bit, and is an error if the length of the boolean array != VectorMask.Count`.

It would be good to also have this done via a byte array to compress the user effort a bit, e.g., instead of VectorMask<int>.CreateUnsafe([true, false, false, true]) we have VectorMask<int>.CreateUnsafe([0x09]). We might want to relax the constraint a bit then, and instead say "(if the length of the byte array) * 8 < VectorMask.Count, zero extend, if greater, truncate etc.).

Alternative Designs

No response

Risks

No response

Author: anthonycanino
Assignees: -
Labels:

api-suggestion, area-System.Runtime.Intrinsics

Milestone: -

@anthonycanino
Copy link
Contributor Author

@Zintom
Copy link

Zintom commented Aug 26, 2022

The CreateUnsafe methods on VectorMask<T> should take Span<T> equivalents rather than the array types; this would allow stack allocated arrays to be passed to the constructor, saving unnecessary allocation to the managed heap; if a caller really wanted to use an array, they still can by creating a span over the array and passing that span.

The current proposal forces the caller to make an allocation.

@tannergooding tannergooding added api-ready-for-review API is ready for review, it is NOT ready for implementation and removed api-suggestion Early API idea and discussion, it is NOT ready for implementation untriaged New issue has not been triaged by the area owner labels Aug 26, 2022
@tannergooding tannergooding added this to the 8.0.0 milestone Aug 26, 2022
@tannergooding
Copy link
Member

Updated the API proposal to cover the full set of changes and to better match existing signatures in a few cases.

@bartonjs
Copy link
Member

bartonjs commented Aug 30, 2022

Video

  • We corrected the IsHardwareAccelerated properties in the proposal to be static, as intended.
  • We agreed that instead of VectorMask64 it should be Vector64Mask.
  • We also would like to investigate what nested types would look like here, e.g. Vector64<T>.Mask
  • We discussed the naming/casing of Xnor, and decided it is correct as proposed.
  • After a very long discussion to try to understand the implications of the type-size conversions on vector masks, we renamed the As- operations to To- because of the complexities of "don't care" bits.
  • The argument to all of the VectorNMask.Create was adjusted based on feedback (byte, ushort, uint, ulong) vs what was proposed (all ushort)
namespace System.Runtime.Intrinsics
{
    public static partial class Vector64
    {
        public Vector64Mask<T> ExtractMask<T>(this Vector64<T> vector);
    }

    public static partial class Vector128
    {
        public Vector128Mask<T> ExtractMask<T>(this Vector128<T> vector);
    }

    public static partial class Vector256
    {
        public Vector256Mask<T> ExtractMask<T>(this Vector256<T> vector);
    }

    public static partial class Vector512
    {
        public Vector512Mask<T> ExtractMask<T>(this Vector512<T> vector);
    }

    public static class Vector64Mask
    {
        public static bool IsHardwareAccelerated { get; }

        public static Vector64Mask<T> Create(byte mask);

        public static Vector64Mask<TTo> To<TFrom, TTo>(this Vector64Mask<TFrom> vector) where TFrom : struct where TTo : struct;

        public static Vector64Mask<byte>   ToByte  <T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<double> ToDouble<T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<short>  ToInt16 <T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<int>    ToInt32 <T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<long>   ToInt64 <T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<nint>   ToNInt  <T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<nuint>  ToNUInt <T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<sbyte>  ToSByte <T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<float>  ToSingle<T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<ushort> ToUInt16<T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<uint>   ToUInt32<T>(this Vector64Mask<T> vector) where T : struct;
        public static Vector64Mask<ulong>  ToUInt64<T>(this Vector64Mask<T> vector) where T : struct;

        public static Vector64Mask<T> BitwiseAnd<T>(Vector64Mask<T> left, Vector64Mask<T> right);
        public static Vector64Mask<T> BitwiseOr<T>(Vector64Mask<T> left, Vector64Mask<T> right);
        public static Vector64Mask<T> AndNot<T>(Vector64Mask<T> left, Vector64Mask<T> right);
        public static Vector64Mask<T> OnesComplement<T>(Vector64Mask<T> value);
        public static Vector64Mask<T> Xor<T>(Vector64Mask<T> left, Vector64Mask<T> right);
        public static Vector64Mask<T> Xnor<T>(Vector64Mask<T> left, Vector64Mask<T> right);

        public static Vector64Mask<T> ShiftLeft<T>(Vector64Mask<T> value, int count);
        public static Vector64Mask<T> ShiftRight<T>(Vector64Mask<T> value, int count);

        public static bool Equals<T>(Vector64Mask<T> left, Vector64Mask<T> right);

        public static int LeadingZeroCount(Vector64Mask<T> mask);
        public static int TrailingZeroCount(Vector64Mask<T> mask);
        public static int PopCount(Vector64Mask<T> mask);
        
        public static bool GetElement(this Vector64<T> vector, int index) where T : struct;
        public static Vector64Mask<T> WithElement(this Vector64<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct Vector64Mask<T> where T : struct 
    {
        private readonly byte _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static Vector64Mask<T> AllBitsSet { get; }
        public static Vector64Mask<T> Zero { get; }

        public static bool this[int index] { get; }

        public static Vector64Mask<T> operator &(Vector64Mask<T> left, Vector64Mask<T> right);
        public static Vector64Mask<T> operator |(Vector64Mask<T> left, Vector64Mask<T> right);
        public static Vector64Mask<T> operator ~(Vector64Mask<T> value);
        public static Vector64Mask<T> operator ^(Vector64Mask<T> left, Vector64Mask<T> right);

        public static Vector64Mask<T> operator <<(Vector64Mask<T> value, int count);
        public static Vector64Mask<T> operator >>(Vector64Mask<T> value, int count);

        public static bool operator ==(Vector64Mask<T> left, Vector64Mask<T> right);
        public static bool operator !=(Vector64Mask<T> left, Vector64Mask<T> right);
    }

    public static class Vector128Mask
    {
        public static bool IsHardwareAccelerated { get; }

        public static Vector128Mask<T> Create(ushort mask);

        public static Vector128Mask<TTo> To<TFrom, TTo>(this Vector128Mask<TFrom> vector) where TFrom : struct where TTo : struct;

        public static Vector128Mask<byte>   ToByte  <T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<double> ToDouble<T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<short>  ToInt16 <T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<int>    ToInt32 <T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<long>   ToInt64 <T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<nint>   ToNInt  <T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<nuint>  ToNUInt <T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<sbyte>  ToSByte <T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<float>  ToSingle<T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<ushort> ToUInt16<T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<uint>   ToUInt32<T>(this Vector128Mask<T> vector) where T : struct;
        public static Vector128Mask<ulong>  ToUInt64<T>(this Vector128Mask<T> vector) where T : struct;

        public static Vector128Mask<T> BitwiseAnd<T>(Vector128Mask<T> left, Vector128Mask<T> right);
        public static Vector128Mask<T> BitwiseOr<T>(Vector128Mask<T> left, Vector128Mask<T> right);
        public static Vector128Mask<T> AndNot<T>(Vector128Mask<T> left, Vector128Mask<T> right);
        public static Vector128Mask<T> OnesComplement<T>(Vector128Mask<T> value);
        public static Vector128Mask<T> Xor<T>(Vector128Mask<T> left, Vector128Mask<T> right);
        public static Vector128Mask<T> Xnor<T>(Vector128Mask<T> left, Vector128Mask<T> right);

        public static Vector128Mask<T> ShiftLeft<T>(Vector128Mask<T> value, int count);
        public static Vector128Mask<T> ShiftRight<T>(Vector128Mask<T> value, int count);

        public static bool Equals<T>(Vector128Mask<T> left, Vector128Mask<T> right);

        public static int LeadingZeroCount(Vector128Mask<T> mask);
        public static int TrailingZeroCount(Vector128Mask<T> mask);
        public static int PopCount(Vector128Mask<T> mask);
        
        public static bool GetElement(this Vector128<T> vector, int index) where T : struct;
        public static Vector128Mask<T> WithElement(this Vector128<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct Vector128Mask<T> where T : struct 
    {
        private readonly ushort _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static Vector128Mask<T> AllBitsSet { get; }
        public static Vector128Mask<T> Zero { get; }

        public static bool this[int index] { get; }

        public static Vector128Mask<T> operator &(Vector128Mask<T> left, Vector128Mask<T> right);
        public static Vector128Mask<T> operator |(Vector128Mask<T> left, Vector128Mask<T> right);
        public static Vector128Mask<T> operator ~(Vector128Mask<T> value);
        public static Vector128Mask<T> operator ^(Vector128Mask<T> left, Vector128Mask<T> right);

        public static Vector128Mask<T> operator <<(Vector128Mask<T> value, int count);
        public static Vector128Mask<T> operator >>(Vector128Mask<T> value, int count);

        public static bool operator ==(Vector128Mask<T> left, Vector128Mask<T> right);
        public static bool operator !=(Vector128Mask<T> left, Vector128Mask<T> right);
    }

    public static class Vector256Mask
    {
        public static bool IsHardwareAccelerated { get; }

        public static Vector256Mask<T> Create(uint mask);

        public static Vector256Mask<TTo> To<TFrom, TTo>(this Vector256Mask<TFrom> vector) where TFrom : struct where TTo : struct;

        public static Vector256Mask<byte>   ToByte  <T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<double> ToDouble<T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<short>  ToInt16 <T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<int>    ToInt32 <T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<long>   ToInt64 <T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<nint>   ToNInt  <T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<nuint>  ToNUInt <T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<sbyte>  ToSByte <T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<float>  ToSingle<T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<ushort> ToUInt16<T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<uint>   ToUInt32<T>(this Vector256Mask<T> vector) where T : struct;
        public static Vector256Mask<ulong>  ToUInt64<T>(this Vector256Mask<T> vector) where T : struct;

        public static Vector256Mask<T> BitwiseAnd<T>(Vector256Mask<T> left, Vector256Mask<T> right);
        public static Vector256Mask<T> BitwiseOr<T>(Vector256Mask<T> left, Vector256Mask<T> right);
        public static Vector256Mask<T> AndNot<T>(Vector256Mask<T> left, Vector256Mask<T> right);
        public static Vector256Mask<T> OnesComplement<T>(Vector256Mask<T> value);
        public static Vector256Mask<T> Xor<T>(Vector256Mask<T> left, Vector256Mask<T> right);
        public static Vector256Mask<T> Xnor<T>(Vector256Mask<T> left, Vector256Mask<T> right);

        public static Vector256Mask<T> ShiftLeft<T>(Vector256Mask<T> value, int count);
        public static Vector256Mask<T> ShiftRight<T>(Vector256Mask<T> value, int count);

        public static bool Equals<T>(Vector256Mask<T> left, Vector256Mask<T> right);

        public static int LeadingZeroCount(Vector256Mask<T> mask);
        public static int TrailingZeroCount(Vector256Mask<T> mask);
        public static int PopCount(Vector256Mask<T> mask);
        
        public static bool GetElement(this Vector256<T> vector, int index) where T : struct;
        public static Vector256Mask<T> WithElement(this Vector256<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct Vector256Mask<T> where T : struct 
    {
        private readonly uint _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static Vector256Mask<T> AllBitsSet { get; }
        public static Vector256Mask<T> Zero { get; }

        public static bool this[int index] { get; }

        public static Vector256Mask<T> operator &(Vector256Mask<T> left, Vector256Mask<T> right);
        public static Vector256Mask<T> operator |(Vector256Mask<T> left, Vector256Mask<T> right);
        public static Vector256Mask<T> operator ~(Vector256Mask<T> value);
        public static Vector256Mask<T> operator ^(Vector256Mask<T> left, Vector256Mask<T> right);

        public static Vector256Mask<T> operator <<(Vector256Mask<T> value, int count);
        public static Vector256Mask<T> operator >>(Vector256Mask<T> value, int count);

        public static bool operator ==(Vector256Mask<T> left, Vector256Mask<T> right);
        public static bool operator !=(Vector256Mask<T> left, Vector256Mask<T> right);
    }

    public static class Vector512Mask
    {
        public static bool IsHardwareAccelerated { get; }

        public static Vector512Mask<T> Create(ulong mask);

        public static Vector512Mask<TTo> To<TFrom, TTo>(this Vector512Mask<TFrom> vector) where TFrom : struct where TTo : struct;

        public static Vector512Mask<byte>   ToByte  <T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<double> ToDouble<T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<short>  ToInt16 <T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<int>    ToInt32 <T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<long>   ToInt64 <T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<nint>   ToNInt  <T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<nuint>  ToNUInt <T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<sbyte>  ToSByte <T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<float>  ToSingle<T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<ushort> ToUInt16<T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<uint>   ToUInt32<T>(this Vector512Mask<T> vector) where T : struct;
        public static Vector512Mask<ulong>  ToUInt64<T>(this Vector512Mask<T> vector) where T : struct;

        public static Vector512Mask<T> BitwiseAnd<T>(Vector512Mask<T> left, Vector512Mask<T> right);
        public static Vector512Mask<T> BitwiseOr<T>(Vector512Mask<T> left, Vector512Mask<T> right);
        public static Vector512Mask<T> AndNot<T>(Vector512Mask<T> left, Vector512Mask<T> right);
        public static Vector512Mask<T> OnesComplement<T>(Vector512Mask<T> value);
        public static Vector512Mask<T> Xor<T>(Vector512Mask<T> left, Vector512Mask<T> right);
        public static Vector512Mask<T> Xnor<T>(Vector512Mask<T> left, Vector512Mask<T> right);

        public static Vector512Mask<T> ShiftLeft<T>(Vector512Mask<T> value, int count);
        public static Vector512Mask<T> ShiftRight<T>(Vector512Mask<T> value, int count);

        public static bool Equals<T>(Vector512Mask<T> left, Vector512Mask<T> right);

        public static int LeadingZeroCount(Vector512Mask<T> mask);
        public static int TrailingZeroCount(Vector512Mask<T> mask);
        public static int PopCount(Vector512Mask<T> mask);
        
        public static bool GetElement(this Vector512<T> vector, int index) where T : struct;
        public static Vector512Mask<T> WithElement(this Vector512<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct Vector512Mask<T> where T : struct 
    {
        private readonly ulong _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static Vector512Mask<T> AllBitsSet { get; }
        public static Vector512Mask<T> Zero { get; }

        public static bool this[int index] { get; }

        public static Vector512Mask<T> operator &(Vector512Mask<T> left, Vector512Mask<T> right);
        public static Vector512Mask<T> operator |(Vector512Mask<T> left, Vector512Mask<T> right);
        public static Vector512Mask<T> operator ~(Vector512Mask<T> value);
        public static Vector512Mask<T> operator ^(Vector512Mask<T> left, Vector512Mask<T> right);

        public static Vector512Mask<T> operator <<(Vector512Mask<T> value, int count);
        public static Vector512Mask<T> operator >>(Vector512Mask<T> value, int count);

        public static bool operator ==(Vector512Mask<T> left, Vector512Mask<T> right);
        public static bool operator !=(Vector512Mask<T> left, Vector512Mask<T> right);
    }
}

namespace System.Numerics
{
    public static partial class Vector
    {
        public VectorMask<T> ExtractMask<T>(this Vector<T> vector);
    }

    public static class VectorMask
    {
        public static bool IsHardwareAccelerated { get; }

        public static VectorMask<T> Create(byte[] value);
        public static VectorMask<T> Create(byte[] value, int index);
        public static VectorMask<T> Create(ReadOnlySpan<byte> value);

        public static VectorMask<TTo> To<TFrom, TTo>(this VectorMask<TFrom> vector) where TFrom : struct where TTo : struct;

        public static VectorMask<byte>   ToByte  <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<double> ToDouble<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<short>  ToInt16 <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<int>    ToInt32 <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<long>   ToInt64 <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<nint>   ToNInt  <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<nuint>  ToNUInt <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<sbyte>  ToSByte <T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<float>  ToSingle<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<ushort> ToUInt16<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<uint>   ToUInt32<T>(this VectorMask<T> vector) where T : struct;
        public static VectorMask<ulong>  ToUInt64<T>(this VectorMask<T> vector) where T : struct;

        public static VectorMask<T> BitwiseAnd<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> BitwiseOr<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> AndNot<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> OnesComplement<T>(VectorMask<T> value);
        public static VectorMask<T> Xor<T>(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> Xnor<T>(VectorMask<T> left, VectorMask<T> right);

        public static VectorMask<T> ShiftLeft<T>(VectorMask<T> value, int count);
        public static VectorMask<T> ShiftRight<T>(VectorMask<T> value, int count);

        public static bool Equals<T>(VectorMask<T> left, VectorMask<T> right);

        public static int LeadingZeroCount(VectorMask<T> mask);
        public static int TrailingZeroCount(VectorMask<T> mask);
        public static int PopCount(VectorMask<T> mask);
        
        public static bool GetElement(this Vector<T> vector, int index) where T : struct;
        public static VectorMask<T> WithElement(this Vector<T> vector, int index, bool value) where T : struct;
    }

    public readonly struct VectorMask<T> where T : struct 
    {
        private readonly ulong _value;

        public static bool IsSupported { get; }
        public static int Count { get; }

        public static VectorMask<T> AllBitsSet { get; }
        public static VectorMask<T> Zero { get; }

        public static bool this[int index] { get; }

        public static VectorMask<T> operator &(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> operator |(VectorMask<T> left, VectorMask<T> right);
        public static VectorMask<T> operator ~(VectorMask<T> value);
        public static VectorMask<T> operator ^(VectorMask<T> left, VectorMask<T> right);

        public static VectorMask<T> operator <<(VectorMask<T> value, int count);
        public static VectorMask<T> operator >>(VectorMask<T> value, int count);

        public static bool operator ==(VectorMask<T> left, VectorMask<T> right);
        public static bool operator !=(VectorMask<T> left, VectorMask<T> right);
    }
}

@bartonjs bartonjs added api-approved API was approved in API review, it can be implemented and removed api-ready-for-review API is ready for review, it is NOT ready for implementation labels Aug 30, 2022
@sparker-arm
Copy link

Is it feasible to add an extra template parameter to specify the width, instead of having many underlying classes? If so, the surface area of this could be reduced considerably.

@tannergooding
Copy link
Member

Is it feasible to add an extra template parameter to specify the width, instead of having many underlying classes? If so, the surface area of this could be reduced considerably.

.NET doesn't have any kind of support like that today. That is, you can't specify something like VectorMask<T, int>.

You could specify something like VectorMask<TVector, T> and the user would have to consume it as VectorMask<Vector128<T>, T> but that has a much worse overall UX and comes with many other complications/considerations for usability.

@tannergooding
Copy link
Member

Superseded by #87097

@tannergooding tannergooding removed the api-approved API was approved in API review, it can be implemented label Jul 24, 2023
@ghost ghost locked as resolved and limited conversation to collaborators Aug 24, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
area-System.Runtime.Intrinsics avx512 Related to the AVX-512 architecture
Projects
None yet
Development

No branches or pull requests

6 participants