Skip to content

Commit

Permalink
Merge branch 'FixPrebuiltKernels' into release/v0.54.1
Browse files Browse the repository at this point in the history
  • Loading branch information
szymonlopaciuk committed Mar 4, 2024
2 parents 8d7a2e7 + 9a3b98f commit 90d2ab4
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 55 deletions.
67 changes: 65 additions & 2 deletions tests/test_prebuild_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import cffi

import xobjects as xo
import xpart as xp
import xtrack as xt
from xtrack.prebuild_kernels import regenerate_kernels

Expand Down Expand Up @@ -61,9 +60,73 @@ def test_prebuild_kernels(mocker, tmp_path, temp_context_default_func):
# Build the tracker on a fresh context, so that the kernel comes from a file
line.build_tracker(_context=xo.ContextCpu())

p = xp.Particles(p0c=1e9, px=3e-6)
p = xt.Particles(p0c=1e9, px=3e-6)
line.track(p)

assert p.x == 6e-6
assert p.y == 0.0
cffi_compile.assert_not_called()


def test_per_element_prebuild_kernels(mocker, tmp_path, temp_context_default_func):
# Set up the temporary kernels directory
kernel_definitions = {
"test_module": {
"config": {},
"classes": [
xt.Drift,
xt.Cavity,
xt.XYShift,
]
},
"test_module_rand": {
"config": {},
"classes": [],
"extra_classes": [
xt.RandomNormal,
]
},
}

patch_defs = 'xtrack.prebuilt_kernels.kernel_definitions.kernel_definitions'
mocker.patch(patch_defs, kernel_definitions)

mocker.patch('xtrack.prebuild_kernels.XT_PREBUILT_KERNELS_LOCATION',
tmp_path)
mocker.patch('xtrack.tracker.XT_PREBUILT_KERNELS_LOCATION', tmp_path)
mocker.patch('xtrack.base_element.XT_PREBUILT_KERNELS_LOCATION', tmp_path)
mocker.patch('xtrack.particles.particles.XT_PREBUILT_KERNELS_LOCATION', tmp_path)

# Try regenerating the kernels
regenerate_kernels()

# Check if the expected files were created
so_file_exists = False
for path in tmp_path.iterdir():
if not path.name.startswith('test_module.'):
continue
if path.suffix not in ('.so', '.dll', '.dylib', '.pyd'):
continue
so_file_exists = True
assert so_file_exists

assert (tmp_path / 'test_module.c').exists()
assert (tmp_path / 'test_module.json').exists()

# Test that reloading the kernel works
cffi_compile = mocker.patch.object(cffi.FFI, 'compile')

drift = xt.Drift(length=2.0)

p = xt.Particles(p0c=1e9, px=3e-6)
drift.track(p)

assert p.x == 6e-6
assert p.y == 0.0

rng = xt.RandomNormal()
n_samples = 100
samples = rng.generate(n_samples=n_samples, n_seeds=n_samples)
assert len(samples) == n_samples

cffi_compile.assert_not_called()
43 changes: 17 additions & 26 deletions xtrack/base_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ class BeamElement(xo.HybridClass, metaclass=MetaBeamElement):
has_backtrack = False
allow_backtrack = False
skip_in_loss_location_refinement = False
prebuilt_kernels_path = XT_PREBUILT_KERNELS_LOCATION
needs_rng = False


def __init__(self, *args, **kwargs):
xo.HybridClass.__init__(self, *args, **kwargs)
Expand All @@ -290,23 +291,20 @@ def compile_kernels(self, extra_classes=(), *args, **kwargs):
local_particle_src=Particles.gen_local_particle_api()))
context = self._context
cls = type(self)

if context.allow_prebuilt_kernels:
import xtrack as xt
from xtrack.prebuild_kernels import (
get_suitable_kernel,
XT_PREBUILT_KERNELS_LOCATION
)
from xtrack.prebuild_kernels import get_suitable_kernel
# Default config is empty (all flags default to not defined, which
# enables most behaviours). In the future this has to be looked at
# whenever a new flag is needed.
_default_config = {}
_print_state = Print.suppress
# Print.suppress = True
Print.suppress = True
classes = (cls._XoStruct,) + tuple(extra_classes)
kernel_info = get_suitable_kernel(
_default_config, classes
)
# Print.suppress = _print_state
Print.suppress = _print_state
if kernel_info:
module_name, _ = kernel_info
kernels = context.kernels_from_file(
Expand All @@ -330,6 +328,9 @@ def track(self, particles=None, increment_at_element=False):
elif particles is None:
raise RuntimeError("Please provide particles to track!")

if self.needs_rng and not particles._has_valid_rng_state():
particles._init_random_number_generator()

context = self._buffer.context

if self._track_kernel_name not in context.kernels:
Expand Down Expand Up @@ -378,19 +379,16 @@ def __init__(self, kernel_name, element, additional_arg_names):
self.additional_arg_names = additional_arg_names

def __call__(self, particles, increment_at_element=False, **kwargs):
print(f'===> Calling PerParticlePyMethod {self.kernel_name}')
instance = self.element
context = instance.context

if self.kernel_name not in context.kernels:
instance.compile_kernels()
context = instance._context

only_if_needed = kwargs.pop('only_if_needed', True)
BeamElement.compile_kernels(instance, only_if_needed=only_if_needed)
kernel = context.kernels[self.kernel_name]

if hasattr(self.element, 'io_buffer') and self.element.io_buffer is not None:
io_buffer_arr = self.element.io_buffer.buffer
else:
context = kernel.context
io_buffer_arr = context.zeros(1, dtype=np.int8) # dummy

kernel.description.n_threads = particles._capacity
Expand All @@ -414,24 +412,18 @@ def __get__(self, instance, owner):

class PyMethod:

def __init__(self, kernel_name, element, additional_arg_names, prebuilt_kernels_path=None):
def __init__(self, kernel_name, element, additional_arg_names):
self.kernel_name = kernel_name
self.element = element
self.additional_arg_names = additional_arg_names
self.prebuilt_kernels_path = prebuilt_kernels_path

def __call__(self, **kwargs):
instance = self.element
context = instance._context

only_if_needed = kwargs.pop('only_if_needed', True)
BeamElement.compile_kernels(
instance,
prebuilt_kernels_path=self.prebuilt_kernels_path,
only_if_needed=only_if_needed,

)
kernel = getattr(context.kernels, self.kernel_name)
BeamElement.compile_kernels(instance, only_if_needed=only_if_needed)
kernel = context.kernels[self.kernel_name]

el_var_name = None
for arg in instance._kernels[self.kernel_name].args:
Expand All @@ -452,8 +444,7 @@ def __init__(self, kernel_name, additional_arg_names):
self.additional_arg_names = additional_arg_names

def __get__(self, instance, owner):
kernels_path = getattr(owner, 'prebuilt_kernels_path', None)
return PyMethod(kernel_name=self.kernel_name,
element=instance,
additional_arg_names=self.additional_arg_names,
prebuilt_kernels_path=kernels_path)
additional_arg_names=self.additional_arg_names)

26 changes: 21 additions & 5 deletions xtrack/particles/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
# ######################################### #

import numpy as np
import xobjects as xo

from pathlib import Path

from .constants import PROTON_MASS_EV

from scipy.constants import e as qe
from scipy.constants import c as clight
from scipy.constants import epsilon_0

import xobjects as xo
from xobjects.general import Print
from xobjects import BypassLinked
from xtrack.prebuild_kernels import XT_PREBUILT_KERNELS_LOCATION

from .constants import PROTON_MASS_EV


LAST_INVALID_STATE = -999999999
Expand Down Expand Up @@ -98,6 +99,7 @@ class Particles(xo.HybridClass):

