diff --git a/src/SQLite.cs b/src/SQLite.cs index cbc8d22a..c4fc2313 100644 --- a/src/SQLite.cs +++ b/src/SQLite.cs @@ -2775,194 +2775,119 @@ public static bool IsMarkedNotNull (MemberInfo p) } } - public partial class SQLiteCommand - { - SQLiteConnection _conn; - private List _bindings; + public class SQLitePreparedStatement : IDisposable { + Sqlite3Statement Statement; + SQLiteConnection Connection { get; set; } + Sqlite3DatabaseHandle Handle => Connection.Handle; + public string CommandText { get; private set; } - public string CommandText { get; set; } + public SQLitePreparedStatement(SQLiteConnection conn, string sql) + { + CommandText = sql; + Connection = conn; + Statement = SQLite3.Prepare2(Handle, CommandText); + } - public SQLiteCommand (SQLiteConnection conn) + bool disposed = false; + public void Dispose() { - _conn = conn; - _bindings = new List (); - CommandText = ""; + if (disposed) + return; + disposed = true; + SQLite3.Finalize(Statement); } - public int ExecuteNonQuery () + public int ExecuteNonQuery (params object[] args) { - if (_conn.Trace) { - _conn.Tracer?.Invoke ("Executing: " + this); + ResetAndBind(args); + if (Connection.Trace) { + Connection.Tracer?.Invoke ("Executing: " + ToString(args)); } var r = SQLite3.Result.OK; - var stmt = Prepare (); - r = SQLite3.Step (stmt); - Finalize (stmt); + r = SQLite3.Step (Statement); if (r == SQLite3.Result.Done) { - int rowsAffected = SQLite3.Changes (_conn.Handle); + int rowsAffected = SQLite3.Changes (Handle); return rowsAffected; } else if (r == SQLite3.Result.Error) { - string msg = SQLite3.GetErrmsg (_conn.Handle); + string msg = SQLite3.GetErrmsg (Handle); throw SQLiteException.New (r, msg); } else if (r == SQLite3.Result.Constraint) { - if (SQLite3.ExtendedErrCode (_conn.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { - throw NotNullConstraintViolationException.New (r, SQLite3.GetErrmsg (_conn.Handle)); + if (SQLite3.ExtendedErrCode (Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { + throw NotNullConstraintViolationException.New (r, SQLite3.GetErrmsg (Handle)); } } - throw SQLiteException.New (r, r.ToString ()); } - public IEnumerable ExecuteDeferredQuery () - { - return ExecuteDeferredQuery (_conn.GetMapping (typeof (T))); - } - - public List ExecuteQuery () - { - return ExecuteDeferredQuery (_conn.GetMapping (typeof (T))).ToList (); - } - - public List ExecuteQuery (TableMapping map) - { - return ExecuteDeferredQuery (map).ToList (); - } - - /// - /// Invoked every time an instance is loaded from the database. - /// - /// - /// The newly created object. - /// - /// - /// This can be overridden in combination with the - /// method to hook into the life-cycle of objects. - /// - protected virtual void OnInstanceCreated (object obj) - { - // Can be overridden. - } - - public IEnumerable ExecuteDeferredQuery (TableMapping map) + public IEnumerable ExecuteDeferredQuery (TableMapping map, params object[] args) { - if (_conn.Trace) { - _conn.Tracer?.Invoke ("Executing Query: " + this); + ResetAndBind(args); + if (Connection.Trace) { + Connection.Tracer?.Invoke ("Executing Query: " + ToString(args)); } + var cols = new SQLite.TableMapping.Column[SQLite3.ColumnCount(Statement)]; - var stmt = Prepare (); - try { - var cols = new TableMapping.Column[SQLite3.ColumnCount (stmt)]; + for (int i = 0; i < cols.Length; i++) { + var name = SQLite3.ColumnName16(Statement, i); + cols[i] = map.FindColumn(name); + } + while (SQLite3.Step(Statement) == SQLite3.Result.Row) { + var obj = Activator.CreateInstance(typeof(T)); for (int i = 0; i < cols.Length; i++) { - var name = SQLite3.ColumnName16 (stmt, i); - cols[i] = map.FindColumn (name); + if (cols[i] == null) + continue; + var colType = SQLite3.ColumnType(Statement, i); + var val = ReadCol(Statement, i, colType, cols[i].ColumnType, Connection); + cols[i].SetValue(obj, val); } - - while (SQLite3.Step (stmt) == SQLite3.Result.Row) { - var obj = Activator.CreateInstance (map.MappedType); - for (int i = 0; i < cols.Length; i++) { - if (cols[i] == null) - continue; - var colType = SQLite3.ColumnType (stmt, i); - var val = ReadCol (stmt, i, colType, cols[i].ColumnType); - cols[i].SetValue (obj, val); - } - OnInstanceCreated (obj); - yield return (T)obj; - } - } - finally { - SQLite3.Finalize (stmt); + yield return (T)obj; } } - public T ExecuteScalar () + public IEnumerable ExecuteDeferredQuery (params object[] args) + { + var map = new TableMapping(typeof(T), CreateFlags.None); + return ExecuteDeferredQuery(map, args); + } + + public T ExecuteScalar (params object[] args) { - if (_conn.Trace) { - _conn.Tracer?.Invoke ("Executing Query: " + this); + ResetAndBind(args); + if (Connection.Trace) { + Connection.Tracer?.Invoke ("Executing Query: " + ToString(args)); } T val = default (T); - var stmt = Prepare (); - - try { - var r = SQLite3.Step (stmt); - if (r == SQLite3.Result.Row) { - var colType = SQLite3.ColumnType (stmt, 0); - val = (T)ReadCol (stmt, 0, colType, typeof (T)); - } - else if (r == SQLite3.Result.Done) { - } - else { - throw SQLiteException.New (r, SQLite3.GetErrmsg (_conn.Handle)); - } + var r = SQLite3.Step (Statement); + if (r == SQLite3.Result.Row) { + var colType = SQLite3.ColumnType (Statement, 0); + val = (T)ReadCol (Statement, 0, colType, typeof (T), Connection); } - finally { - Finalize (stmt); + else if (r == SQLite3.Result.Done) { } - - return val; - } - - public void Bind (string name, object val) - { - _bindings.Add (new Binding { - Name = name, - Value = val - }); - } - - public void Bind (object val) - { - Bind (null, val); - } - - public override string ToString () - { - var parts = new string[1 + _bindings.Count]; - parts[0] = CommandText; - var i = 1; - foreach (var b in _bindings) { - parts[i] = string.Format (" {0}: {1}", i - 1, b.Value); - i++; + else { + throw SQLiteException.New (r, SQLite3.GetErrmsg (Handle)); } - return string.Join (Environment.NewLine, parts); - } - - Sqlite3Statement Prepare () - { - var stmt = SQLite3.Prepare2 (_conn.Handle, CommandText); - BindAll (stmt); - return stmt; - } - void Finalize (Sqlite3Statement stmt) - { - SQLite3.Finalize (stmt); + return val; } - void BindAll (Sqlite3Statement stmt) - { - int nextIdx = 1; - foreach (var b in _bindings) { - if (b.Name != null) { - b.Index = SQLite3.BindParameterIndex (stmt, b.Name); - } - else { - b.Index = nextIdx++; - } - - BindParameter (stmt, b.Index, b.Value, _conn.StoreDateTimeAsTicks, _conn.DateTimeStringFormat, _conn.StoreTimeSpanAsTicks); + void ResetAndBind(params object[] args) { + SQLite3.Reset(Statement); + SQLite3.ClearBindings(Statement); + for (int i = 0; i < args.Length; i++) { + BindParameter(Statement, i + 1, args[i], Connection); } } static IntPtr NegativePointer = new IntPtr (-1); - internal static void BindParameter (Sqlite3Statement stmt, int index, object value, bool storeDateTimeAsTicks, string dateTimeStringFormat, bool storeTimeSpanAsTicks) + internal static void BindParameter (Sqlite3Statement stmt, int index, object value, SQLiteConnection connection) { if (value == null) { SQLite3.BindNull (stmt, index); @@ -2987,7 +2912,7 @@ internal static void BindParameter (Sqlite3Statement stmt, int index, object val SQLite3.BindDouble (stmt, index, Convert.ToDouble (value)); } else if (value is TimeSpan) { - if (storeTimeSpanAsTicks) { + if (connection.StoreTimeSpanAsTicks) { SQLite3.BindInt64 (stmt, index, ((TimeSpan)value).Ticks); } else { @@ -2995,11 +2920,11 @@ internal static void BindParameter (Sqlite3Statement stmt, int index, object val } } else if (value is DateTime) { - if (storeDateTimeAsTicks) { + if (connection.StoreDateTimeAsTicks) { SQLite3.BindInt64 (stmt, index, ((DateTime)value).Ticks); } else { - SQLite3.BindText (stmt, index, ((DateTime)value).ToString (dateTimeStringFormat, System.Globalization.CultureInfo.InvariantCulture), -1, NegativePointer); + SQLite3.BindText (stmt, index, ((DateTime)value).ToString (connection.DateTimeStringFormat, System.Globalization.CultureInfo.InvariantCulture), -1, NegativePointer); } } else if (value is DateTimeOffset) { @@ -3038,16 +2963,7 @@ internal static void BindParameter (Sqlite3Statement stmt, int index, object val } } - class Binding - { - public string Name { get; set; } - - public object Value { get; set; } - - public int Index { get; set; } - } - - object ReadCol (Sqlite3Statement stmt, int index, SQLite3.ColType type, Type clrType) + internal static object ReadCol (Sqlite3Statement stmt, int index, SQLite3.ColType type, Type clrType, SQLiteConnection connection) { if (type == SQLite3.ColType.Null) { return null; @@ -3075,7 +2991,7 @@ object ReadCol (Sqlite3Statement stmt, int index, SQLite3.ColType type, Type clr return (float)SQLite3.ColumnDouble (stmt, index); } else if (clrType == typeof (TimeSpan)) { - if (_conn.StoreTimeSpanAsTicks) { + if (connection.StoreTimeSpanAsTicks) { return new TimeSpan (SQLite3.ColumnInt64 (stmt, index)); } else { @@ -3088,13 +3004,13 @@ object ReadCol (Sqlite3Statement stmt, int index, SQLite3.ColType type, Type clr } } else if (clrType == typeof (DateTime)) { - if (_conn.StoreDateTimeAsTicks) { + if (connection.StoreDateTimeAsTicks) { return new DateTime (SQLite3.ColumnInt64 (stmt, index)); } else { var text = SQLite3.ColumnString (stmt, index); DateTime resultDate; - if (!DateTime.TryParseExact (text, _conn.DateTimeStringFormat, System.Globalization.CultureInfo.InvariantCulture, _conn.DateTimeStyle, out resultDate)) { + if (!DateTime.TryParseExact (text, connection.DateTimeStringFormat, System.Globalization.CultureInfo.InvariantCulture, connection.DateTimeStyle, out resultDate)) { resultDate = DateTime.Parse (text); } return resultDate; @@ -3156,6 +3072,123 @@ object ReadCol (Sqlite3Statement stmt, int index, SQLite3.ColType type, Type clr } } } + + public string ToString (params object[] args) + { + var parts = new string[1 + args.Count()]; + parts[0] = CommandText; + var i = 1; + foreach (var b in args) { + parts[i] = string.Format (" {0}: {1}", i - 1, b); + i++; + } + return string.Join (Environment.NewLine, parts); + } + } + + public partial class SQLiteCommand + { + SQLiteConnection _conn; + private List _bindings; + object[] BindingParams => _bindings.Select(b => b.Value).ToArray(); + + public string CommandText { get; set; } + + public SQLiteCommand (SQLiteConnection conn) + { + _conn = conn; + _bindings = new List (); + CommandText = ""; + } + + public int ExecuteNonQuery () + { + using (var stmt = new SQLitePreparedStatement(_conn, CommandText)) { + return stmt.ExecuteNonQuery(BindingParams); + } + } + + public IEnumerable ExecuteDeferredQuery () + { + return ExecuteDeferredQuery (_conn.GetMapping (typeof (T))); + } + + public List ExecuteQuery () + { + return ExecuteDeferredQuery (_conn.GetMapping (typeof (T))).ToList (); + } + + public List ExecuteQuery (TableMapping map) + { + return ExecuteDeferredQuery (map).ToList (); + } + + /// + /// Invoked every time an instance is loaded from the database. + /// + /// + /// The newly created object. + /// + /// + /// This can be overridden in combination with the + /// method to hook into the life-cycle of objects. + /// + protected virtual void OnInstanceCreated (object obj) + { + // Can be overridden. + } + + public IEnumerable ExecuteDeferredQuery (TableMapping map) + { + using (var stmt = new SQLitePreparedStatement(_conn, CommandText)) { + foreach (var obj in stmt.ExecuteDeferredQuery(map, BindingParams)) { + OnInstanceCreated(obj); + yield return obj; + } + } + } + + public T ExecuteScalar () + { + using (var stmt = new SQLitePreparedStatement(_conn, CommandText)) { + return stmt.ExecuteScalar(BindingParams); + } + } + + public void Bind (string name, object val) + { + _bindings.Add (new Binding { + Name = name, + Value = val + }); + } + + public void Bind (object val) + { + Bind (null, val); + } + + public override string ToString () + { + var parts = new string[1 + _bindings.Count]; + parts[0] = CommandText; + var i = 1; + foreach (var b in _bindings) { + parts[i] = string.Format (" {0}: {1}", i - 1, b.Value); + i++; + } + return string.Join (Environment.NewLine, parts); + } + + class Binding + { + public string Name { get; set; } + + public object Value { get; set; } + + public int Index { get; set; } + } + } /// @@ -3198,7 +3231,7 @@ public int ExecuteNonQuery (object[] source) //bind the values. if (source != null) { for (int i = 0; i < source.Length; i++) { - SQLiteCommand.BindParameter (Statement, i + 1, source[i], Connection.StoreDateTimeAsTicks, Connection.DateTimeStringFormat, Connection.StoreTimeSpanAsTicks); + SQLitePreparedStatement.BindParameter (Statement, i + 1, source[i], Connection); } } r = SQLite3.Step (Statement); @@ -4089,6 +4122,9 @@ public static IntPtr Prepare2 (IntPtr db, string query) [DllImport(LibraryPath, EntryPoint = "sqlite3_finalize", CallingConvention=CallingConvention.Cdecl)] public static extern Result Finalize (IntPtr stmt); + [DllImport(LibraryPath, EntryPoint = "sqlite3_clear_bindings", CallingConvention=CallingConvention.Cdecl)] + public static extern int ClearBindings (IntPtr stmt); + [DllImport(LibraryPath, EntryPoint = "sqlite3_last_insert_rowid", CallingConvention=CallingConvention.Cdecl)] public static extern long LastInsertRowid (IntPtr db); @@ -4254,6 +4290,11 @@ public static Result Finalize (Sqlite3Statement stmt) return (Result)Sqlite3.sqlite3_finalize (stmt); } + public static Result ClearBindings (Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_clear_bindings (stmt); + } + public static long LastInsertRowid (Sqlite3DatabaseHandle db) { return Sqlite3.sqlite3_last_insert_rowid (db); diff --git a/src/SQLiteAsync.cs b/src/SQLiteAsync.cs index 93d27ff1..42ab4131 100644 --- a/src/SQLiteAsync.cs +++ b/src/SQLiteAsync.cs @@ -1470,5 +1470,106 @@ public void Dispose () } } } + + /// + /// Async wrapper for prepared statements. + /// + public class SQLiteAsyncPreparedStatement : IDisposable { + SQLiteConnectionWithLock Connection { get; set; } + + /// Get the inner, synchronous, statement; + public SQLitePreparedStatement Inner { get; private set; } + + /// + /// Creates a prepared statement given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// Use this method when return primitive values. + /// You can set the Trace or TimeExecution properties of the connection + /// to profile execution. + /// + /// The Connection. + /// + /// The fully escaped SQL. + /// + public SQLiteAsyncPreparedStatement(SQLiteAsyncConnection conn, string query) : this(conn.GetConnection(), query) { } + + /// + /// Creates a prepared statement given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// Use this method when return primitive values. + /// You can set the Trace or TimeExecution properties of the connection + /// to profile execution. + /// + /// The Connection. + /// + /// The fully escaped SQL. + /// + public SQLiteAsyncPreparedStatement(SQLiteConnectionWithLock conn, string query) + { + Connection = conn; + Inner = new SQLitePreparedStatement(Connection as SQLiteConnection, query); + } + + /// + /// Executes a prepared statement. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The number of rows modified in the database as a result of this execution. + /// + public Task ExecuteNonQueryAsync (params object[] args) + { + return WrapAsync(() => Inner.ExecuteNonQuery(args)); + } + + /// + /// Executes a prepared statement. + /// It returns each row of the result using the mapping automatically generated for + /// the given type. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// The enumerator (retrieved by calling GetEnumerator() on the result of this method) + /// will call sqlite3_step on each call to MoveNext, so the database + /// connection must remain open for the lifetime of the enumerator. + /// + public Task> ExecuteDeferredQueryAsync (params object[] args) + { + return WrapAsync>(() => Inner.ExecuteDeferredQuery(args).ToList()); + } + + /// + /// Executes a prepared statement. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The resultant object of this execution. + /// + public Task ExecuteScalarAsync (params object[] args) + { + return WrapAsync(() => Inner.ExecuteScalar(args)); + } + + /// + /// Dispose + /// + public void Dispose() => Inner.Dispose(); + + Task WrapAsync (Func exe) + { + return Task.Factory.StartNew (() => { + using (Connection.Lock ()) { + return exe(); + } + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + } }