Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added wavelets of oder 1,2,3,4 #2

Merged
merged 2 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/ANOVAapprox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@ module ANOVAapprox
using GroupedTransforms,
LinearAlgebra, IterativeSolvers, LinearMaps, Distributed, SpecialFunctions

bases = ["per", "cos", "cheb", "std"]
types = Dict("per" => ComplexF64, "cos" => Float64, "cheb" => Float64, "std" => Float64)
bases = ["per", "cos", "cheb", "std" ,"wav1", "wav2", "wav3", "wav4"]
types = Dict("per" => ComplexF64, "cos" => Float64, "cheb" => Float64, "std" => Float64, "wav1" => Float64, "wav2" => Float64,"wav3" => Float64,"wav4" => Float64)
Comment on lines +6 to +7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
bases = ["per", "cos", "cheb", "std" ,"wav1", "wav2", "wav3", "wav4"]
types = Dict("per" => ComplexF64, "cos" => Float64, "cheb" => Float64, "std" => Float64, "wav1" => Float64, "wav2" => Float64,"wav3" => Float64,"wav4" => Float64)
bases = ["per", "cos", "cheb", "std", "wav1", "wav2", "wav3", "wav4"]
types = Dict(
"per" => ComplexF64,
"cos" => Float64,
"cheb" => Float64,
"std" => Float64,
"wav1" => Float64,
"wav2" => Float64,
"wav3" => Float64,
"wav4" => Float64,
)

vtypes = Dict(
"per" => Vector{ComplexF64},
"cos" => Vector{Float64},
"cheb" => Vector{Float64},
"std" => Vector{Float64},
"wav1" => Vector{Float64},
"wav2" => Vector{Float64},
"wav3" => Vector{Float64},
"wav4" => Vector{Float64},
)
gt_systems = Dict("per" => "exp", "cos" => "cos", "cheb" => "cos", "std" => "cos")
gt_systems = Dict("per" => "exp", "cos" => "cos", "cheb" => "cos", "std" => "cos", "wav1" => "wav1", "wav2" => "wav2", "wav3" => "wav3", "wav4" => "wav4")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
gt_systems = Dict("per" => "exp", "cos" => "cos", "cheb" => "cos", "std" => "cos", "wav1" => "wav1", "wav2" => "wav2", "wav3" => "wav3", "wav4" => "wav4")
gt_systems = Dict(
"per" => "exp",
"cos" => "cos",
"cheb" => "cos",
"std" => "cos",
"wav1" => "wav1",
"wav2" => "wav2",
"wav3" => "wav3",
"wav4" => "wav4",
)


function get_orderDependentBW(U::Vector{Vector{Int}}, N::Vector{Int})::Vector{Int}
N_bw = zeros(Int64, length(U))
Expand Down
42 changes: 29 additions & 13 deletions src/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,36 @@ function get_GSI(
λ::Float64;
dict::Bool = false,
)::Union{Vector{Float64},Dict{Vector{Int},Float64}}
variances = norms(a.fc[λ]) .^ 2
if a.basis == "wav1"
variances = norms(a.fc[λ],1,dict=false) .^ 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],1,dict=false) .^ 2
variances = norms(a.fc[λ], 1, dict = false) .^ 2

elseif a.basis == "wav2"
variances = norms(a.fc[λ],2,dict=false) .^ 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],2,dict=false) .^ 2
variances = norms(a.fc[λ], 2, dict = false) .^ 2

elseif a.basis == "wav3"
variances = norms(a.fc[λ],3,dict=false) .^ 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],3,dict=false) .^ 2
variances = norms(a.fc[λ], 3, dict = false) .^ 2

elseif a.basis == "wav4"
variances = norms(a.fc[λ],4,dict=false) .^ 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],4,dict=false) .^ 2
variances = norms(a.fc[λ], 4, dict = false) .^ 2

else
variances = norms(a.fc[λ],dict=false) .^ 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],dict=false) .^ 2
variances = norms(a.fc[λ], dict = false) .^ 2

end
variances = variances[2:end]
variance_f = sum(variances)

if dict
gsis = Dict{Vector{Int},Float64}()
for i = 1:length(a.fc[λ].setting)
s = a.fc[λ].setting[i]
u = s[:u]
if u != []
gsis[u] = norm(a.fc[λ][u])^2 / variance_f
end
if a.basis == "wav1"
variances = norms(a.fc[λ],1,dict=true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],1,dict=true)
variances = norms(a.fc[λ], 1, dict = true)

elseif a.basis == "wav2"
variances = norms(a.fc[λ],2,dict=true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],2,dict=true)
variances = norms(a.fc[λ], 2, dict = true)

elseif a.basis == "wav3"
variances = norms(a.fc[λ],3,dict=true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],3,dict=true)
variances = norms(a.fc[λ], 3, dict = true)

elseif a.basis == "wav4"
variances = norms(a.fc[λ],4,dict=true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],4,dict=true)
variances = norms(a.fc[λ], 4, dict = true)

else
variances = norms(a.fc[λ],dict=true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
variances = norms(a.fc[λ],dict=true)
variances = norms(a.fc[λ], dict = true)

end
return gsis

return Dict((u,variances[u]^2/variance_f) for u in keys(variances))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return Dict((u,variances[u]^2/variance_f) for u in keys(variances))
return Dict((u, variances[u]^2 / variance_f) for u in keys(variances))


else
return variances ./ variance_f
end
Expand Down Expand Up @@ -53,8 +69,8 @@ function get_AttributeRanking(a::approx, λ::Float64)::Vector{Float64}

factors = zeros(Int64, d, ds)

for i = 1:d
for j = 1:ds
for i = 1:d
for j = 1:ds
for v in U
if (i in v) && (length(v) == j)
factors[i,j] += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
factors[i,j] += 1
factors[i, j] += 1

Expand Down Expand Up @@ -92,7 +108,7 @@ function get_ActiveSet( a::approx, eps::Vector{Float64}, λ::Float64 )::Vector{V
lengths = [ length(u) for u in U ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
lengths = [ length(u) for u in U ]
lengths = [length(u) for u in U]

ds = maximum(lengths)

if length(eps) != ds
if length(eps) != ds
error( "Entries in vector eps have to be ds.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
error( "Entries in vector eps have to be ds.")
error("Entries in vector eps have to be ds.")

end

Expand All @@ -107,7 +123,7 @@ function get_ActiveSet( a::approx, eps::Vector{Float64}, λ::Float64 )::Vector{V
end

U_active = Vector{Vector{Int}}(undef, n+1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
U_active = Vector{Vector{Int}}(undef, n+1)
U_active = Vector{Vector{Int}}(undef, n + 1)

U_active[1] = []
U_active[1] = []
idx = 2

for i = 1:length(gsi)
Expand All @@ -122,4 +138,4 @@ end

function get_ActiveSet(a::approx, eps::Vector{Float64})::Dict{Float64,Vector{Vector{Int}}}
return Dict(λ => get_ActiveSet(a, eps, λ) for λ in collect(keys(a.fc)))
end
end
6 changes: 3 additions & 3 deletions src/approx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
A struct to hold the scattered data function approximation.

# Fields
* `basis::String` - basis of the function space; currently choice of `"per"`, `"cos"`, `"cheb"`, and `"std"`
* `basis::String` - basis of the function space; currently choice of `"per"`, `"cos"`, `"cheb"`,`"std"`, `"wav1"`, `"wav2"`,`"wav3"`,`"wav4"`
* `X::Matrix{Float64}` - scattered data nodes with d rows and M columns
* `y::Union{Vector{ComplexF64},Vector{Float64}}` - M function values (complex for `basis = "per"`, real ortherwise)
* `U::Vector{Vector{Int}}` - a vector containing susbets of coordinate indices
* `U::Vector{Vector{Int}}` - a vector containing susbets of coordinate indices
* `N::Vector{Int}` - bandwdiths for each ANOVA term
* `trafo::GroupedTransform` - holds the grouped transformation
* `fc::Dict{Float64,GroupedCoefficients}` - holds the GroupedCoefficients after approximation for every different regularization parameters
Expand Down Expand Up @@ -58,7 +58,7 @@ mutable struct approx
bw = N
end

if (basis == "per") && ((minimum(X) < -0.5) || (maximum(X) >= 0.5))
if (basis == "per" || basis == "wav1" || basis == "wav2" || basis == "wav3" || basis == "wav4" ) && ((minimum(X) < -0.5) || (maximum(X) >= 0.5))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if (basis == "per" || basis == "wav1" || basis == "wav2" || basis == "wav3" || basis == "wav4" ) && ((minimum(X) < -0.5) || (maximum(X) >= 0.5))
if (
basis == "per" ||
basis == "wav1" ||
basis == "wav2" ||
basis == "wav3" ||
basis == "wav4"
) && ((minimum(X) < -0.5) || (maximum(X) >= 0.5))

error("Nodes need to be between -0.5 and 0.5.")
elseif (basis == "cos") && ((minimum(X) < 0) || (maximum(X) > 1))
error("Nodes need to be between 0 and 1.")
Expand Down
20 changes: 12 additions & 8 deletions src/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,19 @@ end
This function computes the relative ``L_2`` error of the function given the norm `norm` and a function that returns the basis coefficients `bc_fun` for regularization parameter `λ`.
"""
function get_L2error(a::approx, norm::Float64, bc_fun::Function, λ::Float64)::Float64
error = norm^2
index_set = get_IndexSet(a.trafo.setting, size(a.X, 1))

for i = 1:size(index_set, 2)
k = index_set[:, i]
error += abs(bc_fun(k) - a.fc[λ][i])^2 - abs(bc_fun(k))^2
if a.basis=="per" || a.basis == "cos" || a.basis =="cheb"|| a.basis == "std"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if a.basis=="per" || a.basis == "cos" || a.basis =="cheb"|| a.basis == "std"
if a.basis == "per" || a.basis == "cos" || a.basis == "cheb" || a.basis == "std"

error = norm^2
index_set = get_IndexSet(a.trafo.setting, size(a.X, 1))

for i = 1:size(index_set, 2)
k = index_set[:, i]
error += abs(bc_fun(k) - a.fc[λ][i])^2 - abs(bc_fun(k))^2
end

return sqrt(error) / norm
else
error("The L2-error is not implemented for this basis")
end

return sqrt(error) / norm
end

@doc raw"""
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using .TestFunctionCheb

rng = MersenneTwister(1234)

tests = ["misc", "cheb_fista", "cheb_lsqr", "per_lsqr", "per_fista"]
tests = ["misc", "cheb_fista", "cheb_lsqr", "per_lsqr", "per_fista","wav_lsqr"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
tests = ["misc", "cheb_fista", "cheb_lsqr", "per_lsqr", "per_fista","wav_lsqr"]
tests = ["misc", "cheb_fista", "cheb_lsqr", "per_lsqr", "per_fista", "wav_lsqr"]

#tests = ["misc", "cheb_lsqr", "per_lsqr", "per_fista"]

for t in tests
Expand Down
48 changes: 48 additions & 0 deletions test/wav_lsqr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#### PERIODIC TEST SOLVER LSQR ####
using ANOVAapprox
include("TestFunctionPeriodic.jl")
using Test
using Random
using Aqua

d = 6
ds = 2
M = 10_000
max_iter = 50
bw = [4,4]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
bw = [4,4]
bw = [4, 4]

λs = [0.0, 1.0]


X = rand( d, M) .- 0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
X = rand( d, M) .- 0.5
X = rand(d, M) .- 0.5

y = [TestFunctionPeriodic.f(X[:, i]) for i = 1:M]
X_test = rand(d, M) .- 0.5
y_test = [TestFunctionPeriodic.f(X_test[:, i]) for i = 1:M]

#### ####

ads = ANOVAapprox.approx(X, y, ds, bw, "wav2")
ANOVAapprox.approximate(ads, lambda = λs)

println( "AR: ", sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)) )
@test abs( sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)) - 1 ) < 0.0001
Comment on lines +26 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
println( "AR: ", sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)) )
@test abs( sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)) - 1 ) < 0.0001
println("AR: ", sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)))
@test abs(sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)) - 1) < 0.0001


bw = ANOVAapprox.get_orderDependentBW(TestFunctionPeriodic.AS, [4,4])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
bw = ANOVAapprox.get_orderDependentBW(TestFunctionPeriodic.AS, [4,4])
bw = ANOVAapprox.get_orderDependentBW(TestFunctionPeriodic.AS, [4, 4])


aU = ANOVAapprox.approx(X, y, TestFunctionPeriodic.AS, bw, "wav2")
ANOVAapprox.approximate(aU, lambda = λs)

err_l2_ds = ANOVAapprox.get_l2error(ads)[0.0]
err_l2_U = ANOVAapprox.get_l2error(aU)[0.0]
err_l2_rand_ds = ANOVAapprox.get_l2error(ads, X_test, y_test)[0.0]
err_l2_rand_U = ANOVAapprox.get_l2error(aU, X_test, y_test)[0.0]

println("== PERIODIC LSQR ==")
println("l2 ds: ", err_l2_ds)
println("l2 U: ", err_l2_U)
println("l2 rand ds: ", err_l2_rand_ds)
println("l2 rand U: ", err_l2_rand_U)

@test err_l2_ds < 0.01
@test err_l2_U < 0.005
@test err_l2_rand_ds < 0.01
@test err_l2_rand_U < 0.005