Skip to content

Commit

Permalink
Merge pull request #227 from SteveDunn/prohibit-reflection
Browse files Browse the repository at this point in the history
Prohibit reflection with Activator.CreateInstance
  • Loading branch information
SteveDunn authored Sep 25, 2022
2 parents 0686f20 + c1c451c commit e076189
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 10 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ public partial struct CustomerId {
```csharp
// catches object creation expressions
var c = new CustomerId(); // error VOG010: Type 'CustomerId' cannot be constructed with 'new' as it is prohibited
var c = Activator.CreateInstance<CustomerId>(); // error VOG025: Type 'CustomerId' cannot be constructed via Reflection as it is prohibited.
CustomerId c = default; // error VOG009: Type 'CustomerId' cannot be constructed with default as it is prohibited.
var c = default(CustomerId); // error VOG009: Type 'CustomerId' cannot be constructed with default as it is prohibited.
var c = GetCustomerId(); // error VOG010: Type 'CustomerId' cannot be constructed with 'new' as it is prohibited
Expand Down Expand Up @@ -649,5 +650,7 @@ To test in VS, you'll have a new 'launch profile':

Select the Vogen project as the active project, and from the dropdown, select 'Roslyn'. Then just F5 to start debugging.

To debug an analyzer, write a unit test using `DisallowNewTests` as a template, not forgetting to change `CreationUsingImplicitNewAnalyzer` to the analyzer you want to test.

### Can I get it to throw my own exception?
Yes, by specifying the exception type in either the `ValueObject` attribute, or globally, with `VogenConfiguration`.
95 changes: 95 additions & 0 deletions src/Vogen/Analyzers/CreationUsingReflectionAnalyzer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using System.Collections.Immutable;
using System.Linq;
using Analyzer.Utilities.Extensions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Vogen.Diagnostics;

namespace Vogen.Analyzers;

/// <summary>
/// An analyzer that stops `CustomerId = default;`.
/// </summary>
[Generator]
public class CreationUsingReflectionAnalyzer : IIncrementalGenerator
{
public record struct FoundItem(Location Location, INamedTypeSymbol VoClass);

public void Initialize(IncrementalGeneratorInitializationContext context)
{
IncrementalValuesProvider<FoundItem?> targets = GetTargets(context);

IncrementalValueProvider<(Compilation, ImmutableArray<FoundItem?>)> compilationAndTypes
= context.CompilationProvider.Combine(targets.Collect());

context.RegisterSourceOutput(compilationAndTypes,
static (spc, source) => Execute(source.Item2, spc));
}

private static IncrementalValuesProvider<FoundItem?> GetTargets(IncrementalGeneratorInitializationContext context)
{
return context.SyntaxProvider.CreateSyntaxProvider(
predicate: static (s, _) => s is InvocationExpressionSyntax,
transform: static (ctx, _) => TryGetTarget(ctx))
.Where(static m => m is not null);
}

private static FoundItem? TryGetTarget(GeneratorSyntaxContext ctx)
{
var syntax = (InvocationExpressionSyntax) ctx.Node;
var methodSymbol = (ctx.SemanticModel.GetSymbolInfo(syntax).Symbol as IMethodSymbol);
if (methodSymbol == null)
{
return null;
}

if (methodSymbol.ReceiverType?.FullNamespace() != "System") return null;
if (methodSymbol.ReceiverType.Name != "Activator") return null;

if (methodSymbol.Name != "CreateInstance")
{
return null;
}

if (methodSymbol.Parameters.Length == 0)
{
if (!methodSymbol.IsGenericMethod) return null;

var returnType = methodSymbol.ReturnType as INamedTypeSymbol;

if (returnType == null) return null;

if (!VoFilter.IsTarget(returnType)) return null;

return new FoundItem(syntax.GetLocation(), returnType);
}

if (methodSymbol.Parameters.Length == 1)
{
var childNodes = syntax.DescendantNodes().OfType<TypeOfExpressionSyntax>().ToList();

if (childNodes.Count != 1) return null;

TypeInfo xxy = ctx.SemanticModel.GetTypeInfo(childNodes[0].Type);

var syntaxNode = xxy.Type as INamedTypeSymbol;

if (!VoFilter.IsTarget(syntaxNode)) return null;

return new FoundItem(syntax.GetLocation(), syntaxNode!);
}

return null;
}

static void Execute(ImmutableArray<FoundItem?> typeDeclarations, SourceProductionContext context)
{
foreach (FoundItem? eachFoundItem in typeDeclarations)
{
if (eachFoundItem is not null)
{
context.ReportDiagnostic(DiagnosticItems.UsingActivatorProhibited(eachFoundItem.Value.Location, eachFoundItem.Value.VoClass.Name));
}
}
}
}
3 changes: 2 additions & 1 deletion src/Vogen/Diagnostics/DiagnosticCode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ public enum DiagnosticCode
TypeShouldBePartial = 21,
InvalidDeserializationStrictness = 22,
InstanceValueCannotBeConverted = 23,
DuplicateTypesFound = 24
DuplicateTypesFound = 24,
UsingActivatorProhibited = 25
}
8 changes: 8 additions & 0 deletions src/Vogen/Diagnostics/DiagnosticItems.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ internal static class DiagnosticItems
"Using default of Value Objects is prohibited",
"Type '{0}' cannot be constructed with default as it is prohibited.");

private static readonly DiagnosticDescriptor _usingActivatorProhibited = CreateDescriptor(
DiagnosticCode.UsingActivatorProhibited,
"Using Reflection to create Value Objects is prohibited",
"Type '{0}' cannot be constructed via Reflection as it is prohibited.");

private static readonly DiagnosticDescriptor _usingNewProhibited = CreateDescriptor(
DiagnosticCode.UsingNewProhibited,
"Using new to create Value Objects is prohibited. Please use the From method for creation.",
Expand Down Expand Up @@ -152,6 +157,9 @@ public static Diagnostic NormalizeInputMethodMustBeStatic(MethodDeclarationSynta
public static Diagnostic UsingDefaultProhibited(Location locationOfDefaultStatement, string voClassName) =>
BuildDiagnostic(_usingDefaultProhibited, voClassName, locationOfDefaultStatement);

public static Diagnostic UsingActivatorProhibited(Location locationOfDefaultStatement, string voClassName) =>
BuildDiagnostic(_usingActivatorProhibited, voClassName, locationOfDefaultStatement);

public static Diagnostic UsingNewProhibited(Location location, string voClassName) =>
BuildDiagnostic(_usingNewProhibited, voClassName, location);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using System.Linq;
using FluentAssertions;
using FluentAssertions.Execution;
using Microsoft.CodeAnalysis;
using Vogen.Analyzers;
using Xunit;

namespace NotSystem
{
public static class Activator
{
public static T? CreateInstance<T>() => default(T);
}
}

namespace SmallTests.DiagnosticsTests
{
public class DisallowCreationWithReflectionTests
{
[Fact]
public void Allows_using_Activate_CreateInstance_from_another_namespace()
{
var x = NotSystem.Activator.CreateInstance<string>();
}


[Theory]
[InlineData("partial class")]
[InlineData("partial struct")]
[InlineData("readonly partial struct")]
[InlineData("partial record class")]
[InlineData("partial record struct")]
[InlineData("readonly partial record struct")]
public void Disallows_generic_method(string type)
{
var source = $@"using Vogen;
using System;
namespace Whatever;
[ValueObject(typeof(int))]
public {type} CustomerId {{ }}
var c = Activator.CreateInstance<CustomerId>();
";

var (diagnostics, _) = TestHelper.GetGeneratedOutput<CreationUsingReflectionAnalyzer>(source);

using var _ = new AssertionScope();

diagnostics.Should().HaveCount(1);
Diagnostic diagnostic = diagnostics.Single();

diagnostic.Id.Should().Be("VOG025");
diagnostic.ToString().Should().Match("*error VOG025: Type 'CustomerId' cannot be constructed via Reflection as it is prohibited.");
}

[Theory]
[InlineData("partial class")]
[InlineData("partial struct")]
[InlineData("readonly partial struct")]
[InlineData("partial record class")]
[InlineData("partial record struct")]
[InlineData("readonly partial record struct")]
public void Disallows_non_generic_method(string type)
{
var source = $@"using Vogen;
using System;
namespace Whatever;
[ValueObject(typeof(int))]
public {type} CustomerId {{ }}
var c = Activator.CreateInstance(typeof(CustomerId));
";

var (diagnostics, _) = TestHelper.GetGeneratedOutput<CreationUsingReflectionAnalyzer>(source);

using var _ = new AssertionScope();

diagnostics.Should().HaveCount(1);
Diagnostic diagnostic = diagnostics.Single();

diagnostic.Id.Should().Be("VOG025");
diagnostic.ToString().Should().Match("*error VOG025: Type 'CustomerId' cannot be constructed via Reflection as it is prohibited.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using FluentAssertions;
using Microsoft.CodeAnalysis;
using Vogen.Analyzers;
using Vogen.Tests;
using Xunit;

namespace SmallTests.DiagnosticsTests;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using FluentAssertions;
using Vogen;
using Vogen.Tests;
using Xunit;

namespace SmallTests.DiagnosticsTests.LocalConfig;
Expand Down
1 change: 0 additions & 1 deletion tests/SmallTests/DiagnosticsTests/LocalConfig/SadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using FluentAssertions;
using Microsoft.CodeAnalysis;
using Vogen;
using Vogen.Tests;
using Xunit;

namespace SmallTests.DiagnosticsTests.LocalConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using FluentAssertions.Execution;
using Microsoft.CodeAnalysis;
using Vogen;
using Vogen.Tests;
using Xunit;

namespace SmallTests.DiagnosticsTests;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using FluentAssertions.Execution;
using Microsoft.CodeAnalysis;
using Vogen.Analyzers;
using Vogen.Tests;
using Xunit;

namespace SmallTests.DiagnosticsTests;
Expand Down
4 changes: 3 additions & 1 deletion tests/SmallTests/RecordClassCreationTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma warning disable VOG025

using System;
using FluentAssertions;
using Vogen;
Expand Down Expand Up @@ -47,7 +49,7 @@ public void Creation_Unhappy_Path_MyInt()
[Fact]
public void Default_vo_throws_at_runtime()
{
MyRecordClassInt vo = (MyRecordClassInt) Activator.CreateInstance(typeof(MyRecordClassInt))!;
MyRecordClassInt vo = (MyRecordClassInt) Activator.CreateInstance(Type.GetType("Vogen.Tests.Types.MyRecordClassInt")!)!;
Func<int> action = () => vo.Value;

action.Should().Throw<ValueObjectValidationException>().WithMessage("Use of uninitialized Value Object*");
Expand Down
3 changes: 2 additions & 1 deletion tests/SmallTests/RecordStructCreationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public void Creation_Unhappy_Path_MyInt()
[Fact]
public void Default_vo_throws_at_runtime()
{
MyRecordStructInt vo = (MyRecordStructInt) Activator.CreateInstance(typeof(MyRecordStructInt))!;
MyRecordStructInt vo =
(MyRecordStructInt) Activator.CreateInstance(Type.GetType("Vogen.Tests.Types.MyRecordStructInt")!)!;
Func<int> action = () => vo.Value;

action.Should().Throw<ValueObjectValidationException>().WithMessage("Use of uninitialized Value Object*");
Expand Down
5 changes: 3 additions & 2 deletions tests/Testbench/Program.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System;
using System.Threading.Tasks;
using Vogen;
#pragma warning disable CS0219

namespace Testbench;

Expand All @@ -11,7 +10,9 @@ public static async Task Main()
{
await Task.CompletedTask;

bool b = MyIntVo.TryParse("123", out var r);
// var vo = Activator.CreateInstance<MyIntVo>();
//var vo = (MyIntVo)Activator.CreateInstance(typeof(MyIntVo))!;
//Console.WriteLine(vo.Value);
}
}

Expand Down

0 comments on commit e076189

Please sign in to comment.