Skip to content

Commit

Permalink
Add Synchronous to MySqlConnectionStringBuilder.
Browse files Browse the repository at this point in the history
Setting Synchronous to true forces async requests to block on the calling thread to improve debugging.
  • Loading branch information
ejball committed Jul 11, 2016
1 parent 591f2ab commit d300616
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/MySqlConnector/MySqlClient/MySqlCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ protected override async Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBeha
var preparer = new MySqlStatementPreparer(CommandText, m_parameterCollection, connection.AllowUserVariables ? StatementPreparerOptions.AllowUserVariables : StatementPreparerOptions.None);
preparer.BindParameters();
var payload = new PayloadData(new ArraySegment<byte>(Payload.CreateEofStringPayload(CommandKind.Query, preparer.PreparedSql)));
await Session.SendAsync(payload, cancellationToken).ConfigureAwait(false);
await Connection.AdaptTask(Session.SendAsync(payload, cancellationToken)).ConfigureAwait(false);
reader = await MySqlDataReader.CreateAsync(this, behavior, cancellationToken).ConfigureAwait(false);
return reader;
}
Expand Down
46 changes: 36 additions & 10 deletions src/MySqlConnector/MySqlClient/MySqlConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ public override async Task OpenAsync(CancellationToken cancellationToken)
if (m_session == null)
{
m_session = new MySqlSession(pool);
var connected = await m_session.ConnectAsync(m_connectionStringBuilder.Server.Split(','), (int) m_connectionStringBuilder.Port).ConfigureAwait(false);
var connected = await AdaptTask(m_session.ConnectAsync(m_connectionStringBuilder.Server.Split(','), (int) m_connectionStringBuilder.Port)).ConfigureAwait(false);
if (!connected)
{
SetState(ConnectionState.Closed);
throw new MySqlException("Unable to connect to any of the specified MySQL hosts.");
}

var payload = await m_session.ReceiveAsync(cancellationToken).ConfigureAwait(false);
var payload = await AdaptTask(m_session.ReceiveAsync(cancellationToken)).ConfigureAwait(false);
var reader = new ByteArrayReader(payload.ArraySegment.Array, payload.ArraySegment.Offset, payload.ArraySegment.Count);
var initialHandshake = new InitialHandshakePacket(reader);
if (initialHandshake.AuthPluginName != "mysql_native_password")
Expand All @@ -130,32 +130,32 @@ public override async Task OpenAsync(CancellationToken cancellationToken)

var response = HandshakeResponse41Packet.Create(initialHandshake, m_connectionStringBuilder.UserID, m_connectionStringBuilder.Password, m_database);
payload = new PayloadData(new ArraySegment<byte>(response));
await m_session.SendReplyAsync(payload, cancellationToken).ConfigureAwait(false);
await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
await AdaptTask(m_session.SendReplyAsync(payload, cancellationToken)).ConfigureAwait(false);
await AdaptTask(m_session.ReceiveReplyAsync(cancellationToken)).ConfigureAwait(false);
// TODO: Check success
}
else if (m_connectionStringBuilder.ConnectionReset)
{
if (m_session.ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0)
{
await m_session.SendAsync(ResetConnectionPayload.Create(), cancellationToken).ConfigureAwait(false);
var payload = await m_session.ReceiveReplyAsync(cancellationToken);
await AdaptTask(m_session.SendAsync(ResetConnectionPayload.Create(), cancellationToken)).ConfigureAwait(false);
var payload = await AdaptTask(m_session.ReceiveReplyAsync(cancellationToken));
OkPayload.Create(payload);
}
else
{
// MySQL doesn't appear to accept a replayed hashed password (using the challenge from the initial handshake), so just send zeroes
// and expect to get a new challenge
var payload = ChangeUserPayload.Create(m_connectionStringBuilder.UserID, new byte[20], m_database);
await m_session.SendAsync(payload, cancellationToken).ConfigureAwait(false);
payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
await AdaptTask(m_session.SendAsync(payload, cancellationToken)).ConfigureAwait(false);
payload = await AdaptTask(m_session.ReceiveReplyAsync(cancellationToken)).ConfigureAwait(false);
var switchRequest = AuthenticationMethodSwitchRequestPayload.Create(payload);
if (switchRequest.Name != "mysql_native_password")
throw new NotSupportedException("Only 'mysql_native_password' authentication method is supported.");
var hashedPassword = AuthenticationUtility.HashPassword(switchRequest.Data, 0, m_connectionStringBuilder.Password);
payload = new PayloadData(new ArraySegment<byte>(hashedPassword));
await m_session.SendReplyAsync(payload, cancellationToken).ConfigureAwait(false);
payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
await AdaptTask(m_session.SendReplyAsync(payload, cancellationToken)).ConfigureAwait(false);
payload = await AdaptTask(m_session.ReceiveReplyAsync(cancellationToken)).ConfigureAwait(false);
OkPayload.Create(payload);
}
}
Expand Down Expand Up @@ -240,6 +240,32 @@ internal MySqlSession Session
internal bool AllowUserVariables => m_connectionStringBuilder.AllowUserVariables;
internal bool ConvertZeroDateTime => m_connectionStringBuilder.ConvertZeroDateTime;
internal bool OldGuids => m_connectionStringBuilder.OldGuids;
internal bool Synchronous => m_connectionStringBuilder.Synchronous;

