Skip to content

Commit

Permalink
unittests and bugfixes for the option trainer.weights_summary
Browse files Browse the repository at this point in the history
  • Loading branch information
MalteEbner committed Jun 17, 2021
1 parent 16a0b07 commit 498d643
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
5 changes: 5 additions & 0 deletions lightly/cli/train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ def _train_cli(cfg, is_cli_call=True):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if cfg["trainer"]["weights_summary"] == "None":
cfg["trainer"]["weights_summary"] = None

if torch.cuda.is_available():
device = 'cuda'
elif cfg['trainer'] and cfg['trainer']['gpus']:
device = 'cpu'
cfg['trainer']['gpus'] = 0
else:
device = 'cpu'

if cfg['loader']['batch_size'] < 64:
msg = 'Training a self-supervised model with a small batch size: {}! '
Expand Down
75 changes: 75 additions & 0 deletions tests/cli/test_cli_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
import re
import sys
import tempfile

import torchvision
from hydra.experimental import compose, initialize

import lightly
from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup, MockedApiWorkflowClient


class TestCLITrain(MockedApiWorkflowSetup):

@classmethod
def setUpClass(cls) -> None:
sys.modules["lightly.cli.upload_cli"].ApiWorkflowClient = MockedApiWorkflowClient

def setUp(self):
MockedApiWorkflowSetup.setUp(self)
self.create_fake_dataset()
with initialize(config_path="../../lightly/cli/config", job_name="test_app"):
self.cfg = compose(config_name="config", overrides=[
"token='123'",
f"input_dir={self.folder_path}",
"trainer.max_epochs=0"
])

def create_fake_dataset(self):
n_data = 5
self.dataset = torchvision.datasets.FakeData(size=n_data, image_size=(3, 32, 32))

self.folder_path = tempfile.mkdtemp()
sample_names = [f'img_{i}.jpg' for i in range(n_data)]
self.sample_names = sample_names
for sample_idx in range(n_data):
data = self.dataset[sample_idx]
path = os.path.join(self.folder_path, sample_names[sample_idx])
data[0].save(path)

def parse_cli_string(self, cli_words: str):
cli_words = cli_words.replace("lightly-train ", "")
cli_words = re.split("=| ", cli_words)
assert len(cli_words) % 2 == 0
dict_keys = cli_words[0::2]
dict_values = cli_words[1::2]
for key, value in zip(dict_keys, dict_values):
value = value.strip('\"')
value = value.strip('\'')
key_parts = key.split(".")
if len(key_parts) == 1:
self.cfg[key_parts[0]]= value
elif len(key_parts) == 2:
self.cfg[key_parts[0]][key_parts[1]] = value
else:
raise ValueError

def test_parse_cli_string(self):
cli_string = "lightly-train trainer.weights_summary=top"
self.parse_cli_string(cli_string)
assert self.cfg["trainer"]["weights_summary"] == 'top'

def test_train_weights_summary(self):
for weights_summary in ["None", "top", "full"]:
cli_string = f"lightly-train trainer.weights_summary={weights_summary}"
with self.subTest(cli_string):
self.parse_cli_string(cli_string)
lightly.cli.train_cli(self.cfg)

def tearDown(self) -> None:
for filename in ["embeddings.csv", "embeddings_sorted.csv"]:
try:
os.remove(filename)
except FileNotFoundError:
pass

0 comments on commit 498d643

Please sign in to comment.