Skip to content

Commit

Permalink
Update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Oct 19, 2023
1 parent 2e7c650 commit 856aebe
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
21 changes: 16 additions & 5 deletions kgcnn/utils/devices.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from keras_core.backend import backend
from typing import Union
import logging

logging.basicConfig() # Module logger
Expand Down Expand Up @@ -54,17 +55,27 @@ def check_device():
return out_info


def set_gpu_device(device_id: int):
"""Set the cuda device by ID. Better use cuda environment variable to do this.
def set_cuda_device(device_id: Union[int, list]):
"""Set the cuda device by ID.
Better use cuda environment variable to do this instead of this function:
.. code-block:: python
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1" # specify which GPU(s) to be used
Args:
device_id (int): ID of the GPU to set.
"""

if backend() == "tensorflow":
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
gpus_use = gpus[device_id]
if isinstance(device_id, int):
gpus_use = gpus[device_id]
else:
gpus_use = [gpus[i] for i in device_id]
tf.config.set_visible_devices(gpus_use, 'GPU')
tf.config.experimental.set_memory_growth(gpus_use, True)

Expand All @@ -74,7 +85,7 @@ def set_gpu_device(device_id: int):

elif backend() == "jax":
import jax
raise NotImplementedError()
jax.default_device = jax.devices('gpu')[device_id]

else:
raise NotImplementedError("Backend %s is not supported for `check_device_cuda` .")
4 changes: 2 additions & 2 deletions training/train_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import kgcnn.training.scheduler
from kgcnn.data.utils import save_pickle_file
from kgcnn.data.transform.scaler.serial import deserialize as deserialize_scaler
from kgcnn.utils.devices import check_device, set_gpu_device
from kgcnn.utils.devices import check_device, set_cuda_device
from kgcnn.training.history import save_history_score, load_history_list, load_time_list
from kgcnn.utils.plots import plot_train_test_loss, plot_predict_true
from kgcnn.models.serial import deserialize as deserialize_model
Expand All @@ -35,7 +35,7 @@

# Check and set device
if args["gpu"] is not None:
set_gpu_device(args["gpu"])
set_cuda_device(args["gpu"])
check_device()

# Set seed.
Expand Down

0 comments on commit 856aebe

Please sign in to comment.