diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index c3ebf3d4d7..bbb0e0e446 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -4363,7 +4363,7 @@ private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDi SqlParameter sqlParameter = rpc.userParams[index]; Debug.Assert(sqlParameter != null, "sqlParameter should not be null."); - if (sqlParameter.ParameterNameFixed.Equals(parameterName, StringComparison.Ordinal)) + if (SqlParameter.ParameterNamesEqual(sqlParameter.ParameterName,parameterName,StringComparison.Ordinal)) { Debug.Assert(sqlParameter.CipherMetadata == null, "param.CipherMetadata should be null."); sqlParameter.HasReceivedMetadata = true; @@ -5457,7 +5457,7 @@ internal void OnReturnValue(SqlReturnValue rec, TdsParserStateObject stateObj) { if (rec.tdsType != TdsEnums.SQLBIGVARBINARY) { - throw SQL.InvalidDataTypeForEncryptedParameter(thisParam.ParameterNameFixed, rec.tdsType, TdsEnums.SQLBIGVARBINARY); + throw SQL.InvalidDataTypeForEncryptedParameter(thisParam.GetPrefixedParameterName(), rec.tdsType, TdsEnums.SQLBIGVARBINARY); } // Decrypt the ciphertext @@ -5487,7 +5487,7 @@ internal void OnReturnValue(SqlReturnValue rec, TdsParserStateObject stateObj) } catch (Exception e) { - throw SQL.ParamDecryptionFailed(thisParam.ParameterNameFixed, null, e); + throw SQL.ParamDecryptionFailed(thisParam.GetPrefixedParameterName(), null, e); } } else @@ -5628,7 +5628,11 @@ private SqlParameter GetParameterForOutputValueExtraction(SqlParameterCollection { thisParam = parameters[i]; // searching for Output or InputOutput or ReturnValue with matching name - if (thisParam.Direction != ParameterDirection.Input && thisParam.Direction != ParameterDirection.ReturnValue && paramName == thisParam.ParameterNameFixed) + if ( + thisParam.Direction != ParameterDirection.Input && + thisParam.Direction != ParameterDirection.ReturnValue && + SqlParameter.ParameterNamesEqual(paramName, thisParam.ParameterName,StringComparison.Ordinal) + ) { foundParam = true; break; // found it @@ -5999,11 +6003,11 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto // Find the return value parameter (if any). SqlParameter returnValueParameter = null; - foreach (SqlParameter parameter in parameters) + foreach (SqlParameter param in parameters) { - if (parameter.Direction == ParameterDirection.ReturnValue) + if (param.Direction == ParameterDirection.ReturnValue) { - returnValueParameter = parameter; + returnValueParameter = param; break; } } @@ -6012,7 +6016,8 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto // EXEC @returnValue = moduleName [parameters] if (returnValueParameter != null) { - execStatement.AppendFormat(@"{0}=", returnValueParameter.ParameterNameFixed); + SqlParameter.AppendPrefixedParameterName(execStatement, returnValueParameter.ParameterName); + execStatement.Append('='); } execStatement.Append(ParseAndQuoteIdentifier(storedProcedureName, false)); @@ -6023,6 +6028,7 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto // Append the first parameter int index = 0; int count = parameters.Count; + SqlParameter parameter; if (count > 0) { // Skip the return value parameters. @@ -6033,15 +6039,19 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto if (index < count) { + parameter = parameters[index]; // Possibility of a SQL Injection issue through parameter names and how to construct valid identifier for parameters. // Since the parameters comes from application itself, there should not be a security vulnerability. // Also since the query is not executed, but only analyzed there is no possibility for elevation of privilege, but only for // incorrect results which would only affect the user that attempts the injection. - execStatement.AppendFormat(@" {0}={0}", parameters[index].ParameterNameFixed); + execStatement.Append(' '); + SqlParameter.AppendPrefixedParameterName(execStatement, parameter.ParameterName); + execStatement.Append('='); + SqlParameter.AppendPrefixedParameterName(execStatement, parameter.ParameterName); // InputOutput and Output parameters need to be marked as such. - if (parameters[index].Direction == ParameterDirection.Output || - parameters[index].Direction == ParameterDirection.InputOutput) + if (parameter.Direction == ParameterDirection.Output || + parameter.Direction == ParameterDirection.InputOutput) { execStatement.AppendFormat(@" OUTPUT"); } @@ -6054,14 +6064,18 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto // Append the rest of parameters for (; index < count; index++) { - if (parameters[index].Direction != ParameterDirection.ReturnValue) + parameter = parameters[index]; + if (parameter.Direction != ParameterDirection.ReturnValue) { - execStatement.AppendFormat(@", {0}={0}", parameters[index].ParameterNameFixed); + execStatement.Append(", "); + SqlParameter.AppendPrefixedParameterName(execStatement, parameter.ParameterName); + execStatement.Append('='); + SqlParameter.AppendPrefixedParameterName(execStatement, parameter.ParameterName); // InputOutput and Output parameters need to be marked as such. if ( - parameters[index].Direction == ParameterDirection.Output || - parameters[index].Direction == ParameterDirection.InputOutput + parameter.Direction == ParameterDirection.Output || + parameter.Direction == ParameterDirection.InputOutput ) { execStatement.AppendFormat(@" OUTPUT"); @@ -6095,9 +6109,11 @@ internal string BuildParamList(TdsParser parser, SqlParameterCollection paramete // add our separator for the ith parameter if (fAddSeparator) + { paramList.Append(','); + } - paramList.Append(sqlParam.ParameterNameFixed); + SqlParameter.AppendPrefixedParameterName(paramList, sqlParam.ParameterName); MetaType mt = sqlParam.InternalMetaType; @@ -6120,7 +6136,7 @@ internal string BuildParamList(TdsParser parser, SqlParameterCollection paramete string typeName = sqlParam.TypeName; if (string.IsNullOrEmpty(typeName)) { - throw SQL.MustSetTypeNameForParam(mt.TypeName, sqlParam.ParameterNameFixed); + throw SQL.MustSetTypeNameForParam(mt.TypeName, sqlParam.GetPrefixedParameterName()); } paramList.Append(ParseAndQuoteIdentifier(typeName, false /* is not UdtTypeName*/)); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 4998251691..e3de792213 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -9312,7 +9312,7 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet } } - WriteParameterName(param.ParameterNameFixed, stateObj, isAnonymous); + WriteParameterName(param.ParameterName, stateObj, isAnonymous); // Write parameter status stateObj.WriteByte(options); @@ -9833,17 +9833,34 @@ private void ExecuteFlushTaskCallback(Task tsk, TdsParserStateObject stateObj, T } } - - private void WriteParameterName(string parameterName, TdsParserStateObject stateObj, bool isAnonymous) + /// + /// Will check the parameter name for the required @ prefix and then write the correct prefixed + /// form and correct character length to the output buffer + /// + private void WriteParameterName(string rawParameterName, TdsParserStateObject stateObj, bool isAnonymous) { // paramLen // paramName - if (!isAnonymous && !string.IsNullOrEmpty(parameterName)) + if (!isAnonymous && !string.IsNullOrEmpty(rawParameterName)) { - Debug.Assert(parameterName.Length <= 0xff, "parameter name can only be 255 bytes, shouldn't get to TdsParser!"); - int tempLen = parameterName.Length & 0xff; - stateObj.WriteByte((byte)tempLen); - WriteString(parameterName, tempLen, 0, stateObj); + int nameLength = rawParameterName.Length; + int totalLength = nameLength; + bool writePrefix = false; + if (nameLength > 0) + { + if (rawParameterName[0] != '@') + { + writePrefix = true; + totalLength += 1; + } + } + Debug.Assert(totalLength <= 0xff, "parameter name can only be 255 bytes, shouldn't get to TdsParser!"); + stateObj.WriteByte((byte)(totalLength & 0xFF)); + if (writePrefix) + { + WriteString("@", 1, 0, stateObj); + } + WriteString(rawParameterName, nameLength, 0, stateObj); } else { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index 5aa96c540a..d5efaacc93 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -4881,7 +4881,7 @@ private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDi SqlParameter sqlParameter = rpc.userParams[index]; Debug.Assert(sqlParameter != null, "sqlParameter should not be null."); - if (sqlParameter.ParameterNameFixed.Equals(parameterName, StringComparison.Ordinal)) + if (SqlParameter.ParameterNamesEqual(sqlParameter.ParameterName, parameterName, StringComparison.Ordinal)) { Debug.Assert(sqlParameter.CipherMetadata == null, "param.CipherMetadata should be null."); sqlParameter.HasReceivedMetadata = true; @@ -6239,7 +6239,7 @@ internal void OnReturnValue(SqlReturnValue rec, TdsParserStateObject stateObj) { if (rec.tdsType != TdsEnums.SQLBIGVARBINARY) { - throw SQL.InvalidDataTypeForEncryptedParameter(thisParam.ParameterNameFixed, rec.tdsType, TdsEnums.SQLBIGVARBINARY); + throw SQL.InvalidDataTypeForEncryptedParameter(thisParam.GetPrefixedParameterName(), rec.tdsType, TdsEnums.SQLBIGVARBINARY); } // Decrypt the ciphertext @@ -6269,7 +6269,7 @@ internal void OnReturnValue(SqlReturnValue rec, TdsParserStateObject stateObj) } catch (Exception e) { - throw SQL.ParamDecryptionFailed(thisParam.ParameterNameFixed, null, e); + throw SQL.ParamDecryptionFailed(thisParam.GetPrefixedParameterName(), null, e); } } else @@ -6462,7 +6462,11 @@ private SqlParameter GetParameterForOutputValueExtraction(SqlParameterCollection { thisParam = parameters[i]; // searching for Output or InputOutput or ReturnValue with matching name - if (thisParam.Direction != ParameterDirection.Input && thisParam.Direction != ParameterDirection.ReturnValue && paramName == thisParam.ParameterNameFixed) + if ( + thisParam.Direction != ParameterDirection.Input && + thisParam.Direction != ParameterDirection.ReturnValue && + SqlParameter.ParameterNamesEqual(paramName, thisParam.ParameterName,StringComparison.Ordinal) + ) { foundParam = true; break; // found it @@ -6850,11 +6854,11 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto // Find the return value parameter (if any). SqlParameter returnValueParameter = null; - foreach (SqlParameter parameter in parameters) + foreach (SqlParameter param in parameters) { - if (parameter.Direction == ParameterDirection.ReturnValue) + if (param.Direction == ParameterDirection.ReturnValue) { - returnValueParameter = parameter; + returnValueParameter = param; break; } } @@ -6863,7 +6867,8 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto // EXEC @returnValue = moduleName [parameters] if (returnValueParameter != null) { - execStatement.AppendFormat(@"{0}=", returnValueParameter.ParameterNameFixed); + SqlParameter.AppendPrefixedParameterName(execStatement, returnValueParameter.ParameterName); + execStatement.Append('='); } execStatement.Append(ParseAndQuoteIdentifier(storedProcedureName, false)); @@ -6874,6 +6879,7 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto // Append the first parameter int index = 0; int count = parameters.Count; + SqlParameter parameter; if (count > 0) { // Skip the return value parameters. @@ -6884,16 +6890,20 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto if (index < count) { + parameter = parameters[index]; // Possibility of a SQL Injection issue through parameter names and how to construct valid identifier for parameters. // Since the parameters comes from application itself, there should not be a security vulnerability. // Also since the query is not executed, but only analyzed there is no possibility for elevation of priviledge, but only for // incorrect results which would only affect the user that attempts the injection. - execStatement.AppendFormat(@" {0}={0}", parameters[index].ParameterNameFixed); + execStatement.Append(' '); + SqlParameter.AppendPrefixedParameterName(execStatement, parameter.ParameterName); + execStatement.Append('='); + SqlParameter.AppendPrefixedParameterName(execStatement, parameter.ParameterName); // InputOutput and Output parameters need to be marked as such. if ( - parameters[index].Direction == ParameterDirection.Output || - parameters[index].Direction == ParameterDirection.InputOutput + parameter.Direction == ParameterDirection.Output || + parameter.Direction == ParameterDirection.InputOutput ) { execStatement.AppendFormat(@" OUTPUT"); @@ -6907,14 +6917,18 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto // Append the rest of parameters for (; index < count; index++) { - if (parameters[index].Direction != ParameterDirection.ReturnValue) + parameter = parameters[index]; + if (parameter.Direction != ParameterDirection.ReturnValue) { - execStatement.AppendFormat(@", {0}={0}", parameters[index].ParameterNameFixed); + execStatement.Append(", "); + SqlParameter.AppendPrefixedParameterName(execStatement, parameter.ParameterName); + execStatement.Append('='); + SqlParameter.AppendPrefixedParameterName(execStatement, parameter.ParameterName); // InputOutput and Output parameters need to be marked as such. if ( - parameters[index].Direction == ParameterDirection.Output || - parameters[index].Direction == ParameterDirection.InputOutput + parameter.Direction == ParameterDirection.Output || + parameter.Direction == ParameterDirection.InputOutput ) { execStatement.AppendFormat(@" OUTPUT"); @@ -6946,9 +6960,10 @@ internal string BuildParamList(TdsParser parser, SqlParameterCollection paramete // add our separator for the ith parameter if (fAddSeparator) + { paramList.Append(','); - - paramList.Append(sqlParam.ParameterNameFixed); + } + SqlParameter.AppendPrefixedParameterName(paramList, sqlParam.ParameterName); MetaType mt = sqlParam.InternalMetaType; @@ -6957,7 +6972,7 @@ internal string BuildParamList(TdsParser parser, SqlParameterCollection paramete // paragraph above doesn't seem to be correct. Server won't find the type // if we don't provide a fully qualified name - paramList.Append(" "); + paramList.Append(' '); if (mt.SqlDbType == SqlDbType.Udt) { string fullTypeName = sqlParam.UdtTypeName; @@ -6971,7 +6986,7 @@ internal string BuildParamList(TdsParser parser, SqlParameterCollection paramete string typeName = sqlParam.TypeName; if (ADP.IsEmpty(typeName)) { - throw SQL.MustSetTypeNameForParam(mt.TypeName, sqlParam.ParameterNameFixed); + throw SQL.MustSetTypeNameForParam(mt.TypeName, sqlParam.GetPrefixedParameterName()); } paramList.Append(ParseAndQuoteIdentifier(typeName, false /* is not UdtTypeName*/)); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 2bb64c1986..c407e1c6e9 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -10110,7 +10110,7 @@ internal Task TdsExecuteRPC(SqlCommand cmd, _SqlRPC[] rpcArray, int timeout, boo } } - WriteParameterName(param.ParameterNameFixed, stateObj, enableOptimizedParameterBinding); + WriteParameterName(param.ParameterName, stateObj, enableOptimizedParameterBinding); // Write parameter status stateObj.WriteByte(options); @@ -10791,17 +10791,34 @@ private void ExecuteFlushTaskCallback(Task tsk, TdsParserStateObject stateObj, T } } - - private void WriteParameterName(string parameterName, TdsParserStateObject stateObj, bool isAnonymous) + /// + /// Will check the parameter name for the required @ prefix and then write the correct prefixed + /// form and correct character length to the output buffer + /// + private void WriteParameterName(string rawParameterName, TdsParserStateObject stateObj, bool isAnonymous) { // paramLen // paramName - if (!isAnonymous && !string.IsNullOrEmpty(parameterName)) + if (!isAnonymous && !string.IsNullOrEmpty(rawParameterName)) { - Debug.Assert(parameterName.Length <= 0xff, "parameter name can only be 255 bytes, shouldn't get to TdsParser!"); - int tempLen = parameterName.Length & 0xff; - stateObj.WriteByte((byte)tempLen); - WriteString(parameterName, tempLen, 0, stateObj); + int nameLength = rawParameterName.Length; + int totalLength = nameLength; + bool writePrefix = false; + if (nameLength > 0) + { + if (rawParameterName[0] != '@') + { + writePrefix = true; + totalLength += 1; + } + } + Debug.Assert(totalLength <= 0xff, "parameter name can only be 255 bytes, shouldn't get to TdsParser!"); + stateObj.WriteByte((byte)(totalLength & 0xFF)); + if (writePrefix) + { + WriteString("@", 1, 0, stateObj); + } + WriteString(rawParameterName, nameLength, 0, stateObj); } else { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs index 3eb0ce5ba2..ee3a09c17a 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs @@ -989,17 +989,65 @@ internal bool ParameterIsSqlType set => SetFlag(SqlParameterFlags.IsSqlParameterSqlType, value); } - internal string ParameterNameFixed + internal string GetPrefixedParameterName() { - get + string parameterName = ParameterName; + if ((parameterName.Length > 0) && (parameterName[0] != '@')) + { + parameterName = "@" + parameterName; + } + Debug.Assert(parameterName.Length <= TdsEnums.MAX_PARAMETER_NAME_LENGTH, "parameter name too long"); + return parameterName; + } + + /// + /// Checks the parameter name for the @ prefix and appends it if it is missing, then apends the parameter name + /// + /// + /// + internal static void AppendPrefixedParameterName(StringBuilder builder, string rawParameterName) + { + if (!string.IsNullOrEmpty(rawParameterName)) + { + if (rawParameterName[0] != '@') + { + builder.Append('@'); + } + builder.Append(rawParameterName); + } + } + + /// + /// Compares the two input names for equality discounting the @ prefix on either or both arguments + /// + /// + internal static bool ParameterNamesEqual(string lhs, string rhs, StringComparison comparison = StringComparison.Ordinal) + { + if (!string.IsNullOrEmpty(lhs)) { - string parameterName = ParameterName; - if ((parameterName.Length > 0) && (parameterName[0] != '@')) + if (string.IsNullOrEmpty(rhs)) + { + return false; + } + else { - parameterName = "@" + parameterName; + ReadOnlySpan lhsSpan = lhs.AsSpan(); + if (lhs[0] == '@') + { + lhsSpan = lhsSpan.Slice(1); + } + ReadOnlySpan rhsSpan = rhs.AsSpan(); + if (rhsSpan[0] == '@') + { + rhsSpan = rhsSpan.Slice(1); + } + return MemoryExtensions.Equals(lhsSpan, rhsSpan, comparison); } - Debug.Assert(parameterName.Length <= TdsEnums.MAX_PARAMETER_NAME_LENGTH, "parameter name too long"); - return parameterName; + } + else + { + // lhs is null or empty so equality is only possible if the rhs is the same + return string.IsNullOrEmpty(rhs); } } @@ -1804,7 +1852,7 @@ internal SmiParameterMetaData MetaDataForSmi(out ParameterPeekAheadValue peekAhe SqlDbType.Structured == mt.SqlDbType, fields, extendedProperties, - ParameterNameFixed, + GetPrefixedParameterName(), typeSpecificNamePart1, typeSpecificNamePart2, typeSpecificNamePart3, diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlQueryMetadataCache.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlQueryMetadataCache.cs index 012065867b..5475eb5a0c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlQueryMetadataCache.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlQueryMetadataCache.cs @@ -73,7 +73,7 @@ internal bool GetQueryMetadataIfExists(SqlCommand sqlCommand) // Iterate over all the parameters and try to get their cipher MD. foreach (SqlParameter param in sqlCommand.Parameters) { - bool found = cipherMetadataDictionary.TryGetValue(param.ParameterNameFixed, out SqlCipherMetadata paramCiperMetadata); + bool found = cipherMetadataDictionary.TryGetValue(param.GetPrefixedParameterName(), out SqlCipherMetadata paramCiperMetadata); // If we failed to identify the encryption for a specific parameter, clear up the cipher MD of all parameters and exit. if (!found) @@ -211,7 +211,7 @@ internal void AddQueryMetadata(SqlCommand sqlCommand, bool ignoreQueriesWithRetu // Cached cipher MD should never have an initialized algorithm since this would contain the key. Debug.Assert(cipherMdCopy is null || !cipherMdCopy.IsAlgorithmInitialized()); - cipherMetadataDictionary.Add(param.ParameterNameFixed, cipherMdCopy); + cipherMetadataDictionary.Add(param.GetPrefixedParameterName(), cipherMdCopy); } // If the size of the cache exceeds the threshold, set that we are in trimming and trim the cache accordingly.