Skip to content

Commit

Permalink
FIX | Removing BinaryFormatter from NetFx (dotnet#869)
Browse files Browse the repository at this point in the history
* Removing BinaryFormatter from NetFx

* review comments

* fix version typo

* remove extra line

* Reverted SqlException Test

* review comments

* Review comment

* Desrialize

* addressing review comments

* Fix exception in deserialization (#1)

* review comments

* add extra line to the end of strings designer

* end of line

Co-authored-by: jJRahnama <jrahnama@simba.com>
Co-authored-by: Karina Zhou <v-jizho2@microsoft.com>
  • Loading branch information
3 people authored Feb 18, 2021
1 parent b5d7bb6 commit 25cde90
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
using System.Runtime.CompilerServices;
using System.Runtime.Remoting;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using System.Runtime.Versioning;
using System.Security.Permissions;
using System.Text;
Expand Down Expand Up @@ -241,29 +240,39 @@ private static void InvokeCallback(object eventContextPair)
// END EventContextPair private class.
// ----------------------------------------

// ----------------------------------------
// Private class for restricting allowed types from deserialization.
// ----------------------------------------

private class SqlDependencyProcessDispatcherSerializationBinder : SerializationBinder
//-----------------------------------------------
// Private Class to add ObjRef as DataContract
//-----------------------------------------------
[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)]
[DataContract]
private class SqlClientObjRef
{
public override Type BindToType(string assemblyName, string typeName)
[DataMember]
private static ObjRef s_sqlObjRef;
internal static IRemotingTypeInfo _typeInfo;

private SqlClientObjRef() { }

public SqlClientObjRef(SqlDependencyProcessDispatcher dispatcher) : base()
{
// Deserializing an unexpected type can inject objects with malicious side effects.
// If the type is unexpected, throw an exception to stop deserialization.
if (typeName == nameof(SqlDependencyProcessDispatcher))
{
return typeof(SqlDependencyProcessDispatcher);
}
else
{
throw new ArgumentException("Unexpected type", nameof(typeName));
}
s_sqlObjRef = RemotingServices.Marshal(dispatcher);
_typeInfo = s_sqlObjRef.TypeInfo;
}

internal static bool CanCastToSqlDependencyProcessDispatcher()
{
return _typeInfo.CanCastTo(typeof(SqlDependencyProcessDispatcher), s_sqlObjRef);
}

internal ObjRef GetObjRef()
{
return s_sqlObjRef;
}

}
// ----------------------------------------
// END SqlDependencyProcessDispatcherSerializationBinder private class.
// ----------------------------------------
// ------------------------------------------
// End SqlClientObjRef private class.
// -------------------------------------------

// ----------------
// Instance members
Expand Down Expand Up @@ -306,10 +315,9 @@ public override Type BindToType(string assemblyName, string typeName)
private static readonly string _typeName = (typeof(SqlDependencyProcessDispatcher)).FullName;

// -----------
// BID members
// EventSource members
// -----------


private readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount);
private static int _objectTypeCount; // EventSource Counter
internal int ObjectID
Expand All @@ -336,7 +344,7 @@ public SqlDependency(SqlCommand command) : this(command, null, SQL.SqlDependency
}

