Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release/v0.58.0 #489

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions tests/test_prebuild_kernels.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# copyright ################################# #
# This file is part of the Xobjects Package. #
# Copyright (c) CERN, 2022. #
# This file is part of the Xtrack Package. #
# Copyright (c) CERN, 2024. #
# ########################################### #
import json

import cffi

import xobjects as xo
import xtrack as xt
from xtrack.prebuild_kernels import regenerate_kernels
from xsuite.prebuild_kernels import regenerate_kernels


def test_prebuild_kernels(mocker, tmp_path, temp_context_default_func, capsys):
Expand Down Expand Up @@ -43,15 +42,17 @@ def test_prebuild_kernels(mocker, tmp_path, temp_context_default_func, capsys):
}),
]

patch_defs = 'xtrack.prebuilt_kernels.kernel_definitions.kernel_definitions'
# Override the definitions with the temporary ones
patch_defs = 'xsuite.kernel_definitions.kernel_definitions'
mocker.patch(patch_defs, kernel_definitions)

mocker.patch('xtrack.prebuild_kernels.XT_PREBUILT_KERNELS_LOCATION',
# We need to change the default location so that loading the kernels works
mocker.patch('xsuite.prebuild_kernels.XSK_PREBUILT_KERNELS_LOCATION',
tmp_path)
mocker.patch('xsuite.XSK_PREBUILT_KERNELS_LOCATION',
tmp_path)
mocker.patch('xtrack.tracker.XT_PREBUILT_KERNELS_LOCATION', tmp_path)

# Try regenerating the kernels
regenerate_kernels()
regenerate_kernels(location=tmp_path)

# Check if the expected files were created
so_file0, = tmp_path.glob('000_test_module.*.so')
Expand Down Expand Up @@ -103,17 +104,17 @@ def test_per_element_prebuild_kernels(mocker, tmp_path, temp_context_default_fun
}),
]

patch_defs = 'xtrack.prebuilt_kernels.kernel_definitions.kernel_definitions'
# Override the definitions with the temporary ones
patch_defs = 'xsuite.kernel_definitions.kernel_definitions'
mocker.patch(patch_defs, kernel_definitions)

mocker.patch('xtrack.prebuild_kernels.XT_PREBUILT_KERNELS_LOCATION',
# We need to change the default location so that loading the kernels works
mocker.patch('xsuite.prebuild_kernels.XSK_PREBUILT_KERNELS_LOCATION',
tmp_path)
mocker.patch('xsuite.XSK_PREBUILT_KERNELS_LOCATION',
tmp_path)
mocker.patch('xtrack.tracker.XT_PREBUILT_KERNELS_LOCATION', tmp_path)
mocker.patch('xtrack.base_element.XT_PREBUILT_KERNELS_LOCATION', tmp_path)
mocker.patch('xtrack.particles.particles.XT_PREBUILT_KERNELS_LOCATION', tmp_path)

# Try regenerating the kernels
regenerate_kernels()
regenerate_kernels(location=tmp_path)

# Check if the expected files were created
so_file_exists = False
Expand Down
19 changes: 13 additions & 6 deletions xtrack/base_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from xobjects.general import Print

from xobjects.hybrid_class import _build_xofields_dict
from xtrack.prebuild_kernels import XT_PREBUILT_KERNELS_LOCATION

from .general import _pkg_root
from .internal_record import RecordIdentifier, RecordIndex, generate_get_record
Expand Down Expand Up @@ -455,23 +454,31 @@ def compile_kernels(self, extra_classes=(), *args, **kwargs):
cls = type(self)

if context.allow_prebuilt_kernels:
from xtrack.prebuild_kernels import get_suitable_kernel
# Default config is empty (all flags default to not defined, which
# enables most behaviours). In the future this has to be looked at
# whenever a new flag is needed.
_default_config = {}
_print_state = Print.suppress
Print.suppress = True
classes = (cls._XoStruct,) + tuple(extra_classes)
kernel_info = get_suitable_kernel(
_default_config, classes
)
try:
from xsuite import (
get_suitable_kernel,
XSK_PREBUILT_KERNELS_LOCATION,
)
except ImportError:
kernel_info = None
else:
kernel_info = get_suitable_kernel(
_default_config, classes
)

Print.suppress = _print_state
if kernel_info:
module_name, _ = kernel_info
kernels = context.kernels_from_file(
module_name=module_name,
containing_dir=XT_PREBUILT_KERNELS_LOCATION,
containing_dir=XSK_PREBUILT_KERNELS_LOCATION,
kernel_descriptions=self._kernels,
)
context.kernels.update(kernels)
Expand Down
2 changes: 2 additions & 0 deletions xtrack/monitors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from .beam_position_monitor import *
from .beam_size_monitor import *
from .beam_profile_monitor import *

monitor_classes = tuple(v for v in globals().values() if isinstance(v, type) and issubclass(v, BeamElement))
23 changes: 23 additions & 0 deletions xtrack/multisetter/multisetter.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,29 @@ def set_values(self, values):
self._set_kernel(data=self, buffer=self._tracker_buffer.buffer,
input=xt.BeamElement._arr2ctx(self, values))

def compile_kernels(self, only_if_needed=True):
context = self._buffer.context
if context.allow_prebuilt_kernels and only_if_needed:
try:
from xsuite import (
get_suitable_kernel,
XSK_PREBUILT_KERNELS_LOCATION,
)
kernel_info = get_suitable_kernel({}, ())
except ImportError:
kernel_info = None

if kernel_info:
module_name, _ = kernel_info
kernels = context.kernels_from_file(
module_name=module_name,
containing_dir=XSK_PREBUILT_KERNELS_LOCATION,
kernel_descriptions=self._kernels,
)
context.kernels.update(kernels)

super().compile_kernels(only_if_needed=only_if_needed)


def _extract_offset(obj, field_name, index, dtype, xodtype):

Expand Down
14 changes: 10 additions & 4 deletions xtrack/particles/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import xobjects as xo
from xobjects.general import Print
from xobjects import BypassLinked
from xtrack.prebuild_kernels import XT_PREBUILT_KERNELS_LOCATION

from .constants import PROTON_MASS_EV

Expand Down Expand Up @@ -1006,16 +1005,23 @@ def _init_random_number_generator(self, seeds=None):
"""
context = self._buffer.context
if context.allow_prebuilt_kernels:
from xtrack.prebuild_kernels import get_suitable_kernel
_print_state = Print.suppress
Print.suppress = True
kernel_info = get_suitable_kernel({}, (self.__class__._XoStruct,))
try:
from xsuite import (
get_suitable_kernel,
XSK_PREBUILT_KERNELS_LOCATION,
)
kernel_info = get_suitable_kernel({}, ())
except ImportError:
kernel_info = None

Print.suppress = _print_state
if kernel_info:
module_name, _ = kernel_info
kernels = context.kernels_from_file(
module_name=module_name,
containing_dir=XT_PREBUILT_KERNELS_LOCATION,
containing_dir=XSK_PREBUILT_KERNELS_LOCATION,
kernel_descriptions=self._kernels,
)
context.kernels.update(kernels)
Expand Down
Loading