diff --git a/Generator.Equals.Runtime/Attributes.cs b/Generator.Equals.Runtime/Attributes.cs index ad5f1cf..e41b04f 100644 --- a/Generator.Equals.Runtime/Attributes.cs +++ b/Generator.Equals.Runtime/Attributes.cs @@ -62,6 +62,19 @@ public class SetEqualityAttribute : Attribute { } + [GeneratedCode("Generator.Equals", "1.0.0.0")] + [Conditional("GENERATOR_EQUALS")] + [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] + public class StringEqualityAttribute : Attribute + { + public StringComparison ComparisonType { get; } + + public StringEqualityAttribute(StringComparison comparisonType) + { + ComparisonType = comparisonType; + } + } + [GeneratedCode("Generator.Equals", "1.0.0.0")] [Conditional("GENERATOR_EQUALS")] [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] diff --git a/Generator.Equals.SnapshotTests/Classes/StringEquality.Net60.Diagnostics.verified.txt b/Generator.Equals.SnapshotTests/Classes/StringEquality.Net60.Diagnostics.verified.txt new file mode 100644 index 0000000..ad47dbb --- /dev/null +++ b/Generator.Equals.SnapshotTests/Classes/StringEquality.Net60.Diagnostics.verified.txt @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/Generator.Equals.SnapshotTests/Classes/StringEquality.NetFramework48.Diagnostics.verified.txt b/Generator.Equals.SnapshotTests/Classes/StringEquality.NetFramework48.Diagnostics.verified.txt new file mode 100644 index 0000000..ad47dbb --- /dev/null +++ b/Generator.Equals.SnapshotTests/Classes/StringEquality.NetFramework48.Diagnostics.verified.txt @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/Generator.Equals.SnapshotTests/RecordStructs/StringEquality.Net60.Diagnostics.verified.txt b/Generator.Equals.SnapshotTests/RecordStructs/StringEquality.Net60.Diagnostics.verified.txt new file mode 100644 index 0000000..ad47dbb --- /dev/null +++ b/Generator.Equals.SnapshotTests/RecordStructs/StringEquality.Net60.Diagnostics.verified.txt @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/Generator.Equals.SnapshotTests/RecordStructs/StringEquality.NetFramework48.Diagnostics.verified.txt b/Generator.Equals.SnapshotTests/RecordStructs/StringEquality.NetFramework48.Diagnostics.verified.txt new file mode 100644 index 0000000..ad47dbb --- /dev/null +++ b/Generator.Equals.SnapshotTests/RecordStructs/StringEquality.NetFramework48.Diagnostics.verified.txt @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/Generator.Equals.SnapshotTests/Records/StringEquality.Net60.Diagnostics.verified.txt b/Generator.Equals.SnapshotTests/Records/StringEquality.Net60.Diagnostics.verified.txt new file mode 100644 index 0000000..ad47dbb --- /dev/null +++ b/Generator.Equals.SnapshotTests/Records/StringEquality.Net60.Diagnostics.verified.txt @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/Generator.Equals.SnapshotTests/Records/StringEquality.NetFramework48.Diagnostics.verified.txt b/Generator.Equals.SnapshotTests/Records/StringEquality.NetFramework48.Diagnostics.verified.txt new file mode 100644 index 0000000..ad47dbb --- /dev/null +++ b/Generator.Equals.SnapshotTests/Records/StringEquality.NetFramework48.Diagnostics.verified.txt @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/Generator.Equals.SnapshotTests/Structs/StringEquality.Net60.Diagnostics.verified.txt b/Generator.Equals.SnapshotTests/Structs/StringEquality.Net60.Diagnostics.verified.txt new file mode 100644 index 0000000..ad47dbb --- /dev/null +++ b/Generator.Equals.SnapshotTests/Structs/StringEquality.Net60.Diagnostics.verified.txt @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/Generator.Equals.SnapshotTests/Structs/StringEquality.NetFramework48.Diagnostics.verified.txt b/Generator.Equals.SnapshotTests/Structs/StringEquality.NetFramework48.Diagnostics.verified.txt new file mode 100644 index 0000000..ad47dbb --- /dev/null +++ b/Generator.Equals.SnapshotTests/Structs/StringEquality.NetFramework48.Diagnostics.verified.txt @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/Generator.Equals.Tests/Classes/StringEquality.Sample.cs b/Generator.Equals.Tests/Classes/StringEquality.Sample.cs new file mode 100644 index 0000000..b85e79c --- /dev/null +++ b/Generator.Equals.Tests/Classes/StringEquality.Sample.cs @@ -0,0 +1,31 @@ +using System; + +namespace Generator.Equals.Tests.Classes +{ + public partial class StringEquality + { + [Equatable] + public partial class SampleCaseInsensitive + { + public SampleCaseInsensitive(string name) + { + Name = name; + } + + [StringEquality(StringComparison.CurrentCultureIgnoreCase)] + public string Name { get; } + } + + [Equatable] + public partial class SampleCaseSensitive + { + public SampleCaseSensitive(string name) + { + Name = name; + } + + [StringEquality(StringComparison.CurrentCulture)] + public string Name { get; } + } + } +} diff --git a/Generator.Equals.Tests/Classes/StringEquality.cs b/Generator.Equals.Tests/Classes/StringEquality.cs new file mode 100644 index 0000000..9fcfe2d --- /dev/null +++ b/Generator.Equals.Tests/Classes/StringEquality.cs @@ -0,0 +1,55 @@ +namespace Generator.Equals.Tests.Classes +{ + public partial class StringEquality + { + public class EqualsTestsNotCaseSensitive : EqualityTestCase + { + public override object Factory1() => new SampleCaseInsensitive("BAR"); + public override object Factory2() => new SampleCaseInsensitive("bar"); + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 == (SampleCaseInsensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 != (SampleCaseInsensitive)value2; + } + + public class NotEqualsTestsNotCaseSensitive : EqualityTestCase + { + public override object Factory1() => new SampleCaseInsensitive("BAR"); + public override object Factory2() => new SampleCaseInsensitive("foo"); + public override bool Expected => false; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 == (SampleCaseInsensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 != (SampleCaseInsensitive)value2; + } + + public class EqualsTestsCaseSensitive : EqualityTestCase + { + public override object Factory1() => new SampleCaseSensitive("Foo"); + public override object Factory2() => new SampleCaseSensitive("Foo"); + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseSensitive)value1 == (SampleCaseSensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseSensitive)value1 != (SampleCaseSensitive)value2; + } + + public class NotEqualsTestsCaseSensitive : EqualityTestCase + { + public override object Factory1() => new SampleCaseSensitive("Foo"); + public override object Factory2() => new SampleCaseSensitive("foo"); + public override bool Expected => false; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseSensitive)value1 == (SampleCaseSensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseSensitive)value1 != (SampleCaseSensitive)value2; + } + } +} diff --git a/Generator.Equals.Tests/RecordStructs/StringEquality.Sample.cs b/Generator.Equals.Tests/RecordStructs/StringEquality.Sample.cs new file mode 100644 index 0000000..481ccc6 --- /dev/null +++ b/Generator.Equals.Tests/RecordStructs/StringEquality.Sample.cs @@ -0,0 +1,17 @@ +using System; + +namespace Generator.Equals.Tests.RecordStructs +{ + public partial class StringEquality + { + [Equatable] + public partial record struct SampleCaseInsensitive( + [property: StringEquality(StringComparison.CurrentCultureIgnoreCase)] + string Name); + + [Equatable] + public partial record struct SampleCaseSensitive( + [property: StringEquality(StringComparison.CurrentCulture)] + string Name); + } +} diff --git a/Generator.Equals.Tests/RecordStructs/StringEquality.cs b/Generator.Equals.Tests/RecordStructs/StringEquality.cs new file mode 100644 index 0000000..72fb932 --- /dev/null +++ b/Generator.Equals.Tests/RecordStructs/StringEquality.cs @@ -0,0 +1,30 @@ +namespace Generator.Equals.Tests.RecordStructs +{ + public partial class StringEquality + { + public class EqualsTests : EqualityTestCase + { + public override object Factory1() => new SampleCaseInsensitive { Name = "BAR" }; + public override object Factory2() => new SampleCaseInsensitive { Name = "bar" }; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 == (SampleCaseInsensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 != (SampleCaseInsensitive)value2; + } + + public class NotEqualsTest : EqualityTestCase + { + public override object Factory1() => new SampleCaseInsensitive { Name = "BAR" }; + public override object Factory2() => new SampleCaseInsensitive { Name = "foo" }; + public override bool Expected => false; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 == (SampleCaseInsensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 != (SampleCaseInsensitive)value2; + } + } +} diff --git a/Generator.Equals.Tests/Records/StringEquality.Sample.cs b/Generator.Equals.Tests/Records/StringEquality.Sample.cs new file mode 100644 index 0000000..aff8f5a --- /dev/null +++ b/Generator.Equals.Tests/Records/StringEquality.Sample.cs @@ -0,0 +1,21 @@ +using System; + +namespace Generator.Equals.Tests.Records +{ + public partial class StringEquality + { + [Equatable] + public partial record SampleCaseInsensitive + { + [StringEquality(StringComparison.CurrentCultureIgnoreCase)] + public string Name { get; init; } = ""; + } + + [Equatable] + public partial record SampleCaseSensitive + { + [StringEquality(StringComparison.CurrentCulture)] + public string Name { get; init; } = ""; + } + } +} diff --git a/Generator.Equals.Tests/Records/StringEquality.cs b/Generator.Equals.Tests/Records/StringEquality.cs new file mode 100644 index 0000000..84e9f3f --- /dev/null +++ b/Generator.Equals.Tests/Records/StringEquality.cs @@ -0,0 +1,55 @@ +namespace Generator.Equals.Tests.Records +{ + public partial class StringEquality + { + public class EqualsTestsNotCaseSensitive : EqualityTestCase + { + public override object Factory1() => new SampleCaseInsensitive { Name = "BAR" }; + public override object Factory2() => new SampleCaseInsensitive { Name = "bar" }; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 == (SampleCaseInsensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 != (SampleCaseInsensitive)value2; + } + + public class NotEqualsTestsNotCaseSensitive : EqualityTestCase + { + public override object Factory1() => new SampleCaseInsensitive { Name = "BAR" }; + public override object Factory2() => new SampleCaseInsensitive { Name = "foo" }; + public override bool Expected => false; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 == (SampleCaseInsensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 != (SampleCaseInsensitive)value2; + } + + public class EqualsTestsCaseSensitive : EqualityTestCase + { + public override object Factory1() => new SampleCaseSensitive { Name = "Foo" }; + public override object Factory2() => new SampleCaseSensitive { Name = "Foo" }; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseSensitive)value1 == (SampleCaseSensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseSensitive)value1 != (SampleCaseSensitive)value2; + } + + public class NotEqualsTestsCaseSensitive : EqualityTestCase + { + public override object Factory1() => new SampleCaseSensitive { Name = "Foo" }; + public override object Factory2() => new SampleCaseSensitive { Name = "foo" }; + public override bool Expected => false; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseSensitive)value1 == (SampleCaseSensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseSensitive)value1 != (SampleCaseSensitive)value2; + } + } +} diff --git a/Generator.Equals.Tests/Structs/StringEquality.Sample.cs b/Generator.Equals.Tests/Structs/StringEquality.Sample.cs new file mode 100644 index 0000000..15b5bbd --- /dev/null +++ b/Generator.Equals.Tests/Structs/StringEquality.Sample.cs @@ -0,0 +1,31 @@ +using System; + +namespace Generator.Equals.Tests.Structs +{ + public partial class StringEquality + { + [Equatable] + public partial struct SampleCaseInsensitive + { + public SampleCaseInsensitive(string name) + { + Name = name; + } + + [StringEquality(StringComparison.CurrentCultureIgnoreCase)] + public string Name { get; init; } = ""; + } + + [Equatable] + public partial struct SampleCaseSensitive + { + public SampleCaseSensitive(string name) + { + Name = name; + } + + [StringEquality(StringComparison.CurrentCulture)] + public string Name { get; init; } = ""; + } + } +} diff --git a/Generator.Equals.Tests/Structs/StringEquality.cs b/Generator.Equals.Tests/Structs/StringEquality.cs new file mode 100644 index 0000000..7d69503 --- /dev/null +++ b/Generator.Equals.Tests/Structs/StringEquality.cs @@ -0,0 +1,30 @@ +namespace Generator.Equals.Tests.Structs +{ + public partial class StringEquality + { + public class EqualsTests : EqualityTestCase + { + public override object Factory1() => new SampleCaseInsensitive { Name = "BAR" }; + public override object Factory2() => new SampleCaseInsensitive { Name = "bar" }; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 == (SampleCaseInsensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 != (SampleCaseInsensitive)value2; + } + + public class NotEqualsTest : EqualityTestCase + { + public override object Factory1() => new SampleCaseInsensitive { Name = "BAR" }; + public override object Factory2() => new SampleCaseInsensitive { Name = "foo" }; + public override bool Expected => false; + + public override bool EqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 == (SampleCaseInsensitive)value2; + + public override bool NotEqualsOperator(object value1, object value2) => + (SampleCaseInsensitive)value1 != (SampleCaseInsensitive)value2; + } + } +} diff --git a/Generator.Equals/AttributesMetadata.cs b/Generator.Equals/AttributesMetadata.cs index cc4d762..c46fa0a 100644 --- a/Generator.Equals/AttributesMetadata.cs +++ b/Generator.Equals/AttributesMetadata.cs @@ -1,4 +1,5 @@ -using Microsoft.CodeAnalysis; +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; namespace Generator.Equals { @@ -11,17 +12,21 @@ public class AttributesMetadata public INamedTypeSymbol UnorderedEquality { get; } public INamedTypeSymbol ReferenceEquality { get; } public INamedTypeSymbol SetEquality { get; } + public INamedTypeSymbol StringEquality { get; } public INamedTypeSymbol CustomEquality { get; } + public ImmutableDictionary StringComparisonLookup { get; } public AttributesMetadata( INamedTypeSymbol equatable, INamedTypeSymbol defaultEquality, INamedTypeSymbol orderedEquality, INamedTypeSymbol ignoreEquality, - INamedTypeSymbol unorderedEquality, - INamedTypeSymbol referenceEquality, + INamedTypeSymbol unorderedEquality, + INamedTypeSymbol referenceEquality, INamedTypeSymbol setEquality, - INamedTypeSymbol customEquality) + INamedTypeSymbol stringEquality, + INamedTypeSymbol customEquality, + ImmutableDictionary stringComparisonLookup) { Equatable = equatable; DefaultEquality = defaultEquality; @@ -30,7 +35,9 @@ public AttributesMetadata( UnorderedEquality = unorderedEquality; ReferenceEquality = referenceEquality; SetEquality = setEquality; + StringEquality = stringEquality; CustomEquality = customEquality; + StringComparisonLookup = stringComparisonLookup; } } } diff --git a/Generator.Equals/EqualityGeneratorBase.cs b/Generator.Equals/EqualityGeneratorBase.cs index 868ac7a..fa451a9 100644 --- a/Generator.Equals/EqualityGeneratorBase.cs +++ b/Generator.Equals/EqualityGeneratorBase.cs @@ -2,6 +2,7 @@ using System.CodeDom.Compiler; using System.Linq; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Generator.Equals { @@ -79,6 +80,23 @@ static void BuildEquality(AttributesMetadata attributesMetadata, IndentedTextWri writer.WriteLine( $"&& global::Generator.Equals.SetEqualityComparer<{string.Join(", ", types.Value)}>.Default.Equals(this.{propertyName}!, other.{propertyName}!)"); } + else if (memberSymbol.HasAttribute(attributesMetadata.StringEquality)) + { + var attribute = memberSymbol.GetAttribute(attributesMetadata.StringEquality)!; + var stringComparisonValue = Convert.ToInt64(attribute.ConstructorArguments[0].Value); + + if (!attributesMetadata.StringComparisonLookup.TryGetValue(stringComparisonValue, + out var enumMemberName)) + { + // NOTE: Very unlikely as this would mean changes to the StringComparison enum + // which is not expected to change. It would also mean that the compiler + // and analyzer are running different dotnet versions. + throw new Exception("should not have gotten here."); + } + + writer.WriteLine( + $"&& global::System.StringComparer.{enumMemberName}.Equals(this.{propertyName}!, other.{propertyName}!)"); + } else if (memberSymbol.HasAttribute(attributesMetadata.CustomEquality)) { var attribute = memberSymbol.GetAttribute(attributesMetadata.CustomEquality); @@ -198,6 +216,26 @@ void BuildHashCodeAdd(Action action) $"global::Generator.Equals.SetEqualityComparer<{string.Join(", ", types.Value)}>.Default"); }); } + else if (memberSymbol.HasAttribute(attributesMetadata.StringEquality)) + { + BuildHashCodeAdd(() => + { + var attribute = memberSymbol.GetAttribute(attributesMetadata.StringEquality)!; + var stringComparisonValue = Convert.ToInt64(attribute.ConstructorArguments[0].Value); + + if (!attributesMetadata.StringComparisonLookup.TryGetValue(stringComparisonValue, + out var enumMemberName)) + { + // NOTE: Very unlikely as this would mean changes to the StringComparison enum + // which is not expected to change. It would also mean that the compiler + // and analyzer are running different dotnet versions. + throw new Exception("should not have gotten here."); + } + + writer.WriteLine( + $"global::System.StringComparer.{enumMemberName}"); + }); + } else if (memberSymbol.HasAttribute(attributesMetadata.CustomEquality)) { BuildHashCodeAdd(() => diff --git a/Generator.Equals/EqualsGenerator.cs b/Generator.Equals/EqualsGenerator.cs index d5a5fdd..543c141 100644 --- a/Generator.Equals/EqualsGenerator.cs +++ b/Generator.Equals/EqualsGenerator.cs @@ -41,6 +41,21 @@ public void Initialize(IncrementalGeneratorInitializationContext context) void Execute(SourceProductionContext productionContext, Compilation compilation, ImmutableArray syntaxArr) { + // Build a lookup for the System.StringComparison enum based on the compilation unit + INamedTypeSymbol typeSymbol = compilation.GetTypeByMetadataName("System.StringComparison")!; + + if (typeSymbol is not { TypeKind: TypeKind.Enum }) + { + throw new Exception("should not have gotten here. System.StringComparison is not an enum."); + } + + // Assume: Underlying type of enum is always `long` + var stringComparisonLookup = typeSymbol + .GetMembers() + .OfType() + .ToImmutableDictionary(key => Convert.ToInt64(key.ConstantValue), elem => elem.Name); + + var attributesMetadata = new AttributesMetadata( compilation.GetTypeByMetadataName("Generator.Equals.EquatableAttribute")!, compilation.GetTypeByMetadataName("Generator.Equals.DefaultEqualityAttribute")!, @@ -49,9 +64,11 @@ void Execute(SourceProductionContext productionContext, Compilation compilation, compilation.GetTypeByMetadataName("Generator.Equals.UnorderedEqualityAttribute")!, compilation.GetTypeByMetadataName("Generator.Equals.ReferenceEqualityAttribute")!, compilation.GetTypeByMetadataName("Generator.Equals.SetEqualityAttribute")!, - compilation.GetTypeByMetadataName("Generator.Equals.CustomEqualityAttribute")! + compilation.GetTypeByMetadataName("Generator.Equals.StringEqualityAttribute")!, + compilation.GetTypeByMetadataName("Generator.Equals.CustomEqualityAttribute")!, + stringComparisonLookup ); - + var handledSymbols = new HashSet(); foreach (var item in syntaxArr) diff --git a/README.md b/README.md index c2bedf0..432da3a 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,13 @@ public string Name { get; set; } // Will only return true if strings are the sam This will ignore whatever equality is implemented for a particular object and compare references instead. +### StringEquality + +```c# +[StringEquality(StringComparison.CurrentCulture | CurrentCultureIgnoreCase | InvariantCulture | InvariantCultureIgnoreCase | Ordinal | OrdinalIgnoreCase)] +public string Title { get; set; } // Will use the StringComparison set in constructor when comparing strings +``` + ### CustomEquality ```c#