Skip to content

Commit

Permalink
Merge pull request #134 from astro-group-bristol/fergus/get-set-status
Browse files Browse the repository at this point in the history
Get/set status code API for integration parameters
  • Loading branch information
fjebaker authored Jul 27, 2023
2 parents 2d4e693 + 0e17688 commit 2772563
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 47 deletions.
4 changes: 3 additions & 1 deletion src/Gradus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ import .GradusBase:
IntegrationParameters,
update_integration_parameters!,
restrict_ensemble,
_fast_dot
_fast_dot,
set_status_code!,
get_status_code

export AbstractMetric,
AbstractTrace,
Expand Down
48 changes: 25 additions & 23 deletions src/GradusBase/GradusBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,30 @@ export AbstractMetric,
AbstractStaticAxisSymmetric,
metric,
unpack_solution,
unpack_solution_full
GeodesicPoint,
AbstractGeodesicPoint,
vector_to_local_sky,
AbstractMetric,
geodesic_equation,
constrain,
inner_radius,
metric_type,
metric_components,
inverse_metric_components,
dotproduct,
propernorm,
tetradframe,
lnrbasis,
lnrframe,
lowerindices,
raiseindices,
StatusCodes,
AbstractIntegrationParameters,
IntegrationParameters,
update_integration_parameters!,
restrict_ensemble
unpack_solution_full,
GeodesicPoint,
AbstractGeodesicPoint,
vector_to_local_sky,
AbstractMetric,
geodesic_equation,
constrain,
inner_radius,
metric_type,
metric_components,
inverse_metric_components,
dotproduct,
propernorm,
tetradframe,
lnrbasis,
lnrframe,
lowerindices,
raiseindices,
StatusCodes,
AbstractIntegrationParameters,
IntegrationParameters,
update_integration_parameters!,
restrict_ensemble,
set_status_code!,
get_status_code

