Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow returning columns in Insert / Update clauses #372

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions QueryBuilder.Tests/InsertTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,52 @@ public void InsertObject()
Assert.Equal("INSERT INTO \"TABLE\" (\"NAME\", \"AGE\") VALUES ('The User', '2018-01-01')", c[EngineCodes.Firebird]);
}

[Fact]
public void InsertReturning()
{
var query = new Query("Table").AsInsert(new
{
Name = "The User",
Age = new DateTime(2018, 1, 1),
}, new [] { "Name", "Age" });

var c = Compile(query);

Assert.Equal("INSERT INTO \"Table\" (\"Name\", \"Age\") VALUES ('The User', '2018-01-01') RETURNING \"Name\", \"Age\"", c[EngineCodes.PostgreSql]);
Assert.Equal("INSERT INTO [Table] ([Name], [Age]) OUTPUT inserted.[Name], inserted.[Age] VALUES ('The User', '2018-01-01')", c[EngineCodes.SqlServer]);
Assert.Equal("INSERT INTO \"TABLE\" (\"NAME\", \"AGE\") VALUES ('The User', '2018-01-01')", c[EngineCodes.Firebird]);
}

[Fact]
public void InsertReturningNull()
{
var query = new Query("Table").AsInsert(new
{
Name = "The User",
Age = new DateTime(2018, 1, 1),
}, null);

var c = Compile(query);

Assert.Equal("INSERT INTO \"Table\" (\"Name\", \"Age\") VALUES ('The User', '2018-01-01')", c[EngineCodes.PostgreSql]);
Assert.Equal("INSERT INTO [Table] ([Name], [Age]) VALUES ('The User', '2018-01-01')", c[EngineCodes.SqlServer]);
}

[Fact]
public void InsertReturningAll()
{
var query = new Query("Table").AsInsert(new
{
Name = "The User",
Age = new DateTime(2018, 1, 1),
}, new [] { "*" });

var c = Compile(query);

Assert.Equal("INSERT INTO \"Table\" (\"Name\", \"Age\") VALUES ('The User', '2018-01-01') RETURNING *", c[EngineCodes.PostgreSql]);
Assert.Equal("INSERT INTO \"TABLE\" (\"NAME\", \"AGE\") VALUES ('The User', '2018-01-01')", c[EngineCodes.Firebird]);
}

[Fact]
public void InsertFromSubQueryWithCte()
{
Expand Down Expand Up @@ -76,6 +122,36 @@ public void InsertMultiRecords()
c[EngineCodes.Firebird]);
}

[Fact]
public void InsertReturningMultiRecords()
{
var query = new Query("expensive_cars")
.AsInsert(
new[] { "name", "brand", "year" },
new[]
{
new object[] {"Chiron", "Bugatti", null},
new object[] {"Huayra", "Pagani", 2012},
new object[] {"Reventon roadster", "Lamborghini", 2009}
},
new[] { "name" }
);

var c = Compile(query);

Assert.Equal(
"INSERT INTO \"expensive_cars\" (\"name\", \"brand\", \"year\") VALUES ('Chiron', 'Bugatti', NULL), ('Huayra', 'Pagani', 2012), ('Reventon roadster', 'Lamborghini', 2009) RETURNING \"name\"",
c[EngineCodes.PostgreSql]);

Assert.Equal(
"INSERT INTO [expensive_cars] ([name], [brand], [year]) OUTPUT inserted.[name] VALUES ('Chiron', 'Bugatti', NULL), ('Huayra', 'Pagani', 2012), ('Reventon roadster', 'Lamborghini', 2009)",
c[EngineCodes.SqlServer]);

Assert.Equal(
"INSERT INTO \"EXPENSIVE_CARS\" (\"NAME\", \"BRAND\", \"YEAR\") SELECT 'Chiron', 'Bugatti', NULL FROM RDB$DATABASE UNION ALL SELECT 'Huayra', 'Pagani', 2012 FROM RDB$DATABASE UNION ALL SELECT 'Reventon roadster', 'Lamborghini', 2009 FROM RDB$DATABASE",
c[EngineCodes.Firebird]);
}

[Fact]
public void InsertWithNullValues()
{
Expand Down
45 changes: 45 additions & 0 deletions QueryBuilder.Tests/UpdateTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,51 @@ public void UpdateObject()
Assert.Equal("UPDATE \"TABLE\" SET \"NAME\" = 'The User', \"AGE\" = '2018-01-01'", c[EngineCodes.Firebird]);
}

[Fact]
public void UpdateReturning()
{
var query = new Query("Table").AsUpdate(new
{
Name = "The User",
Age = new DateTime(2018, 1, 1),
}, new [] { "Name", "Age" });

var c = Compile(query);

Assert.Equal("UPDATE \"Table\" SET \"Name\" = 'The User', \"Age\" = '2018-01-01' RETURNING \"Name\", \"Age\"", c[EngineCodes.PostgreSql]);
Assert.Equal("UPDATE [Table] SET [Name] = 'The User', [Age] = '2018-01-01' OUTPUT inserted.[Name], inserted.[Age]", c[EngineCodes.SqlServer]);
Assert.Equal("UPDATE \"TABLE\" SET \"NAME\" = 'The User', \"AGE\" = '2018-01-01'", c[EngineCodes.Firebird]);
}

