Skip to content

Commit

Permalink
Merge pull request #54 from neptune-ai/fix/optim_name_not_found
Browse files Browse the repository at this point in the history
Don't fail if optim name not found.
  • Loading branch information
kshitij12345 authored Jul 27, 2023
2 parents 2c35a8e + de5ec04 commit 0cd7691
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## neptune-fastai 1.1.1

### Fixes
- Don't error if `optim.__name__` is not present. (https://github.com/neptune-ai/neptune-fastai/pull/54)

## neptune-fastai 1.1.0

### Changes
Expand Down
12 changes: 11 additions & 1 deletion src/neptune_fastai/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,17 @@ def _batch_size(self) -> int:

@property
def _optimizer_name(self) -> Optional[str]:
return self.opt_func.__name__
NA = "N/A"
optim_name = getattr(self.opt_func, "__name__", NA)
if optim_name == NA:
warning_msg = (
"NeptuneCallback: Couldn't retrieve the optimizer name, "
"so it will be logged as 'N/A'. You can set the optimizer "
"name by assigning it to the __name__ attribute. "
"Eg. >>> optimizer.__name__ = 'NAME'"
)
warnings.warn(warning_msg)
return optim_name

@property
def _device(self) -> str:
Expand Down
6 changes: 6 additions & 0 deletions tests/neptune_fastai/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from functools import partial
from itertools import islice
from pathlib import Path

Expand All @@ -24,6 +25,7 @@
untar_data,
)
from fastai.callback.all import SaveModelCallback
from fastai.optimizer import Adam
from fastai.tabular.all import (
Categorify,
FillMissing,
Expand Down Expand Up @@ -71,12 +73,15 @@ def test_vision_classification_with_handler(self):
device=torch.device("cpu"),
)

opt_func = partial(Adam, lr=3e-3, wd=0.01)

learn = cnn_learner(
dls,
squeezenet1_0,
metrics=error_rate,
cbs=[NeptuneCallback(run, "experiment")],
pretrained=False,
opt_func=opt_func,
)

learn.fit(1)
Expand All @@ -91,6 +96,7 @@ def test_vision_classification_with_handler(self):
exp_config = run["experiment/config"].fetch()
assert exp_config["batch_size"] == 64
assert exp_config["criterion"] == "CrossEntropyLoss()"
assert exp_config["optimizer"]["name"] == "N/A"
assert exp_config["input_shape"] == {"x": "[3, 224, 224]", "y": 1}

# and
Expand Down

0 comments on commit 0cd7691

Please sign in to comment.