Skip to content

Commit

Permalink
CSHARP-4337: Use correct serializer for conditional result.
Browse files Browse the repository at this point in the history
  • Loading branch information
rstam committed Oct 7, 2022
1 parent 3d67e80 commit ca00c14
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 46 deletions.
14 changes: 14 additions & 0 deletions src/MongoDB.Bson/Serialization/Serializers/EnumSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ public override TEnum Deserialize(BsonDeserializationContext context, BsonDeseri
}
}

/// <inheritdoc/>
public override bool Equals(object obj)
{
return
obj is EnumSerializer<TEnum> other &&
_representation == other.Representation;
}

/// <inheritdoc/>
public override int GetHashCode()
{
return _representation.GetHashCode();
}

/// <summary>
/// Serializes a value.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,42 @@ public override Expression Visit(Expression node)

var result = base.Visit(node);
_registry.Add(node, _currentKnownSerializersNode);
_currentKnownSerializersNode = _currentKnownSerializersNode.Parent;

var parent = _currentKnownSerializersNode.Parent;
if (ShouldPropagateKnownSerializersToParent(parent))
{
parent.AddKnownSerializersFromChild(_currentKnownSerializersNode);
}
_currentKnownSerializersNode = parent;

return result;
}

protected override Expression VisitConditional(ConditionalExpression node)
{
var result = base.VisitConditional(node);

if (_currentKnownSerializersNode.KnownSerializers.TryGetValue(node.Type, out var resultSerializers) &&
resultSerializers.Count > 1)
{
var ifTrueSerializer = _registry.GetSerializerAtThisLevel(node.IfTrue);
var ifFalseSerializer = _registry.GetSerializerAtThisLevel(node.IfFalse);

if (ifTrueSerializer != null && ifFalseSerializer != null && !ifTrueSerializer.Equals(ifFalseSerializer))
{
throw new ExpressionNotSupportedException(node, because: "IfTrue and IfFalse expressions have different serializers");
}

if (ifTrueSerializer != null)
{
_currentKnownSerializersNode.SetKnownSerializerForType(node.Type, ifTrueSerializer);
}
else if (ifFalseSerializer != null)
{
_currentKnownSerializersNode.SetKnownSerializerForType(node.Type, ifFalseSerializer);
}
}

return result;
}

Expand All @@ -87,14 +122,14 @@ protected override Expression VisitBinary(BinaryExpression node)
{
var rightExpressionSerializer = _registry.GetSerializer(rightExpression);
var leftExpressionSerializer = EnumUnderlyingTypeSerializer.Create(rightExpressionSerializer);
_registry.AddKnownSerializer(leftExpression, leftExpressionSerializer, allowPropagation: false);
_registry.SetNodeSerializer(leftExpression, leftExpressionSerializer);
}

if (rightExpression is ConstantExpression rightConstantExpression)
{
var leftExpressionSerializer = _registry.GetSerializer(leftExpression);
var rightExpressionSerializer = EnumUnderlyingTypeSerializer.Create(leftExpressionSerializer);
_registry.AddKnownSerializer(rightExpression, rightExpressionSerializer, allowPropagation: false);
_registry.SetNodeSerializer(rightExpression, rightExpressionSerializer);
}
}
}
Expand Down Expand Up @@ -202,5 +237,20 @@ protected override Expression VisitParameter(ParameterExpression node)

return result;
}

