Skip to content

Commit

Permalink
Preserve global variable values between multiple commands
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed Mar 27, 2022
1 parent eaab8e6 commit 43fe8b7
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 5 deletions.
20 changes: 20 additions & 0 deletions MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -373,5 +373,25 @@ select @param1
CollectionAssert.AreEqual(new[] { "3", "4", "end" }, results);
}
}

[TestMethod]
public void GlobalVariablesPreservedBetweenCommands()
{
using (var con = new Sql4CdsConnection(_localDataSource.Values.ToArray()))
using (var cmd = con.CreateCommand())
{
cmd.CommandText = "INSERT INTO account (name) VALUES ('test')";
cmd.ExecuteNonQuery();

cmd.CommandText = "SELECT @@IDENTITY";
var accountId = (SqlEntityReference)cmd.ExecuteScalar();

cmd.CommandText = "SELECT @@ROWCOUNT";
var rowCount = (int)cmd.ExecuteScalar();

Assert.AreEqual("test", _context.Data["account"][accountId.Id].GetAttributeValue<string>("name"));
Assert.AreEqual(1, rowCount);
}
}
}
}
1 change: 1 addition & 0 deletions MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)
var cmd = con.CreateCommand();
cmd.CommandTimeout = (int)TimeSpan.FromMinutes(2).TotalSeconds;
cmd.CommandText = CommandText;
cmd.StatementCompleted += (_, e) => _connection.GlobalVariableValues["@@ROWCOUNT"] = (SqlInt32) e.RecordCount;

if (Parameters.Count > 0)
{
Expand Down
21 changes: 20 additions & 1 deletion MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using Microsoft.Crm.Sdk.Messages;
using MarkMpn.Sql4Cds.Engine.ExecutionPlan;
using System.Threading;
using Microsoft.SqlServer.TransactSql.ScriptDom;
using System.Data.SqlTypes;
#if NETCOREAPP
using Microsoft.PowerPlatform.Dataverse.Client;
#else
Expand All @@ -24,6 +26,8 @@ public class Sql4CdsConnection : DbConnection
{
private readonly Dictionary<string, DataSource> _dataSources;
private readonly ChangeDatabaseOptionsWrapper _options;
private readonly Dictionary<string, DataTypeReference> _globalVariableTypes;
private readonly Dictionary<string, object> _globalVariableValues;

/// <summary>
/// Creates a new <see cref="Sql4CdsConnection"/> using the specified XRM connection string
Expand Down Expand Up @@ -58,6 +62,17 @@ public Sql4CdsConnection(params DataSource[] dataSources)

_dataSources = dataSources.ToDictionary(ds => ds.Name, StringComparer.OrdinalIgnoreCase);
_options = new ChangeDatabaseOptionsWrapper(this, options);

_globalVariableTypes = new Dictionary<string, DataTypeReference>(StringComparer.OrdinalIgnoreCase)
{
["@@IDENTITY"] = typeof(SqlEntityReference).ToSqlType(),
["@@ROWCOUNT"] = typeof(SqlInt32).ToSqlType()
};
_globalVariableValues = new Dictionary<string, object>(StringComparer.OrdinalIgnoreCase)
{
["@@IDENTITY"] = SqlEntityReference.Null,
["@@ROWCOUNT"] = (SqlInt32)0
};
}

private static IOrganizationService Connect(string connectionString)
Expand Down Expand Up @@ -196,6 +211,10 @@ public bool QuotedIdentifiers
set => _options.QuotedIdentifiers = value;
}

internal Dictionary<string, DataTypeReference> GlobalVariableTypes => _globalVariableTypes;

internal Dictionary<string, object> GlobalVariableValues => _globalVariableValues;

/// <summary>
/// Triggered before one or more records are about to be deleted
/// </summary>
Expand Down Expand Up @@ -323,7 +342,7 @@ public override void Open()
{
}

protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel)
protected override DbTransaction BeginDbTransaction(System.Data.IsolationLevel isolationLevel)
{
throw new NotSupportedException("Transactions are not supported");
}
Expand Down
14 changes: 10 additions & 4 deletions MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ public Sql4CdsDataReader(Sql4CdsCommand command, IQueryExecutionOptions options,
_parameterTypes = ((Sql4CdsParameterCollection)command.Parameters).GetParameterTypes();
_parameterValues = ((Sql4CdsParameterCollection)command.Parameters).GetParameterValues();

_parameterTypes["@@IDENTITY"] = typeof(SqlEntityReference).ToSqlType();
_parameterTypes["@@ROWCOUNT"] = typeof(SqlInt32).ToSqlType();
_parameterValues["@@IDENTITY"] = SqlEntityReference.Null;
_parameterValues["@@ROWCOUNT"] = (SqlInt32)0;
foreach (var paramType in _connection.GlobalVariableTypes)
_parameterTypes[paramType.Key] = paramType.Value;

foreach (var paramValue in _connection.GlobalVariableValues)
_parameterValues[paramValue.Key] = paramValue.Value;

_labelIndexes = command.Plan
.Select((node, index) => new { node, index })
Expand Down Expand Up @@ -120,6 +121,11 @@ private bool Execute(Dictionary<string, DataTypeReference> parameterTypes, Dicti
_error = true;
throw;
}
finally
{
foreach (var paramName in _connection.GlobalVariableValues.Keys.ToArray())
_connection.GlobalVariableValues[paramName] = parameterValues[paramName];
}

return false;
}
Expand Down

0 comments on commit 43fe8b7

Please sign in to comment.