Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Update model_to() and load_torch_model() methods in ModelABC #733

Merged
merged 14 commits into from
Nov 15, 2023

Conversation

AbishekRajVG
Copy link
Contributor

This PR tracks the changes as per suggestions from @mostafajahanifar in PR #635

Suggestion 1

To follow the usual convention of moving model first and then using DataParallelism, I would suggest improving this function like below:

def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
    """Transfers model to cpu/gpu.
    Args:
        model (torch.nn.Module):
            PyTorch defined model.
        device (str):
            Transfers model to the specified device. Default is "cpu".
    Returns:
        torch.nn.Module:
            The model after being moved to cpu/gpu.
    """
    device = torch.device(device)
    model = model.to(device)
    
    # If target device is CUDA and more than one GPU is available, use DataParallel
    if device.type == "cuda" and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    return model

This will also avoid unnecessary overhead of DataParallel is there is only one GPU available.

Again, this can be integrated as a method into the ModelABC class. I mean, it should already has to method inherited from nn.Module. However, if we need torch.nn.DataParallel, we can replace that to method with this one. Then users can call: my_model.to(device)

Suggestion 2

why not move this function into the ModelABC class as a method? So, users can load model weights for our models just like they do with normal Pytorch models? is it possible something like below:

my_model.load_weights_from_path(path) or my_model.load(path)

I assume because ModelABC is inheriting from nn.module, it should already have load_state_dict method.

Copy link

codecov bot commented Nov 8, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (4a041ae) 99.85% compared to head (130ade2) 99.85%.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop     #733   +/-   ##
========================================
  Coverage    99.85%   99.85%           
========================================
  Files           65       65           
  Lines         7508     7517    +9     
  Branches      1460     1460           
========================================
+ Hits          7497     7506    +9     
  Misses           4        4           
  Partials         7        7           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@shaneahmed shaneahmed added this to the Release v2.0.0 milestone Nov 10, 2023
@shaneahmed shaneahmed added the enhancement New feature or request label Nov 10, 2023
@shaneahmed shaneahmed changed the title Update model_to() and load_torch_model() methods in ModelABC in develop ♻️ Update model_to() and load_torch_model() methods in ModelABC Nov 10, 2023
@shaneahmed shaneahmed merged commit cca7443 into develop Nov 15, 2023
14 checks passed
@shaneahmed shaneahmed deleted the dev-update-model-abc branch November 15, 2023 12:17
@shaneahmed shaneahmed mentioned this pull request Dec 15, 2023
shaneahmed added a commit that referenced this pull request Dec 15, 2023
## 1.5.0 (2023-12-15)

### Major Updates and Feature Improvements

- Adds the bokeh visualization tool. #684
  - The tool allows a user to launch a server on their machine to visualise whole slide images, overlay the results of deep learning algorithms or to select a patch from whole slide image and run TIAToolbox deep learning engines.
  - This tool powers the TIA demos server. For details please see https://tiademos.dcs.warwick.ac.uk/.
- Extends Annotation to Support Init from WKB #639
- Adds `IOConfig` for NuClick in `pretrained_model.yaml` #709
- Adds functions to save the TIAToolbox Engine outputs to Zarr and AnnotationStore files. #724
- Adds Support for QuPath Annotation Imports #721

### Changes to API

- Adds `model.to(device)` and `model.load_model_from_file()` functionality to make it compatible with PyTorch API. #733
- Replaces `pretrained` with `weights` to make the engines compatible with the new PyTorch API. #621
- Adds support for high-level imports for various utility functions and classes such as `WSIReader`, `PatchPredictor` and `imread` #606, #607,
- Adds `tiatoolbox.typing` for type hints. #619
- Fixes incorrect file size saved by `save_tiles`, issue with certain WSIs raised by @TomastpPereira
- TissueMasker transform now returns mask instead of a list. #748
  - Fixes #732

### Bug Fixes and Other Changes

- Fixes `pixman` incompability error on Colab #601
- Removes `shapely.speedups`. The module no longer has any affect in Shapely >=2.0. #622
- Fixes errors in the slidegraph example notebook #608
- Fixes bugs in WSI Registration #645, #670, #693
- Fixes the situation where PatchExtractor.get_coords() can return patch coords which lie fully outside the bounds of a slide. #712
  - Fixes #710
