Skip to content

Commit

Permalink
Add mysql_old_password plugin support
Browse files Browse the repository at this point in the history
  • Loading branch information
blackbeam committed Jun 29, 2021
1 parent 8cee32d commit e18f47d
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ futures-sink = "0.3"
lazy_static = "1"
lru = "0.6.0"
mio = "0.7.7"
mysql_common = { version = "0.27.1", default-features = false }
mysql_common = { version = "0.27.2", default-features = false }
native-tls = "0.2"
once_cell = "1.7.2"
pem = "0.8.1"
Expand Down
8 changes: 6 additions & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,20 @@ jobs:
docker --version
displayName: Install docker
- bash: |
docker run -d --name container -v `pwd`:/root -p 3307:3306 -e MYSQL_ROOT_PASSWORD=password mysql:$(DB_VERSION) --max-allowed-packet=36700160 --local-infile --log-bin=mysql-bin --log-slave-updates --gtid_mode=ON --enforce_gtid_consistency=ON --server-id=1
if [[ "5.6" == "$(DB_VERSION)" ]]; then ARG="--secure-auth=OFF"; fi
docker run -d --name container -v `pwd`:/root -p 3307:3306 -e MYSQL_ROOT_PASSWORD=password mysql:$(DB_VERSION) --max-allowed-packet=36700160 --local-infile --log-bin=mysql-bin --log-slave-updates --gtid_mode=ON --enforce_gtid_consistency=ON --server-id=1 $ARG
while ! nc -W 1 localhost 3307 | grep -q -P '.+'; do sleep 1; done
displayName: Run MySql in Docker
- bash: |
docker exec container bash -l -c "mysql -uroot -ppassword -e \"SET old_passwords = 1; GRANT ALL PRIVILEGES ON *.* TO 'root2'@'%' IDENTIFIED WITH mysql_old_password AS 'password'; SET PASSWORD FOR 'root2'@'%' = OLD_PASSWORD('password')\"";
condition: eq(variables['DB_VERSION'], '5.6')
- bash: |
docker exec container bash -l -c "apt-get update"
docker exec container bash -l -c "apt-get install -y curl clang libssl-dev pkg-config"
docker exec container bash -l -c "curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable"
displayName: Install Rust in docker
- bash: |
if [[ "5.6" != "$(DB_VERSION)" ]]; then SSL=true; fi
if [[ "5.6" != "$(DB_VERSION)" ]]; then SSL=true; else DATABASE_URL="mysql://root2:password@127.0.0.1/mysql?secure_auth=false"; fi
docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test"
docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL COMPRESS=true cargo test"
docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL cargo test"
Expand Down
39 changes: 31 additions & 8 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use mysql_common::{
io::ParseBuf,
packets::{
binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket,
HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, ResultSetTerminator,
SslRequest,
HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest,
ResultSetTerminator, SslRequest,
},
proto::MySerialize,
};
Expand Down Expand Up @@ -390,7 +390,9 @@ impl Conn {
self.inner.id = handshake.connection_id();
self.inner.status = handshake.status_flags();
self.inner.auth_plugin = match handshake.auth_plugin() {
Some(AuthPlugin::MysqlNativePassword) => AuthPlugin::MysqlNativePassword,
Some(AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword) => {
AuthPlugin::MysqlNativePassword
}
Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password,
Some(AuthPlugin::Other(ref name)) => {
let name = String::from_utf8_lossy(name).into();
Expand Down Expand Up @@ -454,18 +456,31 @@ impl Conn {
) -> Result<()> {
if !self.inner.auth_switched {
self.inner.auth_switched = true;
self.inner.nonce = auth_switch_request.plugin_data().into();

if matches!(
auth_switch_request.auth_plugin(),
AuthPlugin::MysqlOldPassword
) {
if self.inner.opts.secure_auth() {
return Err(DriverError::MysqlOldPasswordDisabled.into());
}
}

self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned();

let plugin_data = self
.inner
.auth_plugin
.gen_data(self.inner.opts.pass(), &*self.inner.nonce);

if let Some(plugin_data) = plugin_data {
self.write_struct(&plugin_data).await?;
} else {
self.write_packet(crate::BUFFER_POOL.get()).await?;
}

self.continue_auth().await?;

Ok(())
} else {
unreachable!("auth_switched flag should be checked by caller")
Expand All @@ -477,7 +492,7 @@ impl Conn {
// see https://github.com/rust-lang/rust/issues/46415#issuecomment-528099782
Box::pin(async move {
match self.inner.auth_plugin {
AuthPlugin::MysqlNativePassword => {
AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
self.continue_mysql_native_password_auth().await?;
Ok(())
}
Expand Down Expand Up @@ -561,9 +576,17 @@ impl Conn {
match packet.get(0) {
Some(0x00) => Ok(()),
Some(0xfe) if !self.inner.auth_switched => {
let auth_switch_request = ParseBuf(&*packet).parse::<AuthSwitchRequest>(())?;
self.perform_auth_switch(auth_switch_request).await?;
Ok(())
let auth_switch = if packet.len() > 1 {
ParseBuf(&*packet).parse(())?
} else {
let _ = ParseBuf(&*packet).parse::<OldAuthSwitchRequest>(())?;
// map OldAuthSwitch to AuthSwitch with mysql_old_password plugin
AuthSwitchRequest::new(
"mysql_old_password".as_bytes(),
self.inner.nonce.clone(),
)
};
self.perform_auth_switch(auth_switch).await
}
_ => Err(DriverError::UnexpectedPacket {
payload: packet.to_vec(),
Expand Down
15 changes: 12 additions & 3 deletions src/conn/pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tokio::sync::mpsc;

use std::{
collections::VecDeque,
convert::TryFrom,
pin::Pin,
str::FromStr,
sync::{atomic, Arc, Mutex},
Expand Down Expand Up @@ -111,8 +112,16 @@ pub struct Pool {

impl Pool {
/// Creates a new pool of connections.
pub fn new<O: Into<Opts>>(opts: O) -> Pool {
let opts = opts.into();
///
/// # Panic
///
/// It'll panic if `Opts::try_from(opts)` returns error.
pub fn new<O>(opts: O) -> Pool
where
Opts: TryFrom<O>,
<Opts as TryFrom<O>>::Error: std::error::Error,
{
let opts = Opts::try_from(opts).unwrap();
let pool_opts = opts.pool_opts().clone();
let (tx, rx) = mpsc::unbounded_channel();
Pool {
Expand Down Expand Up @@ -577,7 +586,7 @@ mod test {
#[tokio::test]
async fn should_hold_bounds_on_error() -> super::Result<()> {
// Should not be possible to connect to broadcast address.
let pool = Pool::new(String::from("mysql://255.255.255.255"));
let pool = Pool::new("mysql://255.255.255.255");

assert!(try_join!(pool.get_conn(), pool.get_conn()).is_err());
assert_eq!(ex_field!(pool, exist), 0);
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ pub enum DriverError {

#[error("Named pipe connections temporary disabled (see tokio-rs/tokio#3118)")]
NamedPipesDisabled,

#[error("`mysql_old_password` plugin is insecure and disabled by default")]
MysqlOldPasswordDisabled,
}

impl From<DriverError> for Error {
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ pub mod test_misc {
}

pub fn get_opts() -> OptsBuilder {
let mut builder = OptsBuilder::from_opts(&**DATABASE_URL);
let mut builder = OptsBuilder::from_opts(Opts::from_url(&**DATABASE_URL).unwrap());
if test_ssl() {
let ssl_opts = SslOpts::default()
.with_danger_skip_domain_validation(true)
Expand Down
60 changes: 52 additions & 8 deletions src/opts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use url::{Host, Url};

use std::{
borrow::Cow,
convert::TryFrom,
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
path::Path,
Expand Down Expand Up @@ -399,6 +400,11 @@ pub(crate) struct MysqlOpts {
/// By default `Conn` will query this value from the server. One can avoid this step
/// by explicitly specifying it.
wait_timeout: Option<usize>,

/// Disables `mysql_old_password` plugin (defaults to `true`).
///
/// Available via `secure_auth` connection url parameter.
secure_auth: bool,
}

/// Mysql connection options.
Expand Down Expand Up @@ -689,6 +695,13 @@ impl Opts {
self.inner.mysql_opts.wait_timeout
}

/// Disables `mysql_old_password` plugin (defaults to `true`).
///
/// Available via `secure_auth` connection url parameter.
pub fn secure_auth(&self) -> bool {
self.inner.mysql_opts.secure_auth
}

pub(crate) fn get_capabilities(&self) -> CapabilityFlags {
let mut out = CapabilityFlags::CLIENT_PROTOCOL_41
| CapabilityFlags::CLIENT_SECURE_CONNECTION
Expand Down Expand Up @@ -734,6 +747,7 @@ impl Default for MysqlOpts {
compression: None,
max_allowed_packet: None,
wait_timeout: None,
secure_auth: true,
}
}
}
Expand Down Expand Up @@ -831,8 +845,16 @@ impl Default for OptsBuilder {

impl OptsBuilder {
/// Creates new builder from the given `Opts`.
pub fn from_opts<T: Into<Opts>>(opts: T) -> Self {
let opts = opts.into();
///
/// # Panic
///
/// It'll panic if `Opts::try_from(opts)` returns error.
pub fn from_opts<T>(opts: T) -> Self
where
Opts: TryFrom<T>,
<Opts as TryFrom<T>>::Error: std::error::Error,
{
let opts = Opts::try_from(opts).unwrap();

OptsBuilder {
tcp_port: opts.inner.address.get_tcp_port(),
Expand Down Expand Up @@ -968,6 +990,14 @@ impl OptsBuilder {
});
self
}

/// Disables `mysql_old_password` plugin (defaults to `true`).
///
/// Available via `secure_auth` connection url parameter.
pub fn secure_auth(mut self, secure_auth: bool) -> Self {
self.opts.secure_auth = secure_auth;
self
}
}

impl From<OptsBuilder> for Opts {
Expand Down Expand Up @@ -1180,6 +1210,18 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result<MysqlOpts, UrlError> {
});
}
}
} else if key == "secure_auth" {
match bool::from_str(&*value) {
Ok(secure_auth) => {
opts.secure_auth = secure_auth;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "secure_auth".into(),
value,
});
}
}
} else if key == "socket" {
opts.socket = Some(value)
} else if key == "compression" {
Expand Down Expand Up @@ -1224,9 +1266,11 @@ impl FromStr for Opts {
}
}

impl<T: AsRef<str> + Sized> From<T> for Opts {
fn from(url: T) -> Opts {
Opts::from_url(url.as_ref()).unwrap()
impl<'a> TryFrom<&'a str> for Opts {
type Error = UrlError;

fn try_from(s: &str) -> std::result::Result<Self, UrlError> {
Opts::from_url(s)
}
}

Expand Down Expand Up @@ -1306,21 +1350,21 @@ mod test {
#[should_panic]
fn should_panic_on_invalid_url() {
let opts = "42";
let _: Opts = opts.into();
let _: Opts = Opts::from_str(opts).unwrap();
}

#[test]
#[should_panic]
fn should_panic_on_invalid_scheme() {
let opts = "postgres://localhost";
let _: Opts = opts.into();
let _: Opts = Opts::from_str(opts).unwrap();
}

#[test]
#[should_panic]
fn should_panic_on_unknown_query_param() {
let opts = "mysql://localhost/foo?bar=baz";
let _: Opts = opts.into();
let _: Opts = Opts::from_str(opts).unwrap();
}

#[test]
Expand Down

0 comments on commit e18f47d

Please sign in to comment.