diff --git a/src/unix/lwt_bytes.ml b/src/unix/lwt_bytes.ml index b3cc1a7550..b1936952a9 100644 --- a/src/unix/lwt_bytes.ml +++ b/src/unix/lwt_bytes.ml @@ -5,8 +5,6 @@ open Bigarray -open Lwt.Infix - type t = (char, int8_unsigned_elt, c_layout) Array1.t let create size = Array1.create char c_layout size @@ -101,33 +99,11 @@ let copy buf = open Lwt_unix -external stub_read : Unix.file_descr -> t -> int -> int -> int = "lwt_unix_bytes_read" -external read_job : Unix.file_descr -> t -> int -> int -> int job = "lwt_unix_bytes_read_job" - -let read fd buf pos len = - if pos < 0 || len < 0 || pos > length buf - len then - invalid_arg "Lwt_bytes.read" - else - blocking fd >>= function - | true -> - wait_read fd >>= fun () -> - run_job (read_job (unix_file_descr fd) buf pos len) - | false -> - wrap_syscall Read fd (fun () -> stub_read (unix_file_descr fd) buf pos len) +let read = + Lwt_unix.read_bigarray "Lwt_bytes.read" [@ocaml.warning "-3"] -external stub_write : Unix.file_descr -> t -> int -> int -> int = "lwt_unix_bytes_write" -external write_job : Unix.file_descr -> t -> int -> int -> int job = "lwt_unix_bytes_write_job" - -let write fd buf pos len = - if pos < 0 || len < 0 || pos > length buf - len then - invalid_arg "Lwt_bytes.write" - else - blocking fd >>= function - | true -> - wait_write fd >>= fun () -> - run_job (write_job (unix_file_descr fd) buf pos len) - | false -> - wrap_syscall Write fd (fun () -> stub_write (unix_file_descr fd) buf pos len) +let write = + Lwt_unix.write_bigarray "Lwt_bytes.write" [@ocaml.warning "-3"] external stub_recv : Unix.file_descr -> t -> int -> int -> Unix.msg_flag list -> int = "lwt_unix_bytes_recv" diff --git a/src/unix/lwt_unix.cppo.ml b/src/unix/lwt_unix.cppo.ml index 341cb5efda..0863d37124 100644 --- a/src/unix/lwt_unix.cppo.ml +++ b/src/unix/lwt_unix.cppo.ml @@ -617,6 +617,9 @@ let close ch = else run_job (close_job ch.fd) +type bigarray = + (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t + let wait_read ch = Lwt.catch (fun () -> @@ -640,6 +643,24 @@ let read ch buf pos len = | false -> wrap_syscall Read ch (fun () -> stub_read ch.fd buf pos len) +external stub_read_bigarray : + Unix.file_descr -> bigarray -> int -> int -> int = "lwt_unix_bytes_read" +external read_bigarray_job : + Unix.file_descr -> bigarray -> int -> int -> int job = + "lwt_unix_bytes_read_job" + +let read_bigarray function_name fd buf pos len = + if pos < 0 || len < 0 || pos > Bigarray.Array1.dim buf - len then + invalid_arg function_name + else + blocking fd >>= function + | true -> + wait_read fd >>= fun () -> + run_job (read_bigarray_job (unix_file_descr fd) buf pos len) + | false -> + wrap_syscall Read fd (fun () -> + stub_read_bigarray (unix_file_descr fd) buf pos len) + let wait_write ch = Lwt.catch (fun () -> @@ -667,10 +688,27 @@ let write_string ch buf pos len = let buf = Bytes.unsafe_of_string buf in write ch buf pos len +external stub_write_bigarray : + Unix.file_descr -> bigarray -> int -> int -> int = "lwt_unix_bytes_write" +external write_bigarray_job : + Unix.file_descr -> bigarray -> int -> int -> int job = + "lwt_unix_bytes_write_job" + +let write_bigarray function_name fd buf pos len = + if pos < 0 || len < 0 || pos > Bigarray.Array1.dim buf - len then + invalid_arg function_name + else + blocking fd >>= function + | true -> + wait_write fd >>= fun () -> + run_job (write_bigarray_job (unix_file_descr fd) buf pos len) + | false -> + wrap_syscall Write fd (fun () -> + stub_write_bigarray (unix_file_descr fd) buf pos len) + module IO_vectors = struct - type _bigarray = - (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t + type _bigarray = bigarray type buffer = | Bytes of bytes @@ -780,13 +818,25 @@ external readv_job : Unix.file_descr -> IO_vectors.t -> int -> int job = let readv fd io_vectors = let count = check_io_vectors "Lwt_unix.readv" io_vectors in - Lazy.force fd.blocking >>= function - | true -> - wait_read fd >>= fun () -> - run_job (readv_job fd.fd io_vectors count) - | false -> - wrap_syscall Read fd (fun () -> - stub_readv fd.fd io_vectors.IO_vectors.prefix count) + if Sys.win32 then + match io_vectors.IO_vectors.prefix with + | [] -> + Lwt.return 0 + | first::_ -> + match first.buffer with + | Bytes buffer -> + read fd buffer first.offset first.length + | Bigarray buffer -> + read_bigarray "Lwt_unix.readv" fd buffer first.offset first.length + + else + Lazy.force fd.blocking >>= function + | true -> + wait_read fd >>= fun () -> + run_job (readv_job fd.fd io_vectors count) + | false -> + wrap_syscall Read fd (fun () -> + stub_readv fd.fd io_vectors.IO_vectors.prefix count) external stub_writev : Unix.file_descr -> IO_vectors.io_vector list -> int -> int = @@ -798,13 +848,25 @@ external writev_job : Unix.file_descr -> IO_vectors.t -> int -> int job = let writev fd io_vectors = let count = check_io_vectors "Lwt_unix.writev" io_vectors in - Lazy.force fd.blocking >>= function - | true -> - wait_write fd >>= fun () -> - run_job (writev_job fd.fd io_vectors count) - | false -> - wrap_syscall Write fd (fun () -> - stub_writev fd.fd io_vectors.IO_vectors.prefix count) + if Sys.win32 then + match io_vectors.IO_vectors.prefix with + | [] -> + Lwt.return 0 + | first::_ -> + match first.buffer with + | Bytes buffer -> + write fd buffer first.offset first.length + | Bigarray buffer -> + write_bigarray "Lwt_unix.writev" fd buffer first.offset first.length + + else + Lazy.force fd.blocking >>= function + | true -> + wait_write fd >>= fun () -> + run_job (writev_job fd.fd io_vectors count) + | false -> + wrap_syscall Write fd (fun () -> + stub_writev fd.fd io_vectors.IO_vectors.prefix count) (* +-----------------------------------------------------------------+ | Seeking and truncating | diff --git a/src/unix/lwt_unix.cppo.mli b/src/unix/lwt_unix.cppo.mli index 54e8f26add..ee7f1f54b7 100644 --- a/src/unix/lwt_unix.cppo.mli +++ b/src/unix/lwt_unix.cppo.mli @@ -1541,3 +1541,11 @@ val somaxconn : unit -> int val retained : 'a -> bool ref (** @deprecated Used for testing. *) + +val read_bigarray : + string -> file_descr -> IO_vectors._bigarray -> int -> int -> int Lwt.t + [@@ocaml.deprecated " This is an internal function."] + +val write_bigarray : + string -> file_descr -> IO_vectors._bigarray -> int -> int -> int Lwt.t + [@@ocaml.deprecated " This is an internal function."] diff --git a/test/unix/test_lwt_unix.cppo.ml b/test/unix/test_lwt_unix.cppo.ml index 5fbec41337..3217a836a6 100644 --- a/test/unix/test_lwt_unix.cppo.ml +++ b/test/unix/test_lwt_unix.cppo.ml @@ -378,13 +378,22 @@ let readv_tests = Lwt.return (bytes_written = Bytes.length data) in - let reader read_fd io_vectors underlying expected_count expected_data = + let reader + ?(close = true) + read_fd + io_vectors + underlying + expected_count + expected_data = fun () -> Gc.full_major (); let t = Lwt_unix.readv read_fd io_vectors in Gc.full_major (); t >>= fun bytes_read -> - Lwt_unix.close read_fd >>= fun () -> + (if close then + Lwt_unix.close read_fd + else + Lwt.return ()) >>= fun () -> let actual = List.fold_left (fun acc -> function @@ -497,6 +506,24 @@ let readv_tests = Lwt_list.for_all_s (fun t -> t ()) [writer write_fd (expected ^ "a"); reader read_fd io_vectors underlying limit (expected ^ "_")]); + + test "readv: windows" ~only_if:(fun () -> Sys.win32) begin fun () -> + let read_fd, write_fd = Lwt_unix.pipe () in + + let io_vectors, underlying = + make_io_vectors [ + `Bytes (1, 3, 1); + `Bigarray (1, 4, 1) + ] + in + + Lwt_list.for_all_s (fun t -> t ()) [ + writer write_fd "foobar"; + reader ~close:false read_fd io_vectors underlying 3 "_foo_______"; + (fun () -> Lwt_unix.IO_vectors.drop io_vectors 3; Lwt.return true); + reader read_fd io_vectors underlying 3 "_foo__bar__"; + ] + end; ] let writev_tests = @@ -512,13 +539,17 @@ let writev_tests = io_vectors in - let writer ?blocking write_fd io_vectors data_length = fun () -> + let writer ?(close = true) ?blocking write_fd io_vectors data_length = + fun () -> Lwt_unix.blocking write_fd >>= fun is_blocking -> Gc.full_major (); let t = Lwt_unix.writev write_fd io_vectors in Gc.full_major (); t >>= fun bytes_written -> - Lwt_unix.close write_fd >>= fun () -> + (if close then + Lwt_unix.close write_fd + else + Lwt.return ()) >>= fun () -> let blocking_matches = match blocking, is_blocking with | Some v, v' when v <> v' -> @@ -743,6 +774,24 @@ let writev_tests = Lwt.return (io_correct && not (Lwt_unix.IO_vectors.is_empty io_vectors))); + + test "writev: windows" ~only_if:(fun () -> Sys.win32) begin fun () -> + let io_vectors = + make_io_vectors [ + `Bytes ("foo", 0, 3); + `Bigarray ("bar", 0, 3); + ] + in + + let read_fd, write_fd = Lwt_unix.pipe () in + + Lwt_list.for_all_s (fun t -> t ()) [ + writer ~close:false write_fd io_vectors 3; + (fun () -> Lwt_unix.IO_vectors.drop io_vectors 3; Lwt.return true); + writer write_fd io_vectors 3; + reader read_fd "foobar"; + ] + end; ] let send_recv_msg_tests = [