Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchbench] Check failed: cachedComputation #5967

Open
ysiraichi opened this issue Dec 1, 2023 · 11 comments · Fixed by #6509
Open

[torchbench] Check failed: cachedComputation #5967

ysiraichi opened this issue Dec 1, 2023 · 11 comments · Fixed by #6509
Labels

Comments

@ysiraichi
Copy link
Collaborator

ysiraichi commented Dec 1, 2023

🐛 Bug

Running a few torchbench benchmarks, using dynamo+openxla backend, ends up in an assertion failure:

Traceback (most recent call last):
  File "torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "benchmarks/dynamo/torchbench.py", line 544, in forward_pass
    return mod(*inputs)
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "torchbenchmark/models/cm3leon_generate/model.py", line 1113, in forward
    def forward(self, src_tokens):
  File "torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "torch/_functorch/aot_autograd.py", line 4963, in forward
    return compiled_fn(full_args)
  File "torch/_functorch/aot_autograd.py", line 2017, in g
    return f(*args)
  File "torch/_functorch/aot_autograd.py", line 3164, in runtime_wrapper
    all_outs = call_func_with_args(
  File "torch/_functorch/aot_autograd.py", line 2041, in call_func_with_args
    out = normalize_as_list(f(args))
  File "torch/_functorch/aot_autograd.py", line 2145, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "torch/_functorch/aot_autograd.py", line 2017, in g
    return f(*args)
  File "torch/_dynamo/backends/torchxla.py", line 51, in fwd
    return compiled_graph(*args)
  File "torch/fx/graph_module.py", line 736, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "torch/fx/graph_module.py", line 315, in __call__
    raise e
  File "torch/fx/graph_module.py", line 302, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.5", line 5, in forward
  File "xla/torch_xla/core/dynamo_bridge.py", line 387, in optimized_mod
    res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)
RuntimeError: torch_xla/csrc/xla_graph_executor.cpp:625 : Check failed: cachedComputation 
Stack Trace
*** Begin stack trace ***
	tsl::CurrentStackTrace[abi:cxx11]()
	torch_xla::XLAGraphExecutor::ExecuteComputationWithBarrier(torch::lazy::hash_t, std::vector<c10::IValue, std::allocator<c10::IValue> > const&, torch::lazy::BackendDevice const&)
	
	
	
	
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyObject_FastCallDict
	_PyObject_Call_Prepend
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyObject_FastCallDict
	_PyObject_Call_Prepend
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyObject_FastCallDict
	_PyObject_Call_Prepend
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyObject_FastCallDict
	
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	PyEval_EvalCodeEx
*** End stack trace ***

Affected Benchmarks

  • cm3leon_generate
  • hf_T5_generate

Environment

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 1, 2023

hmm it is in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L621-L624

One possible reason is that this model compiles way too many times so LRU cache kick out one of the graphs. You can try to increase the default cache size in

XLAGraphExecutor::ComputationCache* XLAGraphExecutor::GetComputationCache() {
static const size_t kMaxCacheSize =
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024);
static ComputationCache* cache = new ComputationCache(kMaxCacheSize);
return cache;
}
.. but IMO 1024 is already pretty high.

If this is not the case, it would be really weird since for every cache we should store a computation in dynamo compilation phase.

@zpcore
Copy link
Collaborator

zpcore commented Jan 22, 2024

hmm it is in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L621-L624

One possible reason is that this model compiles way too many times so LRU cache kick out one of the graphs. You can try to increase the default cache size in

XLAGraphExecutor::ComputationCache* XLAGraphExecutor::GetComputationCache() {
static const size_t kMaxCacheSize =
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024);
static ComputationCache* cache = new ComputationCache(kMaxCacheSize);
return cache;
}

.. but IMO 1024 is already pretty high.
If this is not the case, it would be really weird since for every cache we should store a computation in dynamo compilation phase.

Just checked with different XLA_COMPILATION_CACHE_SIZE : 2048, 4096, still the model complains about the LRU cache. Any point on how to debug the issue?

@miladm
Copy link
Collaborator

miladm commented Jan 22, 2024

looks like these models are skipped by torch.bench for inductor. Do we know what their error would be had they not skipped?

cc @zpcore @ysiraichi @frgossen @golechwierowicz

@JackCaoG
Copy link
Collaborator

You can add print statements to both dynamo bridge and LRU cache to print the hashed being inserted into cache. You should be able to see when the compiled program is injected into the cache. If you never see the cached computation with hash being injected, there is a bug in the dynamo bridge in computing the hash.

@zpcore
Copy link
Collaborator

zpcore commented Jan 22, 2024

opacus_cifar10 shows the same issue in the latest run.

@zpcore
Copy link
Collaborator

zpcore commented Jan 22, 2024

Those two got skipped because No install.py is found for those two models and we didn't install the model.

@ysiraichi
Copy link
Collaborator Author

That's odd. Last I tried (a8b27eb) they were still passing on inductor.

python xla/benchmarks/experiment_runner.py --no-resume --suite-name torchbench --repeat 2 --accelerator cuda --test eval --xla None --dynamo inductor

I will try running them again on master.

@ysiraichi
Copy link
Collaborator Author

Oops. I think I misinterpreted your question. So, on torchbench they are skipped only if we try to export those models. Otherwise, they should pass (I think).

@zpcore
Copy link
Collaborator

zpcore commented Jan 22, 2024

Sorry, I mean when I try python install.py --continue_on_fail in https://github.com/pytorch/benchmark, it shows model cm3leon_generate and hf_T5_generate are skipped due to no install.py under those two models.

@ysiraichi
Copy link
Collaborator Author

I have looked into this issue, and contrary to what @zpcore found, I successfully run with dynamo+openxla both these benchmarks by increasing XLA_COMPILATION_CACHE_SIZE=2048. It could, however, still timeout, depending on its running parameters e.g.:

Running cm3leon_generate 30 times, 5 iterations each time, results in a timeout with dynamo+openxla. Each iteration is taking approximately 14s.

I would say we should:

  • Either increase XLA_COMPILATION_CACHE_SIZE default value or its value in the benchmarking script
  • Leave only the non-dynamo execution in the DENY_LIST

@zpcore @miladm @golechwierowicz @cota @frgossen
Could you try reproducing it, running the benchmarking script with --repeat 8 --iterations-per-run 1 (similar to nightly results)? Let me know what you think.

@ysiraichi
Copy link
Collaborator Author

opacus_cifar10 training is still failing with this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants