-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added wavelets of oder 1,2,3,4 #2
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,15 +3,19 @@ module ANOVAapprox | |||||||||||||||||||||||
using GroupedTransforms, | ||||||||||||||||||||||||
LinearAlgebra, IterativeSolvers, LinearMaps, Distributed, SpecialFunctions | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
bases = ["per", "cos", "cheb", "std"] | ||||||||||||||||||||||||
types = Dict("per" => ComplexF64, "cos" => Float64, "cheb" => Float64, "std" => Float64) | ||||||||||||||||||||||||
bases = ["per", "cos", "cheb", "std" ,"wav1", "wav2", "wav3", "wav4"] | ||||||||||||||||||||||||
types = Dict("per" => ComplexF64, "cos" => Float64, "cheb" => Float64, "std" => Float64, "wav1" => Float64, "wav2" => Float64,"wav3" => Float64,"wav4" => Float64) | ||||||||||||||||||||||||
vtypes = Dict( | ||||||||||||||||||||||||
"per" => Vector{ComplexF64}, | ||||||||||||||||||||||||
"cos" => Vector{Float64}, | ||||||||||||||||||||||||
"cheb" => Vector{Float64}, | ||||||||||||||||||||||||
"std" => Vector{Float64}, | ||||||||||||||||||||||||
"wav1" => Vector{Float64}, | ||||||||||||||||||||||||
"wav2" => Vector{Float64}, | ||||||||||||||||||||||||
"wav3" => Vector{Float64}, | ||||||||||||||||||||||||
"wav4" => Vector{Float64}, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
gt_systems = Dict("per" => "exp", "cos" => "cos", "cheb" => "cos", "std" => "cos") | ||||||||||||||||||||||||
gt_systems = Dict("per" => "exp", "cos" => "cos", "cheb" => "cos", "std" => "cos", "wav1" => "wav1", "wav2" => "wav2", "wav3" => "wav3", "wav4" => "wav4") | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
function get_orderDependentBW(U::Vector{Vector{Int}}, N::Vector{Int})::Vector{Int} | ||||||||||||||||||||||||
N_bw = zeros(Int64, length(U)) | ||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -8,20 +8,36 @@ function get_GSI( | |||||
λ::Float64; | ||||||
dict::Bool = false, | ||||||
)::Union{Vector{Float64},Dict{Vector{Int},Float64}} | ||||||
variances = norms(a.fc[λ]) .^ 2 | ||||||
if a.basis == "wav1" | ||||||
variances = norms(a.fc[λ],1,dict=false) .^ 2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
elseif a.basis == "wav2" | ||||||
variances = norms(a.fc[λ],2,dict=false) .^ 2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
elseif a.basis == "wav3" | ||||||
variances = norms(a.fc[λ],3,dict=false) .^ 2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
elseif a.basis == "wav4" | ||||||
variances = norms(a.fc[λ],4,dict=false) .^ 2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
else | ||||||
variances = norms(a.fc[λ],dict=false) .^ 2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
end | ||||||
variances = variances[2:end] | ||||||
variance_f = sum(variances) | ||||||
|
||||||
if dict | ||||||
gsis = Dict{Vector{Int},Float64}() | ||||||
for i = 1:length(a.fc[λ].setting) | ||||||
s = a.fc[λ].setting[i] | ||||||
u = s[:u] | ||||||
if u != [] | ||||||
gsis[u] = norm(a.fc[λ][u])^2 / variance_f | ||||||
end | ||||||
if a.basis == "wav1" | ||||||
variances = norms(a.fc[λ],1,dict=true) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
elseif a.basis == "wav2" | ||||||
variances = norms(a.fc[λ],2,dict=true) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
elseif a.basis == "wav3" | ||||||
variances = norms(a.fc[λ],3,dict=true) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
elseif a.basis == "wav4" | ||||||
variances = norms(a.fc[λ],4,dict=true) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
else | ||||||
variances = norms(a.fc[λ],dict=true) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
end | ||||||
return gsis | ||||||
|
||||||
return Dict((u,variances[u]^2/variance_f) for u in keys(variances)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
|
||||||
else | ||||||
return variances ./ variance_f | ||||||
end | ||||||
|
@@ -53,8 +69,8 @@ function get_AttributeRanking(a::approx, λ::Float64)::Vector{Float64} | |||||
|
||||||
factors = zeros(Int64, d, ds) | ||||||
|
||||||
for i = 1:d | ||||||
for j = 1:ds | ||||||
for i = 1:d | ||||||
for j = 1:ds | ||||||
for v in U | ||||||
if (i in v) && (length(v) == j) | ||||||
factors[i,j] += 1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
|
@@ -92,7 +108,7 @@ function get_ActiveSet( a::approx, eps::Vector{Float64}, λ::Float64 )::Vector{V | |||||
lengths = [ length(u) for u in U ] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
ds = maximum(lengths) | ||||||
|
||||||
if length(eps) != ds | ||||||
if length(eps) != ds | ||||||
error( "Entries in vector eps have to be ds.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
end | ||||||
|
||||||
|
@@ -107,7 +123,7 @@ function get_ActiveSet( a::approx, eps::Vector{Float64}, λ::Float64 )::Vector{V | |||||
end | ||||||
|
||||||
U_active = Vector{Vector{Int}}(undef, n+1) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
U_active[1] = [] | ||||||
U_active[1] = [] | ||||||
idx = 2 | ||||||
|
||||||
for i = 1:length(gsi) | ||||||
|
@@ -122,4 +138,4 @@ end | |||||
|
||||||
function get_ActiveSet(a::approx, eps::Vector{Float64})::Dict{Float64,Vector{Vector{Int}}} | ||||||
return Dict(λ => get_ActiveSet(a, eps, λ) for λ in collect(keys(a.fc))) | ||||||
end | ||||||
end |
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,10 +4,10 @@ | |||||||||||||||||
A struct to hold the scattered data function approximation. | ||||||||||||||||||
|
||||||||||||||||||
# Fields | ||||||||||||||||||
* `basis::String` - basis of the function space; currently choice of `"per"`, `"cos"`, `"cheb"`, and `"std"` | ||||||||||||||||||
* `basis::String` - basis of the function space; currently choice of `"per"`, `"cos"`, `"cheb"`,`"std"`, `"wav1"`, `"wav2"`,`"wav3"`,`"wav4"` | ||||||||||||||||||
* `X::Matrix{Float64}` - scattered data nodes with d rows and M columns | ||||||||||||||||||
* `y::Union{Vector{ComplexF64},Vector{Float64}}` - M function values (complex for `basis = "per"`, real ortherwise) | ||||||||||||||||||
* `U::Vector{Vector{Int}}` - a vector containing susbets of coordinate indices | ||||||||||||||||||
* `U::Vector{Vector{Int}}` - a vector containing susbets of coordinate indices | ||||||||||||||||||
* `N::Vector{Int}` - bandwdiths for each ANOVA term | ||||||||||||||||||
* `trafo::GroupedTransform` - holds the grouped transformation | ||||||||||||||||||
* `fc::Dict{Float64,GroupedCoefficients}` - holds the GroupedCoefficients after approximation for every different regularization parameters | ||||||||||||||||||
|
@@ -58,7 +58,7 @@ mutable struct approx | |||||||||||||||||
bw = N | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
if (basis == "per") && ((minimum(X) < -0.5) || (maximum(X) >= 0.5)) | ||||||||||||||||||
if (basis == "per" || basis == "wav1" || basis == "wav2" || basis == "wav3" || basis == "wav4" ) && ((minimum(X) < -0.5) || (maximum(X) >= 0.5)) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
error("Nodes need to be between -0.5 and 0.5.") | ||||||||||||||||||
elseif (basis == "cos") && ((minimum(X) < 0) || (maximum(X) > 1)) | ||||||||||||||||||
error("Nodes need to be between 0 and 1.") | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -145,15 +145,19 @@ end | |||||
This function computes the relative ``L_2`` error of the function given the norm `norm` and a function that returns the basis coefficients `bc_fun` for regularization parameter `λ`. | ||||||
""" | ||||||
function get_L2error(a::approx, norm::Float64, bc_fun::Function, λ::Float64)::Float64 | ||||||
error = norm^2 | ||||||
index_set = get_IndexSet(a.trafo.setting, size(a.X, 1)) | ||||||
|
||||||
for i = 1:size(index_set, 2) | ||||||
k = index_set[:, i] | ||||||
error += abs(bc_fun(k) - a.fc[λ][i])^2 - abs(bc_fun(k))^2 | ||||||
if a.basis=="per" || a.basis == "cos" || a.basis =="cheb"|| a.basis == "std" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
error = norm^2 | ||||||
index_set = get_IndexSet(a.trafo.setting, size(a.X, 1)) | ||||||
|
||||||
for i = 1:size(index_set, 2) | ||||||
k = index_set[:, i] | ||||||
error += abs(bc_fun(k) - a.fc[λ][i])^2 - abs(bc_fun(k))^2 | ||||||
end | ||||||
|
||||||
return sqrt(error) / norm | ||||||
else | ||||||
error("The L2-error is not implemented for this basis") | ||||||
end | ||||||
|
||||||
return sqrt(error) / norm | ||||||
end | ||||||
|
||||||
@doc raw""" | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -16,7 +16,7 @@ using .TestFunctionCheb | |||||
|
||||||
rng = MersenneTwister(1234) | ||||||
|
||||||
tests = ["misc", "cheb_fista", "cheb_lsqr", "per_lsqr", "per_fista"] | ||||||
tests = ["misc", "cheb_fista", "cheb_lsqr", "per_lsqr", "per_fista","wav_lsqr"] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
#tests = ["misc", "cheb_lsqr", "per_lsqr", "per_fista"] | ||||||
|
||||||
for t in tests | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,48 @@ | ||||||||||
#### PERIODIC TEST SOLVER LSQR #### | ||||||||||
using ANOVAapprox | ||||||||||
include("TestFunctionPeriodic.jl") | ||||||||||
using Test | ||||||||||
using Random | ||||||||||
using Aqua | ||||||||||
|
||||||||||
d = 6 | ||||||||||
ds = 2 | ||||||||||
M = 10_000 | ||||||||||
max_iter = 50 | ||||||||||
bw = [4,4] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||
λs = [0.0, 1.0] | ||||||||||
|
||||||||||
|
||||||||||
X = rand( d, M) .- 0.5 | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||
y = [TestFunctionPeriodic.f(X[:, i]) for i = 1:M] | ||||||||||
X_test = rand(d, M) .- 0.5 | ||||||||||
y_test = [TestFunctionPeriodic.f(X_test[:, i]) for i = 1:M] | ||||||||||
|
||||||||||
#### #### | ||||||||||
|
||||||||||
ads = ANOVAapprox.approx(X, y, ds, bw, "wav2") | ||||||||||
ANOVAapprox.approximate(ads, lambda = λs) | ||||||||||
|
||||||||||
println( "AR: ", sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)) ) | ||||||||||
@test abs( sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)) - 1 ) < 0.0001 | ||||||||||
Comment on lines
+26
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||
|
||||||||||
bw = ANOVAapprox.get_orderDependentBW(TestFunctionPeriodic.AS, [4,4]) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||
|
||||||||||
aU = ANOVAapprox.approx(X, y, TestFunctionPeriodic.AS, bw, "wav2") | ||||||||||
ANOVAapprox.approximate(aU, lambda = λs) | ||||||||||
|
||||||||||
err_l2_ds = ANOVAapprox.get_l2error(ads)[0.0] | ||||||||||
err_l2_U = ANOVAapprox.get_l2error(aU)[0.0] | ||||||||||
err_l2_rand_ds = ANOVAapprox.get_l2error(ads, X_test, y_test)[0.0] | ||||||||||
err_l2_rand_U = ANOVAapprox.get_l2error(aU, X_test, y_test)[0.0] | ||||||||||
|
||||||||||
println("== PERIODIC LSQR ==") | ||||||||||
println("l2 ds: ", err_l2_ds) | ||||||||||
println("l2 U: ", err_l2_U) | ||||||||||
println("l2 rand ds: ", err_l2_rand_ds) | ||||||||||
println("l2 rand U: ", err_l2_rand_U) | ||||||||||
|
||||||||||
@test err_l2_ds < 0.01 | ||||||||||
@test err_l2_U < 0.005 | ||||||||||
@test err_l2_rand_ds < 0.01 | ||||||||||
@test err_l2_rand_U < 0.005 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