Skip to content

Commit

Permalink
Merge pull request #489 from xsuite/release/v0.58.0
Browse files Browse the repository at this point in the history
Release/v0.58.0
  • Loading branch information
szymonlopaciuk authored May 2, 2024
2 parents 0d91a70 + 9267a0c commit c0f17ad
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 623 deletions.
33 changes: 17 additions & 16 deletions tests/test_prebuild_kernels.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# copyright ################################# #
# This file is part of the Xobjects Package. #
# Copyright (c) CERN, 2022. #
# This file is part of the Xtrack Package. #
# Copyright (c) CERN, 2024. #
# ########################################### #
import json

import cffi

import xobjects as xo
import xtrack as xt
from xtrack.prebuild_kernels import regenerate_kernels
from xsuite.prebuild_kernels import regenerate_kernels


def test_prebuild_kernels(mocker, tmp_path, temp_context_default_func, capsys):
Expand Down Expand Up @@ -43,15 +42,17 @@ def test_prebuild_kernels(mocker, tmp_path, temp_context_default_func, capsys):
}),
]

patch_defs = 'xtrack.prebuilt_kernels.kernel_definitions.kernel_definitions'
# Override the definitions with the temporary ones
patch_defs = 'xsuite.kernel_definitions.kernel_definitions'
mocker.patch(patch_defs, kernel_definitions)

mocker.patch('xtrack.prebuild_kernels.XT_PREBUILT_KERNELS_LOCATION',
# We need to change the default location so that loading the kernels works
mocker.patch('xsuite.prebuild_kernels.XSK_PREBUILT_KERNELS_LOCATION',
tmp_path)
mocker.patch('xsuite.XSK_PREBUILT_KERNELS_LOCATION',
tmp_path)
mocker.patch('xtrack.tracker.XT_PREBUILT_KERNELS_LOCATION', tmp_path)

# Try regenerating the kernels
regenerate_kernels()
regenerate_kernels(location=tmp_path)

# Check if the expected files were created
so_file0, = tmp_path.glob('000_test_module.*.so')
Expand Down Expand Up @@ -103,17 +104,17 @@ def test_per_element_prebuild_kernels(mocker, tmp_path, temp_context_default_fun
}),
]

patch_defs = 'xtrack.prebuilt_kernels.kernel_definitions.kernel_definitions'
# Override the definitions with the temporary ones
patch_defs = 'xsuite.kernel_definitions.kernel_definitions'
mocker.patch(patch_defs, kernel_definitions)

mocker.patch('xtrack.prebuild_kernels.XT_PREBUILT_KERNELS_LOCATION',
# We need to change the default location so that loading the kernels works
mocker.patch('xsuite.prebuild_kernels.XSK_PREBUILT_KERNELS_LOCATION',
tmp_path)
mocker.patch('xsuite.XSK_PREBUILT_KERNELS_LOCATION',
tmp_path)
mocker.patch('xtrack.tracker.XT_PREBUILT_KERNELS_LOCATION', tmp_path)
mocker.patch('xtrack.base_element.XT_PREBUILT_KERNELS_LOCATION', tmp_path)
mocker.patch('xtrack.particles.particles.XT_PREBUILT_KERNELS_LOCATION', tmp_path)

# Try regenerating the kernels
regenerate_kernels()
regenerate_kernels(location=tmp_path)

# Check if the expected files were created
so_file_exists = False
Expand Down
19 changes: 13 additions & 6 deletions xtrack/base_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from xobjects.general import Print

from xobjects.hybrid_class import _build_xofields_dict
from xtrack.prebuild_kernels import XT_PREBUILT_KERNELS_LOCATION

from .general import _pkg_root
from .internal_record import RecordIdentifier, RecordIndex, generate_get_record
Expand Down Expand Up @@ -455,23 +454,31 @@ def compile_kernels(self, extra_classes=(), *args, **kwargs):
cls = type(self)

if context.allow_prebuilt_kernels:
from xtrack.prebuild_kernels import get_suitable_kernel
# Default config is empty (all flags default to not defined, which
# enables most behaviours). In the future this has to be looked at
# whenever a new flag is needed.
_default_config = {}
_print_state = Print.suppress
Print.suppress = True
classes = (cls._XoStruct,) + tuple(extra_classes)
kernel_info = get_suitable_kernel(
_default_config, classes
)
try:
from xsuite import (
get_suitable_kernel,
XSK_PREBUILT_KERNELS_LOCATION,
)
except ImportError:
kernel_info = None
else:
kernel_info = get_suitable_kernel(
_default_config, classes
)

Print.suppress = _print_state
if kernel_info:
module_name, _ = kernel_info
kernels = context.kernels_from_file(
module_name=module_name,
containing_dir=XT_PREBUILT_KERNELS_LOCATION,
containing_dir=XSK_PREBUILT_KERNELS_LOCATION,
kernel_descriptions=self._kernels,
)
context.kernels.update(kernels)
Expand Down
2 changes: 2 additions & 0 deletions xtrack/monitors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from .beam_position_monitor import *
from .beam_size_monitor import *
from .beam_profile_monitor import *

monitor_classes = tuple(v for v in globals().values() if isinstance(v, type) and issubclass(v, BeamElement))
23 changes: 23 additions & 0 deletions xtrack/multisetter/multisetter.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,29 @@ def set_values(self, values):
self._set_kernel(data=self, buffer=self._tracker_buffer.buffer,
input=xt.BeamElement._arr2ctx(self, values))

def compile_kernels(self, only_if_needed=True):
context = self._buffer.context
if context.allow_prebuilt_kernels and only_if_needed:
try:
from xsuite import (
get_suitable_kernel,
XSK_PREBUILT_KERNELS_LOCATION,
)
kernel_info = get_suitable_kernel({}, ())
except ImportError:
kernel_info = None

if kernel_info:
module_name, _ = kernel_info
kernels = context.kernels_from_file(
module_name=module_name,
containing_dir=XSK_PREBUILT_KERNELS_LOCATION,
kernel_descriptions=self._kernels,
)
context.kernels.update(kernels)

super().compile_kernels(only_if_needed=only_if_needed)


def _extract_offset(obj, field_name, index, dtype, xodtype):

Expand Down
14 changes: 10 additions & 4 deletions xtrack/particles/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import xobjects as xo
from xobjects.general import Print
from xobjects import BypassLinked
from xtrack.prebuild_kernels import XT_PREBUILT_KERNELS_LOCATION

from .constants import PROTON_MASS_EV

Expand Down Expand Up @@ -1006,16 +1005,23 @@ def _init_random_number_generator(self, seeds=None):
"""
context = self._buffer.context
if context.allow_prebuilt_kernels:
from xtrack.prebuild_kernels import get_suitable_kernel
_print_state = Print.suppress
Print.suppress = True
kernel_info = get_suitable_kernel({}, (self.__class__._XoStruct,))
try:
from xsuite import (
get_suitable_kernel,
XSK_PREBUILT_KERNELS_LOCATION,
)
kernel_info = get_suitable_kernel({}, ())
except ImportError:
kernel_info = None

Print.suppress = _print_state
if kernel_info:
module_name, _ = kernel_info
kernels = context.kernels_from_file(
module_name=module_name,
containing_dir=XT_PREBUILT_KERNELS_LOCATION,
containing_dir=XSK_PREBUILT_KERNELS_LOCATION,
kernel_descriptions=self._kernels,
)
context.kernels.update(kernels)
Expand Down
Loading

0 comments on commit c0f17ad

Please sign in to comment.