Skip to content

Commit

Permalink
Shutdown (#488)
Browse files Browse the repository at this point in the history
* tls-mirage: do not FLOW.close on error

this is the responsibility of the caller (esp. since we'll have half-closed
flows soon)

* refine engine for shutdown

* Engine: guard functions with write_closed

* lwt: adapt to shutdown API

* lwt: only open Lwt.Infix, not Lwt

* tls-lwt: improve failure message

* tls-lwt: use reraise instead of fail

* tls-mirage support for shutdown

* tls-mirage: eof on read_closed

* tls-mirage: only close in close

* minimal changes to get eio and async compiling

* randomconv dependency: move to tls-lwt, restrict to < 0.2.0

* tls-lwt: add ptime dependency

* read: a closed results in eof

* tls-eio: update for new shutdown system

This is a direct port of the changes to tls-lwt to tls-eio.

* tls-eio: update fuzz tests to test half-shutdown

---------

Co-authored-by: Thomas Leonard <talex5@gmail.com>
  • Loading branch information
hannesm and talex5 committed Mar 26, 2024
1 parent bd11882 commit e6a52d8
Show file tree
Hide file tree
Showing 14 changed files with 395 additions and 323 deletions.
11 changes: 5 additions & 6 deletions async/io.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,19 @@ module Make (Fd : Fd) : S with module Fd := Fd = struct
let rec read_react t =
let handle tls buf =
match Tls.Engine.handle_tls tls buf with
| Ok (state, `Response resp, `Data data) ->
| Ok (state, eof, `Response resp, `Data data) ->
t.state
<- (match state with
| `Ok tls -> Active tls
| `Eof -> Eof
| `Alert a -> Error (Tls_alert a));
<- (match eof with
| None -> Active state
| Some `Eof -> Eof);
let%map () =
match resp with
| None -> return ()
| Some resp -> Fd.write_full t.fd resp
in
`Ok data
| Error (alert, `Response resp) ->
t.state <- Error (Tls_failure alert);
t.state <- Error (match alert with `Alert a -> Tls_alert a | f -> Tls_failure f);
let%bind () = Fd.write_full t.fd resp in
read_react t
in
Expand Down
40 changes: 9 additions & 31 deletions eio/tests/fuzz.ml
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ let dir =
the network draining any queued traffic.
Once all fibers have finished, we check that what was sent matches the data
that has been received.
However, due to #452, we currently skip the check on the receiving side if
the receiver has shut down its sending side by then. *)
that has been received. *)

let action =
Crowbar.option (Crowbar.pair dir op) (* None means yield *)
Expand All @@ -85,8 +82,6 @@ module Path : sig
val create :
sender:(Tls_eio.t, exn) result Promise.t ->
receiver:(Tls_eio.t, exn) result Promise.t ->
sender_closed:bool ref ->
receiver_closed:bool ref ->
transmit:(transmit_amount -> unit) ->
dir -> string -> t
(** Create a test driver for one direction, from [sender] to [receiver].
Expand All @@ -113,20 +108,16 @@ end = struct
send_commands : [`Send of int | `Exit] Eio.Stream.t; (* Commands for the sending fiber *)
recv_commands : [`Recv | `Drain] Eio.Stream.t; (* Commands for the receiving fiber *)
transmit : transmit_amount -> unit;
(* FIXME: We shouldn't need to care about these, but see issue #452: *)
sender_closed : bool ref;
receiver_closed : bool ref;
}

let pp_dir f t =
pp_dir f t.dir

let create ~sender ~receiver ~sender_closed ~receiver_closed ~transmit dir message =
let create ~sender ~receiver ~transmit dir message =
let send_commands = Eio.Stream.create max_int in
let recv_commands = Eio.Stream.create max_int in
{ dir; message; sender; receiver; sent = 0; recv = 0;
send_commands; recv_commands;
transmit; sender_closed; receiver_closed }
send_commands; recv_commands; transmit }

let shutdown t =
Eio.Stream.add t.send_commands `Exit
Expand All @@ -143,7 +134,6 @@ end = struct
match Eio.Stream.take t.send_commands with
| `Exit ->
Log.info (fun f -> f "%a: shutdown send (Tls level)" pp_dir t);
t.sender_closed := true;
Eio.Flow.shutdown sender `Send
| `Send len ->
let available = String.length t.message - t.sent in
Expand All @@ -156,12 +146,7 @@ end = struct
);
aux ()
in
try
aux()
with Eio.Io (Eio.Net.E Connection_reset _, _) ->
(* Due to #452, if we get told that the receiver will no longer send, then
we can't send either. *)
assert !(t.receiver_closed)
aux()

let run_recv_thread t =
let recv = Promise.await_exn t.receiver in
Expand All @@ -185,12 +170,11 @@ end = struct
t.recv <- t.recv + got
done
with End_of_file ->
if not !(t.receiver_closed) then (
if t.recv <> t.sent then
Fmt.failwith "%a: Sender sent %d bytes, but receiver got EOF after reading only %d"
pp_dir t
t.sent
t.recv;
if t.recv <> t.sent then (
Fmt.failwith "%a: Sender sent %d bytes, but receiver got EOF after reading only %d"
pp_dir t
t.sent
t.recv
);
Log.info (fun f -> f "%a: recv thread done (got EOF)" pp_dir t)

Expand Down Expand Up @@ -288,22 +272,16 @@ let main client_message server_message quickstart actions =
let client_socket, server_socket = Mock_socket.create_pair () in
let server_flow = Fiber.fork_promise ~sw (fun () -> Tls_eio.server_of_flow Config.server server_socket) in
let client_flow = Fiber.fork_promise ~sw (fun () -> Tls_eio.client_of_flow Config.client client_socket) in
let server_closed = ref false in
let client_closed = ref false in
let to_server =
Path.create
~sender:client_flow
~receiver:server_flow
~sender_closed:client_closed
~receiver_closed:server_closed
~transmit:(Mock_socket.transmit client_socket)
To_server client_message in
let to_client =
Path.create
~sender:server_flow
~receiver:client_flow
~sender_closed:server_closed
~receiver_closed:client_closed
~transmit:(Mock_socket.transmit server_socket)
To_client server_message
in
Expand Down
121 changes: 65 additions & 56 deletions eio/tls_eio.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,37 @@ module Raw = struct
type t = {
flow : Flow.two_way_ty r;
mutable state : [ `Active of Tls.Engine.state
| `Eof
| `Read_closed of Tls.Engine.state
| `Write_closed of Tls.Engine.state
| `Closed
| `Error of exn ] ;
mutable linger : Cstruct.t option ;
recv_buf : Cstruct.t ;
}

let read_t t cs =
try Flow.single_read t.flow cs
with
| End_of_file as ex ->
t.state <- `Eof;
raise ex
| exn ->
(match t.state with
| `Error _ | `Eof -> ()
| `Active _ -> t.state <- `Error exn) ;
raise exn
let half_close state mode =
match state, mode with
| `Active tls, `read -> `Read_closed tls
| `Active tls, `write -> `Write_closed tls
| `Active _, `read_write -> `Closed
| `Read_closed tls, `read -> `Read_closed tls
| `Read_closed _, (`write | `read_write) -> `Closed
| `Write_closed tls, `write -> `Write_closed tls
| `Write_closed _, (`read | `read_write) -> `Closed
| (`Closed | `Error _) as e, (`read | `write | `read_write) -> e

let inject_state tls = function
| `Active _ -> `Active tls
| `Read_closed _ -> `Read_closed tls
| `Write_closed _ -> `Write_closed tls
| (`Closed | `Error _) as e -> e

let write_t t cs =
try Flow.copy (Flow.cstruct_source [cs]) t.flow
with exn ->
(match t.state with
| `Error _ | `Eof -> ()
| `Active _ -> t.state <- `Error exn) ;
| `Error _ -> ()
| _ -> t.state <- `Error exn) ;
raise exn

let try_write_t t cs =
Expand All @@ -55,30 +62,38 @@ module Raw = struct

let handle tls buf =
match Tls.Engine.handle_tls tls buf with
| Ok (state', `Response resp, `Data data) ->
let state' = match state' with
| `Ok tls -> `Active tls
| `Eof -> `Eof
| `Alert a -> `Error (Tls_alert a)
in
| Ok (state', eof, `Response resp, `Data data) ->
let state' = inject_state state' t.state in
let state' = Option.(value ~default:state' (map (fun `Eof -> half_close state' `read) eof)) in
t.state <- state' ;
Option.iter (try_write_t t) resp;
data

| Error (alert, `Response resp) ->
t.state <- `Error (Tls_failure alert) ;
| Error (fail, `Response resp) ->
t.state <- `Error (match fail with `Alert a -> Tls_alert a | f -> Tls_failure f) ;
write_t t resp; read_react t
in

match t.state with
| `Error e -> raise e
| `Eof -> raise End_of_file
| `Active _ ->
let n = read_t t t.recv_buf in
match (t.state, n) with
| (`Active tls, n) -> handle tls (Cstruct.sub t.recv_buf 0 n)
| (`Error e, _) -> raise e
| (`Eof, _) -> raise End_of_file
| `Closed
| `Read_closed _ -> raise End_of_file
| _ ->
match Flow.single_read t.flow t.recv_buf with
| exception End_of_file ->
t.state <- half_close t.state `read;
raise End_of_file
| exception exn ->
(match t.state with
| `Error _ -> ()
| _ -> t.state <- `Error exn) ;
raise exn
| n ->
match t.state with
| `Error e -> raise e
| `Active tls | `Read_closed tls | `Write_closed tls ->
handle tls (Cstruct.sub t.recv_buf 0 n)
| `Closed -> raise End_of_file

let rec single_read t buf =

Expand All @@ -101,11 +116,11 @@ module Raw = struct
let writev t css =
match t.state with
| `Error err -> raise err
| `Eof -> raise (Eio.Net.err (Connection_reset Tls_socket_closed))
| `Active tls ->
| `Write_closed _ | `Closed -> raise (Eio.Net.err (Connection_reset Tls_socket_closed))
| `Active tls | `Read_closed tls ->
match Tls.Engine.send_application_data tls css with
| Some (tls, tlsdata) ->
( t.state <- `Active tls ; write_t t tlsdata )
( t.state <- inject_state tls t.state ; write_t t tlsdata )
| None -> invalid_arg "tls: write: socket not ready"

let single_write t bufs =
Expand Down Expand Up @@ -136,41 +151,36 @@ module Raw = struct
let reneg ?authenticator ?acceptable_cas ?cert ?(drop = true) t =
match t.state with
| `Error err -> raise err
| `Eof -> raise End_of_file
| `Closed | `Read_closed _ | `Write_closed _ -> invalid_arg "tls: closed socket"
| `Active tls ->
match Tls.Engine.reneg ?authenticator ?acceptable_cas ?cert tls with
| None -> invalid_arg "tls: can't renegotiate"
| Some (tls', buf) ->
if drop then t.linger <- None ;
t.state <- `Active tls' ;
t.state <- inject_state tls' t.state ;
write_t t buf;
ignore (drain_handshake t : t)

let key_update ?request t =
match t.state with
| `Error err -> raise err
| `Eof -> raise End_of_file
| `Active tls ->
| `Write_closed _ | `Closed -> invalid_arg "tls: closed socket"
| `Active tls | `Read_closed tls ->
match Tls.Engine.key_update ?request tls with
| Error _ -> invalid_arg "tls: can't update key"
| Error f -> Fmt.invalid_arg "tls: can't update key: %a" Tls.Engine.pp_failure f
| Ok (tls', buf) ->
t.state <- `Active tls' ;
write_t t buf

let close_tls t =
match t.state with
| `Active tls ->
let (_, buf) = Tls.Engine.send_close_notify tls in
t.state <- `Eof ; (* XXX: this looks wrong - we're only trying to close the sending side *)
t.state <- inject_state tls' t.state ;
write_t t buf
| _ -> ()

(* Not sure if we need to keep both directions open on the underlying flow when closing
one direction at the TLS level. *)
let shutdown t = function
| `Send -> close_tls t
| `All -> close_tls t; Flow.shutdown t.flow `All
| `Receive -> () (* Not obvious how to do this with TLS, so ignore for now. *)
| `Receive -> ()
| `Send | `All ->
match t.state with
| `Active tls | `Read_closed tls ->
let tls', buf = Tls.Engine.send_close_notify tls in
t.state <- inject_state tls' (half_close t.state `write) ;
write_t t buf
| _ -> ()

let server_of_flow config flow =
drain_handshake {
Expand All @@ -185,22 +195,21 @@ module Raw = struct
| None -> config
| Some host -> Tls.Config.peer config host
in
let (tls, init) = Tls.Engine.client config' in
let t = {
state = `Eof ;
state = `Active tls ;
flow = (flow :> Flow.two_way_ty r);
linger = None ;
recv_buf = Cstruct.create 4096
} in
let (tls, init) = Tls.Engine.client config' in
let t = { t with state = `Active tls } in
write_t t init;
drain_handshake t


let epoch t =
match t.state with
| `Active tls -> Tls.Engine.epoch tls
| `Eof | `Error _ -> Error ()
| `Active tls | `Read_closed tls | `Write_closed tls -> Tls.Engine.epoch tls
| `Closed | `Error _ -> Error ()

let copy t ~src = Eio.Flow.Pi.simple_copy ~single_write t ~src

Expand Down
Loading

0 comments on commit e6a52d8

Please sign in to comment.