From 9f5ea216411221d026e99a57b1f42b7217828104 Mon Sep 17 00:00:00 2001 From: Bradley Grainger Date: Wed, 28 Sep 2016 21:39:44 -0700 Subject: [PATCH] Use synchronous I/O for sync methods. Fixes #62 --- .../MySqlClient/ConnectionPool.cs | 18 +-- .../MySqlClient/MySqlCommand.cs | 43 ++++--- .../MySqlClient/MySqlConnection.cs | 42 +++---- .../MySqlClient/MySqlDataReader.cs | 36 +++--- .../MySqlClient/MySqlTransaction.cs | 28 ++--- .../Serialization/IOBehavior.cs | 18 +++ .../Serialization/MySqlSession.cs | 89 +++++++++------ .../Serialization/PacketTransmitter.cs | 105 ++++++++++++------ .../Serialization/ProtocolErrorBehavior.cs | 18 +++ 9 files changed, 253 insertions(+), 144 deletions(-) create mode 100644 src/MySqlConnector/Serialization/IOBehavior.cs create mode 100644 src/MySqlConnector/Serialization/ProtocolErrorBehavior.cs diff --git a/src/MySqlConnector/MySqlClient/ConnectionPool.cs b/src/MySqlConnector/MySqlClient/ConnectionPool.cs index 833e62c4c..d8b19681f 100644 --- a/src/MySqlConnector/MySqlClient/ConnectionPool.cs +++ b/src/MySqlConnector/MySqlClient/ConnectionPool.cs @@ -8,7 +8,7 @@ namespace MySql.Data.MySqlClient { internal sealed class ConnectionPool { - public async Task GetSessionAsync(CancellationToken cancellationToken) + public async Task GetSessionAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -24,17 +24,17 @@ public async Task GetSessionAsync(CancellationToken cancellationTo // check for a pooled session if (m_sessions.TryDequeue(out session)) { - if (!await session.TryPingAsync(cancellationToken).ConfigureAwait(false)) + if (!await session.TryPingAsync(ioBehavior, cancellationToken).ConfigureAwait(false)) { // session is not valid - await session.DisposeAsync(cancellationToken).ConfigureAwait(false); + await session.DisposeAsync(ioBehavior, cancellationToken).ConfigureAwait(false); } else { // session is valid, reset if supported if (m_resetConnections) { - await session.ResetConnectionAsync(m_userId, m_password, m_database, cancellationToken).ConfigureAwait(false); + await session.ResetConnectionAsync(m_userId, m_password, m_database, ioBehavior, cancellationToken).ConfigureAwait(false); } // pooled session is ready to be used; return it return session; @@ -42,7 +42,7 @@ public async Task GetSessionAsync(CancellationToken cancellationTo } session = new MySqlSession(this); - await session.ConnectAsync(m_servers, m_port, m_userId, m_password, m_database, cancellationToken).ConfigureAwait(false); + await session.ConnectAsync(m_servers, m_port, m_userId, m_password, m_database, ioBehavior, cancellationToken).ConfigureAwait(false); return session; } catch @@ -64,7 +64,7 @@ public void Return(MySqlSession session) } } - public async Task ClearAsync(CancellationToken cancellationToken) + public async Task ClearAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); try @@ -82,7 +82,7 @@ public async Task ClearAsync(CancellationToken cancellationToken) MySqlSession session; while (m_sessions.TryDequeue(out session)) { - tasks.Add(session.DisposeAsync(cancellationToken)); + tasks.Add(session.DisposeAsync(ioBehavior, cancellationToken)); } if (tasks.Count > 0) { @@ -116,12 +116,12 @@ public static ConnectionPool GetPool(MySqlConnectionStringBuilder csb) return pool; } - public static async Task ClearPoolsAsync(CancellationToken cancellationToken) + public static async Task ClearPoolsAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { var pools = new List(s_pools.Values); foreach (var pool in pools) - await pool.ClearAsync(cancellationToken).ConfigureAwait(false); + await pool.ClearAsync(ioBehavior, cancellationToken).ConfigureAwait(false); } private ConnectionPool(IEnumerable servers, int port, string userId, string password, string database, diff --git a/src/MySqlConnector/MySqlClient/MySqlCommand.cs b/src/MySqlConnector/MySqlClient/MySqlCommand.cs index 218b5b58f..f1eecc1f4 100644 --- a/src/MySqlConnector/MySqlClient/MySqlCommand.cs +++ b/src/MySqlConnector/MySqlClient/MySqlCommand.cs @@ -52,11 +52,11 @@ public override void Cancel() throw new NotSupportedException("Use the Async overloads with a CancellationToken."); } - public override int ExecuteNonQuery() - => ExecuteNonQueryAsync(CancellationToken.None).GetAwaiter().GetResult(); + public override int ExecuteNonQuery() => + ExecuteNonQueryAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - public override object ExecuteScalar() - => ExecuteScalarAsync(CancellationToken.None).GetAwaiter().GetResult(); + public override object ExecuteScalar() => + ExecuteScalarAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); public override void Prepare() { @@ -102,38 +102,47 @@ protected override DbParameter CreateDbParameter() return new MySqlParameter(); } - protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) - => ExecuteDbDataReaderAsync(behavior, CancellationToken.None).GetAwaiter().GetResult(); + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => + ExecuteReaderAsync(behavior, IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - public override async Task ExecuteNonQueryAsync(CancellationToken cancellationToken) + public override Task ExecuteNonQueryAsync(CancellationToken cancellationToken) => + ExecuteNonQueryAsync(IOBehavior.Asynchronous, cancellationToken); + + internal async Task ExecuteNonQueryAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { - using (var reader = await ExecuteReaderAsync(cancellationToken).ConfigureAwait(false)) + using (var reader = (MySqlDataReader) await ExecuteReaderAsync(CommandBehavior.Default, ioBehavior, cancellationToken).ConfigureAwait(false)) { do { - while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + while (await reader.ReadAsync(ioBehavior, cancellationToken).ConfigureAwait(false)) { } - } while (await reader.NextResultAsync(cancellationToken).ConfigureAwait(false)); + } while (await reader.NextResultAsync(ioBehavior, cancellationToken).ConfigureAwait(false)); return reader.RecordsAffected; } } - public override async Task ExecuteScalarAsync(CancellationToken cancellationToken) + public override Task ExecuteScalarAsync(CancellationToken cancellationToken) => + ExecuteScalarAsync(IOBehavior.Asynchronous, cancellationToken); + + internal async Task ExecuteScalarAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { object result = null; - using (var reader = await ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow, cancellationToken).ConfigureAwait(false)) + using (var reader = (MySqlDataReader) await ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow, ioBehavior, cancellationToken).ConfigureAwait(false)) { do { - if (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + if (await reader.ReadAsync(ioBehavior, cancellationToken).ConfigureAwait(false)) result = reader.GetValue(0); - } while (await reader.NextResultAsync(cancellationToken).ConfigureAwait(false)); + } while (await reader.NextResultAsync(ioBehavior, cancellationToken).ConfigureAwait(false)); } return result; } - protected override async Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) + protected override Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) => + ExecuteReaderAsync(behavior, IOBehavior.Asynchronous, cancellationToken); + + internal async Task ExecuteReaderAsync(CommandBehavior behavior, IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyValid(); Connection.HasActiveReader = true; @@ -151,8 +160,8 @@ protected override async Task ExecuteDbDataReaderAsync(CommandBeha var preparer = new MySqlStatementPreparer(CommandText, m_parameterCollection, statementPreparerOptions); preparer.BindParameters(); var payload = new PayloadData(new ArraySegment(Payload.CreateEofStringPayload(CommandKind.Query, preparer.PreparedSql))); - await Session.SendAsync(payload, cancellationToken).ConfigureAwait(false); - reader = await MySqlDataReader.CreateAsync(this, behavior, cancellationToken).ConfigureAwait(false); + await Session.SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); + reader = await MySqlDataReader.CreateAsync(this, behavior, ioBehavior, cancellationToken).ConfigureAwait(false); return reader; } finally diff --git a/src/MySqlConnector/MySqlClient/MySqlConnection.cs b/src/MySqlConnector/MySqlClient/MySqlConnection.cs index 131449bf8..77a75d0eb 100644 --- a/src/MySqlConnector/MySqlClient/MySqlConnection.cs +++ b/src/MySqlConnector/MySqlClient/MySqlConnection.cs @@ -24,15 +24,15 @@ public MySqlConnection(string connectionString) public new MySqlTransaction BeginTransaction() => (MySqlTransaction) base.BeginTransaction(); public Task BeginTransactionAsync(CancellationToken cancellationToken = default(CancellationToken)) => - BeginDbTransactionAsync(IsolationLevel.Unspecified, cancellationToken); + BeginDbTransactionAsync(IsolationLevel.Unspecified, IOBehavior.Asynchronous, cancellationToken); public Task BeginTransactionAsync(IsolationLevel isolationLevel, CancellationToken cancellationToken = default(CancellationToken)) => - BeginDbTransactionAsync(isolationLevel, cancellationToken); + BeginDbTransactionAsync(isolationLevel, IOBehavior.Asynchronous, cancellationToken); protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => - BeginDbTransactionAsync(isolationLevel).GetAwaiter().GetResult(); + BeginDbTransactionAsync(isolationLevel, IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - private async Task BeginDbTransactionAsync(IsolationLevel isolationLevel, CancellationToken cancellationToken = default(CancellationToken)) + private async Task BeginDbTransactionAsync(IsolationLevel isolationLevel, IOBehavior ioBehavior, CancellationToken cancellationToken) { if (State != ConnectionState.Open) throw new InvalidOperationException("Connection is not open."); @@ -67,7 +67,7 @@ protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLeve } using (var cmd = new MySqlCommand("set session transaction isolation level " + isolationLevelValue + "; start transaction;", this)) - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(ioBehavior, cancellationToken).ConfigureAwait(false); var transaction = new MySqlTransaction(this, isolationLevel); CurrentTransaction = transaction; @@ -88,9 +88,12 @@ public override void ChangeDatabase(string databaseName) throw new NotImplementedException(); } - public override void Open() => OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); + public override void Open() => OpenAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - public override async Task OpenAsync(CancellationToken cancellationToken) + public override Task OpenAsync(CancellationToken cancellationToken) => + OpenAsync(IOBehavior.Asynchronous, cancellationToken); + + private async Task OpenAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyNotDisposed(); if (State != ConnectionState.Closed) @@ -104,7 +107,7 @@ public override async Task OpenAsync(CancellationToken cancellationToken) try { - m_session = await CreateSessionAsync(cancellationToken).ConfigureAwait(false); + m_session = await CreateSessionAsync(ioBehavior, cancellationToken).ConfigureAwait(false); m_hasBeenOpened = true; SetState(ConnectionState.Open); @@ -147,20 +150,21 @@ public override string ConnectionString public override string ServerVersion => m_session.ServerVersion.OriginalString; - public static void ClearPool(MySqlConnection connection) => ClearPoolAsync(connection, CancellationToken.None).GetAwaiter().GetResult(); - public static void ClearAllPools() => ClearAllPoolsAsync(CancellationToken.None).GetAwaiter().GetResult(); - public static Task ClearPoolAsync(MySqlConnection connection) => ClearPoolAsync(connection, CancellationToken.None); - public static Task ClearAllPoolsAsync() => ClearAllPoolsAsync(CancellationToken.None); - public static Task ClearAllPoolsAsync(CancellationToken cancellationToken) => ConnectionPool.ClearPoolsAsync(cancellationToken); + public static void ClearPool(MySqlConnection connection) => ClearPoolAsync(connection, IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); + public static Task ClearPoolAsync(MySqlConnection connection) => ClearPoolAsync(connection, IOBehavior.Asynchronous, CancellationToken.None); + public static Task ClearPoolAsync(MySqlConnection connection, CancellationToken cancellationToken) => ClearPoolAsync(connection, IOBehavior.Asynchronous, cancellationToken); + public static void ClearAllPools() => ConnectionPool.ClearPoolsAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); + public static Task ClearAllPoolsAsync() => ConnectionPool.ClearPoolsAsync(IOBehavior.Asynchronous, CancellationToken.None); + public static Task ClearAllPoolsAsync(CancellationToken cancellationToken) => ConnectionPool.ClearPoolsAsync(IOBehavior.Asynchronous, cancellationToken); - public static async Task ClearPoolAsync(MySqlConnection connection, CancellationToken cancellationToken) + private static async Task ClearPoolAsync(MySqlConnection connection, IOBehavior ioBehavior, CancellationToken cancellationToken) { if (connection == null) throw new ArgumentNullException(nameof(connection)); var pool = ConnectionPool.GetPool(connection.m_connectionStringBuilder); if (pool != null) - await pool.ClearAsync(cancellationToken).ConfigureAwait(false); + await pool.ClearAsync(ioBehavior, cancellationToken).ConfigureAwait(false); } protected override DbCommand CreateDbCommand() => new MySqlCommand(this, CurrentTransaction); @@ -203,7 +207,7 @@ internal MySqlSession Session internal bool ConvertZeroDateTime => m_connectionStringBuilder.ConvertZeroDateTime; internal bool OldGuids => m_connectionStringBuilder.OldGuids; - private async Task CreateSessionAsync(CancellationToken cancellationToken) + private async Task CreateSessionAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { var connectTimeout = m_connectionStringBuilder.ConnectionTimeout == 0 ? Timeout.InfiniteTimeSpan : TimeSpan.FromSeconds(checked((int) m_connectionStringBuilder.ConnectionTimeout)); using (var timeoutSource = new CancellationTokenSource(connectTimeout)) @@ -217,13 +221,13 @@ private async Task CreateSessionAsync(CancellationToken cancellati var pool = ConnectionPool.GetPool(m_connectionStringBuilder); // this returns an open session - return await pool.GetSessionAsync(linkedSource.Token).ConfigureAwait(false); + return await pool.GetSessionAsync(ioBehavior, linkedSource.Token).ConfigureAwait(false); } else { var session = new MySqlSession(null); await session.ConnectAsync(m_connectionStringBuilder.Server.Split(','), (int) m_connectionStringBuilder.Port, m_connectionStringBuilder.UserID, - m_connectionStringBuilder.Password, m_connectionStringBuilder.Database, linkedSource.Token).ConfigureAwait(false); + m_connectionStringBuilder.Password, m_connectionStringBuilder.Database, ioBehavior, linkedSource.Token).ConfigureAwait(false); return session; } } @@ -264,7 +268,7 @@ private void DoClose() if (m_connectionStringBuilder.Pooling) m_session.ReturnToPool(); else - m_session.DisposeAsync(CancellationToken.None).GetAwaiter().GetResult(); + m_session.DisposeAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); m_session = null; } SetState(ConnectionState.Closed); diff --git a/src/MySqlConnector/MySqlClient/MySqlDataReader.cs b/src/MySqlConnector/MySqlClient/MySqlDataReader.cs index b9d5ac2b7..cdfd0a039 100644 --- a/src/MySqlConnector/MySqlClient/MySqlDataReader.cs +++ b/src/MySqlConnector/MySqlClient/MySqlDataReader.cs @@ -12,17 +12,18 @@ namespace MySql.Data.MySqlClient { public sealed class MySqlDataReader : DbDataReader { - public override bool NextResult() - { - return NextResultAsync(CancellationToken.None).GetAwaiter().GetResult(); - } + public override bool NextResult() => + NextResultAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - public override async Task NextResultAsync(CancellationToken cancellationToken) + public override Task NextResultAsync(CancellationToken cancellationToken) => + NextResultAsync(IOBehavior.Asynchronous, cancellationToken); + + internal async Task NextResultAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyNotDisposed(); while (m_state == State.ReadingRows || m_state == State.ReadResultSetHeader) - await ReadAsync(cancellationToken).ConfigureAwait(false); + await ReadAsync(ioBehavior, cancellationToken).ConfigureAwait(false); var oldState = m_state; Reset(); @@ -31,17 +32,20 @@ public override async Task NextResultAsync(CancellationToken cancellationT if (oldState != State.HasMoreData) throw new InvalidOperationException("Invalid state: {0}".FormatInvariant(oldState)); - await ReadResultSetHeaderAsync(cancellationToken).ConfigureAwait(false); + await ReadResultSetHeaderAsync(ioBehavior, cancellationToken).ConfigureAwait(false); return true; } public override bool Read() { VerifyNotDisposed(); - return ReadAsync(CancellationToken.None).GetAwaiter().GetResult(); + return ReadAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); } - public override Task ReadAsync(CancellationToken cancellationToken) + public override Task ReadAsync(CancellationToken cancellationToken) => + ReadAsync(IOBehavior.Asynchronous, cancellationToken); + + internal Task ReadAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyNotDisposed(); @@ -51,7 +55,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) if (m_state != State.AlreadyReadFirstRow) { - var payloadTask = m_session.ReceiveReplyAsync(cancellationToken); + var payloadTask = m_session.ReceiveReplyAsync(ioBehavior, cancellationToken); if (payloadTask.IsCompletedSuccessfully) return ReadAsyncRemainder(payloadTask.Result) ? s_trueTask : s_falseTask; return ReadAsyncAwaited(payloadTask.AsTask()); @@ -645,10 +649,10 @@ private void DoClose() } } - internal static async Task CreateAsync(MySqlCommand command, CommandBehavior behavior, CancellationToken cancellationToken) + internal static async Task CreateAsync(MySqlCommand command, CommandBehavior behavior, IOBehavior ioBehavior, CancellationToken cancellationToken) { var dataReader = new MySqlDataReader(command, behavior); - await dataReader.ReadResultSetHeaderAsync(cancellationToken).ConfigureAwait(false); + await dataReader.ReadResultSetHeaderAsync(ioBehavior, cancellationToken).ConfigureAwait(false); return dataReader; } @@ -709,11 +713,11 @@ private MySqlDataReader(MySqlCommand command, CommandBehavior behavior) private MySqlConnection Connection => m_command.Connection; - private async Task ReadResultSetHeaderAsync(CancellationToken cancellationToken) + private async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { while (true) { - var payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + var payload = await m_session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); var firstByte = payload.HeaderByte; if (firstByte == OkPayload.Signature) @@ -740,11 +744,11 @@ private async Task ReadResultSetHeaderAsync(CancellationToken cancellationToken) for (var column = 0; column < m_columnDefinitions.Length; column++) { - payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + payload = await m_session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); m_columnDefinitions[column] = ColumnDefinitionPayload.Create(payload); } - payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + payload = await m_session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); EofPayload.Create(payload); m_command.LastInsertedId = -1; diff --git a/src/MySqlConnector/MySqlClient/MySqlTransaction.cs b/src/MySqlConnector/MySqlClient/MySqlTransaction.cs index fc2e338cf..210d92b4a 100644 --- a/src/MySqlConnector/MySqlClient/MySqlTransaction.cs +++ b/src/MySqlConnector/MySqlClient/MySqlTransaction.cs @@ -3,17 +3,19 @@ using System.Data.Common; using System.Threading; using System.Threading.Tasks; +using MySql.Data.Serialization; namespace MySql.Data.MySqlClient { public class MySqlTransaction : DbTransaction { - public override void Commit() - { - CommitAsync().GetAwaiter().GetResult(); - } + public override void Commit() => + CommitAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - public async Task CommitAsync(CancellationToken cancellationToken = default(CancellationToken)) + public Task CommitAsync(CancellationToken cancellationToken = default(CancellationToken)) => + CommitAsync(IOBehavior.Asynchronous, cancellationToken); + + internal async Task CommitAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyNotDisposed(); if (m_isFinished) @@ -22,7 +24,7 @@ public override void Commit() if (m_connection.CurrentTransaction == this) { using (var cmd = new MySqlCommand("commit", m_connection, this)) - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(ioBehavior, cancellationToken).ConfigureAwait(false); m_connection.CurrentTransaction = null; m_isFinished = true; } @@ -36,12 +38,13 @@ public override void Commit() } } - public override void Rollback() - { - RollbackAsync().GetAwaiter().GetResult(); - } + public override void Rollback() => + RollbackAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); + + public Task RollbackAsync(CancellationToken cancellationToken = default(CancellationToken)) => + RollbackAsync(IOBehavior.Asynchronous, cancellationToken); - public async Task RollbackAsync(CancellationToken cancellationToken = default(CancellationToken)) + internal async Task RollbackAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyNotDisposed(); if (m_isFinished) @@ -50,7 +53,7 @@ public override void Rollback() if (m_connection.CurrentTransaction == this) { using (var cmd = new MySqlCommand("rollback", m_connection, this)) - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(ioBehavior, cancellationToken).ConfigureAwait(false); m_connection.CurrentTransaction = null; m_isFinished = true; } @@ -88,7 +91,6 @@ protected override void Dispose(bool disposing) } } - internal MySqlTransaction(MySqlConnection connection, IsolationLevel isolationLevel) { m_connection = connection; diff --git a/src/MySqlConnector/Serialization/IOBehavior.cs b/src/MySqlConnector/Serialization/IOBehavior.cs new file mode 100644 index 000000000..9f62eb030 --- /dev/null +++ b/src/MySqlConnector/Serialization/IOBehavior.cs @@ -0,0 +1,18 @@ +namespace MySql.Data.Serialization +{ + /// + /// Specifies whether to perform synchronous or asynchronous I/O. + /// + internal enum IOBehavior + { + /// + /// Use synchronous I/O. + /// + Synchronous, + + /// + /// Use asynchronous I/O. + /// + Asynchronous, + } +} diff --git a/src/MySqlConnector/Serialization/MySqlSession.cs b/src/MySqlConnector/Serialization/MySqlSession.cs index 46b76922c..1a2be681f 100644 --- a/src/MySqlConnector/Serialization/MySqlSession.cs +++ b/src/MySqlConnector/Serialization/MySqlSession.cs @@ -23,14 +23,14 @@ public MySqlSession(ConnectionPool pool) public void ReturnToPool() => Pool?.Return(this); - public async Task DisposeAsync(CancellationToken cancellationToken) + public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { if (m_transmitter != null) { try { - await m_transmitter.SendAsync(QuitPayload.Create(), cancellationToken).ConfigureAwait(false); - await m_transmitter.TryReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await m_transmitter.SendAsync(QuitPayload.Create(), ioBehavior, cancellationToken).ConfigureAwait(false); + await m_transmitter.TryReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); } catch (SocketException) { @@ -54,13 +54,13 @@ public async Task DisposeAsync(CancellationToken cancellationToken) m_state = State.Closed; } - public async Task ConnectAsync(IEnumerable hosts, int port, string userId, string password, string database, CancellationToken cancellationToken) + public async Task ConnectAsync(IEnumerable hosts, int port, string userId, string password, string database, IOBehavior ioBehavior, CancellationToken cancellationToken) { - var connected = await OpenSocketAsync(hosts, port, cancellationToken).ConfigureAwait(false); + var connected = await OpenSocketAsync(hosts, port, ioBehavior, cancellationToken).ConfigureAwait(false); if (!connected) throw new MySqlException("Unable to connect to any of the specified MySQL hosts."); - var payload = await ReceiveAsync(cancellationToken).ConfigureAwait(false); + var payload = await ReceiveAsync(ioBehavior, cancellationToken).ConfigureAwait(false); var reader = new ByteArrayReader(payload.ArraySegment.Array, payload.ArraySegment.Offset, payload.ArraySegment.Count); var initialHandshake = new InitialHandshakePacket(reader); if (initialHandshake.AuthPluginName != "mysql_native_password") @@ -70,22 +70,22 @@ public async Task ConnectAsync(IEnumerable hosts, int port, string userI var response = HandshakeResponse41Packet.Create(initialHandshake, userId, password, database); payload = new PayloadData(new ArraySegment(response)); - await SendReplyAsync(payload, cancellationToken).ConfigureAwait(false); - await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await SendReplyAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); + await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); } - public async Task ResetConnectionAsync(string userId, string password, string database, CancellationToken cancellationToken) + public async Task ResetConnectionAsync(string userId, string password, string database, IOBehavior ioBehavior, CancellationToken cancellationToken) { if (ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0) { - await SendAsync(ResetConnectionPayload.Create(), cancellationToken).ConfigureAwait(false); - var payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await SendAsync(ResetConnectionPayload.Create(), ioBehavior, cancellationToken).ConfigureAwait(false); + var payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); OkPayload.Create(payload); // the "reset connection" packet also resets the connection charset, so we need to change that back to our default payload = new PayloadData(new ArraySegment(Payload.CreateEofStringPayload(CommandKind.Query, "SET NAMES utf8mb4;"))); - await SendAsync(payload, cancellationToken).ConfigureAwait(false); - payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); + payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); OkPayload.Create(payload); } else @@ -93,8 +93,8 @@ public async Task ResetConnectionAsync(string userId, string password, string da // optimistically hash the password with the challenge from the initial handshake (supported by MariaDB; doesn't appear to be supported by MySQL) var hashedPassword = AuthenticationUtility.CreateAuthenticationResponse(AuthPluginData, 0, password); var payload = ChangeUserPayload.Create(userId, hashedPassword, database); - await SendAsync(payload, cancellationToken).ConfigureAwait(false); - payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); + payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); if (payload.HeaderByte == AuthenticationMethodSwitchRequestPayload.Signature) { // if the server didn't support the hashed password; rehash with the new challenge @@ -103,19 +103,19 @@ public async Task ResetConnectionAsync(string userId, string password, string da throw new NotSupportedException("Only 'mysql_native_password' authentication method is supported."); hashedPassword = AuthenticationUtility.CreateAuthenticationResponse(switchRequest.Data, 0, password); payload = new PayloadData(new ArraySegment(hashedPassword)); - await SendReplyAsync(payload, cancellationToken).ConfigureAwait(false); - payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await SendReplyAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); + payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); } OkPayload.Create(payload); } } - public async Task TryPingAsync(CancellationToken cancellationToken) + public async Task TryPingAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { - await SendAsync(PingPayload.Create(), cancellationToken).ConfigureAwait(false); + await SendAsync(PingPayload.Create(), ioBehavior, cancellationToken).ConfigureAwait(false); try { - var payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + var payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); OkPayload.Create(payload); return true; } @@ -130,20 +130,20 @@ public async Task TryPingAsync(CancellationToken cancellationToken) } // Starts a new conversation with the server by sending the first packet. - public Task SendAsync(PayloadData payload, CancellationToken cancellationToken) - => TryAsync(m_transmitter.SendAsync, payload, cancellationToken); + public Task SendAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken) + => TryAsync(m_transmitter.SendAsync, payload, ioBehavior, cancellationToken); // Starts a new conversation with the server by receiving the first packet. - public ValueTask ReceiveAsync(CancellationToken cancellationToken) - => TryAsync(m_transmitter.ReceiveAsync, cancellationToken); + public ValueTask ReceiveAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) + => TryAsync(m_transmitter.ReceiveAsync, ioBehavior, cancellationToken); // Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'. - public ValueTask ReceiveReplyAsync(CancellationToken cancellationToken) - => TryAsync(m_transmitter.ReceiveReplyAsync, cancellationToken); + public ValueTask ReceiveReplyAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) + => TryAsync(m_transmitter.ReceiveReplyAsync, ioBehavior, cancellationToken); // Continues a conversation with the server by sending a reply to a packet received with 'Receive' or 'ReceiveReply'. - public Task SendReplyAsync(PayloadData payload, CancellationToken cancellationToken) - => TryAsync(m_transmitter.SendReplyAsync, payload, cancellationToken); + public Task SendReplyAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken) + => TryAsync(m_transmitter.SendReplyAsync, payload, ioBehavior, cancellationToken); private void VerifyConnected() { @@ -153,14 +153,26 @@ private void VerifyConnected() throw new InvalidOperationException("MySqlSession is not connected."); } - private async Task OpenSocketAsync(IEnumerable hostnames, int port, CancellationToken cancellationToken) + private async Task OpenSocketAsync(IEnumerable hostnames, int port, IOBehavior ioBehavior, CancellationToken cancellationToken) { foreach (var hostname in hostnames) { IPAddress[] ipAddresses; try { +#if NETSTANDARD1_3 + // Dns.GetHostAddresses isn't available until netstandard 2.0: https://github.com/dotnet/corefx/pull/11950 ipAddresses = await Dns.GetHostAddressesAsync(hostname).ConfigureAwait(false); +#else + if (ioBehavior == IOBehavior.Asynchronous) + { + ipAddresses = await Dns.GetHostAddressesAsync(hostname).ConfigureAwait(false); + } + else + { + ipAddresses = Dns.GetHostAddresses(hostname); + } +#endif } catch (SocketException) { @@ -179,11 +191,18 @@ private async Task OpenSocketAsync(IEnumerable hostnames, int port { try { + if (ioBehavior == IOBehavior.Asynchronous) + { #if NETSTANDARD1_3 - await socket.ConnectAsync(ipAddress, port).ConfigureAwait(false); + await socket.ConnectAsync(ipAddress, port).ConfigureAwait(false); #else - await Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ipAddress, port, null).ConfigureAwait(false); + await Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ipAddress, port, null).ConfigureAwait(false); #endif + } + else + { + socket.Connect(ipAddress, port); + } } catch (ObjectDisposedException ex) when (cancellationToken.IsCancellationRequested) { @@ -206,10 +225,10 @@ private async Task OpenSocketAsync(IEnumerable hostnames, int port return false; } - private Task TryAsync(Func func, TArg arg, CancellationToken cancellationToken) + private Task TryAsync(Func func, TArg arg, IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyConnected(); - var task = func(arg, cancellationToken); + var task = func(arg, ioBehavior, cancellationToken); if (task.Status == TaskStatus.RanToCompletion) return task; @@ -225,10 +244,10 @@ private void TryAsyncContinuation(Task task) } } - private ValueTask TryAsync(Func> func, CancellationToken cancellationToken) + private ValueTask TryAsync(Func> func, IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyConnected(); - var task = func(cancellationToken); + var task = func(ioBehavior, cancellationToken); if (task.IsCompletedSuccessfully) { if (task.Result.HeaderByte != ErrorPayload.Signature) diff --git a/src/MySqlConnector/Serialization/PacketTransmitter.cs b/src/MySqlConnector/Serialization/PacketTransmitter.cs index 27969f990..6e980454e 100644 --- a/src/MySqlConnector/Serialization/PacketTransmitter.cs +++ b/src/MySqlConnector/Serialization/PacketTransmitter.cs @@ -18,32 +18,32 @@ public PacketTransmitter(Socket socket) } // Starts a new conversation with the server by sending the first packet. - public Task SendAsync(PayloadData payload, CancellationToken cancellationToken) + public Task SendAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken) { m_sequenceId = 0; - return DoSendAsync(payload, cancellationToken); + return DoSendAsync(payload, ioBehavior, cancellationToken); } // Starts a new conversation with the server by receiving the first packet. - public ValueTask ReceiveAsync(CancellationToken cancellationToken) + public ValueTask ReceiveAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { m_sequenceId = 0; - return DoReceiveAsync(cancellationToken); + return DoReceiveAsync(ProtocolErrorBehavior.Throw, ioBehavior, cancellationToken); } // Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'. - public ValueTask ReceiveReplyAsync(CancellationToken cancellationToken) - => DoReceiveAsync(cancellationToken); + public ValueTask ReceiveReplyAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) + => DoReceiveAsync(ProtocolErrorBehavior.Throw, ioBehavior, cancellationToken); // Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'. - public ValueTask TryReceiveReplyAsync(CancellationToken cancellationToken) - => DoReceiveAsync(cancellationToken, optional: true); + public ValueTask TryReceiveReplyAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) + => DoReceiveAsync(ProtocolErrorBehavior.Ignore, ioBehavior, cancellationToken); // Continues a conversation with the server by sending a reply to a packet received with 'Receive' or 'ReceiveReply'. - public Task SendReplyAsync(PayloadData payload, CancellationToken cancellationToken) - => DoSendAsync(payload, cancellationToken); + public Task SendReplyAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken) + => DoSendAsync(payload, ioBehavior, cancellationToken); - private async Task DoSendAsync(PayloadData payload, CancellationToken cancellationToken) + private async Task DoSendAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken) { var bytesSent = 0; var data = payload.ArraySegment; @@ -61,23 +61,39 @@ private async Task DoSendAsync(PayloadData payload, CancellationToken cancellati if (bytesToSend <= m_buffer.Length - 4) { Buffer.BlockCopy(data.Array, data.Offset + bytesSent, m_buffer, 4, bytesToSend); - m_socketAwaitable.EventArgs.SetBuffer(0, bytesToSend + 4); - await m_socket.SendAsync(m_socketAwaitable); + var count = bytesToSend + 4; + if (ioBehavior == IOBehavior.Asynchronous) + { + m_socketAwaitable.EventArgs.SetBuffer(0, count); + await m_socket.SendAsync(m_socketAwaitable); + } + else + { + m_socket.Send(m_buffer, 0, count, SocketFlags.None); + } } else { - m_socketAwaitable.EventArgs.SetBuffer(null, 0, 0); - m_socketAwaitable.EventArgs.BufferList = new[] { new ArraySegment(m_buffer, 0, 4), new ArraySegment(data.Array, data.Offset + bytesSent, bytesToSend) }; - await m_socket.SendAsync(m_socketAwaitable); - m_socketAwaitable.EventArgs.BufferList = null; - m_socketAwaitable.EventArgs.SetBuffer(m_buffer, 0, 0); + if (ioBehavior == IOBehavior.Asynchronous) + { + m_socketAwaitable.EventArgs.SetBuffer(null, 0, 0); + m_socketAwaitable.EventArgs.BufferList = new[] { new ArraySegment(m_buffer, 0, 4), new ArraySegment(data.Array, data.Offset + bytesSent, bytesToSend) }; + await m_socket.SendAsync(m_socketAwaitable); + m_socketAwaitable.EventArgs.BufferList = null; + m_socketAwaitable.EventArgs.SetBuffer(m_buffer, 0, 0); + } + else + { + m_socket.Send(m_buffer, 0, 4, SocketFlags.None); + m_socket.Send(data.Array, data.Offset + bytesSent, bytesToSend, SocketFlags.None); + } } bytesSent += bytesToSend; } while (bytesToSend == c_maxPacketSize); } - private ValueTask DoReceiveAsync(CancellationToken cancellationToken, bool optional = false) + private ValueTask DoReceiveAsync(ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior, CancellationToken cancellationToken) { if (m_end - m_offset > 4) { @@ -86,7 +102,7 @@ private ValueTask DoReceiveAsync(CancellationToken cancellationToke { if (m_buffer[m_offset + 3] != (byte) (m_sequenceId & 0xFF)) { - if (optional) + if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore) return new ValueTask(default(PayloadData)); throw new InvalidOperationException("Packet received out-of-order. Expected {0}; got {1}.".FormatInvariant(m_sequenceId & 0xFF, m_buffer[3])); } @@ -100,13 +116,13 @@ private ValueTask DoReceiveAsync(CancellationToken cancellationToke } } - return new ValueTask(DoReceiveAsync2(cancellationToken, optional)); + return new ValueTask(DoReceiveAsync2(protocolErrorBehavior, ioBehavior, cancellationToken)); } - private async Task DoReceiveAsync2(CancellationToken cancellationToken, bool optional = false) + private async Task DoReceiveAsync2(ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior, CancellationToken cancellationToken) { // common case: the payload is contained within one packet - var payload = await ReceivePacketAsync(cancellationToken, optional).ConfigureAwait(false); + var payload = await ReceivePacketAsync(protocolErrorBehavior, ioBehavior, cancellationToken).ConfigureAwait(false); if (payload == null || payload.ArraySegment.Count != c_maxPacketSize) return payload; @@ -117,7 +133,7 @@ private async Task DoReceiveAsync2(CancellationToken cancellationTo do { - payload = await ReceivePacketAsync(cancellationToken, optional).ConfigureAwait(false); + payload = await ReceivePacketAsync(protocolErrorBehavior, ioBehavior, cancellationToken).ConfigureAwait(false); var oldLength = payloadBytes.Length; Array.Resize(ref payloadBytes, payloadBytes.Length + payload.ArraySegment.Count); @@ -127,7 +143,7 @@ private async Task DoReceiveAsync2(CancellationToken cancellationTo return new PayloadData(new ArraySegment(payloadBytes)); } - private async Task ReceivePacketAsync(CancellationToken cancellationToken, bool optional) + private async Task ReceivePacketAsync(ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior, CancellationToken cancellationToken) { if (m_end - m_offset < 4) { @@ -142,12 +158,21 @@ private async Task ReceivePacketAsync(CancellationToken cancellatio int count = m_buffer.Length - m_end; while (m_end - m_offset < 4) { - m_socketAwaitable.EventArgs.SetBuffer(offset, count); - await m_socket.ReceiveAsync(m_socketAwaitable); - int bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; + int bytesRead; + if (ioBehavior == IOBehavior.Asynchronous) + { + m_socketAwaitable.EventArgs.SetBuffer(offset, count); + await m_socket.ReceiveAsync(m_socketAwaitable); + bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; + } + else + { + bytesRead = m_socket.Receive(m_buffer, offset, count, SocketFlags.None); + } + if (bytesRead <= 0) { - if (optional) + if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore) return null; throw new EndOfStreamException(); } @@ -160,7 +185,7 @@ private async Task ReceivePacketAsync(CancellationToken cancellatio int payloadLength = (int) SerializationUtility.ReadUInt32(m_buffer, m_offset, 3); if (m_buffer[m_offset + 3] != (byte) (m_sequenceId & 0xFF)) { - if (optional) + if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore) return null; throw new InvalidOperationException("Packet received out-of-order. Expected {0}; got {1}.".FormatInvariant(m_sequenceId & 0xFF, m_buffer[3])); } @@ -179,7 +204,8 @@ private async Task ReceivePacketAsync(CancellationToken cancellatio if (payloadLength > m_buffer.Length) { readData = new byte[payloadLength]; - m_socketAwaitable.EventArgs.SetBuffer(readData, 0, 0); + if (ioBehavior == IOBehavior.Asynchronous) + m_socketAwaitable.EventArgs.SetBuffer(readData, 0, 0); } Buffer.BlockCopy(m_buffer, m_offset, readData, 0, m_end - m_offset); m_end -= m_offset; @@ -190,9 +216,17 @@ private async Task ReceivePacketAsync(CancellationToken cancellatio count = readData.Length - m_end; while (m_end < payloadLength) { - m_socketAwaitable.EventArgs.SetBuffer(offset, count); - await m_socket.ReceiveAsync(m_socketAwaitable); - int bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; + int bytesRead; + if (ioBehavior == IOBehavior.Asynchronous) + { + m_socketAwaitable.EventArgs.SetBuffer(offset, count); + await m_socket.ReceiveAsync(m_socketAwaitable); + bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; + } + else + { + bytesRead = m_socket.Receive(readData, offset, count, SocketFlags.None); + } if (bytesRead <= 0) throw new EndOfStreamException(); offset += bytesRead; @@ -203,7 +237,8 @@ private async Task ReceivePacketAsync(CancellationToken cancellatio // switch back to original buffer if a larger one was allocated if (payloadLength > m_buffer.Length) { - m_socketAwaitable.EventArgs.SetBuffer(m_buffer, 0, 0); + if (ioBehavior == IOBehavior.Asynchronous) + m_socketAwaitable.EventArgs.SetBuffer(m_buffer, 0, 0); m_end = 0; } diff --git a/src/MySqlConnector/Serialization/ProtocolErrorBehavior.cs b/src/MySqlConnector/Serialization/ProtocolErrorBehavior.cs new file mode 100644 index 000000000..d78b983f6 --- /dev/null +++ b/src/MySqlConnector/Serialization/ProtocolErrorBehavior.cs @@ -0,0 +1,18 @@ +namespace MySql.Data.Serialization +{ + /// + /// Specifies how to handle protocol errors. + /// + internal enum ProtocolErrorBehavior + { + /// + /// Throw an exception when there is a protocol error. This is the default. + /// + Throw, + + /// + /// Ignore any protocol errors. + /// + Ignore, + } +}