Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor to compile time constants #2697

Closed
1 change: 1 addition & 0 deletions astropy_helpers
Submodule astropy_helpers added at 9f82aa
30 changes: 3 additions & 27 deletions benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,7 @@ def packet(self):

@property
def verysimple_packet_collection(self):
return (
self.nb_simulation_verysimple.transport.transport_state.packet_collection
)
return self.nb_simulation_verysimple.transport.transport_state.packet_collection

@property
def nb_simulation_verysimple(self):
Expand All @@ -258,37 +256,15 @@ def verysimple_opacity_state(self):
return opacity_state_initialize(
self.nb_simulation_verysimple.plasma,
line_interaction_type="macroatom",
disable_line_scattering=self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING,
continuum_processes_enabled=self.nb_simulation_verysimple.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)

@property
def verysimple_enable_full_relativity(self):
return self.nb_simulation_verysimple.transport.enable_full_relativity

@property
def verysimple_disable_line_scattering(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING
)

@property
def verysimple_continuum_processes_enabled(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED
)

@property
def verysimple_tau_russian(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
)
return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN

@property
def verysimple_survival_probability(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
)
return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY

@property
def static_packet(self):
Expand Down
10 changes: 3 additions & 7 deletions benchmarks/transport_montecarlo_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tardis.transport.montecarlo.interaction as interaction
from benchmarks.benchmark_base import BenchmarkBase
from tardis.transport.montecarlo.numba_interface import (
from tardis.transport.montecarlo.numba_config import (
LineInteractionType,
)
from asv_runner.benchmarks.mark import parameterize
Expand All @@ -22,10 +22,10 @@ def time_thomson_scatter(self):
init_nu = packet.nu
init_energy = packet.energy
time_explosion = self.verysimple_time_explosion
enable_full_relativity = self.verysimple_enable_full_relativity

interaction.thomson_scatter(
packet, time_explosion, enable_full_relativity
packet,
time_explosion,
)

@parameterize(
Expand All @@ -42,7 +42,6 @@ def time_line_scatter(self, line_interaction_type):
packet.initialize_line_id(
self.verysimple_opacity_state,
self.verysimple_time_explosion,
self.verysimple_enable_full_relativity,
)
time_explosion = self.verysimple_time_explosion

Expand All @@ -51,7 +50,6 @@ def time_line_scatter(self, line_interaction_type):
time_explosion,
line_interaction_type,
self.verysimple_opacity_state,
self.verysimple_enable_full_relativity,
self.verysimple_continuum_processes_enabled,
)

Expand Down Expand Up @@ -84,7 +82,6 @@ def time_line_emission(self, test_packet):
packet.initialize_line_id(
self.verysimple_opacity_state,
self.verysimple_time_explosion,
self.verysimple_enable_full_relativity,
)

time_explosion = self.verysimple_time_explosion
Expand All @@ -94,5 +91,4 @@ def time_line_emission(self, test_packet):
emission_line_id,
time_explosion,
self.verysimple_opacity_state,
self.verysimple_enable_full_relativity,
)
6 changes: 1 addition & 5 deletions benchmarks/transport_montecarlo_packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ class BenchmarkMontecarloMontecarloNumbaPacket(BenchmarkBase):
"electron_density": 1e-5,
"tua_event": 1e10,
},
{
"electron_density": 1.0,
"tua_event": 1e10
},
{"electron_density": 1.0, "tua_event": 1e10},
]
}
)
Expand Down Expand Up @@ -104,7 +101,6 @@ def time_update_line_estimators(self, parameters):
cur_line_id,
distance_trace,
time_explosion,
enable_full_relativity,
)

@parameterize(
Expand Down
29 changes: 7 additions & 22 deletions benchmarks/transport_montecarlo_vpacket.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ def r_packet(self):
)

def v_packet_initialize_line_id(
self, v_packet, opacity_state, time_explosion, enable_full_relativity
self,
v_packet,
opacity_state,
time_explosion,
):
inverse_line_list_nu = opacity_state.line_list_nu[::-1]
doppler_factor = get_doppler_factor(
v_packet.r, v_packet.mu, time_explosion, enable_full_relativity
v_packet.r,
v_packet.mu,
time_explosion,
)
comov_nu = v_packet.nu * doppler_factor
next_line_id = len(opacity_state.line_list_nu) - np.searchsorted(
Expand All @@ -61,26 +66,19 @@ def time_trace_vpacket_within_shell(self):
)
verysimple_time_explosion = self.verysimple_time_explosion
verysimple_opacity_state = self.verysimple_opacity_state
enable_full_relativity = self.verysimple_enable_full_relativity
continuum_processes_enabled = (
self.verysimple_continuum_processes_enabled
)

# Give the vpacket a reasonable line ID
self.v_packet_initialize_line_id(
v_packet,
verysimple_opacity_state,
verysimple_time_explosion,
enable_full_relativity,
)

vpacket.trace_vpacket_within_shell(
v_packet,
verysimple_numba_radial_1d_geometry,
verysimple_time_explosion,
verysimple_opacity_state,
enable_full_relativity,
continuum_processes_enabled,
)

def time_trace_vpacket(self):
Expand All @@ -90,10 +88,6 @@ def time_trace_vpacket(self):
)
verysimple_time_explosion = self.verysimple_time_explosion
verysimple_opacity_state = self.verysimple_opacity_state
enable_full_relativity = self.verysimple_enable_full_relativity
continuum_processes_enabled = (
self.verysimple_continuum_processes_enabled
)
tau_russian = self.verysimple_tau_russian
survival_probability = self.verysimple_survival_probability

Expand All @@ -105,7 +99,6 @@ def time_trace_vpacket(self):
v_packet,
verysimple_opacity_state,
verysimple_time_explosion,
enable_full_relativity,
)

vpacket.trace_vpacket(
Expand All @@ -115,8 +108,6 @@ def time_trace_vpacket(self):
verysimple_opacity_state,
tau_russian,
survival_probability,
enable_full_relativity,
continuum_processes_enabled,
)

@property
Expand All @@ -136,12 +127,8 @@ def time_trace_bad_vpacket(self):
verysimple_numba_radial_1d_geometry = (
self.verysimple_numba_radial_1d_geometry
)
enable_full_relativity = self.verysimple_enable_full_relativity
verysimple_time_explosion = self.verysimple_time_explosion
verysimple_opacity_state = self.verysimple_opacity_state
continuum_processes_enabled = (
self.verysimple_continuum_processes_enabled
)
tau_russian = self.verysimple_tau_russian
survival_probability = self.verysimple_survival_probability

Expand All @@ -152,8 +139,6 @@ def time_trace_bad_vpacket(self):
verysimple_opacity_state,
tau_russian,
survival_probability,
enable_full_relativity,
continuum_processes_enabled,
)

@parameterize(
Expand Down
4 changes: 1 addition & 3 deletions tardis/io/model/parse_packet_source_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def initialize_packet_source(packet_source, config, geometry):
return packet_source


def parse_packet_source_from_config(config, geometry, legacy_mode_enabled):
def parse_packet_source_from_config(config, geometry):
"""
Parse the packet source based on the given configuration and geometry.

Expand All @@ -66,12 +66,10 @@ def parse_packet_source_from_config(config, geometry, legacy_mode_enabled):
packet_source = BlackBodySimpleSourceRelativistic(
base_seed=config.montecarlo.seed,
time_explosion=config.supernova.time_explosion,
legacy_mode_enabled=legacy_mode_enabled,
)
else:
packet_source = BlackBodySimpleSource(
base_seed=config.montecarlo.seed,
legacy_mode_enabled=legacy_mode_enabled,
)

return initialize_packet_source(packet_source, config, geometry)
20 changes: 12 additions & 8 deletions tardis/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,11 @@ def no_of_raw_shells(self):
return self.geometry.no_of_shells

@classmethod
def from_config(cls, config, atom_data, legacy_mode_enabled=False):
def from_config(
cls,
config,
atom_data,
):
"""
Create a new SimulationState instance from a Configuration object.

Expand All @@ -295,9 +299,7 @@ def from_config(cls, config, atom_data, legacy_mode_enabled=False):
atom_data, config, time_explosion, geometry
)

packet_source = parse_packet_source_from_config(
config, geometry, legacy_mode_enabled
)
packet_source = parse_packet_source_from_config(config, geometry)

radiation_field_state = parse_radiation_field_state_from_config(
config,
Expand All @@ -316,7 +318,11 @@ def from_config(cls, config, atom_data, legacy_mode_enabled=False):
)