- Fixes #738 raised by @xiachenrui

### Development related changes

- Replaces `flake8` and `isort` with `ruff` #625, #666
- Adds `mypy` checks to `root` and `utils` package. This will be rolled out in phases to other modules. #723
- Adds a module to detect file types using magic number/signatures #616
- Uses `poetry` for version updates instead of `bump2version`. #638
- Removes `setup.cfg` and uses `pyproject.toml` for project configurations.
- Reduces runtime for some unit tests e.g., #627, #630, #631, #629
- Reuses models and datasets in tests on GitHub actions by utilising cache #641, #644
- Set up parallel tests locally #671

**Full Changelog:** v1.4.0...v1.5.0

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mostafajahanifar <74412979+mostafajahanifar@users.noreply.github.com>
Co-authored-by: John Pocock <John-P@users.noreply.github.com>
Co-authored-by: DavidBAEpstein <David.Epstein@warwick.ac.uk>
Co-authored-by: David Epstein <22086916+DavidBAEpstein@users.noreply.github.com>
Co-authored-by: Ruqayya Awan <18444369+ruqayya@users.noreply.github.com>
Co-authored-by: Mark Eastwood <20169086+measty@users.noreply.github.com>
Co-authored-by: adamshephard <39619155+adamshephard@users.noreply.github.com>
Co-authored-by: adamshephard <adam.shephard@warwick.ac.uk>
Co-authored-by: Abdol <a@fkrtech.com>
Co-authored-by: Jiaqi-Lv <60471431+Jiaqi-Lv@users.noreply.github.com>
Co-authored-by: Abishek <abishekraj6797@gmail.com>
Co-authored-by: Dmitrii Blaginin <blaginin@mbp.lan>
@shaneahmed shaneahmed mentioned this pull request Dec 15, 2023
shaneahmed added a commit that referenced this pull request Dec 15, 2023
## 1.5.0 (2023-12-15)

### Major Updates and Feature Improvements

- Adds the bokeh visualization tool. #684
  - The tool allows a user to launch a server on their machine to visualise whole slide images, overlay the results of deep learning algorithms or to select a patch from whole slide image and run TIAToolbox deep learning engines.
  - This tool powers the TIA demos server. For details please see https://tiademos.dcs.warwick.ac.uk/.
- Extends Annotation to Support Init from WKB #639
- Adds `IOConfig` for NuClick in `pretrained_model.yaml` #709
- Adds functions to save the TIAToolbox Engine outputs to Zarr and AnnotationStore files. #724
- Adds Support for QuPath Annotation Imports #721

### Changes to API

- Adds `model.to(device)` and `model.load_model_from_file()` functionality to make it compatible with PyTorch API. #733
- Replaces `pretrained` with `weights` to make the engines compatible with the new PyTorch API. #621
- Adds support for high-level imports for various utility functions and classes such as `WSIReader`, `PatchPredictor` and `imread` #606, #607,
- Adds `tiatoolbox.typing` for type hints. #619
- Fixes incorrect file size saved by `save_tiles`, issue with certain WSIs raised by @TomastpPereira
- TissueMasker transform now returns mask instead of a list. #748
  - Fixes #732

### Bug Fixes and Other Changes

- Fixes `pixman` incompability error on Colab #601
- Removes `shapely.speedups`. The module no longer has any affect in Shapely >=2.0. #622
- Fixes errors in the slidegraph example notebook #608
- Fixes bugs in WSI Registration #645, #670, #693
- Fixes the situation where PatchExtractor.get_coords() can return patch coords which lie fully outside the bounds of a slide. #712
  - Fixes #710
- Fixes #738 raised by @xiachenrui

### Development related changes

- Replaces `flake8` and `isort` with `ruff` #625, #666
- Adds `mypy` checks to `root` and `utils` package. This will be rolled out in phases to other modules. #723
- Adds a module to detect file types using magic number/signatures #616
- Uses `poetry` for version updates instead of `bump2version`. #638
- Removes `setup.cfg` and uses `pyproject.toml` for project configurations.
- Reduces runtime for some unit tests e.g., #627, #630, #631, #629
- Reuses models and datasets in tests on GitHub actions by utilising cache #641, #644
- Set up parallel tests locally #671

**Full Changelog:** v1.4.0...v1.5.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants