Skip to content

Commit

Permalink
Added IF and WHILE support
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed Feb 5, 2022
1 parent 70ac8bb commit 6842126
Show file tree
Hide file tree
Showing 12 changed files with 441 additions and 65 deletions.
43 changes: 43 additions & 0 deletions MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -281,5 +281,48 @@ public void LoadToDataTable()
}
}
}

[TestMethod]
public void ControlOfFlow()
{
using (var con = new Sql4CdsConnection(_localDataSource.Values.ToList(), this))
using (var cmd = con.CreateCommand())
{
cmd.CommandText = @"
IF @param1 = 1
SELECT 'a'
IF @param1 = 2
SELECT 'b'
WHILE @param1 < 10
BEGIN
SELECT @param1
SET @param1 += 1
END";
cmd.Parameters.Add(new Sql4CdsParameter("@param1", 1));

var log = "";
var results = new List<string>();

((Sql4CdsCommand)cmd).StatementCompleted += (s, e) => log += e.Statement.Sql + "\r\n";

using (var reader = cmd.ExecuteReader())
{
while (!reader.IsClosed)
{
var table = new DataTable();
table.Load(reader);

Assert.AreEqual(1, table.Columns.Count);
Assert.AreEqual(1, table.Rows.Count);
results.Add(table.Rows[0][0].ToString());
}
}

Assert.AreEqual("SELECT 'a'\r\nSELECT @param1\r\nSET @param1 += 1\r\nSELECT @param1\r\nSET @param1 += 1\r\nSELECT @param1\r\nSET @param1 += 1\r\nSELECT @param1\r\nSET @param1 += 1\r\nSELECT @param1\r\nSET @param1 += 1\r\nSELECT @param1\r\nSET @param1 += 1\r\nSELECT @param1\r\nSET @param1 += 1\r\nSELECT @param1\r\nSET @param1 += 1\r\nSELECT @param1\r\nSET @param1 += 1\r\n", log);
CollectionAssert.AreEqual(new[] { "a", "1", "2", "3", "4", "5", "6", "7", "8", "9" }, results);
}
}
}
}
63 changes: 63 additions & 0 deletions MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3674,5 +3674,68 @@ public void TableVariableNotSupported()

planBuilder.Build(query, null, out _);
}

[TestMethod]
public void IfStatement()
{
var metadata = new AttributeMetadataCache(_service);
var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), this);

var query = @"
IF @param1 = 1
BEGIN
INSERT INTO account (name) VALUES ('one')
DELETE FROM account WHERE accountid = @@IDENTITY
END
ELSE
SELECT name FROM account";

var parameters = new Dictionary<string, DataTypeReference>
{
["@param1"] = typeof(SqlInt32).ToSqlType()
};
var plans = planBuilder.Build(query, parameters, out _);

Assert.AreEqual(1, plans.Length);

var cond = AssertNode<IfNode>(plans[0]);
Assert.AreEqual("@param1 = 1", cond.Condition.ToSql());

Assert.AreEqual(2, cond.TrueStatements.Length);
AssertNode<InsertNode>(cond.TrueStatements[0]);
AssertNode<DeleteNode>(cond.TrueStatements[1]);

Assert.AreEqual(1, cond.FalseStatements.Length);
AssertNode<SelectNode>(cond.FalseStatements[0]);
}

[TestMethod]
public void WhileStatement()
{
var metadata = new AttributeMetadataCache(_service);
var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), this);

var query = @"
WHILE @param1 < 10
BEGIN
INSERT INTO account (name) VALUES (@param1)
SET @param1 += 1
END";

var parameters = new Dictionary<string, DataTypeReference>
{
["@param1"] = typeof(SqlInt32).ToSqlType()
};
var plans = planBuilder.Build(query, parameters, out _);

Assert.AreEqual(1, plans.Length);

var cond = AssertNode<WhileNode>(plans[0]);
Assert.AreEqual("@param1 < 10", cond.Condition.ToSql());

