Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tensor devices for DARTS Trial #2273

Merged
merged 4 commits into from
Mar 10, 2024
Merged

Conversation

sifa1024
Copy link
Contributor

@sifa1024 sifa1024 commented Mar 5, 2024

What this PR does / why we need it:
If I use the original program, I will get this error when running darts-gpu,

<architect.Architect object at 0x7fe597aad780>
Traceback (most recent call last):
  File "/home/sifa/docker/katib/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py", line 259, in <module>
    main()
  File "/home/sifa/docker/katib/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py", line 155, in main
    train(train_loader, valid_loader, model, architect, w_optim, alpha_optim,
  File "/home/sifa/docker/katib/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py", line 194, in train
    architect.unrolled_backward(train_x, train_y, valid_x, valid_y, lr, w_optim)
  File "/home/sifa/docker/katib/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py", line 69, in unrolled_backward
    self.virtual_step(train_x, train_y, xi, w_optim)
  File "/home/sifa/docker/katib/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py", line 56, in virtual_step
    vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Which issue(s) this PR fixes

None. I've create pull request directly.

Checklist:

  • Docs included if any changes are user facing

@tenzen-y
Copy link
Member

tenzen-y commented Mar 5, 2024

@sifa1024 You need to sign to commit with the email used during sign the CLA.

Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix! I left a small comment.

Comment on lines 49 to 50
# Check device use cuda or cpu
use_cuda = list(range(torch.cuda.device_count()))
if use_cuda:
print("Using CUDA")
device = torch.device("cuda" if use_cuda else "cpu")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We identify device here:

if len(all_gpus) > 0:
device = torch.device("cuda")
torch.cuda.set_device(all_gpus[0])
np.random.seed(2)
torch.manual_seed(2)
torch.cuda.manual_seed_all(2)
torch.backends.cudnn.benchmark = True
print(">>> Use GPU for Training <<<")
print("Device ID: {}".format(torch.cuda.current_device()))
print("Device name: {}".format(torch.cuda.get_device_name(0)))
print("Device availability: {}\n".format(torch.cuda.is_available()))
else:
device = torch.device("cpu")
print(">>> Use CPU for Training <<<")
.
Can we just pass the device to the Architect class ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can. But is it a good idea to send the device name?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, it's fine since we don't need to invoke torch API again to understand if we have GPU available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I will change it.

Copy link
Contributor Author

@sifa1024 sifa1024 Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreyvelich Please check it and thank you for your help.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I restarted tests.

sifa1024 added 3 commits March 6, 2024 22:04
72907153+sifa1024@users.noreply.github.com

Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>
72907153+sifa1024@users.noreply.github.com

Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>
72907153+sifa1024@users.noreply.github.com

Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>
Copy link
Member

@tenzen-y tenzen-y left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/lgtm
/approve

/hold
for restarting failed Go Test / Unit Test (1.26.1) (pull_request).

@kubeflow/wg-automl-leads Could you restart CI?

Copy link

[APPROVALNOTIFIER] This PR is APPROVED

This pull-request has been approved by: sifa1024, tenzen-y

The full list of commands accepted by this bot can be found here.

The pull request process is described here

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@google-oss-prow google-oss-prow bot removed the lgtm label Mar 8, 2024
@tenzen-y
Copy link
Member

tenzen-y commented Mar 8, 2024

@kubeflow/wg-automl-leads Could you approve CI, again?

@tenzen-y
Copy link
Member

tenzen-y commented Mar 8, 2024

@sifa1024
This is a future reference.
We should do an actual rebase instead of a merge.

@sifa1024
Copy link
Contributor Author

sifa1024 commented Mar 8, 2024

@tenzen-y OK. I'm sorry. But I found a commit in my branch was not updated, so I just updated this.

@tenzen-y
Copy link
Member

tenzen-y commented Mar 9, 2024

@kubeflow/wg-automl-leads Could you restart Go Test / Unit Test (1.25.0) (pull_request) ?

/lgtm

@google-oss-prow google-oss-prow bot added the lgtm label Mar 9, 2024
@andreyvelich andreyvelich changed the title Update architect.py Fix tensor devices for DARTS Trial Mar 10, 2024
@andreyvelich
Copy link
Member

Thank you for your contribution @sifa1024!
/hold cancel

@google-oss-prow google-oss-prow bot merged commit 61406a5 into kubeflow:master Mar 10, 2024
59 checks passed
@sifa1024 sifa1024 deleted the patch-1 branch March 11, 2024 09:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants