diff --git a/src/hostnet_test/forwarding.ml b/src/hostnet_test/forwarding.ml index 9580060b0..02eefbd7f 100644 --- a/src/hostnet_test/forwarding.ml +++ b/src/hostnet_test/forwarding.ml @@ -13,25 +13,29 @@ let (>>*=) m f = m >>= function module Make(Host: Sig.HOST) = struct - let run ?(timeout=60.) t = + let run ?(timeout=Duration.of_sec 60) t = let timeout = - Host.Time.sleep timeout >>= fun () -> + Host.Time.sleep_ns timeout >>= fun () -> Lwt.fail_with "timeout" in Host.Main.run @@ Lwt.pick [ timeout; t ] - module Channel = Channel.Make(Host.Sockets.Stream.Tcp) + module Channel = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) module ForwardServer = struct (** Accept connections, read the forwarding header and run a proxy *) + module Proxy = + Mirage_flow_lwt.Proxy + (Mclock)(Host.Sockets.Stream.Tcp)(Host.Sockets.Stream.Tcp) + let accept flow = let sizeof = 1 + 2 + 4 + 2 in let header = Cstruct.create sizeof in Host.Sockets.Stream.Tcp.read_into flow header >>= function - | `Eof -> failwith "EOF" - | `Error e -> failwith (Host.Sockets.Stream.Tcp.error_message e) - | `Ok () -> + | Ok `Eof -> failwith "EOF" + | Error e -> Fmt.kstrf failwith "%a" Host.Sockets.Stream.Tcp.pp_error e + | Ok (`Data ()) -> let ip_len = Cstruct.LE.get_uint16 header 1 in let ip = let bytes = Cstruct.(to_string @@ sub header 3 ip_len) in @@ -44,14 +48,11 @@ module Make(Host: Sig.HOST) = struct Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 ip, port) >>= function | Error (`Msg x) -> failwith x | Ok remote -> + Mclock.connect () >>= fun clock -> Lwt.finalize (fun () -> - Mirage_flow.proxy - (module Clock) - (module Host.Sockets.Stream.Tcp) flow - (module Host.Sockets.Stream.Tcp) remote () - >>= function - | `Error (`Msg m) -> failwith m - | `Ok (_l_stats, _r_stats) -> Lwt.return () + Proxy.proxy clock flow remote >>= function + | Error e -> Fmt.kstrf failwith "%a" Proxy.pp_error e + | Ok (_l_stats, _r_stats) -> Lwt.return () ) (fun () -> Host.Sockets.Stream.Tcp.close remote ) @@ -69,7 +70,7 @@ module Make(Host: Sig.HOST) = struct } end - module Forward = Forward.Make(struct + module Forward = Forward.Make(Mclock)(struct include Host.Sockets.Stream.Tcp open Lwt.Infix @@ -89,7 +90,8 @@ module Make(Host: Sig.HOST) = struct module Server = Protocol_9p.Server.Make(Log)(Host.Sockets.Stream.Tcp)(Ports) let with_server f = - let ports = Ports.make () in + Mclock.connect () >>= fun clock -> + let ports = Ports.make clock in Ports.set_context ports ""; Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 localhost, 0) >>= fun server -> @@ -120,11 +122,14 @@ module Make(Host: Sig.HOST) = struct let read_http ch = let rec loop acc = - Channel.read_line ch >>= fun bufs -> - let txt = Cstruct.(to_string (concat bufs)) in - if txt = "" - then Lwt.return acc - else loop (acc ^ txt) + Channel.read_line ch >>= function + | Ok `Eof + | Error _ -> Lwt.return acc + | Ok (`Data bufs) -> + let txt = Cstruct.(to_string (concat bufs)) in + if txt = "" + then Lwt.return acc + else loop (acc ^ txt) in loop "" @@ -141,7 +146,9 @@ module Make(Host: Sig.HOST) = struct then failwith (Printf.sprintf "unrecognised HTTP GET: [%s]" request); let response = "HTTP/1.0 404 Not found\r\ncontent-length: 0\r\n\r\n" in Channel.write_string ch response 0 (String.length response); - Channel.flush ch + Channel.flush ch >|= function + | Ok () -> () + | Error e -> Fmt.kstrf failwith "%a" Channel.pp_write_error e let create () = Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 localhost, 0) @@ -181,13 +188,13 @@ module Make(Host: Sig.HOST) = struct type forward = { t: t; - fid: Protocol_9p_types.Fid.t; + fid: Protocol_9p.Types.Fid.t; ip: Ipaddr.V4.t; port: int; } let create t string = - let mode = Protocol_9p_types.FileMode.make ~is_directory:true + let mode = Protocol_9p.Types.FileMode.make ~is_directory:true ~owner:[`Read; `Write; `Execute] ~group:[`Read; `Execute] ~other:[`Read; `Execute ] () in Client.mkdir t.ninep [] string mode @@ -196,7 +203,7 @@ module Make(Host: Sig.HOST) = struct >>*= fun fid -> Client.walk_from_root t.ninep fid [ string; "ctl" ] >>*= fun _walk -> - Client.LowLevel.openfid t.ninep fid Protocol_9p_types.OpenMode.read_write + Client.LowLevel.openfid t.ninep fid Protocol_9p.Types.OpenMode.read_write >>*= fun _open -> let buf = Cstruct.create (String.length string) in Cstruct.blit_from_string string 0 buf 0 (String.length string); @@ -204,7 +211,7 @@ module Make(Host: Sig.HOST) = struct >>*= fun _write -> Client.LowLevel.read t.ninep fid 0L 1024l >>*= fun read -> - let response = Cstruct.to_string read.Protocol_9p_response.Read.data in + let response = Cstruct.to_string read.Protocol_9p.Response.Read.data in if Astring.String.is_prefix ~affix:"OK " response then begin let line = String.sub response 3 (String.length response - 3) in (* tcp:127.0.0.1:64500:tcp:127.0.0.1:64499 *) @@ -229,15 +236,15 @@ module Make(Host: Sig.HOST) = struct let ch = Channel.create flow in let message = "GET / HTTP/1.0\r\nconnection: close\r\n\r\n" in Channel.write_string ch message 0 (String.length message); - Channel.flush ch - >>= fun () -> - Host.Sockets.Stream.Tcp.shutdown_write flow - >>= fun () -> - read_http ch - >>= fun response -> - if not(Astring.String.is_prefix ~affix:"HTTP" response) - then failwith (Printf.sprintf "unrecognised HTTP response: [%s]" response); - Lwt.return () + Channel.flush ch >>= function + | Error e -> Fmt.kstrf failwith "%a" Channel.pp_write_error e + | Ok () -> + Host.Sockets.Stream.Tcp.shutdown_write flow + >>= fun () -> + read_http ch + >|= fun response -> + if not(Astring.String.is_prefix ~affix:"HTTP" response) + then failwith (Printf.sprintf "unrecognised HTTP response: [%s]" response) let test_one_forward () = let t = LocalServer.with_server (fun server -> diff --git a/src/hostnet_test/jbuild b/src/hostnet_test/jbuild index 554d46823..bcc0a5668 100644 --- a/src/hostnet_test/jbuild +++ b/src/hostnet_test/jbuild @@ -3,7 +3,7 @@ (executables ((names (main_lwt main_uwt)) (libraries ( - hostnet cmdliner alcotest lwt.unix logs.fmt - dns.mirage lwt.preemptive uwt.preemptive + hostnet cmdliner alcotest lwt.unix logs.fmt protocol-9p + mirage-dns lwt.preemptive uwt.preemptive mirage-clock-unix )) (preprocess no_preprocessing))) diff --git a/src/hostnet_test/slirp_stack.ml b/src/hostnet_test/slirp_stack.ml index 3a6ce90a2..299d5707d 100644 --- a/src/hostnet_test/slirp_stack.ml +++ b/src/hostnet_test/slirp_stack.ml @@ -66,48 +66,38 @@ module Make(Host: Sig.HOST) = struct module VMNET = Vmnet.Make(Host.Sockets.Stream.Tcp) module Config = Active_config.Make(Host.Time)(Host.Sockets.Stream.Unix) module Vnet = Basic_backend.Make - module Slirp_stack = Slirp.Make(Config)(VMNET)(Dns_policy)(Host)(Vnet) + module Slirp_stack = + Slirp.Make(Config)(VMNET)(Dns_policy)(Mclock)(Stdlibrandom)(Host)(Vnet) module Client = struct module Netif = VMNET module Ethif1 = Ethif.Make(Netif) - module Arpv41 = Arpv4.Make(Ethif1)(Clock)(Host.Time) - module Ipv41 = Ipv4.Make(Ethif1)(Arpv41) + module Arpv41 = Arpv4.Make(Ethif1)(Mclock)(Host.Time) + module Ipv41 = Static_ipv4.Make(Ethif1)(Arpv41) module Icmpv41 = Icmpv4.Make(Ipv41) - module Udp1 = Udp.Make(Ipv41) - module Tcp1 = Tcp.Flow.Make(Ipv41)(Host.Time)(Clock)(Random) - include Tcpip_stack_direct.Make(Console_unix)(Host.Time) - (Random)(Netif)(Ethif1)(Arpv41)(Ipv41)(Icmpv41)(Udp1)(Tcp1) + module Udp1 = Udp.Make(Ipv41)(Stdlibrandom) + module Tcp1 = Tcp.Flow.Make(Ipv41)(Host.Time)(Mclock)(Stdlibrandom) + include Tcpip_stack_direct.Make(Host.Time) + (Stdlibrandom)(Netif)(Ethif1)(Arpv41)(Ipv41)(Icmpv41)(Udp1)(Tcp1) let or_error name m = - let open Lwt.Infix in m >>= function | `Error _ -> Fmt.kstrf failwith "Failed to connect %s device" name | `Ok x -> Lwt.return x let connect (interface: VMNET.t) = - let open Lwt.Infix in - or_error "console" @@ Console_unix.connect "0" - >>= fun console -> - or_error "ethernet" @@ Ethif1.connect interface - >>= fun ethif -> - or_error "arp" @@ Arpv41.connect ethif - >>= fun arp -> - or_error "ipv4" @@ Ipv41.connect ethif arp - >>= fun ipv4 -> - or_error "icmpv4" @@ Icmpv41.connect ipv4 - >>= fun icmpv4 -> - or_error "udp" @@ Udp1.connect ipv4 - >>= fun udp4 -> - or_error "tcp" @@ Tcp1.connect ipv4 - >>= fun tcp4 -> + Ethif1.connect interface >>= fun ethif -> + Mclock.connect () >>= fun clock -> + Arpv41.connect ethif clock >>= fun arp -> + Ipv41.connect ethif arp >>= fun ipv4 -> + Icmpv41.connect ipv4 >>= fun icmpv4 -> + Udp1.connect ipv4 >>= fun udp4 -> + Tcp1.connect ipv4 clock >>= fun tcp4 -> let cfg = { - V1_LWT. name = "stackv4_ip"; - console; + Mirage_stack_lwt.name = "stackv4_ip"; interface; - mode = `DHCP; } in - or_error "stack" @@ connect cfg ethif arp ipv4 icmpv4 udp4 tcp4 + connect cfg ethif arp ipv4 icmpv4 udp4 tcp4 >>= fun stack -> Lwt.return stack end @@ -137,6 +127,7 @@ module Make(Host: Sig.HOST) = struct } let config_without_bridge = + Mclock.connect () >|= fun clock -> { Slirp.peer_ip; local_ip; @@ -150,6 +141,7 @@ module Make(Host: Sig.HOST) = struct global_arp_table; mtu = 1500; host_names = []; + clock; } (* This is a hacky way to get a hancle to the server side of the stack. *) @@ -178,7 +170,9 @@ module Make(Host: Sig.HOST) = struct ); port - let connection = start_stack (Vnet.create ()) config_without_bridge () + let connection = + config_without_bridge >>= fun config -> + start_stack (Vnet.create ()) config () let with_stack f = connection >>= fun port -> diff --git a/src/hostnet_test/suite.ml b/src/hostnet_test/suite.ml index 83e6d66ae..aab74f656 100644 --- a/src/hostnet_test/suite.ml +++ b/src/hostnet_test/suite.ml @@ -16,9 +16,9 @@ module Make(Host: Sig.HOST) = struct module Slirp_stack = Slirp_stack.Make(Host) open Slirp_stack - let run_test ?(timeout=60.) t = + let run_test ?(timeout=Duration.of_sec 60) t = let timeout = - Host.Time.sleep timeout >>= fun () -> + Host.Time.sleep_ns timeout >>= fun () -> Lwt.fail_with "timeout" in Host.Main.run @@ Lwt.pick [ timeout; t ] @@ -34,14 +34,15 @@ module Make(Host: Sig.HOST) = struct run t let set_dns_policy ?host_names use_host = + Mclock.connect () >|= fun clock -> Dns_policy.remove ~priority:3; Dns_policy.add ~priority:3 ~config:(if use_host then `Host else Dns_policy.google_dns); - Slirp_stack.Debug.update_dns ?host_names () + Slirp_stack.Debug.update_dns ?host_names clock let test_dns_query server use_host () = - set_dns_policy use_host; let t _ stack = + set_dns_policy use_host >>= fun () -> let resolver = DNS.create stack in DNS.gethostbyname ~server resolver "www.google.com" >|= function | (_ :: _) as ips -> @@ -54,8 +55,9 @@ module Make(Host: Sig.HOST) = struct let test_builtin_dns_query server use_host () = let name = "experimental.host.name.localhost" in - set_dns_policy ~host_names:[ Dns.Name.of_string name ] use_host; let t _ stack = + set_dns_policy ~host_names:[ Dns.Name.of_string name ] use_host + >>= fun () -> let resolver = DNS.create stack in DNS.gethostbyname ~server resolver name >>= function | (_ :: _) as ips -> @@ -68,9 +70,9 @@ module Make(Host: Sig.HOST) = struct run t let test_etc_hosts_query server use_host () = - set_dns_policy use_host; let test_name = "vpnkit.is.cool.yes.really" in let t _ stack = + set_dns_policy use_host >>= fun () -> let resolver = DNS.create stack in DNS.gethostbyname ~server resolver test_name >>= function | (_ :: _) as ips -> @@ -104,11 +106,11 @@ module Make(Host: Sig.HOST) = struct begin Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) >|= function - | `Ok _ -> + | Ok _ -> Log.err (fun f -> f "Connected to www.google.com, max_connections exceeded"); failwith "too many connections" - | `Error _ -> + | Error _ -> Log.debug (fun f -> f "Expected failure to connect to www.google.com") end @@ -118,9 +120,9 @@ module Make(Host: Sig.HOST) = struct begin Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) >|= function - | `Ok _ -> + | Ok _ -> Log.debug (fun f -> f "Connected to www.google.com"); - | `Error _ -> + | Error _ -> Log.debug (fun f -> f "Failure to connect to www.google.com: removing \ max_connections limit didn't work"); @@ -135,7 +137,7 @@ module Make(Host: Sig.HOST) = struct Lwt.return_unit ) in - run ~timeout:240.0 t + run ~timeout:(Duration.of_sec 240) t let test_http_fetch () = let t _ stack = @@ -145,33 +147,33 @@ module Make(Host: Sig.HOST) = struct begin Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) >>= function - | `Error _ -> + | Error _ -> Log.err (fun f -> f "Failed to connect to www.google.com:80"); failwith "http_fetch" - | `Ok flow -> + | Ok flow -> Log.info (fun f -> f "Connected to www.google.com:80"); let page = Io_page.(to_cstruct (get 1)) in let http_get = "GET / HTTP/1.0\nHost: anil.recoil.org\n\n" in Cstruct.blit_from_string http_get 0 page 0 (String.length http_get); let buf = Cstruct.sub page 0 (String.length http_get) in Client.TCPV4.write flow buf >>= function - | `Eof -> + | Error `Closed -> Log.err (fun f -> f "EOF writing HTTP request to www.google.com:80"); failwith "EOF on writing HTTP GET" - | `Error _ -> + | Error _ -> Log.err (fun f -> f "Failure writing HTTP request to www.google.com:80"); failwith "Failure on writing HTTP GET" - | `Ok _buf -> + | Ok () -> let rec loop total_bytes = Client.TCPV4.read flow >>= function - | `Eof -> Lwt.return total_bytes - | `Error _ -> + | Ok `Eof -> Lwt.return total_bytes + | Error _ -> Log.err (fun f -> f "Failure read HTTP response from www.google.com:80"); failwith "Failure on reading HTTP GET" - | `Ok buf -> + | Ok (`Data buf) -> Log.info (fun f -> f "Read %d bytes from www.google.com:80" (Cstruct.len buf)); Log.info (fun f -> f "%s" (Cstruct.to_string buf)); @@ -197,24 +199,24 @@ module Make(Host: Sig.HOST) = struct } let accept flow = - let module Channel = Channel.Make(Host.Sockets.Stream.Tcp) in + let module Channel = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) in let ch = Channel.create flow in (* XXX: this looks like it isn't tail recursive to me *) let rec drop_all_data count = - Lwt.catch (fun () -> - Channel.read_some ch >>= fun buffer -> - drop_all_data Int64.(add count (of_int (Cstruct.len buffer))) - ) (function - | End_of_file -> Lwt.return count - | e -> Lwt.fail e - ) + Channel.read_some ch >>= function + | Error e -> Fmt.kstrf Lwt.fail_with "%a" Channel.pp_error e + | Ok `Eof -> Lwt.return count + | Ok (`Data buffer) -> + drop_all_data Int64.(add count (of_int (Cstruct.len buffer))) in drop_all_data 0L >>= fun total -> let response = Cstruct.create 8 in Cstruct.LE.set_uint64 response 0 total; Channel.write_buffer ch response; - Channel.flush ch + Channel.flush ch >>= function + | Error e -> Fmt.kstrf Lwt.fail_with "%a" Channel.pp_write_error e + | Ok () -> Lwt.return_unit let create () = Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) @@ -253,12 +255,12 @@ module Make(Host: Sig.HOST) = struct Client.TCPV4.create_connection (Client.tcpv4 stack) (Ipaddr.V4.localhost, local_port) >>= function - | `Ok c -> + | Ok c -> Log.info (fun f -> f "Connected %d, total tracked connections %d" i (Host.Sockets.get_num_connections ())); loop (c :: acc) (i + 1) - | `Error _ -> + | Error _ -> Fmt.kstrf failwith "Connection %d failed, total tracked connections %d" i (Host.Sockets.get_num_connections ()) @@ -269,7 +271,7 @@ module Make(Host: Sig.HOST) = struct (List.length flows) (Host.Sockets.get_num_connections ())); (* How many connections is this? *) in - run' ~timeout:240.0 t + run' ~timeout:(Duration.of_sec 240) t let test_stream_data connections length () = let t local_port _ stack = @@ -278,18 +280,20 @@ module Make(Host: Sig.HOST) = struct Client.TCPV4.create_connection (Client.tcpv4 stack) (Ipaddr.V4.localhost, local_port) >>= function - | `Error `Refused -> + | Error `Refused -> Log.info (fun f -> f "DevNullServer Refused connection"); - Host.Time.sleep 0.2 + Host.Time.sleep_ns (Duration.of_ms 200) >>= fun () -> connect () - | `Error `Timeout -> + | Error `Timeout -> Log.err (fun f -> f "DevNullServer connection timeout"); failwith "DevNullServer connection timeout"; - | `Error (`Unknown x) -> - Log.err (fun f -> f "DevNullServer connnection failure: %s" x); - failwith x - | `Ok flow -> + | Error e -> + Log.err (fun f -> + f "DevNullServer connnection failure: %a" + Client.TCPV4.pp_error e); + Fmt.kstrf failwith "%a" Client.TCPV4.pp_error e + | Ok flow -> Log.info (fun f -> f "Connected to local server"); Lwt.return flow in @@ -304,32 +308,32 @@ module Make(Host: Sig.HOST) = struct let this_time = min remaining (Cstruct.len page) in let buf = Cstruct.sub page 0 this_time in Client.TCPV4.write flow buf >>= function - | `Eof -> + | Error `Closed -> Log.err (fun f -> f "EOF writing to DevNullServerwith %d bytes left" remaining); (* failwith "EOF on writing to DevNullServer" *) Lwt.return () - | `Error _ -> + | Error _ -> Log.err (fun f -> f "Failure writing to DevNullServer with %d bytes left" remaining); (* failwith "Failure on writing to DevNullServer" *) Lwt.return () - | `Ok () -> + | Ok () -> loop (remaining - this_time) end in loop length >>= fun () -> Client.TCPV4.close flow >>= fun () -> Client.TCPV4.read flow >|= function - | `Eof -> + | Ok `Eof -> Log.err (fun f -> f "EOF reading result from DevNullServer"); (* failwith "EOF reading result from DevNullServer" *) - | `Error _ -> + | Error _ -> Log.err (fun f -> f "Failure reading result from DevNullServer"); (* failwith "Failure on reading result from DevNullServer" *) - | `Ok buf -> + | Ok (`Data buf) -> Log.info (fun f -> f "Read %d bytes from DevNullServer" (Cstruct.len buf)); let response = Cstruct.LE.get_uint64 buf 0 in diff --git a/src/hostnet_test/test_http.ml b/src/hostnet_test/test_http.ml index 38cc24c02..fd09124fd 100644 --- a/src/hostnet_test/test_http.ml +++ b/src/hostnet_test/test_http.ml @@ -95,13 +95,13 @@ module Make(Host: Sig.HOST) = struct Lwt.finalize (fun () -> f server) (fun () -> Server.destroy server) module Outgoing = struct - module C = Channel.Make(Slirp_stack.Client.TCPV4) + module C = Mirage_channel_lwt.Make(Slirp_stack.Client.TCPV4) module IO = Cohttp_mirage_io.Make(C) module Request = Cohttp.Request.Make(IO) module Response = Cohttp.Response.Make(IO) end module Incoming = struct - module C = Channel.Make(Host.Sockets.Stream.Tcp) + module C = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) module IO = Cohttp_mirage_io.Make(C) module Request = Cohttp.Request.Make(IO) module Response = Cohttp.Response.Make(IO) @@ -111,12 +111,12 @@ module Make(Host: Sig.HOST) = struct let open Slirp_stack in Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) >>= function - | `Ok flow -> + | Ok flow -> Log.info (fun f -> f "Connected to %s:80" (Ipaddr.V4.to_string ip)); let oc = Outgoing.C.create flow in Outgoing.Request.write ~flush:true (fun _writer -> Lwt.return_unit) request oc - | `Error _ -> + | Error _ -> Log.err (fun f -> f "Failed to connect to %s:80" (Ipaddr.V4.to_string ip)); failwith "http_fetch" @@ -149,7 +149,8 @@ module Make(Host: Sig.HOST) = struct request >>= fun () -> Lwt.pick [ - (Host.Time.sleep 100. >>= fun () -> Lwt.return `Timeout); + (Host.Time.sleep_ns (Duration.of_sec 100) >|= fun () -> + `Timeout); (forwarded >>= fun x -> Lwt.return (`Result x)) ] ) @@ -247,6 +248,8 @@ module Make(Host: Sig.HOST) = struct Lwt.return () end + let err_flush e = Fmt.kstrf failwith "%a" Incoming.C.pp_write_error e + let test_http_connect () = let test_dst_ip = Ipaddr.V4.of_string_exn "1.2.3.4" in Host.Main.run begin @@ -276,10 +279,13 @@ module Make(Host: Sig.HOST) = struct so we write the header ourselves *) Incoming.C.write_line ic "HTTP/1.0 200 OK\r"; Incoming.C.write_line ic "\r"; - Incoming.C.flush ic - >>= fun () -> - Incoming.C.write_line ic "hello"; - Incoming.C.flush ic + Incoming.C.flush ic >>= function + | Error e -> err_flush e + | Ok () -> + Incoming.C.write_line ic "hello"; + Incoming.C.flush ic >|= function + | Error e -> err_flush e + | Ok () -> () ) (fun server -> Slirp_stack.Slirp_stack.Debug.update_http ~https:("127.0.0.1:" ^ (string_of_int server.Server.port)) () @@ -290,17 +296,20 @@ module Make(Host: Sig.HOST) = struct Client.TCPV4.create_connection (Client.tcpv4 stack) (test_dst_ip, 443) >>= function - | `Error _ -> + | Error _ -> Log.err (fun f -> f "TCPV4.create_connection %a:443 failed" Ipaddr.V4.pp_hum test_dst_ip); failwith "TCPV4.create_connection" - | `Ok flow -> + | Ok flow -> let ic = Outgoing.C.create flow in - Outgoing.C.read_some ~len:5 ic >>= fun buf -> - let txt = Cstruct.to_string buf in - Alcotest.check Alcotest.string "message" "hello" txt; - Lwt.return_unit + Outgoing.C.read_some ~len:5 ic >>= function + | Error e -> Fmt.kstrf failwith "%a" Outgoing.C.pp_error e + | Ok `Eof -> failwith "EOF" + | Ok (`Data buf) -> + let txt = Cstruct.to_string buf in + Alcotest.check Alcotest.string "message" "hello" txt; + Lwt.return_unit ) ) end diff --git a/src/hostnet_test/test_nat.ml b/src/hostnet_test/test_nat.ml index c535e732c..df3f9701d 100644 --- a/src/hostnet_test/test_nat.ml +++ b/src/hostnet_test/test_nat.ml @@ -9,9 +9,9 @@ module Log = (val Logs.src_log src : Logs.LOG) module Make(Host: Sig.HOST) = struct - let run ?(timeout=60.) t = + let run ?(timeout=Duration.of_sec 60) t = let timeout = - Host.Time.sleep timeout >>= fun () -> + Host.Time.sleep_ns timeout >>= fun () -> Lwt.fail_with "timeout" in Host.Main.run @@ Lwt.pick [ timeout; t ] @@ -90,18 +90,22 @@ module Make(Host: Sig.HOST) = struct t let wait_for_data ~highest t = if t.highest < highest then begin - Lwt.pick [ Lwt_condition.wait t.c; Host.Time.sleep 1. ] + Lwt.pick [ Lwt_condition.wait t.c; + Host.Time.sleep_ns (Duration.of_sec 1) ] >>= fun () -> Lwt.return (t.highest >= highest) end else Lwt.return true let wait_for_ports ~num t = if PortSet.cardinal t.seen_source_ports < num then begin - Lwt.pick [ Lwt_condition.wait t.c; Host.Time.sleep 1. ] + Lwt.pick [ Lwt_condition.wait t.c; + Host.Time.sleep_ns (Duration.of_sec 1) ] >|= fun () -> PortSet.cardinal t.seen_source_ports >= num end else Lwt.return true end + let err_udp e = Fmt.kstrf failwith "%a" Client.UDPV4.pp_error e + (* Start a local UDP echo server, send traffic to it and listen for a response *) let test_udp () = @@ -120,13 +124,15 @@ module Make(Host: Sig.HOST) = struct f "Sending %d -> %d value %d" virtual_port local_port (Cstruct.get_uint8 buffer 0)); Client.UDPV4.write - ~source_port:virtual_port - ~dest_ip:Ipaddr.V4.localhost - ~dest_port:local_port udpv4 buffer - >>= fun () -> - UdpServer.wait_for_data ~highest:1 server >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) + ~src_port:virtual_port + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:1 server >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) in loop 5 )) @@ -155,13 +161,15 @@ module Make(Host: Sig.HOST) = struct f "Sending %d -> %d value %d" virtual_port1 local_port (Cstruct.get_uint8 buffer 0)); Client.UDPV4.write - ~source_port:virtual_port1 - ~dest_ip:Ipaddr.V4.localhost - ~dest_port:local_port udpv4 buffer - >>= fun () -> - UdpServer.wait_for_data ~highest:1 server1 >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) + ~src_port:virtual_port1 + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:1 server1 >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) in loop 5 >>= fun () -> (* Listen on a second virtual source port and count @@ -177,15 +185,17 @@ module Make(Host: Sig.HOST) = struct f "Sending %d -> %d value %d" virtual_port2 local_port (Cstruct.get_uint8 buffer 0)); Client.UDPV4.write - ~source_port:virtual_port2 - ~dest_ip:Ipaddr.V4.localhost - ~dest_port:local_port udpv4 buffer - >>= fun () -> - UdpServer.wait_for_data ~highest:2 server2 >>= fun ok2 -> - (* The server should "multicast" the packet to the - original "connection" *) - UdpServer.wait_for_data ~highest:2 server1 >>= fun ok1 -> - if ok1 && ok2 then Lwt.return_unit else loop (remaining - 1) + ~src_port:virtual_port2 + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:2 server2 >>= fun ok2 -> + (* The server should "multicast" the packet to the + original "connection" *) + UdpServer.wait_for_data ~highest:2 server1 >>= fun ok1 -> + if ok1 && ok2 then Lwt.return_unit else loop (remaining - 1) in loop 5 ) @@ -211,18 +221,20 @@ module Make(Host: Sig.HOST) = struct let rec loop remaining = if remaining = 0 then failwith "Timed-out before UDP response arrived"; - let dest_port = echoserver.EchoServer.local_port in + let dst_port = echoserver.EchoServer.local_port in Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port1 dest_port + f "Sending %d -> %d value %d" virtual_port1 dst_port (Cstruct.get_uint8 buffer 0)); Client.UDPV4.write - ~source_port:virtual_port1 - ~dest_ip:Ipaddr.V4.localhost - ~dest_port udpv4 buffer - >>= fun () -> - UdpServer.wait_for_data ~highest:1 server1 >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) + ~src_port:virtual_port1 + ~dst:Ipaddr.V4.localhost + ~dst_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:1 server1 >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) in loop 5 >>= fun () -> @@ -275,13 +287,15 @@ module Make(Host: Sig.HOST) = struct f "Sending %d -> %d value %d" virtual_port local_port (Cstruct.get_uint8 buffer 0)); Client.UDPV4.write - ~source_port:virtual_port - ~dest_ip:Ipaddr.V4.localhost - ~dest_port:local_port udpv4 buffer - >>= fun () -> - UdpServer.wait_for_data ~highest:1 server >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) + ~src_port:virtual_port + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:1 server >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) in loop 5 >>= fun () -> Alcotest.(check int) "One NAT rule" 1 @@ -298,13 +312,15 @@ module Make(Host: Sig.HOST) = struct Log.debug (fun f -> f "Sending %d -> %d value %d" virtual_port local_port (Cstruct.get_uint8 buffer 0)); - Client.UDPV4.write ~source_port:virtual_port - ~dest_ip:Ipaddr.V4.localhost - ~dest_port:local_port udpv4 buffer - >>= fun () -> - UdpServer.wait_for_data ~highest:2 server >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) + Client.UDPV4.write ~src_port:virtual_port + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:2 server >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) in loop 5 >|= fun () -> Alcotest.(check int) "Still one NAT rule" 1 @@ -335,21 +351,25 @@ module Make(Host: Sig.HOST) = struct f "Sending %d -> %d value %d" virtual_port local_port1 (Cstruct.get_uint8 buffer 0)); Client.UDPV4.write - ~source_port:virtual_port - ~dest_ip:Ipaddr.V4.localhost - ~dest_port:local_port1 udpv4 buffer - >>= fun () -> - Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port local_port2 - (Cstruct.get_uint8 buffer 0)); - Client.UDPV4.write - ~source_port:virtual_port - ~dest_ip:Ipaddr.V4.localhost - ~dest_port:local_port2 udpv4 buffer - >>= fun () -> - UdpServer.wait_for_ports ~num:2 server >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) + ~src_port:virtual_port + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port1 udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + Log.debug (fun f -> + f "Sending %d -> %d value %d" virtual_port local_port2 + (Cstruct.get_uint8 buffer 0)); + Client.UDPV4.write + ~src_port:virtual_port + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port2 udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_ports ~num:2 server >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) in loop 5))) in