Skip to content

Commit

Permalink
add model.save test (#147)
Browse files Browse the repository at this point in the history
* add model.save test

* add model.load test

* fixes

* fixes
  • Loading branch information
aniketmaurya authored Dec 29, 2021
1 parent ed9e5c5 commit ef5b1ed
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 19 deletions.
4 changes: 1 addition & 3 deletions gradsflow/callbacks/raytune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import os
from typing import Optional

import torch
from ray import tune

from gradsflow.core.callbacks import Callback
Expand All @@ -40,11 +39,10 @@ class TorchTuneCheckpointCallback(Callback):

def on_epoch_end(self):
epoch = self.model.tracker.current_epoch
model = self.model.learner

with tune.checkpoint_dir(epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "filename")
torch.save((model.state_dict()), path)
self.model.save(path)


class TorchTuneReport(Callback):
Expand Down
17 changes: 12 additions & 5 deletions gradsflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union

import smart_open
import torch
from accelerate import Accelerator
from torch import nn
Expand Down Expand Up @@ -75,9 +76,6 @@ def assert_compiled(self):
if not self._compiled:
raise UserWarning("Model not compiled yet! Please call `model.compile(...)` first.")

def load_from_checkpoint(self, checkpoint):
self.learner = torch.load(checkpoint)

@torch.no_grad()
def predict(self, x):
return self.learner(x)
Expand Down Expand Up @@ -172,9 +170,18 @@ def train(self):
self.learner.requires_grad_(True)
self.learner.train()

def save(self, path: str, save_extra: bool = True):
def load_from_checkpoint(self, checkpoint):
data = torch.load(checkpoint)
if isinstance(data, dict):
self.learner = data["model"]
self.tracker = data["tracker"]
else:
self.learner = data

def save(self, path: str, save_extra: bool = False):
"""save model"""
model = self.learner
if save_extra:
model = {"model": self.learner, "tracker": self.tracker}
torch.save(model, path)
with smart_open.open(path, "wb") as f:
torch.save(model, f)
33 changes: 33 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Arrange
import pytest
import timm

from gradsflow import Model


@pytest.fixture
def resnet18():
cnn = timm.create_model("ssl_resnet18", pretrained=False, num_classes=10).eval()

return cnn


@pytest.fixture
def cnn_model(resnet18):
model = Model(resnet18)
model.TEST = True

return model
41 changes: 30 additions & 11 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,26 @@
model.TEST = True


def test_predict():
def test_predict(cnn_model):
x = torch.randn(1, 3, 64, 64)
r1 = model.forward(x)
r2 = model(x)
r3 = model.predict(x)
r1 = cnn_model.forward(x)
r2 = cnn_model(x)
r3 = cnn_model.predict(x)
assert torch.all(torch.isclose(r1, r2))
assert torch.all(torch.isclose(r2, r3))
assert isinstance(model.predict(torch.randn(1, 3, 64, 64)), torch.Tensor)


def test_fit():
model.TEST = True
def test_fit(cnn_model):
cnn_model.compile()
assert autodataset
tracker = model.fit(autodataset, max_epochs=1, steps_per_epoch=1, show_progress=True)
tracker = cnn_model.fit(autodataset, max_epochs=1, steps_per_epoch=1, show_progress=True)
assert isinstance(tracker, Tracker)

autodataset2 = AutoDataset(train_data.dataloader, num_classes=num_classes)
model.TEST = False
cnn_model.TEST = False
ckpt_cb = ModelCheckpoint(save_extra=False)
tracker2 = model.fit(
tracker2 = cnn_model.fit(
autodataset2,
max_epochs=1,
steps_per_epoch=1,
Expand Down Expand Up @@ -84,7 +84,26 @@ def compute_accuracy(*_, **__):
assert model2.optimizer.param_groups[0]["lr"] == 0.01


def test_set_accelerator():
model2 = Model(cnn, accelerator_config={"fp16": True})
def test_set_accelerator(resnet18):
model2 = Model(resnet18, accelerator_config={"fp16": True})
model2.compile()
assert model2.accelerator


def test_save_model(tmp_path, resnet18, cnn_model):
path = f"{tmp_path}/dummy_model.pth"
cnn_model.save(path, save_extra=True)
assert isinstance(torch.load(path), dict)

cnn_model.save(path, save_extra=False)
assert isinstance(torch.load(path), type(resnet18))


def test_load_from_checkpoint(tmp_path, cnn_model):
path = f"{tmp_path}/dummy_model.pth"
cnn_model.save(path, save_extra=True)
assert isinstance(torch.load(path), dict)

cnn_model.tracker.train.metrics["CHECK"] = True
cnn_model.load_from_checkpoint(path)
assert cnn_model.tracker.train.metrics["CHECK"]

0 comments on commit ef5b1ed

Please sign in to comment.