diff --git a/src/bin/bind.ml b/src/bin/bind.ml index 5ba9f392d..43d1c3db7 100644 --- a/src/bin/bind.ml +++ b/src/bin/bind.ml @@ -8,15 +8,28 @@ module Log = (val Logs.src_log src : Logs.LOG) open Lwt.Infix open Vmnet -let errorf fmt = Fmt.kstrf (fun e -> Lwt.return (`Error (`Msg e))) fmt -let error_of_failure f = - Lwt.catch f (fun e -> Lwt_result.fail (`Msg (Printexc.to_string e))) - let is_windows = Sys.os_type = "Win32" +let failf fmt = Fmt.kstrf (fun e -> Lwt_result.fail (`Msg e)) fmt + module Make(Socket: Sig.SOCKETS) = struct - module Channel = Channel.Make(Socket.Stream.Unix) + module Channel = Mirage_channel_lwt.Make(Socket.Stream.Unix) + + let err_eof = Lwt_result.fail (`Msg "error: got EOF") + let err_read e = failf "error while reading: %a" Channel.pp_error e + let err_flush e = failf "error while flushing: %a" Channel.pp_write_error e + + let with_read x f = + x >>= function + | Error e -> err_read e + | Ok `Eof -> err_eof + | Ok (`Data x) -> f x + + let with_flush x f = + x >>= function + | Error e -> err_flush e + | Ok () -> f () type t = { fd: Socket.Stream.Unix.flow; @@ -33,19 +46,15 @@ module Make(Socket: Sig.SOCKETS) = struct let of_fd fd = let buf = Cstruct.create Init.sizeof in let (_: Cstruct.t) = Init.marshal Init.default buf in - error_of_failure (fun () -> - let c = Channel.create fd in - Channel.write_buffer c buf; - Channel.flush c >>= fun () -> - Channel.read_exactly ~len:Init.sizeof c >>= fun bufs -> - let buf = Cstruct.concat bufs in - let open Lwt_result.Infix in - Lwt.return (Init.unmarshal buf) - >>= fun (init, _) -> - Log.info (fun f -> - f "Client.negotiate: received %s" (Init.to_string init)); - Lwt_result.return { fd; c } - ) + let c = Channel.create fd in + Channel.write_buffer c buf; + with_flush (Channel.flush c) @@ fun () -> + with_read (Channel.read_exactly ~len:Init.sizeof c) @@ fun bufs -> + let buf = Cstruct.concat bufs in + let init, _ = Init.unmarshal buf in + Log.info (fun f -> + f "Client.negotiate: received %s" (Init.to_string init)); + Lwt_result.return { fd; c } let bind_ipv4 t (ipv4, port, stream) = let buf = Cstruct.create Command.sizeof in @@ -53,12 +62,12 @@ module Make(Socket: Sig.SOCKETS) = struct Command.marshal (Command.Bind_ipv4(ipv4, port, stream)) buf in Channel.write_buffer t.c buf; - Channel.flush t.c >>= fun () -> + with_flush (Channel.flush t.c) @@ fun () -> let rawfd = Socket.Stream.Unix.unsafe_get_raw_fd t.fd in let result = String.make 8 '\000' in let n, _, fd = Fd_send_recv.recv_fd rawfd result 0 8 [] in - (if n <> 8 then errorf "Message only contained %d bytes" n else + (if n <> 8 then failf "Message only contained %d bytes" n else let buf = Cstruct.create 8 in Cstruct.blit_from_string result 0 buf 0 8; Log.debug (fun f -> @@ -67,15 +76,15 @@ module Make(Socket: Sig.SOCKETS) = struct f "received result bytes: %s which is %s" (String.escaped result) (Buffer.contents b)); match Cstruct.LE.get_uint64 buf 0 with - | 0L -> Lwt.return (`Ok fd) - | 48L -> errorf "EADDRINUSE" - | 49L -> errorf "EADDRNOTAVAIL" - | n -> errorf "Failed to bind: unrecognised errno: %Ld" n + | 0L -> Lwt_result.return fd + | 48L -> failf "EADDRINUSE" + | 49L -> failf "EADDRNOTAVAIL" + | n -> failf "Failed to bind: unrecognised errno: %Ld" n ) >>= function - | `Error x -> + | Error x -> Unix.close fd; Lwt_result.fail x - | `Ok x -> + | Ok x -> Lwt_result.return x (* This implementation is OSX-only *) diff --git a/src/bin/connect.ml b/src/bin/connect.ml index a919e1fb7..a36593b2e 100644 --- a/src/bin/connect.ml +++ b/src/bin/connect.ml @@ -30,16 +30,15 @@ module Make_unix(Host: Sig.HOST) = struct Cstruct.of_string (Printf.sprintf "00000003.%08lx\n" vsock_port) in write flow address >>= function - | `Ok () -> Lwt.return flow - | `Eof -> + | Ok () -> Lwt.return flow + | Error `Closed -> Log.err (fun f -> f "vsock connect write got Eof"); close flow >>= fun () -> Lwt.fail End_of_file - | `Error e -> - let msg = error_message e in - Log.err (fun f -> f "vsock connect write got %s" msg); + | Error e -> + Log.err (fun f -> f "vsock connect write got %a" pp_write_error e); close flow >>= fun () -> - Lwt.fail_with msg + Fmt.kstrf Lwt.fail_with "%a" pp_write_error e end module Make_hvsock(Host: Sig.HOST) = struct @@ -83,8 +82,10 @@ module Make_hvsock(Host: Sig.HOST) = struct let writev t = F.writev t.flow let shutdown_read t = F.shutdown_read t.flow let shutdown_write t = F.shutdown_write t.flow - let error_message = F.error_message + let pp_error = F.pp_error + let pp_write_error = F.pp_write_error type 'a io = 'a F.io type buffer = F.buffer type error = F.error + type write_error = F.write_error end diff --git a/src/bin/jbuild b/src/bin/jbuild index 2b7acfa68..0c0b33aea 100644 --- a/src/bin/jbuild +++ b/src/bin/jbuild @@ -4,6 +4,7 @@ ((name main) (libraries ( cmdliner ofs logs.fmt hostnet hvsock hvsock.lwt-unix - datakit-server.fs9p win-eventlog asl fd-send-recv + datakit-server-9p win-eventlog asl fd-send-recv duration + mirage-clock-unix mirage-random )) (preprocess no_preprocessing))) diff --git a/src/bin/main.ml b/src/bin/main.ml index dd4d02560..98aada3d8 100644 --- a/src/bin/main.ml +++ b/src/bin/main.ml @@ -54,8 +54,8 @@ module Main(Host: Sig.HOST) = struct module Bind = Bind.Make(Host.Sockets) module Dns_policy = Hostnet_dns.Policy(Host.Files) module Config = Active_config.Make(Host.Time)(Host.Sockets.Stream.Unix) - module Forward_unix = Forward.Make(Connect_unix)(Bind) - module Forward_hvsock = Forward.Make(Connect_hvsock)(Bind) + module Forward_unix = Forward.Make(Mclock)(Connect_unix)(Bind) + module Forward_hvsock = Forward.Make(Mclock)(Connect_hvsock)(Bind) module HV = Flow_lwt_hvsock.Make(Host.Time)(Host.Fn) module Hosts = Hosts.Make(Host.Files) @@ -98,10 +98,10 @@ module Main(Host: Sig.HOST) = struct (* no need to add more delay *) | Unix.Unix_error(_, _, _) -> HV.Hvsock.close socket >>= fun () -> - Host.Time.sleep 1. + Host.Time.sleep_ns (Duration.of_sec 1) | _ -> HV.Hvsock.close socket >>= fun () -> - Host.Time.sleep 1. + Host.Time.sleep_ns (Duration.of_sec 1) ) >>= fun () -> aux () @@ -161,10 +161,11 @@ module Main(Host: Sig.HOST) = struct database key slirp/max-connections instead")); Host.Sockets.set_max_connections max_connections; let uri = Uri.of_string port_control_url in + Mclock.connect () >>= fun clock -> match Uri.scheme uri with | Some "hyperv-connect" -> let module Ports = Active_list.Make(Forward_hvsock) in - let fs = Ports.make () in + let fs = Ports.make clock in Ports.set_context fs ""; let module Server = Protocol_9p.Server.Make(Log9P)(HV)(Ports) in let sockaddr = hvsock_addr_of_uri ~default_serviceid:ports_serviceid uri in @@ -178,7 +179,7 @@ module Main(Host: Sig.HOST) = struct | Ok server -> Server.after_disconnect server) | _ -> let module Ports = Active_list.Make(Forward_unix) in - let fs = Ports.make () in + let fs = Ports.make clock in Ports.set_context fs vsock_path; let module Server = Protocol_9p.Server.Make(Log9P)(Host.Sockets.Stream.Unix)(Ports) @@ -276,6 +277,8 @@ module Main(Host: Sig.HOST) = struct List.map Dns.Name.of_string @@ Astring.String.cuts ~sep:"," host_names in + Mclock.connect () >>= fun clock -> + let hardcoded_configuration = let server_macaddr = Slirp.default_server_macaddr in let peer_ip = Ipaddr.V4.of_string_exn "192.168.65.2" in @@ -301,7 +304,8 @@ module Main(Host: Sig.HOST) = struct client_uuids; bridge_connections = true; mtu = 1500; - host_names } + host_names; + clock } in let config = match db_path with @@ -325,14 +329,15 @@ module Main(Host: Sig.HOST) = struct match Uri.scheme uri with | Some "hyperv-connect" -> let module Slirp_stack = - Slirp.Make(Config)(Vmnet.Make(HV))(Dns_policy)(Host)(Vnet) + Slirp.Make(Config)(Vmnet.Make(HV))(Dns_policy) + (Mclock)(Stdlibrandom)(Host)(Vnet) in let sockaddr = hvsock_addr_of_uri ~default_serviceid:ethernet_serviceid (Uri.of_string socket_url) in ( match config with - | Some config -> Slirp_stack.create ~host_names config + | Some config -> Slirp_stack.create ~host_names clock config | None -> Lwt.return hardcoded_configuration ) >>= fun stack_config -> hvsock_connect_forever socket_url sockaddr (fun fd -> @@ -347,11 +352,11 @@ module Main(Host: Sig.HOST) = struct | _ -> let module Slirp_stack = Slirp.Make(Config)(Vmnet.Make(Host.Sockets.Stream.Unix))(Dns_policy) - (Host)(Vnet) + (Mclock)(Stdlibrandom)(Host)(Vnet) in unix_listen socket_url >>= fun server -> ( match config with - | Some config -> Slirp_stack.create ~host_names config + | Some config -> Slirp_stack.create ~host_names clock config | None -> Lwt.return hardcoded_configuration ) >>= fun stack_config -> Host.Sockets.Stream.Unix.listen server (fun conn -> diff --git a/src/hostnet/arp.ml b/src/hostnet/arp.ml index 4e66635be..9a2534323 100644 --- a/src/hostnet/arp.ml +++ b/src/hostnet/arp.ml @@ -23,6 +23,7 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. * *) +open Lwt.Infix let src = let src = Logs.Src.create "arp" ~doc:"fixed ARP table" in @@ -31,7 +32,7 @@ let src = module Log = (val Logs.src_log src : Logs.LOG) -module Make(Ethif: V1_LWT.ETHIF) = struct +module Make (Ethif: Mirage_protocols_lwt.ETHIF) = struct module Table = Map.Make(Ipaddr.V4) @@ -40,10 +41,9 @@ module Make(Ethif: V1_LWT.ETHIF) = struct type buffer = Cstruct.t type macaddr = Macaddr.t type t = { ethif: Ethif.t; mutable table: macaddr Table.t } - type error = unit - type id = unit + type error = Mirage_protocols.Arp.error + let pp_error = Mirage_protocols.Arp.pp_error type repr = string - type result = [ `Ok of macaddr | `Timeout ] let to_repr t = let pp_one (ip, mac) = @@ -74,11 +74,11 @@ module Make(Ethif: V1_LWT.ETHIF) = struct let query t ip = if Table.mem ip t.table - then Lwt.return (`Ok (Table.find ip t.table)) + then Lwt.return (Ok (Table.find ip t.table)) else begin Log.warn (fun f -> f "ARP table has no entry for %s" (Ipaddr.V4.to_string ip)); - Lwt.return `Timeout + Lwt.return (Error `Timeout) end type arp = { @@ -89,32 +89,8 @@ module Make(Ethif: V1_LWT.ETHIF) = struct tpa: Ipaddr.V4.t; } - let rec input t frame = - let open Arpv4_wire in - match get_arp_op frame with - |1 -> (* Request *) - let req_ipv4 = Ipaddr.V4.of_int32 (get_arp_tpa frame) in - if Table.mem req_ipv4 t.table then begin - Log.debug (fun f -> - f "ARP responding to: who-has %s?" (Ipaddr.V4.to_string req_ipv4)); - let sha = Table.find req_ipv4 t.table in - let tha = Macaddr.of_bytes_exn (copy_arp_sha frame) in - (* the requested address *) - let spa = Ipaddr.V4.of_int32 (get_arp_tpa frame) in - (* the requesting host IPv4 *) - let tpa = Ipaddr.V4.of_int32 (get_arp_spa frame) in - output t { op=`Reply; sha; tha; spa; tpa } - end else Lwt.return_unit - |2 -> (* Reply *) - (* the requested address *) - let spa = Ipaddr.V4.of_int32 (get_arp_tpa frame) in - Log.debug (fun f -> f "ARP ignoring reply %s" (Ipaddr.V4.to_string spa)); - Lwt.return_unit - |n -> - Log.debug (fun f -> f "ARP: Unknown message %d ignored" n); - Lwt.return_unit - and output t arp = + let output t arp = let open Arpv4_wire in (* Obtain a buffer to write into *) let buf = Io_page.to_cstruct (Io_page.get 1) in @@ -129,9 +105,9 @@ module Make(Ethif: V1_LWT.ETHIF) = struct |`Reply -> 2 |`Unknown n -> n in - Wire_structs.set_ethernet_dst dmac 0 buf; - Wire_structs.set_ethernet_src smac 0 buf; - Wire_structs.set_ethernet_ethertype buf 0x0806; (* ARP *) + Ethif_wire.set_ethernet_dst dmac 0 buf; + Ethif_wire.set_ethernet_src smac 0 buf; + Ethif_wire.set_ethernet_ethertype buf 0x0806; (* ARP *) let arpbuf = Cstruct.shift buf 14 in set_arp_htype arpbuf 1; set_arp_ptype arpbuf 0x0800; (* IPv4 *) @@ -143,9 +119,38 @@ module Make(Ethif: V1_LWT.ETHIF) = struct set_arp_tha dmac 0 arpbuf; set_arp_tpa arpbuf tpa; (* Resize buffer to sizeof arp packet *) - let buf = Cstruct.sub buf 0 (sizeof_arp + Wire_structs.sizeof_ethernet) in + let buf = Cstruct.sub buf 0 (sizeof_arp + Ethif_wire.sizeof_ethernet) in Ethif.write t.ethif buf + let input t frame = + let open Arpv4_wire in + match get_arp_op frame with + |1 -> (* Request *) + let req_ipv4 = Ipaddr.V4.of_int32 (get_arp_tpa frame) in + if Table.mem req_ipv4 t.table then begin + Log.debug (fun f -> + f "ARP responding to: who-has %s?" (Ipaddr.V4.to_string req_ipv4)); + let sha = Table.find req_ipv4 t.table in + let tha = Macaddr.of_bytes_exn (copy_arp_sha frame) in + (* the requested address *) + let spa = Ipaddr.V4.of_int32 (get_arp_tpa frame) in + (* the requesting host IPv4 *) + let tpa = Ipaddr.V4.of_int32 (get_arp_spa frame) in + output t { op=`Reply; sha; tha; spa; tpa } >|= function + | Ok () -> () + | Error e -> + Log.err (fun f -> + f "error while reading ARP packet: %a" Ethif.pp_error e); + end else Lwt.return_unit + |2 -> (* Reply *) + (* the requested address *) + let spa = Ipaddr.V4.of_int32 (get_arp_tpa frame) in + Log.debug (fun f -> f "ARP ignoring reply %s" (Ipaddr.V4.to_string spa)); + Lwt.return_unit + |n -> + Log.debug (fun f -> f "ARP: Unknown message %d ignored" n); + Lwt.return_unit + type ethif = Ethif.t let connect ~table ethif = @@ -154,7 +159,7 @@ module Make(Ethif: V1_LWT.ETHIF) = struct Table.add ip mac acc ) Table.empty table in - Lwt.return (`Ok { table; ethif }) + { table; ethif } let disconnect _t = Lwt.return_unit end diff --git a/src/hostnet/arp.mli b/src/hostnet/arp.mli index 405438f4b..03443105d 100644 --- a/src/hostnet/arp.mli +++ b/src/hostnet/arp.mli @@ -2,13 +2,12 @@ rely on the dynamic version which can fail with `No_route_to_host` if the other side doesn't respond *) -module Make(Ethif: V1_LWT.ETHIF): sig - include V1_LWT.ARP +module Make(Ethif: Mirage_protocols_lwt.ETHIF): sig + include Mirage_protocols_lwt.ARP type ethif = Ethif.t val connect: - table:(ipaddr * macaddr) list -> ethif - -> [ `Ok of t | `Error of error ] Lwt.t + table:(ipaddr * macaddr) list -> ethif -> t (** Construct a static ARP table *) end diff --git a/src/hostnet/capture.ml b/src/hostnet/capture.ml index 1b9848b5f..8fcaca709 100644 --- a/src/hostnet/capture.ml +++ b/src/hostnet/capture.ml @@ -1,3 +1,5 @@ +open Lwt.Infix + let src = let src = Logs.Src.create "capture" ~doc:"capture network traffic" in Logs.Src.set_level src (Some Logs.Info); @@ -7,14 +9,21 @@ module Log = (val Logs.src_log src : Logs.LOG) module Make(Input: Sig.VMNET) = struct + type page_aligned_buffer = Io_page.t + type buffer = Cstruct.t + type macaddr = Macaddr.t + type 'a io = 'a Lwt.t type fd = Input.fd + type error = [Mirage_device.error | `Unknown of string] - type stats = { - mutable rx_bytes: int64; - mutable rx_pkts: int32; - mutable tx_bytes: int64; - mutable tx_pkts: int32; - } + let pp_error ppf = function + | #Mirage_device.error as e -> Mirage_device.pp_error ppf e + | `Unknown s -> Fmt.pf ppf "unknown: %s" s + + let lift_error = function + | Ok x -> Ok x + | Error (#Mirage_device.error as e) -> Error e + | Error e -> Fmt.kstrf (fun s -> Error (`Unknown s)) "%a" Input.pp_error e type packet = { len: int; @@ -139,7 +148,7 @@ module Make(Input: Sig.VMNET) = struct type t = { input: Input.t; rules: (string, rule) Hashtbl.t; - stats: stats; + stats: Mirage_net.stats; } let add_match ~t ~name ~limit ~snaplen ~predicate = @@ -152,15 +161,13 @@ module Make(Input: Sig.VMNET) = struct let connect input = let rules = Hashtbl.create 7 in - let stats = { - rx_bytes = 0L; rx_pkts = 0l; tx_bytes = 0L; tx_pkts = 0l; - } in + let stats = Mirage_net.Stats.create () in let t = { input; rules; stats } in (* Add a special capture rule for packets for which there is an error processing the packet captures. Ideally there should be no matches! *) add_match ~t ~name:bad_pcap ~limit:1048576 ~snaplen:1500 ~predicate:(fun _ -> false); - Lwt.return (`Ok t) + t let filesystem t = Vfs.Dir.of_list (fun () -> @@ -188,34 +195,20 @@ module Make(Input: Sig.VMNET) = struct let write t buf = record t [ buf ]; - Input.write t.input buf + Input.write t.input buf >|= lift_error + let writev t bufs = record t bufs; - Input.writev t.input bufs + Input.writev t.input bufs >|= lift_error let listen t callback = Input.listen t.input (fun buf -> record t [ buf ]; callback buf) + >|= lift_error let add_listener t callback = Input.add_listener t.input callback let mac t = Input.mac t.input - type page_aligned_buffer = Io_page.t - - type buffer = Cstruct.t - - type error = [ - | `Unknown of string - | `Unimplemented - | `Disconnected - ] - - type macaddr = Macaddr.t - - type 'a io = 'a Lwt.t - - type id = unit - let get_stats_counters t = t.stats let reset_stats_counters t = diff --git a/src/hostnet/capture.mli b/src/hostnet/capture.mli index d3602a13f..85328bdfe 100644 --- a/src/hostnet/capture.mli +++ b/src/hostnet/capture.mli @@ -2,8 +2,7 @@ module Make(Input: Sig.VMNET): sig include Sig.VMNET include Sig.RECORDER with type t := t - val connect: Input.t - -> [ `Ok of t | `Error of error ] Lwt.t + val connect: Input.t -> t (** Capture traffic from a network, match against a set of capture rules and keep a limited amount of the most recent traffic that matches. *) diff --git a/src/hostnet/dhcp.ml b/src/hostnet/dhcp.ml index 2e5545f01..357a0509d 100644 --- a/src/hostnet/dhcp.ml +++ b/src/hostnet/dhcp.ml @@ -7,9 +7,10 @@ let src = module Log = (val Logs.src_log src : Logs.LOG) -module Make(Netif: V1_LWT.NETWORK) = struct +module Make (Clock: Mirage_clock_lwt.MCLOCK) (Netif: Mirage_net_lwt.S) = struct type t = { + clock: Clock.t; netif: Netif.t; server_macaddr: Macaddr.t; get_dhcp_configuration : unit -> Dhcp_server.Config.t; @@ -32,7 +33,7 @@ module Make(Netif: V1_LWT.NETWORK) = struct (* given some MACs and IPs, construct a usable DHCP configuration *) let make ~server_macaddr ~peer_ip ~highest_peer_ip ~local_ip ~extra_dns_ip - ~get_domain_search ~get_domain_name netif = + ~get_domain_search ~get_domain_name clock netif = let open Dhcp_server.Config in (* FIXME: We need a DHCP range to make the DHCP server happy, even though we intend only to serve IPs to one downstream host. @@ -84,9 +85,9 @@ module Make(Netif: V1_LWT.NETWORK) = struct mac_addr = server_macaddr; network = prefix; (* FIXME: this needs https://github.com/haesbaert/charrua-core/pull/31 *) - range = (peer_ip, peer_ip); (* allow one dynamic client *) + range = Some (peer_ip, peer_ip); (* allow one dynamic client *) } in - { netif; server_macaddr; get_dhcp_configuration } + { clock; netif; server_macaddr; get_dhcp_configuration } let of_interest mac dest = Macaddr.compare dest mac = 0 || not (Macaddr.is_unicast dest) @@ -96,14 +97,19 @@ module Make(Netif: V1_LWT.NETWORK) = struct let logged_bootrequest = ref false let logged_bootreply = ref false - let input net (config : Dhcp_server.Config.t) database buf = + let input clock net (config : Dhcp_server.Config.t) database buf = let open Dhcp_server in match Dhcp_wire.pkt_of_buf buf (Cstruct.len buf) with - | `Error e -> + | Error e -> Log.err (fun f -> f "failed to parse DHCP packet: %s" e); Lwt.return database - | `Ok pkt -> - match (Input.input_pkt config database pkt (Clock.time ())) with + | Ok pkt -> + let elapsed_seconds = + Clock.elapsed_ns clock + |> Duration.to_sec + |> Int32.of_int + in + match Input.input_pkt config database pkt elapsed_seconds with | Input.Silence -> Lwt.return database | Input.Update database -> Log.debug (fun f -> f "lease database updated"); @@ -122,27 +128,30 @@ module Make(Netif: V1_LWT.NETWORK) = struct (Macaddr.to_string (pkt.srcmac))); logged_bootrequest := !logged_bootrequest || (pkt.op = Dhcp_wire.BOOTREQUEST); - Netif.write net (Dhcp_wire.buf_of_pkt reply) - >>= fun () -> - let domain = List.fold_left (fun acc x -> match x with - | Domain_name y -> y - | _ -> acc) "unknown" reply.options in - let dns = List.fold_left (fun acc x -> match x with - | Dns_servers ys -> String.concat ", " (List.map Ipaddr.V4.to_string ys) - | _ -> acc) "none" reply.options in - let routers = List.fold_left (fun acc x -> match x with - | Routers ys -> String.concat ", " (List.map Ipaddr.V4.to_string ys) - | _ -> acc) "none" reply.options in - if reply.op <> Dhcp_wire.BOOTREPLY || not !logged_bootreply - then Log.info (fun f -> - f "%s to %s yiddr %a siddr %a dns %s router %s domain %s" - (op_to_string reply.op) (Macaddr.to_string (reply.dstmac)) - Ipaddr.V4.pp_hum reply.yiaddr Ipaddr.V4.pp_hum reply.siaddr - dns routers domain - ); - logged_bootreply := - !logged_bootreply || (reply.op = Dhcp_wire.BOOTREPLY); - Lwt.return database + Netif.write net (Dhcp_wire.buf_of_pkt reply) >>= function + | Error e -> + Log.err (fun f -> f "failed to parse DHCP reply: %a" Netif.pp_error e); + Lwt.return database + | Ok () -> + let domain = List.fold_left (fun acc x -> match x with + | Domain_name y -> y + | _ -> acc) "unknown" reply.options in + let dns = List.fold_left (fun acc x -> match x with + | Dns_servers ys -> String.concat ", " (List.map Ipaddr.V4.to_string ys) + | _ -> acc) "none" reply.options in + let routers = List.fold_left (fun acc x -> match x with + | Routers ys -> String.concat ", " (List.map Ipaddr.V4.to_string ys) + | _ -> acc) "none" reply.options in + if reply.op <> Dhcp_wire.BOOTREPLY || not !logged_bootreply + then Log.info (fun f -> + f "%s to %s yiddr %a siddr %a dns %s router %s domain %s" + (op_to_string reply.op) (Macaddr.to_string (reply.dstmac)) + Ipaddr.V4.pp_hum reply.yiaddr Ipaddr.V4.pp_hum reply.siaddr + dns routers domain + ); + logged_bootreply := + !logged_bootreply || (reply.op = Dhcp_wire.BOOTREPLY); + Lwt.return database let callback t buf = (* TODO: the scope of this reference ensures that the database @@ -151,12 +160,13 @@ module Make(Netif: V1_LWT.NETWORK) = struct pre-allocated IP anyway, but this will present a problem if that assumption ever changes. *) let database = ref (Dhcp_server.Lease.make_db ()) in - match (Wire_structs.parse_ethernet_frame buf) with - | Some (proto, dst, _payload) when of_interest t.server_macaddr dst -> - (match proto with - | Some Wire_structs.IPv4 -> + match Ethif_packet.Unmarshal.of_cstruct buf with + | Ok (pkt, _payload) when + of_interest t.server_macaddr pkt.Ethif_packet.destination -> + (match pkt.Ethif_packet.ethertype with + | Ethif_wire.IPv4 -> if Dhcp_wire.is_dhcp buf (Cstruct.len buf) then begin - input t.netif (t.get_dhcp_configuration ()) !database buf + input t.clock t.netif (t.get_dhcp_configuration ()) !database buf >|= fun db -> database := db end diff --git a/src/hostnet/dhcp.mli b/src/hostnet/dhcp.mli index c0bf61b85..12b05adba 100644 --- a/src/hostnet/dhcp.mli +++ b/src/hostnet/dhcp.mli @@ -1,4 +1,4 @@ -module Make(Netif: V1_LWT.NETWORK): sig +module Make (Clock: Mirage_clock_lwt.MCLOCK) (Netif: Mirage_net_lwt.S): sig type t val make: server_macaddr:Macaddr.t @@ -6,7 +6,7 @@ module Make(Netif: V1_LWT.NETWORK): sig -> local_ip:Ipaddr.V4.t -> extra_dns_ip:Ipaddr.V4.t list -> get_domain_search:(unit -> string list) -> get_domain_name:(unit -> string) - -> Netif.t -> t + -> Clock.t -> Netif.t -> t val callback: t -> Cstruct.t -> unit Lwt.t end diff --git a/src/hostnet/filter.ml b/src/hostnet/filter.ml index 8acae6c1a..7b9a42a72 100644 --- a/src/hostnet/filter.ml +++ b/src/hostnet/filter.ml @@ -1,3 +1,5 @@ +open Lwt.Infix + let src = let src = Logs.Src.create "ppp" ~doc:"point-to-point network link" in Logs.Src.set_level src (Some Logs.Info); @@ -7,38 +9,43 @@ module Log = (val Logs.src_log src : Logs.LOG) module Make(Input: Sig.VMNET) = struct + type page_aligned_buffer = Io_page.t + type buffer = Cstruct.t + type macaddr = Macaddr.t + type 'a io = 'a Lwt.t type fd = Input.fd + type error = [Mirage_device.error | `Unknown of string] - type stats = { - mutable rx_bytes: int64; - mutable rx_pkts: int32; - mutable tx_bytes: int64; - mutable tx_pkts: int32; - } + let pp_error ppf = function + | #Mirage_device.error as e -> Mirage_device.pp_error ppf e + | `Unknown s -> Fmt.pf ppf "unknown: %s" s type t = { input: Input.t; - stats: stats; + stats: Mirage_net.stats; valid_subnets: Ipaddr.V4.Prefix.t list; valid_sources: Ipaddr.V4.t list; } let connect ~valid_subnets ~valid_sources input = - let stats = { - rx_bytes = 0L; rx_pkts = 0l; tx_bytes = 0L; tx_pkts = 0l; - } in - Lwt.return (`Ok { input; stats; valid_subnets; valid_sources }) + let stats = Mirage_net.Stats.create () in + { input; stats; valid_subnets; valid_sources } let disconnect t = Input.disconnect t.input let after_disconnect t = Input.after_disconnect t.input - let write t buf = Input.write t.input buf - let writev t bufs = Input.writev t.input bufs + let lift_error = function + | Ok x -> Ok x + | Error (#Mirage_device.error as e) -> Error e + | Error e -> Fmt.kstrf (fun s -> Error (`Unknown s)) "%a" Input.pp_error e + + let write t buf = Input.write t.input buf >|= lift_error + let writev t bufs = Input.writev t.input bufs >|= lift_error let filter valid_subnets valid_sources next buf = - match (Wire_structs.parse_ethernet_frame buf) with - | Some (Some Wire_structs.IPv4, _, payload) -> - let src = Ipaddr.V4.of_int32 @@ Wire_structs.Ipv4_wire.get_ipv4_src payload in + match Ethif_packet.Unmarshal.of_cstruct buf with + | Ok (_header, payload) -> + let src = Ipaddr.V4.of_int32 @@ Ipv4_wire.get_ipv4_src payload in let from_valid_networks = List.fold_left (fun acc network -> acc || (Ipaddr.V4.Prefix.mem src network) @@ -56,15 +63,16 @@ module Make(Input: Sig.VMNET) = struct let dst = Ipaddr.V4.to_string @@ Ipaddr.V4.of_int32 @@ - Wire_structs.Ipv4_wire.get_ipv4_dst payload + Ipv4_wire.get_ipv4_dst payload in - let body = Cstruct.shift payload Wire_structs.Ipv4_wire.sizeof_ipv4 in + let body = Cstruct.shift payload Ipv4_wire.sizeof_ipv4 in begin match - Wire_structs.Ipv4_wire.(int_to_protocol @@ get_ipv4_proto payload) + Ipv4_packet.Unmarshal.int_to_protocol + @@ Ipv4_wire.get_ipv4_proto payload with | Some `UDP -> - let src_port = Wire_structs.get_udp_source_port body in - let dst_port = Wire_structs.get_udp_dest_port body in + let src_port = Udp_wire.get_udp_source_port body in + let dst_port = Udp_wire.get_udp_dest_port body in Log.warn (fun f -> f "dropping unexpected UDP packet sent from %s:%d to %s:%d \ (valid subnets = %s; valid sources = %s)" @@ -75,8 +83,8 @@ module Make(Input: Sig.VMNET) = struct (List.map Ipaddr.V4.to_string valid_sources)) ) | Some `TCP -> - let src_port = Wire_structs.Tcp_wire.get_tcp_src_port body in - let dst_port = Wire_structs.Tcp_wire.get_tcp_dst_port body in + let src_port = Tcp.Tcp_wire.get_tcp_src_port body in + let dst_port = Tcp.Tcp_wire.get_tcp_dst_port body in Log.warn (fun f -> f "dropping unexpected TCP packet sent from %s:%d to %s:%d \ (valid subnets = %s; valid sources = %s)" @@ -90,7 +98,7 @@ module Make(Input: Sig.VMNET) = struct Log.warn (fun f -> f "dropping unknown IP protocol %d sent from %s to %s (valid \ subnets = %s; valid sources = %s)" - (Wire_structs.Ipv4_wire.get_ipv4_proto payload) src dst + (Ipv4_wire.get_ipv4_proto payload) src dst (String.concat ", " (List.map Ipaddr.V4.Prefix.to_string valid_subnets)) (String.concat ", " @@ -103,35 +111,14 @@ module Make(Input: Sig.VMNET) = struct let listen t callback = Input.listen t.input @@ filter t.valid_subnets t.valid_sources callback + >|= lift_error let add_listener t callback = Input.add_listener t.input @@ filter t.valid_subnets t.valid_sources callback let mac t = Input.mac t.input - - type page_aligned_buffer = Io_page.t - - type buffer = Cstruct.t - - type error = [ - | `Unknown of string - | `Unimplemented - | `Disconnected - ] - - type macaddr = Macaddr.t - - type 'a io = 'a Lwt.t - - type id = unit - let get_stats_counters t = t.stats - - let reset_stats_counters t = - t.stats.rx_bytes <- 0L; - t.stats.tx_bytes <- 0L; - t.stats.rx_pkts <- 0l; - t.stats.tx_pkts <- 0l + let reset_stats_counters t = Mirage_net.Stats.reset t.stats let of_fd ~client_macaddr_of_uuid:_ ~server_macaddr:_ ~mtu:_ = failwith "Filter.of_fd unimplemented" diff --git a/src/hostnet/filter.mli b/src/hostnet/filter.mli index ed78b3468..b44ee23f8 100644 --- a/src/hostnet/filter.mli +++ b/src/hostnet/filter.mli @@ -4,7 +4,7 @@ module Make(Input: Sig.VMNET): sig val connect: valid_subnets:Ipaddr.V4.Prefix.t list -> valid_sources:Ipaddr.V4.t list -> Input.t - -> [ `Ok of t | `Error of error ] Lwt.t + -> t (** Construct a filtered ethernet network which removes IP packets whose source IP is not in [valid_sources] *) end diff --git a/src/hostnet/forward.ml b/src/hostnet/forward.ml index 7b95214f3..8e627cb5c 100644 --- a/src/hostnet/forward.ml +++ b/src/hostnet/forward.ml @@ -58,7 +58,11 @@ module Port = struct end -module Make(Connector: Sig.Connector)(Socket: Sig.SOCKETS) = struct +module Make + (Clock: Mirage_clock_lwt.MCLOCK) + (Connector: Sig.Connector) + (Socket: Sig.SOCKETS) = +struct type server = [ | `Tcp of Socket.Stream.Tcp.server @@ -75,6 +79,7 @@ module Make(Connector: Sig.Connector)(Socket: Sig.SOCKETS) = struct let get_key t = t.local + type clock = Clock.t type context = string let to_string t = @@ -116,21 +121,22 @@ module Make(Connector: Sig.Connector)(Socket: Sig.SOCKETS) = struct Cstruct.LE.set_uint16 header 7 port; (* Write the header, we should be connected to the container port *) Connector.write remote header >>= function - | `Error e -> - let msg = - Fmt.strf "%s: failed to write forwarding header: %s" description - (Connector.error_message e) - in + | Ok () -> Lwt.return_unit + | Error `Closed -> + let msg = Fmt.strf "%s: EOF writing forwarding header" description in Log.err (fun f -> f "%s" msg); Lwt.fail (Failure msg) - | `Eof -> - let msg = Fmt.strf "%s: EOF writing forwarding header" description in + | Error e -> + let msg = + Fmt.strf "%s: failed to write forwarding header: %a" description + Connector.pp_write_error e + in Log.err (fun f -> f "%s" msg); Lwt.fail (Failure msg) - | `Ok () -> - Lwt.return_unit - let start_tcp_proxy description vsock_path_var remote_port server = + module Proxy = Mirage_flow_lwt.Proxy(Clock)(Connector)(Socket.Stream.Tcp) + + let start_tcp_proxy clock description vsock_path_var remote_port server = Socket.Stream.Tcp.listen server (fun local -> Active_list.Var.read vsock_path_var >>= fun _vsock_path -> Connector.connect () >>= fun remote -> @@ -138,18 +144,15 @@ module Make(Connector: Sig.Connector)(Socket: Sig.SOCKETS) = struct write_forwarding_header description remote remote_port >>= fun () -> Log.debug (fun f -> f "%s: connected" description); - Mirage_flow.proxy - (module Clock) - (module Connector) remote - (module Socket.Stream.Tcp) local () - >|= function - | `Error (`Msg m) -> - Log.err (fun f -> f "%s proxy failed with %s" description m) - | `Ok (l_stats, r_stats) -> + Proxy.proxy clock remote local >|= function + | Error e -> + Log.err (fun f -> + f "%s proxy failed with %a" description Proxy.pp_error e) + | Ok (l_stats, r_stats) -> Log.debug (fun f -> - f "%s completed: l2r = %s; r2l = %s" description - (Mirage_flow.CopyStats.to_string l_stats) - (Mirage_flow.CopyStats.to_string r_stats) + f "%s completed: l2r = %a; r2l = %a" description + Mirage_flow.pp_stats l_stats + Mirage_flow.pp_stats r_stats ) ) (fun () -> Connector.close remote @@ -161,15 +164,15 @@ module Make(Connector: Sig.Connector)(Socket: Sig.SOCKETS) = struct let conn_read flow buf = Connector.read_into flow buf >>= function - | `Eof -> Lwt.fail End_of_file - | `Error e -> Lwt.fail_with (Connector.error_message e) - | `Ok () -> Lwt.return () + | Ok `Eof -> Lwt.fail End_of_file + | Error e -> Fmt.kstrf Lwt.fail_with "%a" Connector.pp_error e + | Ok (`Data ()) -> Lwt.return () let conn_write flow buf = Connector.write flow buf >>= function - | `Eof -> Lwt.fail End_of_file - | `Error e -> Lwt.fail_with (Connector.error_message e) - | `Ok () -> Lwt.return () + | Error `Closed -> Lwt.fail End_of_file + | Error e -> Fmt.kstrf Lwt.fail_with "%a" Connector.pp_write_error e + | Ok () -> Lwt.return () let start_udp_proxy description vsock_path_var remote_port server = let from_internet_buffer = Cstruct.create Constants.max_udp_length in @@ -306,7 +309,7 @@ module Make(Connector: Sig.Connector)(Socket: Sig.SOCKETS) = struct log_exception_continue "udp handle" (fun () -> handle server)); Lwt.return () - let start vsock_path_var t = + let start state vsock_path_var t = match t.local with | `Tcp (local_ip, local_port) -> let description = @@ -324,7 +327,7 @@ module Make(Connector: Sig.Connector)(Socket: Sig.SOCKETS) = struct `Tcp (local_ip, port) | _ -> t.local ); - start_tcp_proxy (to_string t) vsock_path_var t.remote_port server + start_tcp_proxy state (to_string t) vsock_path_var t.remote_port server >|= fun () -> Ok t ) (function diff --git a/src/hostnet/forward.mli b/src/hostnet/forward.mli index a6069626a..df68ce787 100644 --- a/src/hostnet/forward.mli +++ b/src/hostnet/forward.mli @@ -8,10 +8,10 @@ module Port : sig val of_string: string -> (t, [ `Msg of string ]) result end -module Make(Connector: Sig.Connector)(Sockets: Sig.SOCKETS) : sig - include Active_list.Instance - with type context = string - -end +module Make + (Clock: Mirage_clock_lwt.MCLOCK) + (Connector: Sig.Connector) + (Socket: Sig.SOCKETS): + Active_list.Instance with type context = string and type clock = Clock.t val set_allowed_addresses: Ipaddr.t list option -> unit diff --git a/src/hostnet/host_lwt_unix.ml b/src/hostnet/host_lwt_unix.ml index 67f65cd8c..f54105416 100644 --- a/src/hostnet/host_lwt_unix.ml +++ b/src/hostnet/host_lwt_unix.ml @@ -19,6 +19,22 @@ let log_exception_continue description f = Lwt.return () ) +module Common = struct + (** FLOW boilerplate *) + + type 'a io = 'a Lwt.t + type buffer = Cstruct.t + type error = [`Msg of string] + type write_error = [Mirage_flow.write_error | error] + let pp_error ppf (`Msg x) = Fmt.string ppf x + + let pp_write_error ppf = function + | #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e + | #error as e -> pp_error ppf e + + let errorf fmt = Fmt.kstrf (fun s -> Lwt_result.fail (`Msg s)) fmt +end + module Sockets = struct let max_connections = ref None @@ -101,8 +117,7 @@ module Sockets = struct in Lwt.catch (fun () -> Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; - Lwt_unix.Versioned.bind_2 fd addr - >|= fun () -> + Lwt_unix.bind fd addr >|= fun () -> idx, fd ) (fun e -> Lwt_unix.close fd @@ -148,12 +163,7 @@ module Sockets = struct type address = Ipaddr.t * int module Udp = struct - - (* FLOW boilerplate *) - type 'a io = 'a Lwt.t - type buffer = Cstruct.t - type error = [`Msg of string] - let error_message (`Msg x) = x + include Common type flow = { mutable idx: int option; @@ -185,19 +195,19 @@ module Sockets = struct (* Win32 requires all sockets to be bound however macOS and Linux don't *) Lwt.catch (fun () -> - Lwt_unix.Versioned.bind_2 fd (Lwt_unix.ADDR_INET(addr, 0)) + Lwt_unix.bind fd (Lwt_unix.ADDR_INET(addr, 0)) ) (fun _ -> Lwt.return_unit) >|= fun () -> let sockaddr = sockaddr_of_address address in Ok (of_fd ~idx ~description ?read_buffer_size sockaddr address fd) let read t = match t.fd, t.already_read with - | None, _ -> Lwt.return `Eof + | None, _ -> Lwt.return (Ok `Eof) | Some _, Some data when Cstruct.len data > 0 -> t.already_read <- Some (Cstruct.sub data 0 0); (* next read is `Eof *) - Lwt.return (`Ok data) + Lwt.return (Ok (`Data data)) | Some _, Some _ -> - Lwt.return `Eof + Lwt.return (Ok `Eof) | Some fd, None -> let buffer = Cstruct.create t.read_buffer_size in let bytes = Bytes.make t.read_buffer_size '\000' in @@ -207,29 +217,29 @@ module Sockets = struct >>= fun (n, _) -> Cstruct.blit_from_bytes bytes 0 buffer 0 n; let response = Cstruct.sub buffer 0 n in - Lwt.return (`Ok response) + Lwt.return (Ok (`Data response)) ) (fun e -> Log.err (fun f -> f "%s: recvfrom caught %a returning Eof" (string_of_flow t) Fmt.exn e); - Lwt.return `Eof + Lwt.return (Ok `Eof) ) let write t buf = match t.fd with - | None -> Lwt.return `Eof + | None -> Lwt.return (Error `Closed) | Some fd -> Lwt.catch (fun () -> (* Lwt on Win32 doesn't support Lwt_bytes.sendto *) let bytes = Bytes.make (Cstruct.len buf) '\000' in Cstruct.blit_to_bytes buf 0 bytes 0 (Cstruct.len buf); Lwt_unix.sendto fd bytes 0 (Bytes.length bytes) [] t.sockaddr - >>= fun _n -> - Lwt.return (`Ok ()) + >|= fun _n -> + Ok () ) (fun e -> Log.err (fun f -> f "%s: sendto caught %a returning Eof" (string_of_flow t) Fmt.exn e); - Lwt.return `Eof + Lwt.return (Error `Closed) ) let writev t bufs = write t (Cstruct.concat bufs) @@ -365,6 +375,9 @@ module Sockets = struct (* Using Lwt_unix we share an implementation across various transport types *) module Fd = struct + + include Common + type flow = { idx: int; description: string; @@ -374,10 +387,6 @@ module Sockets = struct mutable closed: bool; } - type error = [`Msg of string] - let error_message (`Msg x) = x - let errorf fmt = Fmt.kstrf (fun s -> Lwt_result.fail (`Msg s)) fmt - let of_fd ~idx ?(read_buffer_size = default_read_buffer_size) ~description fd = @@ -410,7 +419,7 @@ module Sockets = struct Lwt.return () let read t = - if t.closed then Lwt.return `Eof + if t.closed then Lwt.return (Ok `Eof) else begin if Cstruct.len t.read_buffer = 0 then t.read_buffer <- Cstruct.create t.read_buffer_size; @@ -418,43 +427,43 @@ module Sockets = struct Lwt_bytes.read t.fd t.read_buffer.Cstruct.buffer t.read_buffer.Cstruct.off t.read_buffer.Cstruct.len >|= function - | 0 -> `Eof + | 0 -> Ok `Eof | n -> let results = Cstruct.sub t.read_buffer 0 n in t.read_buffer <- Cstruct.shift t.read_buffer n; - `Ok results + Ok (`Data results) ) (fun e -> Log.err (fun f -> f "Socket.TCPV4.read %s: caught %a returning Eof" t.description Fmt.exn e); - Lwt.return `Eof + Lwt.return (Ok `Eof) ) end let read_into t buffer = - if t.closed then Lwt.return `Eof + if t.closed then Lwt.return (Ok `Eof) else Lwt.catch (fun () -> Lwt_cstruct.(complete (read t.fd) buffer) >|= fun () -> - `Ok () - ) (fun _e -> Lwt.return `Eof) + Ok (`Data ()) + ) (fun _e -> Lwt.return (Ok `Eof)) let write t buf = - if t.closed then Lwt.return `Eof + if t.closed then Lwt.return (Error `Closed) else Lwt.catch (fun () -> Lwt_cstruct.(complete (write t.fd) buf) >|= fun () -> - `Ok () + Ok () ) (fun e -> Log.err (fun f -> f "Socket.TCPV4.write %s: caught %a returning Eof" t.description Fmt.exn e); - Lwt.return `Eof + Lwt.return (Error `Closed) ) let writev t bufs = let rec loop = function - | [] -> Lwt.return (`Ok ()) + | [] -> Lwt.return (Ok ()) | buf :: bufs -> - if t.closed then Lwt.return `Eof + if t.closed then Lwt.return (Error `Closed) else Lwt_cstruct.(complete (write t.fd) buf) >>= fun () -> loop bufs @@ -465,7 +474,7 @@ module Sockets = struct Log.err (fun f -> f "Socket.TCPV4.writev %s: caught %a returning Eof" t.description Fmt.exn e); - Lwt.return `Eof + Lwt.return (Error `Closed) ) let close t = @@ -561,9 +570,6 @@ module Sockets = struct ) ) server.listening_fds - (* FLOW boilerplate *) - type 'a io = 'a Lwt.t - type buffer = Cstruct.t end module Tcp = struct @@ -627,8 +633,7 @@ module Sockets = struct register_connection description >>= fun idx -> let s = Lwt_unix.socket Lwt_unix.PF_UNIX Lwt_unix.SOCK_STREAM 0 in Lwt.catch (fun () -> - Lwt_unix.Versioned.bind_2 s (Lwt_unix.ADDR_UNIX path) - >|= fun () -> + Lwt_unix.bind s (Lwt_unix.ADDR_UNIX path) >|= fun () -> make ~path [ idx, s ] ) (fun e -> Lwt_unix.close s >>= fun () -> @@ -739,29 +744,7 @@ module Files = struct let unwatch = Lwt.cancel end -module Time = struct - type 'a io = 'a Lwt.t - - let sleep = Lwt_unix.sleep -end - -module Clock = struct - type tm = - { tm_sec: int; - tm_min: int; - tm_hour: int; - tm_mday: int; - tm_mon: int; - tm_year: int; - tm_wday: int; - tm_yday: int; - tm_isdst: bool; - } - - let time = Unix.gettimeofday - - let gmtime _ = failwith "gmtime unimplemented" -end +module Time = Time module Dns = struct diff --git a/src/hostnet/host_uwt.ml b/src/hostnet/host_uwt.ml index e8434d351..3c2021c04 100644 --- a/src/hostnet/host_uwt.ml +++ b/src/hostnet/host_uwt.ml @@ -29,12 +29,15 @@ let sockaddr_of_address (dst, dst_port) = module Common = struct (** FLOW boilerplate *) - type error = [ - | `Msg of string - ] + type 'a io = 'a Lwt.t + type buffer = Cstruct.t + type error = [`Msg of string] + type write_error = [Mirage_flow.write_error | error] + let pp_error ppf (`Msg x) = Fmt.string ppf x - let error_message = function - | `Msg x -> x + let pp_write_error ppf = function + | #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e + | #error as e -> pp_error ppf e let errorf fmt = Fmt.kstrf (fun s -> Lwt_result.fail (`Msg s)) fmt @@ -44,9 +47,6 @@ module Common = struct Some (Ipaddr.of_string @@ Unix.string_of_inet_addr ip, port) | _ -> None with _ -> None - - type 'a io = 'a Lwt.t - type buffer = Cstruct.t end module Sockets = struct @@ -160,12 +160,12 @@ module Sockets = struct ) let rec read t = match t.fd, t.already_read with - | None, _ -> Lwt.return `Eof + | None, _ -> Lwt.return (Ok `Eof) | Some _, Some data when Cstruct.len data > 0 -> t.already_read <- Some (Cstruct.sub data 0 0); (* next read is `Eof *) - Lwt.return (`Ok data) + Lwt.return (Ok (`Data data)) | Some _, Some _ -> - Lwt.return `Eof + Lwt.return (Ok `Eof) | Some fd, None -> let buf = Cstruct.create t.read_buffer_size in Lwt.catch (fun () -> @@ -178,11 +178,11 @@ module Sockets = struct was %d bytes)" t.label (Cstruct.len buf)); read t end else - Lwt.return (`Ok (Cstruct.sub buf 0 recv.Uwt.Udp.recv_len)) + Lwt.return (Ok (`Data (Cstruct.sub buf 0 recv.Uwt.Udp.recv_len))) ) (function | Unix.Unix_error(e, _, _) when Uwt.of_unix_error e = Uwt.ECANCELED -> (* happens on normal timeout *) - Lwt.return `Eof + Lwt.return (Ok `Eof) | e -> Log.err (fun f -> f "Socket.%s.recvfrom: %s caught %s returning Eof" @@ -190,21 +190,21 @@ module Sockets = struct (string_of_flow t) (Printexc.to_string e) ); - Lwt.return `Eof + Lwt.return (Ok `Eof) ) let write t buf = match t.fd with - | None -> Lwt.return `Eof + | None -> Lwt.return (Error `Closed) | Some fd -> Lwt.catch (fun () -> Uwt.Udp.send_ba ~pos:buf.Cstruct.off ~len:buf.Cstruct.len ~buf:buf.Cstruct.buffer fd t.sockaddr >>= fun () -> - Lwt.return (`Ok ()) + Lwt.return (Ok ()) ) (fun e -> Log.err (fun f -> f "Socket.%s.write %s: caught %s returning Eof" t.label t.description (Printexc.to_string e)); - Lwt.return `Eof + Lwt.return (Error `Closed) ) let writev t bufs = write t (Cstruct.concat bufs) @@ -441,14 +441,13 @@ module Sockets = struct let read_into t buf = let rec loop buf = if Cstruct.len buf = 0 - then Lwt.return (`Ok ()) + then Lwt.return (Ok (`Data ())) else Uwt.Tcp.read_ba ~pos:buf.Cstruct.off ~len:buf.Cstruct.len t.fd ~buf:buf.Cstruct.buffer >>= function - | 0 -> Lwt.return `Eof - | n -> - loop (Cstruct.shift buf n) + | 0 -> Lwt.return (Ok `Eof) + | n -> loop (Cstruct.shift buf n) in loop buf @@ -460,21 +459,21 @@ module Sockets = struct ~len:t.read_buffer.Cstruct.len t.fd ~buf:t.read_buffer.Cstruct.buffer >>= function - | 0 -> Lwt.return `Eof + | 0 -> Lwt.return (Ok `Eof) | n -> let results = Cstruct.sub t.read_buffer 0 n in t.read_buffer <- Cstruct.shift t.read_buffer n; - Lwt.return (`Ok results) + Lwt.return (Ok (`Data results)) ) (function | Unix.Unix_error(Unix.ECONNRESET, _, _) -> - Lwt.return `Eof + Lwt.return (Ok `Eof) | Unix.Unix_error(e, _, _) when Uwt.of_unix_error e = Uwt.ECANCELED -> - Lwt.return `Eof + Lwt.return (Ok `Eof) | e -> Log.err (fun f -> f "Socket.%s.read %s: caught %s returning Eof" t.label t.description (Printexc.to_string e)); - Lwt.return `Eof + Lwt.return (Ok `Eof) ) let write t buf = @@ -482,23 +481,23 @@ module Sockets = struct Uwt.Tcp.write_ba ~pos:buf.Cstruct.off ~len:buf.Cstruct.len t.fd ~buf:buf.Cstruct.buffer >>= fun () -> - Lwt.return (`Ok ()) + Lwt.return (Ok ()) ) (function | Unix.Unix_error(Unix.ECONNRESET, _, _) -> - Lwt.return `Eof + Lwt.return (Error `Closed) | Unix.Unix_error(e, _, _) when Uwt.of_unix_error e = Uwt.ECANCELED -> - Lwt.return `Eof + Lwt.return (Error `Closed) | e -> Log.err (fun f -> f "Socket.%s.write %s: caught %s returning Eof" t.label t.description (Printexc.to_string e)); - Lwt.return `Eof + Lwt.return (Error `Closed) ) let writev t bufs = Lwt.catch (fun () -> let rec loop = function - | [] -> Lwt.return (`Ok ()) + | [] -> Lwt.return (Ok ()) | buf :: bufs -> Uwt.Tcp.write_ba ~pos:buf.Cstruct.off ~len:buf.Cstruct.len t.fd ~buf:buf.Cstruct.buffer @@ -508,14 +507,14 @@ module Sockets = struct loop bufs ) (function | Unix.Unix_error(Unix.ECONNRESET, _, _) -> - Lwt.return `Eof + Lwt.return (Error `Closed) | Unix.Unix_error(e, _, _) when Uwt.of_unix_error e = Uwt.ECANCELED -> - Lwt.return `Eof + Lwt.return (Error `Closed) | e -> Log.err (fun f -> f "Socket.%s.writev %s: caught %s returning Eof" t.label t.description (Printexc.to_string e)); - Lwt.return `Eof + Lwt.return (Error `Closed) ) let close t = @@ -767,12 +766,12 @@ module Sockets = struct let read_into t buf = let rec loop buf = if Cstruct.len buf = 0 - then Lwt.return (`Ok ()) + then Lwt.return (Ok (`Data ())) else Uwt.Pipe.read_ba ~pos:buf.Cstruct.off ~len:buf.Cstruct.len t.fd ~buf:buf.Cstruct.buffer >>= function - | 0 -> Lwt.return `Eof + | 0 -> Lwt.return (Ok `Eof) | n -> loop (Cstruct.shift buf n) in loop buf @@ -785,16 +784,16 @@ module Sockets = struct ~len:t.read_buffer.Cstruct.len t.fd ~buf:t.read_buffer.Cstruct.buffer >>= function - | 0 -> Lwt.return `Eof + | 0 -> Lwt.return (Ok `Eof) | n -> let results = Cstruct.sub t.read_buffer 0 n in t.read_buffer <- Cstruct.shift t.read_buffer n; - Lwt.return (`Ok results) + Lwt.return (Ok (`Data results)) ) (fun e -> Log.err (fun f -> f "Socket.Pipe.read %s: caught %a returning Eof" t.description Fmt.exn e); - Lwt.return `Eof + Lwt.return (Ok `Eof) ) let write t buf = @@ -802,23 +801,23 @@ module Sockets = struct Uwt.Pipe.write_ba ~pos:buf.Cstruct.off ~len:buf.Cstruct.len t.fd ~buf:buf.Cstruct.buffer >|= fun () -> - `Ok () + Ok () ) (function | Unix.Unix_error(Unix.EPIPE, _, _) -> (* other end has closed, this is normal *) - Lwt.return `Eof + Lwt.return (Error `Closed) | e -> (* Unexpected error *) Log.err (fun f -> f "Socket.Pipe.write %s: caught %a returning Eof" t.description Fmt.exn e); - Lwt.return `Eof + Lwt.return (Error `Closed) ) let writev t bufs = Lwt.catch (fun () -> let rec loop = function - | [] -> Lwt.return (`Ok ()) + | [] -> Lwt.return (Ok ()) | buf :: bufs -> Uwt.Pipe.write_ba ~pos:buf.Cstruct.off ~len:buf.Cstruct.len t.fd ~buf:buf.Cstruct.buffer @@ -830,7 +829,7 @@ module Sockets = struct Log.err (fun f -> f "Socket.Pipe.writev %s: caught %a returning Eof" t.description Fmt.exn e); - Lwt.return `Eof + Lwt.return (Error `Closed) ) let close t = @@ -1000,26 +999,7 @@ end module Time = struct type 'a io = 'a Lwt.t - - let sleep secs = Uwt.Timer.sleep (int_of_float (secs *. 1000.)) -end - -module Clock = struct - type tm = - { tm_sec: int; - tm_min: int; - tm_hour: int; - tm_mday: int; - tm_mon: int; - tm_year: int; - tm_wday: int; - tm_yday: int; - tm_isdst: bool; - } - - let time = Unix.gettimeofday - - let gmtime _ = failwith "gmtime unimplemented" + let sleep_ns x = Uwt.Timer.sleep (Duration.to_ms x) end module Dns = struct diff --git a/src/hostnet/hostnet_dns.ml b/src/hostnet/hostnet_dns.ml index 7343e235d..c7ee46b1c 100644 --- a/src/hostnet/hostnet_dns.ml +++ b/src/hostnet/hostnet_dns.ml @@ -148,13 +148,13 @@ let try_builtins local_ip host_names question = | _ -> None module Make - (Ip: V1_LWT.IPV4) - (Udp:V1_LWT.UDPV4) - (Tcp:V1_LWT.TCPV4) + (Ip: Mirage_protocols_lwt.IPV4) + (Udp:Mirage_protocols_lwt.UDPV4) + (Tcp:Mirage_protocols_lwt.TCPV4) (Socket: Sig.SOCKETS) (D: Sig.DNS) - (Time: V1_LWT.TIME) - (Clock: V1.CLOCK) + (Time: Mirage_time_lwt.S) + (Clock: Mirage_clock_lwt.MCLOCK) (Recorder: Sig.RECORDER) = struct @@ -213,47 +213,43 @@ struct packet creation fn *) let frame = Io_page.to_cstruct (Io_page.get 1) in let smac = "\000\000\000\000\000\000" in - Wire_structs.set_ethernet_src smac 0 frame; - Wire_structs.set_ethernet_ethertype frame 0x0800; - let buf = Cstruct.shift frame Wire_structs.sizeof_ethernet in - Wire_structs.Ipv4_wire.set_ipv4_hlen_version buf ((4 lsl 4) + (5)); - Wire_structs.Ipv4_wire.set_ipv4_tos buf 0; - Wire_structs.Ipv4_wire.set_ipv4_ttl buf 38; - let proto = Wire_structs.Ipv4_wire.protocol_to_int `UDP in - Wire_structs.Ipv4_wire.set_ipv4_proto buf proto; - Wire_structs.Ipv4_wire.set_ipv4_src buf (Ipaddr.V4.to_int32 source_ip); - Wire_structs.Ipv4_wire.set_ipv4_dst buf (Ipaddr.V4.to_int32 dest_ip); + Ethif_wire.set_ethernet_src smac 0 frame; + Ethif_wire.set_ethernet_ethertype frame 0x0800; + let buf = Cstruct.shift frame Ethif_wire.sizeof_ethernet in + Ipv4_wire.set_ipv4_hlen_version buf ((4 lsl 4) + (5)); + Ipv4_wire.set_ipv4_tos buf 0; + Ipv4_wire.set_ipv4_ttl buf 38; + let proto = Ipv4_packet.Marshal.protocol_to_int `UDP in + Ipv4_wire.set_ipv4_proto buf proto; + Ipv4_wire.set_ipv4_src buf (Ipaddr.V4.to_int32 source_ip); + Ipv4_wire.set_ipv4_dst buf (Ipaddr.V4.to_int32 dest_ip); let header_len = - Wire_structs.sizeof_ethernet + Wire_structs.Ipv4_wire.sizeof_ipv4 + Ethif_wire.sizeof_ethernet + Ipv4_wire.sizeof_ipv4 in - let frame = - Cstruct.set_len frame (header_len + Wire_structs.sizeof_udp) - in + let frame = Cstruct.set_len frame (header_len + Udp_wire.sizeof_udp) in let udp_buf = Cstruct.shift frame header_len in - Wire_structs.set_udp_source_port udp_buf source_port; - Wire_structs.set_udp_dest_port udp_buf dest_port; - Wire_structs.set_udp_length udp_buf - (Wire_structs.sizeof_udp + Cstruct.lenv bufs); - Wire_structs.set_udp_checksum udp_buf 0; + Udp_wire.set_udp_source_port udp_buf source_port; + Udp_wire.set_udp_dest_port udp_buf dest_port; + Udp_wire.set_udp_length udp_buf (Udp_wire.sizeof_udp + Cstruct.lenv bufs); + Udp_wire.set_udp_checksum udp_buf 0; let csum = Ip.checksum frame (udp_buf :: bufs) in - Wire_structs.set_udp_checksum udp_buf csum; + Udp_wire.set_udp_checksum udp_buf csum; (* Ip.writev *) let bufs = frame :: bufs in - let tlen = Cstruct.lenv bufs - Wire_structs.sizeof_ethernet in + let tlen = Cstruct.lenv bufs - Ethif_wire.sizeof_ethernet in let dmac = String.make 6 '\000' in (* Ip.adjust_output_header *) - Wire_structs.set_ethernet_dst dmac 0 frame; + Ethif_wire.set_ethernet_dst dmac 0 frame; let buf = - Cstruct.sub frame Wire_structs.sizeof_ethernet - Wire_structs.Ipv4_wire.sizeof_ipv4 + Cstruct.sub frame Ethif_wire.sizeof_ethernet Ipv4_wire.sizeof_ipv4 in (* Set the mutable values in the ipv4 header *) - Wire_structs.Ipv4_wire.set_ipv4_len buf tlen; - Wire_structs.Ipv4_wire.set_ipv4_id buf (Random.int 65535); (* TODO *) - Wire_structs.Ipv4_wire.set_ipv4_csum buf 0; + Ipv4_wire.set_ipv4_len buf tlen; + Ipv4_wire.set_ipv4_id buf (Random.int 65535); (* TODO *) + Ipv4_wire.set_ipv4_csum buf 0; let checksum = Tcpip_checksum.ones_complement buf in - Wire_structs.Ipv4_wire.set_ipv4_csum buf checksum; + Ipv4_wire.set_ipv4_csum buf checksum; Recorder.record recorder bufs | None -> () (* nowhere to log packet *) @@ -264,7 +260,7 @@ struct f "DNS names %s will map to local IP %s" (String.concat ", " @@ List.map Dns.Name.to_string host_names) (Ipaddr.to_string local_ip)); - function + fun clock -> function | `Upstream config -> let open Dns_forward.Config.Address in let nr_servers = @@ -280,10 +276,11 @@ struct Lwt.return_unit | _ -> (* We don't know how to marshal IPv6 yet *) - Lwt.return_unit in - Dns_udp_resolver.create ~message_cb config + Lwt.return_unit + in + Dns_udp_resolver.create ~message_cb config clock >>= fun dns_udp_resolver -> - Dns_tcp_resolver.create ~message_cb config + Dns_tcp_resolver.create ~message_cb config clock >>= fun dns_tcp_resolver -> Lwt.return { local_ip; host_names; resolver = Upstream { dns_tcp_resolver; dns_udp_resolver } } @@ -291,18 +288,13 @@ struct Log.info (fun f -> f "Will use the host's DNS resolver"); Lwt.return { local_ip; host_names; resolver = Host } - let answer t is_tcp buffer = + let answer t is_tcp buf = let open Dns.Packet in - let len = Cstruct.len buffer in - let buf = Dns.Buf.of_cstruct buffer in - match Dns.Protocol.Server.parse (Dns.Buf.sub buf 0 len) with + let len = Cstruct.len buf in + match Dns.Protocol.Server.parse (Cstruct.sub buf 0 len) with | None -> Lwt.return (Error (`Msg "failed to parse DNS packet")) | Some ({ questions = [ question ]; _ } as request) -> - let marshal pkt = - let buf = Dns.Buf.create 1024 in - let buf = Dns.Packet.marshal buf pkt in - Cstruct.of_bigarray buf in let reply answers = let id = request.id in let detail = @@ -323,9 +315,9 @@ struct | None -> match is_tcp, t.resolver with | true, Upstream { dns_tcp_resolver; _ } -> - Dns_tcp_resolver.answer buffer dns_tcp_resolver + Dns_tcp_resolver.answer buf dns_tcp_resolver | false, Upstream { dns_udp_resolver; _ } -> - Dns_udp_resolver.answer buffer dns_udp_resolver + Dns_udp_resolver.answer buf dns_udp_resolver | _, Host -> D.resolve question >>= function @@ -351,8 +343,7 @@ struct let describe buf = let len = Cstruct.len buf in - let buf = Dns.Buf.of_cstruct buf in - match Dns.Protocol.Server.parse (Dns.Buf.sub buf 0 len) with + match Dns.Protocol.Server.parse (Cstruct.sub buf 0 len) with | None -> Printf.sprintf "Unparsable DNS packet length %d" len | Some request -> Dns.Packet.to_string request @@ -361,9 +352,9 @@ struct >>= function | Error (`Msg m) -> Log.warn (fun f -> f "%s lookup failed: %s" (describe buf) m); - Lwt.return_unit + Lwt.return (Ok ()) | Ok buffer -> - Udp.write ~source_port:53 ~dest_ip:src ~dest_port:src_port udp buffer + Udp.write ~src_port:53 ~dst:src ~dst_port:src_port udp buffer let handle_tcp ~t = (* FIXME: need to record the upstream request *) diff --git a/src/hostnet/hostnet_dns.mli b/src/hostnet/hostnet_dns.mli index c7f2af644..672360f91 100644 --- a/src/hostnet/hostnet_dns.mli +++ b/src/hostnet/hostnet_dns.mli @@ -11,13 +11,13 @@ module Config: sig end module Make - (Ip: V1_LWT.IPV4 with type prefix = Ipaddr.V4.t) - (Udp: V1_LWT.UDPV4) - (Tcp:V1_LWT.TCPV4) + (Ip: Mirage_protocols_lwt.IPV4) + (Udp: Mirage_protocols_lwt.UDPV4) + (Tcp: Mirage_protocols_lwt.TCPV4) (Socket: Sig.SOCKETS) (Dns_resolver: Sig.DNS) - (Time: V1_LWT.TIME) - (Clock: V1.CLOCK) + (Time: Mirage_time_lwt.S) + (Clock: Mirage_clock_lwt.MCLOCK) (Recorder: Sig.RECORDER) : sig @@ -27,7 +27,7 @@ sig val create: local_address:Dns_forward.Config.Address.t -> host_names:Dns.Name.t list -> - Config.t -> t Lwt.t + Clock.t -> Config.t -> t Lwt.t (** Create a DNS forwarding instance based on the given configuration, either [`Upstream config]: send DNS requests to the given upstream servers [`Host]: use the Host's resolver. @@ -38,7 +38,7 @@ sig val handle_udp: t:t -> udp:Udp.t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> src_port:int -> - Cstruct.t -> unit Lwt.t + Cstruct.t -> (unit, Udp.error) result Lwt.t val handle_tcp: t:t -> (int -> (Tcp.flow -> unit Lwt.t) option) Lwt.t diff --git a/src/hostnet/hostnet_http.ml b/src/hostnet/hostnet_http.ml index 088dad0e4..4e5b43422 100644 --- a/src/hostnet/hostnet_http.ml +++ b/src/hostnet/hostnet_http.ml @@ -88,9 +88,9 @@ module Exclude = struct end module Make - (Ip: V1_LWT.IPV4 with type prefix = Ipaddr.V4.t) - (Udp: V1_LWT.UDPV4) - (Tcp:Mirage_flow_s.SHUTDOWNABLE) + (Ip: Mirage_protocols_lwt.IPV4) + (Udp: Mirage_protocols_lwt.UDPV4) + (Tcp:Mirage_flow_lwt.SHUTDOWNABLE) (Socket: Sig.SOCKETS) (Dns_resolver: Sig.DNS) = struct @@ -200,13 +200,13 @@ module Make Lwt.return (Ok t) module Incoming = struct - module C = Channel.Make(Tcp) + module C = Mirage_channel_lwt.Make(Tcp) module IO = Cohttp_mirage_io.Make(C) module Request = Cohttp.Request.Make(IO) module Response = Cohttp.Response.Make(IO) end module Outgoing = struct - module C = Channel.Make(Socket.Stream.Tcp) + module C = Mirage_channel_lwt.Make(Socket.Stream.Tcp) module IO = Cohttp_mirage_io.Make(C) module Request = Cohttp.Request.Make(IO) module Response = Cohttp.Response.Make(IO) @@ -330,38 +330,40 @@ module Make (* forward outgoing to ingoing *) let a_t flow ~incoming ~outgoing = + let warn pp e = + Log.warn (fun f -> f "Unexpected exeption %a in proxy" pp e); + in let rec loop () = - Lwt.catch (fun () -> - Outgoing.C.read_some outgoing >>= fun buf -> + (Outgoing.C.read_some outgoing >>= function + | Ok `Eof -> Lwt.return false + | Error e -> warn Outgoing.C.pp_error e; Lwt.return false + | Ok (`Data buf) -> Incoming.C.write_buffer incoming buf; - Incoming.C.flush incoming >|= fun () -> - true - ) (function - | End_of_file -> Lwt.return false - | e -> - Log.warn (fun f -> - f "Possibly unexpected exeption %a in proxy" Fmt.exn e); - Lwt.return false) - >>= fun continue -> + Incoming.C.flush incoming >|= function + | Ok () -> true + | Error `Closed -> false + | Error e -> warn Incoming.C.pp_write_error e; false + ) >>= fun continue -> if continue then loop () else Tcp.shutdown_write flow in loop () (* forward ingoing to outgoing *) let b_t remote ~incoming ~outgoing = + let warn pp e = + Log.warn (fun f -> f "Unexpected exeption %a in proxy" pp e); + in let rec loop () = - Lwt.catch (fun () -> - Incoming.C.read_some incoming >>= fun buf -> + (Incoming.C.read_some incoming >>= function + | Ok `Eof -> Lwt.return false + | Error e -> warn Incoming.C.pp_error e; Lwt.return false + | Ok (`Data buf) -> Outgoing.C.write_buffer outgoing buf; - Outgoing.C.flush outgoing >|= fun () -> - true - ) (function - | End_of_file -> Lwt.return false - | e -> - Log.warn (fun f -> - f "Possibly unexpected exeption %a in proxy" Fmt.exn e); - Lwt.return false) - >>= fun continue -> + Outgoing.C.flush outgoing >|= function + | Ok () -> true + | Error `Closed -> false + | Error e -> warn Outgoing.C.pp_write_error e; false + ) >>= fun continue -> if continue then loop () else Socket.Stream.Tcp.shutdown_write remote in loop () diff --git a/src/hostnet/hostnet_http.mli b/src/hostnet/hostnet_http.mli index f8cfae6ee..5f241d9ee 100644 --- a/src/hostnet/hostnet_http.mli +++ b/src/hostnet/hostnet_http.mli @@ -10,9 +10,9 @@ module Exclude: sig end module Make - (Ip: V1_LWT.IPV4 with type prefix = Ipaddr.V4.t) - (Udp: V1_LWT.UDPV4) - (Tcp: Mirage_flow_s.SHUTDOWNABLE) + (Ip: Mirage_protocols_lwt.IPV4) + (Udp: Mirage_protocols_lwt.UDPV4) + (Tcp: Mirage_flow_lwt.SHUTDOWNABLE) (Socket: Sig.SOCKETS) (Dns_resolver: Sig.DNS) : sig diff --git a/src/hostnet/hostnet_udp.ml b/src/hostnet/hostnet_udp.ml index 478ec0995..7c5e849ee 100644 --- a/src/hostnet/hostnet_udp.ml +++ b/src/hostnet/hostnet_udp.ml @@ -17,14 +17,18 @@ type datagram = { payload: Cstruct.t; } -module Make(Sockets: Sig.SOCKETS)(Time: V1_LWT.TIME) = struct +module Make + (Sockets: Sig.SOCKETS) + (Clock: Mirage_clock_lwt.MCLOCK) + (Time: Mirage_time_lwt.S) = +struct module Udp = Sockets.Datagram.Udp type flow = { description: string; server: Udp.server; - mutable last_use: float; + mutable last_use: int64; } (* For every src, src_port behind the NAT we create one listening socket @@ -32,7 +36,8 @@ module Make(Sockets: Sig.SOCKETS)(Time: V1_LWT.TIME) = struct but preserve them on the way in. *) type t = { - max_idle_time: float; + clock: Clock.t; + max_idle_time: int64; background_gc_t: unit Lwt.t; table: (address, flow) Hashtbl.t; (* src -> flow *) mutable send_reply: (datagram -> unit Lwt.t) option; @@ -42,14 +47,13 @@ module Make(Sockets: Sig.SOCKETS)(Time: V1_LWT.TIME) = struct let get_nat_table_size t = Hashtbl.length t.table - let start_background_gc table max_idle_time = + let start_background_gc clock table max_idle_time = let rec loop () = - Time.sleep max_idle_time - >>= fun () -> - let now = Unix.gettimeofday () in + Time.sleep_ns max_idle_time >>= fun () -> + let now_ns = Clock.elapsed_ns clock in let to_shutdown = Hashtbl.fold (fun k flow acc -> - if now -. flow.last_use > max_idle_time then begin + if Int64.(sub now_ns flow.last_use) > max_idle_time then begin Log.debug (fun f -> f "Hostnet_udp %s: expiring UDP NAT rule" flow.description); (k, flow) :: acc @@ -73,11 +77,11 @@ module Make(Sockets: Sig.SOCKETS)(Time: V1_LWT.TIME) = struct in loop () - let create ?(max_idle_time = 60.) () = + let create ?(max_idle_time = Duration.(of_sec 60)) clock = let table = Hashtbl.create 7 in - let background_gc_t = start_background_gc table max_idle_time in + let background_gc_t = start_background_gc clock table max_idle_time in let send_reply = None in - { max_idle_time; background_gc_t; table; send_reply } + { clock; max_idle_time; background_gc_t; table; send_reply } let description { src = src, src_port; dst = dst, dst_port; _ } = Fmt.strf "udp:%a:%d-%a:%d" Ipaddr.pp_hum src src_port Ipaddr.pp_hum @@ -131,7 +135,7 @@ module Make(Sockets: Sig.SOCKETS)(Time: V1_LWT.TIME) = struct Lwt.catch (fun () -> Udp.bind ~description:(description datagram) (Ipaddr.(V4 V4.any), 0) >>= fun server -> - let last_use = Unix.gettimeofday () in + let last_use = Clock.elapsed_ns t.clock in let flow = { description = d; server; last_use } in Hashtbl.replace t.table datagram.src flow; (* Start a listener *) @@ -147,10 +151,8 @@ module Make(Sockets: Sig.SOCKETS)(Time: V1_LWT.TIME) = struct | None -> Lwt.return () | Some flow -> Lwt.catch (fun () -> - Udp.sendto flow.server datagram.dst datagram.payload - >>= fun () -> - flow.last_use <- Unix.gettimeofday (); - Lwt.return () + Udp.sendto flow.server datagram.dst datagram.payload >|= fun () -> + flow.last_use <- Clock.elapsed_ns t.clock; ) (fun e -> Log.err (fun f -> f "Hostnet_udp %s: Lwt_bytes.send caught %a" diff --git a/src/hostnet/hostnet_udp.mli b/src/hostnet/hostnet_udp.mli index 8b49faf4d..ef15ee0c7 100644 --- a/src/hostnet/hostnet_udp.mli +++ b/src/hostnet/hostnet_udp.mli @@ -14,12 +14,16 @@ type datagram = { type reply = Cstruct.t -> unit Lwt.t -module Make(Sockets: Sig.SOCKETS)(Time: V1_LWT.TIME): sig +module Make + (Sockets: Sig.SOCKETS) + (Clock: Mirage_clock_lwt.MCLOCK) + (Time: Mirage_time_lwt.S): +sig type t (** A UDP NAT implementation *) - val create: ?max_idle_time:float -> unit -> t + val create: ?max_idle_time:int64 -> Clock.t -> t (** Create a UDP NAT implementation which will keep "NAT rules" alive until they become idle for the given [?max_idle_time] *) diff --git a/src/hostnet/jbuild b/src/hostnet/jbuild index 0b0ddcce3..81a58b5c8 100644 --- a/src/hostnet/jbuild +++ b/src/hostnet/jbuild @@ -3,12 +3,14 @@ (library ((name hostnet) (libraries ( - cstruct lwt.unix logs mirage-types.lwt ipaddr mirage-flow - ppx_sexp_conv pcap-format mirage-console.unix tcpip.ethif + cstruct lwt.unix logs ipaddr mirage-flow-lwt + ppx_sexp_conv pcap-format mirage-console-unix tcpip.ethif tcpip.arpv4 tcpip.ipv4 tcpip.icmpv4 tcpip.udp tcpip.tcp tcpip.stack-direct - charrua-core.server dns dns.lwt ofs uwt uwt.ext uwt.preemptive - lwt.preemptive threads astring datakit-server.vfs dns-forward tar mirage-vnetif - dnssd uuidm mirage-http cohttp channel ezjsonm + charrua-core.server dns dns-lwt ofs uwt uwt.ext uwt.preemptive + lwt.preemptive threads astring datakit-server dns-forward tar mirage-vnetif + dnssd uuidm mirage-http cohttp-lwt mirage-channel ezjsonm + mirage-protocols-lwt duration mirage-time-lwt mirage-clock-lwt + mirage-time-unix mirage-random tcpip.unix )) (c_names (stubs_utils)) (wrapped false))) diff --git a/src/hostnet/mux.ml b/src/hostnet/mux.ml index 14fa0d2c1..2a5d04b1a 100644 --- a/src/hostnet/mux.ml +++ b/src/hostnet/mux.ml @@ -8,39 +8,26 @@ let src = module Log = (val Logs.src_log src : Logs.LOG) module DontCareAboutStats = struct - type stats = { - mutable rx_bytes: int64; - mutable rx_pkts: int32; - mutable tx_bytes: int64; - mutable tx_pkts: int32; - } - let get_stats_counters _ = { - rx_bytes = 0L; rx_pkts = 0l; - tx_bytes = 0L; tx_pkts = 0l; - } - + let get_stats_counters _ = Mirage_net.Stats.create () let reset_stats_counters _ = () end module ObviouslyCommon = struct - type page_aligned_buffer = Io_page.t - - type buffer = Cstruct.t - - type error = [ - | `Unknown of string - | `Unimplemented - | `Disconnected - ] + type page_aligned_buffer = Io_page.t type macaddr = Macaddr.t - type 'a io = 'a Lwt.t + type buffer = Cstruct.t + type error = [Mirage_device.error | `Unknown of string] + + let pp_error ppf = function + | #Mirage_device.error as e -> Mirage_device.pp_error ppf e + | `Unknown s -> Fmt.pf ppf "unknown: %s" s - type id = unit end -module Make(Netif: V1_LWT.NETWORK) = struct +module Make (Netif: Mirage_net_lwt.S) = struct + include DontCareAboutStats include ObviouslyCommon @@ -61,6 +48,11 @@ module Make(Netif: V1_LWT.NETWORK) = struct mutable default_callback: callback; } + let lift_error: ('a, Netif.error) result -> ('a, error) result = function + | Ok x -> Ok x + | Error (#Mirage_device.error as e) -> Error e + | Error e -> Fmt.kstrf (fun s -> Error (`Unknown s)) "%a" Netif.pp_error e + let filesystem t = let xs = RuleMap.fold @@ -99,20 +91,14 @@ module Make(Netif: V1_LWT.NETWORK) = struct let rules = RuleMap.empty in let default_callback = fun _ -> Lwt.return_unit in let t = { netif; rules; default_callback } in - Netif.listen netif @@ callback t - >>= fun () -> - Lwt.return t - - let write t buffer = - Netif.write t.netif buffer - let writev t buffers = - Netif.writev t.netif buffers - let listen t callback = - t.default_callback <- callback; - Lwt.return_unit - let disconnect t = - Netif.disconnect t.netif - + Netif.listen netif @@ callback t >>= fun r -> + let r = match r with Ok () -> Ok t | Error _ as e -> e in + Lwt.return r >|= lift_error + + let write t buffer = Netif.write t.netif buffer >|= lift_error + let writev t buffers = Netif.writev t.netif buffers >|= lift_error + let listen t callback = t.default_callback <- callback; Lwt.return (Ok ()) + let disconnect t = Netif.disconnect t.netif let mac t = Netif.mac t.netif module Port = struct @@ -125,15 +111,17 @@ module Make(Netif: V1_LWT.NETWORK) = struct rule: rule; } - let write t buffer = Netif.write t.netif buffer - let writev t buffers = Netif.writev t.netif buffers + let write t buffer = Netif.write t.netif buffer >|= lift_error + let writev t buffers = Netif.writev t.netif buffers >|= lift_error + let listen t callback = Log.debug (fun f -> f "activating switch port for %s" (Ipaddr.V4.to_string t.rule)); let last_active_time = Unix.gettimeofday () in let port = { callback; last_active_time } in t.switch.rules <- RuleMap.add t.rule port t.switch.rules; - Lwt.return_unit + Lwt.return (Ok ()) + let disconnect t = Log.debug (fun f -> f "deactivating switch port for %s" (Ipaddr.V4.to_string t.rule)); diff --git a/src/hostnet/mux.mli b/src/hostnet/mux.mli index e0a3e0d8f..6927d508d 100644 --- a/src/hostnet/mux.mli +++ b/src/hostnet/mux.mli @@ -1,5 +1,5 @@ -module Make(Netif: V1_LWT.NETWORK) : sig - include V1_LWT.NETWORK +module Make(Netif: Mirage_net_lwt.S) : sig + include Mirage_net_lwt.S (** A simple ethernet multiplexer/demultiplexer @@ -14,14 +14,14 @@ module Make(Netif: V1_LWT.NETWORK) : sig *) - val connect: Netif.t -> t Lwt.t + val connect: Netif.t -> (t, error) result Lwt.t (** Connect a multiplexer/demultiplexer and return a [t] which behaves like a V1_LWT.NETWORK representing the multiplexed end. *) type rule = Ipaddr.V4.t (** We currently support matching on IPv4 destination addresses only *) - module Port : V1_LWT.NETWORK + module Port : Mirage_net_lwt.S (** A network which receives all the traffic matching a specific rule *) val port: t -> rule -> Port.t diff --git a/src/hostnet/sig.ml b/src/hostnet/sig.ml index 9b6f6c478..8f9f48f04 100644 --- a/src/hostnet/sig.ml +++ b/src/hostnet/sig.ml @@ -3,12 +3,12 @@ module type READ_INTO = sig type error val read_into: flow -> Cstruct.t -> - [ `Eof | `Error of error | `Ok of unit ] Lwt.t + (unit Mirage_flow.or_eof, error) result Lwt.t (** Completely fills the given buffer with data from [fd] *) end module type FLOW_CLIENT = sig - include Mirage_flow_s.SHUTDOWNABLE + include Mirage_flow_lwt.SHUTDOWNABLE type address @@ -19,7 +19,7 @@ module type FLOW_CLIENT = sig end module type CONN = sig - include V1_LWT.FLOW + include Mirage_flow_lwt.S include READ_INTO with type flow := flow @@ -166,8 +166,7 @@ module type HOST = sig include FILES end - module Time: V1_LWT.TIME - module Clock: V1.CLOCK + module Time: Mirage_time_lwt.S module Dns: sig include DNS @@ -199,7 +198,7 @@ end module type VMNET = sig (** A virtual ethernet link to the VM *) - include V1_LWT.NETWORK + include Mirage_net_lwt.S val add_listener: t -> (Cstruct.t -> unit Lwt.t) -> unit (** Add a callback which will be invoked in parallel with all received packets *) diff --git a/src/hostnet/slirp.ml b/src/hostnet/slirp.ml index e65843ada..d6e1ca74b 100644 --- a/src/hostnet/slirp.ml +++ b/src/hostnet/slirp.ml @@ -40,29 +40,12 @@ let log_exception_continue description f = Lwt.return () ) -module Infix = struct - let ( >>= ) m f = m >>= function - | `Ok x -> f x - | `Error x -> Lwt.return (`Error x) -end - let failf fmt = Fmt.kstrf Lwt.fail_with fmt -let errorf fmt = Fmt.kstrf (fun e -> Lwt.return (`Error (`Msg e))) fmt let or_failwith name m = - m >>= function - | `Error _ -> failf "Failed to connect %s device" name - | `Ok x -> Lwt.return x - -let or_failwith_result name m = m >>= function | Error _ -> failf "Failed to connect %s device" name - | Ok x -> Lwt.return x - -let or_error name m = - m >>= function - | `Error _ -> errorf "Failed to connect %s device" name - | `Ok x -> Lwt.return (`Ok x) + | Ok x -> Lwt.return x let restart_on_change name to_string values = Active_config.tl values @@ -90,7 +73,7 @@ type uuid_table = { table: (Uuidm.t, Ipaddr.V4.t * int) Hashtbl.t; } -type config = { +type 'a config = { server_macaddr: Macaddr.t; peer_ip: Ipaddr.V4.t; local_ip: Ipaddr.V4.t; @@ -103,22 +86,28 @@ type config = { bridge_connections: bool; mtu: int; host_names: Dns.Name.t list; + clock: 'a; } module Make (Config: Active_config.S) (Vmnet: Sig.VMNET) (Dns_policy: Sig.DNS_POLICY) + (Clock: sig + include Mirage_clock_lwt.MCLOCK + val connect: unit -> t Lwt.t + end) + (Random: Mirage_random.C) (Host: Sig.HOST) (Vnet : Vnetif.BACKEND with type macaddr = Macaddr.t) = struct (* module Tcpip_stack = Tcpip_stack.Make(Vmnet)(Host.Time) *) - module Filteredif = Filter.Make(Vmnet) +module Filteredif = Filter.Make(Vmnet) module Netif = Capture.Make(Filteredif) module Recorder = (Netif: Sig.RECORDER with type t = Netif.t) module Switch = Mux.Make(Netif) - module Dhcp = Dhcp.Make(Switch) + module Dhcp = Dhcp.Make(Clock)(Switch) (* This ARP implementation will respond to the VM: *) module Global_arp_ethif = Ethif.Make(Switch) @@ -127,10 +116,10 @@ struct (* This stack will attach to a switch port and represent a single remote IP *) module Stack_ethif = Ethif.Make(Switch.Port) module Stack_arpv4 = Arp.Make(Stack_ethif) - module Stack_ipv4 = Ipv4.Make(Stack_ethif)(Stack_arpv4) + module Stack_ipv4 = Static_ipv4.Make(Stack_ethif)(Stack_arpv4) module Stack_icmpv4 = Icmpv4.Make(Stack_ipv4) module Stack_tcp_wire = Tcp.Wire.Make(Stack_ipv4) - module Stack_udp = Udp.Make(Stack_ipv4) + module Stack_udp = Udp.Make(Stack_ipv4)(Random) module Stack_tcp = struct include Tcp.Flow.Make(Stack_ipv4)(Host.Time)(Clock)(Random) let shutdown_read _flow = @@ -143,18 +132,23 @@ struct module Dns_forwarder = Hostnet_dns.Make(Stack_ipv4)(Stack_udp)(Stack_tcp)(Host.Sockets)(Host.Dns) - (Host.Time)(Host.Clock)(Recorder) + (Host.Time)(Clock)(Recorder) module Http_forwarder = Hostnet_http.Make(Stack_ipv4)(Stack_udp)(Stack_tcp)(Host.Sockets)(Host.Dns) - module Udp_nat = Hostnet_udp.Make(Host.Sockets)(Host.Time) + module Udp_nat = Hostnet_udp.Make(Host.Sockets)(Clock)(Host.Time) + + let dns_forwarder ~local_address ~host_names clock = + Dns_forwarder.create ~local_address ~host_names clock (Dns_policy.config ()) (* Global variable containing the global DNS configuration *) let dns = let ip = Ipaddr.V4 (Ipaddr.V4.of_string_exn default_host) in let local_address = { Dns_forward.Config.Address.ip; port = 0 } in - ref (Dns_forwarder.create ~local_address ~host_names:[] @@ - Dns_policy.config ()) + ref ( + Clock.connect () >>= fun clock -> + dns_forwarder ~local_address ~host_names:[] clock + ) (* Global variable containing the global HTTP proxy configuration *) let http = @@ -175,35 +169,39 @@ struct | _ -> false let string_of_id id = + let src = Stack_tcp_wire.src id in + let src_port = Stack_tcp_wire.src_port id in + let dst = Stack_tcp_wire.dst id in + let dst_port = Stack_tcp_wire.dst_port id in Fmt.strf "TCP %a:%d > %a:%d" - Ipaddr.V4.pp_hum id.Stack_tcp_wire.dest_ip id.Stack_tcp_wire.dest_port - Ipaddr.V4.pp_hum id.Stack_tcp_wire.local_ip id.Stack_tcp_wire.local_port + Ipaddr.V4.pp_hum dst dst_port + Ipaddr.V4.pp_hum src src_port module Tcp = struct module Id = struct module M = struct - type t = Stack_tcp_wire.id - let compare - { Stack_tcp_wire.local_ip = local_ip1; - local_port = local_port1; - dest_ip = dest_ip1; - dest_port = dest_port1 } - { Stack_tcp_wire.local_ip = local_ip2; - local_port = local_port2; - dest_ip = dest_ip2; - dest_port = dest_port2 } = - let dest_ip' = Ipaddr.V4.compare dest_ip1 dest_ip2 in - let local_ip' = Ipaddr.V4.compare local_ip1 local_ip2 in - let dest_port' = compare dest_port1 dest_port2 in - let local_port' = compare local_port1 local_port2 in - if dest_port' <> 0 - then dest_port' - else if dest_ip' <> 0 - then dest_ip' - else if local_ip' <> 0 - then local_ip' - else local_port' + type t = Stack_tcp_wire.t + let compare id1 id2 = + let dst_ip1 = Stack_tcp_wire.dst id1 in + let dst_port1 = Stack_tcp_wire.dst_port id1 in + let dst_ip2 = Stack_tcp_wire.dst id2 in + let dst_port2 = Stack_tcp_wire.dst_port id2 in + let src_ip1 = Stack_tcp_wire.src id1 in + let src_port1 = Stack_tcp_wire.src_port id1 in + let src_ip2 = Stack_tcp_wire.src id2 in + let src_port2 = Stack_tcp_wire.src_port id2 in + let dst_ip' = Ipaddr.V4.compare dst_ip1 dst_ip2 in + let dst_port' = compare dst_port1 dst_port2 in + let src_ip' = Ipaddr.V4.compare src_ip1 src_ip2 in + let src_port' = compare src_port1 src_port2 in + if dst_port' <> 0 + then dst_port' + else if dst_ip' <> 0 + then dst_ip' + else if src_port' <> 0 + then src_port' + else src_ip' end include M module Set = Set.Make(M) @@ -214,7 +212,7 @@ struct (** An established flow *) type t = { - id: Stack_tcp_wire.id; + id: Stack_tcp_wire.t; mutable socket: Host.Sockets.Stream.Tcp.flow option; mutable last_active_time: float; } @@ -257,6 +255,7 @@ struct icmpv4: Stack_icmpv4.t; udp4: Stack_udp.t; tcp4: Stack_tcp.t; + clock: Clock.t; mutable pending: Tcp.Id.Set.t; mutable last_active_time: float; } @@ -265,37 +264,29 @@ struct let touch t = t.last_active_time <- Unix.gettimeofday () - let create recorder switch arp_table ip mtu = + let create recorder switch arp_table ip mtu clock = let netif = Switch.port switch ip in - let open Infix in - or_error "Stack_ethif.connect" @@ Stack_ethif.connect ~mtu netif - >>= fun ethif -> - or_error "Stack_arpv4.connect" @@ Stack_arpv4.connect ~table:arp_table ethif - >>= fun arp -> - or_error "Stack_ipv4.connect" @@ Stack_ipv4.connect ethif arp + Stack_ethif.connect ~mtu netif >>= fun ethif -> + Stack_arpv4.connect ~table:arp_table ethif |>fun arp -> + Stack_ipv4.connect + ~gateway:None + ~network:Ipaddr.V4.(Prefix.of_addr unspecified) + ethif arp >>= fun ipv4 -> - or_error "Stack_icmpv4.connect" @@ Stack_icmpv4.connect ipv4 - >>= fun icmpv4 -> - or_error "Stack_udp.connect" @@ Stack_udp.connect ipv4 - >>= fun udp4 -> - or_error "Stack_tcp.connect" @@ Stack_tcp.connect ipv4 - >>= fun tcp4 -> + Stack_icmpv4.connect ipv4 >>= fun icmpv4 -> + Stack_udp.connect ipv4 >>= fun udp4 -> + Stack_tcp.connect ipv4 clock >>= fun tcp4 -> - let open Lwt.Infix in Stack_ipv4.set_ip ipv4 ip (* I am the destination *) >>= fun () -> - Stack_ipv4.set_ip_netmask ipv4 Ipaddr.V4.unspecified (* 0.0.0.0 *) - >>= fun () -> - Stack_ipv4.set_ip_gateways ipv4 [ ] - >>= fun () -> let pending = Tcp.Id.Set.empty in let last_active_time = Unix.gettimeofday () in let tcp_stack = { recorder; netif; ethif; arp; ipv4; icmpv4; udp4; tcp4; pending; - last_active_time } + last_active_time; clock } in - Lwt.return (`Ok tcp_stack) + Lwt.return tcp_stack let intercept_tcp_syn t ~id ~syn on_syn_callback (buf: Cstruct.t) = if syn then begin @@ -311,8 +302,9 @@ struct (fun () -> on_syn_callback () >>= fun listeners -> - Stack_tcp.input t.tcp4 ~listeners ~src:id.Stack_tcp_wire.dest_ip - ~dst:id.Stack_tcp_wire.local_ip buf + let src = Stack_tcp_wire.dst id in + let dst = Stack_tcp_wire.src id in + Stack_tcp.input t.tcp4 ~listeners ~src ~dst buf ) (fun () -> t.pending <- Tcp.Id.Set.remove id t.pending; Lwt.return_unit; @@ -321,10 +313,14 @@ struct end else begin Tcp.Flow.touch id; (* non-SYN packets are injected into the stack as normal *) - Stack_tcp.input t.tcp4 ~listeners:(fun _ -> None) - ~src:id.Stack_tcp_wire.dest_ip ~dst:id.Stack_tcp_wire.local_ip buf + let src = Stack_tcp_wire.dst id in + let dst = Stack_tcp_wire.src id in + Stack_tcp.input t.tcp4 ~listeners:(fun _ -> None) ~src ~dst buf end + module Proxy = + Mirage_flow_lwt.Proxy(Clock)(Stack_tcp)(Host.Sockets.Stream.Tcp) + let input_tcp t ~id ~syn (ip, port) (buf: Cstruct.t) = intercept_tcp_syn t ~id ~syn (fun () -> Host.Sockets.Stream.Tcp.connect (ip, port) @@ -335,36 +331,33 @@ struct Ipaddr.pp_hum ip port m); Lwt.return (fun _ -> None) | Ok socket -> - let t = Tcp.Flow.create id socket in + let tcp = Tcp.Flow.create id socket in let listeners port = Log.debug (fun f -> f "%a:%d handshake complete" Ipaddr.pp_hum ip port); let f flow = - match t.Tcp.Flow.socket with + match tcp.Tcp.Flow.socket with | None -> Log.err (fun f -> f "%s callback called on closed socket" - (Tcp.Flow.to_string t)); + (Tcp.Flow.to_string tcp)); Lwt.return_unit | Some socket -> Lwt.finalize (fun () -> - Mirage_flow.proxy - (module Clock) - (module Stack_tcp) flow - (module Host.Sockets.Stream.Tcp) socket () + Proxy.proxy t.clock flow socket >>= function - | `Error (`Msg m) -> + | Error e -> Log.debug (fun f -> - f "%s proxy failed with %s" - (Tcp.Flow.to_string t) m); + f "%s proxy failed with %a" + (Tcp.Flow.to_string tcp) Proxy.pp_error e); Lwt.return_unit - | `Ok (_l_stats, _r_stats) -> + | Ok (_l_stats, _r_stats) -> Lwt.return_unit ) (fun () -> Log.debug (fun f -> - f "closing flow %s" (string_of_id t.Tcp.Flow.id)); - t.Tcp.Flow.socket <- None; - Tcp.Flow.remove t.Tcp.Flow.id; + f "closing flow %s" (string_of_id tcp.Tcp.Flow.id)); + tcp.Tcp.Flow.socket <- None; + Tcp.Flow.remove tcp.Tcp.Flow.id; Host.Sockets.Stream.Tcp.close socket ) in @@ -436,8 +429,23 @@ struct open Frame + let pp_error ppf = function + | `Udp e -> Stack_udp.pp_error ppf e + | `Ipv4 e -> Stack_ipv4.pp_error ppf e + + let lift_ipv4_error = function + | Ok _ as x -> x + | Error e -> Error (`Ipv4 e) + + let lift_udp_error = function + | Ok _ as x -> x + | Error e -> Error (`Udp e) + + let ok () = Ok () + module Local = struct type t = { + clock: Clock.t; endpoint: Endpoint.t; udp_nat: Udp_nat.t; dns_ips: Ipaddr.V4.t list; @@ -456,6 +464,7 @@ struct | _ -> Lwt.return_unit in Stack_ipv4.input t.endpoint.Endpoint.ipv4 ~tcp:none ~udp:none ~default raw + >|= ok (* UDP on port 53 -> DNS forwarder *) | Ipv4 { src; dst; @@ -464,20 +473,20 @@ struct let udp = t.endpoint.Endpoint.udp4 in !dns >>= fun t -> Dns_forwarder.handle_udp ~t ~udp ~src ~dst ~src_port payload + >|= lift_udp_error (* TCP to port 53 -> DNS forwarder *) | Ipv4 { src; dst; payload = Tcp { src = src_port; dst = 53; syn; raw; payload = Payload _; _ }; _ } -> let id = - { Stack_tcp_wire.local_port = 53; dest_ip = src; local_ip = dst; - dest_port = src_port } + Stack_tcp_wire.v ~src_port:53 ~dst:src ~src:dst ~dst_port:src_port in - Endpoint.intercept_tcp_syn t.endpoint ~id ~syn - (fun () -> - !dns >>= fun t -> - Dns_forwarder.handle_tcp ~t + Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun () -> + !dns >>= fun t -> + Dns_forwarder.handle_tcp ~t ) raw + >|= ok (* UDP to port 123: localhost NTP *) | Ipv4 { src; @@ -492,6 +501,7 @@ struct dst = Ipaddr.V4 localhost, 123; payload } in Udp_nat.input ~t:t.udp_nat ~datagram () + >|= ok (* UDP to any other port: localhost *) | Ipv4 { src; dst; ihl; dnf; raw; @@ -504,10 +514,11 @@ struct if Cstruct.len payload < len then begin Log.err (fun f -> f "%s: dropping because reported len %d actual len %d" description len (Cstruct.len payload)); - Lwt.return_unit + Lwt.return (Ok ()) end else if dnf && (Cstruct.len payload > safe_outgoing_mtu) then begin Endpoint.send_icmp_dst_unreachable t.endpoint ~src ~dst ~src_port ~dst_port ~ihl raw + >|= lift_ipv4_error end else begin (* [1] For UDP to our local address, rewrite the destination to localhost. This is the inverse of the rewrite @@ -517,6 +528,7 @@ struct dst = Ipaddr.(V4 V4.localhost), dst_port; payload } in Udp_nat.input ~t:t.udp_nat ~datagram () + >|= ok end (* TCP to local ports *) @@ -524,16 +536,16 @@ struct payload = Tcp { src = src_port; dst = dst_port; syn; raw; payload = Payload _; _ }; _ } -> let id = - { Stack_tcp_wire.local_port = dst_port; dest_ip = src; local_ip = dst; - dest_port = src_port } + Stack_tcp_wire.v ~src_port:dst_port ~dst:src ~src:dst ~dst_port:src_port in Endpoint.input_tcp t.endpoint ~id ~syn (Ipaddr.V4 Ipaddr.V4.localhost, dst_port) raw + >|= ok | _ -> - Lwt.return_unit + Lwt.return (Ok ()) - let create endpoint udp_nat dns_ips = - let tcp_stack = { endpoint; udp_nat; dns_ips } in + let create clock endpoint udp_nat dns_ips = + let tcp_stack = { clock; endpoint; udp_nat; dns_ips } in let open Lwt.Infix in (* Wire up the listeners to receive future packets: *) Switch.Port.listen endpoint.Endpoint.netif @@ -542,13 +554,17 @@ struct match parse [ buf ] with | Ok (Ethernet { payload = Ipv4 ipv4; _ }) -> Endpoint.touch endpoint; - input_ipv4 tcp_stack (Ipv4 ipv4) + (input_ipv4 tcp_stack (Ipv4 ipv4) >|= function + | Ok () -> () + | Error e -> + Log.err (fun l -> + l "error while reading IPv4 input: %a" pp_error e)) | _ -> Lwt.return_unit ) - >>= fun () -> - - Lwt.return (`Ok tcp_stack) + >|= function + | Ok () -> Ok tcp_stack + | Error _ as e -> e end @@ -572,12 +588,16 @@ struct | _ -> Lwt.return_unit in Stack_ipv4.input t.endpoint.Endpoint.ipv4 ~tcp:none ~udp:none ~default raw + >|= ok (* Transparent HTTP intercept? *) - | Ipv4 { src = dest_ip; dst = local_ip; + | Ipv4 { src = dest_ip ; dst = local_ip; payload = Tcp { src = dest_port; dst = local_port; syn; raw; _ }; _ } -> - let id = { Stack_tcp_wire.local_port; dest_ip; local_ip; dest_port } in + let id = + Stack_tcp_wire.v + ~src_port:local_port ~dst:dest_ip ~src:local_ip ~dst_port:dest_port + in let callback = match !http with | None -> None | Some http -> Http_forwarder.handle ~dst:(local_ip, local_port) ~t:http @@ -586,8 +606,10 @@ struct | None -> Endpoint.input_tcp t.endpoint ~id ~syn (Ipaddr.V4 local_ip, local_port) raw (* common case *) + >|= ok | Some cb -> Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun _ -> cb) raw + >|= ok end | Ipv4 { src; dst; ihl; dnf; raw; payload = Udp { src = src_port; dst = dst_port; len; @@ -598,7 +620,7 @@ struct Log.err (fun f -> f "%s: dropping because reported len %d actual len %d" description len (Cstruct.len payload)); - Lwt.return_unit + Lwt_result.return () end else if dnf && (Cstruct.len payload > safe_outgoing_mtu) then begin Endpoint.send_icmp_dst_unreachable t.endpoint ~src ~dst ~src_port ~dst_port ~ihl raw @@ -608,10 +630,10 @@ struct dst = Ipaddr.V4 dst, dst_port; payload } in Udp_nat.input ~t:t.udp_nat ~datagram () + >|= ok end - | _ -> - Lwt.return_unit + | _ -> Lwt_result.return () let create endpoint udp_nat = let tcp_stack = { endpoint; udp_nat } in @@ -623,12 +645,18 @@ struct match parse [ buf ] with | Ok (Ethernet { payload = Ipv4 ipv4; _ }) -> Endpoint.touch endpoint; - input_ipv4 tcp_stack (Ipv4 ipv4) + (input_ipv4 tcp_stack (Ipv4 ipv4) >|= function + | Ok () -> () + | Error e -> + Log.err (fun l -> + l "error while reading IPv4 input: %a" + Stack_ipv4.pp_error e)) | _ -> Lwt.return_unit ) - >>= fun () -> - Lwt.return (`Ok tcp_stack) + >|= function + | Ok () -> Ok tcp_stack + | Error _ as e -> e end let filesystem t = @@ -654,13 +682,17 @@ struct ) let diagnostics t flow = - let module C = Channel.Make(Host.Sockets.Stream.Unix) in + let module C = Mirage_channel_lwt.Make(Host.Sockets.Stream.Unix) in let module Writer = Tar.HeaderWriter(Lwt)(struct type out_channel = C.t type 'a t = 'a Lwt.t let really_write oc buf = C.write_buffer oc buf; - C.flush oc + C.flush oc >|= function + | Ok () -> () + | Error e -> + Log.err (fun l -> + l "error while flushing tar channel: %a" C.pp_write_error e) end) in let c = C.create flow in @@ -672,7 +704,7 @@ struct Lwt.return_unit in let rec tar pwd dir = - let mod_time = Int64.of_float @@ Host.Clock.time () in + let mod_time = Int64.of_float @@ Unix.gettimeofday () in Vfs.Dir.ls dir >>?= fun inodes -> Lwt_list.iter_s @@ -701,7 +733,8 @@ struct else aux (Int64.add offset len) in aux 0L >>= fun () -> - Lwt.return (List.rev !fragments) in + Lwt.return (List.rev !fragments) + in copy () >>= fun fragments -> let length = @@ -716,26 +749,33 @@ struct >>= fun () -> List.iter (C.write_buffer c) fragments; C.write_buffer c (Tar.Header.zero_padding header); - C.flush c + C.flush c >|= function + | Ok () -> () + | Error e -> + Log.err (fun l -> + l "flushing of tar block failed: %a" C.pp_write_error e); ) inodes in tar "" (filesystem t) >>= fun () -> C.write_buffer c Tar.Header.zero_block; C.write_buffer c Tar.Header.zero_block; - C.flush c + C.flush c >|= function + | Ok () -> () + | Error e -> + Log.err (fun l -> + l "error while flushing the diagnostic: %a" C.pp_write_error e) module Debug = struct let get_nat_table_size t = Udp_nat.get_nat_table_size t.udp_nat let update_dns - ?(local_ip = Ipaddr.V4 Ipaddr.V4.localhost) ?(host_names = []) () + ?(local_ip = Ipaddr.V4 Ipaddr.V4.localhost) ?(host_names = []) clock = let local_address = { Dns_forward.Config.Address.ip = local_ip; port = 0 } in - dns := Dns_forwarder.create ~local_address ~host_names - (Dns_policy.config ()) + dns := dns_forwarder ~local_address ~host_names clock let update_http ?http:http_config ?https ?exclude () = Http_forwarder.create ?http:http_config ?https ?exclude () @@ -757,7 +797,7 @@ struct (* If no traffic is received for 5 minutes, delete the endpoint and the switch port. *) let rec delete_unused_endpoints t () = - Host.Time.sleep 30. + Host.Time.sleep_ns (Duration.of_sec 30) >>= fun () -> Lwt_mutex.with_lock t.endpoints_m (fun () -> @@ -777,16 +817,16 @@ struct let connect x l2_switch l2_client_id client_macaddr server_macaddr peer_ip local_ip highest_ip extra_dns_ip mtu get_domain_search get_domain_name - (global_arp_table:arp_table) use_bridge + (global_arp_table:arp_table) use_bridge clock = let valid_subnets = [ Ipaddr.V4.Prefix.global ] in let valid_sources = [ Ipaddr.V4.of_string_exn "0.0.0.0" ] in - or_failwith "filter" @@ Filteredif.connect ~valid_subnets ~valid_sources x - >>= fun (filteredif: Filteredif.t) -> - or_failwith "capture" @@ Netif.connect filteredif - >>= fun interface -> + Filteredif.connect ~valid_subnets ~valid_sources x + |> fun (filteredif: Filteredif.t) -> + Netif.connect filteredif + |> fun interface -> Dns_forwarder.set_recorder interface; let kib = 1024 in @@ -796,7 +836,7 @@ struct (* Capture 64KiB of NTP traffic *) Netif.add_match ~t:interface ~name:"ntp.pcap" ~limit:(64 * kib) ~snaplen:1500 ~predicate:is_ntp; - Switch.connect interface + or_failwith "Switch.connect" (Switch.connect interface) >>= fun switch -> (* Serve a static ARP table *) @@ -804,7 +844,7 @@ struct peer_ip, client_macaddr; local_ip, server_macaddr; ] @ (List.map (fun ip -> ip, server_macaddr) extra_dns_ip) in - or_failwith "arp_ethif" @@ Global_arp_ethif.connect switch + Global_arp_ethif.connect switch >>= fun global_arp_ethif -> (* Listen on local IPs *) @@ -818,11 +858,11 @@ struct end in let dhcp = Dhcp.make ~server_macaddr ~peer_ip ~highest_peer_ip ~local_ip - ~extra_dns_ip ~get_domain_search ~get_domain_name switch in + ~extra_dns_ip ~get_domain_search ~get_domain_name clock switch in let endpoints = IPMap.empty in let endpoints_m = Lwt_mutex.create () in - let udp_nat = Udp_nat.create () in + let udp_nat = Udp_nat.create clock in let t = { l2_switch; l2_client_id; @@ -839,13 +879,12 @@ struct Lwt_mutex.with_lock t.endpoints_m (fun () -> if IPMap.mem ip t.endpoints - then Lwt.return (`Ok (IPMap.find ip t.endpoints)) + then Lwt.return (Ok (IPMap.find ip t.endpoints)) else begin - let open Infix in - Endpoint.create interface switch local_arp_table ip mtu - >>= fun endpoint -> + Endpoint.create interface switch local_arp_table ip mtu clock + >|= fun endpoint -> t.endpoints <- IPMap.add ip endpoint t.endpoints; - Lwt.return (`Ok endpoint) + Ok endpoint end ) in @@ -860,15 +899,18 @@ struct then local_ip else src in begin - find_endpoint src - >>= function - | `Error (`Msg m) -> + find_endpoint src >>= function + | Error (`Msg m) -> Log.err (fun f -> f "Failed to create an endpoint for %a: %s" Ipaddr.V4.pp_hum dst m); Lwt.return_unit - | `Ok endpoint -> - Stack_udp.write ~source_port:src_port ~dest_ip:dst ~dest_port:dst_port - endpoint.Endpoint.udp4 payload + | Ok endpoint -> + Stack_udp.write ~src_port ~dst ~dst_port endpoint.Endpoint.udp4 payload + >|= function + | Error e -> + Log.err (fun f -> + f "Failed to write a UDP packet: %a" Stack_udp.pp_error e); + | Ok () -> () end | { Hostnet_udp.src = src, src_port; dst = dst, dst_port; _ } -> Log.err (fun f -> @@ -888,7 +930,10 @@ struct l2_client_id (Macaddr.to_string eth_src) (Macaddr.to_string eth_dst)); - Switch.write switch buf + (Switch.write switch buf >|= function + | Ok () -> () + | Error e -> + Log.err (fun l -> l "switch write failed: %a" Switch.pp_error e)) (* write packets from virtual network directly to client *) | _ -> Lwt.return_unit ); end; @@ -910,7 +955,10 @@ struct f "%d: forwarded to bridge for %s->%s" l2_client_id (Macaddr.to_string eth_src) (Macaddr.to_string eth_dst)); (* pass to virtual network *) - Vnet.write t.l2_switch t.l2_client_id buf + Vnet.write t.l2_switch t.l2_client_id buf >|= function + | Ok () -> () + | Error e -> + Log.err (fun l -> l "Vnet write failed: %a" Mirage_device.pp_error e) end else begin Lwt.return_unit (* drop if bridge is not used *) end @@ -935,59 +983,72 @@ struct if use_bridge then begin (* reply with global table if bridge is in use *) Lwt_mutex.with_lock global_arp_table.mutex (fun _ -> - or_failwith "arp" @@ Global_arp.connect ~table:global_arp_table.table - global_arp_ethif) + global_arp_ethif + |> Lwt.return) end else begin (* if not, use local table *) - or_failwith "arp" @@ Global_arp.connect ~table:local_arp_table global_arp_ethif + |> Lwt.return end end >>= fun arp -> - Global_arp.input arp (Cstruct.shift buf Wire_structs.sizeof_ethernet) + Global_arp.input arp (Cstruct.shift buf Ethif_wire.sizeof_ethernet) | Ok (Ethernet { payload = Ipv4 ({ dst; _ } as ipv4 ); _ }) -> (* For any new IP destination, create a stack to proxy for the remote system *) if List.mem dst local_ips then begin begin - let open Infix in - find_endpoint dst - >>= fun endpoint -> + let open Lwt_result.Infix in + find_endpoint dst >>= fun endpoint -> Log.debug (fun f -> f "creating local TCP/IP proxy for %a" Ipaddr.V4.pp_hum dst); - Local.create endpoint udp_nat local_ips + Local.create clock endpoint udp_nat local_ips end >>= function - | `Error (`Msg m) -> - Log.err (fun f -> f "Failed to create a TCP/IP stack: %s" m); + | Error e -> + Log.err (fun f -> + f "Failed to create a TCP/IP stack: %a" Switch.Port.pp_error e); Lwt.return_unit - | `Ok tcp_stack -> + | Ok tcp_stack -> (* inject the ethernet frame into the new stack *) - Local.input_ipv4 tcp_stack (Ipv4 ipv4) + Local.input_ipv4 tcp_stack (Ipv4 ipv4) >|= function + | Ok () -> () + | Error e -> + Log.err (fun f -> f "failed to read TCP/IP input: %a" pp_error e); end else begin begin - let open Infix in - find_endpoint dst - >>= fun endpoint -> + let open Lwt_result.Infix in + find_endpoint dst >>= fun endpoint -> Log.debug (fun f -> f "create remote TCP/IP proxy for %a" Ipaddr.V4.pp_hum dst); Remote.create endpoint udp_nat end >>= function - | `Error (`Msg m) -> - Log.err (fun f -> f "Failed to create a TCP/IP stack: %s" m); + | Error e -> + Log.err (fun f -> + f "Failed to create a TCP/IP stack: %a" + Switch.Port.pp_error e); Lwt.return_unit - | `Ok tcp_stack -> + | Ok tcp_stack -> (* inject the ethernet frame into the new stack *) - Remote.input_ipv4 tcp_stack (Ipv4 ipv4) + Remote.input_ipv4 tcp_stack (Ipv4 ipv4) >|= function + | Ok () -> () + | Error e -> + Log.err (fun l -> + l "error while reading remote IPv4 input: %a" + Stack_ipv4.pp_error e) end | _ -> Lwt.return_unit ) - >>= fun () -> - Log.info (fun f -> f "TCP/IP ready"); - Lwt.return t - - let create ?(host_names = [ Dns.Name.of_string "vpnkit.host" ]) config = + >>= function + | Error e -> + Log.err (fun f -> f "TCP/IP not ready: %a" Switch.pp_error e); + Lwt.fail_with "not ready" + | Ok () -> + Log.info (fun f -> f "TCP/IP ready"); + Lwt.return t + + let create ?(host_names = [ Dns.Name.of_string "vpnkit.host" ]) clock config = let driver = [ "com.docker.driver.amd64-linux" ] in let max_connections_path = driver @ [ "slirp"; "max-connections" ] in @@ -1190,17 +1251,13 @@ struct in Log.info (fun f -> f "updating resolvers to %s" (Hostnet_dns.Config.to_string config)); - !dns >>= fun t -> - Dns_forwarder.destroy t - >>= fun () -> + !dns >>= Dns_forwarder.destroy >|= fun () -> Dns_policy.remove ~priority:3; Dns_policy.add ~priority:3 ~config; let local_address = { Dns_forward.Config.Address.ip = Ipaddr.V4 local_ip; port = 0 } in - dns := Dns_forwarder.create ~local_address ~host_names - (Dns_policy.config ()); - Lwt.return_unit + dns := dns_forwarder ~local_address ~host_names clock in let rec monitor_dns_settings settings = @@ -1312,6 +1369,7 @@ struct bridge_connections; mtu; host_names; + clock; } in Lwt.return t @@ -1327,7 +1385,11 @@ struct Lwt.return mac (* may raise Not_found if id is unknown to the bridge *) end else begin (* new uuid, register in bridge *) (* register new client on bridge *) - or_failwith "l2_switch" @@ Lwt.return @@ Vnet.register l2_switch + let l2_client_id = match Vnet.register l2_switch with + | `Ok x -> Ok x + | `Error e -> Error e + in + or_failwith "l2_switch" @@ Lwt.return l2_client_id >>= fun l2_client_id -> let client_macaddr = (Vnet.mac l2_switch l2_client_id) in @@ -1411,7 +1473,7 @@ struct begin (* If bridge is in use, create unique IP and update global ARP *) if t.bridge_connections then begin - or_failwith_result "vmnet" @@ + or_failwith "vmnet" @@ Vmnet.of_fd ~client_macaddr_of_uuid:(client_macaddr_of_uuid t t.peer_ip l2_switch) ~server_macaddr:t.server_macaddr ~mtu:t.mtu client @@ -1423,10 +1485,10 @@ struct connect x l2_switch l2_client_id client_macaddr t.server_macaddr client_ip t.local_ip t.highest_ip t.extra_dns_ip t.mtu t.get_domain_search t.get_domain_name t.global_arp_table - t.bridge_connections + t.bridge_connections t.clock end else begin (* When bridge is disabled, just use fixed uuid and peer_ip from t *) - or_failwith_result "vmnet" @@ + or_failwith "vmnet" @@ Vmnet.of_fd ~client_macaddr_of_uuid:(fun _ -> Lwt.return default_client_macaddr) ~server_macaddr:t.server_macaddr ~mtu:t.mtu client @@ -1434,7 +1496,7 @@ struct let client_macaddr = Vmnet.get_client_macaddr x in connect x l2_switch (-1) client_macaddr t.server_macaddr t.peer_ip t.local_ip t.highest_ip t.extra_dns_ip t.mtu t.get_domain_search - t.get_domain_name t.global_arp_table t.bridge_connections + t.get_domain_name t.global_arp_table t.bridge_connections t.clock end end diff --git a/src/hostnet/slirp.mli b/src/hostnet/slirp.mli index 743c77b56..f46f9c9f7 100644 --- a/src/hostnet/slirp.mli +++ b/src/hostnet/slirp.mli @@ -14,7 +14,10 @@ type uuid_table = { table: (Uuidm.t, Ipaddr.V4.t * int) Hashtbl.t; } -type config = { +(** A slirp TCP/IP stack ready to accept connections *) + + +type 'clock config = { server_macaddr: Macaddr.t; peer_ip: Ipaddr.V4.t; local_ip: Ipaddr.V4.t; @@ -27,24 +30,29 @@ type config = { bridge_connections: bool; mtu: int; host_names: Dns.Name.t list; + clock: 'clock; } -(** A slirp TCP/IP stack ready to accept connections *) - module Make (Config: Active_config.S) (Vmnet: Sig.VMNET) (Dns_policy: Sig.DNS_POLICY) + (Clock: sig + include Mirage_clock_lwt.MCLOCK + val connect: unit -> t Lwt.t + end) + (Random: Mirage_random.C) (Host: Sig.HOST) - (Vnet : Vnetif.BACKEND) : + (Vnet : Vnetif.BACKEND with type macaddr = Macaddr.t) : sig - val create: ?host_names:Dns.Name.t list -> Config.t -> config Lwt.t + val create: ?host_names:Dns.Name.t list -> Clock.t -> Config.t -> + Clock.t config Lwt.t (** Initialise a TCP/IP stack, taking configuration from the Config.t *) type t - val connect: config -> Vmnet.fd -> Vnet.t -> t Lwt.t + val connect: Clock.t config -> Vmnet.fd -> Vnet.t -> t Lwt.t (** Read and write ethernet frames on the given fd, connected to the specified Vnetif backend *) @@ -62,7 +70,7 @@ sig (** Return the number of active NAT table entries *) val update_dns: ?local_ip:Ipaddr.t -> ?host_names:Dns.Name.t list -> - unit -> unit + Clock.t -> unit (** Update the DNS forwarder following a configuration change *) val update_http: ?http:string -> ?https:string -> ?exclude:string -> diff --git a/src/hostnet/vmnet.ml b/src/hostnet/vmnet.ml index 30bcd3be2..af5cde379 100644 --- a/src/hostnet/vmnet.ml +++ b/src/hostnet/vmnet.ml @@ -40,7 +40,7 @@ module Init = struct let version = Cstruct.LE.get_uint32 rest 5 in let commit = Cstruct.(to_string @@ sub rest 9 40) in let rest = Cstruct.shift rest sizeof in - Ok ({ magic; version; commit }, rest) + { magic; version; commit }, rest end module Command = struct @@ -147,18 +147,23 @@ end module Make(C: Sig.CONN) = struct - module Channel = Channel.Make(C) + module Channel = Mirage_channel_lwt.Make(C) - type stats = { - mutable rx_bytes: int64; - mutable rx_pkts: int32; - mutable tx_bytes: int64; - mutable tx_pkts: int32; - } + type page_aligned_buffer = Io_page.t + type macaddr = Macaddr.t + type 'a io = 'a Lwt.t + type buffer = Cstruct.t + type error = [Mirage_device.error | `Channel of Channel.write_error] + + let pp_error ppf = function + | #Mirage_device.error as e -> Mirage_device.pp_error ppf e + | `Channel e -> Channel.pp_write_error ppf e + + let failf fmt = Fmt.kstrf (fun e -> Lwt_result.fail (`Msg e)) fmt type t = { mutable fd: Channel.t option; - stats: stats; + stats: Mirage_net.stats; client_uuid: Uuidm.t; client_macaddr: Macaddr.t; server_macaddr: Macaddr.t; @@ -174,89 +179,76 @@ module Make(C: Sig.CONN) = struct after_disconnect_u: unit Lwt.u; } - let error_of_failure f = - Lwt.catch f (fun e -> Lwt_result.fail (`Msg (Printexc.to_string e))) - - exception Disconnected - - let get_fd t = match t.fd with - | Some fd -> fd - | None -> raise Disconnected - let get_client_uuid t = t.client_uuid let get_client_macaddr t = t.client_macaddr + let err_eof = Lwt_result.fail (`Msg "error: got EOF") + let err_read e = failf "error while reading: %a" Channel.pp_error e + let err_flush e = failf "error while flushing: %a" Channel.pp_write_error e + + let with_read x f = + x >>= function + | Error e -> err_read e + | Ok `Eof -> err_eof + | Ok (`Data x) -> f x + + let with_flush x f = + x >>= function + | Error e -> err_flush e + | Ok () -> f () + + let with_msg x f = + match x with + | Ok x -> f x + | Error _ as e -> Lwt.return e + let server_negotiate ~fd ~client_macaddr_of_uuid ~mtu = - error_of_failure (fun () -> - Channel.read_exactly ~len:Init.sizeof fd - >>= fun bufs -> - let buf = Cstruct.concat bufs in - let open Lwt_result.Infix in - Lwt.return (Init.unmarshal buf) - >>= fun (init, _) -> - Log.info (fun f -> - f "PPP.negotiate: received %s" (Init.to_string init)); - let (_: Cstruct.t) = Init.marshal Init.default buf in - let open Lwt.Infix in - Channel.write_buffer fd buf; - Channel.flush fd - >>= fun () -> - Channel.read_exactly ~len:Command.sizeof fd - >>= fun bufs -> - let buf = Cstruct.concat bufs in - let open Lwt_result.Infix in - Lwt.return (Command.unmarshal buf) - >>= fun (command, _) -> - Log.info (fun f -> - f "PPP.negotiate: received %s" (Command.to_string command)); - match command with - | Command.Ethernet uuid -> begin - let open Lwt.Infix in - client_macaddr_of_uuid uuid - >>= fun client_macaddr -> - let vif = Vif.create client_macaddr mtu () in - let buf = Cstruct.create Vif.sizeof in - let (_: Cstruct.t) = Vif.marshal vif buf in - let open Lwt.Infix in - Log.info (fun f -> f "PPP.negotiate: sending %s" (Vif.to_string vif)); - Channel.write_buffer fd buf; - Channel.flush fd - >>= fun () -> - Lwt_result.return (uuid, client_macaddr) - end - | Command.Bind_ipv4 _ -> - raise (Failure "PPP.negotiate: unsupported command Bind_ipv4") - ) + with_read (Channel.read_exactly ~len:Init.sizeof fd) @@ fun bufs -> + let buf = Cstruct.concat bufs in + let init, _ = Init.unmarshal buf in + Log.info (fun f -> f "PPP.negotiate: received %s" (Init.to_string init)); + let (_: Cstruct.t) = Init.marshal Init.default buf in + Channel.write_buffer fd buf; + with_flush (Channel.flush fd) @@ fun () -> + with_read (Channel.read_exactly ~len:Command.sizeof fd) @@ fun bufs -> + let buf = Cstruct.concat bufs in + with_msg (Command.unmarshal buf) @@ fun (command, _) -> + Log.info (fun f -> + f "PPP.negotiate: received %s" (Command.to_string command)); + match command with + | Command.Bind_ipv4 _ -> failf "PPP.negotiate: unsupported command Bind_ipv4" + | Command.Ethernet uuid -> + client_macaddr_of_uuid uuid >>= fun client_macaddr -> + let vif = Vif.create client_macaddr mtu () in + let buf = Cstruct.create Vif.sizeof in + let (_: Cstruct.t) = Vif.marshal vif buf in + Log.info (fun f -> f "PPP.negotiate: sending %s" (Vif.to_string vif)); + Channel.write_buffer fd buf; + with_flush (Channel.flush fd) @@ fun () -> + Lwt_result.return (uuid, client_macaddr) let client_negotiate ~uuid ~fd = - error_of_failure - (fun () -> - let open Lwt.Infix in - let buf = Cstruct.create Init.sizeof in - let (_: Cstruct.t) = Init.marshal Init.default buf in - Channel.write_buffer fd buf; - Channel.flush fd >>= fun () -> - Channel.read_exactly ~len:Init.sizeof fd >>= fun bufs -> - let buf = Cstruct.concat bufs in - let open Lwt_result.Infix in - Lwt.return (Init.unmarshal buf) >>= fun (init, _) -> - Log.info (fun f -> - f "Client.negotiate: received %s" (Init.to_string init)); - let buf = Cstruct.create Command.sizeof in - let (_: Cstruct.t) = Command.marshal (Command.Ethernet uuid) buf in - let open Lwt.Infix in - Channel.write_buffer fd buf; - Channel.flush fd >>= fun () -> - Channel.read_exactly ~len:Vif.sizeof fd >>= fun bufs -> - let buf = Cstruct.concat bufs in - let open Lwt_result.Infix in - Lwt.return (Vif.unmarshal buf)>>= fun (vif, _) -> - Log.debug (fun f -> f "Client.negotiate: vif %s" (Vif.to_string vif)); - Lwt_result.return (vif) - ) + let buf = Cstruct.create Init.sizeof in + let (_: Cstruct.t) = Init.marshal Init.default buf in + Channel.write_buffer fd buf; + with_flush (Channel.flush fd) @@ fun () -> + with_read (Channel.read_exactly ~len:Init.sizeof fd) @@ fun bufs -> + let buf = Cstruct.concat bufs in + let init, _ = Init.unmarshal buf in + Log.info (fun f -> f "Client.negotiate: received %s" (Init.to_string init)); + let buf = Cstruct.create Command.sizeof in + let (_: Cstruct.t) = Command.marshal (Command.Ethernet uuid) buf in + Channel.write_buffer fd buf; + with_flush (Channel.flush fd) @@ fun () -> + with_read (Channel.read_exactly ~len:Vif.sizeof fd) @@ fun bufs -> + let buf = Cstruct.concat bufs in + let open Lwt_result.Infix in + Lwt.return (Vif.unmarshal buf) >>= fun (vif, _) -> + Log.debug (fun f -> f "Client.negotiate: vif %s" (Vif.to_string vif)); + Lwt_result.return (vif) (* Use blocking I/O here so we can avoid Using Lwt_unix or Uwt. Ideally we would use a FLOW handle referencing a file/stream. *) @@ -308,7 +300,7 @@ module Make(C: Sig.CONN) = struct let make ~client_macaddr ~server_macaddr ~mtu ~client_uuid fd = let fd = Some fd in - let stats = { rx_bytes = 0L; rx_pkts = 0l; tx_bytes = 0L; tx_pkts = 0l } in + let stats = Mirage_net.Stats.create () in let write_header = Cstruct.create (1024 * Packet.sizeof) in let write_m = Lwt_mutex.create () in let pcap = None in @@ -347,9 +339,14 @@ module Make(C: Sig.CONN) = struct | Some fd -> t.fd <- None; Log.debug (fun f -> f "Vmnet.disconnect flushing channel"); - Channel.flush fd >>= fun () -> - Lwt.wakeup_later t.after_disconnect_u (); - Lwt.return () + (Channel.flush fd >|= function + | Ok () -> () + | Error e -> + Log.err (fun l -> + l "error while disconnecting the vmtnet connection: %a" + Channel.pp_write_error e); + ) >|= fun () -> + Lwt.wakeup_later t.after_disconnect_u () let after_disconnect t = t.after_disconnect @@ -379,163 +376,146 @@ module Make(C: Sig.CONN) = struct Lwt.return_unit ) - let drop_on_error f = Lwt.catch f (fun _ -> Lwt.return ()) - let writev t bufs = Lwt_mutex.with_lock t.write_m (fun () -> - drop_on_error (fun () -> - capture t bufs >>= fun () -> - let len = List.(fold_left (+) 0 (map Cstruct.len bufs)) in - if len > (t.mtu + ethernet_header_length) then begin - Log.err (fun f -> - f "Dropping over-large ethernet frame, length = %d, mtu = \ - %d" len t.mtu - ); - Lwt.return_unit - end else begin - let buf = Cstruct.create Packet.sizeof in - Packet.marshal len buf; - let fd = get_fd t in - Channel.write_buffer fd buf; - Log.debug (fun f -> - let b = Buffer.create 128 in - List.iter (Cstruct.hexdump_to_buffer b) bufs; - f "sending\n%s" (Buffer.contents b) - ); - List.iter (Channel.write_buffer fd) bufs; - Channel.flush fd - end - ) + capture t bufs >>= fun () -> + let len = List.(fold_left (+) 0 (map Cstruct.len bufs)) in + if len > (t.mtu + ethernet_header_length) then begin + Log.err (fun f -> + f "Dropping over-large ethernet frame, length = %d, mtu = \ + %d" len t.mtu + ); + Lwt_result.return () + end else begin + let buf = Cstruct.create Packet.sizeof in + Packet.marshal len buf; + match t.fd with + | None -> Lwt_result.fail `Disconnected + | Some fd -> + Channel.write_buffer fd buf; + Log.debug (fun f -> + let b = Buffer.create 128 in + List.iter (Cstruct.hexdump_to_buffer b) bufs; + f "sending\n%s" (Buffer.contents b) + ); + List.iter (Channel.write_buffer fd) bufs; + Channel.flush fd >|= function + | Ok () -> Ok () + | Error e -> Error (`Channel e) + end ) + let err_eof = + Log.debug (fun f -> f "PPP.listen: closing connection"); + Lwt.return false + + let err_unexpected t pp e = + Log.err (fun f -> + f "PPP.listen: caught unexpected %a: disconnecting" pp e); + disconnect t >>= fun () -> + Lwt.return false + + let with_fd t f = match t.fd with + | None -> Lwt.return false + | Some fd -> f fd + + let with_read t x f = + x >>= function + | Error e -> err_unexpected t Channel.pp_error e + | Ok `Eof -> err_eof + | Ok (`Data x) -> f x + + let with_msg t x f = + match x with + | Error (`Msg e) -> err_unexpected t Fmt.string e + | Ok x -> f x + let listen t callback = if t.listening then begin Log.debug (fun f -> f "PPP.listen: called a second time: doing nothing"); - Lwt.return (); + Lwt.return (Ok ()); end else begin t.listening <- true; let last_error_log = ref 0. in let rec loop () = - let open Lwt_result.Infix in - Lwt.catch - (fun () -> - let open Lwt.Infix in - let fd = get_fd t in - Channel.read_exactly ~len:Packet.sizeof fd - >>= fun bufs -> - let read_header = Cstruct.concat bufs in - let open Lwt_result.Infix in - Lwt.return (Packet.unmarshal read_header) - >>= fun (len, _) -> - let open Lwt.Infix in - Channel.read_exactly ~len fd - >>= fun bufs -> - capture t bufs - >>= fun () -> - Log.debug (fun f -> - let b = Buffer.create 128 in - List.iter (Cstruct.hexdump_to_buffer b) bufs; - f "received\n%s" (Buffer.contents b) - ); - let buf = Cstruct.concat bufs in - let callback buf = - Lwt.catch (fun () -> callback buf) - (function - | Host_uwt.Sockets.Too_many_connections - | Host_lwt_unix.Sockets.Too_many_connections -> - (* No need to log this again *) - Lwt.return_unit - | e -> - let now = Unix.gettimeofday () in - if (now -. !last_error_log) > 30. then begin - Log.err (fun f -> - f "PPP.listen callback caught %a" Fmt.exn e); - last_error_log := now; - end; - Lwt.return_unit - ) in - Lwt.async (fun () -> callback buf); - List.iter (fun callback -> - Lwt.async (fun () -> callback buf) - ) t.listeners; - Lwt_result.return true - ) (function - | End_of_file -> - Log.debug (fun f -> f "PPP.listen: closing connection"); - Lwt_result.return false - | Disconnected -> - Lwt_result.return false - | e -> - Log.err (fun f -> - f "PPP.listen: caught unexpected %a: disconnecting" Fmt.exn e); - let open Lwt.Infix in - disconnect t - >>= fun () -> - Lwt_result.return false - ) - >>= fun continue -> - if continue then loop () else Lwt_result.return () in + (with_fd t @@ fun fd -> + with_read t (Channel.read_exactly ~len:Packet.sizeof fd) @@ fun bufs -> + let read_header = Cstruct.concat bufs in + with_msg t (Packet.unmarshal read_header) @@ fun (len, _) -> + with_read t (Channel.read_exactly ~len fd) @@ fun bufs -> + capture t bufs >>= fun () -> + Log.debug (fun f -> + let b = Buffer.create 128 in + List.iter (Cstruct.hexdump_to_buffer b) bufs; + f "received\n%s" (Buffer.contents b) + ); + let buf = Cstruct.concat bufs in + let callback buf = + Lwt.catch (fun () -> callback buf) + (function + | Host_uwt.Sockets.Too_many_connections + | Host_lwt_unix.Sockets.Too_many_connections -> + (* No need to log this again *) + Lwt.return_unit + | e -> + let now = Unix.gettimeofday () in + if (now -. !last_error_log) > 30. then begin + Log.err (fun f -> + f "PPP.listen callback caught %a" Fmt.exn e); + last_error_log := now; + end; + Lwt.return_unit + ) + in + Lwt.async (fun () -> callback buf); + List.iter (fun callback -> + Lwt.async (fun () -> callback buf) + ) t.listeners; + Lwt.return true + ) >>= function + | true -> loop () + | false -> Lwt.return () + in Lwt.async @@ loop; - Lwt.return (); + Lwt.return (Ok ()); end let write t buf = Lwt_mutex.with_lock t.write_m (fun () -> - drop_on_error (fun () -> - capture t [ buf ] >>= fun () -> - let len = Cstruct.len buf in - if len > (t.mtu + ethernet_header_length) then begin - Log.err (fun f -> - f "Dropping over-large ethernet frame, length = %d, mtu = \ - %d" len t.mtu - ); - Lwt.return_unit - end else begin - if Cstruct.len t.write_header < Packet.sizeof then begin - t.write_header <- Cstruct.create (1024 * Packet.sizeof) - end; - Packet.marshal len t.write_header; - let fd = get_fd t in - Channel.write_buffer fd - (Cstruct.sub t.write_header 0 Packet.sizeof); - t.write_header <- Cstruct.shift t.write_header Packet.sizeof; - Log.debug (fun f -> - let b = Buffer.create 128 in - Cstruct.hexdump_to_buffer b buf; - f "sending\n%s" (Buffer.contents b) - ); - Channel.write_buffer fd buf; - Channel.flush fd - end)) - - let add_listener t callback = - t.listeners <- callback :: t.listeners + capture t [ buf ] >>= fun () -> + let len = Cstruct.len buf in + if len > (t.mtu + ethernet_header_length) then begin + Log.err (fun f -> + f "Dropping over-large ethernet frame, length = %d, mtu = \ + %d" len t.mtu + ); + Lwt.return (Ok ()) + end else begin + if Cstruct.len t.write_header < Packet.sizeof then begin + t.write_header <- Cstruct.create (1024 * Packet.sizeof) + end; + Packet.marshal len t.write_header; + match t.fd with + | None -> Lwt.return (Error `Disconnected) + | Some fd -> + Channel.write_buffer fd + (Cstruct.sub t.write_header 0 Packet.sizeof); + t.write_header <- Cstruct.shift t.write_header Packet.sizeof; + Log.debug (fun f -> + let b = Buffer.create 128 in + Cstruct.hexdump_to_buffer b buf; + f "sending\n%s" (Buffer.contents b) + ); + Channel.write_buffer fd buf; + Channel.flush fd >|= function + | Ok () -> Ok () + | Error e -> Error (`Channel e) + end) + let add_listener t callback = t.listeners <- callback :: t.listeners let mac t = t.server_macaddr - - type page_aligned_buffer = Io_page.t - - type buffer = Cstruct.t - - type error = [ - | `Unknown of string - | `Unimplemented - | `Disconnected - ] - - type macaddr = Macaddr.t - - type 'a io = 'a Lwt.t - - type id = unit - let get_stats_counters t = t.stats - - let reset_stats_counters t = - t.stats.rx_bytes <- 0L; - t.stats.tx_bytes <- 0L; - t.stats.rx_pkts <- 0l; - t.stats.tx_pkts <- 0l + let reset_stats_counters t = Mirage_net.Stats.reset t.stats end diff --git a/src/hostnet/vmnet.mli b/src/hostnet/vmnet.mli index e602666d0..62e54b026 100644 --- a/src/hostnet/vmnet.mli +++ b/src/hostnet/vmnet.mli @@ -4,8 +4,7 @@ module Make(C: Sig.CONN): sig type fd = C.flow - include V1_LWT.NETWORK - with type buffer = Cstruct.t + include Mirage_net_lwt.S with type buffer = Cstruct.t val after_disconnect: t -> unit Lwt.t (** [after_disconnect connection] resolves after [connection] has @@ -52,7 +51,7 @@ module Init : sig val default: t val marshal: t -> Cstruct.t -> Cstruct.t - val unmarshal: Cstruct.t -> (t * Cstruct.t, [ `Msg of string ]) result + val unmarshal: Cstruct.t -> t * Cstruct.t end module Command : sig diff --git a/src/ofs/active_config.ml b/src/ofs/active_config.ml index f8986eeef..be44aaaea 100644 --- a/src/ofs/active_config.ml +++ b/src/ofs/active_config.ml @@ -64,7 +64,7 @@ module type S = sig val bool: t -> default:bool -> path -> bool values Lwt.t end -module Make(Time: V1_LWT.TIME)(FLOW: V1_LWT.FLOW) = struct +module Make(Time: Mirage_time_lwt.S)(FLOW: Mirage_flow_lwt.S) = struct module Client = Protocol_9p.Client.Make(Log)(FLOW) @@ -132,13 +132,13 @@ module Make(Time: V1_LWT.TIME)(FLOW: V1_LWT.FLOW) = struct Client.connect flow ?username () ) >>= function | Error (`Msg x) -> - Time.sleep 0.1 + Time.sleep_ns (Duration.of_ms 100) >>= fun () -> loop ~x (n - 1) | Ok conn -> Lwt.return conn ) (fun e -> - Time.sleep 0.1 + Time.sleep_ns (Duration.of_ms 100) >>= fun () -> loop ~x:(Printexc.to_string e) (n - 1) ) in diff --git a/src/ofs/active_config.mli b/src/ofs/active_config.mli index 024ef17e8..c40aecde3 100644 --- a/src/ofs/active_config.mli +++ b/src/ofs/active_config.mli @@ -32,7 +32,7 @@ module type S = sig (** The stream of bool values at [path] *) end -module Make(Time: V1_LWT.TIME)(FLOW: V1_LWT.FLOW): sig +module Make(Time: Mirage_time_lwt.S)(FLOW: Mirage_flow_lwt.S): sig include S val create: ?username:string -> branch:string diff --git a/src/ofs/active_list.ml b/src/ofs/active_list.ml index cda42f9ba..5bde48cb8 100644 --- a/src/ofs/active_list.ml +++ b/src/ofs/active_list.ml @@ -33,6 +33,7 @@ end module type Instance = sig type t + type clock val to_string: t -> string val of_string: string -> (t, [ `Msg of string ]) result @@ -41,7 +42,7 @@ module type Instance = sig type context (** The context in which a [t] is [start]ed, for example a TCP/IP stack *) - val start: context Var.t -> t -> (t, [ `Msg of string ]) result Lwt.t + val start: clock -> context Var.t -> t -> (t, [ `Msg of string ]) result Lwt.t val stop: t -> unit Lwt.t @@ -49,28 +50,21 @@ module type Instance = sig val get_key: t -> key end -module Transaction = struct - type t = { - name: string; (* directory name *) - mutable source: string; - mutable destination: string; - } -end - module StringMap = Map.Make(String) -module Make(Instance: Instance) = struct +module Make (Instance: Instance) = struct open Protocol_9p type t = { mutable context: Instance.context Var.t; + clock: Instance.clock; } - let make () = + let make clock = let context = Var.create () in - { context } + { context; clock } - let set_context { context } x = Var.fill context x + let set_context { context; _ } x = Var.fill context x (* We manage a list of named entries *) type entry = { @@ -105,8 +99,6 @@ module Make(Instance: Instance) = struct module Error = struct let badfid = Lwt.return (Response.error "fid not found") - let badwalk = Lwt.return (Response.error "bad walk") (* TODO: ? *) - let enoent = Lwt.return (Response.error "file not found") let eperm = Lwt.return (Response.error "permission denied") end @@ -209,14 +201,6 @@ The directory will be deleted and replaced with a file of the same name. connection.fids := Types.Fid.Map.remove fid !(connection.fids); return () - let disconnect connection _ = - Log.debug (fun f -> f "disconnecting 9P client"); - let resources = - Types.Fid.Map.fold (fun _ resource acc -> resource :: acc) - !(connection.fids) [] - in - Lwt_list.iter_s free_resource resources - let open_ _connection ~cancel:_ _ = try let qid = next_qid [] in @@ -245,11 +229,6 @@ The directory will be deleted and replaced with a file of the same name. u = None; }) - let errors_to_client = (function - | Error (`Msg msg) -> Error { Response.Err.ename = msg; errno = None } - | Ok _ as ok -> ok - ) - let read_children count offset children = let buf = Cstruct.create count in let rec write off rest = function @@ -366,7 +345,8 @@ The directory will be deleted and replaced with a file of the same name. end else begin match Instance.of_string @@ Cstruct.to_string data with | Ok f -> let open Lwt.Infix in - begin Instance.start connection.t.context f >>= function + begin Instance.start connection.t.clock connection.t.context f >>= + function | Ok f' -> (* local_port is resolved *) entry.instance <- Some f'; entry.result <- Some ("OK " ^ (Instance.to_string f') ^ "\n"); diff --git a/src/ofs/jbuild b/src/ofs/jbuild index fc282f939..47f9784c0 100644 --- a/src/ofs/jbuild +++ b/src/ofs/jbuild @@ -2,6 +2,7 @@ (library ((name ofs) - (libraries (protocol-9p protocol-9p.unix cstruct cstruct.lwt lwt.unix - io-page.unix logs mirage-types.lwt astring named-pipe.lwt)) + (libraries (protocol-9p protocol-9p-unix cstruct cstruct-lwt lwt.unix + io-page-unix logs astring named-pipe.lwt + mirage-time-lwt mirage-flow-lwt duration)) (wrapped false))) diff --git a/vpnkit.opam b/vpnkit.opam index 9e15b3185..07bf72d06 100644 --- a/vpnkit.opam +++ b/vpnkit.opam @@ -56,6 +56,7 @@ depends: [ "astring" "mirage-flow-lwt" {>= "1.4.0"} "mirage-time-lwt" {>= "1.1.0"} + "mirage-time-unix" "mirage-protocols" {>= "1.1.0"} "mirage-channel" {>= "3.0.1"} "mirage-console-unix"