Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchbench] dlrm fails to run on training. #6008

Closed
ysiraichi opened this issue Dec 3, 2023 · 6 comments · Fixed by #7584
Closed

[torchbench] dlrm fails to run on training. #6008

ysiraichi opened this issue Dec 3, 2023 · 6 comments · Fixed by #7584
Labels

Comments

@ysiraichi
Copy link
Collaborator

🐛 Bug

Running the upstreamed benchmarking scripts with the following command results in an unexpected error.

python xla/benchmarks/experiment_runner.py \
       --suite-name torchbench \
       --accelerator cuda \
       --xla PJRT --xla None \
       --dynamo openxla --dynamo None \
       --test train \
       --repeat 30 --iterations-per-run 5 \
       --print-subprocess \
       --no-resume -k dlrm
Traceback (most recent call last):
  File ""xla/benchmarks/experiment_runner.py", line 601, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 597, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 65, in run
    self.run_single_experiment(experiment_config, model_config)
  File "xla/benchmarks/experiment_runner.py", line 161, in run_single_experiment
    run_metrics, output = self.timed_run(benchmark_experiment,
  File "xla/benchmarks/experiment_runner.py", line 328, in timed_run
    output = loop()
  File "xla/benchmarks/experiment_runner.py", line 310, in loop
    output = benchmark_model.model_iter_fn(
  File "torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "xla/benchmarks/torchbench_model.py", line 247, in train
    super().train(inputs, collect_full_output=collect_full_output)
  File "xla/benchmarks/benchmark_model.py", line 142, in train
    self._optimizer_zero_grad()
  File "xla/benchmarks/benchmark_model.py", line 145, in resume_in_train
    loss.backward()
  File "torch/_tensor.py", line 503, in backward
    torch.autograd.backward(
  File "torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "torch/_functorch/aot_autograd.py", line 4201, in backward
    out = call_compiled_backward()
  File "torch/_functorch/aot_autograd.py", line 4167, in call_compiled_backward
    out = call_func_with_args(
  File "torch/_functorch/aot_autograd.py", line 2016, in call_func_with_args
    out = normalize_as_list(f(args))
  File "torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "torch/_functorch/aot_autograd.py", line 1992, in g
    return f(*args)
  File "torch/_dynamo/backends/torchxla.py", line 49, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "xla/torch_xla/core/dynamo_bridge.py", line 517, in extract_compiled_graph
    collector.run(*xla_args)
  File "torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "xla/torch_xla/core/dynamo_bridge.py", line 431, in run_node
    result = super().run_node(n)
  File "torch/fx/interpreter.py", line 195, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "torch/fx/interpreter.py", line 267, in call_function
    return target(*args, **kwargs)
  File "torch/_ops.py", line 509, in __call__
    return self._op(*args, **kwargs or {})
NotImplementedError: Could not run 'aten::_sparse_coo_tensor_with_dims_and_tensors' with arguments from the 'SparseXLA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_sparse_coo_tensor_with_dims_and_tensors' is only available for these backends: [XLA, Meta, SparseCPU, SparseCUDA, SparseMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXLA, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

XLA: registered at torch_xla/csrc/aten_cpu_fallback.cpp:51 [backend fallback]
Meta: registered at build/aten/src/ATen/RegisterMeta.cpp:26984 [kernel]
SparseCPU: registered at build/aten/src/ATen/RegisterSparseCPU.cpp:1387 [kernel]
SparseCUDA: registered at build/aten/src/ATen/RegisterSparseCUDA.cpp:1573 [kernel]
SparseMeta: registered at build/aten/src/ATen/RegisterSparseMeta.cpp:249 [kernel]
BackendSelect: registered at build/aten/src/ATen/RegisterBackendSelect.cpp:807 [kernel]
Python: registered at aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at aten/src/ATen/FunctionalizeFallbackKernel.cpp:302 [backend fallback]
Named: registered at aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradCPU: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradCUDA: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradHIP: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradXLA: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradMPS: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradIPU: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradXPU: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradHPU: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradVE: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradLazy: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradMTIA: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradPrivateUse1: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradPrivateUse2: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradPrivateUse3: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradMeta: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradNestedTensor: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
Tracer: registered at torch/csrc/autograd/generated/TraceType_2.cpp:17346 [kernel]
AutocastCPU: fallthrough registered at aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastXLA: fallthrough registered at torch_xla/csrc/autocast_mode.cpp:25 [backend fallback]
AutocastCUDA: fallthrough registered at aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:720 [backend fallback]
BatchedNestedTensor: registered at aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:746 [backend fallback]
FuncTorchVmapMode: fallthrough registered at aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]


While executing %_sparse_coo_tensor_with_dims_and_tensors : [num_users=1] = call_function[target=torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors.default](args = (1, 1, [1000000, 64], %view_6, %view_7), kwargs = {dtype: torch.float32, layout: torch.sparse_coo, device: xla:0, pin_memory: None})
Original traceback:
  File "xla/benchmarks/benchmark_model.py", line 143, in resume_in_train
    pred = self.module(*inputs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 337, in forward
    return self.sequential_forward(dense_x, lS_o, lS_i)
  File "torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 349, in sequential_forward
    ly = self.apply_emb(lS_o, lS_i, self.emb_l)
  File "torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 294, in apply_emb
    V = E(sparse_index_group_batch, sparse_offset_group_batch)

Environment

@ysiraichi
Copy link
Collaborator Author

I thought that we could fallback SparseXLA onto SparseCPU, returning a CPU sparse tensor. Even though that solves this problem, PyTorch autograd expects that the output will be a tensor on XLA, not on CPU.

I'm thinking that there's no way out of this error, unless we actually support XLA sparse tensors. That's because EmbeddingBag backward function always returns a sparse tensor when sparse=true. So, it's not something we can fallback to CPU, and then move back to XLA, again, unless we support XLA sparse tensors.

@JackCaoG @lezcano
Any ideas?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jun 3, 2024

hmm Do you know what does supporting SparseXLA even means? I have no idea how much works are involved.

@lezcano
Copy link
Collaborator

lezcano commented Jun 3, 2024

Can we be a bit cheeky and lower EmbeddingBag with sparse=true into sparse=false?

@ysiraichi
Copy link
Collaborator Author

Do you know what does supporting SparseXLA even means?

Good question. I tried calling at::native::new_with_dims_and_tensor_sparse_symint with XLA tensors, hoping everything would work out, but turns out it doesn't. I got a new error:

Traceback (most recent call last):
  File "xla/benchmarks/experiment_runner.py", line 963, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 959, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 61, in run
    self.run_single_config()
  File "xla/benchmarks/experiment_runner.py", line 257, in run_single_config
    metrics, last_output = self.run_once_and_gather_metrics(
  File "xla/benchmarks/experiment_runner.py", line 352, in run_once_and_gather_metrics
    output, _ = loop(iter_fn=self._default_iter_fn)
  File "xla/benchmarks/experiment_runner.py", line 309, in loop
    output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
  File "xla/benchmarks/experiment_runner.py", line 219, in _default_iter_fn
    output = benchmark_model.model_iter_fn(
  File "xla/benchmarks/torchbench_model.py", line 413, in train
    super().train(inputs, collect_full_output=collect_full_output)
  File "xla/benchmarks/benchmark_model.py", line 183, in train
    loss.backward()
  File "torch/_tensor.py", line 523, in backward
    torch.autograd.backward(
  File "torch/autograd/__init__.py", line 284, in backward
    _engine_run_backward(
  File "torch/autograd/graph.py", line 767, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: layout_or_default(layout_opt) == Layout::Strided INTERNAL ASSERT FAILED at "../aten/src/ATen/EmptyTensor.cpp":448, please report a bug to PyTorch.

I checked with gdb where, in the backwards, we were erroring, and turns out it's a Tensor.detach call on the sparse XLA tensor. When it goes through the functionalization implementation, it calls to_meta, which calls empty_strided. It errors, since the tensor is not actually strided.

So maybe, the problem is in FunctionalTensorWrapper assuming (which I think it does) that the tensor will be strided. That said, it's hard to say whether this is the only issue.

Can we be a bit cheeky and lower EmbeddingBag with sparse=true into sparse=false?

I guess so. But, then, it would be hard to compare the performance results, since we are running different programs.

@lezcano
Copy link
Collaborator

lezcano commented Jun 4, 2024

Yeah, it'd be difficult to compare results, but at least things would run :D
I don't see how to fix the sparse issue really. @JackCaoG any ideas?

@lezcano
Copy link
Collaborator

lezcano commented Jun 4, 2024

We could also just error out with a clean error asking the user to set sparse=False. Either works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants