diff --git a/QueryBuilder.Tests/InsertTests.cs b/QueryBuilder.Tests/InsertTests.cs index 6e044e7b..9ae1d9b6 100644 --- a/QueryBuilder.Tests/InsertTests.cs +++ b/QueryBuilder.Tests/InsertTests.cs @@ -25,6 +25,52 @@ public void InsertObject() Assert.Equal("INSERT INTO \"TABLE\" (\"NAME\", \"AGE\") VALUES ('The User', '2018-01-01')", c[EngineCodes.Firebird]); } + [Fact] + public void InsertReturning() + { + var query = new Query("Table").AsInsert(new + { + Name = "The User", + Age = new DateTime(2018, 1, 1), + }, new [] { "Name", "Age" }); + + var c = Compile(query); + + Assert.Equal("INSERT INTO \"Table\" (\"Name\", \"Age\") VALUES ('The User', '2018-01-01') RETURNING \"Name\", \"Age\"", c[EngineCodes.PostgreSql]); + Assert.Equal("INSERT INTO [Table] ([Name], [Age]) OUTPUT inserted.[Name], inserted.[Age] VALUES ('The User', '2018-01-01')", c[EngineCodes.SqlServer]); + Assert.Equal("INSERT INTO \"TABLE\" (\"NAME\", \"AGE\") VALUES ('The User', '2018-01-01')", c[EngineCodes.Firebird]); + } + + [Fact] + public void InsertReturningNull() + { + var query = new Query("Table").AsInsert(new + { + Name = "The User", + Age = new DateTime(2018, 1, 1), + }, null); + + var c = Compile(query); + + Assert.Equal("INSERT INTO \"Table\" (\"Name\", \"Age\") VALUES ('The User', '2018-01-01')", c[EngineCodes.PostgreSql]); + Assert.Equal("INSERT INTO [Table] ([Name], [Age]) VALUES ('The User', '2018-01-01')", c[EngineCodes.SqlServer]); + } + + [Fact] + public void InsertReturningAll() + { + var query = new Query("Table").AsInsert(new + { + Name = "The User", + Age = new DateTime(2018, 1, 1), + }, new [] { "*" }); + + var c = Compile(query); + + Assert.Equal("INSERT INTO \"Table\" (\"Name\", \"Age\") VALUES ('The User', '2018-01-01') RETURNING *", c[EngineCodes.PostgreSql]); + Assert.Equal("INSERT INTO \"TABLE\" (\"NAME\", \"AGE\") VALUES ('The User', '2018-01-01')", c[EngineCodes.Firebird]); + } + [Fact] public void InsertFromSubQueryWithCte() { @@ -76,6 +122,36 @@ public void InsertMultiRecords() c[EngineCodes.Firebird]); } + [Fact] + public void InsertReturningMultiRecords() + { + var query = new Query("expensive_cars") + .AsInsert( + new[] { "name", "brand", "year" }, + new[] + { + new object[] {"Chiron", "Bugatti", null}, + new object[] {"Huayra", "Pagani", 2012}, + new object[] {"Reventon roadster", "Lamborghini", 2009} + }, + new[] { "name" } + ); + + var c = Compile(query); + + Assert.Equal( + "INSERT INTO \"expensive_cars\" (\"name\", \"brand\", \"year\") VALUES ('Chiron', 'Bugatti', NULL), ('Huayra', 'Pagani', 2012), ('Reventon roadster', 'Lamborghini', 2009) RETURNING \"name\"", + c[EngineCodes.PostgreSql]); + + Assert.Equal( + "INSERT INTO [expensive_cars] ([name], [brand], [year]) OUTPUT inserted.[name] VALUES ('Chiron', 'Bugatti', NULL), ('Huayra', 'Pagani', 2012), ('Reventon roadster', 'Lamborghini', 2009)", + c[EngineCodes.SqlServer]); + + Assert.Equal( + "INSERT INTO \"EXPENSIVE_CARS\" (\"NAME\", \"BRAND\", \"YEAR\") SELECT 'Chiron', 'Bugatti', NULL FROM RDB$DATABASE UNION ALL SELECT 'Huayra', 'Pagani', 2012 FROM RDB$DATABASE UNION ALL SELECT 'Reventon roadster', 'Lamborghini', 2009 FROM RDB$DATABASE", + c[EngineCodes.Firebird]); + } + [Fact] public void InsertWithNullValues() { diff --git a/QueryBuilder.Tests/UpdateTests.cs b/QueryBuilder.Tests/UpdateTests.cs index 7b2b0ec3..d31873d6 100644 --- a/QueryBuilder.Tests/UpdateTests.cs +++ b/QueryBuilder.Tests/UpdateTests.cs @@ -25,6 +25,51 @@ public void UpdateObject() Assert.Equal("UPDATE \"TABLE\" SET \"NAME\" = 'The User', \"AGE\" = '2018-01-01'", c[EngineCodes.Firebird]); } + [Fact] + public void UpdateReturning() + { + var query = new Query("Table").AsUpdate(new + { + Name = "The User", + Age = new DateTime(2018, 1, 1), + }, new [] { "Name", "Age" }); + + var c = Compile(query); + + Assert.Equal("UPDATE \"Table\" SET \"Name\" = 'The User', \"Age\" = '2018-01-01' RETURNING \"Name\", \"Age\"", c[EngineCodes.PostgreSql]); + Assert.Equal("UPDATE [Table] SET [Name] = 'The User', [Age] = '2018-01-01' OUTPUT inserted.[Name], inserted.[Age]", c[EngineCodes.SqlServer]); + Assert.Equal("UPDATE \"TABLE\" SET \"NAME\" = 'The User', \"AGE\" = '2018-01-01'", c[EngineCodes.Firebird]); + } + + [Fact] + public void UpdateReturningNull() + { + var query = new Query("Table").AsUpdate(new + { + Name = "The User", + Age = new DateTime(2018, 1, 1), + }, null); + + var c = Compile(query); + + Assert.Equal("UPDATE \"Table\" SET \"Name\" = 'The User', \"Age\" = '2018-01-01'", c[EngineCodes.PostgreSql]); + Assert.Equal("UPDATE [Table] SET [Name] = 'The User', [Age] = '2018-01-01'", c[EngineCodes.SqlServer]); + } + + [Fact] + public void UpdateReturningAll() + { + var query = new Query("Table").AsUpdate(new + { + Name = "The User", + Age = new DateTime(2018, 1, 1), + }, new [] { "*" }); + + var c = Compile(query); + + Assert.Equal("UPDATE \"Table\" SET \"Name\" = 'The User', \"Age\" = '2018-01-01' RETURNING *", c[EngineCodes.PostgreSql]); + } + [Fact] public void UpdateWithNullValues() { diff --git a/QueryBuilder/Clauses/InsertClause.cs b/QueryBuilder/Clauses/InsertClause.cs index 41dfd74b..7d047451 100644 --- a/QueryBuilder/Clauses/InsertClause.cs +++ b/QueryBuilder/Clauses/InsertClause.cs @@ -12,6 +12,7 @@ public class InsertClause : AbstractInsertClause public List Columns { get; set; } public List Values { get; set; } public bool ReturnId { get; set; } = false; + public List ReturnColumns { get; set; } = new List(); public override AbstractClause Clone() { @@ -22,6 +23,7 @@ public override AbstractClause Clone() Columns = Columns, Values = Values, ReturnId = ReturnId, + ReturnColumns = ReturnColumns, }; } } diff --git a/QueryBuilder/Compilers/Compiler.cs b/QueryBuilder/Compilers/Compiler.cs index c042fff4..d7a9bb24 100644 --- a/QueryBuilder/Compilers/Compiler.cs +++ b/QueryBuilder/Compilers/Compiler.cs @@ -303,11 +303,16 @@ private SqlResult CompileUpdateQuery(Query query) var sets = string.Join(", ", parts); - ctx.RawSql = $"UPDATE {table} SET {sets}{where}"; + ctx.RawSql = CompileUpdateQueryString(table, sets, where, toUpdate.ReturnColumns ?? new List()); return ctx; } + protected virtual string CompileUpdateQueryString(string table, string sets, string where, List returnColumns) + { + return $"UPDATE {table} SET {sets}{where}"; + } + protected virtual SqlResult CompileInsertQuery(Query query) { var ctx = new SqlResult @@ -346,13 +351,20 @@ protected virtual SqlResult CompileInsertQuery(Query query) } var inserts = ctx.Query.GetComponents("insert", EngineCode); + var returningColumns = inserts.Where(clause => (clause as InsertClause)?.ReturnColumns != null).SelectMany(clause => (clause as InsertClause).ReturnColumns).Distinct().ToList(); if (inserts[0] is InsertClause insertClause) { - var columns = string.Join(", ", WrapArray(insertClause.Columns)); - var values = string.Join(", ", Parameterize(ctx, insertClause.Values)); + var values = new StringBuilder("VALUES"); + + foreach (var clauseValues in inserts.Select(clause => (clause as InsertClause).Values)) + { + values.Append(" ("); + values.Append(string.Join(", ", Parameterize(ctx, clauseValues))); + values.Append("),"); + } - ctx.RawSql = $"INSERT INTO {table} ({columns}) VALUES ({values})"; + ctx.RawSql = CompileInsertQueryString(table, insertClause.Columns, values.ToString(0, values.Length - 1), returningColumns); if (insertClause.ReturnId && !string.IsNullOrEmpty(LastId)) { @@ -362,35 +374,21 @@ protected virtual SqlResult CompileInsertQuery(Query query) else { var clause = inserts[0] as InsertQueryClause; - - var columns = ""; - - if (clause.Columns.Any()) - { - columns = $" ({string.Join(", ", WrapArray(clause.Columns))}) "; - } - var subCtx = CompileSelectQuery(clause.Query); - ctx.Bindings.AddRange(subCtx.Bindings); - ctx.RawSql = $"INSERT INTO {table}{columns}{subCtx.RawSql}"; - } - - if (inserts.Count > 1) - { - foreach (var insert in inserts.GetRange(1, inserts.Count - 1)) - { - var clause = insert as InsertClause; - - ctx.RawSql += ", (" + string.Join(", ", Parameterize(ctx, clause.Values)) + ")"; - - } + ctx.Bindings.AddRange(subCtx.Bindings); + ctx.RawSql = CompileInsertQueryString(table, clause.Columns, subCtx.RawSql, returningColumns); } - return ctx; } + protected virtual string CompileInsertQueryString(string table, List columns, string rawValues, List returnColumns) + { + var columnsSql = columns.Count > 0 ? " (" + string.Join(", ", WrapArray(columns)) + ") " : ""; + + return $"INSERT INTO {table}{columnsSql}{rawValues}"; + } protected virtual SqlResult CompileCteQuery(SqlResult ctx, Query query) { diff --git a/QueryBuilder/Compilers/PostgresCompiler.cs b/QueryBuilder/Compilers/PostgresCompiler.cs index ea04e9d0..0b8ca220 100644 --- a/QueryBuilder/Compilers/PostgresCompiler.cs +++ b/QueryBuilder/Compilers/PostgresCompiler.cs @@ -1,3 +1,5 @@ +using System.Collections.Generic; + namespace SqlKata.Compilers { public class PostgresCompiler : Compiler @@ -37,5 +39,31 @@ protected override string CompileBasicDateCondition(SqlResult ctx, BasicDateCond return sql; } + + protected override string CompileInsertQueryString(string table, List columns, string rawValues, List returnColumns) + { + var rawSql = base.CompileInsertQueryString(table, columns, rawValues, returnColumns); + + if (returnColumns.Count > 0) + { + var returning = string.Join(", ", WrapArray(returnColumns)); + rawSql += $" RETURNING {returning}"; + } + + return rawSql; + } + + protected override string CompileUpdateQueryString(string table, string sets, string where, List returnColumns) + { + var rawSql = $"UPDATE {table} SET {sets}{where}"; + + if (returnColumns.Count > 0) + { + var returning = string.Join(", ", WrapArray(returnColumns)); + rawSql += $" RETURNING {returning}"; + } + + return rawSql; + } } } diff --git a/QueryBuilder/Compilers/SqlServerCompiler.cs b/QueryBuilder/Compilers/SqlServerCompiler.cs index 13bc8264..7bf490aa 100644 --- a/QueryBuilder/Compilers/SqlServerCompiler.cs +++ b/QueryBuilder/Compilers/SqlServerCompiler.cs @@ -1,3 +1,5 @@ +using System.Collections.Generic; + namespace SqlKata.Compilers { public class SqlServerCompiler : Compiler @@ -168,5 +170,34 @@ protected override string CompileBasicDateCondition(SqlResult ctx, BasicDateCond return sql; } + + protected override string CompileInsertQueryString(string table, List columns, string rawValues, List returnColumns) + { + var columnsSql = columns.Count > 0 ? " (" + string.Join(", ", WrapArray(columns)) + ") " : ""; + var rawSql = $"INSERT INTO {table}{columnsSql}"; + + if (returnColumns.Count > 0) + { + const string prefix = "inserted."; + var output = string.Join($", {prefix}", WrapArray(returnColumns)); + rawSql += $"OUTPUT {prefix}{output} "; + } + + return rawSql + rawValues; + } + + protected override string CompileUpdateQueryString(string table, string sets, string where, List returnColumns) + { + var rawSql = $"UPDATE {table} SET {sets}"; + + if (returnColumns.Count > 0) + { + const string prefix = "inserted."; + var output = string.Join($", {prefix}", WrapArray(returnColumns)); + rawSql += $" OUTPUT {prefix}{output}"; + } + + return rawSql + where; + } } } diff --git a/QueryBuilder/Query.Insert.cs b/QueryBuilder/Query.Insert.cs index c6545ea8..3d03049c 100644 --- a/QueryBuilder/Query.Insert.cs +++ b/QueryBuilder/Query.Insert.cs @@ -7,11 +7,16 @@ namespace SqlKata { public partial class Query { - public Query AsInsert(object data, bool returnId = false) + public Query AsInsert(object data, bool returnId = false, IEnumerable returnColumns = null) { var dictionary = BuildDictionaryFromObject(data); - return AsInsert(dictionary, returnId); + return AsInsert(dictionary, returnId, returnColumns); + } + + public Query AsInsert(object data, IEnumerable returnColumns) + { + return AsInsert(data, false, returnColumns); } public Query AsInsert(IEnumerable columns, IEnumerable values) @@ -40,7 +45,7 @@ public Query AsInsert(IEnumerable columns, IEnumerable values) return this; } - public Query AsInsert(IReadOnlyDictionary data, bool returnId = false) + public Query AsInsert(IReadOnlyDictionary data, bool returnId = false, IEnumerable returnColumns = null) { if (data == null || data.Count == 0) { @@ -54,21 +59,29 @@ public Query AsInsert(IReadOnlyDictionary data, bool returnId = Columns = data.Keys.ToList(), Values = data.Values.ToList(), ReturnId = returnId, + ReturnColumns = returnColumns?.ToList(), }); return this; } + public Query AsInsert(IReadOnlyDictionary data, IEnumerable returnColumns) + { + return AsInsert(data, false, returnColumns); + } + /// /// Produces insert multi records /// /// /// + /// /// - public Query AsInsert(IEnumerable columns, IEnumerable> valuesCollection) + public Query AsInsert(IEnumerable columns, IEnumerable> valuesCollection, IEnumerable returnColumns = null) { var columnsList = columns?.ToList(); var valuesCollectionList = valuesCollection?.ToList(); + var returnColumnsList = returnColumns?.ToList(); if ((columnsList?.Count ?? 0) == 0 || (valuesCollectionList?.Count ?? 0) == 0) { @@ -90,7 +103,8 @@ public Query AsInsert(IEnumerable columns, IEnumerable returnColumns = null) { var dictionary = BuildDictionaryFromObject(data, considerKeys: true); - return AsUpdate(dictionary); + return AsUpdate(dictionary, returnColumns); } - public Query AsUpdate(IEnumerable columns, IEnumerable values) + public Query AsUpdate(IEnumerable columns, IEnumerable values, IEnumerable returnColumns = null) { if ((columns?.Count() ?? 0) == 0 || (values?.Count() ?? 0) == 0) @@ -34,13 +34,14 @@ public Query AsUpdate(IEnumerable columns, IEnumerable values) ClearComponent("update").AddComponent("update", new InsertClause { Columns = columns.ToList(), - Values = values.ToList() + Values = values.ToList(), + ReturnColumns = returnColumns?.ToList(), }); return this; } - public Query AsUpdate(IReadOnlyDictionary data) + public Query AsUpdate(IReadOnlyDictionary data, IEnumerable returnColumns = null) { if (data == null || data.Count == 0) @@ -54,6 +55,7 @@ public Query AsUpdate(IReadOnlyDictionary data) { Columns = data.Keys.ToList(), Values = data.Values.ToList(), + ReturnColumns = returnColumns?.ToList(), }); return this; diff --git a/SqlKata.Execution/Query.Extensions.Async.cs b/SqlKata.Execution/Query.Extensions.Async.cs index 8d5f185c..0cc5e787 100644 --- a/SqlKata.Execution/Query.Extensions.Async.cs +++ b/SqlKata.Execution/Query.Extensions.Async.cs @@ -71,17 +71,18 @@ public static async Task ChunkAsync(this Query query, int chunkSize, Action InsertAsync( this Query query, - IReadOnlyDictionary values + IReadOnlyDictionary values, + IEnumerable returnColumns = null ) { return await QueryHelper.CreateQueryFactory(query) - .ExecuteAsync(query.AsInsert(values)); + .ExecuteAsync(query.AsInsert(values, returnColumns)); } - public static async Task InsertAsync(this Query query, object data) + public static async Task InsertAsync(this Query query, object data, IEnumerable returnColumns = null) { return await QueryHelper.CreateQueryFactory(query) - .ExecuteAsync(query.AsInsert(data)); + .ExecuteAsync(query.AsInsert(data, returnColumns)); } public static async Task InsertGetIdAsync(this Query query, object data) @@ -110,16 +111,16 @@ Query fromQuery .ExecuteAsync(query.AsInsert(columns, fromQuery)); } - public static async Task UpdateAsync(this Query query, IReadOnlyDictionary values) + public static async Task UpdateAsync(this Query query, IReadOnlyDictionary values, IEnumerable returnColumns = null) { return await QueryHelper.CreateQueryFactory(query) - .ExecuteAsync(query.AsUpdate(values)); + .ExecuteAsync(query.AsUpdate(values, returnColumns)); } - public static async Task UpdateAsync(this Query query, object data) + public static async Task UpdateAsync(this Query query, object data, IEnumerable returnColumns = null) { return await QueryHelper.CreateQueryFactory(query) - .ExecuteAsync(query.AsUpdate(data)); + .ExecuteAsync(query.AsUpdate(data, returnColumns)); } public static async Task DeleteAsync(this Query query) diff --git a/SqlKata.Execution/Query.Extensions.cs b/SqlKata.Execution/Query.Extensions.cs index 91ccca2b..46c77110 100644 --- a/SqlKata.Execution/Query.Extensions.cs +++ b/SqlKata.Execution/Query.Extensions.cs @@ -71,18 +71,19 @@ public static void Chunk(this Query query, int chunkSize, Action(chunkSize, action); } - public static int Insert(this Query query, IReadOnlyDictionary values) + public static int Insert(this Query query, IReadOnlyDictionary values, IEnumerable returnColumns = null) { - return QueryHelper.CreateQueryFactory(query).Execute(query.AsInsert(values)); + return QueryHelper.CreateQueryFactory(query).Execute(query.AsInsert(values, returnColumns)); } public static int Insert( this Query query, IEnumerable columns, - IEnumerable> valuesCollection + IEnumerable> valuesCollection, + IEnumerable returnColumns = null ) { - return QueryHelper.CreateQueryFactory(query).Execute(query.AsInsert(columns, valuesCollection)); + return QueryHelper.CreateQueryFactory(query).Execute(query.AsInsert(columns, valuesCollection, returnColumns)); } public static int Insert(this Query query, IEnumerable columns, Query fromQuery) @@ -90,9 +91,9 @@ public static int Insert(this Query query, IEnumerable columns, Query fr return QueryHelper.CreateQueryFactory(query).Execute(query.AsInsert(columns, fromQuery)); } - public static int Insert(this Query query, object data) + public static int Insert(this Query query, object data, IEnumerable returnColumns = null) { - return QueryHelper.CreateQueryFactory(query).Execute(query.AsInsert(data)); + return QueryHelper.CreateQueryFactory(query).Execute(query.AsInsert(data, returnColumns)); } public static T InsertGetId(this Query query, object data) @@ -104,14 +105,14 @@ public static T InsertGetId(this Query query, object data) return row.Id; } - public static int Update(this Query query, IReadOnlyDictionary values) + public static int Update(this Query query, IReadOnlyDictionary values, IEnumerable returnColumns = null) { - return QueryHelper.CreateQueryFactory(query).Execute(query.AsUpdate(values)); + return QueryHelper.CreateQueryFactory(query).Execute(query.AsUpdate(values, returnColumns)); } - public static int Update(this Query query, object data) + public static int Update(this Query query, object data, IEnumerable returnColumns = null) { - return QueryHelper.CreateQueryFactory(query).Execute(query.AsUpdate(data)); + return QueryHelper.CreateQueryFactory(query).Execute(query.AsUpdate(data, returnColumns)); } public static int Delete(this Query query)