_kernels = {
'Particles_initialize_rand_gen': xo.Kernel(
c_name="Particles_initialize_rand_gen",
args=[
xo.Arg(xo.ThisClass, name='particles'),
xo.Arg(xo.UInt32, pointer=True, name='seeds'),
Expand Down Expand Up @@ -995,6 +997,21 @@ def _init_random_number_generator(self, seeds=None):
Initialize state of the random number generator (possibility to providing
a seed for each particle).
"""
context = self._buffer.context
if context.allow_prebuilt_kernels:
from xtrack.prebuild_kernels import get_suitable_kernel
_print_state = Print.suppress
Print.suppress = True
kernel_info = get_suitable_kernel({}, (self.__class__._XoStruct,))
Print.suppress = _print_state
if kernel_info:
module_name, _ = kernel_info
kernels = context.kernels_from_file(
module_name=module_name,
containing_dir=XT_PREBUILT_KERNELS_LOCATION,
kernel_descriptions=self._kernels,
)
context.kernels.update(kernels)
self.compile_kernels(only_if_needed=True)

if seeds is None:
Expand All @@ -1005,7 +1022,6 @@ def _init_random_number_generator(self, seeds=None):
if not hasattr(seeds, 'dtype') or seeds.dtype != np.uint32:
seeds = np.array(seeds, dtype=np.uint32)

context = self._buffer.context
seeds_dev = context.nparray_to_context_array(seeds)
kernel = context.kernels['Particles_initialize_rand_gen']
kernel(particles=self, seeds=seeds_dev, n_init=self._capacity)
Expand Down
50 changes: 31 additions & 19 deletions xtrack/prebuilt_kernels/kernel_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,16 @@
# These will be enumerated in order of appearance in the dict, so in this case
# (for optimization purposes) the order is important.
kernel_definitions = {
'default_only_xtrack_no_config': {
'config': {},
'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS,
},
'default_only_xtrack': {
'config': BASE_CONFIG,
'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS,
},
'only_xtrack_non_tracking_kernels': {
'config': BASE_CONFIG,
'config': {},
'classes': [],
'extra_classes': NON_TRACKING_ELEMENTS
},
Expand Down Expand Up @@ -122,6 +126,7 @@
}
}


try:
import xfields as xf
DEFAULT_BB3D_ELEMENTS = [
Expand All @@ -134,48 +139,55 @@
'config': BASE_CONFIG,
'classes': [*DEFAULT_BB3D_ELEMENTS, LineSegmentMap],
}
kernel_definitions['default_bb3d_no_config'] = {
'config': {},
'classes': [*DEFAULT_BB3D_ELEMENTS, LineSegmentMap],
}
except ImportError:
LOGGER.warning('Xfields not installed, skipping BB3D elements')


try:
import xcoll as xc
DEFAULT_XCOLL_ELEMENTS = [
*ONLY_XTRACK_ELEMENTS,
*NO_SYNRAD_ELEMENTS,
xc.BlackAbsorber,
xc.EverestBlock,
xc.EverestCollimator,
xc.EverestCrystal
]

kernel_definitions['default_xcoll'] = {
'config': BASE_CONFIG,
'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS \
+ [xc.BlackAbsorber, xc.EverestBlock, \
xc.EverestCollimator, xc.EverestCrystal]
'classes': DEFAULT_XCOLL_ELEMENTS
}
kernel_definitions['default_xcoll_no_config'] = {
'config': {},
'classes': DEFAULT_XCOLL_ELEMENTS
}
kernel_definitions['default_xcoll_frozen_longitudinal'] = {
'config': {**BASE_CONFIG, **FREEZE_LONGITUDINAL},
'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS \
+ [xc.BlackAbsorber, xc.EverestBlock, \
xc.EverestCollimator, xc.EverestCrystal]
'classes': DEFAULT_XCOLL_ELEMENTS
}
kernel_definitions['default_xcoll_frozen_energy'] = {
'config': {**BASE_CONFIG, **FREEZE_ENERGY},
'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS \
+ [xc.BlackAbsorber, xc.EverestBlock, \
xc.EverestCollimator, xc.EverestCrystal]
'classes': DEFAULT_XCOLL_ELEMENTS
}
kernel_definitions['default_xcoll_backtrack'] = {
'config': {**BASE_CONFIG, 'XSUITE_BACKTRACK': True},
'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS \
+ [xc.BlackAbsorber, xc.EverestBlock, \
xc.EverestCollimator, xc.EverestCrystal]
'classes': DEFAULT_XCOLL_ELEMENTS
}
kernel_definitions['default_xcoll_backtrack_no_limit'] = {
'config': {**{k: v for k,v in BASE_CONFIG.items()
if k != 'XTRACK_GLOBAL_XY_LIMIT'},
'XSUITE_BACKTRACK': True},
'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS \
+ [xc.BlackAbsorber, xc.EverestBlock, \
xc.EverestCollimator, xc.EverestCrystal]
'classes': DEFAULT_XCOLL_ELEMENTS
}
kernel_definitions['default_xcoll_backtrack_frozen_energy'] = {
'config': {**BASE_CONFIG, **FREEZE_ENERGY, 'XSUITE_BACKTRACK': True},
'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS \
+ [xc.BlackAbsorber, xc.EverestBlock, \
xc.EverestCollimator, xc.EverestCrystal]
'classes': DEFAULT_XCOLL_ELEMENTS
}
except ImportError:
LOGGER.warning('Xcoll not installed, skipping collimator elements')

3 changes: 0 additions & 3 deletions xtrack/random/random_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class RandomUniform(BeamElement):
allow_track = False

_extra_c_sources = [
# The base (bitwise) rng is in xtrack, as this is where the
# seeds are stored. This is needed to avoid circular imports
# in xtrack.Particles
_pkg_root.joinpath('particles', 'rng_src', 'base_rng.h'),
_pkg_root.joinpath('random', 'random_src', 'uniform.h')
]
Expand Down
3 changes: 3 additions & 0 deletions xtrack/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def __init__(
_buffer=_buffer)
line._freeze()

if np.any([hasattr(ee, 'needs_rng') and ee.needs_rng for ee in line.elements]):
line._needs_rng = True

_buffer = tracker_data_base._buffer

# Make a "marker" element to increase at_element
Expand Down

0 comments on commit 90d2ab4

Please sign in to comment.