Skip to content

Commit

Permalink
Merge pull request #80 from grok-ai/develop
Browse files Browse the repository at this point in the history
Version 0.2.3
  • Loading branch information
Flegyas authored Dec 15, 2022
2 parents 32037ae + 2c88a79 commit 02e3d59
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion cookiecutter.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
"repository_url": "https://github.com/{{ cookiecutter.github_user }}/{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}",
"conda_env_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}",
"python_version": "3.9",
"__version": "0.2.2"
"__version": "0.2.3"
}
8 changes: 6 additions & 2 deletions {{ cookiecutter.repository_name }}/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ name: {{ cookiecutter.conda_env_name }}
channels:
- defaults
- pytorch
- nvidia

dependencies:
- python={{ cookiecutter.python_version }}
- pytorch=1.10.*
- torchvision=0.11.*
- pytorch==1.13.*
- pytorch-cuda=11.6
- torchvision
- torchaudio
- pip
- pip:
- -e .[dev]
7 changes: 4 additions & 3 deletions {{ cookiecutter.repository_name }}/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ package_dir=
=src
packages=find:
install_requires =
nn-template-core>=0.1.0,<0.2
nn-template-core==0.2.*

# Add project specific dependencies
# Stuff easy to break with updates
pytorch-lightning>=1.5.8,<1.6
hydra-core
pytorch-lightning==1.7.*
torchmetrics==0.10.*
hydra-core==1.2.*
wandb
streamlit
# hydra-joblib-launcher
Expand Down
3 changes: 2 additions & 1 deletion {{ cookiecutter.repository_name }}/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningModule
from pytorch_lightning.core.saving import _load_state

from nn_core.serialization import NNCheckpointIO
from tests.conftest import load_checkpoint
Expand All @@ -24,7 +25,7 @@ def test_load_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig

checkpoint = NNCheckpointIO.load(path=checkpoint_path)

module = module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint["metadata"])
module = _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint["metadata"])
assert module is not None
assert sum(p.numel() for p in module.parameters())

Expand Down
1 change: 0 additions & 1 deletion {{ cookiecutter.repository_name }}/tests/test_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,3 @@ def test_resume(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig, tmp_pat

assert old_checkpoint["run_path"] != new_checkpoint["run_path"]
assert old_checkpoint["global_step"] * 2 == new_checkpoint["global_step"]
assert new_checkpoint["epoch"] == 2

0 comments on commit 02e3d59

Please sign in to comment.