Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code suggestions for #1210 #1213

Merged
Changes from 8 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
8b235b9
move `ex.args[2] isa Integer`
hyrodium Oct 29, 2023
477ab09
split `if` block
hyrodium Oct 29, 2023
684bd0a
simplify :zeros and :ones
hyrodium Oct 29, 2023
4885ad9
refactor :rand
hyrodium Oct 29, 2023
d048d16
refactor :randn and :randexp
hyrodium Oct 29, 2023
835ad76
update comments
hyrodium Oct 29, 2023
5cbc347
add _isnonnegvec
hyrodium Oct 30, 2023
d279187
update with `_isnonnegvec`
hyrodium Oct 30, 2023
2ed9633
add `_isnonnegvec(args, n)` method to check the size of `args`
hyrodium Oct 31, 2023
16dbbc7
fix `@SArray` for `@SArray rand(rng,T,dim)` etc.
hyrodium Oct 31, 2023
563ea0c
update comments
hyrodium Oct 31, 2023
d1a2d32
update `@SVector` macro
hyrodium Oct 31, 2023
b3cf411
update `@SMatrix`
hyrodium Nov 1, 2023
ff9b32a
update `@SVector`
hyrodium Nov 1, 2023
6eb7829
update `@SArray`
hyrodium Nov 1, 2023
0a9fc2b
introduce `fargs` variable
hyrodium Nov 12, 2023
337a2e0
avoid `_isnonnegvec` in `static_matrix_gen`
hyrodium Nov 12, 2023
5021855
avoid `_isnonnegvec` in `static_vector_gen`
hyrodium Nov 12, 2023
7352e26
remove unnecessary `_isnonnegvec`
hyrodium Nov 18, 2023
623cb0c
add `_rng()` function
hyrodium Nov 18, 2023
d1cf08d
update tests on `@SVector` macro
hyrodium Nov 18, 2023
0403edf
update tests on `@MVector` macro
hyrodium Dec 1, 2023
9f0753b
organize test/MMatrix.jl and test/SMatrix.jl
hyrodium Dec 2, 2023
12b7634
organize test/MMatrix.jl and test/SMatrix.jl
hyrodium Dec 2, 2023
cb07fcf
update with broken tests
hyrodium Dec 2, 2023
b4ace0d
organize test/MMatrix.jl and test/SMatrix.jl for `rand*` functions
hyrodium Dec 2, 2023
6c5a02b
fix around `broken` key for `@test` macro
hyrodium Dec 2, 2023
9217564
fix zero-length tests
hyrodium Dec 2, 2023
1b0c4f4
update `test/SArray.jl` to match `test/MArray.jl`
hyrodium Jan 2, 2024
24265b9
update tests for `@SArray ones` etc.
hyrodium Jan 2, 2024
c9df6a8
add supports for `@SArray ones(3-1,2)` etc.
hyrodium Jan 2, 2024
929e692
move block for `fill`
hyrodium Jan 2, 2024
4feb7af
update macro `@SArray rand(rng,2,3)` to use ordinary dispatches
hyrodium Jan 2, 2024
d17ee6b
update around `@SArray randn` etc.
hyrodium Jan 2, 2024
af095e3
remove unnecessary dollars
hyrodium Jan 2, 2024
26c922e
simplify `@SArray fill`
hyrodium Jan 2, 2024
2cbb67a
add `@testset "expand_error"`
hyrodium Jan 2, 2024
6a03236
update tests for `@SArray rand(...)` etc.
hyrodium Jan 2, 2024
0f1c559
fix bug in `rand*_with_Val`
hyrodium Jan 2, 2024
89fc5e5
cleanup tests
hyrodium Jan 2, 2024
c55b170
update macro `@SMatrix rand(rng,2,3)` to use ordinary dispatches
hyrodium Jan 2, 2024
0bb47e7
update macro `@SVector rand(rng,3)` to use ordinary dispatches
hyrodium Jan 2, 2024
c8cb1f5
move block for `fill`
hyrodium Jan 2, 2024
2e40f73
simplify `_randexp_with_Val`
hyrodium Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 97 additions & 21 deletions src/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ function parse_cat_ast(ex::Expr)
end

