diff --git a/src/xunit.analyzers.fixes/X1000/ClassDataAttributeMustPointAtValidClassFixer.cs b/src/xunit.analyzers.fixes/X1000/ClassDataAttributeMustPointAtValidClassFixer.cs index 0788dec8..91e990ef 100644 --- a/src/xunit.analyzers.fixes/X1000/ClassDataAttributeMustPointAtValidClassFixer.cs +++ b/src/xunit.analyzers.fixes/X1000/ClassDataAttributeMustPointAtValidClassFixer.cs @@ -29,7 +29,21 @@ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) if (semanticModel is null) return; - var typeOfExpression = root.FindNode(context.Span).FirstAncestorOrSelf(); + // The context wraps "ClassData(typeof(T))", which is an attribute syntax + var attribute = root.FindNode(context.Span); + + // Dig in to find the attribute arguments + var attributeArgumentList = attribute.ChildNodes().OfType().FirstOrDefault(); + if (attributeArgumentList is null) + return; + + // Dig into that to get the attribute argument + var attributeArgument = attributeArgumentList.ChildNodes().OfType().FirstOrDefault(); + if (attributeArgument is null) + return; + + // And finally, dig in to get the typeof expression + var typeOfExpression = attributeArgument.ChildNodes().OfType().FirstOrDefault(); if (typeOfExpression is null) return; diff --git a/src/xunit.analyzers.tests/Analyzers/X1000/ClassDataAttributeMustPointAtValidClassTests.cs b/src/xunit.analyzers.tests/Analyzers/X1000/ClassDataAttributeMustPointAtValidClassTests.cs index d8610144..277ffcb3 100644 --- a/src/xunit.analyzers.tests/Analyzers/X1000/ClassDataAttributeMustPointAtValidClassTests.cs +++ b/src/xunit.analyzers.tests/Analyzers/X1000/ClassDataAttributeMustPointAtValidClassTests.cs @@ -1,23 +1,28 @@ using System.Threading.Tasks; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Xunit; using Verify = CSharpVerifier; public class ClassDataAttributeMustPointAtValidClassTests { - static readonly string TestMethodSource = @" + static string TestMethodSource(string testMethodParams = "(int n)") => @$" +#nullable enable + using Xunit; -public class TestClass { +public class TestClass {{ [Theory] [ClassData(typeof(DataClass))] - public void TestMethod() { } -}"; + public void TestMethod{testMethodParams} {{ }} +}}"; - [Fact] - public async Task SuccessCaseV2() + public class SuccessCases { - var dataClassSource = @" + [Fact] + public async Task SuccessCaseV2() + { + var dataClassSource = @" using System.Collections; using System.Collections.Generic; @@ -26,13 +31,13 @@ class DataClass: IEnumerable { IEnumerator IEnumerable.GetEnumerator() => null; }"; - await Verify.VerifyAnalyzerV2(LanguageVersion.CSharp7_1, [TestMethodSource, dataClassSource]); - } + await Verify.VerifyAnalyzerV2(LanguageVersion.CSharp9, [TestMethodSource(), dataClassSource]); + } - public static TheoryData SuccessCasesV3Data = new() - { - // IEnumerable> - @" + public static TheoryData SuccessCasesV3Data = new() + { + // IEnumerable> maps to int + { "(int n)", @" using System.Collections; using System.Collections.Generic; using Xunit; @@ -40,36 +45,146 @@ class DataClass: IEnumerable { class DataClass: IEnumerable> { public IEnumerator> GetEnumerator() => null; IEnumerator IEnumerable.GetEnumerator() => null; -}", - // IAsyncEnumerable> - @" +}" }, + // IAsyncEnumerable> maps to int + { "(int n)", @" using System.Collections.Generic; using System.Threading; using Xunit; public class DataClass : IAsyncEnumerable> { public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; -}", - // IAsyncEnumerable - @" +}" }, + // IAsyncEnumerable maps to int + { "(int n)", @" using System.Collections.Generic; using System.Threading; using Xunit; +public class DerivedTheoryDataRow : TheoryDataRow { + public DerivedTheoryDataRow(int t) : base(t) { } +} + public class DataClass : IAsyncEnumerable { public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> maps to int + { "(int n)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DerivedTheoryDataRow : TheoryDataRow { + public DerivedTheoryDataRow(T t) : base(t) { } } -public class DerivedTheoryDataRow : TheoryDataRow { - public DerivedTheoryDataRow(int t) : base(t) { } -}", - }; +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> maps to int + { "(int n)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; - [Theory] - [MemberData(nameof(SuccessCasesV3Data))] - public async Task SuccessCasesV3(string dataClassSource) - { - await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp7_1, [TestMethodSource, dataClassSource]); +public class DerivedTheoryDataRow : TheoryDataRow { + public DerivedTheoryDataRow(T t, U u) : base(t) { } +} + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> with optional parameter + { "(int n, int p = 0)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> with params array (no values) + { "(int n, params int[] a)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> with params array (one value) + { "(int n, params string[] a)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> with params array (multiple values) + { "(int n, params string[] a)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> with params array (array for params array) + { "(int n, params string[] a)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> maps to generic T + { "(T t)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> maps to generic T? + { "(T? t)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> maps unnamed tuple to named tuple + { "((int x, int y) point)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + // IAsyncEnumerable> maps tuples with mismatched names + { "((int x, int y) point)", @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}" }, + }; + + [Theory] + [MemberData(nameof(SuccessCasesV3Data))] + public async Task SuccessCasesV3( + string methodParams, + string dataClassSource) + { + await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp9, [TestMethodSource(methodParams), dataClassSource]); + } } public class X1007_ClassDataAttributeMustPointAtValidClass @@ -133,16 +248,16 @@ public async Task FailureCases(string dataClassSource) var expectedV2 = Verify .Diagnostic("xUnit1007") - .WithSpan(6, 23, 6, 32) + .WithSpan(8, 6, 8, 34) .WithArguments("DataClass", "IEnumerable"); var expectedV3 = Verify .Diagnostic("xUnit1007") - .WithSpan(6, 23, 6, 32) + .WithSpan(8, 6, 8, 34) .WithArguments("DataClass", "IEnumerable, IAsyncEnumerable, IEnumerable, or IAsyncEnumerable"); - await Verify.VerifyAnalyzerV2([TestMethodSource, dataClassSource], expectedV2); - await Verify.VerifyAnalyzerV3([TestMethodSource, dataClassSource], expectedV3); + await Verify.VerifyAnalyzerV2(LanguageVersion.CSharp9, [TestMethodSource(), dataClassSource], expectedV2); + await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp9, [TestMethodSource(), dataClassSource], expectedV3); } [Fact] @@ -158,10 +273,166 @@ public class DataClass : IAsyncEnumerable { var expected = Verify .Diagnostic("xUnit1007") - .WithSpan(6, 23, 6, 32) + .WithSpan(8, 6, 8, 34) .WithArguments("DataClass", "IEnumerable"); - await Verify.VerifyAnalyzerV2(LanguageVersion.CSharp7_1, [TestMethodSource, dataClassSource], expected); + await Verify.VerifyAnalyzerV2(LanguageVersion.CSharp9, [TestMethodSource(), dataClassSource], expected); + } + } + + public class X1037_TheoryDataTypeArgumentsMustMatchTestMethodParameters_TooFewTypeParameters + { + [Theory] + [InlineData("TheoryDataRow")] + [InlineData("DerivedTheoryDataRow")] + public async Task NotEnoughTypeParameters_Triggers(string theoryDataRowType) + { + var source = $@" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DerivedTheoryDataRow : TheoryDataRow {{ + public DerivedTheoryDataRow(T t, U u) : base(t) {{ }} +}} + +public class DataClass : IAsyncEnumerable<{theoryDataRowType}> {{ + public IAsyncEnumerator<{theoryDataRowType}> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}}"; + + var expected = + Verify + .Diagnostic("xUnit1037") + .WithSpan(8, 6, 8, 34) + .WithSeverity(DiagnosticSeverity.Error) + .WithArguments("Xunit.TheoryDataRow"); + + await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp9, [TestMethodSource("(int n, string f)"), source], expected); + } + } + + public class X1038_TheoryDataTypeArgumentsMustMatchTestMethodParameters_ExtraTypeParameters + { + [Theory] + [InlineData("TheoryDataRow")] + [InlineData("DerivedTheoryDataRow")] + public async Task TooManyTypeParameters_Triggers(string theoryDataRowType) + { + var source = $@" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DerivedTheoryDataRow : TheoryDataRow {{ + public DerivedTheoryDataRow(T t) : base(t, 21.12) {{ }} +}} + +public class DataClass : IAsyncEnumerable<{theoryDataRowType}> {{ + public IAsyncEnumerator<{theoryDataRowType}> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}}"; + + var expected = + Verify + .Diagnostic("xUnit1038") + .WithSpan(8, 6, 8, 34) + .WithSeverity(DiagnosticSeverity.Error) + .WithArguments("Xunit.TheoryDataRow"); + + await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp9, [TestMethodSource("(int n)"), source], expected); + } + + [Fact] + public async Task ExtraDataPastParamsArray_Triggers() + { + var source = $@" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> {{ + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}}"; + + var expected = + Verify + .Diagnostic("xUnit1038") + .WithSpan(8, 6, 8, 34) + .WithSeverity(DiagnosticSeverity.Error) + .WithArguments("Xunit.TheoryDataRow"); + + await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp9, [TestMethodSource("(int n, params double[] d)"), source], expected); + } + } + + public class X1039_TheoryDataTypeArgumentsMustMatchTestMethodParameters_IncompatibleTypes + { + [Fact] + public async Task WithIncompatibleType_Triggers() + { + var source = @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}"; + + var expected = + Verify + .Diagnostic("xUnit1039") + .WithSpan(9, 35, 9, 41) + .WithSeverity(DiagnosticSeverity.Error) + .WithArguments("string", "DataClass", "d"); + + await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp9, [TestMethodSource("(int n, double d)"), source], expected); + } + + [Fact] + public async Task WithExtraValueNotCompatibleWithParamsArray_Triggers() + { + var source = @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}"; + + var expected = + Verify + .Diagnostic("xUnit1039") + .WithSpan(9, 42, 9, 50) + .WithSeverity(DiagnosticSeverity.Error) + .WithArguments("int", "DataClass", "s"); + + await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp9, [TestMethodSource("(int n, params string[] s)"), source], expected); + } + } + + public class X1040_TheoryDataTypeArgumentsMustMatchTestMethodParameters_IncompatibleNullability + { + [Fact] + public async Task ValidTheoryDataRowMemberWithMismatchedNullability_Triggers() + { + var source = @" +using System.Collections.Generic; +using System.Threading; +using Xunit; + +public class DataClass : IAsyncEnumerable> { + public IAsyncEnumerator> GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; +}"; + + var expected = + Verify + .Diagnostic("xUnit1040") + .WithSpan(9, 28, 9, 34) + .WithSeverity(DiagnosticSeverity.Warning) + .WithArguments("string?", "DataClass", "s"); + + await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp9, [TestMethodSource("(string s)"), source], expected); } } @@ -235,9 +506,9 @@ public async Task FailureCases(string dataClassSource) var expected = Verify .Diagnostic("xUnit1050") - .WithSpan(6, 23, 6, 32); + .WithSpan(8, 6, 8, 34); - await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp7_1, [TestMethodSource, dataClassSource], expected); + await Verify.VerifyAnalyzerV3(LanguageVersion.CSharp9, [TestMethodSource(), dataClassSource], expected); } } } diff --git a/src/xunit.analyzers.tests/Fixes/X1000/ClassDataAttributeMustPointAtValidClassFixerTests.cs b/src/xunit.analyzers.tests/Fixes/X1000/ClassDataAttributeMustPointAtValidClassFixerTests.cs index 966ba1b8..f839dd6d 100644 --- a/src/xunit.analyzers.tests/Fixes/X1000/ClassDataAttributeMustPointAtValidClassFixerTests.cs +++ b/src/xunit.analyzers.tests/Fixes/X1000/ClassDataAttributeMustPointAtValidClassFixerTests.cs @@ -17,7 +17,7 @@ public class TestData { public class TestClass { [Theory] - [ClassData(typeof({|xUnit1007:TestData|}))] + [{|xUnit1007:ClassData(typeof(TestData))|}] public void TestMethod(int _) { } }"; var afterV2 = @" @@ -33,7 +33,7 @@ public class TestClass { [ClassData(typeof(TestData))] public void TestMethod(int _) { } }"; - var afterV3 = afterV2.Replace("typeof(TestData)", "typeof({|xUnit1050:TestData|})"); + var afterV3 = afterV2.Replace("ClassData(typeof(TestData))", "{|xUnit1050:ClassData(typeof(TestData))|}"); await Verify.VerifyCodeFixV2(before, afterV2, ClassDataAttributeMustPointAtValidClassFixer.Key_FixDataClass); await Verify.VerifyCodeFixV3(before, afterV3, ClassDataAttributeMustPointAtValidClassFixer.Key_FixDataClass); @@ -56,7 +56,7 @@ public class TestData : IEnumerable { public class TestClass { [Theory] - [ClassData(typeof({|xUnit1007:TestData|}))] + [{|xUnit1007:ClassData(typeof(TestData))|}] public void TestMethod(int _) { } }"; var afterV2 = @" @@ -76,7 +76,7 @@ public class TestClass { [ClassData(typeof(TestData))] public void TestMethod(int _) { } }"; - var afterV3 = afterV2.Replace("typeof(TestData)", "typeof({|xUnit1050:TestData|})"); + var afterV3 = afterV2.Replace("ClassData(typeof(TestData))", "{|xUnit1050:ClassData(typeof(TestData))|}"); await Verify.VerifyCodeFixV2(before, afterV2, ClassDataAttributeMustPointAtValidClassFixer.Key_FixDataClass); await Verify.VerifyCodeFixV3(before, afterV3, ClassDataAttributeMustPointAtValidClassFixer.Key_FixDataClass); @@ -99,7 +99,7 @@ public class TestData : IEnumerable { public class TestClass { [Theory] - [ClassData(typeof({|xUnit1007:TestData|}))] + [{|xUnit1007:ClassData(typeof(TestData))|}] public void TestMethod(int _) { } }"; var afterV2 = @" @@ -123,7 +123,7 @@ public class TestClass { [ClassData(typeof(TestData))] public void TestMethod(int _) { } }"; - var afterV3 = afterV2.Replace("typeof(TestData)", "typeof({|xUnit1050:TestData|})"); + var afterV3 = afterV2.Replace("ClassData(typeof(TestData))", "{|xUnit1050:ClassData(typeof(TestData))|}"); await Verify.VerifyCodeFixV2(before, afterV2, ClassDataAttributeMustPointAtValidClassFixer.Key_FixDataClass); await Verify.VerifyCodeFixV3(before, afterV3, ClassDataAttributeMustPointAtValidClassFixer.Key_FixDataClass); @@ -144,7 +144,7 @@ public abstract class TestData : IEnumerable { public class TestClass { [Theory] - [ClassData(typeof({|xUnit1007:TestData|}))] + [{|xUnit1007:ClassData(typeof(TestData))|}] public void TestMethod(int _) { } }"; var afterV2 = @" @@ -162,7 +162,7 @@ public class TestClass { [ClassData(typeof(TestData))] public void TestMethod(int _) { } }"; - var afterV3 = afterV2.Replace("typeof(TestData)", "typeof({|xUnit1050:TestData|})"); + var afterV3 = afterV2.Replace("ClassData(typeof(TestData))", "{|xUnit1050:ClassData(typeof(TestData))|}"); await Verify.VerifyCodeFixV2(before, afterV2, ClassDataAttributeMustPointAtValidClassFixer.Key_FixDataClass); await Verify.VerifyCodeFixV3(before, afterV3, ClassDataAttributeMustPointAtValidClassFixer.Key_FixDataClass); diff --git a/src/xunit.analyzers/X1000/ClassDataAttributeMustPointAtValidClass.cs b/src/xunit.analyzers/X1000/ClassDataAttributeMustPointAtValidClass.cs index 402b1f3f..5a9ec694 100644 --- a/src/xunit.analyzers/X1000/ClassDataAttributeMustPointAtValidClass.cs +++ b/src/xunit.analyzers/X1000/ClassDataAttributeMustPointAtValidClass.cs @@ -1,4 +1,7 @@ +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Reflection; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -15,6 +18,10 @@ public class ClassDataAttributeMustPointAtValidClass : XunitDiagnosticAnalyzer public ClassDataAttributeMustPointAtValidClass() : base( Descriptors.X1007_ClassDataAttributeMustPointAtValidClass, + Descriptors.X1037_TheoryDataTypeArgumentsMustMatchTestMethodParameters_TooFewTypeParameters, + Descriptors.X1038_TheoryDataTypeArgumentsMustMatchTestMethodParameters_ExtraTypeParameters, + Descriptors.X1039_TheoryDataTypeArgumentsMustMatchTestMethodParameters_IncompatibleTypes, + Descriptors.X1040_TheoryDataTypeArgumentsMustMatchTestMethodParameters_IncompatibleNullability, Descriptors.X1050_ClassDataTheoryDataRowIsRecommendedForStronglyTypedAnalysis ) { } @@ -31,71 +38,277 @@ public override void AnalyzeCompilation( var iEnumerableOfTheoryDataRow = TypeSymbolFactory.IEnumerableOfITheoryDataRow(compilation); var iAsyncEnumerableOfObjectArray = TypeSymbolFactory.IAsyncEnumerableOfObjectArray(compilation); var iAsyncEnumerableOfTheoryDataRow = TypeSymbolFactory.IAsyncEnumerableOfITheoryDataRow(compilation); - var theoryDataRows = - TypeSymbolFactory - .TheoryDataRow_ByGenericArgumentCount(compilation) - .Where(kvp => kvp.Key != 0) - .Select(kvp => kvp.Value.ConstructUnboundGenericType()) - .ToArray(); + var theoryDataRowTypes = TypeSymbolFactory.TheoryDataRow_ByGenericArgumentCount(compilation); context.RegisterSyntaxNodeAction(context => { - if (context.Node is not AttributeSyntax attribute) - return; - if (attribute.ArgumentList?.Arguments.FirstOrDefault()?.Expression is not TypeOfExpressionSyntax argumentExpression) + if (context.Node is not MethodDeclarationSyntax testMethod) return; + var attributeLists = testMethod.AttributeLists; var semanticModel = context.SemanticModel; - if (!SymbolEqualityComparer.Default.Equals(semanticModel.GetTypeInfo(attribute).Type, xunitContext.Core.ClassDataAttributeType)) - return; - - if (semanticModel.GetTypeInfo(argumentExpression.Type).Type is not INamedTypeSymbol classType) - return; - if (classType.Kind == SymbolKind.ErrorType) - return; - var missingInterface = !iEnumerableOfObjectArray.IsAssignableFrom(classType); - if (xunitContext.HasV3References) + foreach (var attributeSyntax in attributeLists.WhereNotNull().SelectMany(attList => attList.Attributes)) { - if (missingInterface && iEnumerableOfTheoryDataRow is not null) - missingInterface = !iEnumerableOfTheoryDataRow.IsAssignableFrom(classType); - if (missingInterface && iAsyncEnumerableOfObjectArray is not null) - missingInterface = !iAsyncEnumerableOfObjectArray.IsAssignableFrom(classType); - if (missingInterface && iAsyncEnumerableOfTheoryDataRow is not null) - missingInterface = !iAsyncEnumerableOfTheoryDataRow.IsAssignableFrom(classType); - } + context.CancellationToken.ThrowIfCancellationRequested(); - var isAbstract = classType.IsAbstract; - var noValidConstructor = !classType.InstanceConstructors.Any(c => c.Parameters.IsEmpty && c.DeclaredAccessibility == Accessibility.Public); + // Only work against ClassDataAttribute + if (!SymbolEqualityComparer.Default.Equals(semanticModel.GetTypeInfo(attributeSyntax).Type, xunitContext.Core.ClassDataAttributeType)) + continue; - if (missingInterface || isAbstract || noValidConstructor) - { - context.ReportDiagnostic( - Diagnostic.Create( - Descriptors.X1007_ClassDataAttributeMustPointAtValidClass, - argumentExpression.Type.GetLocation(), - classType.Name, - xunitContext.HasV3References ? typesV3 : typesV2 - ) - ); - return; + // Need the referenced type to do anything + if (attributeSyntax.ArgumentList is null) + continue; + if (attributeSyntax.ArgumentList.Arguments[0].Expression is not TypeOfExpressionSyntax typeOfExpression) + continue; + if (semanticModel.GetTypeInfo(typeOfExpression.Type).Type is not INamedTypeSymbol classType) + continue; + if (classType.Kind == SymbolKind.ErrorType) + continue; + + // Make sure the class implements a compatible interface + var isValidDeclaration = VerifyDataSourceDeclaration(context, compilation, xunitContext, classType, attributeSyntax); + + // Everything from here is based on ensuring I(Async)Enumerable>, which is + // only available in v3. + if (!xunitContext.HasV3References) + continue; + + var rowType = classType.UnwrapEnumerable(compilation); + if (rowType is null) + continue; + + if (IsGenericTheoryDataRowType(rowType, theoryDataRowTypes, out var theoryDataReturnType)) + VerifyGenericArgumentTypes(semanticModel, context, testMethod, theoryDataRowTypes[0], theoryDataReturnType, classType, attributeSyntax); + else if (isValidDeclaration) + ReportClassReturnsUnsafeTypeValue(context, attributeSyntax); } + }, SyntaxKind.MethodDeclaration); + } - // For v3, recommend I(Async)Enumerable> over I(Async)Enumerable. - // Can't make any recommendations for v2 projects, since deriving from TheoryData<> doesn't make sense. - if (!xunitContext.HasV3References) - return; + static bool IsGenericTheoryDataRowType( + ITypeSymbol? rowType, + Dictionary theoryDataRowTypes, + [NotNullWhen(true)] out INamedTypeSymbol? theoryReturnType) + { + theoryReturnType = default; - var rowType = classType.UnwrapEnumerable(compilation); - if (rowType is null || theoryDataRows.Any(tdr => tdr.IsAssignableFrom(rowType.OriginalDefinition))) - return; + var working = rowType as INamedTypeSymbol; + for (; working is not null; working = working.BaseType) + { + var returnTypeArguments = working.TypeArguments; + if (returnTypeArguments.Length != 0 + && theoryDataRowTypes.TryGetValue(returnTypeArguments.Length, out var theoryDataType) + && SymbolEqualityComparer.Default.Equals(theoryDataType, working.OriginalDefinition)) + break; + } + + if (working is null) + return false; + + theoryReturnType = working; + return true; + } + static void ReportClassReturnsUnsafeTypeValue( + SyntaxNodeAnalysisContext context, + AttributeSyntax attribute) => context.ReportDiagnostic( Diagnostic.Create( Descriptors.X1050_ClassDataTheoryDataRowIsRecommendedForStronglyTypedAnalysis, - argumentExpression.Type.GetLocation() + attribute.GetLocation() + ) + ); + + static void ReportExtraTypeArguments( + SyntaxNodeAnalysisContext context, + AttributeSyntax attribute, + INamedTypeSymbol theoryDataType) => + context.ReportDiagnostic( + Diagnostic.Create( + Descriptors.X1038_TheoryDataTypeArgumentsMustMatchTestMethodParameters_ExtraTypeParameters, + attribute.GetLocation(), + SymbolDisplay.ToDisplayString(theoryDataType) ) ); - }, SyntaxKind.Attribute); + + static void ReportIncompatibleType( + SyntaxNodeAnalysisContext context, + TypeSyntax parameterType, + ITypeSymbol theoryDataTypeParameter, + INamedTypeSymbol namedClassType, + IParameterSymbol parameter) => + context.ReportDiagnostic( + Diagnostic.Create( + Descriptors.X1039_TheoryDataTypeArgumentsMustMatchTestMethodParameters_IncompatibleTypes, + parameterType.GetLocation(), + SymbolDisplay.ToDisplayString(theoryDataTypeParameter), + SymbolDisplay.ToDisplayString(namedClassType), + parameter.Name + ) + ); + + static void ReportIncorrectImplementationType( + SyntaxNodeAnalysisContext context, + string validSymbols, + AttributeSyntax attribute, + ITypeSymbol classType) => + context.ReportDiagnostic( + Diagnostic.Create( + Descriptors.X1007_ClassDataAttributeMustPointAtValidClass, + attribute.GetLocation(), + classType.Name, + validSymbols + ) + ); + + static void ReportNullabilityMismatch( + SyntaxNodeAnalysisContext context, + TypeSyntax parameterType, + ITypeSymbol theoryDataTypeParameter, + INamedTypeSymbol namedClassType, + IParameterSymbol parameter) => + context.ReportDiagnostic( + Diagnostic.Create( + Descriptors.X1040_TheoryDataTypeArgumentsMustMatchTestMethodParameters_IncompatibleNullability, + parameterType.GetLocation(), + SymbolDisplay.ToDisplayString(theoryDataTypeParameter), + SymbolDisplay.ToDisplayString(namedClassType), + parameter.Name + ) + ); + + static void ReportTooFewTypeArguments( + SyntaxNodeAnalysisContext context, + AttributeSyntax attribute, + INamedTypeSymbol theoryDataType) => + context.ReportDiagnostic( + Diagnostic.Create( + Descriptors.X1037_TheoryDataTypeArgumentsMustMatchTestMethodParameters_TooFewTypeParameters, + attribute.GetLocation(), + SymbolDisplay.ToDisplayString(theoryDataType) + ) + ); + + static bool VerifyDataSourceDeclaration( + SyntaxNodeAnalysisContext context, + Compilation compilation, + XunitContext xunitContext, + INamedTypeSymbol classType, + AttributeSyntax attribute) + { + var v3 = xunitContext.HasV3References; + var iEnumerableOfObjectArrayType = TypeSymbolFactory.IEnumerableOfObjectArray(compilation); + var iEnumerableOfTheoryDataRowType = TypeSymbolFactory.IEnumerableOfITheoryDataRow(compilation); + var iAsyncEnumerableOfObjectArrayType = TypeSymbolFactory.IAsyncEnumerableOfObjectArray(compilation); + var iAsyncEnumerableOfTheoryDataRowType = TypeSymbolFactory.IAsyncEnumerableOfITheoryDataRow(compilation); + + // Make sure we implement one of the interfaces + var valid = iEnumerableOfObjectArrayType.IsAssignableFrom(classType); + + if (!valid && v3 && iAsyncEnumerableOfObjectArrayType is not null) + valid = iAsyncEnumerableOfObjectArrayType.IsAssignableFrom(classType); + + if (!valid && v3 && iEnumerableOfTheoryDataRowType is not null) + valid = iEnumerableOfTheoryDataRowType.IsAssignableFrom(classType); + + if (!valid && v3 && iAsyncEnumerableOfTheoryDataRowType is not null) + valid = iAsyncEnumerableOfTheoryDataRowType.IsAssignableFrom(classType); + + // Also make sure we're non-abstract and have an empty constructor + valid = + valid && + !classType.IsAbstract && + classType.InstanceConstructors.Any(c => c.Parameters.IsEmpty && c.DeclaredAccessibility == Accessibility.Public); + + if (!valid) + ReportIncorrectImplementationType( + context, + xunitContext.HasV3References ? typesV3 : typesV2, + attribute, + classType + ); + + return valid; + } + + static void VerifyGenericArgumentTypes( + SemanticModel semanticModel, + SyntaxNodeAnalysisContext context, + MethodDeclarationSyntax testMethod, + INamedTypeSymbol theoryDataType, + INamedTypeSymbol theoryReturnType, + ITypeSymbol classType, + AttributeSyntax attribute) + { + if (classType is not INamedTypeSymbol namedClassType) + return; + + var returnTypeArguments = theoryReturnType.TypeArguments; + var testMethodSymbol = semanticModel.GetDeclaredSymbol(testMethod, context.CancellationToken); + if (testMethodSymbol is null) + return; + + var testMethodParameterSymbols = testMethodSymbol.Parameters; + var testMethodParameterSyntaxes = testMethod.ParameterList.Parameters; + + if (testMethodParameterSymbols.Length > returnTypeArguments.Length + && testMethodParameterSymbols.Skip(returnTypeArguments.Length).Any(p => !p.IsOptional && !p.IsParams)) + { + ReportTooFewTypeArguments(context, attribute, theoryDataType); + return; + } + + int typeArgumentIdx = 0, parameterTypeIdx = 0; + for (; typeArgumentIdx < returnTypeArguments.Length && parameterTypeIdx < testMethodParameterSymbols.Length; typeArgumentIdx++) + { + var parameterSyntax = testMethodParameterSyntaxes[parameterTypeIdx]; + if (parameterSyntax.Type is null) + continue; + + var parameter = testMethodParameterSymbols[parameterTypeIdx]; + if (parameter.Type is null) + continue; + + var parameterType = + parameter.IsParams && parameter.Type is IArrayTypeSymbol paramsArraySymbol + ? paramsArraySymbol.ElementType + : parameter.Type; + + var typeArgument = returnTypeArguments[typeArgumentIdx]; + if (typeArgument is null) + continue; + + if (parameterType.Kind != SymbolKind.TypeParameter && !parameterType.IsAssignableFrom(typeArgument)) + { + bool report = true; + + // The user might be providing the full array for 'params'; if they do, we need to move + // the parameter type index forward because it's been consumed by the array + if (parameter.IsParams && parameter.Type.IsAssignableFrom(typeArgument)) + { + report = false; + parameterTypeIdx++; + } + + if (report) + ReportIncompatibleType(context, parameterSyntax.Type, typeArgument, namedClassType, parameter); + } + + // Nullability of value types is handled by the type compatibility test, + // but nullability of reference types isn't + if (parameterType.IsReferenceType + && typeArgument.IsReferenceType + && parameterType.NullableAnnotation == NullableAnnotation.NotAnnotated + && typeArgument.NullableAnnotation == NullableAnnotation.Annotated) + ReportNullabilityMismatch(context, parameterSyntax.Type, typeArgument, namedClassType, parameter); + + // Only move the parameter type index forward when the current parameter is not a 'params' + if (!parameter.IsParams) + parameterTypeIdx++; + } + + if (typeArgumentIdx < returnTypeArguments.Length) + ReportExtraTypeArguments(context, attribute, theoryDataType); } }