internal Task AdaptTask(Task task)
{
if (!Synchronous)
return task;

task.GetAwaiter().GetResult();
return Task.FromResult<object>(null);
}

internal Task<T> AdaptTask<T>(Task<T> task)
{
if (!Synchronous)
return task;

return Task.FromResult(task.GetAwaiter().GetResult());
}

internal ValueTask<T> AdaptTask<T>(ValueTask<T> task)
{
if (!Synchronous)
return task;

return new ValueTask<T>(task.AsTask().GetAwaiter().GetResult());
}

private void SetState(ConnectionState newState)
{
Expand Down
11 changes: 11 additions & 0 deletions src/MySqlConnector/MySqlClient/MySqlConnectionStringBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ public uint MaximumPoolSize
set { MySqlConnectionStringOption.MaximumPoolSize.SetValue(this, value); }
}

public bool Synchronous
{
get { return MySqlConnectionStringOption.Synchronous.GetValue(this); }
set { MySqlConnectionStringOption.Synchronous.SetValue(this, value); }
}

public override bool ContainsKey(string key)
{
var option = MySqlConnectionStringOption.TryGetOptionForKey(key);
Expand Down Expand Up @@ -165,6 +171,7 @@ internal abstract class MySqlConnectionStringOption
public static readonly MySqlConnectionStringOption<bool> ConnectionReset;
public static readonly MySqlConnectionStringOption<uint> MinimumPoolSize;
public static readonly MySqlConnectionStringOption<uint> MaximumPoolSize;
public static readonly MySqlConnectionStringOption<bool> Synchronous;

public static MySqlConnectionStringOption TryGetOptionForKey(string key)
{
Expand Down Expand Up @@ -259,6 +266,10 @@ static MySqlConnectionStringOption()
AddOption(MaximumPoolSize = new MySqlConnectionStringOption<uint>(
keys: new[] { "Maximum Pool Size", "Max Pool Size", "MaximumPoolSize", "maxpoolsize" },
defaultValue: 100));

AddOption(Synchronous = new MySqlConnectionStringOption<bool>(
keys: new[] { "Synchronous" },
defaultValue: false));
}

static readonly Dictionary<string, MySqlConnectionStringOption> s_options;
Expand Down
8 changes: 4 additions & 4 deletions src/MySqlConnector/MySqlClient/MySqlDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public override Task<bool> ReadAsync(CancellationToken cancellationToken)

if (m_state != State.AlreadyReadFirstRow)
{
var payloadTask = m_session.ReceiveReplyAsync(cancellationToken);
var payloadTask = Connection.AdaptTask(m_session.ReceiveReplyAsync(cancellationToken));
if (payloadTask.IsCompletedSuccessfully)
return ReadAsyncRemainder(payloadTask.Result) ? s_trueTask : s_falseTask;
return ReadAsyncAwaited(payloadTask.AsTask());
Expand Down Expand Up @@ -566,7 +566,7 @@ private async Task ReadResultSetHeaderAsync(CancellationToken cancellationToken)
{
while (true)
{
var payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
var payload = await Connection.AdaptTask(m_session.ReceiveReplyAsync(cancellationToken)).ConfigureAwait(false);

var firstByte = payload.HeaderByte;
if (firstByte == OkPayload.Signature)
Expand All @@ -593,11 +593,11 @@ private async Task ReadResultSetHeaderAsync(CancellationToken cancellationToken)

for (var column = 0; column < m_columnDefinitions.Length; column++)
{
payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
payload = await Connection.AdaptTask(m_session.ReceiveReplyAsync(cancellationToken)).ConfigureAwait(false);
m_columnDefinitions[column] = ColumnDefinitionPayload.Create(payload);
}

payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
payload = await Connection.AdaptTask(m_session.ReceiveReplyAsync(cancellationToken)).ConfigureAwait(false);
EofPayload.Create(payload);

m_command.LastInsertedId = -1;
Expand Down

0 comments on commit d300616

Please sign in to comment.