diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index a4f113c367f..d67c221b766 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -87,7 +87,8 @@ case $GPU_ARCH_TYPE in ;; esac PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${CHANNEL}/${GPU_ARCH_ID}" -pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}" +# TODO: remove pinning of mpmath when https://github.com/pytorch/vision/issues/8292 is properly fixed. +pip install --progress-bar=off "mpmath<1.4" --pre torch --index-url="${PYTORCH_WHEEL_INDEX}" if [[ $GPU_ARCH_TYPE == 'cuda' ]]; then python -c "import torch; exit(not torch.cuda.is_available())" diff --git a/.github/workflows/build-conda-macos.yml b/.github/workflows/build-conda-macos.yml deleted file mode 100644 index 6f4929e27e3..00000000000 --- a/.github/workflows/build-conda-macos.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: Build Macos Conda - -on: - pull_request: - push: - branches: - - nightly - - main - - release/* - tags: - # NOTE: Binary build pipelines should only get triggered on release candidate builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ - workflow_dispatch: - -jobs: - generate-matrix: - uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main - with: - package-type: conda - os: macos - test-infra-repository: pytorch/test-infra - test-infra-ref: main - build: - needs: generate-matrix - strategy: - fail-fast: false - matrix: - include: - - repository: pytorch/vision - pre-script: "" - post-script: "" - conda-package-directory: packaging/torchvision - smoke-test-script: test/smoke_test.py - package-name: torchvision - name: ${{ matrix.repository }} - uses: pytorch/test-infra/.github/workflows/build_conda_macos.yml@main - with: - conda-package-directory: ${{ matrix.conda-package-directory }} - repository: ${{ matrix.repository }} - ref: "" - test-infra-repository: pytorch/test-infra - test-infra-ref: main - build-matrix: ${{ needs.generate-matrix.outputs.matrix }} - pre-script: ${{ matrix.pre-script }} - post-script: ${{ matrix.post-script }} - package-name: ${{ matrix.package-name }} - smoke-test-script: ${{ matrix.smoke-test-script }} - runner-type: macos-12 - trigger-event: ${{ github.event_name }} - secrets: - CONDA_PYTORCHBOT_TOKEN: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - CONDA_PYTORCHBOT_TOKEN_TEST: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} diff --git a/.github/workflows/build-wheels-macos.yml b/.github/workflows/build-wheels-macos.yml deleted file mode 100644 index 4c3820ddf13..00000000000 --- a/.github/workflows/build-wheels-macos.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: Build Macos Wheels - -on: - pull_request: - push: - branches: - - nightly - - main - - release/* - tags: - # NOTE: Binary build pipelines should only get triggered on release candidate builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ - workflow_dispatch: - -permissions: - id-token: write - contents: read - -jobs: - generate-matrix: - uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main - with: - package-type: wheel - os: macos - test-infra-repository: pytorch/test-infra - test-infra-ref: main - build: - needs: generate-matrix - strategy: - fail-fast: false - matrix: - include: - - repository: pytorch/vision - pre-script: packaging/pre_build_script.sh - post-script: packaging/post_build_script.sh - smoke-test-script: test/smoke_test.py - package-name: torchvision - name: ${{ matrix.repository }} - uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main - with: - repository: ${{ matrix.repository }} - ref: "" - test-infra-repository: pytorch/test-infra - test-infra-ref: main - build-matrix: ${{ needs.generate-matrix.outputs.matrix }} - pre-script: ${{ matrix.pre-script }} - post-script: ${{ matrix.post-script }} - package-name: ${{ matrix.package-name }} - runner-type: macos-12 - smoke-test-script: ${{ matrix.smoke-test-script }} - trigger-event: ${{ github.event_name }} diff --git a/references/classification/train.py b/references/classification/train.py index 978e7cf8acf..d52124fcf33 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -222,7 +222,7 @@ def main(args): num_classes = len(dataset.classes) mixup_cutmix = get_mixup_cutmix( - mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2 + mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_classes=num_classes, use_v2=args.use_v2 ) if mixup_cutmix is not None: diff --git a/references/classification/transforms.py b/references/classification/transforms.py index 3d10388c36f..5443437d29d 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -7,21 +7,21 @@ from torchvision.transforms import functional as F -def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2): +def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2): transforms_module = get_module(use_v2) mixup_cutmix = [] if mixup_alpha > 0: mixup_cutmix.append( - transforms_module.MixUp(alpha=mixup_alpha, num_categories=num_categories) + transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes) if use_v2 - else RandomMixUp(num_classes=num_categories, p=1.0, alpha=mixup_alpha) + else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha) ) if cutmix_alpha > 0: mixup_cutmix.append( - transforms_module.CutMix(alpha=mixup_alpha, num_categories=num_categories) + transforms_module.CutMix(alpha=mixup_alpha, num_classes=num_classes) if use_v2 - else RandomCutMix(num_classes=num_categories, p=1.0, alpha=mixup_alpha) + else RandomCutMix(num_classes=num_classes, p=1.0, alpha=mixup_alpha) ) if not mixup_cutmix: return None diff --git a/test/test_utils.py b/test/test_utils.py index 49dc553de3e..ffcad425aeb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -432,6 +432,22 @@ def test_draw_keypoints_visibility_default(): assert_equal(result, expected) +def test_draw_keypoints_dtypes(): + image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8) + image_float = to_dtype(image_uint8, torch.float, scale=True) + + out_uint8 = utils.draw_keypoints(image_uint8, keypoints) + out_float = utils.draw_keypoints(image_float, keypoints) + + assert out_uint8.dtype == torch.uint8 + assert out_uint8 is not image_uint8 + + assert out_float.is_floating_point() + assert out_float is not image_float + + torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) + + def test_draw_keypoints_errors(): h, w = 10, 10 img = torch.full((3, 100, 100), 0, dtype=torch.uint8) diff --git a/torchvision/utils.py b/torchvision/utils.py index 79e533d4663..734cb127db1 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -336,13 +336,13 @@ def draw_keypoints( """ Draws Keypoints on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The image values should be uint8 in [0, 255] or float in [0, 1]. Keypoints can be drawn for multiple instances at a time. This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances, in the format [x, y]. connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints @@ -363,7 +363,7 @@ def draw_keypoints( For more details, see :ref:`draw_keypoints_with_visibility`. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + img (Tensor[C, H, W]): Image Tensor with keypoints drawn. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -371,8 +371,8 @@ def draw_keypoints( # validate image if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif not (image.dtype == torch.uint8 or image.is_floating_point()): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: @@ -397,6 +397,12 @@ def draw_keypoints( f"Got {visibility.shape = } and {keypoints.shape = }" ) + original_dtype = image.dtype + if original_dtype.is_floating_point: + from torchvision.transforms.v2.functional import to_dtype # noqa + + image = to_dtype(image, dtype=torch.uint8, scale=True) + ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) draw = ImageDraw.Draw(img_to_draw) @@ -428,7 +434,10 @@ def draw_keypoints( width=width, ) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) + if original_dtype.is_floating_point: + out = to_dtype(out, dtype=original_dtype, scale=True) + return out # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization