Skip to content

Commit

Permalink
Patterns: Allow non-negative and full integer sets to merge (intersec…
Browse files Browse the repository at this point in the history
…t or union) (dotnet#71968)
  • Loading branch information
jcouv authored Mar 6, 2024
1 parent 99fbf1b commit 3ae3ef9
Show file tree
Hide file tree
Showing 28 changed files with 667 additions and 317 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ namespace Microsoft.CodeAnalysis.CSharp

internal static partial class ValueSetFactory
{
private struct ByteTC : INumericTC<byte>
private class ByteTC : INumericTC<byte>
{
public static readonly ByteTC Instance = new ByteTC();

byte INumericTC<byte>.MinValue => byte.MinValue;

byte INumericTC<byte>.MaxValue => byte.MaxValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@

using System;
using System.Diagnostics;
using System.Globalization;

namespace Microsoft.CodeAnalysis.CSharp
{
using static BinaryOperatorKind;

internal static partial class ValueSetFactory
{
private struct CharTC : INumericTC<char>
private class CharTC : INumericTC<char>
{
public static readonly CharTC Instance = new CharTC();

char INumericTC<char>.MinValue => char.MinValue;

char INumericTC<char>.MaxValue => char.MaxValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ namespace Microsoft.CodeAnalysis.CSharp

internal static partial class ValueSetFactory
{
private struct DecimalTC : INumericTC<decimal>
private class DecimalTC : INumericTC<decimal>
{
public static readonly DecimalTC Instance = new DecimalTC();

// These are the smallest nonzero normal mantissa value (in three parts) below which you could use a higher scale.
// This is the 96-bit representation of ((2^96)-1) / 10;
private const uint transitionLow = 0x99999999;
Expand Down Expand Up @@ -112,7 +114,7 @@ decimal INumericTC<decimal>.Prev(decimal value)

public decimal Random(Random random)
{
INumericTC<uint> uinttc = default(UIntTC);
INumericTC<uint> uinttc = UIntTC.Instance;
return new DecimalRep(
low: uinttc.Random(random),
mid: uinttc.Random(random),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ private sealed class DecimalValueSetFactory : IValueSetFactory<decimal>, IValueS
{
public static readonly DecimalValueSetFactory Instance = new DecimalValueSetFactory();

private readonly IValueSetFactory<decimal> _underlying = NumericValueSetFactory<decimal, DecimalTC>.Instance;
private readonly IValueSetFactory<decimal> _underlying = new NumericValueSetFactory<decimal>(DecimalTC.Instance);

IValueSet IValueSetFactory.AllValues => NumericValueSet<decimal, DecimalTC>.AllValues;
IValueSet IValueSetFactory.AllValues => NumericValueSet<decimal>.AllValues(DecimalTC.Instance);

IValueSet IValueSetFactory.NoValues => NumericValueSet<decimal, DecimalTC>.NoValues;
IValueSet IValueSetFactory.NoValues => NumericValueSet<decimal>.NoValues(DecimalTC.Instance);

public IValueSet<decimal> Related(BinaryOperatorKind relation, decimal value) => _underlying.Related(relation, DecimalTC.Normalize(value));

IValueSet IValueSetFactory.Random(int expectedSize, Random random) => _underlying.Random(expectedSize, random);

ConstantValue IValueSetFactory.RandomValue(Random random) => ConstantValue.Create(default(DecimalTC).Random(random));
ConstantValue IValueSetFactory.RandomValue(Random random) => ConstantValue.Create(DecimalTC.Instance.Random(random));

IValueSet IValueSetFactory.Related(BinaryOperatorKind relation, ConstantValue value) =>
value.IsBad ? NumericValueSet<decimal, DecimalTC>.AllValues : Related(relation, default(DecimalTC).FromConstantValue(value));
value.IsBad ? NumericValueSet<decimal>.AllValues(DecimalTC.Instance) : Related(relation, DecimalTC.Instance.FromConstantValue(value));

bool IValueSetFactory.Related(BinaryOperatorKind relation, ConstantValue left, ConstantValue right) => _underlying.Related(relation, left, right);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ namespace Microsoft.CodeAnalysis.CSharp

internal static partial class ValueSetFactory
{
private struct DoubleTC : FloatingTC<double>, INumericTC<double>
private class DoubleTC : FloatingTC<double>, INumericTC<double>
{
public static readonly DoubleTC Instance = new DoubleTC();

double INumericTC<double>.MinValue => double.NegativeInfinity;

double INumericTC<double>.MaxValue => double.PositiveInfinity;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using Roslyn.Utilities;

Expand All @@ -17,8 +18,7 @@ internal static partial class ValueSetFactory
/// relational operators for it; such a set can be formed only by including explicitly mentioned
/// members (or the inverse, excluding them, by complementing the set).
/// </summary>
private sealed class EnumeratedValueSet<T, TTC> : IValueSet<T>
where TTC : struct, IEquatableValueTC<T>
private sealed class EnumeratedValueSet<T> : IValueSet<T>
where T : notnull
{
/// <summary>
Expand All @@ -29,14 +29,19 @@ private sealed class EnumeratedValueSet<T, TTC> : IValueSet<T>

private readonly ImmutableHashSet<T> _membersIncludedOrExcluded;

private EnumeratedValueSet(bool included, ImmutableHashSet<T> membersIncludedOrExcluded) =>
(this._included, this._membersIncludedOrExcluded) = (included, membersIncludedOrExcluded);
private readonly IEquatableValueTC<T> _tc;

public static readonly EnumeratedValueSet<T, TTC> AllValues = new EnumeratedValueSet<T, TTC>(included: false, ImmutableHashSet<T>.Empty);
private EnumeratedValueSet(bool included, ImmutableHashSet<T> membersIncludedOrExcluded, IEquatableValueTC<T> tc) =>
(this._included, this._membersIncludedOrExcluded, this._tc) = (included, membersIncludedOrExcluded, tc);

public static readonly EnumeratedValueSet<T, TTC> NoValues = new EnumeratedValueSet<T, TTC>(included: true, ImmutableHashSet<T>.Empty);
public static EnumeratedValueSet<T> AllValues(IEquatableValueTC<T> tc)
=> new EnumeratedValueSet<T>(included: false, ImmutableHashSet<T>.Empty, tc);

internal static EnumeratedValueSet<T, TTC> Including(T value) => new EnumeratedValueSet<T, TTC>(included: true, ImmutableHashSet<T>.Empty.Add(value));
public static EnumeratedValueSet<T> NoValues(IEquatableValueTC<T> tc)
=> new EnumeratedValueSet<T>(included: true, ImmutableHashSet<T>.Empty, tc);

internal static EnumeratedValueSet<T> Including(T value, IEquatableValueTC<T> tc)
=> new EnumeratedValueSet<T>(included: true, ImmutableHashSet<T>.Empty.Add(value), tc);

public bool IsEmpty => _included && _membersIncludedOrExcluded.IsEmpty;

Expand All @@ -45,25 +50,24 @@ ConstantValue IValueSet.Sample
get
{
if (IsEmpty) throw new ArgumentException();
var tc = default(TTC);
if (_included)
return tc.ToConstantValue(_membersIncludedOrExcluded.OrderBy(k => k).First());
return _tc.ToConstantValue(_membersIncludedOrExcluded.OrderBy(k => k).First());
if (typeof(T) == typeof(string))
{
// try some simple strings.
if (this.Any(BinaryOperatorKind.Equal, (T)(object)""))
return tc.ToConstantValue((T)(object)"");
return _tc.ToConstantValue((T)(object)"");
for (char c = 'A'; c <= 'z'; c++)
if (this.Any(BinaryOperatorKind.Equal, (T)(object)c.ToString()))
return tc.ToConstantValue((T)(object)c.ToString());
return _tc.ToConstantValue((T)(object)c.ToString());
}
// If that doesn't work, choose from a sufficiently large random selection of values.
// Since this is an excluded set, they cannot all be excluded
var candidates = tc.RandomValues(_membersIncludedOrExcluded.Count + 1, new Random(0), _membersIncludedOrExcluded.Count + 1);
var candidates = _tc.RandomValues(_membersIncludedOrExcluded.Count + 1, new Random(0), _membersIncludedOrExcluded.Count + 1);
foreach (var value in candidates)
{
if (this.Any(BinaryOperatorKind.Equal, value))
return tc.ToConstantValue(value);
return _tc.ToConstantValue(value);
}

throw ExceptionUtilities.Unreachable();
Expand All @@ -81,7 +85,7 @@ public bool Any(BinaryOperatorKind relation, T value)
}
}

bool IValueSet.Any(BinaryOperatorKind relation, ConstantValue value) => value.IsBad || Any(relation, default(TTC).FromConstantValue(value));
bool IValueSet.Any(BinaryOperatorKind relation, ConstantValue value) => value.IsBad || Any(relation, _tc.FromConstantValue(value));

public bool All(BinaryOperatorKind relation, T value)
{
Expand All @@ -104,28 +108,30 @@ public bool All(BinaryOperatorKind relation, T value)
}
}

bool IValueSet.All(BinaryOperatorKind relation, ConstantValue value) => !value.IsBad && All(relation, default(TTC).FromConstantValue(value));
bool IValueSet.All(BinaryOperatorKind relation, ConstantValue value) => !value.IsBad && All(relation, _tc.FromConstantValue(value));

public IValueSet<T> Complement() => new EnumeratedValueSet<T, TTC>(!_included, _membersIncludedOrExcluded);
public IValueSet<T> Complement() => new EnumeratedValueSet<T>(!_included, _membersIncludedOrExcluded, _tc);

IValueSet IValueSet.Complement() => this.Complement();

public IValueSet<T> Intersect(IValueSet<T> o)
{
if (this == o)
return this;
var other = (EnumeratedValueSet<T, TTC>)o;
var other = (EnumeratedValueSet<T>)o;
Debug.Assert(object.ReferenceEquals(this._tc, other._tc));

var (larger, smaller) = (this._membersIncludedOrExcluded.Count > other._membersIncludedOrExcluded.Count) ? (this, other) : (other, this);
switch (larger._included, smaller._included)
{
case (true, true):
return new EnumeratedValueSet<T, TTC>(true, larger._membersIncludedOrExcluded.Intersect(smaller._membersIncludedOrExcluded));
return new EnumeratedValueSet<T>(true, larger._membersIncludedOrExcluded.Intersect(smaller._membersIncludedOrExcluded), _tc);
case (true, false):
return new EnumeratedValueSet<T, TTC>(true, larger._membersIncludedOrExcluded.Except(smaller._membersIncludedOrExcluded));
return new EnumeratedValueSet<T>(true, larger._membersIncludedOrExcluded.Except(smaller._membersIncludedOrExcluded), _tc);
case (false, false):
return new EnumeratedValueSet<T, TTC>(false, larger._membersIncludedOrExcluded.Union(smaller._membersIncludedOrExcluded));
return new EnumeratedValueSet<T>(false, larger._membersIncludedOrExcluded.Union(smaller._membersIncludedOrExcluded), _tc);
case (false, true):
return new EnumeratedValueSet<T, TTC>(true, smaller._membersIncludedOrExcluded.Except(larger._membersIncludedOrExcluded));
return new EnumeratedValueSet<T>(true, smaller._membersIncludedOrExcluded.Except(larger._membersIncludedOrExcluded), _tc);
}
}

Expand All @@ -135,28 +141,31 @@ public IValueSet<T> Union(IValueSet<T> o)
{
if (this == o)
return this;
var other = (EnumeratedValueSet<T, TTC>)o;
var other = (EnumeratedValueSet<T>)o;
Debug.Assert(object.ReferenceEquals(this._tc, other._tc));

var (larger, smaller) = (this._membersIncludedOrExcluded.Count > other._membersIncludedOrExcluded.Count) ? (this, other) : (other, this);
switch (larger._included, smaller._included)
{
case (false, false):
return new EnumeratedValueSet<T, TTC>(false, larger._membersIncludedOrExcluded.Intersect(smaller._membersIncludedOrExcluded));
return new EnumeratedValueSet<T>(false, larger._membersIncludedOrExcluded.Intersect(smaller._membersIncludedOrExcluded), _tc);
case (false, true):
return new EnumeratedValueSet<T, TTC>(false, larger._membersIncludedOrExcluded.Except(smaller._membersIncludedOrExcluded));
return new EnumeratedValueSet<T>(false, larger._membersIncludedOrExcluded.Except(smaller._membersIncludedOrExcluded), _tc);
case (true, true):
return new EnumeratedValueSet<T, TTC>(true, larger._membersIncludedOrExcluded.Union(smaller._membersIncludedOrExcluded));
return new EnumeratedValueSet<T>(true, larger._membersIncludedOrExcluded.Union(smaller._membersIncludedOrExcluded), _tc);
case (true, false):
return new EnumeratedValueSet<T, TTC>(false, smaller._membersIncludedOrExcluded.Except(larger._membersIncludedOrExcluded));
return new EnumeratedValueSet<T>(false, smaller._membersIncludedOrExcluded.Except(larger._membersIncludedOrExcluded), _tc);
}
}

IValueSet IValueSet.Union(IValueSet other) => Union((IValueSet<T>)other);

public override bool Equals(object? obj)
{
if (obj is not EnumeratedValueSet<T, TTC> other)
if (obj is not EnumeratedValueSet<T> other)
return false;

Debug.Assert(object.ReferenceEquals(this._tc, other._tc));
return this._included == other._included
&& this._membersIncludedOrExcluded.SetEqualsWithoutIntermediateHashSet(other._membersIncludedOrExcluded);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

using System;
using System.Diagnostics;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CSharp
{
Expand All @@ -15,42 +14,40 @@ internal static partial class ValueSetFactory
/// <summary>
/// A value set factory that only supports equality and works by including or excluding specific values.
/// </summary>
private sealed class EnumeratedValueSetFactory<T, TTC> : IValueSetFactory<T> where TTC : struct, IEquatableValueTC<T> where T : notnull
private sealed class EnumeratedValueSetFactory<T> : IValueSetFactory<T> where T : notnull
{
public static readonly EnumeratedValueSetFactory<T, TTC> Instance = new EnumeratedValueSetFactory<T, TTC>();
private readonly IEquatableValueTC<T> _tc;

IValueSet IValueSetFactory.AllValues => EnumeratedValueSet<T, TTC>.AllValues;
IValueSet IValueSetFactory.AllValues => EnumeratedValueSet<T>.AllValues(_tc);

IValueSet IValueSetFactory.NoValues => EnumeratedValueSet<T, TTC>.NoValues;
IValueSet IValueSetFactory.NoValues => EnumeratedValueSet<T>.NoValues(_tc);

private EnumeratedValueSetFactory() { }
public EnumeratedValueSetFactory(IEquatableValueTC<T> tc) { _tc = tc; }

public IValueSet<T> Related(BinaryOperatorKind relation, T value)
{
switch (relation)
{
case Equal:
return EnumeratedValueSet<T, TTC>.Including(value);
return EnumeratedValueSet<T>.Including(value, _tc);
default:
return EnumeratedValueSet<T, TTC>.AllValues; // supported for error recovery
return EnumeratedValueSet<T>.AllValues(_tc); // supported for error recovery
}
}

IValueSet IValueSetFactory.Related(BinaryOperatorKind relation, ConstantValue value) =>
value.IsBad || value.IsNull ? EnumeratedValueSet<T, TTC>.AllValues : this.Related(relation, default(TTC).FromConstantValue(value));
value.IsBad || value.IsNull ? EnumeratedValueSet<T>.AllValues(_tc) : this.Related(relation, _tc.FromConstantValue(value));

bool IValueSetFactory.Related(BinaryOperatorKind relation, ConstantValue left, ConstantValue right)
{
Debug.Assert(relation == BinaryOperatorKind.Equal);
TTC tc = default;
return tc.FromConstantValue(left).Equals(tc.FromConstantValue(right));
return _tc.FromConstantValue(left).Equals(_tc.FromConstantValue(right));
}

public IValueSet Random(int expectedSize, Random random)
{
TTC tc = default;
T[] values = tc.RandomValues(expectedSize, random, expectedSize * 2);
IValueSet<T> result = EnumeratedValueSet<T, TTC>.NoValues;
T[] values = _tc.RandomValues(expectedSize, random, expectedSize * 2);
IValueSet<T> result = EnumeratedValueSet<T>.NoValues(_tc);
Debug.Assert(result.IsEmpty);
foreach (T value in values)
result = result.Union(Related(Equal, value));
Expand All @@ -60,8 +57,7 @@ public IValueSet Random(int expectedSize, Random random)

ConstantValue IValueSetFactory.RandomValue(Random random)
{
TTC tc = default;
return tc.ToConstantValue(tc.RandomValues(1, random, 100)[0]);
return _tc.ToConstantValue(_tc.RandomValues(1, random, 100)[0]);
}
}
}
Expand Down
Loading

0 comments on commit 3ae3ef9

Please sign in to comment.