Skip to content

Commit

Permalink
Move ConnectAsync, ResetConnectionAsync into MySqlSession.
Browse files Browse the repository at this point in the history
This will make it easier to create idle connections (to fill a connection pool).
  • Loading branch information
bgrainger committed Sep 16, 2016
1 parent a0ca58a commit 54e3a8d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 98 deletions.
77 changes: 4 additions & 73 deletions src/MySqlConnector/MySqlClient/MySqlConnection.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
using System;
using System.Data;
using System.Data.Common;
using System.IO;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using MySql.Data.Serialization;
Expand Down Expand Up @@ -105,35 +103,17 @@ public override async Task OpenAsync(CancellationToken cancellationToken)
if (m_session != null)
{
// test that session is still valid and (optionally) reset it
if (!await TryPingAsync(m_session, cancellationToken).ConfigureAwait(false))
if (!await m_session.TryPingAsync(cancellationToken).ConfigureAwait(false))
Utility.Dispose(ref m_session);
else if (m_connectionStringBuilder.ConnectionReset)
await ResetConnectionAsync(cancellationToken).ConfigureAwait(false);
await m_session.ResetConnectionAsync(m_connectionStringBuilder.UserID, m_connectionStringBuilder.Password, m_database, cancellationToken).ConfigureAwait(false);
}

if (m_session == null)
{
m_session = new MySqlSession(pool);
var connected = await m_session.ConnectAsync(m_connectionStringBuilder.Server.Split(','), (int) m_connectionStringBuilder.Port).ConfigureAwait(false);
if (!connected)
{
SetState(ConnectionState.Closed);
throw new MySqlException("Unable to connect to any of the specified MySQL hosts.");
}

var payload = await m_session.ReceiveAsync(cancellationToken).ConfigureAwait(false);
var reader = new ByteArrayReader(payload.ArraySegment.Array, payload.ArraySegment.Offset, payload.ArraySegment.Count);
var initialHandshake = new InitialHandshakePacket(reader);
if (initialHandshake.AuthPluginName != "mysql_native_password")
throw new NotSupportedException("Only 'mysql_native_password' authentication method is supported.");
m_session.ServerVersion = new ServerVersion(Encoding.ASCII.GetString(initialHandshake.ServerVersion));
m_session.AuthPluginData = initialHandshake.AuthPluginData;

var response = HandshakeResponse41Packet.Create(initialHandshake, m_connectionStringBuilder.UserID, m_connectionStringBuilder.Password, m_database);
payload = new PayloadData(new ArraySegment<byte>(response));
await m_session.SendReplyAsync(payload, cancellationToken).ConfigureAwait(false);
await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
// TODO: Check success
await m_session.ConnectAsync(m_connectionStringBuilder.Server.Split(','), (int) m_connectionStringBuilder.Port, m_connectionStringBuilder.UserID,
m_connectionStringBuilder.Password, m_database, cancellationToken).ConfigureAwait(false);
}

m_hasBeenOpened = true;
Expand Down Expand Up @@ -274,55 +254,6 @@ private void DoClose()
}
}

private async Task ResetConnectionAsync(CancellationToken cancellationToken)
{
if (m_session.ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0)
{
await m_session.SendAsync(ResetConnectionPayload.Create(), cancellationToken).ConfigureAwait(false);
var payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
OkPayload.Create(payload);
}
else
{
// optimistically hash the password with the challenge from the initial handshake (supported by MariaDB; doesn't appear to be supported by MySQL)
var hashedPassword = AuthenticationUtility.HashPassword(m_session.AuthPluginData, 0, m_connectionStringBuilder.Password);
var payload = ChangeUserPayload.Create(m_connectionStringBuilder.UserID, hashedPassword, m_database);
await m_session.SendAsync(payload, cancellationToken).ConfigureAwait(false);
payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
if (payload.HeaderByte == AuthenticationMethodSwitchRequestPayload.Signature)
{
// if the server didn't support the hashed password; rehash with the new challenge
var switchRequest = AuthenticationMethodSwitchRequestPayload.Create(payload);
if (switchRequest.Name != "mysql_native_password")
throw new NotSupportedException("Only 'mysql_native_password' authentication method is supported.");
hashedPassword = AuthenticationUtility.HashPassword(switchRequest.Data, 0, m_connectionStringBuilder.Password);
payload = new PayloadData(new ArraySegment<byte>(hashedPassword));
await m_session.SendReplyAsync(payload, cancellationToken).ConfigureAwait(false);
payload = await m_session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
}
OkPayload.Create(payload);
}
}

private static async Task<bool> TryPingAsync(MySqlSession session, CancellationToken cancellationToken)
{
await session.SendAsync(PingPayload.Create(), cancellationToken).ConfigureAwait(false);
try
{
var payload = await session.ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
OkPayload.Create(payload);
return true;
}
catch (EndOfStreamException)
{
}
catch (SocketException)
{
}

return false;
}

MySqlConnectionStringBuilder m_connectionStringBuilder;
MySqlSession m_session;
ConnectionState m_connectionState;
Expand Down
121 changes: 96 additions & 25 deletions src/MySqlConnector/Serialization/MySqlSession.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using MySql.Data.MySqlClient;
Expand Down Expand Up @@ -49,7 +51,100 @@ public async Task DisposeAsync(CancellationToken cancellationToken)
m_state = State.Closed;
}

