diff --git a/tests/test_prebuild_kernels.py b/tests/test_prebuild_kernels.py index 2fdfc31ba..30b3c4a9f 100644 --- a/tests/test_prebuild_kernels.py +++ b/tests/test_prebuild_kernels.py @@ -1,14 +1,13 @@ # copyright ################################# # -# This file is part of the Xobjects Package. # -# Copyright (c) CERN, 2022. # +# This file is part of the Xtrack Package. # +# Copyright (c) CERN, 2024. # # ########################################### # -import json import cffi import xobjects as xo import xtrack as xt -from xtrack.prebuild_kernels import regenerate_kernels +from xsuite.prebuild_kernels import regenerate_kernels def test_prebuild_kernels(mocker, tmp_path, temp_context_default_func, capsys): @@ -43,15 +42,17 @@ def test_prebuild_kernels(mocker, tmp_path, temp_context_default_func, capsys): }), ] - patch_defs = 'xtrack.prebuilt_kernels.kernel_definitions.kernel_definitions' + # Override the definitions with the temporary ones + patch_defs = 'xsuite.kernel_definitions.kernel_definitions' mocker.patch(patch_defs, kernel_definitions) - - mocker.patch('xtrack.prebuild_kernels.XT_PREBUILT_KERNELS_LOCATION', + # We need to change the default location so that loading the kernels works + mocker.patch('xsuite.prebuild_kernels.XSK_PREBUILT_KERNELS_LOCATION', + tmp_path) + mocker.patch('xsuite.XSK_PREBUILT_KERNELS_LOCATION', tmp_path) - mocker.patch('xtrack.tracker.XT_PREBUILT_KERNELS_LOCATION', tmp_path) # Try regenerating the kernels - regenerate_kernels() + regenerate_kernels(location=tmp_path) # Check if the expected files were created so_file0, = tmp_path.glob('000_test_module.*.so') @@ -103,17 +104,17 @@ def test_per_element_prebuild_kernels(mocker, tmp_path, temp_context_default_fun }), ] - patch_defs = 'xtrack.prebuilt_kernels.kernel_definitions.kernel_definitions' + # Override the definitions with the temporary ones + patch_defs = 'xsuite.kernel_definitions.kernel_definitions' mocker.patch(patch_defs, kernel_definitions) - - mocker.patch('xtrack.prebuild_kernels.XT_PREBUILT_KERNELS_LOCATION', + # We need to change the default location so that loading the kernels works + mocker.patch('xsuite.prebuild_kernels.XSK_PREBUILT_KERNELS_LOCATION', + tmp_path) + mocker.patch('xsuite.XSK_PREBUILT_KERNELS_LOCATION', tmp_path) - mocker.patch('xtrack.tracker.XT_PREBUILT_KERNELS_LOCATION', tmp_path) - mocker.patch('xtrack.base_element.XT_PREBUILT_KERNELS_LOCATION', tmp_path) - mocker.patch('xtrack.particles.particles.XT_PREBUILT_KERNELS_LOCATION', tmp_path) # Try regenerating the kernels - regenerate_kernels() + regenerate_kernels(location=tmp_path) # Check if the expected files were created so_file_exists = False diff --git a/xtrack/base_element.py b/xtrack/base_element.py index 37b3609c4..e274598cc 100644 --- a/xtrack/base_element.py +++ b/xtrack/base_element.py @@ -11,7 +11,6 @@ from xobjects.general import Print from xobjects.hybrid_class import _build_xofields_dict -from xtrack.prebuild_kernels import XT_PREBUILT_KERNELS_LOCATION from .general import _pkg_root from .internal_record import RecordIdentifier, RecordIndex, generate_get_record @@ -455,7 +454,6 @@ def compile_kernels(self, extra_classes=(), *args, **kwargs): cls = type(self) if context.allow_prebuilt_kernels: - from xtrack.prebuild_kernels import get_suitable_kernel # Default config is empty (all flags default to not defined, which # enables most behaviours). In the future this has to be looked at # whenever a new flag is needed. @@ -463,15 +461,24 @@ def compile_kernels(self, extra_classes=(), *args, **kwargs): _print_state = Print.suppress Print.suppress = True classes = (cls._XoStruct,) + tuple(extra_classes) - kernel_info = get_suitable_kernel( - _default_config, classes - ) + try: + from xsuite import ( + get_suitable_kernel, + XSK_PREBUILT_KERNELS_LOCATION, + ) + except ImportError: + kernel_info = None + else: + kernel_info = get_suitable_kernel( + _default_config, classes + ) + Print.suppress = _print_state if kernel_info: module_name, _ = kernel_info kernels = context.kernels_from_file( module_name=module_name, - containing_dir=XT_PREBUILT_KERNELS_LOCATION, + containing_dir=XSK_PREBUILT_KERNELS_LOCATION, kernel_descriptions=self._kernels, ) context.kernels.update(kernels) diff --git a/xtrack/monitors/__init__.py b/xtrack/monitors/__init__.py index f726ab2fb..d8bf6d53d 100644 --- a/xtrack/monitors/__init__.py +++ b/xtrack/monitors/__init__.py @@ -4,3 +4,5 @@ from .beam_position_monitor import * from .beam_size_monitor import * from .beam_profile_monitor import * + +monitor_classes = tuple(v for v in globals().values() if isinstance(v, type) and issubclass(v, BeamElement)) diff --git a/xtrack/multisetter/multisetter.py b/xtrack/multisetter/multisetter.py index a751a8e96..01dfffb1d 100644 --- a/xtrack/multisetter/multisetter.py +++ b/xtrack/multisetter/multisetter.py @@ -221,6 +221,29 @@ def set_values(self, values): self._set_kernel(data=self, buffer=self._tracker_buffer.buffer, input=xt.BeamElement._arr2ctx(self, values)) + def compile_kernels(self, only_if_needed=True): + context = self._buffer.context + if context.allow_prebuilt_kernels and only_if_needed: + try: + from xsuite import ( + get_suitable_kernel, + XSK_PREBUILT_KERNELS_LOCATION, + ) + kernel_info = get_suitable_kernel({}, ()) + except ImportError: + kernel_info = None + + if kernel_info: + module_name, _ = kernel_info + kernels = context.kernels_from_file( + module_name=module_name, + containing_dir=XSK_PREBUILT_KERNELS_LOCATION, + kernel_descriptions=self._kernels, + ) + context.kernels.update(kernels) + + super().compile_kernels(only_if_needed=only_if_needed) + def _extract_offset(obj, field_name, index, dtype, xodtype): diff --git a/xtrack/particles/particles.py b/xtrack/particles/particles.py index 7db62f782..a16289229 100644 --- a/xtrack/particles/particles.py +++ b/xtrack/particles/particles.py @@ -13,7 +13,6 @@ import xobjects as xo from xobjects.general import Print from xobjects import BypassLinked -from xtrack.prebuild_kernels import XT_PREBUILT_KERNELS_LOCATION from .constants import PROTON_MASS_EV @@ -1006,16 +1005,23 @@ def _init_random_number_generator(self, seeds=None): """ context = self._buffer.context if context.allow_prebuilt_kernels: - from xtrack.prebuild_kernels import get_suitable_kernel _print_state = Print.suppress Print.suppress = True - kernel_info = get_suitable_kernel({}, (self.__class__._XoStruct,)) + try: + from xsuite import ( + get_suitable_kernel, + XSK_PREBUILT_KERNELS_LOCATION, + ) + kernel_info = get_suitable_kernel({}, ()) + except ImportError: + kernel_info = None + Print.suppress = _print_state if kernel_info: module_name, _ = kernel_info kernels = context.kernels_from_file( module_name=module_name, - containing_dir=XT_PREBUILT_KERNELS_LOCATION, + containing_dir=XSK_PREBUILT_KERNELS_LOCATION, kernel_descriptions=self._kernels, ) context.kernels.update(kernels) diff --git a/xtrack/prebuild_kernels.py b/xtrack/prebuild_kernels.py deleted file mode 100644 index 2e27c9151..000000000 --- a/xtrack/prebuild_kernels.py +++ /dev/null @@ -1,352 +0,0 @@ -# copyright ################################# # -# This file is part of the Xobjects Package. # -# Copyright (c) CERN, 2023. # -# ########################################### # -import os -import json -import logging -from pathlib import Path -from pprint import pformat -from typing import Iterator, Optional, Tuple - -from .general import _print - -import numpy as np - -import xobjects as xo -import xtrack as xt - - -LOGGER = logging.getLogger(__name__) - -XT_PREBUILT_KERNELS_LOCATION = Path(xt.__file__).parent / 'prebuilt_kernels' - -BEAM_ELEMENTS_INIT_DEFAULTS = { - 'Bend': { - 'length': 1., - }, - 'Quadrupole': { - 'length': 1., - }, - 'Solenoid': { - 'length': 1., - }, - 'BeamBeamBiGaussian2D': { - 'other_beam_Sigma_11': 1., - 'other_beam_Sigma_33': 1., - 'other_beam_num_particles': 0., - 'other_beam_q0': 1., - 'other_beam_beta0': 1., - }, - 'BeamBeamBiGaussian3D': { - 'slices_other_beam_zeta_center': np.array([0]), - 'slices_other_beam_num_particles': np.array([0]), - 'phi': 0., - 'alpha': 0, - 'other_beam_q0': 1., - 'slices_other_beam_Sigma_11': np.array([1]), - 'slices_other_beam_Sigma_12': np.array([0]), - 'slices_other_beam_Sigma_22': np.array([0]), - 'slices_other_beam_Sigma_33': np.array([1]), - 'slices_other_beam_Sigma_34': np.array([0]), - 'slices_other_beam_Sigma_44': np.array([0]), - }, - 'LimitPolygon': { - 'x_vertices': np.array([0, 1, 1, 0]), - 'y_vertices': np.array([0, 0, 1, 1]), - }, -} - - -# SpaceChargeBiGaussian is not included for now (different issues - -# circular import, incompatible compilation flags) -# try: -# from xfields import LongitudinalProfileQGaussian - -# BEAM_ELEMENTS_INIT_DEFAULTS['SpaceChargeBiGaussian'] = { -# 'longitudinal_profile': LongitudinalProfileQGaussian( -# number_of_particles=0, sigma_z=1), -# } -# except ModuleNotFoundError: -# LOGGER.warning('Prebuilding kernels might fail, as xfields is not ' -# 'installed.') - - -def get_element_class_by_name(name: str) -> type: - try: - from xfields import element_classes as xf_element_classes - except ModuleNotFoundError: - xf_element_classes = () - - try: - from xcoll import element_classes as xc_element_classes - except ModuleNotFoundError: - xc_element_classes = () - - xt_rng_classes = tuple([getattr(xt, cls) - for cls in dir(xt.random) - if cls.startswith('Random')]) - xt_multisetter = (xt.MultiSetter, ) - - # from xtrack.monitors import generate_monitor_class - # monitor_cls = generate_monitor_class(xp.Particles) - xt_monitor_classes = (xt.ParticlesMonitor, ) - - element_classes = xt.element_classes + xt_rng_classes \ - + xt_monitor_classes + xt_multisetter \ - + xf_element_classes + xc_element_classes - - for cls in element_classes: - if cls.__name__ == name: - return cls - - raise ValueError(f'No element class with name {name} available.') - - -def save_kernel_metadata( - module_name: str, - config: dict, - kernel_element_classes, -): - out_file = XT_PREBUILT_KERNELS_LOCATION / f'{module_name}.json' - - try: - import xfields - xf_version = xfields.__version__ - except ModuleNotFoundError: - xf_version = None - - try: - import xcoll - xc_version = xcoll.__version__ - except ModuleNotFoundError: - xc_version = None - - kernel_metadata = { - 'config': config.data, - 'classes': [cls._DressingClass.__name__ for cls in kernel_element_classes], - 'versions': { - 'xtrack': xt.__version__, - 'xfields': xf_version, - 'xcoll': xc_version, - 'xobjects': xo.__version__, - } - } - - with out_file.open('w') as fd: - json.dump(kernel_metadata, fd, indent=4) - - -def enumerate_kernels() -> Iterator[Tuple[str, dict]]: - """ - Iterate over the prebuilt kernels compatible with the current version of - xsuite. The first element of the tuple is the name of the kernel module - and the second is a dictionary with the kernel metadata. - """ - from xtrack.prebuilt_kernels.kernel_definitions import kernel_definitions - for kernel_name, _ in kernel_definitions: - metadata_file = XT_PREBUILT_KERNELS_LOCATION / f'{kernel_name}.json' - - if not metadata_file.exists(): - continue - - with metadata_file.open('r') as fd: - kernel_metadata = json.load(fd) - - try: - import xfields - xf_version = xfields.__version__ - except ModuleNotFoundError: - xf_version = None - - try: - import xcoll - xc_version = xcoll.__version__ - except ModuleNotFoundError: - xc_version = None - - if kernel_metadata['versions']['xtrack'] != xt.__version__: - continue - - if kernel_metadata['versions']['xobjects'] != xo.__version__: - continue - - if (kernel_metadata['versions']['xfields'] != xf_version - and xf_version is not None): - continue - - if (kernel_metadata['versions']['xcoll'] != xc_version - and xc_version is not None): - continue - - yield metadata_file.stem, kernel_metadata - - -def get_suitable_kernel( - config: dict, - line_element_classes, - verbose=False, -) -> Optional[Tuple[str, list]]: - """ - Given a configuration and a list of element classes, return a tuple with - the name of a suitable prebuilt kernel module together with the list of - element classes that were used to build it. Set `verbose` to True, to - obtain a justification of the choice (or lack thereof) on standard output. - """ - - env_var = os.environ.get("XSUITE_PREBUILT_KERNELS") - if env_var and env_var == '0': - if verbose: - _print('Skipping the search for a suitable kernel, as the ' - 'environment variable XSUITE_PREBUILT_KERNELS == "0".') - return - - requested_class_names = [ - cls._DressingClass.__name__ for cls in line_element_classes - ] - # Hack: we don't select on particles class as prebuild kernels anyway only - # work for xp.Particles - requested_class_names = [cls for cls in requested_class_names - if cls != 'Particles' and cls != 'ParticlesBase'] - - for module_name, kernel_metadata in enumerate_kernels(): - if verbose: - _print(f"==> Considering the precompiled kernel `{module_name}`...") - - available_classes_names = kernel_metadata['classes'] - if kernel_metadata['config'] != config: - if verbose: - lhs = kernel_metadata['config'] - rhs = config - config_diff = {kk: (lhs.get(kk), rhs.get(kk)) - for kk in set(lhs.keys()) | set(rhs.keys()) - if lhs.get(kk) != rhs.get(kk)} - _print(f'The kernel `{module_name}` is unsuitable. Its config ' - f'(left) and the requested one (right) differ at the ' - f'following keys:\n' - f'{pformat(config_diff)}') - _print(f'Skipping class compatibility check for `{module_name}`.') - - continue - - if verbose: - _print(f'The kernel `{module_name}` has the right config.') - - if set(requested_class_names) <= set(available_classes_names): - available_classes = [ - get_element_class_by_name(class_name) - for class_name in available_classes_names - ] - _print(f'Found suitable prebuilt kernel `{module_name}`.') - return module_name, available_classes - elif verbose: - class_diff = set(requested_class_names) - set(available_classes_names) - _print(f'The kernel `{module_name}` is unsuitable. It does not ' - f'provide the following requested classes: ' - f'{", ".join(class_diff)}.') - - if verbose: - _print('==> No suitable precompiled kernel found.') - - -def regenerate_kernels(kernels=None): - """ - Use the kernel definitions in the `kernel_definitions.py` file to - regenerate kernel shared objects using the current version of xsuite. - """ - if kernels is not None and ( - isinstance(kernels, str) or not hasattr(kernels, '__iter__')): - kernels = [kernels] - - # Delete existing kernels to avoid accidentally loading in existing C code - clear_kernels(kernels) - - import xpart as xp - from xtrack.prebuilt_kernels.kernel_definitions import kernel_definitions - try: - import xcoll as xc - BEAM_ELEMENTS_INIT_DEFAULTS['EverestBlock'] = { - 'material': xc.materials.Silicon, - 'use_prebuilt_kernels': False - } - BEAM_ELEMENTS_INIT_DEFAULTS['EverestCollimator'] = { - 'material': xc.materials.Silicon, - 'use_prebuilt_kernels': False - } - BEAM_ELEMENTS_INIT_DEFAULTS['EverestCrystal'] = { - 'material': xc.materials.SiliconCrystal, - 'use_prebuilt_kernels': False - } - except ImportError: - pass - - for module_name, metadata in kernel_definitions: - if kernels is not None and module_name not in kernels: - continue - - config = metadata['config'] - element_classes = metadata['classes'] - extra_classes = metadata.get('extra_classes', []) - - elements = [] - for cls in element_classes: - if cls.__name__ in BEAM_ELEMENTS_INIT_DEFAULTS: - element = cls(**BEAM_ELEMENTS_INIT_DEFAULTS[cls.__name__]) - else: - element = cls() - elements.append(element) - - line = xt.Line(elements=elements) - tracker = xt.Tracker(line=line, compile=False, _prebuilding_kernels=True) - tracker.config.clear() - tracker.config.update(config) - - # Get all kernels in the elements - extra_kernels = {} - extra_classes.append(xp.Particles) - extra_classes = [getattr(el, '_XoStruct', el) for el in extra_classes] - all_classes = tracker._tracker_data_base.kernel_element_classes + extra_classes - for el in all_classes: - extra_kernels.update(el._kernels) - - # TODO: Add any other kernels that are defined in the context - # Need to add the source etc - # kernel_descriptions.update(tracker._context.kernels) - - tracker._build_kernel( - module_name=module_name, - containing_dir=XT_PREBUILT_KERNELS_LOCATION, - compile='force', - extra_classes=extra_classes, - extra_kernels=extra_kernels, - ) - - all_classes = [cls for cls in all_classes - if cls.__name__ != 'ParticlesData'] - save_kernel_metadata( - module_name=module_name, - config=tracker.config, - kernel_element_classes=all_classes, - ) - - -def clear_kernels(kernels=None, verbose=False): - if kernels is not None and ( - isinstance(kernels, str) or not hasattr(kernels, '__iter__')): - kernels = [kernels] - for file in XT_PREBUILT_KERNELS_LOCATION.iterdir(): - if file.name.startswith('_'): - continue - if file.suffix not in ('.c', '.so', '.json'): - continue - if kernels is not None and file.stem.split('.')[0] not in kernels: - continue - file.unlink() - - if verbose: - print(f'Removed `{file}`.') - - -if __name__ == '__main__': - regenerate_kernels() - diff --git a/xtrack/prebuilt_kernels/__init__.py b/xtrack/prebuilt_kernels/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/xtrack/prebuilt_kernels/kernel_definitions.py b/xtrack/prebuilt_kernels/kernel_definitions.py deleted file mode 100644 index 3441741fe..000000000 --- a/xtrack/prebuilt_kernels/kernel_definitions.py +++ /dev/null @@ -1,238 +0,0 @@ -# copyright ################################# # -# This file is part of the Xobjects Package. # -# Copyright (c) CERN, 2023. # -# ########################################### # -import logging - -from xtrack.beam_elements import * -from xtrack.random import * -from xtrack.multisetter import MultiSetter - -LOGGER = logging.getLogger(__name__) - -BASE_CONFIG = { - 'XTRACK_MULTIPOLE_NO_SYNRAD': True, - 'XFIELDS_BB3D_NO_BEAMSTR': True, - 'XFIELDS_BB3D_NO_BHABHA': True, - 'XTRACK_GLOBAL_XY_LIMIT': 1.0, -} - -FREEZE_ENERGY = { - 'FREEZE_VAR_delta': True, - 'FREEZE_VAR_ptau': True, - 'FREEZE_VAR_rpp': True, - 'FREEZE_VAR_rvv': True, -} - -FREEZE_LONGITUDINAL = { - **FREEZE_ENERGY, - 'FREEZE_VAR_zeta': True, -} - -ONLY_XTRACK_ELEMENTS = [ - Drift, - Multipole, - Marker, - ReferenceEnergyIncrease, - Cavity, - XYShift, - ZetaShift, - Elens, - Wire, - SRotation, - YRotation, - Solenoid, - RFMultipole, - DipoleEdge, - SimpleThinBend, - SimpleThinQuadrupole, - LineSegmentMap, - NonLinearLens, - LimitEllipse, - LimitRectEllipse, - LimitRect, - LimitRacetrack, - LimitPolygon, - DriftSlice, - DriftSliceBend, - DriftSliceOctupole, - DriftSliceQuadrupole, - DriftSliceSextupole, - ThickSliceBend, - ThickSliceOctupole, - ThickSliceQuadrupole, - ThickSliceSextupole, - ThickSliceSolenoid, - ThinSliceBend, - ThinSliceBendEntry, - ThinSliceBendExit, - ThinSliceOctupole, - ThinSliceQuadrupole, - ThinSliceSextupole, -] - -NO_SYNRAD_ELEMENTS = [ - Bend, - Quadrupole, - Sextupole, - Octupole, -] - -NON_TRACKING_ELEMENTS = [ - RandomUniform, - RandomExponential, - RandomNormal, - RandomRutherford, - MultiSetter -] - -# These are enumerated in order specified below: the highest priority at the top -kernel_definitions = [ - ('default_only_xtrack_no_config', { - 'config': {}, - 'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS, - }), - ('default_only_xtrack', { - 'config': BASE_CONFIG, - 'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS, - }), - ('default_only_xtrack_no_limit', { - 'config': { - **{k: v for k, v in BASE_CONFIG.items() - if k != 'XTRACK_GLOBAL_XY_LIMIT'} - }, - 'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS, - }), - ('only_xtrack_non_tracking_kernels', { - 'config': BASE_CONFIG, - 'classes': [], - 'extra_classes': NON_TRACKING_ELEMENTS - }), - ('default_only_xtrack_backtrack', { - 'config': {**BASE_CONFIG, 'XSUITE_BACKTRACK': True}, - 'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS, - }), - ('default_only_xtrack_backtrack_no_limit', { - 'config': { - **{k: v for k, v in BASE_CONFIG.items() - if k != 'XTRACK_GLOBAL_XY_LIMIT'}, - 'XSUITE_BACKTRACK': True - }, - 'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS, - }), - ('only_xtrack_frozen_longitudinal', { - 'config': {**BASE_CONFIG, **FREEZE_LONGITUDINAL}, - 'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS, - }), - ('only_xtrack_frozen_energy', { - 'config': {**BASE_CONFIG, **FREEZE_ENERGY}, - 'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS, - }), - ('only_xtrack_backtrack_frozen_energy', { - 'config': {**BASE_CONFIG, **FREEZE_ENERGY, 'XSUITE_BACKTRACK': True}, - 'classes': ONLY_XTRACK_ELEMENTS + NO_SYNRAD_ELEMENTS, - }), - ('only_xtrack_taper', { - 'config': { - **BASE_CONFIG, - 'XTRACK_MULTIPOLE_NO_SYNRAD': False, - 'XTRACK_MULTIPOLE_TAPER': True, - 'XTRACK_DIPOLEEDGE_TAPER': True, - }, - 'classes': ONLY_XTRACK_ELEMENTS, - }), - ('only_xtrack_with_synrad', { - 'config': {**BASE_CONFIG, 'XTRACK_MULTIPOLE_NO_SYNRAD': False}, - 'classes': ONLY_XTRACK_ELEMENTS, - }), - ('only_xtrack_with_synrad_kick_as_co', { - 'config': { - **BASE_CONFIG, 'XTRACK_MULTIPOLE_NO_SYNRAD': False, - 'XTRACK_SYNRAD_KICK_SAME_AS_FIRST': True - }, - 'classes': ONLY_XTRACK_ELEMENTS, - }), -] - - -try: - import xfields as xf - - DEFAULT_BB3D_ELEMENTS = [ - *ONLY_XTRACK_ELEMENTS, - xf.BeamBeamBiGaussian2D, - xf.BeamBeamBiGaussian3D, - ] - - kernel_definitions.append(('default_bb3d', { - 'config': BASE_CONFIG, - 'classes': [*DEFAULT_BB3D_ELEMENTS, LineSegmentMap], - })) - - kernel_definitions.append(('default_bb3d_no_config', { - 'config': {}, - 'classes': [*DEFAULT_BB3D_ELEMENTS, LineSegmentMap], - })) - -except ImportError: - LOGGER.warning('Xfields not installed, skipping BB3D elements') - - -try: - import xcoll as xc - - DEFAULT_XCOLL_ELEMENTS = [ - *ONLY_XTRACK_ELEMENTS, - *NO_SYNRAD_ELEMENTS, - ZetaShift, - xc.BlackAbsorber, - xc.EverestBlock, - xc.EverestCollimator, - xc.EverestCrystal - ] - - kernel_definitions += [ - ('default_xcoll', { - 'config': BASE_CONFIG, - 'classes': DEFAULT_XCOLL_ELEMENTS, - }), - ('default_xcoll_no_config', { - 'config': {}, - 'classes': DEFAULT_XCOLL_ELEMENTS, - }), - ('default_xcoll_no_limit', { - 'config': { - **{k: v for k, v in BASE_CONFIG.items() - if k != 'XTRACK_GLOBAL_XY_LIMIT'} - }, - 'classes': DEFAULT_XCOLL_ELEMENTS, - }), - ('default_xcoll_frozen_longitudinal', { - 'config': {**BASE_CONFIG, **FREEZE_LONGITUDINAL}, - 'classes': DEFAULT_XCOLL_ELEMENTS, - }), - ('default_xcoll_frozen_energy', { - 'config': {**BASE_CONFIG, **FREEZE_ENERGY}, - 'classes': DEFAULT_XCOLL_ELEMENTS, - }), - ('default_xcoll_backtrack', { - 'config': {**BASE_CONFIG, 'XSUITE_BACKTRACK': True}, - 'classes': DEFAULT_XCOLL_ELEMENTS, - }), - ('default_xcoll_backtrack_no_limit', { - 'config': { - **{k: v for k, v in BASE_CONFIG.items() - if k != 'XTRACK_GLOBAL_XY_LIMIT'}, - 'XSUITE_BACKTRACK': True - }, - 'classes': DEFAULT_XCOLL_ELEMENTS, - }), - ('default_xcoll_backtrack_frozen_energy', { - 'config': {**BASE_CONFIG, **FREEZE_ENERGY, 'XSUITE_BACKTRACK': True}, - 'classes': DEFAULT_XCOLL_ELEMENTS, - }), - ] - -except ImportError: - LOGGER.warning('Xcoll not installed, skipping collimator elements') - diff --git a/xtrack/random/__init__.py b/xtrack/random/__init__.py index 14d96855b..d16bbd8ed 100644 --- a/xtrack/random/__init__.py +++ b/xtrack/random/__init__.py @@ -1 +1,3 @@ from .random_generators import RandomUniform, RandomExponential, RandomNormal, RandomRutherford + +rng_classes = (RandomUniform, RandomExponential, RandomNormal, RandomRutherford) diff --git a/xtrack/random/random_generators.py b/xtrack/random/random_generators.py index ac71a9bc6..c62e6947d 100644 --- a/xtrack/random/random_generators.py +++ b/xtrack/random/random_generators.py @@ -181,7 +181,7 @@ def __init__(self, **kwargs): raise ValueError('Rutherford random generator is not currently supported on GPU.') def set_parameters(self, A, B, lower_val, upper_val): - self.compile_kernels(particles_class=xp.Particles, only_if_needed=True) + self.compile_kernels(particles_class=xt.Particles, only_if_needed=True) context = self._buffer.context context.kernels.set_rutherford(rng=self, A=A, B=B, lower_val=lower_val, upper_val=upper_val) diff --git a/xtrack/tracker.py b/xtrack/tracker.py index c0002d176..be7db0955 100644 --- a/xtrack/tracker.py +++ b/xtrack/tracker.py @@ -23,7 +23,6 @@ from .pipeline import PipelineStatus from .progress_indicator import progress from .tracker_data import TrackerData -from .prebuild_kernels import get_suitable_kernel, XT_PREBUILT_KERNELS_LOCATION logger = logging.getLogger(__name__) @@ -429,15 +428,24 @@ def _build_kernel( ): if compile == 'force': use_prebuilt_kernels = False - elif not self._context.allow_prebuilt_kernels: # only CPU serial + elif not self._context.allow_prebuilt_kernels: # only CPU serial use_prebuilt_kernels = False else: use_prebuilt_kernels = self.use_prebuilt_kernels if use_prebuilt_kernels: - kernel_info = get_suitable_kernel( - self.config, self.line_element_classes - ) + try: + from xsuite import ( + get_suitable_kernel, + XSK_PREBUILT_KERNELS_LOCATION, + ) + except ImportError: + kernel_info = None + else: + kernel_info = get_suitable_kernel( + self.config, self.line_element_classes + ) + if kernel_info: module_name, modules_classes = kernel_info @@ -445,7 +453,7 @@ def _build_kernel( modules_classes)['track_line'] kernels = self._context.kernels_from_file( module_name=module_name, - containing_dir=XT_PREBUILT_KERNELS_LOCATION, + containing_dir=XSK_PREBUILT_KERNELS_LOCATION, kernel_descriptions={'track_line': kernel_description}, ) return kernels['track_line'] @@ -708,6 +716,9 @@ def get_kernel_descriptions(self, kernel_element_classes): # Random number generator init kernel kernel_descriptions.update(xt.Particles._kernels) + # Multisetter + kernel_descriptions.update(xt.MultiSetter._kernels) + return kernel_descriptions def _prepare_collective_track_session(self, particles, ele_start, ele_stop, @@ -1458,6 +1469,7 @@ def __getstate__(self): return state def check_compatibility_with_prebuilt_kernels(self): + from xsuite import get_suitable_kernel get_suitable_kernel( config=self.line.config, line_element_classes=self.line_element_classes,