end # module
68 changes: 55 additions & 13 deletions src/GradusBase/geodesic-solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,60 @@ abstract type AbstractTrace end
Parameters that are made available at each step of the integration, that need not be constant.
For example, the turning points or withing-geometry flags.
The integration parameters should track which spacetime `M` they are parameters for.
Integration parameters must implement
- [`set_status_code!`](@ref)
- [`get_status_code`](@ref)
For more complex parameters, may also optionally implement
- [`update_integration_parameters!`](@ref)
See the documentation of each of the above functions for details of their operation.
"""
abstract type AbstractIntegrationParameters end

update_integration_parameters!(
p::AbstractIntegrationParameters,
::AbstractIntegrationParameters,
) = error("Not implemented for $(typeof(p))")
"""
update_integration_parameters!(old::AbstractIntegrationParameters, new::AbstractIntegrationParameters)
mutable struct IntegrationParameters <: AbstractIntegrationParameters
status::StatusCodes.T
end
Update (mutate) the `old` integration parameters to take the value of the `new`. Function should return
the `old`.
Note this function is practically only used to update any mutable fields in the integration parameters,
such as resetting any changes to an original state.
"""
function update_integration_parameters!(
p::IntegrationParameters,
new::IntegrationParameters,
old::AbstractIntegrationParameters,
new::AbstractIntegrationParameters,
)
p.status = new.status
p
set_status_code!(old, get_status_code(new))
old
end

"""
set_status_code!(p::AbstractIntegrationParameters, status::StatusCodes.T)
Update the status [`StatusCodes.T`](@ref) in `p` with `status`.
"""
set_status_code!(params::AbstractIntegrationParameters, status::StatusCodes.T) =
error("Not implemented for $(typeof(params))")

"""
get_status_code(p::AbstractIntegrationParameters)::StatusCodes.T
Return the status [`StatusCodes.T`](@ref) in `status`.
"""
get_status_code(params::AbstractIntegrationParameters) =
error("Not implemented for $(typeof(params))")

mutable struct IntegrationParameters <: AbstractIntegrationParameters
status::StatusCodes.T
end

set_status_code!(params::IntegrationParameters, status::StatusCodes.T) =
params.status = status
get_status_code(params::IntegrationParameters) = params.status

"""
abstract type AbstractGeodesicPoint
Expand Down Expand Up @@ -118,7 +152,7 @@ function unpack_solution(::AbstractMetric, sol::SciMLBase.AbstractODESolution{T}
# get the auxiliary values if we have any
aux = unpack_auxiliary(us[end])

GeodesicPoint(sol.prob.p.status, t_init, t, x_init, x, v_init, v, aux)
GeodesicPoint(get_status_code(sol.prob.p), t_init, t, x_init, x, v_init, v, aux)
end
end

Expand Down Expand Up @@ -154,7 +188,15 @@ function unpack_solution_full(
ui = SVector{4,T}(us[i][1:4])
vi = SVector{4,T}(us[i][5:8])
ti = ts[i]
GeodesicPoint(sol.prob.p.status, t_start, ti, u_start, ui, v_start, vi)
GeodesicPoint(
get_status_code(sol.prob.p),
t_start,
ti,
u_start,
ui,
v_start,
vi,
)
end
end
end
Expand Down
5 changes: 4 additions & 1 deletion src/corona/profiles/voronoi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ function VoronoiDiscProfile(
VoronoiDiscProfile(
m,
d,
filter(i -> i.prob.p.status == StatusCodes.IntersectedWithGeometry, simsols.u),
filter(
i -> get_status_code(i.prob.p) == StatusCodes.IntersectedWithGeometry,
simsols.u,
),
)
end

Expand Down
2 changes: 1 addition & 1 deletion src/tracing/callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function terminate_with_status!(status::StatusCodes.T)
function _terminate_with_status_closure!(integrator)
integrator.p.status = status
set_status_code!(integrator.p, status)
terminate!(integrator)
end
end
Expand Down
8 changes: 4 additions & 4 deletions src/tracing/charts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ end
function chart_terminate!(integrator)
# terminate with status code depending on whether inner or outer boundary
if integrator.u[2] chart.inner_radius
integrator.p.status = StatusCodes.WithinInnerBoundary
set_status_code!(integrator.p, StatusCodes.WithinInnerBoundary)
terminate!(integrator)
else
integrator.p.status = StatusCodes.OutOfDomain
set_status_code!(integrator.p, StatusCodes.OutOfDomain)
terminate!(integrator)
end
end
Expand All @@ -38,10 +38,10 @@ end
# terminate with status code depending on whether inner or outer boundary
rmin = chart.shapefunc(integrator.u[3])
if integrator.u[2] rmin
integrator.p.status = StatusCodes.WithinInnerBoundary
set_status_code!(integrator.p, StatusCodes.WithinInnerBoundary)
terminate!(integrator)
else
integrator.p.status = StatusCodes.OutOfDomain
set_status_code!(integrator.p, StatusCodes.OutOfDomain)
terminate!(integrator)
end
end
Expand Down
6 changes: 5 additions & 1 deletion src/tracing/method-implementations/first-order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ end
make_parameters(L, Q, sign_θ, ::Type{T}) where {T} =
FirstOrderIntegrationParameters{T}(L, Q, -1, sign_θ, [0.0, 0.0], StatusCodes.NoStatus)

set_status_code!(params::FirstOrderIntegrationParameters, status::StatusCodes.T) =
params.status = status
get_status_code(params::FirstOrderIntegrationParameters) = params.status

function update_integration_parameters!(
p::FirstOrderIntegrationParameters,
new::FirstOrderIntegrationParameters,
Expand All @@ -104,7 +108,7 @@ function update_integration_parameters!(
p.Q = new.Q
p.θ = new.θ
p.changes = new.changes
p.status = new.status
set_status_code!(p, get_status_code(new))
p
end

Expand Down
6 changes: 5 additions & 1 deletion src/tracing/radiative-transfer-problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@ function _radiative_transfer_integration_parameters(status::StatusCodes.T, geome
RadiativeTransferIntegrationParameters(status, within_geometry)
end

set_status_code!(params::RadiativeTransferIntegrationParameters, status::StatusCodes.T) =
params.status = status
get_status_code(params::RadiativeTransferIntegrationParameters) = params.status


function update_integration_parameters!(
p::RadiativeTransferIntegrationParameters,
new::RadiativeTransferIntegrationParameters,
)
p.status = new.status
set_status_code!(p, get_status_code(new))
p.within_geometry .= new.within_geometry
p
end
Expand Down
7 changes: 5 additions & 2 deletions test/smoke-tests/disc-profiles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ using StaticArrays
sampler = EvenSampler(domain = LowerHemisphere()),
)

intersected_simsols =
filter(i -> i.prob.p.status == StatusCodes.IntersectedWithGeometry, simsols.u)
intersected_simsols = filter(
i ->
Gradus.get_status_code(i.prob.p) == StatusCodes.IntersectedWithGeometry,
simsols.u,
)
sd_endpoints = map(sol -> unpack_solution(m, sol), intersected_simsols)

# test ensemble solution constructor
Expand Down

0 comments on commit 2772563

Please sign in to comment.