public async Task<bool> ConnectAsync(IEnumerable<string> hostnames, int port)
public async Task ConnectAsync(IEnumerable<string> hosts, int port, string userId, string password, string database, CancellationToken cancellationToken)
{
var connected = await OpenSocketAsync(hosts, port).ConfigureAwait(false);
if (!connected)
throw new MySqlException("Unable to connect to any of the specified MySQL hosts.");

var payload = await ReceiveAsync(cancellationToken).ConfigureAwait(false);
var reader = new ByteArrayReader(payload.ArraySegment.Array, payload.ArraySegment.Offset, payload.ArraySegment.Count);
var initialHandshake = new InitialHandshakePacket(reader);
if (initialHandshake.AuthPluginName != "mysql_native_password")
throw new NotSupportedException("Only 'mysql_native_password' authentication method is supported.");
ServerVersion = new ServerVersion(Encoding.ASCII.GetString(initialHandshake.ServerVersion));
AuthPluginData = initialHandshake.AuthPluginData;

var response = HandshakeResponse41Packet.Create(initialHandshake, userId, password, database);
payload = new PayloadData(new ArraySegment<byte>(response));
await SendReplyAsync(payload, cancellationToken).ConfigureAwait(false);
await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
}

public async Task ResetConnectionAsync(string userId, string password, string database, CancellationToken cancellationToken)
{
if (ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0)
{
await SendAsync(ResetConnectionPayload.Create(), cancellationToken).ConfigureAwait(false);
var payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
OkPayload.Create(payload);
}
else
{
// optimistically hash the password with the challenge from the initial handshake (supported by MariaDB; doesn't appear to be supported by MySQL)
var hashedPassword = AuthenticationUtility.HashPassword(AuthPluginData, 0, password);
var payload = ChangeUserPayload.Create(userId, hashedPassword, database);
await SendAsync(payload, cancellationToken).ConfigureAwait(false);
payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
if (payload.HeaderByte == AuthenticationMethodSwitchRequestPayload.Signature)
{
// if the server didn't support the hashed password; rehash with the new challenge
var switchRequest = AuthenticationMethodSwitchRequestPayload.Create(payload);
if (switchRequest.Name != "mysql_native_password")
throw new NotSupportedException("Only 'mysql_native_password' authentication method is supported.");
hashedPassword = AuthenticationUtility.HashPassword(switchRequest.Data, 0, password);
payload = new PayloadData(new ArraySegment<byte>(hashedPassword));
await SendReplyAsync(payload, cancellationToken).ConfigureAwait(false);
payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
}
OkPayload.Create(payload);
}
}

public async Task<bool> TryPingAsync(CancellationToken cancellationToken)
{
await SendAsync(PingPayload.Create(), cancellationToken).ConfigureAwait(false);
try
{
var payload = await ReceiveReplyAsync(cancellationToken).ConfigureAwait(false);
OkPayload.Create(payload);
return true;
}
catch (EndOfStreamException)
{
}
catch (SocketException)
{
}

return false;
}

// Starts a new conversation with the server by sending the first packet.
public Task SendAsync(PayloadData payload, CancellationToken cancellationToken)
=> TryAsync(m_transmitter.SendAsync, payload, cancellationToken);

// Starts a new conversation with the server by receiving the first packet.
public ValueTask<PayloadData> ReceiveAsync(CancellationToken cancellationToken)
=> TryAsync(m_transmitter.ReceiveAsync, cancellationToken);

// Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'.
public ValueTask<PayloadData> ReceiveReplyAsync(CancellationToken cancellationToken)
=> TryAsync(m_transmitter.ReceiveReplyAsync, cancellationToken);

// Continues a conversation with the server by sending a reply to a packet received with 'Receive' or 'ReceiveReply'.
public Task SendReplyAsync(PayloadData payload, CancellationToken cancellationToken)
=> TryAsync(m_transmitter.SendReplyAsync, payload, cancellationToken);

private void VerifyConnected()
{
if (m_state == State.Closed)
throw new ObjectDisposedException(nameof(MySqlSession));
if (m_state != State.Connected)
throw new InvalidOperationException("MySqlSession is not connected.");
}

private async Task<bool> OpenSocketAsync(IEnumerable<string> hostnames, int port)
{
foreach (var hostname in hostnames)
{
Expand Down Expand Up @@ -93,30 +188,6 @@ public async Task<bool> ConnectAsync(IEnumerable<string> hostnames, int port)
return false;
}

// Starts a new conversation with the server by sending the first packet.
public Task SendAsync(PayloadData payload, CancellationToken cancellationToken)
=> TryAsync(m_transmitter.SendAsync, payload, cancellationToken);

// Starts a new conversation with the server by receiving the first packet.
public ValueTask<PayloadData> ReceiveAsync(CancellationToken cancellationToken)
=> TryAsync(m_transmitter.ReceiveAsync, cancellationToken);

// Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'.
public ValueTask<PayloadData> ReceiveReplyAsync(CancellationToken cancellationToken)
=> TryAsync(m_transmitter.ReceiveReplyAsync, cancellationToken);

// Continues a conversation with the server by sending a reply to a packet received with 'Receive' or 'ReceiveReply'.
public Task SendReplyAsync(PayloadData payload, CancellationToken cancellationToken)
=> TryAsync(m_transmitter.SendReplyAsync, payload, cancellationToken);


private void VerifyConnected()
{
if (m_state == State.Closed)
throw new ObjectDisposedException(nameof(MySqlSession));
if (m_state != State.Connected)
throw new InvalidOperationException("MySqlSession is not connected.");
}

private Task TryAsync<TArg>(Func<TArg, CancellationToken, Task> func, TArg arg, CancellationToken cancellationToken)
{
Expand Down

0 comments on commit 54e3a8d

Please sign in to comment.