From 251789bcfadbc38de9e978fe96115b45bca03a03 Mon Sep 17 00:00:00 2001 From: Tommy Hofmann Date: Sun, 11 Aug 2024 12:43:45 +0200 Subject: [PATCH] feat: Add `length` support Adds a keyword argument `length=ex` to allow for specifying the length of the iterator. Closes #41 --- README.md | 18 +++++++++++++- src/macro.jl | 61 +++++++++++++++++++++++++++++++++++++++++++++-- test/test_main.jl | 18 +++++++++++++- 3 files changed, 93 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e76c1c7..ed5c597 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,22 @@ end for fib in fibonacci(10) println(fib) end + +# Example with specifies the length +using ResumableFunctions + +@resumable length=n^2 function fibonacci(n::Int) :: Int + a = 0 + b = 1 + for i in 1:n^2 + @yield a + a, b = b, a+b + end +end + +for fib in fibonacci(5) + println(fib) +end ``` ## Benchmarks @@ -137,4 +153,4 @@ A [detailed change log is kept](https://github.com/JuliaDynamics/ResumableFuncti * In a `try` block only top level `@yield` statements are allowed. * In a `finally` block a `@yield` statement is not allowed. * An anonymous function can not contain a `@yield` statement. -* Many more restrictions. \ No newline at end of file +* Many more restrictions. diff --git a/src/macro.jl b/src/macro.jl index 4203912..9e37a19 100755 --- a/src/macro.jl +++ b/src/macro.jl @@ -20,7 +20,7 @@ macro nosave(expr=nothing) end """ -Macro that transforms a function definition in a finite-statemachine: +Macro that transforms a function definition in a finite-state machine: - Defines a new `mutable struct` that implements the iterator interface and is used to store the internal state. - Makes this new type callable having following characteristics: @@ -28,8 +28,47 @@ Macro that transforms a function definition in a finite-statemachine: - returns at a `@yield` statement and; - continues after the `@yield` statement when called again. - Defines a constructor function that respects the calling conventions of the initial function definition and returns an object of the new type. + +If the element type and length is known, the resulting iterator can be made +more efficient as follows: +- Use `length=ex` to specify the length (if known) of the iterator, like: + @resumable length=ex function f(x); body; end + Here `ex` can be any expression containing the arguments of `f`. +- Use `function f(x)::T` to specify the element type of the iterator. + +# Extended + +```julia +julia> @resumable length=n^2 function f(n)::Int + for i in 1:n^2 + @yield i + end + end +f (generic function with 2 methods) + +julia> collect(f(3)) +9-element Vector{Int64}: + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +``` """ -macro resumable(expr::Expr) +macro resumable(ex::Expr...) + length(ex) >= 3 && error("Too many arguments") + for i in 1:length(ex)-1 + a = ex[i] + if !(a isa Expr && a.head === :(=) && a.args[1] in [:length]) + error("only keyword argument 'length' allowed") + end + end + + expr = ex[end] expr.head !== :function && error("Expression is not a function definition!") # The function that executes a step of the finite state machine @@ -128,6 +167,23 @@ macro resumable(expr::Expr) end func_def[:rtype] = nothing func_def[:body] = postwalk(x->transform_slots(x, keys(slots)), func_def[:body]) + + # Capture the length=... + interface_defs = [] + for i in 1:length(ex)-1 + a = ex[i] + if !(a isa Expr && a.head === :(=) && a.args[1] in [:length]) + error("only keyword argument 'length' allowed") + end + if a.args[1] === :length + push!(interface_defs, quote Base.IteratorSize(::Type{<: $type_name}) = Base.HasLength() end) + func_def2 = copy(func_def) + func_def2[:body] = a.args[2] + new_body = postwalk(x->transform_slots(x, keys(slots)), a.args[2]) + push!(interface_defs, quote Base.length(_fsmi::$type_name) = begin $new_body end end) + end + end + func_def[:body] = postwalk(transform_arg, func_def[:body]) func_def[:body] = postwalk(transform_exc, func_def[:body]) |> flatten ui8 = BoxedUInt8(zero(UInt8)) @@ -161,6 +217,7 @@ macro resumable(expr::Expr) esc(quote $type_expr $func_expr + $(interface_defs...) Base.@__doc__($call_expr) end) end diff --git a/test/test_main.jl b/test/test_main.jl index 7f104ba..d169924 100644 --- a/test/test_main.jl +++ b/test/test_main.jl @@ -226,4 +226,20 @@ end @testset "test_unstable" begin @test collect(test_unstable(3)) == ["number 1", "number 2", "number 3"] -end \ No newline at end of file +end + +# test length + +@testset "test_length" begin + @resumable length=n^2*m^2 function test_length(n, m) + for i in 1:n^2 + for j in 1:m^2 + @yield i + j + end + end + end + + @test length(test_length(10, 20)) === 10^2 * 20^2 + @test length(collect(test_length(10, 20))) === 10^2 * 20^2 + @test Base.IteratorSize(typeof(test_length(1, 1))) == Base.HasLength() +end