Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX | Removing BinaryFormatter from NetFx #869

Merged
merged 14 commits into from
Feb 18, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
using System.Runtime.CompilerServices;
using System.Runtime.Remoting;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using System.Runtime.Versioning;
using System.Security.Permissions;
using System.Text;
Expand Down Expand Up @@ -241,29 +240,39 @@ private static void InvokeCallback(object eventContextPair)
// END EventContextPair private class.
// ----------------------------------------

JRahnama marked this conversation as resolved.
Show resolved Hide resolved
// ----------------------------------------
// Private class for restricting allowed types from deserialization.
// ----------------------------------------

private class SqlDependencyProcessDispatcherSerializationBinder : SerializationBinder
//-----------------------------------------------
// Private Class to add ObjRef as DataContract
//-----------------------------------------------
[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)]
[DataContract]
private class SqlClientObjRef
{
public override Type BindToType(string assemblyName, string typeName)
[DataMember]
private static ObjRef s_sqlObjRef;
internal static IRemotingTypeInfo _typeInfo;

private SqlClientObjRef() { }

public SqlClientObjRef(SqlDependencyProcessDispatcher dispatcher) : base()
{
// Deserializing an unexpected type can inject objects with malicious side effects.
// If the type is unexpected, throw an exception to stop deserialization.
if (typeName == nameof(SqlDependencyProcessDispatcher))
{
return typeof(SqlDependencyProcessDispatcher);
}
else
{
throw new ArgumentException("Unexpected type", nameof(typeName));
}
s_sqlObjRef = RemotingServices.Marshal(dispatcher);
_typeInfo = s_sqlObjRef.TypeInfo;
}

internal static bool CanCastToSqlDependencyProcessDispatcher()
{
return _typeInfo.CanCastTo(typeof(SqlDependencyProcessDispatcher), s_sqlObjRef);
}

internal ObjRef GetObjRef()
{
return s_sqlObjRef;
}

}
// ----------------------------------------
// END SqlDependencyProcessDispatcherSerializationBinder private class.
// ----------------------------------------
// ------------------------------------------
// End SqlClientObjRef private class.
// -------------------------------------------

// ----------------
// Instance members
Expand Down Expand Up @@ -306,10 +315,9 @@ public override Type BindToType(string assemblyName, string typeName)
private static readonly string _typeName = (typeof(SqlDependencyProcessDispatcher)).FullName;

// -----------
// BID members
// EventSource members
// -----------


private readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount);
private static int _objectTypeCount; // EventSource Counter
internal int ObjectID
Expand All @@ -336,7 +344,7 @@ public SqlDependency(SqlCommand command) : this(command, null, SQL.SqlDependency
}

/// <include file='..\..\..\..\..\..\..\doc\snippets\Microsoft.Data.SqlClient\SqlDependency.xml' path='docs/members[@name="SqlDependency"]/ctorCommandOptionsTimeout/*' />
[System.Security.Permissions.HostProtectionAttribute(ExternalThreading = true)]
[HostProtection(ExternalThreading = true)]
public SqlDependency(SqlCommand command, string options, int timeout)
{
long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent("<sc.SqlDependency|DEP> {0}, options: '{1}', timeout: '{2}'", ObjectID, options, timeout);
Expand Down Expand Up @@ -597,11 +605,13 @@ private static void ObtainProcessDispatcher()
_processDispatcher = dependency.SingletonProcessDispatcher; // Set to static instance.

// Serialize and set in native.
ObjRef objRef = GetObjRef(_processDispatcher);
BinaryFormatter formatter = new BinaryFormatter();
MemoryStream stream = new MemoryStream();
GetSerializedObject(objRef, formatter, stream);
SNINativeMethodWrapper.SetData(stream.GetBuffer()); // Native will be forced to synchronize and not overwrite.
using (MemoryStream stream = new MemoryStream())
{
SqlClientObjRef objRef = new SqlClientObjRef(_processDispatcher);
DataContractSerializer serializer = new DataContractSerializer(objRef.GetType());
GetSerializedObject(objRef, serializer, stream);
SNINativeMethodWrapper.SetData(stream.ToArray()); // Native will be forced to synchronize and not overwrite.
}
}
else
{
Expand All @@ -628,37 +638,39 @@ private static void ObtainProcessDispatcher()
#if DEBUG // Possibly expensive, limit to debug.
SqlClientEventSource.Log.TryNotificationTraceEvent("<sc.SqlDependency.ObtainProcessDispatcher|DEP> AppDomain.CurrentDomain.FriendlyName: {0}", AppDomain.CurrentDomain.FriendlyName);
#endif
BinaryFormatter formatter = new BinaryFormatter();
MemoryStream stream = new MemoryStream(nativeStorage);
_processDispatcher = GetDeserializedObject(formatter, stream); // Deserialize and set for appdomain.
SqlClientEventSource.Log.TryNotificationTraceEvent("<sc.SqlDependency.ObtainProcessDispatcher|DEP> processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID);
using (MemoryStream stream = new MemoryStream(nativeStorage))
{
DataContractSerializer serializer = new DataContractSerializer(typeof(SqlClientObjRef));
if (SqlClientObjRef.CanCastToSqlDependencyProcessDispatcher())
{
// Deserialize and set for appdomain.
_processDispatcher = GetDeserializedObject(serializer, stream);
}
else
{
throw new ArgumentException(Strings.SqlDependency_UnexpectedValueOnDeserialize);
}
SqlClientEventSource.Log.TryNotificationTraceEvent("<sc.SqlDependency.ObtainProcessDispatcher|DEP> processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID);
}
}
}

// ---------------------------------------------------------
// Static security asserted methods - limit scope of assert.
// ---------------------------------------------------------

[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)]
private static ObjRef GetObjRef(SqlDependencyProcessDispatcher _processDispatcher)
{
return RemotingServices.Marshal(_processDispatcher);
}

[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)]
private static void GetSerializedObject(ObjRef objRef, BinaryFormatter formatter, MemoryStream stream)
private static void GetSerializedObject(SqlClientObjRef objRef, DataContractSerializer serializer, MemoryStream stream)
{
formatter.Serialize(stream, objRef);
serializer.WriteObject(stream, objRef);
}

[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)]
private static SqlDependencyProcessDispatcher GetDeserializedObject(BinaryFormatter formatter, MemoryStream stream)
private static SqlDependencyProcessDispatcher GetDeserializedObject(DataContractSerializer serializer, MemoryStream stream)
{
// Use a custom SerializationBinder to restrict deserialized types to SqlDependencyProcessDispatcher.
formatter.Binder = new SqlDependencyProcessDispatcherSerializationBinder();
object result = formatter.Deserialize(stream);
Debug.Assert(result.GetType() == typeof(SqlDependencyProcessDispatcher), "Unexpected type stored in native!");
return (SqlDependencyProcessDispatcher)result;
object refResult = serializer.ReadObject(stream);
var result = RemotingServices.Unmarshal((refResult as SqlClientObjRef).GetObjRef());
return result as SqlDependencyProcessDispatcher;
}

// -------------------------
Expand Down Expand Up @@ -1325,7 +1337,6 @@ private void AddCommandInternal(SqlCommand cmd)
{
if (cmd != null)
{
// Don't bother with BID if command null.
long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent("<sc.SqlDependency.AddCommandInternal|DEP> {0}, SqlCommand: {1}", ObjectID, cmd.ObjectID);
try
{
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -4602,4 +4602,7 @@
<data name="Azure_RetryFailure" xml:space="preserve">
<value>Failed after 5 retries.</value>
</data>
</root>
<data name="SqlDependency_UnexpectedValueOnDeserialize" xml:space="preserve">
<value>Unexpected type detected on deserialize.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ public void SerializationTest()
Assert.Equal(e.StackTrace, sqlEx.StackTrace);
}


[Fact]
[ActiveIssue("12161", TestPlatforms.AnyUnix)]
public static void SqlExcpetionSerializationTest()
Expand Down