Skip to content

Commit

Permalink
Provide a code-fix to add missing stateful marshaller shape methods (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jkoritzinsky authored Aug 5, 2022
1 parent 5d4526d commit 7e967b1
Show file tree
Hide file tree
Showing 8 changed files with 763 additions and 41 deletions.
2 changes: 1 addition & 1 deletion eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
<FsCheckVersion>2.14.3</FsCheckVersion>
<!-- Uncomment to set a fixed version, else the latest is used -->
<!--<SdkVersionForWorkloadTesting>7.0.100-rc.1.22402.35</SdkVersionForWorkloadTesting>-->
<CompilerPlatformTestingVersion>1.1.2-beta1.22205.2</CompilerPlatformTestingVersion>
<CompilerPlatformTestingVersion>1.1.2-beta1.22403.2</CompilerPlatformTestingVersion>
<!-- Docs -->
<MicrosoftPrivateIntellisenseVersion>7.0.0-preview-20220721.1</MicrosoftPrivateIntellisenseVersion>
<!-- ILLink -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ public static class DefaultMarshalModeDiagnostics
description: GetResourceString(nameof(SR.StatelessLinearCollectionRequiresTwoParameterAllocateContainerForManagedElementsDescription)));

/// <inheritdoc cref="CustomMarshallerAttributeAnalyzer.StatefulMarshallerRequiresFromManagedRule" />
public static readonly DiagnosticDescriptor StatefulMarshallerRequiresFromManagedRule =
private static readonly DiagnosticDescriptor StatefulMarshallerRequiresFromManagedRule =
new DiagnosticDescriptor(
Ids.CustomMarshallerTypeMustHaveRequiredShape,
GetResourceString(nameof(SR.CustomMarshallerTypeMustHaveRequiredShapeTitle)),
Expand All @@ -479,7 +479,7 @@ public static class DefaultMarshalModeDiagnostics
description: GetResourceString(nameof(SR.StatefulMarshallerRequiresFromManagedDescription)));

/// <inheritdoc cref="CustomMarshallerAttributeAnalyzer.StatefulMarshallerRequiresToUnmanagedRule" />
public static readonly DiagnosticDescriptor StatefulMarshallerRequiresToUnmanagedRule =
private static readonly DiagnosticDescriptor StatefulMarshallerRequiresToUnmanagedRule =
new DiagnosticDescriptor(
Ids.CustomMarshallerTypeMustHaveRequiredShape,
GetResourceString(nameof(SR.CustomMarshallerTypeMustHaveRequiredShapeTitle)),
Expand All @@ -490,7 +490,7 @@ public static class DefaultMarshalModeDiagnostics
description: GetResourceString(nameof(SR.StatefulMarshallerRequiresToUnmanagedDescription)));

/// <inheritdoc cref="CustomMarshallerAttributeAnalyzer.StatefulMarshallerRequiresToManagedRule" />
public static readonly DiagnosticDescriptor StatefulMarshallerRequiresToManagedRule =
private static readonly DiagnosticDescriptor StatefulMarshallerRequiresToManagedRule =
new DiagnosticDescriptor(
Ids.CustomMarshallerTypeMustHaveRequiredShape,
GetResourceString(nameof(SR.CustomMarshallerTypeMustHaveRequiredShapeTitle)),
Expand All @@ -501,7 +501,7 @@ public static class DefaultMarshalModeDiagnostics
description: GetResourceString(nameof(SR.StatefulMarshallerRequiresToManagedDescription)));

/// <inheritdoc cref="CustomMarshallerAttributeAnalyzer.StatefulMarshallerRequiresFromUnmanagedRule" />
public static readonly DiagnosticDescriptor StatefulMarshallerRequiresFromUnmanagedRule =
private static readonly DiagnosticDescriptor StatefulMarshallerRequiresFromUnmanagedRule =
new DiagnosticDescriptor(
Ids.CustomMarshallerTypeMustHaveRequiredShape,
GetResourceString(nameof(SR.CustomMarshallerTypeMustHaveRequiredShapeTitle)),
Expand All @@ -511,7 +511,7 @@ public static class DefaultMarshalModeDiagnostics
isEnabledByDefault: true,
description: GetResourceString(nameof(SR.StatefulMarshallerRequiresFromUnmanagedDescription)));

internal static DiagnosticDescriptor GetDefaultMarshalModeDiagnostic(DiagnosticDescriptor errorDescriptor)
public static DiagnosticDescriptor GetDefaultMarshalModeDiagnostic(DiagnosticDescriptor errorDescriptor)
{
if (ReferenceEquals(errorDescriptor, CustomMarshallerAttributeAnalyzer.StatelessValueInRequiresConvertToUnmanagedRule))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ private static void AddMissingMembers(
{
AddMissingMembersToStatelessMarshaller(editor, declaringSyntax, marshallerType, managedType, missingMemberNames, isLinearCollectionMarshaller);
}
if (marshallerType.IsValueType)
{
AddMissingMembersToStatefulMarshaller(editor, declaringSyntax, marshallerType, managedType, missingMemberNames, isLinearCollectionMarshaller);
}
}

private static void AddMissingMembersToStatelessMarshaller(DocumentEditor editor, SyntaxNode declaringSyntax, INamedTypeSymbol marshallerType, ITypeSymbol managedType, HashSet<string> missingMemberNames, bool isLinearCollectionMarshaller)
Expand Down Expand Up @@ -398,6 +402,173 @@ ITypeSymbol CreateManagedElementTypeSymbol()
}
}

private static void AddMissingMembersToStatefulMarshaller(DocumentEditor editor, SyntaxNode declaringSyntax, INamedTypeSymbol marshallerType, ITypeSymbol managedType, HashSet<string> missingMemberNames, bool isLinearCollectionMarshaller)
{
SyntaxGenerator gen = editor.Generator;
// Get the methods of the shape so we can use them to determine what types to use in signatures that are not obvious.
var (_, methods) = StatefulMarshallerShapeHelper.GetShapeForType(marshallerType, managedType, isLinearCollectionMarshaller, editor.SemanticModel.Compilation);
INamedTypeSymbol spanOfT = editor.SemanticModel.Compilation.GetBestTypeByMetadataName(TypeNames.System_Span_Metadata)!;
INamedTypeSymbol readOnlySpanOfT = editor.SemanticModel.Compilation.GetBestTypeByMetadataName(TypeNames.System_ReadOnlySpan_Metadata)!;
var (typeParameters, _) = marshallerType.GetAllTypeArgumentsIncludingInContainingTypes();

// Use a lazy factory for the type syntaxes to avoid re-checking the various methods and reconstructing the syntax.
Lazy<SyntaxNode> unmanagedTypeSyntax = new(CreateUnmanagedTypeSyntax, isThreadSafe: false);
Lazy<ITypeSymbol> managedElementTypeSymbol = new(CreateManagedElementTypeSymbol, isThreadSafe: false);

List<SyntaxNode> newMembers = new();

if (missingMemberNames.Contains(ShapeMemberNames.Value.Stateful.FromManaged))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.FromManaged,
parameters: new[] { gen.ParameterDeclaration("managed", gen.TypeExpression(managedType)) },
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.Value.Stateful.ToUnmanaged))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.ToUnmanaged,
returnType: unmanagedTypeSyntax.Value,
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.Value.Stateful.FromUnmanaged))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.FromUnmanaged,
parameters: new[] { gen.ParameterDeclaration("unmanaged", unmanagedTypeSyntax.Value) },
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.Value.Stateful.ToManaged))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.ToManaged,
returnType: gen.TypeExpression(managedType),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.BufferSize))
{
newMembers.Add(
gen.WithAccessorDeclarations(
gen.PropertyDeclaration(ShapeMemberNames.BufferSize,
gen.TypeExpression(editor.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Int32)),
Accessibility.Public,
DeclarationModifiers.Static),
gen.GetAccessorDeclaration(statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) })));
}

if (missingMemberNames.Contains(ShapeMemberNames.LinearCollection.Stateful.GetManagedValuesSource))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.LinearCollection.Stateful.GetManagedValuesSource,
returnType: gen.TypeExpression(readOnlySpanOfT.Construct(managedElementTypeSymbol.Value)),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.LinearCollection.Stateful.GetUnmanagedValuesDestination))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.LinearCollection.Stateful.GetUnmanagedValuesDestination,
returnType: gen.TypeExpression(spanOfT.Construct(typeParameters[typeParameters.Length - 1])),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.LinearCollection.Stateful.GetUnmanagedValuesSource))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.LinearCollection.Stateful.GetUnmanagedValuesSource,
parameters: new[]
{
gen.ParameterDeclaration("numElements", gen.TypeExpression(SpecialType.System_Int32))
},
returnType: gen.TypeExpression(readOnlySpanOfT.Construct(typeParameters[typeParameters.Length - 1])),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.LinearCollection.Stateful.GetManagedValuesDestination))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.LinearCollection.Stateful.GetManagedValuesDestination,
parameters: new[]
{
gen.ParameterDeclaration("numElements", gen.TypeExpression(SpecialType.System_Int32))
},
returnType: gen.TypeExpression(spanOfT.Construct(managedElementTypeSymbol.Value)),
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

if (missingMemberNames.Contains(ShapeMemberNames.Free))
{
newMembers.Add(
gen.MethodDeclaration(
ShapeMemberNames.Value.Stateful.Free,
accessibility: Accessibility.Public,
statements: new[] { DefaultMethodStatement(gen, editor.SemanticModel.Compilation) }));
}

editor.ReplaceNode(declaringSyntax, (declaringSyntax, gen) => gen.AddMembers(declaringSyntax, newMembers));

SyntaxNode CreateUnmanagedTypeSyntax()
{
ITypeSymbol? unmanagedType = null;
if (methods.ToUnmanaged is not null)
{
unmanagedType = methods.ToUnmanaged.ReturnType;
}
else if (methods.FromUnmanaged is not null)
{
unmanagedType = methods.FromUnmanaged.Parameters[0].Type;
}
else if (methods.UnmanagedValuesSource is not null)
{
unmanagedType = methods.UnmanagedValuesSource.Parameters[0].Type;
}
else if (methods.UnmanagedValuesDestination is not null)
{
unmanagedType = methods.UnmanagedValuesDestination.Parameters[0].Type;
}

if (unmanagedType is not null)
{
return gen.TypeExpression(unmanagedType);
}
return gen.TypeExpression(editor.SemanticModel.Compilation.GetSpecialType(SpecialType.System_IntPtr));
}

ITypeSymbol CreateManagedElementTypeSymbol()
{
if (methods.ManagedValuesSource is not null)
{
return ((INamedTypeSymbol)methods.ManagedValuesSource.ReturnType).TypeArguments[0];
}
if (methods.ManagedValuesDestination is not null)
{
return ((INamedTypeSymbol)methods.ManagedValuesDestination.ReturnType).TypeArguments[0];
}

return editor.SemanticModel.Compilation.GetSpecialType(SpecialType.System_IntPtr);
}
}

private static SyntaxNode DefaultMethodStatement(SyntaxGenerator generator, Compilation compilation)
{
return generator.ThrowStatement(generator.ObjectCreationExpression(
Expand Down
Loading

0 comments on commit 7e967b1

Please sign in to comment.