Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trying to Train MaskRCNN on XLA #1749

Closed
SharanSMenon opened this issue Mar 12, 2020 · 2 comments
Closed

Trying to Train MaskRCNN on XLA #1749

SharanSMenon opened this issue Mar 12, 2020 · 2 comments
Assignees

Comments

@SharanSMenon
Copy link

❓ Questions and Help

I get the following error when trying to train Mask RCNN. I was following the following tutorial:

https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

I got it to work on cuda and all I did was change the device to xla.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-35-5e7209f83d58> in <module>()
      3 for epoch in range(num_epochs):
      4     # train for one epoch, printing every 10 iterations
----> 5     train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
      6     # update the learning rate
      7     lr_scheduler.step()

6 frames
<ipython-input-34-60b12dee6390> in train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq)
     18         targets = [{k: torch.tensor(v).to(device) for k, v in t.items()} for t in targets]
     19 
---> 20         loss_dict = model(images, targets)
     21 
     22         losses = sum(loss for loss in loss_dict.values())

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     68         if isinstance(features, torch.Tensor):
     69             features = OrderedDict([('0', features)])
---> 70         proposals, proposal_losses = self.rpn(images, features, targets)
     71         detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
     72         detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/rpn.py in forward(self, images, features, targets)
    473         # note that we detach the deltas because Faster R-CNN do not backprop through
    474         # the proposals
--> 475         proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
    476         proposals = proposals.view(num_images, -1, 4)
    477         boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/_utils.py in decode(self, rel_codes, boxes)
    185             box_sum += val
    186         pred_boxes = self.decode_single(
--> 187             rel_codes.reshape(box_sum, -1), concat_boxes
    188         )
    189         return pred_boxes.reshape(box_sum, -1, 4)

/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/_utils.py in decode_single(self, rel_codes, boxes)
    221         pred_h = torch.exp(dh) * heights[:, None]
    222 
--> 223         pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
    224         pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
    225         pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w

RuntimeError: torch_xla/csrc/aten_xla_bridge.cpp:69 : Check failed: xtensor 
*** Begin stack trace ***
	tensorflow::CurrentStackTrace[abi:cxx11]()
	torch_xla::bridge::GetXlaTensor(at::Tensor const&)
	
	torch_xla::AtenXlaType::mul(at::Tensor const&, at::Tensor const&)
	c10::detail::wrap_kernel_functor_unboxed_<c10::detail::WrapRuntimeKernelFunctor_<at::Tensor (*)(at::Tensor const&, at::Tensor const&), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&)
	
	
	
	
	
	
	
	
	PyCFunction_Call
	PyObject_Call
	
	
	PyNumber_Multiply
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	
	_PyObject_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	
	_PyObject_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallDict
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	PyObject_Call
*** End stack trace ***
Input tensor is not an XLA tensor: torch.FloatTensor
@JackCaoG
Copy link
Collaborator

@SharanSMenon Thanks for reporting.! The error in

pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w

is that torch.tensor(0.5, dtype=pred_ctr_x.dtype) is a CPU tensor (while both pred_ctr_x and pred_w are XLA tensors) and currently Xla requires lhs of a binary operation to be a XLA tensor.

To workaround this you can switch the order of the mul make it

pred_boxes1 = pred_ctr_x -  pred_w * torch.tensor(0.5, dtype=pred_ctr_x.dtype)

or explicitly passing the device to the first tensor. We are discussing internally about how do we want to fix this( or if we want to fix this ).

@ailzhang
Copy link
Contributor

@SharanSMenon I created the PR above in pytorch/vision to make the device more clear. If you see more cases like this, please feel free to send a PR to pytorch/vision! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants