diff --git a/src/MySqlConnector/IOKind.cs b/src/MySqlConnector/IOKind.cs new file mode 100644 index 000000000..52e503a90 --- /dev/null +++ b/src/MySqlConnector/IOKind.cs @@ -0,0 +1,19 @@ +namespace MySql.Data +{ + /// + /// Specifies whether to perform synchronous or asynchronous I/O. + /// + internal enum IOKind + { + /// + /// Use synchronous I/O. + /// + Synchronous, + + /// + /// Use asynchronous I/O. + /// + + Asynchronous, + } +} diff --git a/src/MySqlConnector/MySqlClient/ConnectionPool.cs b/src/MySqlConnector/MySqlClient/ConnectionPool.cs index 1c5a38e99..24a0eb547 100644 --- a/src/MySqlConnector/MySqlClient/ConnectionPool.cs +++ b/src/MySqlConnector/MySqlClient/ConnectionPool.cs @@ -1,5 +1,4 @@ -using System; -using System.Collections.Concurrent; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -9,8 +8,7 @@ namespace MySql.Data.MySqlClient { internal sealed class ConnectionPool { - - public async Task GetSessionAsync(CancellationToken cancellationToken) + public async Task GetSessionAsync(IOKind ioKind, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -26,17 +24,17 @@ public async Task GetSessionAsync(CancellationToken cancellationTo // check for a pooled session if (m_sessions.TryDequeue(out session)) { - if (!await session.TryPingAsync(cancellationToken).ConfigureAwait(false)) + if (!await session.TryPingAsync(ioKind, cancellationToken).ConfigureAwait(false)) { // session is not valid - await session.DisposeAsync(cancellationToken).ConfigureAwait(false); + await session.DisposeAsync(ioKind, cancellationToken).ConfigureAwait(false); } else { // session is valid, reset if supported if (m_resetConnections) { - await session.ResetConnectionAsync(m_userId, m_password, m_database, cancellationToken).ConfigureAwait(false); + await session.ResetConnectionAsync(m_userId, m_password, m_database, ioKind, cancellationToken).ConfigureAwait(false); } // pooled session is ready to be used; return it return session; @@ -44,7 +42,7 @@ public async Task GetSessionAsync(CancellationToken cancellationTo } session = new MySqlSession(this); - await session.ConnectAsync(m_servers, m_port, m_userId, m_password, m_database, m_connectionTimeout, cancellationToken).ConfigureAwait(false); + await session.ConnectAsync(m_servers, m_port, m_userId, m_password, m_database, m_connectionTimeout, ioKind, cancellationToken).ConfigureAwait(false); return session; } catch @@ -66,7 +64,7 @@ public void Return(MySqlSession session) } } - public async Task ClearAsync(CancellationToken cancellationToken) + public async Task ClearAsync(IOKind ioKind, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); try @@ -84,7 +82,7 @@ public async Task ClearAsync(CancellationToken cancellationToken) MySqlSession session; while (m_sessions.TryDequeue(out session)) { - tasks.Add(session.DisposeAsync(cancellationToken)); + tasks.Add(session.DisposeAsync(ioKind, cancellationToken)); } if (tasks.Count > 0) { @@ -118,12 +116,12 @@ public static ConnectionPool GetPool(MySqlConnectionStringBuilder csb) return pool; } - public static async Task ClearPoolsAsync(CancellationToken cancellationToken) + public static async Task ClearPoolsAsync(IOKind ioKind, CancellationToken cancellationToken) { var pools = new List(s_pools.Values); foreach (var pool in pools) - await pool.ClearAsync(cancellationToken).ConfigureAwait(false); + await pool.ClearAsync(ioKind, cancellationToken).ConfigureAwait(false); } private ConnectionPool(IEnumerable servers, int port, string userId, string password, string database, int connectionTimeout, diff --git a/src/MySqlConnector/MySqlClient/MySqlCommand.cs b/src/MySqlConnector/MySqlClient/MySqlCommand.cs index 218b5b58f..58bbdcf8a 100644 --- a/src/MySqlConnector/MySqlClient/MySqlCommand.cs +++ b/src/MySqlConnector/MySqlClient/MySqlCommand.cs @@ -53,10 +53,10 @@ public override void Cancel() } public override int ExecuteNonQuery() - => ExecuteNonQueryAsync(CancellationToken.None).GetAwaiter().GetResult(); + => ExecuteNonQueryAsync(IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); public override object ExecuteScalar() - => ExecuteScalarAsync(CancellationToken.None).GetAwaiter().GetResult(); + => ExecuteScalarAsync(IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); public override void Prepare() { @@ -103,37 +103,46 @@ protected override DbParameter CreateDbParameter() } protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) - => ExecuteDbDataReaderAsync(behavior, CancellationToken.None).GetAwaiter().GetResult(); + => ExecuteReaderAsync(behavior, IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - public override async Task ExecuteNonQueryAsync(CancellationToken cancellationToken) + public override Task ExecuteNonQueryAsync(CancellationToken cancellationToken) => + ExecuteNonQueryAsync(IOKind.Asynchronous, cancellationToken); + + internal async Task ExecuteNonQueryAsync(IOKind ioKind, CancellationToken cancellationToken) { - using (var reader = await ExecuteReaderAsync(cancellationToken).ConfigureAwait(false)) + using (var reader = (MySqlDataReader) await ExecuteReaderAsync(CommandBehavior.Default, ioKind, cancellationToken).ConfigureAwait(false)) { do { - while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + while (await reader.ReadAsync(ioKind, cancellationToken).ConfigureAwait(false)) { } - } while (await reader.NextResultAsync(cancellationToken).ConfigureAwait(false)); + } while (await reader.NextResultAsync(ioKind, cancellationToken).ConfigureAwait(false)); return reader.RecordsAffected; } } - public override async Task ExecuteScalarAsync(CancellationToken cancellationToken) + public override Task ExecuteScalarAsync(CancellationToken cancellationToken) => + ExecuteScalarAsync(IOKind.Asynchronous, cancellationToken); + + internal async Task ExecuteScalarAsync(IOKind ioKind, CancellationToken cancellationToken) { object result = null; - using (var reader = await ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow, cancellationToken).ConfigureAwait(false)) + using (var reader = (MySqlDataReader) await ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow, ioKind, cancellationToken).ConfigureAwait(false)) { do { - if (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + if (await reader.ReadAsync(ioKind, cancellationToken).ConfigureAwait(false)) result = reader.GetValue(0); - } while (await reader.NextResultAsync(cancellationToken).ConfigureAwait(false)); + } while (await reader.NextResultAsync(ioKind, cancellationToken).ConfigureAwait(false)); } return result; } - protected override async Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) + protected override Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) => + ExecuteReaderAsync(behavior, IOKind.Asynchronous, cancellationToken); + + internal async Task ExecuteReaderAsync(CommandBehavior behavior, IOKind ioKind, CancellationToken cancellationToken) { VerifyValid(); Connection.HasActiveReader = true; @@ -151,8 +160,8 @@ protected override async Task ExecuteDbDataReaderAsync(CommandBeha var preparer = new MySqlStatementPreparer(CommandText, m_parameterCollection, statementPreparerOptions); preparer.BindParameters(); var payload = new PayloadData(new ArraySegment(Payload.CreateEofStringPayload(CommandKind.Query, preparer.PreparedSql))); - await Session.SendAsync(payload, cancellationToken).ConfigureAwait(false); - reader = await MySqlDataReader.CreateAsync(this, behavior, cancellationToken).ConfigureAwait(false); + await Session.SendAsync(payload, ioKind, cancellationToken).ConfigureAwait(false); + reader = await MySqlDataReader.CreateAsync(this, behavior, ioKind, cancellationToken).ConfigureAwait(false); return reader; } finally diff --git a/src/MySqlConnector/MySqlClient/MySqlConnection.cs b/src/MySqlConnector/MySqlClient/MySqlConnection.cs index 2ab9417df..6a2caae9a 100644 --- a/src/MySqlConnector/MySqlClient/MySqlConnection.cs +++ b/src/MySqlConnector/MySqlClient/MySqlConnection.cs @@ -24,15 +24,15 @@ public MySqlConnection(string connectionString) public new MySqlTransaction BeginTransaction() => (MySqlTransaction) base.BeginTransaction(); public Task BeginTransactionAsync(CancellationToken cancellationToken = default(CancellationToken)) => - BeginDbTransactionAsync(IsolationLevel.Unspecified, cancellationToken); + BeginDbTransactionAsync(IsolationLevel.Unspecified, IOKind.Asynchronous, cancellationToken); public Task BeginTransactionAsync(IsolationLevel isolationLevel, CancellationToken cancellationToken = default(CancellationToken)) => - BeginDbTransactionAsync(isolationLevel, cancellationToken); + BeginDbTransactionAsync(isolationLevel, IOKind.Asynchronous, cancellationToken); protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => - BeginDbTransactionAsync(isolationLevel).GetAwaiter().GetResult(); + BeginDbTransactionAsync(isolationLevel, IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - private async Task BeginDbTransactionAsync(IsolationLevel isolationLevel, CancellationToken cancellationToken = default(CancellationToken)) + private async Task BeginDbTransactionAsync(IsolationLevel isolationLevel, IOKind ioKind, CancellationToken cancellationToken) { if (State != ConnectionState.Open) throw new InvalidOperationException("Connection is not open."); @@ -67,7 +67,7 @@ protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLeve } using (var cmd = new MySqlCommand("set session transaction isolation level " + isolationLevelValue + "; start transaction;", this)) - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(ioKind, cancellationToken).ConfigureAwait(false); var transaction = new MySqlTransaction(this, isolationLevel); CurrentTransaction = transaction; @@ -88,9 +88,12 @@ public override void ChangeDatabase(string databaseName) throw new NotImplementedException(); } - public override void Open() => OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); + public override void Open() => OpenAsync(IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - public override async Task OpenAsync(CancellationToken cancellationToken) + public override Task OpenAsync(CancellationToken cancellationToken) => + OpenAsync(IOKind.Asynchronous, cancellationToken); + + private async Task OpenAsync(IOKind ioKind, CancellationToken cancellationToken) { VerifyNotDisposed(); if (State != ConnectionState.Closed) @@ -104,7 +107,7 @@ public override async Task OpenAsync(CancellationToken cancellationToken) try { - m_session = await CreateSessionAsync(cancellationToken).ConfigureAwait(false); + m_session = await CreateSessionAsync(ioKind, cancellationToken).ConfigureAwait(false); m_hasBeenOpened = true; SetState(ConnectionState.Open); @@ -147,20 +150,21 @@ public override string ConnectionString public override string ServerVersion => m_session.ServerVersion.OriginalString; - public static void ClearPool(MySqlConnection connection) => ClearPoolAsync(connection, CancellationToken.None).GetAwaiter().GetResult(); - public static void ClearAllPools() => ClearAllPoolsAsync(CancellationToken.None).GetAwaiter().GetResult(); - public static Task ClearPoolAsync(MySqlConnection connection) => ClearPoolAsync(connection, CancellationToken.None); - public static Task ClearAllPoolsAsync() => ClearAllPoolsAsync(CancellationToken.None); - public static Task ClearAllPoolsAsync(CancellationToken cancellationToken) => ConnectionPool.ClearPoolsAsync(cancellationToken); + public static void ClearPool(MySqlConnection connection) => ClearPoolAsync(connection, IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); + public static Task ClearPoolAsync(MySqlConnection connection) => ClearPoolAsync(connection, IOKind.Asynchronous, CancellationToken.None); + public static Task ClearPoolAsync(MySqlConnection connection, CancellationToken cancellationToken) => ClearPoolAsync(connection, IOKind.Asynchronous, cancellationToken); + public static void ClearAllPools() => ConnectionPool.ClearPoolsAsync(IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); + public static Task ClearAllPoolsAsync() => ConnectionPool.ClearPoolsAsync(IOKind.Asynchronous, CancellationToken.None); + public static Task ClearAllPoolsAsync(CancellationToken cancellationToken) => ConnectionPool.ClearPoolsAsync(IOKind.Asynchronous, cancellationToken); - public static async Task ClearPoolAsync(MySqlConnection connection, CancellationToken cancellationToken) + private static async Task ClearPoolAsync(MySqlConnection connection, IOKind ioKind, CancellationToken cancellationToken) { if (connection == null) throw new ArgumentNullException(nameof(connection)); var pool = ConnectionPool.GetPool(connection.m_connectionStringBuilder); if (pool != null) - await pool.ClearAsync(cancellationToken).ConfigureAwait(false); + await pool.ClearAsync(ioKind, cancellationToken).ConfigureAwait(false); } protected override DbCommand CreateDbCommand() => new MySqlCommand(this, CurrentTransaction); @@ -203,20 +207,20 @@ internal MySqlSession Session internal bool ConvertZeroDateTime => m_connectionStringBuilder.ConvertZeroDateTime; internal bool OldGuids => m_connectionStringBuilder.OldGuids; - private async Task CreateSessionAsync(CancellationToken cancellationToken) + private async Task CreateSessionAsync(IOKind ioKind, CancellationToken cancellationToken) { // get existing session from the pool if possible if (m_connectionStringBuilder.Pooling) { var pool = ConnectionPool.GetPool(m_connectionStringBuilder); // this returns an open session - return await pool.GetSessionAsync(cancellationToken).ConfigureAwait(false); + return await pool.GetSessionAsync(ioKind, cancellationToken).ConfigureAwait(false); } else { var session = new MySqlSession(null); await session.ConnectAsync(m_connectionStringBuilder.Server.Split(','), (int) m_connectionStringBuilder.Port, m_connectionStringBuilder.UserID, - m_connectionStringBuilder.Password, m_connectionStringBuilder.Database, (int) m_connectionStringBuilder.ConnectionTimeout, cancellationToken).ConfigureAwait(false); + m_connectionStringBuilder.Password, m_connectionStringBuilder.Database, (int) m_connectionStringBuilder.ConnectionTimeout, ioKind, cancellationToken).ConfigureAwait(false); return session; } } @@ -251,7 +255,7 @@ private void DoClose() if (m_connectionStringBuilder.Pooling) m_session.ReturnToPool(); else - m_session.DisposeAsync(CancellationToken.None).GetAwaiter().GetResult(); + m_session.DisposeAsync(IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); m_session = null; } SetState(ConnectionState.Closed); diff --git a/src/MySqlConnector/MySqlClient/MySqlDataReader.cs b/src/MySqlConnector/MySqlClient/MySqlDataReader.cs index b9d5ac2b7..a8d55918a 100644 --- a/src/MySqlConnector/MySqlClient/MySqlDataReader.cs +++ b/src/MySqlConnector/MySqlClient/MySqlDataReader.cs @@ -12,17 +12,18 @@ namespace MySql.Data.MySqlClient { public sealed class MySqlDataReader : DbDataReader { - public override bool NextResult() - { - return NextResultAsync(CancellationToken.None).GetAwaiter().GetResult(); - } + public override bool NextResult() => + NextResultAsync(IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); - public override async Task NextResultAsync(CancellationToken cancellationToken) + public override Task NextResultAsync(CancellationToken cancellationToken) => + NextResultAsync(IOKind.Asynchronous, cancellationToken); + + internal async Task NextResultAsync(IOKind ioKind, CancellationToken cancellationToken) { VerifyNotDisposed(); while (m_state == State.ReadingRows || m_state == State.ReadResultSetHeader) - await ReadAsync(cancellationToken).ConfigureAwait(false); + await ReadAsync(ioKind, cancellationToken).ConfigureAwait(false); var oldState = m_state; Reset(); @@ -31,17 +32,20 @@ public override async Task NextResultAsync(CancellationToken cancellationT if (oldState != State.HasMoreData) throw new InvalidOperationException("Invalid state: {0}".FormatInvariant(oldState)); - await ReadResultSetHeaderAsync(cancellationToken).ConfigureAwait(false); + await ReadResultSetHeaderAsync(ioKind, cancellationToken).ConfigureAwait(false); return true; } public override bool Read() { VerifyNotDisposed(); - return ReadAsync(CancellationToken.None).GetAwaiter().GetResult(); + return ReadAsync(IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); } - public override Task ReadAsync(CancellationToken cancellationToken) + public override Task ReadAsync(CancellationToken cancellationToken) => + ReadAsync(IOKind.Asynchronous, cancellationToken); + + internal Task ReadAsync(IOKind ioKind, CancellationToken cancellationToken) { VerifyNotDisposed(); @@ -51,7 +55,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) if (m_state != State.AlreadyReadFirstRow) { - var payloadTask = m_session.ReceiveReplyAsync(cancellationToken); + var payloadTask = m_session.ReceiveReplyAsync(ioKind, cancellationToken); if (payloadTask.IsCompletedSuccessfully) return ReadAsyncRemainder(payloadTask.Result) ? s_trueTask : s_falseTask; return ReadAsyncAwaited(payloadTask.AsTask()); @@ -645,10 +649,10 @@ private void DoClose() } } - internal static async Task CreateAsync(MySqlCommand command, CommandBehavior behavior, CancellationToken cancellationToken) + internal static async Task CreateAsync(MySqlCommand command, CommandBehavior behavior, IOKind ioKind, CancellationToken cancellationToken) { var dataReader = new MySqlDataReader(command, behavior); - await dataReader.ReadResultSetHeaderAsync(cancellationToken).ConfigureAwait(false); + await dataReader.ReadResultSetHeaderAsync(ioKind, cancellationToken).ConfigureAwait(false); return dataReader; } @@ -709,11 +713,11 @@ private MySqlDataReader(MySqlCommand command, CommandBehavior behavior) private MySqlConnection Connection => m_command.Connection; - private async Task ReadResultSetHeaderAsync(CancellationToken cancellationToken) + private async Task ReadResultSetHeaderAsync(IOKind ioKind, CancellationToken cancellationToken) { while (true) { - var payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + var payload = await m_session.ReceiveReplyAsync(ioKind, cancellationToken).ConfigureAwait(false); var firstByte = payload.HeaderByte; if (firstByte == OkPayload.Signature) @@ -740,11 +744,11 @@ private async Task ReadResultSetHeaderAsync(CancellationToken cancellationToken) for (var column = 0; column < m_columnDefinitions.Length; column++) { - payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + payload = await m_session.ReceiveReplyAsync(ioKind, cancellationToken).ConfigureAwait(false); m_columnDefinitions[column] = ColumnDefinitionPayload.Create(payload); } - payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + payload = await m_session.ReceiveReplyAsync(ioKind, cancellationToken).ConfigureAwait(false); EofPayload.Create(payload); m_command.LastInsertedId = -1; diff --git a/src/MySqlConnector/MySqlClient/MySqlTransaction.cs b/src/MySqlConnector/MySqlClient/MySqlTransaction.cs index fc2e338cf..cec416eb4 100644 --- a/src/MySqlConnector/MySqlClient/MySqlTransaction.cs +++ b/src/MySqlConnector/MySqlClient/MySqlTransaction.cs @@ -10,10 +10,13 @@ public class MySqlTransaction : DbTransaction { public override void Commit() { - CommitAsync().GetAwaiter().GetResult(); + CommitAsync(IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); } - public async Task CommitAsync(CancellationToken cancellationToken = default(CancellationToken)) + public Task CommitAsync(CancellationToken cancellationToken = default(CancellationToken)) + => CommitAsync(IOKind.Asynchronous, cancellationToken); + + internal async Task CommitAsync(IOKind ioKind, CancellationToken cancellationToken) { VerifyNotDisposed(); if (m_isFinished) @@ -22,7 +25,7 @@ public override void Commit() if (m_connection.CurrentTransaction == this) { using (var cmd = new MySqlCommand("commit", m_connection, this)) - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(ioKind, cancellationToken).ConfigureAwait(false); m_connection.CurrentTransaction = null; m_isFinished = true; } @@ -38,10 +41,13 @@ public override void Commit() public override void Rollback() { - RollbackAsync().GetAwaiter().GetResult(); + RollbackAsync(IOKind.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); } - public async Task RollbackAsync(CancellationToken cancellationToken = default(CancellationToken)) + public Task RollbackAsync(CancellationToken cancellationToken = default(CancellationToken)) + => RollbackAsync(IOKind.Asynchronous, cancellationToken); + + internal async Task RollbackAsync(IOKind ioKind, CancellationToken cancellationToken) { VerifyNotDisposed(); if (m_isFinished) @@ -50,7 +56,7 @@ public override void Rollback() if (m_connection.CurrentTransaction == this) { using (var cmd = new MySqlCommand("rollback", m_connection, this)) - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(ioKind, cancellationToken).ConfigureAwait(false); m_connection.CurrentTransaction = null; m_isFinished = true; } diff --git a/src/MySqlConnector/Serialization/MySqlSession.cs b/src/MySqlConnector/Serialization/MySqlSession.cs index 411bdcb90..57aaaa8ef 100644 --- a/src/MySqlConnector/Serialization/MySqlSession.cs +++ b/src/MySqlConnector/Serialization/MySqlSession.cs @@ -23,14 +23,14 @@ public MySqlSession(ConnectionPool pool) public void ReturnToPool() => Pool?.Return(this); - public async Task DisposeAsync(CancellationToken cancellationToken) + public async Task DisposeAsync(IOKind ioKind, CancellationToken cancellationToken) { if (m_transmitter != null) { try { - await m_transmitter.SendAsync(QuitPayload.Create(), cancellationToken).ConfigureAwait(false); - await m_transmitter.TryReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await m_transmitter.SendAsync(QuitPayload.Create(), ioKind, cancellationToken).ConfigureAwait(false); + await m_transmitter.TryReceiveReplyAsync(ioKind, cancellationToken).ConfigureAwait(false); } catch (SocketException) { @@ -54,13 +54,13 @@ public async Task DisposeAsync(CancellationToken cancellationToken) m_state = State.Closed; } - public async Task ConnectAsync(IEnumerable hosts, int port, string userId, string password, string database, int timeoutSeconds, CancellationToken cancellationToken) + public async Task ConnectAsync(IEnumerable hosts, int port, string userId, string password, string database, int timeoutSeconds, IOKind ioKind, CancellationToken cancellationToken) { var connected = await OpenSocketAsync(hosts, port, timeoutSeconds).ConfigureAwait(false); if (!connected) throw new MySqlException("Unable to connect to any of the specified MySQL hosts."); - var payload = await ReceiveAsync(cancellationToken).ConfigureAwait(false); + var payload = await ReceiveAsync(ioKind, cancellationToken).ConfigureAwait(false); var reader = new ByteArrayReader(payload.ArraySegment.Array, payload.ArraySegment.Offset, payload.ArraySegment.Count); var initialHandshake = new InitialHandshakePacket(reader); if (initialHandshake.AuthPluginName != "mysql_native_password") @@ -70,16 +70,16 @@ public async Task ConnectAsync(IEnumerable hosts, int port, string userI var response = HandshakeResponse41Packet.Create(initialHandshake, userId, password, database); payload = new PayloadData(new ArraySegment(response)); - await SendReplyAsync(payload, cancellationToken).ConfigureAwait(false); - await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await SendReplyAsync(payload, ioKind, cancellationToken).ConfigureAwait(false); + await ReceiveReplyAsync(ioKind, cancellationToken).ConfigureAwait(false); } - public async Task ResetConnectionAsync(string userId, string password, string database, CancellationToken cancellationToken) + public async Task ResetConnectionAsync(string userId, string password, string database, IOKind ioKind, CancellationToken cancellationToken) { if (ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0) { - await SendAsync(ResetConnectionPayload.Create(), cancellationToken).ConfigureAwait(false); - var payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await SendAsync(ResetConnectionPayload.Create(), ioKind, cancellationToken).ConfigureAwait(false); + var payload = await ReceiveReplyAsync(ioKind, cancellationToken).ConfigureAwait(false); OkPayload.Create(payload); } else @@ -87,8 +87,8 @@ public async Task ResetConnectionAsync(string userId, string password, string da // optimistically hash the password with the challenge from the initial handshake (supported by MariaDB; doesn't appear to be supported by MySQL) var hashedPassword = AuthenticationUtility.CreateAuthenticationResponse(AuthPluginData, 0, password); var payload = ChangeUserPayload.Create(userId, hashedPassword, database); - await SendAsync(payload, cancellationToken).ConfigureAwait(false); - payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await SendAsync(payload, ioKind, cancellationToken).ConfigureAwait(false); + payload = await ReceiveReplyAsync(ioKind, cancellationToken).ConfigureAwait(false); if (payload.HeaderByte == AuthenticationMethodSwitchRequestPayload.Signature) { // if the server didn't support the hashed password; rehash with the new challenge @@ -97,19 +97,19 @@ public async Task ResetConnectionAsync(string userId, string password, string da throw new NotSupportedException("Only 'mysql_native_password' authentication method is supported."); hashedPassword = AuthenticationUtility.CreateAuthenticationResponse(switchRequest.Data, 0, password); payload = new PayloadData(new ArraySegment(hashedPassword)); - await SendReplyAsync(payload, cancellationToken).ConfigureAwait(false); - payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + await SendReplyAsync(payload, ioKind, cancellationToken).ConfigureAwait(false); + payload = await ReceiveReplyAsync(ioKind, cancellationToken).ConfigureAwait(false); } OkPayload.Create(payload); } } - public async Task TryPingAsync(CancellationToken cancellationToken) + public async Task TryPingAsync(IOKind ioKind, CancellationToken cancellationToken) { - await SendAsync(PingPayload.Create(), cancellationToken).ConfigureAwait(false); + await SendAsync(PingPayload.Create(), ioKind, cancellationToken).ConfigureAwait(false); try { - var payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false); + var payload = await ReceiveReplyAsync(ioKind, cancellationToken).ConfigureAwait(false); OkPayload.Create(payload); return true; } @@ -124,20 +124,20 @@ public async Task TryPingAsync(CancellationToken cancellationToken) } // Starts a new conversation with the server by sending the first packet. - public Task SendAsync(PayloadData payload, CancellationToken cancellationToken) - => TryAsync(m_transmitter.SendAsync, payload, cancellationToken); + public Task SendAsync(PayloadData payload, IOKind ioKind, CancellationToken cancellationToken) + => TryAsync(m_transmitter.SendAsync, payload, ioKind, cancellationToken); // Starts a new conversation with the server by receiving the first packet. - public ValueTask ReceiveAsync(CancellationToken cancellationToken) - => TryAsync(m_transmitter.ReceiveAsync, cancellationToken); + public ValueTask ReceiveAsync(IOKind ioKind, CancellationToken cancellationToken) + => TryAsync(m_transmitter.ReceiveAsync, ioKind, cancellationToken); // Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'. - public ValueTask ReceiveReplyAsync(CancellationToken cancellationToken) - => TryAsync(m_transmitter.ReceiveReplyAsync, cancellationToken); + public ValueTask ReceiveReplyAsync(IOKind ioKind, CancellationToken cancellationToken) + => TryAsync(m_transmitter.ReceiveReplyAsync, ioKind, cancellationToken); // Continues a conversation with the server by sending a reply to a packet received with 'Receive' or 'ReceiveReply'. - public Task SendReplyAsync(PayloadData payload, CancellationToken cancellationToken) - => TryAsync(m_transmitter.SendReplyAsync, payload, cancellationToken); + public Task SendReplyAsync(PayloadData payload, IOKind ioKind, CancellationToken cancellationToken) + => TryAsync(m_transmitter.SendReplyAsync, payload, ioKind, cancellationToken); private void VerifyConnected() { @@ -212,10 +212,10 @@ private async Task OpenSocketAsync(IEnumerable hostnames, int port } - private Task TryAsync(Func func, TArg arg, CancellationToken cancellationToken) + private Task TryAsync(Func func, TArg arg, IOKind ioKind, CancellationToken cancellationToken) { VerifyConnected(); - var task = func(arg, cancellationToken); + var task = func(arg, ioKind, cancellationToken); if (task.Status == TaskStatus.RanToCompletion) return task; @@ -231,10 +231,10 @@ private void TryAsyncContinuation(Task task) } } - private ValueTask TryAsync(Func> func, CancellationToken cancellationToken) + private ValueTask TryAsync(Func> func, IOKind ioKind, CancellationToken cancellationToken) { VerifyConnected(); - var task = func(cancellationToken); + var task = func(ioKind, cancellationToken); if (task.IsCompletedSuccessfully) { if (task.Result.HeaderByte != ErrorPayload.Signature) diff --git a/src/MySqlConnector/Serialization/PacketTransmitter.cs b/src/MySqlConnector/Serialization/PacketTransmitter.cs index 27969f990..33e61cd5a 100644 --- a/src/MySqlConnector/Serialization/PacketTransmitter.cs +++ b/src/MySqlConnector/Serialization/PacketTransmitter.cs @@ -18,32 +18,32 @@ public PacketTransmitter(Socket socket) } // Starts a new conversation with the server by sending the first packet. - public Task SendAsync(PayloadData payload, CancellationToken cancellationToken) + public Task SendAsync(PayloadData payload, IOKind ioKind, CancellationToken cancellationToken) { m_sequenceId = 0; - return DoSendAsync(payload, cancellationToken); + return DoSendAsync(payload, ioKind, cancellationToken); } // Starts a new conversation with the server by receiving the first packet. - public ValueTask ReceiveAsync(CancellationToken cancellationToken) + public ValueTask ReceiveAsync(IOKind ioKind, CancellationToken cancellationToken) { m_sequenceId = 0; - return DoReceiveAsync(cancellationToken); + return DoReceiveAsync(cancellationToken, ioKind, optional: false); } // Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'. - public ValueTask ReceiveReplyAsync(CancellationToken cancellationToken) - => DoReceiveAsync(cancellationToken); + public ValueTask ReceiveReplyAsync(IOKind ioKind, CancellationToken cancellationToken) + => DoReceiveAsync(cancellationToken, ioKind, optional: false); // Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'. - public ValueTask TryReceiveReplyAsync(CancellationToken cancellationToken) - => DoReceiveAsync(cancellationToken, optional: true); + public ValueTask TryReceiveReplyAsync(IOKind ioKind, CancellationToken cancellationToken) + => DoReceiveAsync(cancellationToken, ioKind, optional: true); // Continues a conversation with the server by sending a reply to a packet received with 'Receive' or 'ReceiveReply'. - public Task SendReplyAsync(PayloadData payload, CancellationToken cancellationToken) - => DoSendAsync(payload, cancellationToken); + public Task SendReplyAsync(PayloadData payload, IOKind ioKind, CancellationToken cancellationToken) + => DoSendAsync(payload, ioKind, cancellationToken); - private async Task DoSendAsync(PayloadData payload, CancellationToken cancellationToken) + private async Task DoSendAsync(PayloadData payload, IOKind ioKind, CancellationToken cancellationToken) { var bytesSent = 0; var data = payload.ArraySegment; @@ -61,23 +61,39 @@ private async Task DoSendAsync(PayloadData payload, CancellationToken cancellati if (bytesToSend <= m_buffer.Length - 4) { Buffer.BlockCopy(data.Array, data.Offset + bytesSent, m_buffer, 4, bytesToSend); - m_socketAwaitable.EventArgs.SetBuffer(0, bytesToSend + 4); - await m_socket.SendAsync(m_socketAwaitable); + var count = bytesToSend + 4; + if (ioKind == IOKind.Asynchronous) + { + m_socketAwaitable.EventArgs.SetBuffer(0, count); + await m_socket.SendAsync(m_socketAwaitable); + } + else + { + m_socket.Send(m_buffer, 0, count, SocketFlags.None); + } } else { - m_socketAwaitable.EventArgs.SetBuffer(null, 0, 0); - m_socketAwaitable.EventArgs.BufferList = new[] { new ArraySegment(m_buffer, 0, 4), new ArraySegment(data.Array, data.Offset + bytesSent, bytesToSend) }; - await m_socket.SendAsync(m_socketAwaitable); - m_socketAwaitable.EventArgs.BufferList = null; - m_socketAwaitable.EventArgs.SetBuffer(m_buffer, 0, 0); + if (ioKind == IOKind.Asynchronous) + { + m_socketAwaitable.EventArgs.SetBuffer(null, 0, 0); + m_socketAwaitable.EventArgs.BufferList = new[] { new ArraySegment(m_buffer, 0, 4), new ArraySegment(data.Array, data.Offset + bytesSent, bytesToSend) }; + await m_socket.SendAsync(m_socketAwaitable); + m_socketAwaitable.EventArgs.BufferList = null; + m_socketAwaitable.EventArgs.SetBuffer(m_buffer, 0, 0); + } + else + { + m_socket.Send(m_buffer, 0, 4, SocketFlags.None); + m_socket.Send(data.Array, data.Offset + bytesSent, bytesToSend, SocketFlags.None); + } } bytesSent += bytesToSend; } while (bytesToSend == c_maxPacketSize); } - private ValueTask DoReceiveAsync(CancellationToken cancellationToken, bool optional = false) + private ValueTask DoReceiveAsync(CancellationToken cancellationToken, IOKind ioKind, bool optional) { if (m_end - m_offset > 4) { @@ -100,13 +116,13 @@ private ValueTask DoReceiveAsync(CancellationToken cancellationToke } } - return new ValueTask(DoReceiveAsync2(cancellationToken, optional)); + return new ValueTask(DoReceiveAsync2(cancellationToken, ioKind, optional)); } - private async Task DoReceiveAsync2(CancellationToken cancellationToken, bool optional = false) + private async Task DoReceiveAsync2(CancellationToken cancellationToken, IOKind ioKind, bool optional = false) { // common case: the payload is contained within one packet - var payload = await ReceivePacketAsync(cancellationToken, optional).ConfigureAwait(false); + var payload = await ReceivePacketAsync(cancellationToken, ioKind, optional).ConfigureAwait(false); if (payload == null || payload.ArraySegment.Count != c_maxPacketSize) return payload; @@ -117,7 +133,7 @@ private async Task DoReceiveAsync2(CancellationToken cancellationTo do { - payload = await ReceivePacketAsync(cancellationToken, optional).ConfigureAwait(false); + payload = await ReceivePacketAsync(cancellationToken, ioKind, optional).ConfigureAwait(false); var oldLength = payloadBytes.Length; Array.Resize(ref payloadBytes, payloadBytes.Length + payload.ArraySegment.Count); @@ -127,7 +143,7 @@ private async Task DoReceiveAsync2(CancellationToken cancellationTo return new PayloadData(new ArraySegment(payloadBytes)); } - private async Task ReceivePacketAsync(CancellationToken cancellationToken, bool optional) + private async Task ReceivePacketAsync(CancellationToken cancellationToken, IOKind ioKind, bool optional) { if (m_end - m_offset < 4) { @@ -142,9 +158,18 @@ private async Task ReceivePacketAsync(CancellationToken cancellatio int count = m_buffer.Length - m_end; while (m_end - m_offset < 4) { - m_socketAwaitable.EventArgs.SetBuffer(offset, count); - await m_socket.ReceiveAsync(m_socketAwaitable); - int bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; + int bytesRead; + if (ioKind == IOKind.Asynchronous) + { + m_socketAwaitable.EventArgs.SetBuffer(offset, count); + await m_socket.ReceiveAsync(m_socketAwaitable); + bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; + } + else + { + bytesRead = m_socket.Receive(m_buffer, offset, count, SocketFlags.None); + } + if (bytesRead <= 0) { if (optional) @@ -179,7 +204,8 @@ private async Task ReceivePacketAsync(CancellationToken cancellatio if (payloadLength > m_buffer.Length) { readData = new byte[payloadLength]; - m_socketAwaitable.EventArgs.SetBuffer(readData, 0, 0); + if (ioKind == IOKind.Asynchronous) + m_socketAwaitable.EventArgs.SetBuffer(readData, 0, 0); } Buffer.BlockCopy(m_buffer, m_offset, readData, 0, m_end - m_offset); m_end -= m_offset; @@ -190,9 +216,17 @@ private async Task ReceivePacketAsync(CancellationToken cancellatio count = readData.Length - m_end; while (m_end < payloadLength) { - m_socketAwaitable.EventArgs.SetBuffer(offset, count); - await m_socket.ReceiveAsync(m_socketAwaitable); - int bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; + int bytesRead; + if (ioKind == IOKind.Asynchronous) + { + m_socketAwaitable.EventArgs.SetBuffer(offset, count); + await m_socket.ReceiveAsync(m_socketAwaitable); + bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; + } + else + { + bytesRead = m_socket.Receive(readData, offset, count, SocketFlags.None); + } if (bytesRead <= 0) throw new EndOfStreamException(); offset += bytesRead; @@ -203,7 +237,8 @@ private async Task ReceivePacketAsync(CancellationToken cancellatio // switch back to original buffer if a larger one was allocated if (payloadLength > m_buffer.Length) { - m_socketAwaitable.EventArgs.SetBuffer(m_buffer, 0, 0); + if (ioKind == IOKind.Asynchronous) + m_socketAwaitable.EventArgs.SetBuffer(m_buffer, 0, 0); m_end = 0; }