[Fact]
public void UpdateReturningNull()
{
var query = new Query("Table").AsUpdate(new
{
Name = "The User",
Age = new DateTime(2018, 1, 1),
}, null);

var c = Compile(query);

Assert.Equal("UPDATE \"Table\" SET \"Name\" = 'The User', \"Age\" = '2018-01-01'", c[EngineCodes.PostgreSql]);
Assert.Equal("UPDATE [Table] SET [Name] = 'The User', [Age] = '2018-01-01'", c[EngineCodes.SqlServer]);
}

[Fact]
public void UpdateReturningAll()
{
var query = new Query("Table").AsUpdate(new
{
Name = "The User",
Age = new DateTime(2018, 1, 1),
}, new [] { "*" });

var c = Compile(query);

Assert.Equal("UPDATE \"Table\" SET \"Name\" = 'The User', \"Age\" = '2018-01-01' RETURNING *", c[EngineCodes.PostgreSql]);
}

[Fact]
public void UpdateWithNullValues()
{
Expand Down
2 changes: 2 additions & 0 deletions QueryBuilder/Clauses/InsertClause.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public class InsertClause : AbstractInsertClause
public List<string> Columns { get; set; }
public List<object> Values { get; set; }
public bool ReturnId { get; set; } = false;
public List<string> ReturnColumns { get; set; } = new List<string>();

public override AbstractClause Clone()
{
Expand All @@ -22,6 +23,7 @@ public override AbstractClause Clone()
Columns = Columns,
Values = Values,
ReturnId = ReturnId,
ReturnColumns = ReturnColumns,
};
}
}
Expand Down
50 changes: 24 additions & 26 deletions QueryBuilder/Compilers/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,16 @@ private SqlResult CompileUpdateQuery(Query query)

var sets = string.Join(", ", parts);

ctx.RawSql = $"UPDATE {table} SET {sets}{where}";
ctx.RawSql = CompileUpdateQueryString(table, sets, where, toUpdate.ReturnColumns ?? new List<string>());

return ctx;
}

protected virtual string CompileUpdateQueryString(string table, string sets, string where, List<string> returnColumns)
{
return $"UPDATE {table} SET {sets}{where}";
}

protected virtual SqlResult CompileInsertQuery(Query query)
{
var ctx = new SqlResult
Expand Down Expand Up @@ -346,13 +351,20 @@ protected virtual SqlResult CompileInsertQuery(Query query)
}

var inserts = ctx.Query.GetComponents<AbstractInsertClause>("insert", EngineCode);
var returningColumns = inserts.Where(clause => (clause as InsertClause)?.ReturnColumns != null).SelectMany(clause => (clause as InsertClause).ReturnColumns).Distinct().ToList();

if (inserts[0] is InsertClause insertClause)
{
var columns = string.Join(", ", WrapArray(insertClause.Columns));
var values = string.Join(", ", Parameterize(ctx, insertClause.Values));
var values = new StringBuilder("VALUES");

foreach (var clauseValues in inserts.Select(clause => (clause as InsertClause).Values))
{
values.Append(" (");
values.Append(string.Join(", ", Parameterize(ctx, clauseValues)));
values.Append("),");
}

ctx.RawSql = $"INSERT INTO {table} ({columns}) VALUES ({values})";
ctx.RawSql = CompileInsertQueryString(table, insertClause.Columns, values.ToString(0, values.Length - 1), returningColumns);

if (insertClause.ReturnId && !string.IsNullOrEmpty(LastId))
{
Expand All @@ -362,35 +374,21 @@ protected virtual SqlResult CompileInsertQuery(Query query)
else
{
var clause = inserts[0] as InsertQueryClause;

var columns = "";

if (clause.Columns.Any())
{
columns = $" ({string.Join(", ", WrapArray(clause.Columns))}) ";
}

var subCtx = CompileSelectQuery(clause.Query);
ctx.Bindings.AddRange(subCtx.Bindings);

ctx.RawSql = $"INSERT INTO {table}{columns}{subCtx.RawSql}";
}

if (inserts.Count > 1)
{
foreach (var insert in inserts.GetRange(1, inserts.Count - 1))
{
var clause = insert as InsertClause;

ctx.RawSql += ", (" + string.Join(", ", Parameterize(ctx, clause.Values)) + ")";

}
ctx.Bindings.AddRange(subCtx.Bindings);
ctx.RawSql = CompileInsertQueryString(table, clause.Columns, subCtx.RawSql, returningColumns);
}


return ctx;
}

protected virtual string CompileInsertQueryString(string table, List<string> columns, string rawValues, List<string> returnColumns)
{
var columnsSql = columns.Count > 0 ? " (" + string.Join(", ", WrapArray(columns)) + ") " : "";

return $"INSERT INTO {table}{columnsSql}{rawValues}";
}

protected virtual SqlResult CompileCteQuery(SqlResult ctx, Query query)
{
Expand Down
28 changes: 28 additions & 0 deletions QueryBuilder/Compilers/PostgresCompiler.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Collections.Generic;

namespace SqlKata.Compilers
{
public class PostgresCompiler : Compiler
Expand Down Expand Up @@ -37,5 +39,31 @@ protected override string CompileBasicDateCondition(SqlResult ctx, BasicDateCond

return sql;
}

protected override string CompileInsertQueryString(string table, List<string> columns, string rawValues, List<string> returnColumns)
{
var rawSql = base.CompileInsertQueryString(table, columns, rawValues, returnColumns);

if (returnColumns.Count > 0)
{
var returning = string.Join(", ", WrapArray(returnColumns));
rawSql += $" RETURNING {returning}";
}

return rawSql;
}

protected override string CompileUpdateQueryString(string table, string sets, string where, List<string> returnColumns)
{
var rawSql = $"UPDATE {table} SET {sets}{where}";

if (returnColumns.Count > 0)
{
var returning = string.Join(", ", WrapArray(returnColumns));
rawSql += $" RETURNING {returning}";
}

return rawSql;
}
}
}
31 changes: 31 additions & 0 deletions QueryBuilder/Compilers/SqlServerCompiler.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Collections.Generic;

namespace SqlKata.Compilers
{
public class SqlServerCompiler : Compiler
Expand Down Expand Up @@ -168,5 +170,34 @@ protected override string CompileBasicDateCondition(SqlResult ctx, BasicDateCond

return sql;
}

protected override string CompileInsertQueryString(string table, List<string> columns, string rawValues, List<string> returnColumns)
{
var columnsSql = columns.Count > 0 ? " (" + string.Join(", ", WrapArray(columns)) + ") " : "";
var rawSql = $"INSERT INTO {table}{columnsSql}";

if (returnColumns.Count > 0)
{
const string prefix = "inserted.";
var output = string.Join($", {prefix}", WrapArray(returnColumns));
rawSql += $"OUTPUT {prefix}{output} ";
}

return rawSql + rawValues;
}

protected override string CompileUpdateQueryString(string table, string sets, string where, List<string> returnColumns)
{
var rawSql = $"UPDATE {table} SET {sets}";

if (returnColumns.Count > 0)
{
const string prefix = "inserted.";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be times when a person wants the deleted columns as well.

var output = string.Join($", {prefix}", WrapArray(returnColumns));
rawSql += $" OUTPUT {prefix}{output}";
}

return rawSql + where;
}
}
}
24 changes: 19 additions & 5 deletions QueryBuilder/Query.Insert.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ namespace SqlKata
{
public partial class Query
{
public Query AsInsert(object data, bool returnId = false)
public Query AsInsert(object data, bool returnId = false, IEnumerable<string> returnColumns = null)
{
var dictionary = BuildDictionaryFromObject(data);

return AsInsert(dictionary, returnId);
return AsInsert(dictionary, returnId, returnColumns);
}

public Query AsInsert(object data, IEnumerable<string> returnColumns)
{
return AsInsert(data, false, returnColumns);
}

public Query AsInsert(IEnumerable<string> columns, IEnumerable<object> values)
Expand Down Expand Up @@ -40,7 +45,7 @@ public Query AsInsert(IEnumerable<string> columns, IEnumerable<object> values)
return this;
}

public Query AsInsert(IReadOnlyDictionary<string, object> data, bool returnId = false)
public Query AsInsert(IReadOnlyDictionary<string, object> data, bool returnId = false, IEnumerable<string> returnColumns = null)
{
if (data == null || data.Count == 0)
{
Expand All @@ -54,21 +59,29 @@ public Query AsInsert(IReadOnlyDictionary<string, object> data, bool returnId =
Columns = data.Keys.ToList(),
Values = data.Values.ToList(),
ReturnId = returnId,
ReturnColumns = returnColumns?.ToList(),
});

return this;
}

public Query AsInsert(IReadOnlyDictionary<string, object> data, IEnumerable<string> returnColumns)
{
return AsInsert(data, false, returnColumns);
}

/// <summary>
/// Produces insert multi records
/// </summary>
/// <param name="columns"></param>
/// <param name="valuesCollection"></param>
/// <param name="returnColumns"></param>
/// <returns></returns>
public Query AsInsert(IEnumerable<string> columns, IEnumerable<IEnumerable<object>> valuesCollection)
public Query AsInsert(IEnumerable<string> columns, IEnumerable<IEnumerable<object>> valuesCollection, IEnumerable<string> returnColumns = null)
{
var columnsList = columns?.ToList();
var valuesCollectionList = valuesCollection?.ToList();
var returnColumnsList = returnColumns?.ToList();

if ((columnsList?.Count ?? 0) == 0 || (valuesCollectionList?.Count ?? 0) == 0)
{
Expand All @@ -90,7 +103,8 @@ public Query AsInsert(IEnumerable<string> columns, IEnumerable<IEnumerable<objec
AddComponent("insert", new InsertClause
{
Columns = columnsList,
Values = valuesList
Values = valuesList,
ReturnColumns = returnColumnsList,
});
}

Expand Down
Loading