Skip to content

Commit

Permalink
Merge pull request #136 from astro-group-bristol/fergus/integ-param
Browse files Browse the repository at this point in the history
feat: spacetime concious integration parameters
  • Loading branch information
fjebaker authored Jul 30, 2023
2 parents e3665ad + dd0b74d commit dd23b71
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 74 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.4.12"
[deps]
Buckets = "3235f445-51d8-4100-901d-5b23398ac3ab"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
31 changes: 28 additions & 3 deletions src/Gradus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using LinearAlgebra: ×, ⋅, norm, det, dot, inv

using Parameters

import DiffEqBase

using SciMLBase
using OrdinaryDiffEq
using DiffEqCallbacks
Expand Down Expand Up @@ -158,7 +160,7 @@ number of geodesics. Also used to dispatch different tracing problems.
abstract type AbstractTrace end

"""
AbstractIntegrationParameters
AbstractIntegrationParameters{M}
Parameters that are made available at each step of the integration, that need not be constant.
For example, the turning points or withing-geometry flags.
Expand All @@ -167,13 +169,28 @@ The integration parameters should track which spacetime `M` they are parameters
Integration parameters must implement
- [`set_status_code!`](@ref)
- [`get_status_code`](@ref)
- [`get_metric`](@ref)
For more complex parameters, may also optionally implement
- [`update_integration_parameters!`](@ref)
See the documentation of each of the above functions for details of their operation.
"""
abstract type AbstractIntegrationParameters end
abstract type AbstractIntegrationParameters{M<:AbstractMetric} end

# type alias, since this is often used
const MutStatusCode = MVector{1,StatusCodes.T}

# TODO: temporary fix for https://github.com/SciML/DiffEqBase.jl/issues/918
function DiffEqBase.anyeltypedual(
::AbstractIntegrationParameters{<:AbstractMetric{T}},
) where {T}
if T <: ForwardDiff.Dual
T
else
Any
end
end

"""
update_integration_parameters!(old::AbstractIntegrationParameters, new::AbstractIntegrationParameters)
Expand All @@ -197,7 +214,7 @@ end
Update the status [`StatusCodes`](@ref) in `p` with `status`.
"""
set_status_code!(params::AbstractIntegrationParameters, status::StatusCodes.T) =
set_status_code!(params::AbstractIntegrationParameters, ::StatusCodes.T) =
error("Not implemented for $(typeof(params))")

