Skip to content

Commit

Permalink
feat: implement IMAP COMPRESS
Browse files Browse the repository at this point in the history
  • Loading branch information
link2xt committed Oct 10, 2024
1 parent 1954ce4 commit 82e247c
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 2 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ jobs:
- name: check tokio
run: cargo check --workspace --all-targets --no-default-features --features runtime-tokio

- name: check compress feature with tokio
run: cargo check --workspace --all-targets --no-default-features --features runtime-tokio,compress

- name: check compress feature with async-std
run: cargo check --workspace --all-targets --no-default-features --features runtime-async-std,compress

- name: check async-std examples
working-directory: examples
run: cargo check --workspace --all-targets --no-default-features --features runtime-async-std
Expand Down
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ is-it-maintained-open-issues = { repository = "async-email/async-imap" }

[features]
default = ["runtime-async-std"]
compress = ["async-compression"]

runtime-async-std = ["async-std"]
runtime-tokio = ["tokio"]
runtime-async-std = ["async-std", "async-compression?/futures-io"]
runtime-tokio = ["tokio", "async-compression?/tokio"]

[dependencies]
async-channel = "2.0.0"
async-compression = { git = "https://github.com/Nullus157/async-compression.git", default-features = false, features = ["deflate"], optional = true }
async-std = { version = "1.8.0", default-features = false, features = ["std", "unstable"], optional = true }
base64 = "0.21"
bytes = "1"
Expand All @@ -35,6 +37,7 @@ imap-proto = "0.16.4"
log = "0.4.8"
nom = "7.0"
once_cell = "1.8.0"
pin-project = "1"
pin-utils = "0.1.0-alpha.4"
self_cell = "1.0.1"
stop-token = "0.7"
Expand Down
168 changes: 168 additions & 0 deletions src/extensions/compress.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
//! IMAP COMPRESS extension specified in [RFC4978](https://www.rfc-editor.org/rfc/rfc4978.html).

use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};

use pin_project::pin_project;

use crate::client::Session;
use crate::error::Result;
use crate::imap_stream::ImapStream;
use crate::shared_stream::SharedStream;
use crate::types::IdGenerator;
use crate::Connection;

#[cfg(feature = "runtime-async-std")]
use async_std::io::{BufReader, Read, Write};
#[cfg(feature = "runtime-tokio")]
use tokio::io::{AsyncRead as Read, AsyncWrite as Write, BufReader, ReadBuf};

#[cfg(feature = "runtime-tokio")]
use async_compression::tokio::bufread::DeflateDecoder;
#[cfg(feature = "runtime-tokio")]
use async_compression::tokio::write::DeflateEncoder;

#[cfg(feature = "runtime-async-std")]
use async_compression::futures::bufread::DeflateDecoder;
#[cfg(feature = "runtime-async-std")]
use async_compression::futures::write::DeflateEncoder;

/// IMAP stream
#[derive(Debug)]
#[pin_project]
pub struct DeflateStream<T: Read + Write + Unpin + fmt::Debug> {
/// Shared stream reference to allow direct access
/// to the underlying stream.
stream: SharedStream<T>,

#[pin]
decoder: DeflateDecoder<BufReader<SharedStream<T>>>,

#[pin]
encoder: DeflateEncoder<SharedStream<T>>,
}

impl<T: Read + Write + Unpin + fmt::Debug> DeflateStream<T> {
pub(crate) fn new(stream: T) -> Self {
let stream = SharedStream::new(stream);
let decoder = DeflateDecoder::new(BufReader::new(stream.clone()));
let encoder = DeflateEncoder::new(stream.clone());
Self {
stream,
decoder,
encoder,
}
}

/// Runs provided function while holding the lock on the underlying stream.
///
/// This allows to access the underlying stream while ensuring
/// that no data is read from the stream or written into the stream at the same time.
pub fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R {
self.stream.with_lock(f)
}
}

#[cfg(feature = "runtime-tokio")]
impl<T: Read + Write + Unpin + fmt::Debug> Read for DeflateStream<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.project().decoder.poll_read(cx, buf)
}
}

#[cfg(feature = "runtime-async-std")]
impl<T: Read + Write + Unpin + fmt::Debug> Read for DeflateStream<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<async_std::io::Result<usize>> {
self.project().decoder.poll_read(cx, buf)
}
}

#[cfg(feature = "runtime-tokio")]
impl<T: Read + Write + Unpin + fmt::Debug> Write for DeflateStream<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project().encoder.poll_write(cx, buf)
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().encoder.poll_flush(cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().encoder.poll_shutdown(cx)
}
}

#[cfg(feature = "runtime-async-std")]
impl<T: Read + Write + Unpin + fmt::Debug> Write for DeflateStream<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<async_std::io::Result<usize>> {
self.project().encoder.poll_write(cx, buf)
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<async_std::io::Result<()>> {
self.project().encoder.poll_flush(cx)
}

fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<async_std::io::Result<()>> {
self.project().encoder.poll_close(cx)
}
}

impl<T: Read + Write + Unpin + fmt::Debug + Send> Session<T> {
/// Runs `COMPRESS DEFLATE` command.
pub async fn compress<F, S>(self, f: F) -> Result<Session<S>>
where
S: Read + Write + Unpin + fmt::Debug,
F: FnOnce(DeflateStream<T>) -> S,
{
let Self {
mut conn,
unsolicited_responses_tx,
unsolicited_responses,
} = self;
conn.run_command_and_check_ok("COMPRESS DEFLATE", Some(unsolicited_responses_tx.clone()))
.await?;

let stream = conn.into_inner();
let deflate_stream = DeflateStream::new(stream);
let stream = ImapStream::new(f(deflate_stream));
let conn = Connection {
stream,
request_ids: IdGenerator::new(),
};
let session = Session {
conn,
unsolicited_responses_tx,
unsolicited_responses,
};
Ok(session)
}
}
3 changes: 3 additions & 0 deletions src/extensions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
//! Implementations of various IMAP extensions.
#[cfg(feature = "compress")]
pub mod compress;

pub mod idle;

pub mod quota;
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ mod imap_stream;
mod parse;
pub mod types;

#[cfg(feature = "compress")]
pub use crate::extensions::compress::DeflateStream;

#[cfg(feature = "compress")]
mod shared_stream;

pub use crate::authenticator::Authenticator;
pub use crate::client::*;

Expand Down
157 changes: 157 additions & 0 deletions src/shared_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};

#[cfg(feature = "runtime-async-std")]
use async_std::io::{Read, Write};
#[cfg(feature = "runtime-tokio")]
use tokio::io::{AsyncRead as Read, AsyncWrite as Write, ReadBuf};

#[cfg(feature = "runtime-tokio")]
#[derive(Debug)]
pub(crate) struct SharedStream<T: std::fmt::Debug> {
inner: Arc<Mutex<T>>,
is_write_vectored: bool,
}

#[cfg(feature = "runtime-async-std")]
#[derive(Debug)]
pub(crate) struct SharedStream<T: std::fmt::Debug> {
inner: Arc<Mutex<T>>,
}

#[cfg(feature = "runtime-tokio")]
impl<T: std::fmt::Debug> Clone for SharedStream<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
is_write_vectored: self.is_write_vectored,
}
}
}

#[cfg(feature = "runtime-async-std")]
impl<T: std::fmt::Debug> Clone for SharedStream<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}

#[cfg(feature = "runtime-tokio")]
impl<T: std::fmt::Debug> SharedStream<T>
where
T: Read + Write,
{
pub(crate) fn new(stream: T) -> SharedStream<T> {
let is_write_vectored = stream.is_write_vectored();

let inner = Arc::new(Mutex::new(stream));

Self {
inner,
is_write_vectored,
}
}
}

#[cfg(feature = "runtime-async-std")]
impl<T: std::fmt::Debug> SharedStream<T>
where
T: Read + Write,
{
pub(crate) fn new(stream: T) -> SharedStream<T> {
let inner = Arc::new(Mutex::new(stream));

Self { inner }
}
}

impl<T: Unpin + std::fmt::Debug> SharedStream<T> {
pub(crate) fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R {
let mut guard = self.inner.lock().unwrap();
let stream = Pin::new(&mut *guard);
f(stream)
}
}

#[cfg(feature = "runtime-tokio")]
impl<T: Read + Unpin + std::fmt::Debug> Read for SharedStream<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.with_lock(|stream| stream.poll_read(cx, buf))
}
}

#[cfg(feature = "runtime-async-std")]
impl<T: Read + Unpin + std::fmt::Debug> Read for SharedStream<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<async_std::io::Result<usize>> {
self.with_lock(|stream| stream.poll_read(cx, buf))
}
}

#[cfg(feature = "runtime-tokio")]
impl<T: Write + Unpin + std::fmt::Debug> Write for SharedStream<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.with_lock(|stream| stream.poll_write(cx, buf))
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.with_lock(|stream| stream.poll_flush(cx))
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.with_lock(|stream| stream.poll_shutdown(cx))
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
self.with_lock(|stream| stream.poll_write_vectored(cx, bufs))
}

fn is_write_vectored(&self) -> bool {
self.is_write_vectored
}
}

#[cfg(feature = "runtime-async-std")]
impl<T: Write + Unpin + std::fmt::Debug> Write for SharedStream<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.with_lock(|stream| stream.poll_write(cx, buf))
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<async_std::io::Result<()>> {
self.with_lock(|stream| stream.poll_flush(cx))
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<async_std::io::Result<()>> {
self.with_lock(|stream| stream.poll_close(cx))
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[async_std::io::IoSlice<'_>],
) -> Poll<async_std::io::Result<usize>> {
self.with_lock(|stream| stream.poll_write_vectored(cx, bufs))
}
}

0 comments on commit 82e247c

Please sign in to comment.