escall(args) = Iterators.map(esc, args)
function _isnonnegvec(args)
length(args) == 0 && return false
all(isa.(args, Integer)) && return all(args .≥ 0)
return false
end
function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
if !isa(ex, Expr)
error("Bad input for @$SA")
Expand Down Expand Up @@ -197,37 +202,45 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if f === :zeros || f === :ones
if length(ex.args) == 1
f === :zeros || f === :ones || error("@$SA got bad expression: $(ex)")
# for calls like `zeros()`
return :($f($SA{$Tuple{},$Float64}))
elseif f !== :rand || length(ex.args) == 2
return quote
if isa($(esc(ex.args[2])), DataType)
$f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))})
else
$f($SA{$Tuple{$(escall(ex.args[2:end])...)}})
end
end
elseif _isnonnegvec(ex.args[2:end])
# for calls like `zeros(dims...)`
return :($f($SA{$Tuple{$(escall(ex.args[2:end])...)}}))
elseif length(ex.args) == 2
# for calls like `zeros(type)`
return :($f($SA{$Tuple{},$(esc(ex.args[2]))}))
elseif _isnonnegvec(ex.args[3:end])
# for calls like `zeros(type, dims...)`
return :($f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))}))
else
error("@$SA got bad expression: $(ex)")
end
elseif f === :rand
if length(ex.args) == 1
# No support for `@SArray rand()`
error("@$SA got bad expression: $(ex)")
elseif _isnonnegvec(ex.args[2:end])
# for calls like `rand(dims...)`
return :($f($SA{$Tuple{$(escall(ex.args[2:end])...)}}))
elseif _isnonnegvec(ex.args[3:end])
# for calls like `rand(rng, dims...)`
# for calls like `rand(type, dims...)`
# for calls like `rand(sampler, dims...)`
return quote
if isa($(esc(ex.args[2])), DataType)
$f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))})
elseif isa($(esc(ex.args[2])), Integer)
$f($SA{$Tuple{$(escall(ex.args[2:end])...)}})
elseif isa($(esc(ex.args[2])), Random.AbstractRNG)
# for calls like rand(rng::AbstractRNG, sampler, dims::Integer...)
if isa($(esc(ex.args[2])), Random.AbstractRNG)
StaticArrays._rand(
$(esc(ex.args[2])),
$(esc(ex.args[3])),
Size($(escall(ex.args[4:end])...)),
Float64,
Size($(escall(ex.args[3:end])...)),
$SA{
Tuple{$(escall(ex.args[4:end])...)},
Random.gentype($(esc(ex.args[3]))),
Tuple{$(escall(ex.args[3:end])...)},
Float64,
},
)
else
# for calls like rand(sampler, dims::Integer...)
StaticArrays._rand(
Random.GLOBAL_RNG,
$(esc(ex.args[2])),
Expand All @@ -239,6 +252,69 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
)
end
end
elseif _isnonnegvec(ex.args[4:end])
# for calls like `rand(rng, type, dims...)`
# for calls like `rand(rng, sampler, dims...)`
return quote
StaticArrays._rand(
$(esc(ex.args[2])),
$(esc(ex.args[3])),
Size($(escall(ex.args[4:end])...)),
$SA{
Tuple{$(escall(ex.args[4:end])...)},
Random.gentype($(esc(ex.args[3]))),
},
)
end
else
error("@$SA got bad expression: $(ex)")
end
elseif f === :randn || f === :randexp
_f = Symbol(:_, f)
if length(ex.args) == 1
# No support for `@SArray randn()` etc.
error("@$SA got bad expression: $(ex)")
elseif _isnonnegvec(ex.args[2:end])
# for calls like `randn(dims...)`
return :($f($SA{$Tuple{$(escall(ex.args[2:end])...)}}))
elseif _isnonnegvec(ex.args[3:end])
mateuszbaran marked this conversation as resolved.
Show resolved Hide resolved
# for calls like `randn(rng, dims...)`
# for calls like `randn(type, dims...)`
return quote
if isa($(esc(ex.args[2])), Random.AbstractRNG)
StaticArrays.$_f(
$(esc(ex.args[2])),
Size($(escall(ex.args[3:end])...)),
$SA{
Tuple{$(escall(ex.args[3:end])...)},
Float64,
},
)
else
StaticArrays.$_f(
Random.GLOBAL_RNG,
Size($(escall(ex.args[3:end])...)),
$SA{
Tuple{$(escall(ex.args[3:end])...)},
$(esc(ex.args[2])),
},
)
end
end
elseif _isnonnegvec(ex.args[4:end])
# for calls like `randn(rng, type, dims...)`
return quote
StaticArrays.$_f(
$(esc(ex.args[2])),
Size($(escall(ex.args[4:end])...)),
$SA{
Tuple{$(escall(ex.args[4:end])...)},
$(esc(ex.args[3])),
},
)
end
else
error("@$SA got bad expression: $(ex)")
end
elseif f === :fill
length(ex.args) == 1 && error("@$SA got bad expression: $(ex)")
Expand Down
Loading