Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow arbitrary config params with more efficient repr #30

Merged
merged 12 commits into from
Jun 20, 2024
4 changes: 2 additions & 2 deletions postgres-protocol/src/authentication/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ enum Credentials<const N: usize> {
/// A regular password as a vector of bytes.
Password(Vec<u8>),
/// A precomputed pair of keys.
Keys(Box<ScramKeys<N>>),
Keys(ScramKeys<N>),
}

enum State {
Expand Down Expand Up @@ -176,7 +176,7 @@ impl ScramSha256 {

/// Constructs a new instance which will use the provided key pair for authentication.
pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
let password = Credentials::Keys(keys.into());
let password = Credentials::Keys(keys);
ScramSha256::new_inner(password, channel_binding, nonce())
}

Expand Down
2 changes: 1 addition & 1 deletion postgres-protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ macro_rules! from_usize {
impl FromUsize for $t {
#[inline]
fn from_usize(x: usize) -> io::Result<$t> {
if x > <$t>::max_value() as usize {
if x > <$t>::MAX as usize {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"value too large to transmit",
Expand Down
60 changes: 60 additions & 0 deletions postgres-protocol/src/message/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,66 @@ where
})
}

#[inline]
pub fn startup_message_cstr(
parameters: &StartupMessageParams,
buf: &mut BytesMut,
) -> io::Result<()> {
write_body(buf, |buf| {
// postgres protocol version 3.0(196608) in bigger-endian
buf.put_i32(0x00_03_00_00);
buf.put_slice(&parameters.params);
buf.put_u8(0);
Ok(())
})
}

#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct StartupMessageParams {
params: BytesMut,
}

impl StartupMessageParams {
/// Set parameter's value by its name.
pub fn insert(&mut self, name: &str, value: &str) -> Result<(), io::Error> {
if name.contains('\0') | value.contains('\0') {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
self.params.put(name.as_bytes());
self.params.put(&b"\0"[..]);
self.params.put(value.as_bytes());
self.params.put(&b"\0"[..]);
Ok(())
}

pub fn str_iter(&self) -> impl Iterator<Item = (&str, &str)> {
let params =
std::str::from_utf8(&self.params).expect("should be validated as utf8 already");
StrParamsIter(params)
}

/// Get parameter's value by its name.
pub fn get(&self, name: &str) -> Option<&str> {
self.str_iter().find_map(|(k, v)| (k == name).then_some(v))
}
}

struct StrParamsIter<'a>(&'a str);

impl<'a> Iterator for StrParamsIter<'a> {
type Item = (&'a str, &'a str);

fn next(&mut self) -> Option<Self::Item> {
let (key, r) = self.0.split_once('\0')?;
let (value, r) = r.split_once('\0')?;
self.0 = r;
Some((key, value))
}
}

#[inline]
pub fn sync(buf: &mut BytesMut) {
buf.put_u8(b'S');
Expand Down
2 changes: 1 addition & 1 deletion postgres-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ impl ToSql for IpAddr {
}

fn downcast(len: usize) -> Result<i32, Box<dyn Error + Sync + Send>> {
if len > i32::max_value() as usize {
if len > i32::MAX as usize {
Err("value too large to transmit".into())
} else {
Ok(len as i32)
Expand Down
1 change: 0 additions & 1 deletion postgres-types/src/special.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use bytes::BytesMut;
use postgres_protocol::types;
use std::error::Error;
use std::{i32, i64};

use crate::{FromSql, IsNull, ToSql, Type};

Expand Down
24 changes: 0 additions & 24 deletions postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,6 @@ impl Config {
self
}

/// Gets the password to authenticate with, if one has been configured with
/// the `password` method.
pub fn get_password(&self) -> Option<&[u8]> {
self.config.get_password()
}

/// Sets precomputed protocol-specific keys to authenticate with.
/// When set, this option will override `password`.
/// See [`AuthKeys`] for more information.
Expand All @@ -159,12 +153,6 @@ impl Config {
self
}

/// Gets precomputed protocol-specific keys to authenticate with.
/// if one has been configured with the `auth_keys` method.
pub fn get_auth_keys(&self) -> Option<AuthKeys> {
self.config.get_auth_keys()
}

/// Sets the name of the database to connect to.
///
/// Defaults to the user.
Expand All @@ -185,24 +173,12 @@ impl Config {
self
}

/// Gets the command line options used to configure the server, if the
/// options have been set with the `options` method.
pub fn get_options(&self) -> Option<&str> {
self.config.get_options()
}

/// Sets the value of the `application_name` runtime parameter.
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.config.application_name(application_name);
self
}

/// Gets the value of the `application_name` runtime parameter, if it has
/// been set with the `application_name` method.
pub fn get_application_name(&self) -> Option<&str> {
self.config.get_application_name()
}

/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
Expand Down
Loading
Loading