Skip to content

Commit

Permalink
switched environment to WSL
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreiMoraru123 committed Jul 23, 2023
1 parent 2bf6427 commit 3370ac2
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 16 deletions.
2 changes: 2 additions & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
XLA_FLAGS="--xla_gpu_cuda_data_dir=/usr/lib/cuda"
LD_LIBRARY_PATH="/usr/lib/cuda/lib64:"
11 changes: 6 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
tensorflow==2.10.0
pytest~=7.4.0
numpy~=1.25.1
pillow~=10.0.0
einops~=0.6.1
tensorflow==2.12.1
pytest~=7.3.1
numpy~=1.24.3
pillow~=9.4.0
einops~=0.6.1
python-dotenv~=1.0.0
29 changes: 22 additions & 7 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
# standard imports
import os
import pytest

# third-party imports
import pytest # type: ignore
import tensorflow as tf # type: ignore


@pytest.fixture(autouse=True)
def use_cpu_only():
"""Conv2D and tf.nn.Conv2D are broken on my CUDA/cuDNN version for some reason."""
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf.config.set_visible_devices([], 'GPU')
yield
@pytest.fixture(params=['CPU', 'GPU'], autouse=True)
def device(request):
"""Selects a runtime device for Tensorflow, so when parametrized with both CPU & GPU, tests will be run for both."""
device_type = request.param
if device_type == 'CPU':
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf.config.set_visible_devices([], 'GPU')
else:
os.environ.pop('CUDA_VISIBLE_DEVICES', None)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError:
print("You should really try out this other framework called PyTorch")

yield device_type
8 changes: 6 additions & 2 deletions train_srgan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# standard imports
from dotenv import load_dotenv

# third-party imports
import tensorflow as tf # type: ignore
from tensorflow.keras.losses import Loss # type:ignore
Expand All @@ -8,6 +11,8 @@
from transforms import ImageTransform
from model import Generator, Discriminator, TruncatedVGG19

load_dotenv()

# Data parameters
data_folder = './' # folder with JSON data files
crop_size = 96 # crop size of target HR images
Expand Down Expand Up @@ -43,9 +48,8 @@
content_loss = tf.keras.losses.MeanSquaredError()
adversarial_loss = tf.keras.losses.BinaryCrossentropy()

tf.config.set_visible_devices([], 'GPU')


@tf.function
def train_step(low_res_images, high_res_images, generator, discriminator, adversarial_loss,
truncated_vgg, optimizer_d, optimizer_g, content_loss, transform) -> Loss:
"""
Expand Down
7 changes: 5 additions & 2 deletions train_srresnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# standard imports
from dotenv import load_dotenv

# third-party imports
import tensorflow as tf # type: ignore
from tensorflow.keras.losses import Loss # type:ignore
Expand All @@ -7,6 +10,8 @@
from dataset import create_dataset
from model import SuperResolutionResNet

load_dotenv()

# Data parameters
data_folder = './' # folder with JSON data files
crop_size = 96 # crop size of target HR images
Expand All @@ -31,8 +36,6 @@
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
loss_fn = tf.keras.losses.MeanSquaredError()

tf.config.set_visible_devices([], 'GPU')


@tf.function
def train_step(low_res_images, high_res_images, model, optimizer, loss_fn) -> Loss:
Expand Down

0 comments on commit 3370ac2

Please sign in to comment.