Skip to content

Commit

Permalink
Merge pull request #123 from fjebaker/fergus/copy
Browse files Browse the repository at this point in the history
Copy
  • Loading branch information
fjebaker committed Jun 30, 2024
2 parents dc082cf + e84a206 commit ea9ddec
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 10 deletions.
4 changes: 4 additions & 0 deletions src/abstract-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ function allocate_model_output(
construct_objective_cache(T, model, domain)
end

function Base.copy(m::AbstractSpectralModel)
typeof(m)((copy(getproperty(m, f)) for f in fieldnames(typeof(m)))...)
end

# printing

function _printinfo(io::IO, m::M) where {M<:AbstractSpectralModel}
Expand Down
6 changes: 5 additions & 1 deletion src/composite-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ struct CompositeModel{M1,M2,O,T,K} <: AbstractSpectralModel{T,K}
m2::M2,
op::O,
) where {M1<:AbstractSpectralModel{T},M2<:AbstractSpectralModel{T,K},O} where {T,K} =
new{M1,M2,O,T,K}(deepcopy(m1), deepcopy(m2), op)
new{M1,M2,O,T,K}(copy(m1), copy(m2), op)
end

function Base.copy(m::CompositeModel)
CompositeModel(getfield(m, :left), getfield(m, :right), getfield(m, :op))
end

function implementation(::Type{<:CompositeModel{M1,M2}}) where {M1,M2}
Expand Down
10 changes: 9 additions & 1 deletion src/fitparam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function fit_param_default_error(val)
end

# concrete types
mutable struct FitParam{T}
mutable struct FitParam{T<:Number}
value::T
error::T

Expand Down Expand Up @@ -47,6 +47,14 @@ Base.isapprox(f1::FitParam, f2::FitParam; kwargs...) =
Base.:(==)(f1::FitParam, f2::FitParam) = f1.value == f2.value
Base.convert(T::Type{<:Number}, f::FitParam) = convert(T, f.value)

Base.copy(f::FitParam) = FitParam(
f.value;
error = f.error,
lower_limit = f.lower_limit,
upper_limit = f.upper_limit,
frozen = f.frozen,
)

paramtype(::Type{FitParam{T}}) where {T} = T
paramtype(::T) where {T<:FitParam} = paramtype(T)

Expand Down
4 changes: 4 additions & 0 deletions src/meta-models/caching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ struct AutoCache{M,T,K,C<:CacheEntry} <: AbstractModelWrapper{M,T,K}
end
end

function Base.copy(m::AutoCache)
AutoCache(copy(m.model), deepcopy(m.cache), m.abstol)
end

function AutoCache(model::AbstractSpectralModel{T,K}; abstol = 1e-3) where {T,K}
params = [get_value.(parameter_tuple(model))...]
cache = CacheEntry(params)
Expand Down
9 changes: 9 additions & 0 deletions src/meta-models/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ struct AsConvolution{M,T,V,P} <: AbstractModelWrapper{M,T,Convolutional}
end
end

"""
Base.copy(m::AsConvolution)
Creates a copy of an [`AsConvolution`](@ref) wrapped model. Will make a
`deepcopy` of the cache to elimiate possible thread contention, but does not
copy the domain.
"""
Base.copy(m::AsConvolution) = AsConvolution(copy(m.model), m.domain, deepcopy(m.cache))

function AsConvolution(
model::AbstractSpectralModel{T};
domain = collect(range(0, 2, 100)),
Expand Down
2 changes: 2 additions & 0 deletions src/meta-models/surrogate-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ struct SurrogateSpectralModel{T,K,N,S,Symbols} <: AbstractSpectralModel{T,K}
params::NTuple{N,T}
end

Base.copy(s::SurrogateSpectralModel) = typeof(s)(deepcopy(s.surrogate), deepcopy(s.params))

function SurrogateSpectralModel(
::K,
surrogate::S,
Expand Down
31 changes: 26 additions & 5 deletions src/meta-models/table-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,33 @@ First field in the struct **must be** `table`. See
abstract type AbstractTableModel{T,K} <: AbstractSpectralModel{T,K} end

# reflection tie-ins
function Reflection.get_closure_symbols(::Type{<:AbstractTableModel})
(:table,)
end
function Reflection.get_parameter_symbols(model::Type{<:AbstractTableModel})
# `table` field is not a model parameter
Reflection.get_closure_symbols(::Type{<:AbstractTableModel}) = (:table,)

# `table` field is not a model parameter
Reflection.get_parameter_symbols(model::Type{<:AbstractTableModel}) =
fieldnames(model)[2:end]

"""
Base.copy(m::AbstractTableModel)
Create a copy of an [`AbstractTableModel`](@ref). This will copy all fields except
the `table` field, which is assumed to be a constant set of values that can be
shared by multiple copies.
When this is not the case, the user should redefine `Base.copy` for their particular
table model to copy the table as needed.
"""
function Base.copy(m::AbstractTableModel)
typeof(m)(m.table, (copy(getproperty(m, f)) for f in fieldnames(typeof(m))[2:end])...)
end

abstract type AbstractCachedModel{T,K} <: AbstractSpectralModel{T,K} end

# reflection tie-ins
function Reflection.get_closure_symbols(::Type{<:AbstractCachedModel})
(:cache,)
end
Reflection.get_parameter_symbols(model::Type{<:AbstractCachedModel}) =
fieldnames(model)[2:end]

export AbstractTableModel
3 changes: 0 additions & 3 deletions test/models/test-as-convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ model = conv(lines)

domain = collect(range(0.0, 10.0, 150))

plot(domain[1:end-1], invokemodel(domain, lines))
plot(domain[1:end-1], invokemodel(domain, model))

output = invokemodel(domain, model)

@test sum(output) 3.2570820013702395 atol = 1e-4
Expand Down
27 changes: 27 additions & 0 deletions test/models/test-copy.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using SpectralFitting
using Test

include("../dummies.jl")

# single model
m = PowerLaw()
m2 = copy(m)
@inferred copy(m)
m.K.value = 2.0
@test m2.K.value == 1.0

# table model
m = DummyMultiplicativeTableModel()
m2 = copy(m)
@inferred copy(m)
m.a.value = 2.0
@test m2.a.value == 1.0
@test m.table === m2.table

# composite model
model = PowerLaw() + PowerLaw()
model2 = copy(model)
model.K_1.value = 2.0
@test model2.K_2.value == 1.0

@inferred copy(model)
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ end
include("models/test-table-models.jl")
include("models/test-surrogate-models.jl")
include("models/test-auto-cache.jl")
include("models/test-as-convolution.jl")
include("models/test-copy.jl")

# only test XSPEC models when not using CI
# since model data access is annoying
Expand Down

0 comments on commit ea9ddec

Please sign in to comment.