From 856aebe485281182485dfdde3e30127c7b539b44 Mon Sep 17 00:00:00 2001 From: patrickreiser Date: Thu, 19 Oct 2023 16:09:29 +0200 Subject: [PATCH] Update for keras 3.0 --- kgcnn/utils/devices.py | 21 ++++++++++++++++----- training/train_force.py | 4 ++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/kgcnn/utils/devices.py b/kgcnn/utils/devices.py index 74dedd03..156b859a 100644 --- a/kgcnn/utils/devices.py +++ b/kgcnn/utils/devices.py @@ -1,4 +1,5 @@ from keras_core.backend import backend +from typing import Union import logging logging.basicConfig() # Module logger @@ -54,17 +55,27 @@ def check_device(): return out_info -def set_gpu_device(device_id: int): - """Set the cuda device by ID. Better use cuda environment variable to do this. +def set_cuda_device(device_id: Union[int, list]): + """Set the cuda device by ID. + + Better use cuda environment variable to do this instead of this function: + + .. code-block:: python + + import os + os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"]="1" # specify which GPU(s) to be used Args: device_id (int): ID of the GPU to set. """ - if backend() == "tensorflow": import tensorflow as tf gpus = tf.config.list_physical_devices('GPU') - gpus_use = gpus[device_id] + if isinstance(device_id, int): + gpus_use = gpus[device_id] + else: + gpus_use = [gpus[i] for i in device_id] tf.config.set_visible_devices(gpus_use, 'GPU') tf.config.experimental.set_memory_growth(gpus_use, True) @@ -74,7 +85,7 @@ def set_gpu_device(device_id: int): elif backend() == "jax": import jax - raise NotImplementedError() + jax.default_device = jax.devices('gpu')[device_id] else: raise NotImplementedError("Backend %s is not supported for `check_device_cuda` .") \ No newline at end of file diff --git a/training/train_force.py b/training/train_force.py index 717c3b38..eb066f77 100644 --- a/training/train_force.py +++ b/training/train_force.py @@ -8,7 +8,7 @@ import kgcnn.training.scheduler from kgcnn.data.utils import save_pickle_file from kgcnn.data.transform.scaler.serial import deserialize as deserialize_scaler -from kgcnn.utils.devices import check_device, set_gpu_device +from kgcnn.utils.devices import check_device, set_cuda_device from kgcnn.training.history import save_history_score, load_history_list, load_time_list from kgcnn.utils.plots import plot_train_test_loss, plot_predict_true from kgcnn.models.serial import deserialize as deserialize_model @@ -35,7 +35,7 @@ # Check and set device if args["gpu"] is not None: - set_gpu_device(args["gpu"]) + set_cuda_device(args["gpu"]) check_device() # Set seed.