/// <include file='..\..\..\..\..\..\..\doc\snippets\Microsoft.Data.SqlClient\SqlDependency.xml' path='docs/members[@name="SqlDependency"]/ctorCommandOptionsTimeout/*' />
[System.Security.Permissions.HostProtectionAttribute(ExternalThreading = true)]
[HostProtection(ExternalThreading = true)]
public SqlDependency(SqlCommand command, string options, int timeout)
{
long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent("<sc.SqlDependency|DEP> {0}, options: '{1}', timeout: '{2}'", ObjectID, options, timeout);
Expand Down Expand Up @@ -597,11 +605,13 @@ private static void ObtainProcessDispatcher()
_processDispatcher = dependency.SingletonProcessDispatcher; // Set to static instance.

// Serialize and set in native.
ObjRef objRef = GetObjRef(_processDispatcher);
BinaryFormatter formatter = new BinaryFormatter();
MemoryStream stream = new MemoryStream();
GetSerializedObject(objRef, formatter, stream);
SNINativeMethodWrapper.SetData(stream.GetBuffer()); // Native will be forced to synchronize and not overwrite.
using (MemoryStream stream = new MemoryStream())
{
SqlClientObjRef objRef = new SqlClientObjRef(_processDispatcher);
DataContractSerializer serializer = new DataContractSerializer(objRef.GetType());
GetSerializedObject(objRef, serializer, stream);
SNINativeMethodWrapper.SetData(stream.ToArray()); // Native will be forced to synchronize and not overwrite.
}
}
else
{
Expand All @@ -628,37 +638,39 @@ private static void ObtainProcessDispatcher()
#if DEBUG // Possibly expensive, limit to debug.
SqlClientEventSource.Log.TryNotificationTraceEvent("<sc.SqlDependency.ObtainProcessDispatcher|DEP> AppDomain.CurrentDomain.FriendlyName: {0}", AppDomain.CurrentDomain.FriendlyName);
#endif
BinaryFormatter formatter = new BinaryFormatter();
MemoryStream stream = new MemoryStream(nativeStorage);
_processDispatcher = GetDeserializedObject(formatter, stream); // Deserialize and set for appdomain.
SqlClientEventSource.Log.TryNotificationTraceEvent("<sc.SqlDependency.ObtainProcessDispatcher|DEP> processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID);
using (MemoryStream stream = new MemoryStream(nativeStorage))
{
DataContractSerializer serializer = new DataContractSerializer(typeof(SqlClientObjRef));
if (SqlClientObjRef.CanCastToSqlDependencyProcessDispatcher())
{
// Deserialize and set for appdomain.
_processDispatcher = GetDeserializedObject(serializer, stream);
}
else
{
throw new ArgumentException(Strings.SqlDependency_UnexpectedValueOnDeserialize);
}
SqlClientEventSource.Log.TryNotificationTraceEvent("<sc.SqlDependency.ObtainProcessDispatcher|DEP> processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID);
}
}
}

// ---------------------------------------------------------
// Static security asserted methods - limit scope of assert.
// ---------------------------------------------------------

[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)]
private static ObjRef GetObjRef(SqlDependencyProcessDispatcher _processDispatcher)
{
return RemotingServices.Marshal(_processDispatcher);
}

[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)]
private static void GetSerializedObject(ObjRef objRef, BinaryFormatter formatter, MemoryStream stream)
private static void GetSerializedObject(SqlClientObjRef objRef, DataContractSerializer serializer, MemoryStream stream)
{
formatter.Serialize(stream, objRef);
serializer.WriteObject(stream, objRef);
}

[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)]
private static SqlDependencyProcessDispatcher GetDeserializedObject(BinaryFormatter formatter, MemoryStream stream)
private static SqlDependencyProcessDispatcher GetDeserializedObject(DataContractSerializer serializer, MemoryStream stream)
{
// Use a custom SerializationBinder to restrict deserialized types to SqlDependencyProcessDispatcher.
formatter.Binder = new SqlDependencyProcessDispatcherSerializationBinder();
object result = formatter.Deserialize(stream);
Debug.Assert(result.GetType() == typeof(SqlDependencyProcessDispatcher), "Unexpected type stored in native!");
return (SqlDependencyProcessDispatcher)result;
object refResult = serializer.ReadObject(stream);
var result = RemotingServices.Unmarshal((refResult as SqlClientObjRef).GetObjRef());
return result as SqlDependencyProcessDispatcher;
}

// -------------------------
Expand Down Expand Up @@ -1325,7 +1337,6 @@ private void AddCommandInternal(SqlCommand cmd)
{
if (cmd != null)
{
// Don't bother with BID if command null.
long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent("<sc.SqlDependency.AddCommandInternal|DEP> {0}, SqlCommand: {1}", ObjectID, cmd.ObjectID);
try
{
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -4602,4 +4602,7 @@
<data name="Azure_RetryFailure" xml:space="preserve">
<value>Failed after 5 retries.</value>
</data>
</root>
<data name="SqlDependency_UnexpectedValueOnDeserialize" xml:space="preserve">
<value>Unexpected type detected on deserialize.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ public void SerializationTest()
Assert.Equal(e.StackTrace, sqlEx.StackTrace);
}


[Fact]
[ActiveIssue("12161", TestPlatforms.AnyUnix)]
public static void SqlExcpetionSerializationTest()
Expand Down

0 comments on commit 25cde90

Please sign in to comment.