"""
Expand All @@ -208,6 +225,14 @@ Return the status [`StatusCodes`](@ref) in `status`.
get_status_code(params::AbstractIntegrationParameters) =
error("Not implemented for $(typeof(params))")

"""
get_metric(p::AbstractIntegrationParameters{M})::M where {M}
Return the [`AbstractMetric`](@ref) `m::M` for which the integration parameters
have been specialised.
"""
get_metric(params::AbstractIntegrationParameters) =
error("Not implemented for $(typeof(params))")

"""
abstract type AbstractGeodesicPoint
Expand Down
7 changes: 4 additions & 3 deletions src/metrics/kerr-newman-ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ function geodesic_ode_problem(
end
function f(u::SVector{8,T}, p, λ) where {T}
@inbounds let x = SVector{4,T}(@view(u[1:4])), v = SVector{4,T}(@view(u[5:8]))
dv = SVector{4,T}(geodesic_equation(m, x, v))
_m = get_metric(p)
dv = SVector{4,T}(geodesic_equation(_m, x, v))
# add maxwell part
dvf = if !(trace.q 0.0)
F = faraday_tensor(m, x)
F = faraday_tensor(_m, x)
q_μ * (F * v)
else
zero(SVector{4,T})
Expand All @@ -95,7 +96,7 @@ function geodesic_ode_problem(
f,
u_init,
time_domain,
IntegrationParameters(StatusCodes.NoStatus);
IntegrationParameters(m, StatusCodes.NoStatus);
callback = callback,
)
end
Expand Down
231 changes: 196 additions & 35 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
Base.precompile(
Tuple{
typeof(_second_order_ode_f),
SVector{8,Float64},
IntegrationParameters{KerrMetric{Float64}},
Float64,
},
) # time: 5.720331
Base.precompile(
Tuple{
typeof(lineprofile),
KerrMetric{Float64},
SVector{4,Float64},
GeometricThinDisc{Float64},
},
) # time: 4.8591228
) # time: 2.5695212
Base.precompile(
Tuple{
typeof(tracegeodesics),
Expand All @@ -14,7 +22,119 @@ Base.precompile(
SVector{4,Float64},
Vararg{Any},
},
) # time: 1.1855857
) # time: 1.4936496
let fbody = try
__lookup_kwbody__(
which(tracegeodesics, (KerrMetric{Float64}, SVector{4,Float64}, Vararg{Any})),
)
catch missing
end
if !ismissing(fbody)
precompile(
fbody,
(
Float64,
Float64,
TraceGeodesic{Float64},
Base.Pairs{Symbol,Union{},Tuple{},NamedTuple{(),Tuple{}}},
typeof(tracegeodesics),
KerrMetric{Float64},
SVector{4,Float64},
Vararg{Any},
),
)
end
end # time: 1.0984381
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{(:n_samples,),Tuple{Int64}},
typeof(emissivity_profile),
KerrMetric{Float64},
GeometricThinDisc{Float64},
LampPostModel{Float64},
},
) # time: 0.9623503
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{(:n_samples,),Tuple{Int64}},
typeof(tracecorona),
KerrMetric{Float64},
GeometricThinDisc{Float64},
LampPostModel{Float64},
},
) # time: 0.4676826
let fbody = try
__lookup_kwbody__(
which(tracegeodesics, (KerrMetric{Float64}, SVector{4,Float64}, Vararg{Any})),
)
catch missing
end
if !ismissing(fbody)
precompile(
fbody,
(
Float64,
Float64,
TraceRadiativeTransfer{Float64},
Base.Pairs{Symbol,Union{},Tuple{},NamedTuple{(),Tuple{}}},
typeof(tracegeodesics),
KerrMetric{Float64},
SVector{4,Float64},
Vararg{Any},
),
)
end
end # time: 0.3491159
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{(:n_samples,),Tuple{Int64}},
typeof(tracegeodesics),
KerrMetric{Float64},
LampPostModel{Float64},
GeometricThinDisc{Float64},
Vararg{Any},
},
) # time: 0.3443357
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{(:image_width, :image_height),Tuple{Int64,Int64}},
typeof(rendergeodesics),
KerrMetric{Float64},
SVector{4,Float64},
GeometricThinDisc{Float64},
Vararg{Any},
},
) # time: 0.19957814
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{(:trace,),Tuple{TraceRadiativeTransfer{Float64}}},
typeof(tracegeodesics),
KerrMetric{Float64},
SVector{4,Float64},
SVector{4,Float64},
Vararg{Any},
},
) # time: 0.13094269
Base.precompile(Tuple{Type{RadialDiscProfile},Vector{Float64},Vector{Float64},Vector{Any}}) # time: 0.08506817
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{
(:trace, :image_width, :image_height),
Tuple{TraceRadiativeTransfer{Float64},Int64,Int64},
},
typeof(rendergeodesics),
KerrMetric{Float64},
SVector{4,Float64},
GeometricThinDisc{Float64},
Vararg{Any},
},
) # time: 0.0841068
let fbody = try
__lookup_kwbody__(
which(
Expand Down Expand Up @@ -47,59 +167,54 @@ let fbody = try
),
)
end
end # time: 0.9053782
end # time: 0.036170956
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{
(:trace, :image_width, :image_height),
Tuple{TraceRadiativeTransfer{Float64},Int64,Int64},
},
typeof(rendergeodesics),
typeof(tracegeodesics),
KerrMetric{Float64},
SVector{4,Float64},
GeometricThinDisc{Float64},
Vector{SVector{4,Float64}},
Vector{SVector{4,Float64}},
Vararg{Any},
},
) # time: 0.11555993
) # time: 0.033662505
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{(:trace,),Tuple{TraceRadiativeTransfer{Float64}}},
typeof(tracegeodesics),
typeof(tracing_configuration),
TraceGeodesic{Float64},
KerrMetric{Float64},
SVector{4,Float64},
SVector{4,Float64},
Vararg{Any},
GeometricThinDisc{Float64},
Float64,
},
) # time: 0.04722405
) # time: 0.030030003
Base.precompile(
Tuple{
typeof(tracegeodesics),
typeof(tracing_configuration),
TraceRadiativeTransfer{Float64},
KerrMetric{Float64},
Vector{SVector{4,Float64}},
Vector{SVector{4,Float64}},
Vararg{Any},
SVector{4,Float64},
SVector{4,Float64},
GeometricThinDisc{Float64},
Float64,
},
) # time: 0.034065828
) # time: 0.01967225
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{(:n_samples,),Tuple{Int64}},
typeof(tracegeodesics),
NamedTuple{
(:trajectories, :save_on, :ensemble),
Tuple{Int64,Bool,EnsembleEndpointThreads},
},
typeof(tracing_configuration),
TraceGeodesic{Float64},
KerrMetric{Float64},
LampPostModel{Float64},
SVector{4,Float64},
Function,
GeometricThinDisc{Float64},
Vararg{Any},
},
) # time: 0.025394669
Base.precompile(
Tuple{
typeof(update_integration_parameters!),
RadiativeTransferIntegrationParameters{Vector{Bool}},
RadiativeTransferIntegrationParameters{Vector{Bool}},
Float64,
},
) # time: 0.007881916
) # time: 0.010621044
Base.precompile(
Tuple{
typeof(Core.kwcall),
Expand All @@ -115,4 +230,50 @@ Base.precompile(
GeometricThinDisc{Float64},
Float64,
},
) # time: 0.007585501
) # time: 0.008629832
Base.precompile(
Tuple{
typeof(tracing_configuration),
TraceGeodesic{Float64},
KerrMetric{Float64},
Vector{SVector{4,Float64}},
Vector{SVector{4,Float64}},
GeometricThinDisc{Float64},
Float64,
},
) # time: 0.007926913
Base.precompile(
Tuple{
typeof(update_integration_parameters!),
RadiativeTransferIntegrationParameters{KerrMetric{Float64},Vector{Bool}},
RadiativeTransferIntegrationParameters{KerrMetric{Float64},Vector{Bool}},
},
) # time: 0.007825662
Base.precompile(
Tuple{typeof(point_source_equitorial_disc_emissivity),Float64,Any,Float64,Float64},
) # time: 0.006162999
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{(:gtol,),Tuple{Float64}},
typeof(distance_to_disc),
GeometricThinDisc{Float64},
SVector{9,Float64},
},
) # time: 0.003949623
Base.precompile(
Tuple{
typeof(Core.kwcall),
NamedTuple{(:gtol,),Tuple{Float64}},
typeof(distance_to_disc),
GeometricThinDisc{Float64},
SVector{8,Float64},
},
) # time: 0.001979128
Base.precompile(
Tuple{
typeof(set_status_code!),
IntegrationParameters{KerrMetric{Float64}},
Gradus.StatusCodes.T,
},
) # time: 0.001029539
3 changes: 2 additions & 1 deletion src/solution-processing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,6 @@ function unpack_solution_full(
end

unpack_solution(gp::AbstractGeodesicPoint) = gp
unpack_solution(sol::SciMLBase.AbstractODESolution) = unpack_solution(sol.prob.f.f.m, sol)
unpack_solution(sol::SciMLBase.AbstractODESolution) =
unpack_solution(get_metric(sol.prob.p), sol)
unpack_solution(simsol::SciMLBase.AbstractEnsembleSolution) = map(unpack_solution, simsol.u)
Loading

0 comments on commit dd23b71

Please sign in to comment.