We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
moco
Running the upstreamed benchmarking scripts with the following command results in an unexpected error.
python xla/benchmarks/experiment_runner.py \ --suite-name torchbench \ --accelerator cuda \ --xla PJRT \ --dynamo openxla \ --test eval \ --repeat 30 --iterations-per-run 5 \ --print-subprocess \ --no-resume --filter moco
[rank0]: Traceback (most recent call last): [rank0]: File "xla/benchmarks/experiment_runner.py", line 994, in <module> [rank0]: main() [rank0]: File "xla/benchmarks/experiment_runner.py", line 990, in main [rank0]: runner.run() [rank0]: File "xla/benchmarks/experiment_runner.py", line 64, in run [rank0]: self.run_single_config() [rank0]: File "xla/benchmarks/experiment_runner.py", line 261, in run_single_config [rank0]: metrics, last_output = self.run_once_and_gather_metrics( [rank0]: File "xla/benchmarks/experiment_runner.py", line 357, in run_once_and_gather_metrics [rank0]: output, _ = loop(iter_fn=self._default_iter_fn) [rank0]: File "xla/benchmarks/experiment_runner.py", line 314, in loop [rank0]: output, timing, trace = iter_fn(benchmark_experiment, benchmark_model, [rank0]: File "xla/benchmarks/experiment_runner.py", line 223, in _default_iter_fn [rank0]: output = benchmark_model.model_iter_fn( [rank0]: File "torch/_dynamo/eval_frame.py", line 435, in _fn [rank0]: return fn(*args, **kwargs) [rank0]: File "xla/benchmarks/benchmark_model.py", line 168, in eval [rank0]: pred = self.module(*inputs) [rank0]: File "torch/nn/modules/module.py", line 1716, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "torch/nn/modules/module.py", line 1727, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: File "torch/nn/parallel/distributed.py", line 1634, in forward [rank0]: inputs, kwargs = self._pre_forward(*inputs, **kwargs) [rank0]: File "torch/nn/parallel/distributed.py", line 1530, in _pre_forward [rank0]: self._sync_buffers() [rank0]: File "torch/nn/parallel/distributed.py", line 2167, in _sync_buffers [rank0]: self._sync_module_buffers(authoritative_rank) [rank0]: File "torch/nn/parallel/distributed.py", line 2171, in _sync_module_buffers [rank0]: self._default_broadcast_coalesced(authoritative_rank=authoritative_rank) [rank0]: File "torch/nn/parallel/distributed.py", line 2193, in _default_broadcast_coalesced [rank0]: self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank) [rank0]: File "torch/nn/parallel/distributed.py", line 2108, in _distributed_broadcast_coalesced [rank0]: dist._broadcast_coalesced( [rank0]: File "torch/_dynamo/convert_frame.py", line 1121, in __call__ [rank0]: return self._torchdynamo_orig_callable( [rank0]: File "torch/_dynamo/convert_frame.py", line 948, in __call__ [rank0]: result = self._inner_convert( [rank0]: File "torch/_dynamo/convert_frame.py", line 472, in __call__ [rank0]: return _compile( [rank0]: File "torch/_utils_internal.py", line 85, in wrapper_function [rank0]: return StrobelightCompileTimeProfiler.profile_compile_time( [rank0]: File "torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time [rank0]: return func(*args, **kwargs) [rank0]: File "/usr/local/lib/python3.10/contextlib.py", line 79, in inner [rank0]: return func(*args, **kwds) [rank0]: File "torch/_dynamo/convert_frame.py", line 817, in _compile [rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform) [rank0]: File "torch/_dynamo/utils.py", line 233, in time_wrapper [rank0]: r = func(*args, **kwargs) [rank0]: File "torch/_dynamo/convert_frame.py", line 636, in compile_inner [rank0]: out_code = transform_code_object(code, transform) [rank0]: File "torch/_dynamo/bytecode_transformation.py", line 1270, in transform_code_object [rank0]: transformations(instructions, code_options) [rank0]: File "torch/_dynamo/convert_frame.py", line 178, in _fn [rank0]: return fn(*args, **kwargs) [rank0]: File "torch/_dynamo/convert_frame.py", line 582, in transform [rank0]: tracer.run() [rank0]: File "torch/_dynamo/symbolic_convert.py", line 2476, in run [rank0]: super().run() [rank0]: File "torch/_dynamo/symbolic_convert.py", line 904, in run [rank0]: while self.step(): [rank0]: File "torch/_dynamo/symbolic_convert.py", line 816, in step [rank0]: self.dispatch_table[inst.opcode](self, inst) [rank0]: File "torch/_dynamo/symbolic_convert.py", line 1585, in LOAD_ATTR [rank0]: self._load_attr(inst) [rank0]: File "torch/_dynamo/symbolic_convert.py", line 1575, in _load_attr [rank0]: result = BuiltinVariable(getattr).call_function( [rank0]: File "torch/_dynamo/variables/builtin.py", line 963, in call_function [rank0]: return handler(tx, args, kwargs) [rank0]: File "torch/_dynamo/variables/builtin.py", line 712, in <lambda> [rank0]: return lambda tx, args, kwargs: obj.call_function( [rank0]: File "torch/_dynamo/variables/builtin.py", line 963, in call_function [rank0]: return handler(tx, args, kwargs) [rank0]: File "torch/_dynamo/variables/builtin.py", line 847, in builtin_dipatch [rank0]: rv = fn(tx, args, kwargs) [rank0]: File "torch/_dynamo/variables/builtin.py", line 765, in call_self_handler [rank0]: result = self_handler(tx, *args, **kwargs) [rank0]: File "torch/_dynamo/variables/builtin.py", line 1607, in call_getattr [rank0]: return obj.var_getattr(tx, name) [rank0]: File "torch/_dynamo/variables/user_defined.py", line 891, in var_getattr [rank0]: return variables.UserMethodVariable( [rank0]: File "torch/_dynamo/variables/functions.py", line 309, in __init__ [rank0]: super().__init__(fn=fn, **kwargs) [rank0]: File "torch/_dynamo/variables/functions.py", line 146, in __init__ [rank0]: assert isinstance( [rank0]: AssertionError: expected FunctionType found instancemethod <instancemethod at 0x7f01f594df30> [rank0]: from user code: [rank0]: File "xla/torch_xla/distributed/xla_backend.py", line 98, in broadcast [rank0]: root_tensor = tensors[opts.rootTensor] [rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information [rank0]: You can suppress this exception and fall back to eager by setting: [rank0]: import torch._dynamo [rank0]: torch._dynamo.config.suppress_errors = True
cc @miladm @JackCaoG
The text was updated successfully, but these errors were encountered:
Reproducible with: pytorch/benchmark@612b3c8
Sorry, something went wrong.
No branches or pull requests
🐛 Bug
Running the upstreamed benchmarking scripts with the following command results in an unexpected error.
python xla/benchmarks/experiment_runner.py \ --suite-name torchbench \ --accelerator cuda \ --xla PJRT \ --dynamo openxla \ --test eval \ --repeat 30 --iterations-per-run 5 \ --print-subprocess \ --no-resume --filter moco
Environment
cc @miladm @JackCaoG
The text was updated successfully, but these errors were encountered: