Skip to content

Commit

Permalink
Filled out reader to accomodate more types.
Browse files Browse the repository at this point in the history
Added manifest
  • Loading branch information
SteveDunn committed Nov 12, 2024
1 parent e29a05c commit 7884c8a
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 27 deletions.
105 changes: 85 additions & 20 deletions src/Vogen/GenerateCodeForMessagePack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

namespace Vogen;


internal class GenerateCodeForMessagePack
{
public static void GenerateForAMarkerClass(SourceProductionContext context, Compilation compilation, MarkerClassDefinition markerClass)
Expand All @@ -18,11 +17,11 @@ public static void GenerateForAMarkerClass(SourceProductionContext context, Comp
{
return;
}

string pns = markerClassSymbol.FullNamespace() ?? "";

string ns = pns.Length == 0 ? "" : $"namespace {pns};";

var isPublic = markerClassSymbol.DeclaredAccessibility.HasFlag(Accessibility.Public);
var accessor = isPublic ? "public" : "internal";

Expand All @@ -34,7 +33,8 @@ public static void GenerateForAMarkerClass(SourceProductionContext context, Comp

{{accessor}} partial class {{markerClassSymbol.Name}}
{
{{GenerateForEachAttribute()}}
{{GenerateManifest()}}
{{GenerateFormatters()}}
}

""";
Expand All @@ -47,11 +47,40 @@ public static void GenerateForAMarkerClass(SourceProductionContext context, Comp

return;

string GenerateForEachAttribute()
string GenerateManifest()
{
return
$$"""
{{accessor}} static global::MessagePack.Formatters.IMessagePackFormatter[] MessagePackFormatters => new global::MessagePack.Formatters.IMessagePackFormatter[]
{
{{GenerateEach()}}
};
""";

string GenerateEach()
{
string?[] names = markerClass.AttributeDefinitions.Where(
m => m.Marker?.Kind is ConversionMarkerKind.MessagePack).Select(
x =>
{
if (x is null) return null;
if (x.Marker is null) return null;
string? wrapperNameShort = x.Marker.VoSymbol.Name;
return $"new {wrapperNameShort}MessagePackFormatter()";
}).ToArray();

return string.Join(", ", names);
}
}

string GenerateFormatters()
{
StringBuilder sb = new();

foreach (MarkerAttributeDefinition eachMarker in markerClass.AttributeDefinitions.Where(m => m.Marker?.Kind is ConversionMarkerKind.MessagePack))
foreach (MarkerAttributeDefinition eachMarker in markerClass.AttributeDefinitions.Where(
m => m.Marker?.Kind is ConversionMarkerKind.MessagePack))
{
sb.AppendLine(
$$"""
Expand All @@ -63,8 +92,9 @@ string GenerateForEachAttribute()
}
}


public static void GenerateForApplicableValueObjects(SourceProductionContext context, Compilation compilation, List<VoWorkItem> valueObjects)
public static void GenerateForApplicableValueObjects(SourceProductionContext context,
Compilation compilation,
List<VoWorkItem> valueObjects)
{
if (!compilation.IsAtLeastCSharpVersion(LanguageVersion.CSharp12))
{
Expand All @@ -77,7 +107,7 @@ public static void GenerateForApplicableValueObjects(SourceProductionContext con

List<FormatterSourceAndFilename> toWrite = items.Select(
p => GenerateSourceAndFilename(p.WrapperAccessibility, p.WrapperType, p.ContainerNamespace, p.UnderlyingType)).ToList();

foreach (var eachToWrite in toWrite)
{
SourceText sourceText = Util.FormatSource(eachToWrite.SourceCode);
Expand All @@ -89,9 +119,9 @@ public static void GenerateForApplicableValueObjects(SourceProductionContext con
public record FormatterSourceAndFilename(string FormatterFullyQualifiedName, string Filename, string SourceCode);

private static FormatterSourceAndFilename GenerateSourceAndFilename(
string accessibility,
INamedTypeSymbol wrapperSymbol,
string theNamespace,
string accessibility,
INamedTypeSymbol wrapperSymbol,
string theNamespace,
INamedTypeSymbol underlyingSymbol)
{
string wrapperName = Util.EscapeIfRequired(wrapperSymbol.Name);
Expand All @@ -101,9 +131,9 @@ private static FormatterSourceAndFilename GenerateSourceAndFilename(
string sb =
$$"""
{{GeneratedCodeSegments.Preamble}}

{{ns}}

{{GenerateSource(accessibility, wrapperSymbol, underlyingSymbol)}}
""";

Expand All @@ -124,6 +154,13 @@ private static string GenerateSource(string accessibility, INamedTypeSymbol wrap
string wrapperName = Util.EscapeIfRequired(wrapperSymbol.FullName() ?? wrapperSymbol.Name);

string underlyingTypeName = underlyingSymbol.FullName() ?? wrapperSymbol.Name;

string readMethod = GenerateReadMethod();

if (readMethod.Length == 0)
{
return "#error unsupported underlying type " + underlyingSymbol.SpecialType;
}

string sb =
$$"""
Expand All @@ -133,7 +170,7 @@ public void Serialize(ref global::MessagePack.MessagePackWriter writer, {{wrappe
writer.Write(value.Value);
public {{wrapperName}} Deserialize(ref global::MessagePack.MessagePackReader reader, global::MessagePack.MessagePackSerializerOptions options) =>
Deserialize(reader.{{GenerateReadMethod()}});
Deserialize(reader.{{readMethod}});

static {{wrapperName}} Deserialize({{underlyingTypeName}} value) => UnsafeDeserialize(default, value);

Expand All @@ -143,10 +180,38 @@ public void Serialize(ref global::MessagePack.MessagePackWriter writer, {{wrappe
""";

return sb;
}

private static string GenerateReadMethod()
{
return "ReadInt32()";
string GenerateReadMethod()
{
if(underlyingSymbol.SpecialType == SpecialType.System_Boolean)
return "ReadBoolean()";
if(underlyingSymbol.SpecialType == SpecialType.System_SByte)
return "ReadSByte()";
if(underlyingSymbol.SpecialType == SpecialType.System_Byte)
return "ReadByte()";
if(underlyingSymbol.SpecialType == SpecialType.System_Char)
return "ReadChar()";
if(underlyingSymbol.SpecialType == SpecialType.System_DateTime)
return "ReadDateTime()";
if(underlyingSymbol.SpecialType == SpecialType.System_Double)
return "ReadDouble()";
if(underlyingSymbol.SpecialType == SpecialType.System_Single)
return "ReadSingle()";
if(underlyingSymbol.SpecialType == SpecialType.System_String)
return "ReadString()";
if(underlyingSymbol.SpecialType == SpecialType.System_Int16)
return "ReadInt16()";
if(underlyingSymbol.SpecialType == SpecialType.System_Int32)
return "ReadInt32()";
if(underlyingSymbol.SpecialType == SpecialType.System_Int64)
return "ReadInt64()";
if(underlyingSymbol.SpecialType == SpecialType.System_UInt16)
return "ReadUInt16()";
if(underlyingSymbol.SpecialType == SpecialType.System_UInt32)
return "ReadUInt32()";
if(underlyingSymbol.SpecialType == SpecialType.System_UInt64)
return "ReadUInt64()";
return "";
}
}
}
66 changes: 59 additions & 7 deletions tests/Testbench/Program.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.IO;
using System.Linq;
using System.Xml.Serialization;
using MessagePack;
using MessagePack.Formatters;
using N1;
using N2;
using Vogen;
Expand All @@ -13,13 +9,69 @@
namespace Testbench;

[MessagePack<MyInt>()]
[MessagePack<MyId>()]
[MessagePack<MyBool>()]
[MessagePack<Name>()]
[MessagePack<MyString>()]
[EfCoreConverter<MyInt>]
public partial class MyMarkers;
internal partial class MyMarkers;

[ValueObject<bool>]
public partial struct MyBool
{
}

public static class Program
{
public static void Main()
{
// Create an instance of the sample class
var originalObject = new Sample
{
Id = MyId.From(123),
Name = Name.From("Test"),
Active = MyBool.From(true)
};


// Caret is currently at line 47

// Create custom resolver with the MyIdFormatter
var customResolver = MessagePack.Resolvers.CompositeResolver.Create(
MyMarkers.MessagePackFormatters,
// new IMessagePackFormatter[] { new MyMarkers.MyIdMessagePackFormatter(), new MyMarkers.NameMessagePackFormatter(), new MyMarkers.MyBoolMessagePackFormatter() },
new IFormatterResolver[] { MessagePack.Resolvers.StandardResolver.Instance }
);

var options = MessagePackSerializerOptions.Standard.WithResolver(customResolver);

byte[] serializedObject = MessagePackSerializer.Serialize(originalObject, options);


// Deserialize the byte array back to the Sample object using the custom options
var deserializedObject = MessagePackSerializer.Deserialize<Sample>(serializedObject, options);

// Display the deserialized object
Console.WriteLine($"Id: {deserializedObject.Id}, Name: {deserializedObject.Name}, Active: {deserializedObject.Active}");

}
}
}


[MessagePackObject]
public class Sample
{
[MessagePack.Key(0)]
public MyId Id { get; set; }

[MessagePack.Key(1)] public Name Name { get; set; } = Name.From("");
[MessagePack.Key(2)] public MyBool Active { get; set; } = MyBool.From(false);
}

[ValueObject<int>]
public partial struct MyId;

[ValueObject<string>]
public partial struct Name;


0 comments on commit 7884c8a

Please sign in to comment.