private bool ShouldPropagateKnownSerializersToParent(KnownSerializersNode parent)
{
if (parent == null)
{
return false;
}

return parent.Expression.NodeType switch
{
ExpressionType.MemberInit => false,
ExpressionType.New => false,
_ => true
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ internal class KnownSerializersNode
// private fields
private readonly Expression _expression;
private readonly Dictionary<Type, HashSet<IBsonSerializer>> _knownSerializers = new Dictionary<Type, HashSet<IBsonSerializer>>();
private IBsonSerializer _nodeSerializer; // a serializer used only for this node (not propagated upwards)
private readonly KnownSerializersNode _parent;

// constructors
Expand All @@ -42,7 +43,16 @@ public KnownSerializersNode(Expression expression, KnownSerializersNode parent)
public KnownSerializersNode Parent => _parent;

// public methods
public void AddKnownSerializer(Type type, IBsonSerializer serializer, bool allowPropagation = true)
public void AddKnownSerializersFromChild(KnownSerializersNode child)
{
foreach (var type in child.KnownSerializers.Keys)
foreach (var serializer in child.KnownSerializers[type])
{
AddKnownSerializer(type, serializer);
}
}

public void AddKnownSerializer(Type type, IBsonSerializer serializer)
{
if (!_knownSerializers.TryGetValue(type, out var set))
{
Expand All @@ -51,15 +61,35 @@ public void AddKnownSerializer(Type type, IBsonSerializer serializer, bool allow
}

set.Add(serializer);
}

public void SetKnownSerializerForType(Type type, IBsonSerializer serializer)
{
if (serializer.ValueType != type)
{
throw new ArgumentException($"Serializer value type {serializer.ValueType} does not match expected type {type}.");
}

_knownSerializers[type] = new HashSet<IBsonSerializer> { serializer };
}

if (allowPropagation && ShouldPropagateKnownSerializerToParent())
public void SetNodeSerializer(IBsonSerializer serializer)
{
if (serializer.ValueType != _expression.Type)
{
_parent.AddKnownSerializer(type, serializer);
throw new ArgumentException($"Serializer value type {serializer.ValueType} does not match expression type {_expression.Type}.");
}

_nodeSerializer = serializer;
}

public HashSet<IBsonSerializer> GetPossibleSerializers(Type type)
{
if (_nodeSerializer != null && _nodeSerializer.ValueType == type)
{
return new HashSet<IBsonSerializer> { _nodeSerializer };
}

var possibleSerializers = GetPossibleSerializersAtThisLevel(type);
if (possibleSerializers.Count > 0)
{
Expand Down Expand Up @@ -115,20 +145,5 @@ private HashSet<IBsonSerializer> GetPossibleSerializersAtThisLevel(Type type)

return possibleSerializers;
}

private bool ShouldPropagateKnownSerializerToParent()
{
if (_parent == null)
{
return false;
}

return _parent.Expression.NodeType switch
{
ExpressionType.MemberInit => false,
ExpressionType.New => false,
_ => true
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ public void Add(Expression expression, KnownSerializersNode knownSerializers)
_registry.Add(expression, knownSerializers);
}

public void AddKnownSerializer(Expression expression, IBsonSerializer knownSerializer, bool allowPropagation = true)
public void SetNodeSerializer(Expression expression, IBsonSerializer nodeSerializer)
{
if (knownSerializer.ValueType != expression.Type)
if (nodeSerializer.ValueType != expression.Type)
{
throw new ArgumentException($"Serializer value type {knownSerializer.ValueType} does not match expresion type {expression.Type}.", nameof(knownSerializer));
throw new ArgumentException($"Serializer value type {nodeSerializer.ValueType} does not match expresion type {expression.Type}.", nameof(nodeSerializer));
}

if (!_registry.TryGetValue(expression, out var knownSerializers))
{
throw new InvalidOperationException("KnownSerializersNode does not exist yet for expression: {expression}.");
}

knownSerializers.AddKnownSerializer(expression.Type, knownSerializer, allowPropagation);
knownSerializers.SetNodeSerializer(nodeSerializer);
}

public IBsonSerializer GetSerializer(Expression expression, IBsonSerializer defaultSerializer = null)
Expand All @@ -74,6 +74,18 @@ public IBsonSerializer GetSerializer(Expression expression, Type type, IBsonSeri
};
}

public IBsonSerializer GetSerializerAtThisLevel(Expression expression)
{
var expressionType = expression is LambdaExpression lambdaExpression ? lambdaExpression.ReturnType : expression.Type;
return GetSerializerAtThisLevel(expression, expressionType);
}

public IBsonSerializer GetSerializerAtThisLevel(Expression expression, Type type)
{
var possibleSerializers = _registry.TryGetValue(expression, out var knownSerializers) ? knownSerializers.GetPossibleSerializers(type) : new HashSet<IBsonSerializer>();
return possibleSerializers.Count == 1 ? possibleSerializers.Single() : null;
}

private IBsonSerializer LookupSerializer(Expression expression, Type type)
{
if (type.IsConstructedGenericType &&
Expand Down
Loading

0 comments on commit ca00c14

Please sign in to comment.