Skip to content

Commit

Permalink
feat: Add length support
Browse files Browse the repository at this point in the history
Adds a keyword argument `length=ex` to allow for specifying the
length of the iterator.

Closes #41
  • Loading branch information
thofma committed Aug 11, 2024
1 parent 3f7d845 commit 251789b
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 4 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ end
for fib in fibonacci(10)
println(fib)
end

# Example with specifies the length
using ResumableFunctions

@resumable length=n^2 function fibonacci(n::Int) :: Int
a = 0
b = 1
for i in 1:n^2
@yield a
a, b = b, a+b
end
end

for fib in fibonacci(5)
println(fib)
end
```

## Benchmarks
Expand Down Expand Up @@ -137,4 +153,4 @@ A [detailed change log is kept](https://github.com/JuliaDynamics/ResumableFuncti
* In a `try` block only top level `@yield` statements are allowed.
* In a `finally` block a `@yield` statement is not allowed.
* An anonymous function can not contain a `@yield` statement.
* Many more restrictions.
* Many more restrictions.
61 changes: 59 additions & 2 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,55 @@ macro nosave(expr=nothing)
end

"""
Macro that transforms a function definition in a finite-statemachine:
Macro that transforms a function definition in a finite-state machine:
- Defines a new `mutable struct` that implements the iterator interface and is used to store the internal state.
- Makes this new type callable having following characteristics:
- implementents the statements from the initial function definition but;
- returns at a `@yield` statement and;
- continues after the `@yield` statement when called again.
- Defines a constructor function that respects the calling conventions of the initial function definition and returns an object of the new type.
If the element type and length is known, the resulting iterator can be made
more efficient as follows:
- Use `length=ex` to specify the length (if known) of the iterator, like:
@resumable length=ex function f(x); body; end
Here `ex` can be any expression containing the arguments of `f`.
- Use `function f(x)::T` to specify the element type of the iterator.
# Extended
```julia
julia> @resumable length=n^2 function f(n)::Int
for i in 1:n^2
@yield i
end
end
f (generic function with 2 methods)
julia> collect(f(3))
9-element Vector{Int64}:
1
2
3
4
5
6
7
8
9
```
"""
macro resumable(expr::Expr)
macro resumable(ex::Expr...)
length(ex) >= 3 && error("Too many arguments")
for i in 1:length(ex)-1
a = ex[i]
if !(a isa Expr && a.head === :(=) && a.args[1] in [:length])
error("only keyword argument 'length' allowed")
end
end

expr = ex[end]
expr.head !== :function && error("Expression is not a function definition!")

# The function that executes a step of the finite state machine
Expand Down Expand Up @@ -128,6 +167,23 @@ macro resumable(expr::Expr)
end
func_def[:rtype] = nothing
func_def[:body] = postwalk(x->transform_slots(x, keys(slots)), func_def[:body])

# Capture the length=...
interface_defs = []
for i in 1:length(ex)-1
a = ex[i]
if !(a isa Expr && a.head === :(=) && a.args[1] in [:length])
error("only keyword argument 'length' allowed")
end
if a.args[1] === :length
push!(interface_defs, quote Base.IteratorSize(::Type{<: $type_name}) = Base.HasLength() end)
func_def2 = copy(func_def)
func_def2[:body] = a.args[2]
new_body = postwalk(x->transform_slots(x, keys(slots)), a.args[2])
push!(interface_defs, quote Base.length(_fsmi::$type_name) = begin $new_body end end)
end
end

func_def[:body] = postwalk(transform_arg, func_def[:body])
func_def[:body] = postwalk(transform_exc, func_def[:body]) |> flatten
ui8 = BoxedUInt8(zero(UInt8))
Expand Down Expand Up @@ -161,6 +217,7 @@ macro resumable(expr::Expr)
esc(quote
$type_expr
$func_expr
$(interface_defs...)
Base.@__doc__($call_expr)
end)
end
18 changes: 17 additions & 1 deletion test/test_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,20 @@ end

@testset "test_unstable" begin
@test collect(test_unstable(3)) == ["number 1", "number 2", "number 3"]
end
end

# test length

@testset "test_length" begin
@resumable length=n^2*m^2 function test_length(n, m)
for i in 1:n^2
for j in 1:m^2
@yield i + j
end
end
end

@test length(test_length(10, 20)) === 10^2 * 20^2
@test length(collect(test_length(10, 20))) === 10^2 * 20^2
@test Base.IteratorSize(typeof(test_length(1, 1))) == Base.HasLength()
end

0 comments on commit 251789b

Please sign in to comment.