@classmethod
def from_csvy(cls, config, atom_data=None, legacy_mode_enabled=False):
def from_csvy(
cls,
config,
atom_data=None,
):
"""
Create a new SimulationState instance from a Configuration object.

Expand Down Expand Up @@ -394,9 +400,7 @@ def from_csvy(cls, config, atom_data=None, legacy_mode_enabled=False):
geometry,
)

packet_source = parse_packet_source_from_config(
config, geometry, legacy_mode_enabled
)
packet_source = parse_packet_source_from_config(config, geometry)

radiation_field_state = parse_radiation_field_state_from_csvy(
config, csvy_model_config, csvy_model_data, geometry, packet_source
Expand Down
3 changes: 1 addition & 2 deletions tardis/model/matter/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def effective_element_masses(self):
def elemental_number_density(self):
"""Elemental Number Density computed using the formula: (elemental_mass_fraction * density) / atomic mass"""
return (
self.elemental_mass_fraction
* self.density.to(u.g / u.cm**3).value
self.elemental_mass_fraction * self.density.to(u.g / u.cm**3).value
).divide(
self.effective_element_masses.reindex(
self.elemental_mass_fraction.index
Expand Down
2 changes: 1 addition & 1 deletion tardis/opacities/opacities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)
from tardis.transport.montecarlo.numba_config import (
from tardis.transport.montecarlo.configuration.constants import (
SIGMA_THOMSON,
)

Expand Down
6 changes: 2 additions & 4 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tardis.plasma.standard_plasmas import assemble_plasma
from tardis.simulation.convergence import ConvergenceSolver
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.util.base import is_notebook
from tardis.visualization import ConvergencePlots

Expand Down Expand Up @@ -199,7 +200,7 @@ def __init__(
self._callbacks = OrderedDict()
self._cb_next_id = 0

self.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED = (
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED = (
not self.plasma.continuum_interaction_species.empty
)

Expand Down Expand Up @@ -615,7 +616,6 @@ def from_config(
virtual_packet_logging=False,
show_convergence_plots=False,
show_progress_bars=True,
legacy_mode_enabled=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -671,13 +671,11 @@ def from_config(
simulation_state = SimulationState.from_csvy(
config,
atom_data=atom_data,
legacy_mode_enabled=legacy_mode_enabled,
)
else:
simulation_state = SimulationState.from_config(
config,
atom_data=atom_data,
legacy_mode_enabled=legacy_mode_enabled,
)
# Override with custom packet source from function argument if present
if packet_source is not None:
Expand Down
17 changes: 12 additions & 5 deletions tardis/transport/frame_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
njit_dict_no_parallel,
)

from tardis.transport.montecarlo.numba_config import C_SPEED_OF_LIGHT
from tardis.transport.montecarlo.configuration.constants import (
C_SPEED_OF_LIGHT,
)
from tardis.transport.montecarlo.configuration import montecarlo_globals


@njit(**njit_dict_no_parallel)
def get_doppler_factor(r, mu, time_explosion, enable_full_relativity):
def get_doppler_factor(r, mu, time_explosion):
inv_c = 1 / C_SPEED_OF_LIGHT
inv_t = 1 / time_explosion
beta = r * inv_t * inv_c
if not enable_full_relativity:
if not montecarlo_globals.ENABLE_FULL_RELATIVITY:
return get_doppler_factor_partial_relativity(mu, beta)
else:
return get_doppler_factor_full_relativity(mu, beta)
Expand All @@ -31,7 +34,11 @@ def get_doppler_factor_full_relativity(mu, beta):


@njit(**njit_dict_no_parallel)
def get_inverse_doppler_factor(r, mu, time_explosion, enable_full_relativity):
def get_inverse_doppler_factor(
r,
mu,
time_explosion,
):
"""
Calculate doppler factor for frame transformation

Expand All @@ -44,7 +51,7 @@ def get_inverse_doppler_factor(r, mu, time_explosion, enable_full_relativity):
inv_c = 1 / C_SPEED_OF_LIGHT
inv_t = 1 / time_explosion
beta = r * inv_t * inv_c
if not enable_full_relativity:
if not montecarlo_globals.ENABLE_FULL_RELATIVITY:
return get_inverse_doppler_factor_partial_relativity(mu, beta)
else:
return get_inverse_doppler_factor_full_relativity(mu, beta)
Expand Down
Loading
Loading