diff --git a/src/Vogen/GenerateCodeForMessagePack.cs b/src/Vogen/GenerateCodeForMessagePack.cs index d1f7619835..9c30e0a256 100644 --- a/src/Vogen/GenerateCodeForMessagePack.cs +++ b/src/Vogen/GenerateCodeForMessagePack.cs @@ -7,7 +7,6 @@ namespace Vogen; - internal class GenerateCodeForMessagePack { public static void GenerateForAMarkerClass(SourceProductionContext context, Compilation compilation, MarkerClassDefinition markerClass) @@ -18,11 +17,11 @@ public static void GenerateForAMarkerClass(SourceProductionContext context, Comp { return; } - + string pns = markerClassSymbol.FullNamespace() ?? ""; - + string ns = pns.Length == 0 ? "" : $"namespace {pns};"; - + var isPublic = markerClassSymbol.DeclaredAccessibility.HasFlag(Accessibility.Public); var accessor = isPublic ? "public" : "internal"; @@ -34,7 +33,8 @@ public static void GenerateForAMarkerClass(SourceProductionContext context, Comp {{accessor}} partial class {{markerClassSymbol.Name}} { - {{GenerateForEachAttribute()}} + {{GenerateManifest()}} + {{GenerateFormatters()}} } """; @@ -47,11 +47,40 @@ public static void GenerateForAMarkerClass(SourceProductionContext context, Comp return; - string GenerateForEachAttribute() + string GenerateManifest() + { + return + $$""" + {{accessor}} static global::MessagePack.Formatters.IMessagePackFormatter[] MessagePackFormatters => new global::MessagePack.Formatters.IMessagePackFormatter[] + { + {{GenerateEach()}} + }; + """; + + string GenerateEach() + { + string?[] names = markerClass.AttributeDefinitions.Where( + m => m.Marker?.Kind is ConversionMarkerKind.MessagePack).Select( + x => + { + if (x is null) return null; + if (x.Marker is null) return null; + + string? wrapperNameShort = x.Marker.VoSymbol.Name; + + return $"new {wrapperNameShort}MessagePackFormatter()"; + }).ToArray(); + + return string.Join(", ", names); + } + } + + string GenerateFormatters() { StringBuilder sb = new(); - foreach (MarkerAttributeDefinition eachMarker in markerClass.AttributeDefinitions.Where(m => m.Marker?.Kind is ConversionMarkerKind.MessagePack)) + foreach (MarkerAttributeDefinition eachMarker in markerClass.AttributeDefinitions.Where( + m => m.Marker?.Kind is ConversionMarkerKind.MessagePack)) { sb.AppendLine( $$""" @@ -63,8 +92,9 @@ string GenerateForEachAttribute() } } - - public static void GenerateForApplicableValueObjects(SourceProductionContext context, Compilation compilation, List valueObjects) + public static void GenerateForApplicableValueObjects(SourceProductionContext context, + Compilation compilation, + List valueObjects) { if (!compilation.IsAtLeastCSharpVersion(LanguageVersion.CSharp12)) { @@ -77,7 +107,7 @@ public static void GenerateForApplicableValueObjects(SourceProductionContext con List toWrite = items.Select( p => GenerateSourceAndFilename(p.WrapperAccessibility, p.WrapperType, p.ContainerNamespace, p.UnderlyingType)).ToList(); - + foreach (var eachToWrite in toWrite) { SourceText sourceText = Util.FormatSource(eachToWrite.SourceCode); @@ -89,9 +119,9 @@ public static void GenerateForApplicableValueObjects(SourceProductionContext con public record FormatterSourceAndFilename(string FormatterFullyQualifiedName, string Filename, string SourceCode); private static FormatterSourceAndFilename GenerateSourceAndFilename( - string accessibility, - INamedTypeSymbol wrapperSymbol, - string theNamespace, + string accessibility, + INamedTypeSymbol wrapperSymbol, + string theNamespace, INamedTypeSymbol underlyingSymbol) { string wrapperName = Util.EscapeIfRequired(wrapperSymbol.Name); @@ -101,9 +131,9 @@ private static FormatterSourceAndFilename GenerateSourceAndFilename( string sb = $$""" {{GeneratedCodeSegments.Preamble}} - + {{ns}} - + {{GenerateSource(accessibility, wrapperSymbol, underlyingSymbol)}} """; @@ -124,6 +154,13 @@ private static string GenerateSource(string accessibility, INamedTypeSymbol wrap string wrapperName = Util.EscapeIfRequired(wrapperSymbol.FullName() ?? wrapperSymbol.Name); string underlyingTypeName = underlyingSymbol.FullName() ?? wrapperSymbol.Name; + + string readMethod = GenerateReadMethod(); + + if (readMethod.Length == 0) + { + return "#error unsupported underlying type " + underlyingSymbol.SpecialType; + } string sb = $$""" @@ -133,7 +170,7 @@ public void Serialize(ref global::MessagePack.MessagePackWriter writer, {{wrappe writer.Write(value.Value); public {{wrapperName}} Deserialize(ref global::MessagePack.MessagePackReader reader, global::MessagePack.MessagePackSerializerOptions options) => - Deserialize(reader.{{GenerateReadMethod()}}); + Deserialize(reader.{{readMethod}}); static {{wrapperName}} Deserialize({{underlyingTypeName}} value) => UnsafeDeserialize(default, value); @@ -143,10 +180,38 @@ public void Serialize(ref global::MessagePack.MessagePackWriter writer, {{wrappe """; return sb; - } - private static string GenerateReadMethod() - { - return "ReadInt32()"; + string GenerateReadMethod() + { + if(underlyingSymbol.SpecialType == SpecialType.System_Boolean) + return "ReadBoolean()"; + if(underlyingSymbol.SpecialType == SpecialType.System_SByte) + return "ReadSByte()"; + if(underlyingSymbol.SpecialType == SpecialType.System_Byte) + return "ReadByte()"; + if(underlyingSymbol.SpecialType == SpecialType.System_Char) + return "ReadChar()"; + if(underlyingSymbol.SpecialType == SpecialType.System_DateTime) + return "ReadDateTime()"; + if(underlyingSymbol.SpecialType == SpecialType.System_Double) + return "ReadDouble()"; + if(underlyingSymbol.SpecialType == SpecialType.System_Single) + return "ReadSingle()"; + if(underlyingSymbol.SpecialType == SpecialType.System_String) + return "ReadString()"; + if(underlyingSymbol.SpecialType == SpecialType.System_Int16) + return "ReadInt16()"; + if(underlyingSymbol.SpecialType == SpecialType.System_Int32) + return "ReadInt32()"; + if(underlyingSymbol.SpecialType == SpecialType.System_Int64) + return "ReadInt64()"; + if(underlyingSymbol.SpecialType == SpecialType.System_UInt16) + return "ReadUInt16()"; + if(underlyingSymbol.SpecialType == SpecialType.System_UInt32) + return "ReadUInt32()"; + if(underlyingSymbol.SpecialType == SpecialType.System_UInt64) + return "ReadUInt64()"; + return ""; + } } } \ No newline at end of file diff --git a/tests/Testbench/Program.cs b/tests/Testbench/Program.cs index cb6942e50a..606582a606 100644 --- a/tests/Testbench/Program.cs +++ b/tests/Testbench/Program.cs @@ -1,10 +1,6 @@ using System; -using System.Collections.Generic; -using System.ComponentModel.DataAnnotations; -using System.IO; -using System.Linq; -using System.Xml.Serialization; using MessagePack; +using MessagePack.Formatters; using N1; using N2; using Vogen; @@ -13,13 +9,69 @@ namespace Testbench; [MessagePack()] +[MessagePack()] +[MessagePack()] +[MessagePack()] [MessagePack()] [EfCoreConverter] -public partial class MyMarkers; +internal partial class MyMarkers; + +[ValueObject] +public partial struct MyBool +{ +} public static class Program { public static void Main() { + // Create an instance of the sample class + var originalObject = new Sample + { + Id = MyId.From(123), + Name = Name.From("Test"), + Active = MyBool.From(true) + }; + + +// Caret is currently at line 47 + +// Create custom resolver with the MyIdFormatter + var customResolver = MessagePack.Resolvers.CompositeResolver.Create( + MyMarkers.MessagePackFormatters, +// new IMessagePackFormatter[] { new MyMarkers.MyIdMessagePackFormatter(), new MyMarkers.NameMessagePackFormatter(), new MyMarkers.MyBoolMessagePackFormatter() }, + new IFormatterResolver[] { MessagePack.Resolvers.StandardResolver.Instance } + ); + + var options = MessagePackSerializerOptions.Standard.WithResolver(customResolver); + + byte[] serializedObject = MessagePackSerializer.Serialize(originalObject, options); + + +// Deserialize the byte array back to the Sample object using the custom options + var deserializedObject = MessagePackSerializer.Deserialize(serializedObject, options); + +// Display the deserialized object + Console.WriteLine($"Id: {deserializedObject.Id}, Name: {deserializedObject.Name}, Active: {deserializedObject.Active}"); + } -} \ No newline at end of file +} + + +[MessagePackObject] +public class Sample +{ + [MessagePack.Key(0)] + public MyId Id { get; set; } + + [MessagePack.Key(1)] public Name Name { get; set; } = Name.From(""); + [MessagePack.Key(2)] public MyBool Active { get; set; } = MyBool.From(false); +} + +[ValueObject] +public partial struct MyId; + +[ValueObject] +public partial struct Name; + +