diff --git a/src/core/lwt_seq.ml b/src/core/lwt_seq.ml index 543dac1d1..ddfef7d37 100644 --- a/src/core/lwt_seq.ml +++ b/src/core/lwt_seq.ml @@ -267,10 +267,16 @@ let iter_n ?(max_concurrency = 1) f seq = loop [] max_concurrency (fun () -> Lwt.apply seq ()) let rec unfold f u () = + match f u with + | None -> return_nil + | Some (x, u') -> Lwt.return (Cons (x, unfold f u')) + | exception exc -> Lwt.fail exc + +let rec unfold_lwt f u () = let* x = f u in match x with | None -> return_nil - | Some (x, u') -> Lwt.return (Cons (x, unfold f u')) + | Some (x, u') -> Lwt.return (Cons (x, unfold_lwt f u')) let rec of_list = function | [] -> empty diff --git a/src/core/lwt_seq.mli b/src/core/lwt_seq.mli index 6add5f4c6..3b7e30165 100644 --- a/src/core/lwt_seq.mli +++ b/src/core/lwt_seq.mli @@ -107,7 +107,8 @@ val iter_n : ?max_concurrency:int -> ('a -> unit Lwt.t) -> 'a t -> unit Lwt.t @param max_concurrency defaults to [1]. @raise Invalid_argument if [max_concurrency < 1]. *) -val unfold : ('b -> ('a * 'b) option Lwt.t) -> 'b -> 'a t +val unfold : ('b -> ('a * 'b) option) -> 'b -> 'a t +val unfold_lwt : ('b -> ('a * 'b) option Lwt.t) -> 'b -> 'a t (** Build a sequence from a step function and an initial value. [unfold f u] returns [empty] if the promise [f u] resolves to [None], or [fun () -> Lwt.return (Cons (x, unfold f y))] if the promise [f u] resolves diff --git a/test/core/test_lwt_seq.ml b/test/core/test_lwt_seq.ml index 5f1b6a616..20b787cfb 100644 --- a/test/core/test_lwt_seq.ml +++ b/test/core/test_lwt_seq.ml @@ -9,12 +9,6 @@ open Test let l = [1; 2; 3; 4; 5] let a = Lwt_seq.of_list l -let b = - Lwt_seq.unfold - (function - | [] -> let+ () = Lwt.pause () in None - | x::xs -> let+ () = Lwt.pause () in Some (x, xs)) - l let rec pause n = if n <= 0 then Lwt.return_unit @@ -22,6 +16,12 @@ let rec pause n = let* () = Lwt.pause () in pause (n - 1) let pause n = pause (n mod 5) +let b = + Lwt_seq.unfold_lwt + (function + | [] -> let+ () = pause 2 in None + | x::xs -> let+ () = pause (x+2) in Some (x, xs)) + l let suite_base = suite "lwt_seq" [ test "fold_left" begin fun () -> @@ -134,8 +134,7 @@ let suite_base = suite "lwt_seq" [ test "unfold" begin fun () -> let range first last = - let step i = if i > last then Lwt.return_none - else Lwt.return_some (i, succ i) in + let step i = if i > last then None else Some (i, succ i) in Lwt_seq.unfold step first in let* a = Lwt_seq.to_list (range 1 3) in @@ -144,6 +143,20 @@ let suite_base = suite "lwt_seq" [ ([] = b) end; + test "unfold_lwt" begin fun () -> + let range first last = + let step i = + if i > last then Lwt.return_none else Lwt.return_some (i, succ i) + in + Lwt_seq.unfold_lwt step first + in + let* a = Lwt_seq.to_list (range 1 3) in + let+ b = Lwt_seq.to_list (range 1 0) in + ([1;2;3] = a) && + ([] = b) + end; + + test "fold-into-exception-from-of-seq" begin fun () -> let fail = fun () -> failwith "XXX" in let seq = fun () -> Seq.Cons (1, (fun () -> Seq.Cons (2, fail))) in