Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLS without cstruct #497

Merged
merged 22 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions async/io.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ module Make (Fd : Fd) : S with module Fd := Fd = struct
type t =
{ fd : Fd.t
; mutable state : State.t
; mutable linger : Cstruct.t option
; recv_buf : Cstruct.t
; mutable linger : string option
; recv_buf : bytes
}

let tls_error = Fn.compose Deferred.Or_error.error_s Tls_error.sexp_of_t
Expand Down Expand Up @@ -73,18 +73,17 @@ module Make (Fd : Fd) : S with module Fd := Fd = struct
| Active _, `Eof ->
t.state <- Eof;
return `Eof
| Active tls, `Ok n -> handle tls (Cstruct.sub t.recv_buf 0 n)
| Active tls, `Ok n -> handle tls (Stdlib.Bytes.sub_string t.recv_buf 0 n)
| Error e, _ -> tls_error e
| Eof, _ -> return `Eof)
;;

let rec read t buf =
let writeout res =
let open Cstruct in
let rlen = length res in
let n = min (length buf) rlen in
blit res 0 buf 0 n;
t.linger <- (if n < rlen then Some (sub res n (rlen - n)) else None);
let rlen = String.length res in
let n = min (Bytes.length buf) rlen in
Stdlib.Bytes.blit_string res 0 buf 0 n;
t.linger <- (if n < rlen then Some (Stdlib.String.sub res n (rlen - n)) else None);
return n
in
match t.linger with
Expand Down Expand Up @@ -120,7 +119,7 @@ module Make (Fd : Fd) : S with module Fd := Fd = struct
match mcs, t.linger with
| None, _ -> ()
| scs, None -> t.linger <- scs
| Some cs, Some l -> t.linger <- Some (Cstruct.append l cs)
| Some cs, Some l -> t.linger <- Some (l ^ cs)
in
match t.state with
| Active tls when not (Tls.Engine.handshake_in_progress tls) -> return t
Expand Down Expand Up @@ -173,7 +172,7 @@ module Make (Fd : Fd) : S with module Fd := Fd = struct
{ state = Active (Tls.Engine.server config)
; fd
; linger = None
; recv_buf = Cstruct.create 4096
; recv_buf = Bytes.create 4096
}
;;

Expand All @@ -183,7 +182,7 @@ module Make (Fd : Fd) : S with module Fd := Fd = struct
| None -> config
| Some host -> Tls.Config.peer config host
in
let t = { state = Eof; fd; linger = None; recv_buf = Cstruct.create 4096 } in
let t = { state = Eof; fd; linger = None; recv_buf = Bytes.create 4096 } in
let tls, init = Tls.Engine.client config' in
let t = { t with state = Active tls } in
let%bind () = Fd.write_full t.fd init in
Expand Down
8 changes: 4 additions & 4 deletions async/io_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ open! Async
module type Fd = sig
type t

val read : t -> Cstruct.t -> [ `Ok of int | `Eof ] Deferred.Or_error.t
val write_full : t -> Cstruct.t -> unit Deferred.Or_error.t
val read : t -> bytes -> [ `Ok of int | `Eof ] Deferred.Or_error.t
val write_full : t -> string -> unit Deferred.Or_error.t
end

module type S = sig
Expand All @@ -32,10 +32,10 @@ module type S = sig

(** [read t buffer] is [length], the number of bytes read into
[buffer]. *)
val read : t -> Cstruct.t -> int Deferred.Or_error.t
val read : t -> bytes -> int Deferred.Or_error.t

(** [writev t buffers] writes the [buffers] to the session. *)
val writev : t -> Cstruct.t list -> unit Deferred.Or_error.t
val writev : t -> string list -> unit Deferred.Or_error.t

(** [close t] closes the TLS session by sending a close notify to the peer. *)
val close_tls : t -> unit Deferred.Or_error.t
Expand Down
8 changes: 4 additions & 4 deletions async/session.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@ module Fd = struct
type t = Reader.t * Writer.t

let read (reader, (_ : Writer.t)) buf =
Deferred.Or_error.try_with (fun () -> Async_cstruct.read reader buf)
Deferred.Or_error.try_with (fun () -> Reader.read reader buf)
;;

let write ((_ : Reader.t), writer) buf =
Deferred.Or_error.try_with (fun () ->
Async_cstruct.schedule_write writer buf;
Writer.write writer buf;
Writer.flushed writer)
;;

let rec write_full fd buf =
let open Deferred.Or_error.Let_syntax in
match Cstruct.length buf with
match String.length buf with
| 0 -> return ()
| len ->
let%bind () = write fd buf in
write_full fd (Cstruct.shift buf len)
write_full fd (String.sub buf ~pos:len ~len:(String.length buf - len))
;;
end

Expand Down
6 changes: 3 additions & 3 deletions async/tls_async.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ let try_to_close t =
;;

let pipe t =
let b_reader = Cstruct.create 0x8000 in
let b_reader = Bytes.create 0x8000 in
let rec f_reader writer =
match%bind Session.read t b_reader with
| Ok 0 ->
Pipe.close writer;
return ()
| Ok len ->
let%bind () = Pipe.write writer (Cstruct.to_string (Cstruct.sub b_reader 0 len)) in
let%bind () = Pipe.write writer (Stdlib.Bytes.sub_string b_reader 0 len) in
f_reader writer
| Error read_error ->
Log.Global.error_s [%sexp (read_error : Error.t)];
Expand All @@ -28,7 +28,7 @@ let pipe t =
let%bind pipe_read = Pipe.read reader in
match pipe_read with
| `Ok s ->
(match%bind Session.writev t [ Cstruct.of_string s ] with
(match%bind Session.writev t [ s ] with
| Ok () -> f_writer reader
| Error (_ : Error.t) -> try_to_close t)
| `Eof -> try_to_close t
Expand Down
55 changes: 24 additions & 31 deletions async/x509_async.ml
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ module Or_error = struct
let of_result ~to_string = Result.map_error ~f:(Fn.compose Error.of_string to_string)
let of_result_msg x = of_result x ~to_string:(fun (`Msg msg) -> msg)

let lift_result_msg_of_cstruct f ~contents =
f (Cstruct.of_string contents) |> of_result_msg
let lift_result_msg_of_string f ~contents =
f contents |> of_result_msg
;;

let lift_asn_error_of_cstruct f ~contents =
f (Cstruct.of_string contents) |> of_result ~to_string:(fun (`Parse msg) -> msg)
let lift_asn_error_of_string f ~contents =
f contents |> of_result ~to_string:(fun (`Parse msg) -> msg)
;;
end

module CRL = struct
include X509.CRL

let decode_der = Or_error.lift_result_msg_of_cstruct decode_der
let decode_der = Or_error.lift_result_msg_of_string decode_der

let revoke ?digest ~issuer ~this_update ?next_update ?extensions revoked_certs key =
revoke ?digest ~issuer ~this_update ?next_update ?extensions revoked_certs key
Expand All @@ -58,9 +58,9 @@ module Certificate = struct
include X509.Certificate
open Deferred.Or_error.Let_syntax

let decode_pem_multiple = Or_error.lift_result_msg_of_cstruct decode_pem_multiple
let decode_pem = Or_error.lift_result_msg_of_cstruct decode_pem
let decode_der = Or_error.lift_result_msg_of_cstruct decode_der
let decode_pem_multiple = Or_error.lift_result_msg_of_string decode_pem_multiple
let decode_pem = Or_error.lift_result_msg_of_string decode_pem
let decode_der = Or_error.lift_result_msg_of_string decode_der

let of_pem_file ca_file =
let%bind contents = file_contents ca_file in
Expand All @@ -81,7 +81,7 @@ module Authenticator = struct
module Chain_of_trust = struct
type t =
{ trust_anchors : [ `File of Filename.t | `Directory of Filename.t ]
; allowed_hashes : Mirage_crypto.Hash.hash list option
; allowed_hashes : Digestif.hash' list option
; crls : Filename.t option
}

Expand All @@ -93,8 +93,8 @@ module Authenticator = struct

type t =
| Chain_of_trust of Chain_of_trust.t
| Cert_fingerprint of Mirage_crypto.Hash.hash * string
| Key_fingerprint of Mirage_crypto.Hash.hash * string
| Cert_fingerprint of Digestif.hash' * string
| Key_fingerprint of Digestif.hash' * string

let ca_file ?allowed_hashes ?crls filename () =
let trust_anchors = `File filename in
Expand All @@ -114,7 +114,7 @@ module Authenticator = struct
let known_delimiters = [ ':'; ' ' ] in
String.filter fingerprint ~f:(fun c ->
not (List.exists known_delimiters ~f:(Char.equal c)))
|> Cstruct.of_hex
|> Ohex.decode
;;

let of_cas ~time ({ trust_anchors; allowed_hashes; crls } : Chain_of_trust.t) =
Expand All @@ -132,12 +132,12 @@ module Authenticator = struct

let of_cert_fingerprint ~time hash fingerprint =
let fingerprint = cleanup_fingerprint fingerprint in
X509.Authenticator.server_cert_fingerprint ~time ~hash ~fingerprint
X509.Authenticator.cert_fingerprint ~time ~hash ~fingerprint
;;

let of_key_fingerprint ~time hash fingerprint =
let fingerprint = cleanup_fingerprint fingerprint in
X509.Authenticator.server_key_fingerprint ~time ~hash ~fingerprint
X509.Authenticator.key_fingerprint ~time ~hash ~fingerprint
;;

let time = Fn.compose Ptime.of_float_s Unix.gettimeofday
Expand All @@ -156,7 +156,7 @@ end
module Distinguished_name = struct
include X509.Distinguished_name

let decode_der = Or_error.lift_result_msg_of_cstruct decode_der
let decode_der = Or_error.lift_result_msg_of_string decode_der
end

module OCSP = struct
Expand All @@ -169,7 +169,7 @@ module OCSP = struct
create ?certs ?digest ?requestor_name ?key cert_ids |> Or_error.of_result_msg
;;

let decode_der = Or_error.lift_asn_error_of_cstruct decode_der
let decode_der = Or_error.lift_asn_error_of_string decode_der
end

module Response = struct
Expand All @@ -196,14 +196,14 @@ module OCSP = struct
;;

let responses t = responses t |> Or_error.of_result_msg
let decode_der = Or_error.lift_asn_error_of_cstruct decode_der
let decode_der = Or_error.lift_asn_error_of_string decode_der
end
end

module PKCS12 = struct
include X509.PKCS12

let decode_der = Or_error.lift_result_msg_of_cstruct decode_der
let decode_der = Or_error.lift_result_msg_of_string decode_der
let verify password t = verify password t |> Or_error.of_result_msg
end

Expand All @@ -213,11 +213,10 @@ module Private_key = struct
let sign hash ?scheme key data =
sign hash ?scheme key data
|> Or_error.of_result_msg
|> Or_error.map ~f:Cstruct.to_string
;;

let decode_der = Or_error.lift_result_msg_of_cstruct decode_der
let decode_pem = Or_error.lift_result_msg_of_cstruct decode_pem
let decode_der = Or_error.lift_result_msg_of_string decode_der
let decode_pem = Or_error.lift_result_msg_of_string decode_pem

let of_pem_file file =
let%map contents = Reader.file_contents file in
Expand All @@ -229,27 +228,21 @@ module Public_key = struct
include X509.Public_key

let verify hash ?scheme ~signature key data =
let signature = Cstruct.of_string signature in
let data =
match data with
| `Digest data -> `Digest (Cstruct.of_string data)
| `Message data -> `Message (Cstruct.of_string data)
in
verify hash ?scheme ~signature key data |> Or_error.of_result_msg
;;

let decode_der = Or_error.lift_result_msg_of_cstruct decode_der
let decode_pem = Or_error.lift_result_msg_of_cstruct decode_pem
let decode_der = Or_error.lift_result_msg_of_string decode_der
let decode_pem = Or_error.lift_result_msg_of_string decode_pem
end

module Signing_request = struct
include X509.Signing_request

let decode_der ?allowed_hashes der =
Cstruct.of_string der |> decode_der ?allowed_hashes |> Or_error.of_result_msg
decode_der ?allowed_hashes der |> Or_error.of_result_msg
;;

let decode_pem pem = Cstruct.of_string pem |> decode_pem |> Or_error.of_result_msg
let decode_pem pem = decode_pem pem |> Or_error.of_result_msg

let create subject ?digest ?extensions key =
create subject ?digest ?extensions key |> Or_error.of_result_msg
Expand Down
Loading