Assert.AreEqual(2, cond.Statements.Length);
AssertNode<InsertNode>(cond.Statements[0]);
AssertNode<AssignVariablesNode>(cond.Statements[1]);
}
}
}
70 changes: 46 additions & 24 deletions MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Linq.Expressions;
using System.Text;
using MarkMpn.Sql4Cds.Engine.ExecutionPlan;
using Microsoft.SqlServer.TransactSql.ScriptDom;
using Microsoft.Xrm.Sdk;

namespace MarkMpn.Sql4Cds.Engine
Expand All @@ -20,7 +21,7 @@ class Sql4CdsDataReader : DbDataReader, ISql4CdsDataReader
private readonly CommandBehavior _behavior;
private readonly List<DataTable> _results;
private readonly List<IRootExecutionPlanNode> _resultQueries;
private readonly int _recordsAffected;
private int _recordsAffected;
private int _resultIndex;
private int _rowIndex;
private bool _closed;
Expand Down Expand Up @@ -77,38 +78,59 @@ public Sql4CdsDataReader(Sql4CdsCommand command, IQueryExecutionOptions options,

foreach (var plan in command.Plan)
{
if (plan is IDataSetExecutionPlanNode dataSetNode)
{
var table = dataSetNode.Execute(_connection.DataSources, options, parameterTypes, parameterValues);
_results.Add(table);
_resultQueries.Add(dataSetNode);
Execute(plan, parameterTypes, parameterValues);
}

_connection.OnInfoMessage(plan, $"({table.Rows.Count} row{(table.Rows.Count == 1 ? "" : "s")} affected)");
command.OnStatementCompleted(plan, -1);
}
else if (plan is IDmlQueryExecutionPlanNode dmlNode)
{
var msg = dmlNode.Execute(_connection.DataSources, options, parameterTypes, parameterValues, out var recordsAffected);
_resultIndex = -1;

if (!NextResult())
Close();
}

private void Execute(IRootExecutionPlanNode plan, Dictionary<string, DataTypeReference> parameterTypes, Dictionary<string, object> parameterValues)
{
if (plan is IDataSetExecutionPlanNode dataSetNode)
{
var table = dataSetNode.Execute(_connection.DataSources, _options, parameterTypes, parameterValues);
_results.Add(table);
_resultQueries.Add(dataSetNode);

if (!String.IsNullOrEmpty(msg))
_connection.OnInfoMessage(plan, msg);
_connection.OnInfoMessage(plan, $"({table.Rows.Count} row{(table.Rows.Count == 1 ? "" : "s")} affected)");
_command.OnStatementCompleted(plan, -1);
}
else if (plan is IDmlQueryExecutionPlanNode dmlNode)
{
var msg = dmlNode.Execute(_connection.DataSources, _options, parameterTypes, parameterValues, out var recordsAffected);

if (!String.IsNullOrEmpty(msg))
_connection.OnInfoMessage(plan, msg);

command.OnStatementCompleted(plan, recordsAffected);
_command.OnStatementCompleted(plan, recordsAffected);

if (recordsAffected != -1)
{
if (_recordsAffected == -1)
_recordsAffected = 0;
if (recordsAffected != -1)
{
if (_recordsAffected == -1)
_recordsAffected = 0;

_recordsAffected += recordsAffected;
}
_recordsAffected += recordsAffected;
}
}
else if (plan is IControlOfFlowNode cond)
{
while (true)
{
var childNodes = cond.Execute(_connection.DataSources, _options, parameterTypes, parameterValues, out var rerun);

_resultIndex = -1;
if (childNodes == null)
break;

if (!NextResult())
Close();
foreach (var child in childNodes)
Execute(child, parameterTypes, parameterValues);

if (!rerun)
break;
}
}
}

public IRootExecutionPlanNode CurrentResultQuery => _resultQueries[_resultIndex];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ public static bool IsConstantValueExpression(this ScalarExpression expr, INodeSc
var variableVisitor = new VariableCollectingVisitor();
expr.Accept(variableVisitor);

if (variableVisitor.Variables.Count > 0)
if (variableVisitor.Variables.Count > 0 || variableVisitor.GlobalVariables.Count > 0)
return false;

var parameterlessVisitor = new ParameterlessCollectingVisitor();
Expand Down
24 changes: 24 additions & 0 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/IControlOfFlowNode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.SqlServer.TransactSql.ScriptDom;

namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan
{
/// <summary>
/// Describes a node which moves execution to another node
/// </summary>
interface IControlOfFlowNode : IRootExecutionPlanNodeInternal
{
/// <summary>
/// Checks which nodes should be executed next
/// </summary>
/// <param name="dataSources">The data sources that can be accessed by the query</param>
/// <param name="options">The options which describe how the query should be executed</param>
/// <param name="parameterTypes">The types of any parameters available to the query</param>
/// <param name="parameterValues">The values of any parameters available to the query</param>
/// <param name="rerun">Indicates if this node should be executed again before moving on to the next statement in the batch</param>
/// <returns>The nodes which should be executed next. If <c>null</c>, the entire node is finished and control should move on to the next statement in the batch</returns>
IRootExecutionPlanNodeInternal[] Execute(IDictionary<string, DataSource> dataSources, IQueryExecutionOptions options, IDictionary<string, DataTypeReference> parameterTypes, IDictionary<string, object> parameterValues, out bool rerun);
}
}
83 changes: 83 additions & 0 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/IfNode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Text;
using Microsoft.SqlServer.TransactSql.ScriptDom;

namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan
{
class IfNode : BaseNode, IControlOfFlowNode
{
private int _executionCount;
private readonly Timer _timer = new Timer();

public override int ExecutionCount => _executionCount;

public override TimeSpan Duration => _timer.Duration;

[Browsable(false)]
public string Sql { get; set; }

[Browsable(false)]
public int Index { get; set; }

[Browsable(false)]
public int Length { get; set; }

[Category("If")]
[Description("The condition that must be true for execution to continue")]
public BooleanExpression Condition { get; set; }

[Browsable(false)]
public IRootExecutionPlanNodeInternal[] TrueStatements { get; set; }

[Browsable(false)]
public IRootExecutionPlanNodeInternal[] FalseStatements { get; set; }

public override void AddRequiredColumns(IDictionary<string, DataSource> dataSources, IDictionary<string, DataTypeReference> parameterTypes, IList<string> requiredColumns)
{
foreach (var node in TrueStatements)
node.AddRequiredColumns(dataSources, parameterTypes, new List<string>(requiredColumns));

if (FalseStatements != null)
{
foreach (var node in FalseStatements)
node.AddRequiredColumns(dataSources, parameterTypes, new List<string>(requiredColumns));
}
}

public IRootExecutionPlanNodeInternal[] Execute(IDictionary<string, DataSource> dataSources, IQueryExecutionOptions options, IDictionary<string, DataTypeReference> parameterTypes, IDictionary<string, object> parameterValues, out bool rerun)
{
rerun = false;
var expr = Condition.Compile(null, parameterTypes);

if (expr(null, parameterValues, options))
return TrueStatements;
else
return FalseStatements;
}

public IRootExecutionPlanNodeInternal FoldQuery(IDictionary<string, DataSource> dataSources, IQueryExecutionOptions options, IDictionary<string, DataTypeReference> parameterTypes, IList<OptimizerHint> hints)
{
for (var i = 0; i < TrueStatements.Length; i++)
TrueStatements[i] = TrueStatements[i].FoldQuery(dataSources, options, parameterTypes, hints);

if (FalseStatements != null)
{
for (var i = 0; i < FalseStatements.Length; i++)
FalseStatements[i] = FalseStatements[i].FoldQuery(dataSources, options, parameterTypes, hints);
}

return this;
}

public override IEnumerable<IExecutionPlanNode> GetSources()
{
if (FalseStatements == null)
return TrueStatements;

return TrueStatements.Concat(FalseStatements);
}
}
}
2 changes: 1 addition & 1 deletion MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public DataTable Execute(IDictionary<string, DataSource> dataSources, IQueryExec
}
}

if (Parent == null)
if (Parent == null || Parent is IControlOfFlowNode)
parameterValues["@@ROWCOUNT"] = (SqlInt32)sqlTable.Rows.Count;

return sqlTable;
Expand Down
Loading

0 comments on commit 6842126

Please sign in to comment.