diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 214600c39ecea..9cb60feac9d1b 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -2356,6 +2356,7 @@ def backward(ctx, *flat_args): def call_compiled_backward(): if CompiledFunction.compiled_bw is None: + assert all(a is not None for a in all_args) if aot_config.dynamic_shapes: all_args_list = list(all_args) CompiledFunction.compiled_bw = create_aot_dispatcher_function(