Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first #2106

Closed
WestbrookZero opened this issue Feb 1, 2021 · 6 comments
Labels
bug Something isn't working

Comments

@WestbrookZero
Copy link

when i start to train , An error is as follows:
image

@WestbrookZero WestbrookZero added the bug Something isn't working label Feb 1, 2021
@github-actions
Copy link
Contributor

github-actions bot commented Feb 1, 2021

👋 Hello @WestbrookZero, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com.

Requirements

Python 3.8 or later with all requirements.txt dependencies installed, including torch>=1.7. To install run:

$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

@hehe03
Copy link

hehe03 commented May 20, 2021

Hello @WestbrookZero , I've run into the same problem, and I'm wondering how did you reslove it, thanks!

@oukohou
Copy link

oukohou commented May 21, 2021

+1 same problem, what's the workaround?

@glenn-jocher
Copy link
Member

glenn-jocher commented May 21, 2021

@hehe91 @oukohou If you believe you have a reproducible issue, we suggest you close this issue and raise a new one using the 🐛 Bug Report template, providing screenshots and a minimum reproducible example to help us better understand and diagnose your problem. Thank you!

@oukohou
Copy link

oukohou commented May 22, 2021

@glenn-jocher well, no bother, that's a tiny carelesses of old-versioned codes, just simply update the repo can fix it.
In case someone still uses the old codes, a tiny modification can fix it:
in utils/general.py's output_to_target function, just add one more type assert:

def output_to_target(output, width, height):
    # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
    if isinstance(output, torch.Tensor):
        output = output.cpu().numpy()

    targets = []
    for i, o in enumerate(output):
        if o is not None:
            # sometimes output can be a list of tensor, so here ensure the type again, this fixes the error.
            if isinstance(o, torch.Tensor):
                o = o.cpu().numpy()
            for pred in o:
                box = pred[:4]
                w = (box[2] - box[0]) / width
                h = (box[3] - box[1]) / height
                x = box[0] / width + w / 2
                y = box[1] / height + h / 2
                conf = pred[4]
                cls = int(pred[5])

                targets.append([i, cls, x, y, w, h, conf])

    return np.array(targets)

@hehe91 hope this helps.
and @glenn-jocher thanks for your great work again!

robin-maillot pushed a commit to robin-maillot/yolov5 that referenced this issue Jun 22, 2021
alexchwong added a commit to alexchwong/ScaledYOLOv4 that referenced this issue Dec 26, 2021
@jingruhou
Copy link

@glenn-jocher well, no bother, that's a tiny carelesses of old-versioned codes, just simply update the repo can fix it. In case someone still uses the old codes, a tiny modification can fix it: in utils/general.py's output_to_target function, just add one more type assert:

def output_to_target(output, width, height):
    # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
    if isinstance(output, torch.Tensor):
        output = output.cpu().numpy()

    targets = []
    for i, o in enumerate(output):
        if o is not None:
            # sometimes output can be a list of tensor, so here ensure the type again, this fixes the error.
            if isinstance(o, torch.Tensor):
                o = o.cpu().numpy()
            for pred in o:
                box = pred[:4]
                w = (box[2] - box[0]) / width
                h = (box[3] - box[1]) / height
                x = box[0] / width + w / 2
                y = box[1] / height + h / 2
                conf = pred[4]
                cls = int(pred[5])

                targets.append([i, cls, x, y, w, h, conf])

    return np.array(targets)

@hehe91 hope this helps. and @glenn-jocher thanks for your great work again!

thanks,resolved my question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants