diff --git a/Generator.Equals.DynamicGenerationTests/Base_Assertions.cs b/Generator.Equals.DynamicGenerationTests/Base_Assertions.cs new file mode 100644 index 0000000..c2a92d3 --- /dev/null +++ b/Generator.Equals.DynamicGenerationTests/Base_Assertions.cs @@ -0,0 +1,82 @@ +extern alias genEquals; +using System.Collections.Immutable; +using Microsoft.CodeAnalysis.CSharp; + +namespace Generator.Equals.Tests.DynamicGeneration; + +public class Base_Assertions +{ + // Check if immutable arrays are comparable + [Fact] + public void Immutable_Arrays_Equatable() + { + var a = ImmutableArray.Create(1, 2, 3); + var b = ImmutableArray.Create(1, 2, 3); + Assert.Equal(a, b); + } + + [Fact] + public void Immutable_Arrays_Comparable() + { + var a = ImmutableArray.Create(1, 2, 3); + var b = ImmutableArray.Create(1, 2, 3); + Assert.True(a.SequenceEqual(b)); + } + + // Check if equality type model is comparable by value + [Fact] + public void EqualityTypeModel_Equatable() + { + var a = CreateEqualityTypeModelMock(); + var b = CreateEqualityTypeModelMock(); + Assert.Equal(a, b); + } + + [Fact] + public void AttributesMetadata_Equatable() + { + var a = AttributesMetadata.CreateDefault(); + var b = AttributesMetadata.CreateDefault(); + Assert.True(a.Equals(b)); + } + + internal static EqualityTypeModel CreateEqualityTypeModelMock() + { + var containingSymbols = ImmutableArray.Create( + new NamespaceContainingSymbol { Name = "Namespace1" }, + new TypeContainingSymbol { Name = "ContainingType", Kind = null } + ); + + var attributesMetadata = AttributesMetadata.Instance; + + var equalityMemberModels = ImmutableArray.Create( + new EqualityMemberModel + { + PropertyName = "Property1", + TypeName = "int", + EqualityType = EqualityType.DefaultEquality + }, + new EqualityMemberModel + { + PropertyName = "Property2", + TypeName = "string", + EqualityType = EqualityType.StringEquality, + StringComparer = attributesMetadata.StringComparisonLookup[4] + } + ); + + return new EqualityTypeModel + { + TypeName = "MyType", + ContainingSymbols = containingSymbols, + AttributesMetadata = attributesMetadata, + ExplicitMode = false, + IgnoreInheritedMembers = false, + SyntaxKind = SyntaxKind.ClassDeclaration, + BaseTypeName = "BaseType", + IsSealed = true, + BuildEqualityModels = equalityMemberModels, + Fullname = "Namespace1.MyType" + }; + } +} \ No newline at end of file diff --git a/Generator.Equals.DynamicGenerationTests/DefaultsTest.cs b/Generator.Equals.DynamicGenerationTests/DefaultsTest.cs new file mode 100644 index 0000000..0865e2e --- /dev/null +++ b/Generator.Equals.DynamicGenerationTests/DefaultsTest.cs @@ -0,0 +1,121 @@ +using Generator.Equals.Tests.DynamicGeneration.Utils; + +namespace Generator.Equals.Tests.DynamicGeneration; + +public class DefaultsTest +{ + [Theory] + [InlineData("Unordered", + "global::Generator.Equals.UnorderedEqualityComparer.Default.Equals(this.Properties!, other.Properties!)")] + [InlineData("Ordered", + "global::Generator.Equals.OrderedEqualityComparer.Default.Equals(this.Properties!, other.Properties!)")] + public void Global_Enumerable_Order_Is_Respected(string enumerableEquality, string expectedComparisonLine) + { + var source = + """ + public partial class UnorderedEquality + { + [Equatable] + public partial class Sample + { + public List? Properties { get; set; } + } + } + """; + + + // generator_equals_comparison_string = OrdinalIgnoreCase + // generator_equals_comparison_enumerable = Unordered + var genResult = GeneratorTestHelper.RunGenerator(source, new() + { + { "generator_equals_comparison_enumerable", enumerableEquality } + }); + + var gensource = genResult.Results + .SelectMany(x => x.GeneratedSources) + .Select(x => x.SourceText) + .ToList() + ; + + Assert.NotNull(gensource); + + Assert.Contains( + expectedComparisonLine, + gensource.FirstOrDefault()?.ToString()); + } + + // generator_equals_comparison_string is respected + [Theory] + [InlineData("OrdinalIgnoreCase", + "global::System.StringComparer.OrdinalIgnoreCase.Equals(this.Tag!, other.Tag!)")] + [InlineData("Ordinal", + "global::System.StringComparer.Ordinal.Equals(this.Tag!, other.Tag!)")] + [InlineData("InvariantCulture", + "global::System.StringComparer.InvariantCulture.Equals(this.Tag!, other.Tag!)")] + public void Global_String_Comparison_Is_Respected(string stringComparison, string expectedComparisonLine) + { + var source = + """ + [Equatable] + public partial class Resource + { + public string Tag {get;set;} + } + """; + + var genResult = GeneratorTestHelper.RunGenerator(source, new() + { + { "generator_equals_comparison_string", stringComparison } + }); + + var gensource = genResult.Results + .SelectMany(x => x.GeneratedSources) + .Select(x => x.SourceText) + .ToList() + ; + + Assert.NotNull(gensource); + + Assert.Contains( + expectedComparisonLine, + gensource.FirstOrDefault()?.ToString()); + + } + + [Theory] + [InlineData("OrdinalIgnoreCase", + "new global::Generator.Equals.OrderedEqualityComparer(global::System.StringComparer.OrdinalIgnoreCase).Equals(this.Tags!, other.Tags!)")] + [InlineData("Ordinal", + "new global::Generator.Equals.OrderedEqualityComparer(global::System.StringComparer.Ordinal).Equals(this.Tags!, other.Tags!)")] + [InlineData("InvariantCulture", + "new global::Generator.Equals.OrderedEqualityComparer(global::System.StringComparer.InvariantCulture).Equals(this.Tags!, other.Tags!)")] + public void Global_String_Comparison_Is_Respected_In_Lists(string stringComparison, string expectedComparisonLine) + { + var source = + """ + [Equatable] + public partial class Resource + { + public List? Tags {get;set;} + } + """; + + var genResult = GeneratorTestHelper.RunGenerator(source, new() + { + { "generator_equals_comparison_string", stringComparison } + }); + + var gensource = genResult.Results + .SelectMany(x => x.GeneratedSources) + .Select(x => x.SourceText) + .ToList() + ; + + Assert.NotNull(gensource); + + Assert.Contains( + expectedComparisonLine, + gensource.FirstOrDefault()?.ToString()); + + } +} \ No newline at end of file diff --git a/Generator.Equals.DynamicGenerationTests/Generator.Equals.Tests.DynamicGeneration.csproj b/Generator.Equals.DynamicGenerationTests/Generator.Equals.Tests.DynamicGeneration.csproj new file mode 100644 index 0000000..231fb4b --- /dev/null +++ b/Generator.Equals.DynamicGenerationTests/Generator.Equals.Tests.DynamicGeneration.csproj @@ -0,0 +1,33 @@ + + + + net8.0 + enable + enable + + false + true + + + + + + + + + + + + + + + genEquals + + + + + + + + + diff --git a/Generator.Equals.DynamicGenerationTests/Issues/Issue-60-StringEquality-Enumerables.cs b/Generator.Equals.DynamicGenerationTests/Issues/Issue-60-StringEquality-Enumerables.cs new file mode 100644 index 0000000..fbd4cb9 --- /dev/null +++ b/Generator.Equals.DynamicGenerationTests/Issues/Issue-60-StringEquality-Enumerables.cs @@ -0,0 +1,91 @@ +using Generator.Equals.Tests.DynamicGeneration.Utils; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using SourceGeneratorTestHelpers; + +namespace Generator.Equals.DynamicGenerationTests.Issues; + +public class Issue_60_StringEquality_Enumerables +{ + public static readonly List References = GeneratorTestHelper.References; + + [Fact] + public void Comparison_is_correctly_generated() + { + var input = SourceText.CSharp( + """ + using System; + using System.Collections.Generic; + using Generator.Equals; + + [Equatable] + public partial class Resource + { + [UnorderedEquality] + [StringEqualityAttribute(StringComparison.OrdinalIgnoreCase)] + public string[] Tags { get; set; } = Array.Empty(); + + } + """ + ); + + var result = IncrementalGenerator.Run + ( + input, + new CSharpParseOptions(), + References + ); + + var gensource = result.Results + .SelectMany(x => x.GeneratedSources) + .Select(x => x.SourceText) + .ToList() + ; + + Assert.NotNull(gensource); + + var src = gensource.FirstOrDefault()?.ToString(); + + Assert.Contains( + "new global::Generator.Equals.UnorderedEqualityComparer(global::System.StringComparer.OrdinalIgnoreCase)", + src); + } + + [Fact] + public void Comparison_is_correctly_generated_without_attributes() + { + var input = SourceText.CSharp( + """ + using System; + using System.Collections.Generic; + using Generator.Equals; + + [Equatable] + public partial class Resource + { + public string[] Tags { get; set; } = Array.Empty(); + } + """ + ); + + var result = IncrementalGenerator.Run + ( + input, + new CSharpParseOptions(), + References + ); + + var gensource = result.Results + .SelectMany(x => x.GeneratedSources) + .Select(x => x.SourceText) + .ToList() + ; + + Assert.NotNull(gensource); + + var src = gensource.FirstOrDefault()?.ToString(); + + Assert.Contains("new global::Generator.Equals.OrderedEqualityComparer(global::System.StringComparer.Ordinal)", + src); + } +} \ No newline at end of file diff --git a/Generator.Equals.DynamicGenerationTests/LocalFieldsAssertions.cs b/Generator.Equals.DynamicGenerationTests/LocalFieldsAssertions.cs new file mode 100644 index 0000000..c2c3944 --- /dev/null +++ b/Generator.Equals.DynamicGenerationTests/LocalFieldsAssertions.cs @@ -0,0 +1,31 @@ +namespace Generator.Equals.Tests.DynamicGeneration; + +public class LocalFieldsAssertions +{ + [Fact] + public static void Test() + { + var code = """ + [Equatable] + public partial class Sample + { + public Sample(string name) + { + Name = name; + } + + [ReferenceEquality] public string Name { get; } + } + """; + + var generated = Utils.GeneratorTestHelper.RunGenerator(code); + + var gensource = generated.Results + .SelectMany(x => x.GeneratedSources) + .Select(x => x.SourceText) + .ToList() + ; + + Assert.NotNull(gensource); + } +} \ No newline at end of file diff --git a/Generator.Equals.DynamicGenerationTests/Utils/GeneratorTestHelper.cs b/Generator.Equals.DynamicGenerationTests/Utils/GeneratorTestHelper.cs new file mode 100644 index 0000000..1e80c51 --- /dev/null +++ b/Generator.Equals.DynamicGenerationTests/Utils/GeneratorTestHelper.cs @@ -0,0 +1,130 @@ +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace Generator.Equals.Tests.DynamicGeneration.Utils; + +public class GeneratorTestHelper +{ + public static readonly List References = + AppDomain.CurrentDomain.GetAssemblies() + .Where(x => !x.IsDynamic && !string.IsNullOrWhiteSpace(x.Location)) + .Select(x => MetadataReference.CreateFromFile(x.Location)) + .Append(MetadataReference.CreateFromFile(typeof(EquatableAttribute).Assembly.Location)) + .ToList(); + + public static GeneratorDriverRunResult RunGenerator + ( + [StringSyntax("c#-test")] string source, + Dictionary? options = null + ) + { + // If source has no namespace, add one + if (!source.Contains("namespace")) + { + source = "namespace TestNamespace;\n" + source; + } + + // using CompositeDto.Generator.Runtime; + // using System; + + if (!source.Contains("using Generator.Equals;")) + { + source = "using Generator.Equals;\n" + source; + } + + if (!source.Contains("using System;")) + { + source = "using System;\n" + source; + } + + if (!source.Contains("using System.Collections.Generic;")) + { + source = "using System.Collections.Generic;\n" + source; + } + + + + + var syntaxTree = CSharpSyntaxTree.ParseText(source); + + var compilation = CSharpCompilation.Create( + assemblyName: "Tests", + syntaxTrees: [syntaxTree], + references: References, + options: new( + outputKind: OutputKind.DynamicallyLinkedLibrary + ) + ); + + IEnumerable generator = new[] + { + new EqualsGenerator().AsSourceGenerator() + }; + + AnalyzerConfigOptionsProvider? prov = + options is null + ? null + : new MockAnalyzerConfigOptionsProvider(options); + + var driver = CSharpGeneratorDriver + .Create + ( + generator + , optionsProvider: prov + ) + .RunGeneratorsAndUpdateCompilation( + compilation, + out var outputCompilation, + out var diagnostics + ); + + var diag = outputCompilation.GetDiagnostics(); + var warnings = diag.Where(d => d.Severity == DiagnosticSeverity.Warning).ToList(); + + // Assert.Empty( + // outputCompilation + // .GetDiagnostics() + // .Where(d => d.Severity is DiagnosticSeverity.Error or DiagnosticSeverity.Warning) + // ); + + // Assert.Empty(diagnostics); + return driver.GetRunResult(); + } +} + +public class MockAnalyzerConfigOptionsProvider : AnalyzerConfigOptionsProvider +{ + public override AnalyzerConfigOptions GlobalOptions { get; } + + public MockAnalyzerConfigOptionsProvider(Dictionary options) + { + GlobalOptions = new MockAnalyzerConfigOptions(options); + } + + public override AnalyzerConfigOptions GetOptions(SyntaxTree tree) + { + return GlobalOptions; + } + + public override AnalyzerConfigOptions GetOptions(AdditionalText textFile) + { + return GlobalOptions; + } +} + +public class MockAnalyzerConfigOptions : AnalyzerConfigOptions +{ + private readonly Dictionary _options; + + public MockAnalyzerConfigOptions(Dictionary options) + { + _options = options; + } + + public override bool TryGetValue(string key, [NotNullWhen(true)] out string? value) + { + return _options.TryGetValue(key, out value); + } +} \ No newline at end of file diff --git a/Generator.Equals.DynamicGenerationTests/Utils/Globals.cs b/Generator.Equals.DynamicGenerationTests/Utils/Globals.cs new file mode 100644 index 0000000..a99ffd8 --- /dev/null +++ b/Generator.Equals.DynamicGenerationTests/Utils/Globals.cs @@ -0,0 +1,4 @@ +extern alias genEquals; + +global using genEquals::Generator.Equals.Models; +global using genEquals::Generator.Equals; \ No newline at end of file diff --git a/Generator.Equals.DynamicGenerationTests/Utils/SourceText.cs b/Generator.Equals.DynamicGenerationTests/Utils/SourceText.cs new file mode 100644 index 0000000..bbfdf2b --- /dev/null +++ b/Generator.Equals.DynamicGenerationTests/Utils/SourceText.cs @@ -0,0 +1,10 @@ +extern alias genEquals; + +using genEquals::System.Diagnostics.CodeAnalysis; + +namespace Generator.Equals.DynamicGenerationTests; + +internal static class SourceText +{ + public static string CSharp([StringSyntax("c#-test")] string source) => source; +} \ No newline at end of file diff --git a/Generator.Equals.Runtime/Attributes.cs b/Generator.Equals.Runtime/Attributes.cs index e41b04f..0b8a83b 100644 --- a/Generator.Equals.Runtime/Attributes.cs +++ b/Generator.Equals.Runtime/Attributes.cs @@ -1,5 +1,6 @@ using System; using System.CodeDom.Compiler; +using System.Collections.Generic; using System.Diagnostics; namespace Generator.Equals diff --git a/Generator.Equals.Tests/.globalconfig b/Generator.Equals.Tests/.globalconfig new file mode 100644 index 0000000..65de447 --- /dev/null +++ b/Generator.Equals.Tests/.globalconfig @@ -0,0 +1,2 @@ +generator_equals_comparison_string = OrdinalIgnoreCase +generator_equals_comparison_enumerable = Unordered \ No newline at end of file diff --git a/Generator.Equals.Tests/Classes/StringArrayEquality.cs b/Generator.Equals.Tests/Classes/StringArrayEquality.cs new file mode 100644 index 0000000..9f46ad1 --- /dev/null +++ b/Generator.Equals.Tests/Classes/StringArrayEquality.cs @@ -0,0 +1,162 @@ +using System; +// ReSharper disable InconsistentNaming +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + +namespace Generator.Equals.Tests.Classes; + +public partial class StringArrayEquality +{ + [Equatable] + public partial class Sample + { + // Order doesnt matter, case doesn't matter + [UnorderedEquality, StringEquality(StringComparison.OrdinalIgnoreCase)] + public string[] Unordered_Case_Insensitive_SS { get; set; } + + // Order matters, case doesn't matter + [OrderedEquality, StringEquality(StringComparison.OrdinalIgnoreCase)] + public string[] Ordered_Case_Insensitive { get; set; } + + // Order doesn't matter, case matters + [UnorderedEquality, StringEquality(StringComparison.Ordinal)] + public string[] Unordered_Case_Sensitive { get; set; } + + // Order matters, case matters + [OrderedEquality, StringEquality(StringComparison.Ordinal)] + public string[] Ordered_Case_Sensitive { get; set; } + + // Default expectation: Order matters, case matters + public string[] DefaultBehaviour { get; set; } + } +} + +public partial class StringArrayEquality +{ + public class EqualsTests : EqualityTestCase + { + public override object Factory1() + { + return new Sample + { + Unordered_Case_Insensitive_SS = new[] { "a", "b", "c" }, + Ordered_Case_Insensitive = new[] { "a", "b", "c" }, + Unordered_Case_Sensitive = new[] { "a", "b", "c" }, + Ordered_Case_Sensitive = new[] { "a", "b", "c" }, + DefaultBehaviour = new[] { "a", "b", "c" }, + }; + } + + public override object Factory2() + { + return new Sample + { + Unordered_Case_Insensitive_SS = new[] { "b", "A", "c" }, + Ordered_Case_Insensitive = new[] { "A", "b", "c" }, + Unordered_Case_Sensitive = new[] { "b", "a", "c" }, + Ordered_Case_Sensitive = new[] { "a", "b", "c" }, + DefaultBehaviour = new[] { "a", "b", "c" }, + }; + } + + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; + } + + // Now test if the equals operator can return false. + // Unordered_Case_Insensitive: "a", "b", "c" vs "b", "X", "c" (X is not in the first array) + public class NotEqualsTest_Unordered_Case_Insensitive : EqualityTestCase + { + public override bool Expected => false; + + public override object Factory1() => new Sample + { + Unordered_Case_Insensitive_SS = new[] { "a", "b", "c" }, + }; + + public override object Factory2() => new Sample + { + Unordered_Case_Insensitive_SS = new[] { "b", "X", "c" }, + }; + + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; + } + + // Ordered_Case_Insensitive: "a", "b", "c" vs "A", "b", "c" (a is different case) + public class NotEqualsTest_Ordered_Case_Insensitive : EqualityTestCase + { + public override bool Expected => false; + + public override object Factory1() => new Sample + { + Ordered_Case_Insensitive = new[] { "a", "b", "c" }, + }; + + public override object Factory2() => new Sample + { + Ordered_Case_Insensitive = new[] { "b", "A", "c" }, + }; + + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; + } + + // Unordered_Case_Sensitive: "a", "b", "c" vs "b", "a", "C" (C is different case) + public class NotEqualsTest_Unordered_Case_Sensitive : EqualityTestCase + { + public override bool Expected => false; + + public override object Factory1() => new Sample + { + Unordered_Case_Sensitive = new[] { "a", "b", "c" }, + }; + + public override object Factory2() => new Sample + { + Unordered_Case_Sensitive = new[] { "b", "a", "C" }, + }; + + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; + } + + // Ordered_Case_Sensitive: "a", "b", "c" vs "b", "a", "C" (C is different case) + public class NotEqualsTest_Ordered_Case_Sensitive : EqualityTestCase + { + public override bool Expected => false; + + public override object Factory1() => new Sample + { + Ordered_Case_Sensitive = new[] { "a", "b", "c" }, + }; + + public override object Factory2() => new Sample + { + Ordered_Case_Sensitive = new[] { "b", "a", "C" }, + }; + + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; + } + + // DefaultBehaviour: "a", "b", "c" vs "b", "a", "C" (C is different case) + public class NotEqualsTest_DefaultBehaviour : EqualityTestCase + { + public override bool Expected => false; + + public override object Factory1() => new Sample + { + DefaultBehaviour = new[] { "a", "b", "c" }, + }; + + public override object Factory2() => new Sample + { + DefaultBehaviour = new[] { "b", "a", "C" }, + }; + + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; + } +} + +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/Generator.Equals.Tests/Classes/UnorderedEquality.Sample.cs b/Generator.Equals.Tests/Classes/UnorderedEquality.Sample.cs index 9a5aa26..54eba18 100644 --- a/Generator.Equals.Tests/Classes/UnorderedEquality.Sample.cs +++ b/Generator.Equals.Tests/Classes/UnorderedEquality.Sample.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; namespace Generator.Equals.Tests.Classes @@ -7,7 +8,8 @@ public partial class UnorderedEquality [Equatable] public partial class Sample { - [UnorderedEquality] public List? Properties { get; set; } + [UnorderedEquality] + public List? Properties { get; set; } } } } \ No newline at end of file diff --git a/Generator.Equals.Tests/Classes/UnorderedEquality.cs b/Generator.Equals.Tests/Classes/UnorderedEquality.cs index 07eb641..8e68dcf 100644 --- a/Generator.Equals.Tests/Classes/UnorderedEquality.cs +++ b/Generator.Equals.Tests/Classes/UnorderedEquality.cs @@ -18,30 +18,30 @@ public override object Factory1() Properties = Enumerable .Range(1, 1000) .OrderBy(_ => randomSort.NextDouble()) - .ToList() + .ToList(), }; } - public override bool EqualsOperator(object value1, object value2) => (Sample) value1 == (Sample) value2; - public override bool NotEqualsOperator(object value1, object value2) => (Sample) value1 != (Sample) value2; + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; } - + public class NotEqualsTest : EqualityTestCase { public override bool Expected => false; public override object Factory1() => new Sample { - Properties = Enumerable.Range(1, 1000).ToList() + Properties = Enumerable.Range(1, 1000).ToList(), }; public override object Factory2() => new Sample { - Properties = Enumerable.Range(1, 1001).ToList() + Properties = Enumerable.Range(1, 1001).ToList(), }; - public override bool EqualsOperator(object value1, object value2) => (Sample) value1 == (Sample) value2; - public override bool NotEqualsOperator(object value1, object value2) => (Sample) value1 != (Sample) value2; + public override bool EqualsOperator(object value1, object value2) => (Sample)value1 == (Sample)value2; + public override bool NotEqualsOperator(object value1, object value2) => (Sample)value1 != (Sample)value2; } } } \ No newline at end of file diff --git a/Generator.Equals.Tests/Generator.Equals.Tests.csproj b/Generator.Equals.Tests/Generator.Equals.Tests.csproj index 2b2b27c..b8321dc 100644 --- a/Generator.Equals.Tests/Generator.Equals.Tests.csproj +++ b/Generator.Equals.Tests/Generator.Equals.Tests.csproj @@ -6,6 +6,8 @@ disable NU1701 enable + + 1 @@ -40,5 +42,7 @@ - + + + \ No newline at end of file diff --git a/Generator.Equals.sln b/Generator.Equals.sln index 993c6a2..afaf96b 100644 --- a/Generator.Equals.sln +++ b/Generator.Equals.sln @@ -1,4 +1,4 @@ - + Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio Version 17 VisualStudioVersion = 17.3.32811.315 @@ -13,6 +13,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Generator.Equals.Tests.TopL EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Generator.Equals.Runtime", "Generator.Equals.Runtime\Generator.Equals.Runtime.csproj", "{2AD99F42-5E2C-451A-97EE-59C64EC8B270}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Generator.Equals.Tests.DynamicGeneration", "Generator.Equals.DynamicGenerationTests\Generator.Equals.Tests.DynamicGeneration.csproj", "{20F96A29-3BC9-4115-B99D-1E6D45C6340F}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -39,6 +41,10 @@ Global {2AD99F42-5E2C-451A-97EE-59C64EC8B270}.Debug|Any CPU.Build.0 = Debug|Any CPU {2AD99F42-5E2C-451A-97EE-59C64EC8B270}.Release|Any CPU.ActiveCfg = Release|Any CPU {2AD99F42-5E2C-451A-97EE-59C64EC8B270}.Release|Any CPU.Build.0 = Release|Any CPU + {20F96A29-3BC9-4115-B99D-1E6D45C6340F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {20F96A29-3BC9-4115-B99D-1E6D45C6340F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {20F96A29-3BC9-4115-B99D-1E6D45C6340F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {20F96A29-3BC9-4115-B99D-1E6D45C6340F}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/Generator.Equals/ContainingTypesBuilder.cs b/Generator.Equals/ContainingTypesBuilder.cs index c27d933..1dec2aa 100644 --- a/Generator.Equals/ContainingTypesBuilder.cs +++ b/Generator.Equals/ContainingTypesBuilder.cs @@ -4,28 +4,37 @@ using System.IO; using System.Linq; using System.Text; - using Generator.Equals.Models; - using Microsoft.CodeAnalysis.CSharp; -namespace Generator.Equals +namespace Generator.Equals; + +internal static class ContainingTypesBuilder { - internal static class ContainingTypesBuilder + public static string Build(EqualityTypeModel model, Action content) { - public static string Build(ImmutableArray containingSymbols, Action content) + var sb = new StringBuilder(capacity: 4096); + using (var writer = CreateWriter(sb, model.ContainingSymbols)) { - using var buffer = new StringWriter(new StringBuilder(capacity: 4096)); - using var writer = new IndentedTextWriter(buffer); + content(model, writer); + } - foreach (var parentSymbol in containingSymbols.Reverse()) + return sb.ToString(); + } + + public static IndentedTextWriter CreateWriter(StringBuilder sb, ImmutableArray containingSymbols) + { + var writer = new UnwindingTextWriter(sb); + foreach (var parentSymbol in containingSymbols.Reverse()) + { + switch (parentSymbol) { - if (parentSymbol is NamespaceContainingSymbol namespaceSymbol) + case NamespaceContainingSymbol namespaceSymbol: { writer.WriteLine(); - writer.WriteLine(EqualityGeneratorBase.EnableNullableContext); - writer.WriteLine(EqualityGeneratorBase.SuppressObsoleteWarningsPragma); - writer.WriteLine(EqualityGeneratorBase.SuppressTypeConflictsWarningsPragma); + writer.WriteLine(GeneratorConstants.EnableNullableContext); + writer.WriteLine(GeneratorConstants.SuppressObsoleteWarningsPragma); + writer.WriteLine(GeneratorConstants.SuppressTypeConflictsWarningsPragma); writer.WriteLine(); if (!string.IsNullOrEmpty(namespaceSymbol.Name)) @@ -33,8 +42,10 @@ public static string Build(ImmutableArray containingSymbols, A writer.WriteLine($"namespace {namespaceSymbol.Name}"); writer.AppendOpenBracket(); } + + break; } - else if (parentSymbol is TypeContainingSymbol typeContainingSymbol) + case TypeContainingSymbol typeContainingSymbol: { var keyword = typeContainingSymbol.Kind switch { @@ -47,14 +58,21 @@ public static string Build(ImmutableArray containingSymbols, A writer.WriteLine($"partial {keyword} {parentSymbol.Name}"); writer.AppendOpenBracket(); + break; } } + } - content(writer); - - writer.UnwindOpenedBrackets(); + return writer; + } - return buffer.ToString(); + internal class UnwindingTextWriter(StringBuilder sb) + : IndentedTextWriter(new StringWriter(sb)) + { + protected override void Dispose(bool disposing) + { + this.UnwindOpenedBrackets(); + base.Dispose(disposing); } } } \ No newline at end of file diff --git a/Generator.Equals/EqualityGeneratorBase.cs b/Generator.Equals/EqualityGeneratorBase.cs deleted file mode 100644 index 802f968..0000000 --- a/Generator.Equals/EqualityGeneratorBase.cs +++ /dev/null @@ -1,195 +0,0 @@ -using System.CodeDom.Compiler; -using System.Collections.Immutable; - -using Generator.Equals.Models; - -using Microsoft.CodeAnalysis; - -namespace Generator.Equals -{ - internal class EqualityGeneratorBase - { - protected const string GeneratedCodeAttributeDeclaration = - "[global::System.CodeDom.Compiler.GeneratedCodeAttribute(\"Generator.Equals\", \"1.0.0.0\")]"; - - internal const string EnableNullableContext = "#nullable enable"; - - // CS0612: Obsolete with no comment - // CS0618: obsolete with comment - internal const string SuppressObsoleteWarningsPragma = "#pragma warning disable CS0612,CS0618"; - - internal const string SuppressTypeConflictsWarningsPragma = "#pragma warning disable CS0436"; - - protected static readonly string[] EqualsOperatorCodeComment = @" -/// -/// Indicates whether the object on the left is equal to the object on the right. -/// -/// The left object -/// The right object -/// true if the objects are equal; otherwise, false.".ToLines(); - - protected static readonly string[] NotEqualsOperatorCodeComment = @" -/// -/// Indicates whether the object on the left is not equal to the object on the right. -/// -/// The left object -/// The right object -/// true if the objects are not equal; otherwise, false.".ToLines(); - - protected const string InheritDocComment = "/// "; - - private static void BuildEquality(EqualityMemberModel memberModel, IndentedTextWriter writer) - { - if (memberModel.Ignored) - { - return; - } - - switch (memberModel.EqualityType) - { - case EqualityType.IgnoreEquality: - break; - - case EqualityType.UnorderedEquality when !memberModel.IsDictionary: - writer.WriteLine( - $"&& global::Generator.Equals.UnorderedEqualityComparer<{memberModel.TypeName}>.Default.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); - break; - - case EqualityType.UnorderedEquality when memberModel.IsDictionary: - writer.WriteLine( - $"&& global::Generator.Equals.DictionaryEqualityComparer<{memberModel.TypeName}>.Default.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); - break; - - case EqualityType.OrderedEquality: - writer.WriteLine( - $"&& global::Generator.Equals.OrderedEqualityComparer<{memberModel.TypeName}>.Default.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); - break; - - case EqualityType.ReferenceEquality: - writer.WriteLine( - $"&& global::Generator.Equals.ReferenceEqualityComparer<{memberModel.TypeName}>.Default.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); - break; - - case EqualityType.SetEquality: - writer.WriteLine( - $"&& global::Generator.Equals.SetEqualityComparer<{memberModel.TypeName}>.Default.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); - break; - - case EqualityType.StringEquality: - writer.WriteLine( - $"&& global::System.StringComparer.{memberModel.StringComparer}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); - break; - - case EqualityType.CustomEquality when memberModel.ComparerHasStaticInstance: - writer.WriteLine( - $"&& {memberModel.ComparerType}.{memberModel.ComparerMemberName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); - - break; - - case EqualityType.CustomEquality when !memberModel.ComparerHasStaticInstance: - writer.WriteLine( - $"&& new {memberModel.ComparerType}().Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); - - break; - - case EqualityType.DefaultEquality: - writer.WriteLine( - $"&& global::Generator.Equals.DefaultEqualityComparer<{memberModel.TypeName}>.Default.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); - break; - } - } - - internal static void BuildMembersEquality( - ImmutableArray models, - IndentedTextWriter writer - ) - { - foreach (var model in models) - { - BuildEquality(model, writer); - } - } - - private static void BuildHashCode( - ISymbol memberSymbol, - ITypeSymbol typeSymbol, - AttributesMetadata attributesMetadata, - IndentedTextWriter writer, - bool explicitMode - ) - { - var model = EqualityMemberModelTransformer - .BuildEqualityModel(memberSymbol, typeSymbol, attributesMetadata, explicitMode); - - BuildHashCode(model, writer); - } - - private static void BuildHashCode(EqualityMemberModel memberModel, IndentedTextWriter writer) - { - if (memberModel.Ignored) - { - return; - } - - switch (memberModel.EqualityType) - { - case EqualityType.IgnoreEquality: - break; - - case EqualityType.UnorderedEquality when memberModel.IsDictionary: - BuildHashCodeAdd($"global::Generator.Equals.DictionaryEqualityComparer<{memberModel.TypeName}>.Default"); - break; - - case EqualityType.UnorderedEquality when !memberModel.IsDictionary: - BuildHashCodeAdd($"global::Generator.Equals.UnorderedEqualityComparer<{memberModel.TypeName}>.Default"); - break; - - case EqualityType.OrderedEquality: - BuildHashCodeAdd($"global::Generator.Equals.OrderedEqualityComparer<{memberModel.TypeName}>.Default"); - break; - - case EqualityType.ReferenceEquality: - BuildHashCodeAdd($"global::Generator.Equals.ReferenceEqualityComparer<{memberModel.TypeName}>.Default"); - break; - - case EqualityType.SetEquality: - BuildHashCodeAdd($"global::Generator.Equals.SetEqualityComparer<{memberModel.TypeName}>.Default"); - break; - - case EqualityType.StringEquality: - BuildHashCodeAdd($"global::System.StringComparer.{memberModel.StringComparer}"); - break; - - case EqualityType.CustomEquality when memberModel.ComparerHasStaticInstance: - BuildHashCodeAdd($"{memberModel.ComparerType}.{memberModel.ComparerMemberName}"); - break; - - case EqualityType.CustomEquality when !memberModel.ComparerHasStaticInstance: - BuildHashCodeAdd($"new {memberModel.ComparerType}()"); - break; - - case EqualityType.DefaultEquality: - BuildHashCodeAdd($"global::Generator.Equals.DefaultEqualityComparer<{memberModel.TypeName}>.Default"); - break; - } - - void BuildHashCodeAdd(string comparer) - { - writer.WriteLine("hashCode.Add("); - writer.Indent++; - writer.WriteLine($"this.{memberModel.PropertyName}!,"); - writer.WriteLine(comparer); - writer.Indent--; - writer.WriteLine(");"); - } - } - - public static void BuildMembersHashCode(ImmutableArray models, IndentedTextWriter writer) - { - foreach (var model in models) - { - BuildHashCode(model, writer); - } - } - } -} \ No newline at end of file diff --git a/Generator.Equals/EqualityMemberModelTransformer.cs b/Generator.Equals/EqualityMemberModelTransformer.cs index 2297366..33928d5 100644 --- a/Generator.Equals/EqualityMemberModelTransformer.cs +++ b/Generator.Equals/EqualityMemberModelTransformer.cs @@ -58,8 +58,54 @@ bool explicitMode }; } - // Check for different equality attributes and map them to the model - if (memberSymbol.HasAttribute(attributesMetadata.UnorderedEquality)) + + if (memberSymbol.HasAttribute(attributesMetadata.StringEquality)) + { + var attribute = memberSymbol.GetAttribute(attributesMetadata.StringEquality)!; + var stringComparisonValue = Convert.ToInt64(attribute.ConstructorArguments[0].Value); + + bool isDefaultStringComparer = false; + if (!attributesMetadata.StringComparisonLookup.TryGetValue(stringComparisonValue, out var enumMemberName)) + { + // Todo: Diagnostic + // throw new Exception("Unexpected StringComparison value."); + enumMemberName = "Ordinal"; + isDefaultStringComparer = true; + } + + + // Special case: We do this comparison through either OrderedEquality or UnorderedEquality + if (typeSymbol.IsStringArray() && typeSymbol.GetIEnumerableTypeArguments() is { } args) + { + var equalityType = memberSymbol.HasAttribute(attributesMetadata.UnorderedEquality) + ? EqualityType.UnorderedEquality + : EqualityType.OrderedEquality; + + var isDefault = equalityType == EqualityType.OrderedEquality && + !memberSymbol.HasAttribute(attributesMetadata.OrderedEquality); + + // return new EqualityMemberModel(propertyName, args.Name, equalityType, stringComparer: enumMemberName); + return new EqualityMemberModel + { + PropertyName = propertyName, + TypeName = args.Name, + EqualityType = equalityType, + StringComparer = enumMemberName, + IsDefaultEqualityType = isDefault, + IsDefaultStringComparer = isDefaultStringComparer + }; + } + + // return new EqualityMemberModel(propertyName, typeName, EqualityType.StringEquality, stringComparer: enumMemberName); + return new EqualityMemberModel + { + PropertyName = propertyName, + TypeName = typeName, + EqualityType = EqualityType.StringEquality, + StringComparer = enumMemberName + }; + } + else if (memberSymbol.HasAttribute(attributesMetadata.UnorderedEquality)) { var args = typeSymbol.GetIDictionaryTypeArguments() ?? typeSymbol.GetIEnumerableTypeArguments()!; @@ -102,24 +148,6 @@ bool explicitMode EqualityType = EqualityType.SetEquality }; } - else if (memberSymbol.HasAttribute(attributesMetadata.StringEquality)) - { - var attribute = memberSymbol.GetAttribute(attributesMetadata.StringEquality)!; - var stringComparisonValue = Convert.ToInt64(attribute.ConstructorArguments[0].Value, CultureInfo.InvariantCulture); - - if (!attributesMetadata.StringComparisonLookup.TryGetValue(stringComparisonValue, out var enumMemberName)) - { - throw new InvalidOperationException("Unexpected StringComparison value."); - } - - return new EqualityMemberModel - { - PropertyName = propertyName, - TypeName = typeName, - EqualityType = EqualityType.StringEquality, - StringComparer = enumMemberName - }; - } else if (memberSymbol.HasAttribute(attributesMetadata.CustomEquality)) { var attribute = memberSymbol.GetAttribute(attributesMetadata.CustomEquality); @@ -143,13 +171,65 @@ bool explicitMode } var isIgnored = (explicitMode && !memberSymbol.HasAttribute(attributesMetadata.DefaultEquality)); + if (isIgnored) + { + return new EqualityMemberModel + { + PropertyName = propertyName, + TypeName = typeName, + EqualityType = EqualityType.IgnoreEquality, + Ignored = true + }; + } + + // if is string: + if (typeSymbol.IsString()) + { + return new EqualityMemberModel + { + PropertyName = propertyName, + TypeName = typeName, + EqualityType = EqualityType.StringEquality, + StringComparer = "Ordinal", + IsDefaultEqualityType = true, + IsDefaultStringComparer = true + }; + } + + if (typeSymbol.GetIEnumerableTypeArguments() is { } sargs) + { + if (string.Equals(sargs.Name, "string", StringComparison.Ordinal)) + { + return new EqualityMemberModel + { + PropertyName = propertyName, + TypeName = sargs.Name, + EqualityType = EqualityType.OrderedEquality, + StringComparer = "Ordinal", + IsDefaultEqualityType = true, + IsDefaultStringComparer = true + }; + } + else + { + return new EqualityMemberModel + { + PropertyName = propertyName, + TypeName = sargs.Name, + EqualityType = EqualityType.OrderedEquality, + IsDefaultEqualityType = true + }; + } + } + return new EqualityMemberModel { PropertyName = propertyName, TypeName = typeName, - EqualityType = isIgnored ? EqualityType.IgnoreEquality : EqualityType.DefaultEquality, - Ignored = isIgnored + EqualityType = EqualityType.DefaultEquality, + Ignored = false, + IsDefaultEqualityType = true }; } } \ No newline at end of file diff --git a/Generator.Equals/EqualsGenerator.cs b/Generator.Equals/EqualsGenerator.cs index a1dd7d1..f038c75 100644 --- a/Generator.Equals/EqualsGenerator.cs +++ b/Generator.Equals/EqualsGenerator.cs @@ -1,4 +1,5 @@ -using System.Linq; +using System; +using System.Linq; using System.Runtime.CompilerServices; using System.Text; using Generator.Equals.Generators; @@ -24,22 +25,34 @@ public void Initialize(IncrementalGeneratorInitializationContext context) (syntaxContext, ct) => new EqualityTypeModelTransformer(syntaxContext).Transform(ct) ); - context.RegisterSourceOutput(provider, (spc, ctx) => Execute(spc, ctx)); + var config = context.AnalyzerConfigOptionsProvider + .Select((options, _) => new GeneratorOptions(options.GlobalOptions)); + + var combinedProvider = provider.Combine(config); + + + // context.RegisterSourceOutput(provider, (spc, ctx) => Execute(spc, ctx)); + context.RegisterSourceOutput(combinedProvider, (spc, ctx) => Execute2(spc, ctx)); } - private static void Execute(SourceProductionContext productionContext, EqualityTypeModel? model) + + private void Execute2(SourceProductionContext productionContext, (EqualityTypeModel? model, GeneratorOptions options) ctx) { - if (productionContext.CancellationToken.IsCancellationRequested || model is null) + if (productionContext.CancellationToken.IsCancellationRequested || ctx.model is null) { return; } + var model = ctx.model + .WithGeneratorOptions(ctx.options); + + var source = model.SyntaxKind switch { - SyntaxKind.StructDeclaration => StructEqualityGenerator.Generate(model), - SyntaxKind.RecordStructDeclaration => RecordStructEqualityGenerator.Generate(model), - SyntaxKind.RecordDeclaration => RecordEqualityGenerator.Generate(model), - SyntaxKind.ClassDeclaration => ClassEqualityGenerator.Generate(model), + SyntaxKind.StructDeclaration => StructGenerator.Generate(model), + SyntaxKind.RecordStructDeclaration => RecordStructGenerator.Generate(model), + SyntaxKind.RecordDeclaration => RecordGenerator.Generate(model), + SyntaxKind.ClassDeclaration => ClassGenerator.Generate(model), _ => null }; @@ -48,6 +61,11 @@ private static void Execute(SourceProductionContext productionContext, EqualityT return; } + //Test: prepend "CounterEnabled: {options.CounterEnabled}" to the generated source + source = $"// CounterEnabled: {ctx.options.DefaultStringComparison}\n" + + $"// ArrayCompare: {ctx.options.ArrayCompare}\n" + + $"{source}"; + var fileName = $"{EscapeFileName(model.Fullname)}.Generator.Equals.g.cs"!; productionContext.AddSource(fileName, source); } diff --git a/Generator.Equals/SymbolHelpers.cs b/Generator.Equals/Extensions/SymbolHelpers.cs similarity index 85% rename from Generator.Equals/SymbolHelpers.cs rename to Generator.Equals/Extensions/SymbolHelpers.cs index e12ef1e..06aa3e9 100644 --- a/Generator.Equals/SymbolHelpers.cs +++ b/Generator.Equals/Extensions/SymbolHelpers.cs @@ -74,6 +74,23 @@ public static string ToFQF(this ISymbol symbol) ? null : new DictionaryArgumentsResult(res); } + + public static bool IsStringArray(this ITypeSymbol typeSymbol) + { + // Check if the symbol is an array + if (typeSymbol is IArrayTypeSymbol arrayTypeSymbol) + { + // Check if the element type is string + return arrayTypeSymbol.ElementType.SpecialType == SpecialType.System_String; + } + + return false; + } + + public static bool IsString(this ITypeSymbol typeSymbol) + { + return typeSymbol.SpecialType == SpecialType.System_String; + } } public record DictionaryArgumentsResult(ImmutableArray? Arguments) : ArgumentsResult(Arguments); diff --git a/Generator.Equals/GeneratorConstants.cs b/Generator.Equals/GeneratorConstants.cs new file mode 100644 index 0000000..82d9587 --- /dev/null +++ b/Generator.Equals/GeneratorConstants.cs @@ -0,0 +1,39 @@ +using System.CodeDom.Compiler; +using System.Collections.Immutable; +using Generator.Equals.Generators.Core; +using Generator.Equals.Models; +using Microsoft.CodeAnalysis; + +namespace Generator.Equals; + +internal class GeneratorConstants +{ + internal const string GeneratedCodeAttributeDeclaration = + "[global::System.CodeDom.Compiler.GeneratedCodeAttribute(\"Generator.Equals\", \"1.0.0.0\")]"; + + internal const string EnableNullableContext = "#nullable enable"; + + // CS0612: Obsolete with no comment + // CS0618: obsolete with comment + internal const string SuppressObsoleteWarningsPragma = "#pragma warning disable CS0612,CS0618"; + + internal const string SuppressTypeConflictsWarningsPragma = "#pragma warning disable CS0436"; + + internal static readonly string[] EqualsOperatorCodeComment = @" +/// +/// Indicates whether the object on the left is equal to the object on the right. +/// +/// The left object +/// The right object +/// true if the objects are equal; otherwise, false.".ToLines(); + + internal static readonly string[] NotEqualsOperatorCodeComment = @" +/// +/// Indicates whether the object on the left is not equal to the object on the right. +/// +/// The left object +/// The right object +/// true if the objects are not equal; otherwise, false.".ToLines(); + + internal const string InheritDocComment = "/// "; +} \ No newline at end of file diff --git a/Generator.Equals/GeneratorOptions.cs b/Generator.Equals/GeneratorOptions.cs new file mode 100644 index 0000000..82d49e9 --- /dev/null +++ b/Generator.Equals/GeneratorOptions.cs @@ -0,0 +1,52 @@ +using System; +using System.Diagnostics; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace Generator.Equals; + +internal record GeneratorOptions +{ + public StringComparison DefaultStringComparison { get; init; } = StringComparison.Ordinal; + public ArrayComparison ArrayCompare { get; init; } = ArrayComparison.Ordered; + + public GeneratorOptions(AnalyzerConfigOptions options) + { + // if (options.TryGetValue("build_property.DemoSourceGenerator_Counter", out var counterEnabledValue)) + // { + // DefaultStringComparison = IsFeatureEnabled(counterEnabledValue); + // } + + if (options.TryGetValue("generator_equals_comparison_string", out var stringComparison)) + { + DefaultStringComparison = Enum.TryParse(stringComparison, out var comparison) + ? comparison + : StringComparison.Ordinal; + } + + if (options.TryGetValue("generator_equals_comparison_enumerable", out var arrayComparison)) + { + ArrayCompare = Enum.TryParse(arrayComparison, out var comparison) + ? comparison + : ArrayComparison.Ordered; + } + } + + private static bool IsFeatureEnabled(string counterEnabledValue) + { + return Equals("enable", counterEnabledValue) + || Equals("enabled", counterEnabledValue) + || Equals("true", counterEnabledValue) + || Equals("1", counterEnabledValue); + + static bool Equals(string v1, string v2) + { + return StringComparer.OrdinalIgnoreCase.Equals(v1, v2); + } + } +} + +internal enum ArrayComparison +{ + Ordered, + Unordered, +} \ No newline at end of file diff --git a/Generator.Equals/Generators/ClassEqualityGenerator.cs b/Generator.Equals/Generators/ClassEqualityGenerator.cs index e39fd1a..ac12915 100644 --- a/Generator.Equals/Generators/ClassEqualityGenerator.cs +++ b/Generator.Equals/Generators/ClassEqualityGenerator.cs @@ -1,118 +1,116 @@ using System.CodeDom.Compiler; - +using Generator.Equals.Generators.Core; using Generator.Equals.Models; -namespace Generator.Equals.Generators +namespace Generator.Equals.Generators; + +internal sealed class ClassGenerator { - internal sealed class ClassEqualityGenerator : EqualityGeneratorBase + private static void BuildEquals( + EqualityTypeModel model, + IndentedTextWriter writer + ) { - private static void BuildEquals( - EqualityTypeModel model, - IndentedTextWriter writer - ) + var ignoreInheritedMembers = model.IgnoreInheritedMembers; + var symbolName = model.TypeName; + var baseTypeName = model.BaseTypeName; + var isRootClass = baseTypeName == "object"; + + writer.WriteLine("// Fields"); + LocalFieldGenerator.BuildEqualityComparerFields(model.BuildEqualityModels, writer); + + writer.WriteLines(GeneratorConstants.EqualsOperatorCodeComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine("public static bool operator ==("); + writer.WriteLine(1, $"{symbolName}? left,"); + writer.WriteLine(1, $"{symbolName}? right) =>"); + writer.WriteLine(1, $"global::Generator.Equals.DefaultEqualityComparer<{symbolName}?>.Default"); + writer.WriteLine(2, $".Equals(left, right);"); + writer.WriteLine(); + + writer.WriteLines(GeneratorConstants.NotEqualsOperatorCodeComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine($"public static bool operator !=({symbolName}? left, {symbolName}? right) =>"); + writer.WriteLine(1, "!(left == right);"); + writer.WriteLine(); + + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine("public override bool Equals(object? obj) =>"); + writer.WriteLine(1, $"Equals(obj as {symbolName});"); + writer.WriteLine(); + + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine($"bool global::System.IEquatable<{symbolName}>.Equals({symbolName}? obj) => Equals((object?) obj);"); + writer.WriteLine(); + + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine($"{(model.IsSealed ? "private" : "protected")} bool Equals({symbolName}? other)"); + writer.AppendOpenBracket(); + + writer.WriteLine("if (ReferenceEquals(null, other)) return false;"); + writer.WriteLine("if (ReferenceEquals(this, other)) return true;"); + writer.WriteLine(); + + if (isRootClass || ignoreInheritedMembers) { - var ignoreInheritedMembers = model.IgnoreInheritedMembers; - var symbolName = model.TypeName; - var baseTypeName = model.BaseTypeName; - var isRootClass = baseTypeName == "object"; - - writer.WriteLines(EqualsOperatorCodeComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine("public static bool operator ==("); - writer.WriteLine(1, $"{symbolName}? left,"); - writer.WriteLine(1, $"{symbolName}? right) =>"); - writer.WriteLine(1, $"global::Generator.Equals.DefaultEqualityComparer<{symbolName}?>.Default"); - writer.WriteLine(2, $".Equals(left, right);"); - writer.WriteLine(); + writer.WriteLine("return other.GetType() == this.GetType()"); + } + else + { + writer.WriteLine($"return base.Equals(other as {baseTypeName})"); + } - writer.WriteLines(NotEqualsOperatorCodeComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine($"public static bool operator !=({symbolName}? left, {symbolName}? right) =>"); - writer.WriteLine(1, "!(left == right);"); - writer.WriteLine(); + writer.Indent++; + EqualityMethodGenerator.BuildMembersEquality(model.BuildEqualityModels, writer); + writer.WriteLine(";"); + writer.Indent--; - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine("public override bool Equals(object? obj) =>"); - writer.WriteLine(1, $"Equals(obj as {symbolName});"); - writer.WriteLine(); + writer.AppendCloseBracket(); + } - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine($"bool global::System.IEquatable<{symbolName}>.Equals({symbolName}? obj) => Equals((object?) obj);"); - writer.WriteLine(); + private static void BuildGetHashCode( + EqualityTypeModel model, + IndentedTextWriter writer + ) + { + var ignoreInheritedMembers = model.IgnoreInheritedMembers; + var baseTypeName = model.BaseTypeName; - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine($"{(model.IsSealed ? "private" : "protected")} bool Equals({symbolName}? other)"); - writer.AppendOpenBracket(); + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine(@"public override int GetHashCode()"); + writer.AppendOpenBracket(); - writer.WriteLine("if (ReferenceEquals(null, other)) return false;"); - writer.WriteLine("if (ReferenceEquals(this, other)) return true;"); - writer.WriteLine(); + writer.WriteLine(@"var hashCode = new global::System.HashCode();"); + writer.WriteLine(); - if (isRootClass || ignoreInheritedMembers) - { - writer.WriteLine("return other.GetType() == this.GetType()"); - } - else - { - writer.WriteLine($"return base.Equals(other as {baseTypeName})"); - } + writer.WriteLine(baseTypeName == "object" || ignoreInheritedMembers + ? "hashCode.Add(this.GetType());" + : "hashCode.Add(base.GetHashCode());"); - writer.Indent++; - BuildMembersEquality(model.BuildEqualityModels, writer); - writer.WriteLine(";"); - writer.Indent--; + HashCodeMethodGenerator.BuildMembersHashCode(model.BuildEqualityModels, writer); - writer.AppendCloseBracket(); - } + writer.WriteLine(); + writer.WriteLine("return hashCode.ToHashCode();"); - private static void BuildGetHashCode( - EqualityTypeModel model, - IndentedTextWriter writer - ) - { - var ignoreInheritedMembers = model.IgnoreInheritedMembers; - var baseTypeName = model.BaseTypeName; + writer.AppendCloseBracket(); + } - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine(@"public override int GetHashCode()"); + public static string Generate(EqualityTypeModel model) => + ContainingTypesBuilder.Build(model, static (model, writer) => + { + writer.WriteLine($"partial class {model.TypeName} : global::System.IEquatable<{model.TypeName}>"); writer.AppendOpenBracket(); - writer.WriteLine(@"var hashCode = new global::System.HashCode();"); - writer.WriteLine(); - - writer.WriteLine(baseTypeName == "object" || ignoreInheritedMembers - ? "hashCode.Add(this.GetType());" - : "hashCode.Add(base.GetHashCode());"); - - BuildMembersHashCode(model.BuildEqualityModels, writer); + BuildEquals(model, writer); writer.WriteLine(); - writer.WriteLine("return hashCode.ToHashCode();"); - - writer.AppendCloseBracket(); - } - public static string Generate(EqualityTypeModel model) - { - var code = ContainingTypesBuilder.Build(model.ContainingSymbols, content: writer => - { - writer.WriteLine($"partial class {model.TypeName} : global::System.IEquatable<{model.TypeName}>"); - writer.AppendOpenBracket(); - - BuildEquals(model, writer); - - writer.WriteLine(); - - BuildGetHashCode(model, writer); + BuildGetHashCode(model, writer); - writer.AppendCloseBracket(); - }); - - return code; - } - } + writer.AppendCloseBracket(); + }); } \ No newline at end of file diff --git a/Generator.Equals/Generators/Core/EqualityMethodGenerator.cs b/Generator.Equals/Generators/Core/EqualityMethodGenerator.cs new file mode 100644 index 0000000..e77fe89 --- /dev/null +++ b/Generator.Equals/Generators/Core/EqualityMethodGenerator.cs @@ -0,0 +1,102 @@ +using Generator.Equals.Models; + +using System; +using System.CodeDom.Compiler; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Text; + +namespace Generator.Equals.Generators.Core; + +internal class EqualityMethodGenerator +{ + internal static void BuildMembersEquality( + ImmutableArray models, + IndentedTextWriter writer + ) + { + foreach (var model in models) + { + BuildEquality(model, writer); + } + } + + internal static void BuildEquality(EqualityMemberModel memberModel, IndentedTextWriter writer) + { + if (memberModel.Ignored) + { + return; + } + + var comparerFieldName = LocalFieldGenerator.GetFieldName(memberModel); + + switch (memberModel.EqualityType) + { + case EqualityType.IgnoreEquality: + break; + + case EqualityType.UnorderedEquality + when memberModel is { IsDictionary: false, StringComparer: not null and not "" }: + + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + + case EqualityType.UnorderedEquality + when memberModel is { IsDictionary: false, StringComparer: null }: + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + + case EqualityType.UnorderedEquality + when memberModel is { IsDictionary: true, StringComparer: null }: + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + + case EqualityType.OrderedEquality + when memberModel is { StringComparer: not null and not "" }: + + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + + case EqualityType.OrderedEquality: + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + + case EqualityType.ReferenceEquality: + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + + case EqualityType.SetEquality: + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + + case EqualityType.StringEquality: + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + + case EqualityType.CustomEquality when memberModel.ComparerHasStaticInstance: + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + + break; + + case EqualityType.CustomEquality when !memberModel.ComparerHasStaticInstance: + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + + break; + + case EqualityType.DefaultEquality: + writer.WriteLine( + $"&& {comparerFieldName}.Equals(this.{memberModel.PropertyName}!, other.{memberModel.PropertyName}!)"); + break; + } + } +} \ No newline at end of file diff --git a/Generator.Equals/Generators/Core/HashCodeMethodGenerator.cs b/Generator.Equals/Generators/Core/HashCodeMethodGenerator.cs new file mode 100644 index 0000000..fe17ca6 --- /dev/null +++ b/Generator.Equals/Generators/Core/HashCodeMethodGenerator.cs @@ -0,0 +1,96 @@ +using Generator.Equals.Models; + +using System; +using System.CodeDom.Compiler; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Text; + +namespace Generator.Equals.Generators.Core; + +internal class HashCodeMethodGenerator +{ + public static void BuildMembersHashCode(ImmutableArray models, IndentedTextWriter writer) + { + foreach (var model in models) + { + HashCodeMethodGenerator.BuildHashCode(model, writer); + } + } + + internal static void BuildHashCode(EqualityMemberModel memberModel, IndentedTextWriter writer) + { + if (memberModel.Ignored) + { + return; + } + + switch (memberModel.EqualityType) + { + case EqualityType.IgnoreEquality: + break; + + case EqualityType.UnorderedEquality + when memberModel is { StringComparer: null, IsDictionary: true }: + + BuildHashCodeAdd($"global::Generator.Equals.DictionaryEqualityComparer<{memberModel.TypeName}>.Default"); + break; + + case EqualityType.UnorderedEquality + when memberModel is { StringComparer: null, IsDictionary: false }: + + BuildHashCodeAdd($"global::Generator.Equals.UnorderedEqualityComparer<{memberModel.TypeName}>.Default"); + break; + + case EqualityType.UnorderedEquality + when memberModel is { StringComparer: not null and not "", IsDictionary: false }: + + BuildHashCodeAdd($"new global::Generator.Equals.UnorderedEqualityComparer<{memberModel.TypeName}>(global::System.StringComparer.{memberModel.StringComparer})"); + break; + + case EqualityType.OrderedEquality + when memberModel is { StringComparer: not null and not "" }: + + BuildHashCodeAdd($"new global::Generator.Equals.OrderedEqualityComparer<{memberModel.TypeName}>(global::System.StringComparer.{memberModel.StringComparer})"); + break; + + case EqualityType.OrderedEquality: + BuildHashCodeAdd($"global::Generator.Equals.OrderedEqualityComparer<{memberModel.TypeName}>.Default"); + break; + + case EqualityType.ReferenceEquality: + BuildHashCodeAdd($"global::Generator.Equals.ReferenceEqualityComparer<{memberModel.TypeName}>.Default"); + break; + + case EqualityType.SetEquality: + BuildHashCodeAdd($"global::Generator.Equals.SetEqualityComparer<{memberModel.TypeName}>.Default"); + break; + + case EqualityType.StringEquality: + BuildHashCodeAdd($"global::System.StringComparer.{memberModel.StringComparer}"); + break; + + case EqualityType.CustomEquality when memberModel.ComparerHasStaticInstance: + BuildHashCodeAdd($"{memberModel.ComparerType}.{memberModel.ComparerMemberName}"); + break; + + case EqualityType.CustomEquality when !memberModel.ComparerHasStaticInstance: + BuildHashCodeAdd($"new {memberModel.ComparerType}()"); + break; + + case EqualityType.DefaultEquality: + BuildHashCodeAdd($"global::Generator.Equals.DefaultEqualityComparer<{memberModel.TypeName}>.Default"); + break; + } + + void BuildHashCodeAdd(string comparer) + { + writer.WriteLine("hashCode.Add("); + writer.Indent++; + writer.WriteLine($"this.{memberModel.PropertyName}!,"); + writer.WriteLine(comparer); + writer.Indent--; + writer.WriteLine(");"); + } + } +} \ No newline at end of file diff --git a/Generator.Equals/Generators/Core/LocalFieldGenerator.cs b/Generator.Equals/Generators/Core/LocalFieldGenerator.cs new file mode 100644 index 0000000..4374c54 --- /dev/null +++ b/Generator.Equals/Generators/Core/LocalFieldGenerator.cs @@ -0,0 +1,131 @@ +// using System.CodeDom.Compiler; +// using System.Collections.Immutable; +// using Generator.Equals.Models; +// +// namespace Generator.Equals.Generators.Core; +// +// internal static class LocalFieldGenerator +// { +// internal static void BuildEqualityComparerFields( +// ImmutableArray models, +// IndentedTextWriter writer +// ) +// { +// foreach (var model in models) +// { +// if (model.Ignored) continue; +// +// switch (model.EqualityType) +// { +// case EqualityType.UnorderedEquality when model.StringComparer is not null: +// line =( +// $"private static readonly global::Generator.Equals.UnorderedEqualityComparer<{model.TypeName}> _unorderedEqualityComparer_{model.TypeName}_{model.StringComparer} = " + +// $"new(global::System.StringComparer.{model.StringComparer});"); +// break; +// +// case EqualityType.OrderedEquality when model.StringComparer is not null: +// line =( +// $"private static readonly global::Generator.Equals.OrderedEqualityComparer<{model.TypeName}> _orderedEqualityComparer_{model.TypeName}_{model.StringComparer} = " + +// $"new(global::System.StringComparer.{model.StringComparer});"); +// break; +// +// // Add more cases here if you need additional comparer types (SetEquality, ReferenceEquality, etc.) +// } +// } +// } +// } + +using System.CodeDom.Compiler; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Text; +using System.Text.RegularExpressions; +using Generator.Equals.Models; + +namespace Generator.Equals.Generators.Core; + +internal static class LocalFieldGenerator +{ + internal static void BuildEqualityComparerFields( + ImmutableArray models, + IndentedTextWriter writer + ) + { + var hashset = new HashSet(); + + foreach (var model in models) + { + if (model.Ignored) continue; + + var fieldName = GetFieldName(model); + + var localFieldDefinition = model.EqualityType switch + { + EqualityType.UnorderedEquality when model.IsDictionary && model.StringComparer is null => + $"private static readonly global::Generator.Equals.DictionaryEqualityComparer<{model.TypeName}> {fieldName} = " + + $"global::Generator.Equals.DictionaryEqualityComparer<{model.TypeName}>.Default;", + EqualityType.UnorderedEquality when model.StringComparer is not null => + $"private static readonly global::Generator.Equals.UnorderedEqualityComparer<{model.TypeName}> {fieldName} = " + + $"new(global::System.StringComparer.{model.StringComparer});", + EqualityType.UnorderedEquality when model.StringComparer is null => + $"private static readonly global::Generator.Equals.UnorderedEqualityComparer<{model.TypeName}> {fieldName} = " + + $"global::Generator.Equals.UnorderedEqualityComparer<{model.TypeName}>.Default;", + EqualityType.OrderedEquality when model.StringComparer is not null => + $"private static readonly global::Generator.Equals.OrderedEqualityComparer<{model.TypeName}> {fieldName} = " + + $"new(global::System.StringComparer.{model.StringComparer});", + EqualityType.OrderedEquality => + $"private static readonly global::Generator.Equals.OrderedEqualityComparer<{model.TypeName}> {fieldName} = " + + $"global::Generator.Equals.OrderedEqualityComparer<{model.TypeName}>.Default;", + EqualityType.ReferenceEquality => ( + $"private static readonly global::Generator.Equals.ReferenceEqualityComparer<{model.TypeName}> {fieldName} = " + + $"global::Generator.Equals.ReferenceEqualityComparer<{model.TypeName}>.Default;"), + EqualityType.SetEquality => ( + $"private static readonly global::Generator.Equals.SetEqualityComparer<{model.TypeName}> {fieldName} = " + + $"global::Generator.Equals.SetEqualityComparer<{model.TypeName}>.Default;"), + EqualityType.StringEquality when model.StringComparer is not null => ( + $"private static readonly global::System.StringComparer {fieldName} = " + + $"global::System.StringComparer.{model.StringComparer};"), + EqualityType.CustomEquality when model.ComparerHasStaticInstance => ( + $"private static readonly {model.ComparerType} {fieldName} = " + + $"{model.ComparerType}.{model.ComparerMemberName};"), + EqualityType.CustomEquality when !model.ComparerHasStaticInstance => ( + $"private static readonly {model.ComparerType} {fieldName} = " + + $"new {model.ComparerType}();"), + EqualityType.DefaultEquality => ( + $"private static readonly global::Generator.Equals.DefaultEqualityComparer<{model.TypeName}> {fieldName} = " + + $"global::Generator.Equals.DefaultEqualityComparer<{model.TypeName}>.Default;"), + _ => string.Empty + }; + + if (!string.IsNullOrEmpty(localFieldDefinition) && hashset.Add(localFieldDefinition)) + { + writer.WriteLine(localFieldDefinition); + } + } + } + + private static readonly Regex _typeNameClean = new Regex(@"[<>:.]", RegexOptions.Compiled); + + public static string GetFieldName(EqualityMemberModel model) + { + // Replace <>:. with _ + var cleanedTypeName = _typeNameClean.Replace(model.TypeName, "_"); + + + return model.EqualityType switch + { + EqualityType.UnorderedEquality when model.IsDictionary && model.StringComparer is null => $"_dictionaryEqualityComparer_{cleanedTypeName}", + EqualityType.UnorderedEquality when model.StringComparer is not null => $"_unorderedEqualityComparer_{cleanedTypeName}_{model.StringComparer}", + EqualityType.UnorderedEquality when model.StringComparer is null => $"_unorderedEqualityComparer_{cleanedTypeName}", + EqualityType.OrderedEquality when model.StringComparer is not null => $"_orderedEqualityComparer_{cleanedTypeName}_{model.StringComparer}", + EqualityType.OrderedEquality => $"_orderedEqualityComparer_{cleanedTypeName}", + EqualityType.ReferenceEquality => $"_referenceEqualityComparer_{cleanedTypeName}", + EqualityType.SetEquality => $"_setEqualityComparer_{cleanedTypeName}", + EqualityType.StringEquality when model.StringComparer is not null => $"_stringComparer_{model.StringComparer}", + EqualityType.CustomEquality when model.ComparerHasStaticInstance => $"_customComparer_{model.ComparerType}_{model.PropertyName}", + EqualityType.CustomEquality when !model.ComparerHasStaticInstance => $"_customComparer_{model.ComparerType}_{model.PropertyName}", + EqualityType.DefaultEquality => $"_defaultEqualityComparer_{cleanedTypeName}", + _ => string.Empty + }; + } +} \ No newline at end of file diff --git a/Generator.Equals/Generators/RecordEqualityGenerator.cs b/Generator.Equals/Generators/RecordEqualityGenerator.cs index 0e532df..e5af065 100644 --- a/Generator.Equals/Generators/RecordEqualityGenerator.cs +++ b/Generator.Equals/Generators/RecordEqualityGenerator.cs @@ -1,83 +1,78 @@ using System.CodeDom.Compiler; - +using Generator.Equals.Generators.Core; using Generator.Equals.Models; -namespace Generator.Equals.Generators +namespace Generator.Equals.Generators; + +internal static class RecordGenerator { - internal class RecordEqualityGenerator : EqualityGeneratorBase + static void BuildEquals( + EqualityTypeModel model, + IndentedTextWriter writer + ) { - static void BuildEquals( - EqualityTypeModel model, - IndentedTextWriter writer - ) - { - bool ignoreInheritedMembers = model.IgnoreInheritedMembers; - var symbolName = model.TypeName; - var baseTypeName = model.BaseTypeName; + bool ignoreInheritedMembers = model.IgnoreInheritedMembers; + var symbolName = model.TypeName; + var baseTypeName = model.BaseTypeName; - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine(model.IsSealed - ? $"public bool Equals({symbolName}? other)" - : $"public virtual bool Equals({symbolName}? other)"); - writer.AppendOpenBracket(); + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine(model.IsSealed + ? $"public bool Equals({symbolName}? other)" + : $"public virtual bool Equals({symbolName}? other)"); + writer.AppendOpenBracket(); - writer.WriteLine("return"); + writer.WriteLine("return"); - writer.Indent++; + writer.Indent++; - writer.WriteLine(baseTypeName == "object" || ignoreInheritedMembers - ? "!ReferenceEquals(other, null) && EqualityContract == other.EqualityContract" - : $"base.Equals(({baseTypeName}?)other)"); + writer.WriteLine(baseTypeName == "object" || ignoreInheritedMembers + ? "!ReferenceEquals(other, null) && EqualityContract == other.EqualityContract" + : $"base.Equals(({baseTypeName}?)other)"); - BuildMembersEquality(model.BuildEqualityModels, writer); + EqualityMethodGenerator.BuildMembersEquality(model.BuildEqualityModels, writer); - writer.WriteLine(";"); - writer.Indent--; + writer.WriteLine(";"); + writer.Indent--; - writer.AppendCloseBracket(); - } + writer.AppendCloseBracket(); + } - static void BuildGetHashCode( - EqualityTypeModel model, - IndentedTextWriter writer - ) - { - bool ignoreInheritedMembers = model.IgnoreInheritedMembers; - var baseTypeName = model.BaseTypeName; + static void BuildGetHashCode( + EqualityTypeModel model, + IndentedTextWriter writer + ) + { + bool ignoreInheritedMembers = model.IgnoreInheritedMembers; + var baseTypeName = model.BaseTypeName; - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine(@"public override int GetHashCode()"); - writer.AppendOpenBracket(); + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine(@"public override int GetHashCode()"); + writer.AppendOpenBracket(); - writer.WriteLine(@"var hashCode = new global::System.HashCode();"); - writer.WriteLine(); + writer.WriteLine(@"var hashCode = new global::System.HashCode();"); + writer.WriteLine(); - writer.WriteLine(baseTypeName == "object" || ignoreInheritedMembers - ? "hashCode.Add(this.EqualityContract);" - : "hashCode.Add(base.GetHashCode());"); + writer.WriteLine(baseTypeName == "object" || ignoreInheritedMembers + ? "hashCode.Add(this.EqualityContract);" + : "hashCode.Add(base.GetHashCode());"); - BuildMembersHashCode(model.BuildEqualityModels, writer); + HashCodeMethodGenerator.BuildMembersHashCode(model.BuildEqualityModels, writer); - writer.WriteLine(); - writer.WriteLine("return hashCode.ToHashCode();"); + writer.WriteLine(); + writer.WriteLine("return hashCode.ToHashCode();"); - writer.AppendCloseBracket(); - } + writer.AppendCloseBracket(); + } - public static string Generate(EqualityTypeModel model) + public static string Generate(EqualityTypeModel model) => + ContainingTypesBuilder.Build(model, static (model, writer) => { - var code = ContainingTypesBuilder.Build(model.ContainingSymbols, content: writer => - { - BuildEquals(model, writer); + BuildEquals(model, writer); - writer.WriteLine(); - - BuildGetHashCode(model, writer); - }); + writer.WriteLine(); - return code; - } - } + BuildGetHashCode(model, writer); + }); } \ No newline at end of file diff --git a/Generator.Equals/Generators/RecordStructEqualityGenerator.cs b/Generator.Equals/Generators/RecordStructEqualityGenerator.cs index 032dc54..2d680f7 100644 --- a/Generator.Equals/Generators/RecordStructEqualityGenerator.cs +++ b/Generator.Equals/Generators/RecordStructEqualityGenerator.cs @@ -1,63 +1,58 @@ using System.CodeDom.Compiler; - +using Generator.Equals.Generators.Core; using Generator.Equals.Models; -namespace Generator.Equals.Generators +namespace Generator.Equals.Generators; + +internal sealed class RecordStructGenerator { - internal sealed class RecordStructEqualityGenerator : EqualityGeneratorBase + private static void BuildEquals( + EqualityTypeModel model, + IndentedTextWriter writer + ) { - private static void BuildEquals( - EqualityTypeModel model, - IndentedTextWriter writer) - { - var symbolName = model.TypeName; + var symbolName = model.TypeName; - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine($"public bool Equals({symbolName} other)"); - writer.AppendOpenBracket(); + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine($"public bool Equals({symbolName} other)"); + writer.AppendOpenBracket(); - writer.WriteLine("return true"); + writer.WriteLine("return true"); - writer.Indent++; - BuildMembersEquality(model.BuildEqualityModels, writer); - writer.WriteLine(";"); - writer.Indent--; + writer.Indent++; + EqualityMethodGenerator.BuildMembersEquality(model.BuildEqualityModels, writer); + writer.WriteLine(";"); + writer.Indent--; - writer.AppendCloseBracket(); - } + writer.AppendCloseBracket(); + } - private static void BuildGetHashCode( - EqualityTypeModel model, - IndentedTextWriter writer) - { - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine(@"public override int GetHashCode()"); - writer.AppendOpenBracket(); + private static void BuildGetHashCode( + EqualityTypeModel model, + IndentedTextWriter writer + ) + { + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine(@"public override int GetHashCode()"); + writer.AppendOpenBracket(); - writer.WriteLine(@"var hashCode = new global::System.HashCode();"); - writer.WriteLine(); + writer.WriteLine(@"var hashCode = new global::System.HashCode();"); + writer.WriteLine(); - BuildMembersHashCode(model.BuildEqualityModels, writer); + HashCodeMethodGenerator.BuildMembersHashCode(model.BuildEqualityModels, writer); - writer.WriteLine(); - writer.WriteLine("return hashCode.ToHashCode();"); - writer.AppendCloseBracket(); - } + writer.WriteLine(); + writer.WriteLine("return hashCode.ToHashCode();"); + writer.AppendCloseBracket(); + } - public static string Generate(EqualityTypeModel model) + public static string Generate(EqualityTypeModel model) + => ContainingTypesBuilder.Build(model, static (model, writer) => { - var code = ContainingTypesBuilder.Build( - model.ContainingSymbols, - content: writer => - { - BuildEquals(model, writer); - writer.WriteLine(); - BuildGetHashCode(model, writer); - }); - - return code; - } - } + BuildEquals(model, writer); + writer.WriteLine(); + BuildGetHashCode(model, writer); + }); } \ No newline at end of file diff --git a/Generator.Equals/Generators/StructEqualityGenerator.cs b/Generator.Equals/Generators/StructEqualityGenerator.cs index cd656e1..7928cf4 100644 --- a/Generator.Equals/Generators/StructEqualityGenerator.cs +++ b/Generator.Equals/Generators/StructEqualityGenerator.cs @@ -1,89 +1,81 @@ using System.CodeDom.Compiler; - +using System.Text; +using Generator.Equals.Generators.Core; using Generator.Equals.Models; -namespace Generator.Equals.Generators +namespace Generator.Equals.Generators; + +internal sealed class StructGenerator { - internal sealed class StructEqualityGenerator : EqualityGeneratorBase + private static void BuildEquals(EqualityTypeModel model, IndentedTextWriter writer) { - private static void BuildEquals(EqualityTypeModel model, IndentedTextWriter writer) - { - var symbolName = model.TypeName; - - writer.WriteLines(EqualsOperatorCodeComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine("public static bool operator ==("); - writer.WriteLine(1, $"{symbolName} left,"); - writer.WriteLine(1, $"{symbolName} right) =>"); - writer.WriteLine(1, $"global::Generator.Equals.DefaultEqualityComparer<{symbolName}>.Default"); - writer.WriteLine(2, $".Equals(left, right);"); - writer.WriteLine(); - - writer.WriteLines(NotEqualsOperatorCodeComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine($"public static bool operator !=({symbolName} left, {symbolName} right) =>"); - writer.WriteLine(1, "!(left == right);"); - writer.WriteLine(); - - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine("public override bool Equals(object? obj) =>"); - writer.WriteLine(1, $"obj is {symbolName} o && Equals(o);"); - writer.WriteLine(); + var symbolName = model.TypeName; + + writer.WriteLines(GeneratorConstants.EqualsOperatorCodeComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine("public static bool operator ==("); + writer.WriteLine(1, $"{symbolName} left,"); + writer.WriteLine(1, $"{symbolName} right) =>"); + writer.WriteLine(1, $"global::Generator.Equals.DefaultEqualityComparer<{symbolName}>.Default"); + writer.WriteLine(2, $".Equals(left, right);"); + writer.WriteLine(); + + writer.WriteLines(GeneratorConstants.NotEqualsOperatorCodeComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine($"public static bool operator !=({symbolName} left, {symbolName} right) =>"); + writer.WriteLine(1, "!(left == right);"); + writer.WriteLine(); + + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine("public override bool Equals(object? obj) =>"); + writer.WriteLine(1, $"obj is {symbolName} o && Equals(o);"); + writer.WriteLine(); + + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine($"public bool Equals({symbolName} other)"); + writer.AppendOpenBracket(); + + writer.WriteLine("return true"); + + writer.Indent++; + EqualityMethodGenerator.BuildMembersEquality(model.BuildEqualityModels, writer); + + writer.WriteLine(";"); + writer.Indent--; + + writer.AppendCloseBracket(); + } - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine($"public bool Equals({symbolName} other)"); - writer.AppendOpenBracket(); + private static void BuildGetHashCode(EqualityTypeModel model, IndentedTextWriter writer) + { + writer.WriteLine(GeneratorConstants.InheritDocComment); + writer.WriteLine(GeneratorConstants.GeneratedCodeAttributeDeclaration); + writer.WriteLine(@"public override int GetHashCode()"); + writer.AppendOpenBracket(); - writer.WriteLine("return true"); + writer.WriteLine(@"var hashCode = new global::System.HashCode();"); + writer.WriteLine(); - writer.Indent++; - BuildMembersEquality(model.BuildEqualityModels, writer); + HashCodeMethodGenerator.BuildMembersHashCode(model.BuildEqualityModels, writer); - writer.WriteLine(";"); - writer.Indent--; + writer.WriteLine(); + writer.WriteLine("return hashCode.ToHashCode();"); - writer.AppendCloseBracket(); - } + writer.AppendCloseBracket(); + } - private static void BuildGetHashCode(EqualityTypeModel model, IndentedTextWriter writer) + public static string Generate(EqualityTypeModel model) + => ContainingTypesBuilder.Build(model, static (model, writer) => { - writer.WriteLine(InheritDocComment); - writer.WriteLine(GeneratedCodeAttributeDeclaration); - writer.WriteLine(@"public override int GetHashCode()"); + writer.WriteLine($"partial struct {model.TypeName} : global::System.IEquatable<{model.TypeName}>"); writer.AppendOpenBracket(); - writer.WriteLine(@"var hashCode = new global::System.HashCode();"); - writer.WriteLine(); - - BuildMembersHashCode(model.BuildEqualityModels, writer); + BuildEquals(model, writer); writer.WriteLine(); - writer.WriteLine("return hashCode.ToHashCode();"); - - writer.AppendCloseBracket(); - } - - public static string Generate(EqualityTypeModel model) - { - // Generate the code using the custom model instead of Roslyn types - var code = ContainingTypesBuilder.Build(model.ContainingSymbols, content: writer => - { - writer.WriteLine($"partial struct {model.TypeName} : global::System.IEquatable<{model.TypeName}>"); - writer.AppendOpenBracket(); - // BuildEquals and BuildGetHashCode are adjusted to accept the custom model instead of Roslyn symbols - BuildEquals(model, writer); - - writer.WriteLine(); - - BuildGetHashCode(model, writer); - - writer.AppendCloseBracket(); - }); - - return code; - } - } + BuildGetHashCode(model, writer); + }); } \ No newline at end of file diff --git a/Generator.Equals/Models/EqualityMemberModel.cs b/Generator.Equals/Models/EqualityMemberModel.cs index ee24d23..baf3b00 100644 --- a/Generator.Equals/Models/EqualityMemberModel.cs +++ b/Generator.Equals/Models/EqualityMemberModel.cs @@ -12,4 +12,12 @@ internal sealed record EqualityMemberModel public bool Ignored { get; init; } public bool ComparerHasStaticInstance { get; init; } + + + /// + /// OrderedEquality vs UnorderedEquality. + /// Specifies that the equality type is not explicitly set and can be overridden by global settings. + /// + public bool IsDefaultEqualityType { get; init; } + public bool IsDefaultStringComparer { get; init; } } \ No newline at end of file diff --git a/Generator.Equals/Models/EqualityTypeModel.cs b/Generator.Equals/Models/EqualityTypeModel.cs index 41eb694..b21bc8d 100644 --- a/Generator.Equals/Models/EqualityTypeModel.cs +++ b/Generator.Equals/Models/EqualityTypeModel.cs @@ -1,4 +1,5 @@ -using Microsoft.CodeAnalysis.CSharp; +using System.Linq; +using Microsoft.CodeAnalysis.CSharp; namespace Generator.Equals.Models; @@ -15,4 +16,52 @@ internal sealed record EqualityTypeModel public required bool IgnoreInheritedMembers { get; init; } public required EquatableImmutableArray BuildEqualityModels { get; init; } public required string Fullname { get; init; } + + // public GeneratorOptions? GlobalOptions { get; init; } + + public EqualityTypeModel WithGeneratorOptions(GeneratorOptions options) + { + EqualityMemberModel[] newModels = null; + + for (var index = 0; index < BuildEqualityModels.Items.Length; index++) + { + var equalityModel = BuildEqualityModels.Items[index]; + + if (equalityModel.IsDefaultEqualityType + && equalityModel.EqualityType == EqualityType.OrderedEquality) + { + if (options.ArrayCompare == ArrayComparison.Unordered) + { + newModels ??= BuildEqualityModels.Items.ToArray(); + + equalityModel = equalityModel with + { + EqualityType = EqualityType.UnorderedEquality + }; + + newModels[index] = equalityModel; + } + } + + if (equalityModel.IsDefaultStringComparer) + { + newModels ??= BuildEqualityModels.Items.ToArray(); + + equalityModel = equalityModel with + { + StringComparer = options.DefaultStringComparison.ToString() + }; + + newModels[index] = equalityModel; + } + } + + return this with + { + BuildEqualityModels = newModels is null + ? BuildEqualityModels + : new EquatableImmutableArray(newModels) + }; + } + } \ No newline at end of file diff --git a/Generator.Equals/Models/EquatableImmutableArray.cs b/Generator.Equals/Models/EquatableImmutableArray.cs index a766298..109024c 100644 --- a/Generator.Equals/Models/EquatableImmutableArray.cs +++ b/Generator.Equals/Models/EquatableImmutableArray.cs @@ -12,6 +12,11 @@ public EquatableImmutableArray() : this(ImmutableArray.Empty) { } + public EquatableImmutableArray(T[] items) : this(items.ToImmutableArray()) + { + } + + public bool Equals(EquatableImmutableArray other) { return Items.SequenceEqual(other.Items); diff --git a/README.md b/README.md index 432da3a..e1ea881 100644 --- a/README.md +++ b/README.md @@ -222,3 +222,16 @@ partial class Doctor : Person public string Specialization { get; set; } } ``` + +### Global Options + +You can configure global options by creating an ``.globalconfig`` file which looks like this: + +```ini +generator_equals_comparison_string = OrdinalIgnoreCase +generator_equals_comparison_enumerable = Unordered +``` + +generator_equals_comparison_string: Possible values are `CurrentCulture`, `CurrentCultureIgnoreCase`, `InvariantCulture`, `InvariantCultureIgnoreCase`, `Ordinal`, `OrdinalIgnoreCase`. Default is `Ordinal`. + +generator_equals_comparison_enumerable: Possible values are `Ordered`, `Unordered`. Default is `Ordered`. \ No newline at end of file