-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[1/2] Added backward pass on CPU for interpolation with anti-alias option #4208
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation looks good to me, thanks!
Some parts are not yet binded (like autograd for the ops), but I suppose this can be left for the future.
# skip float16 on CPU case | ||
return | ||
|
||
torch.manual_seed(12) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, should we set the seed of the tests in a more automated way via pytest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's fine for now as this is what we're doing almost everywhere else
Once we're finished with pytest porting I'll look into ways to improve the RNG handling in our tests.
One thing I'm wondering is: does torch.manual_seed(12)
leak the rng for the rest of the tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does torch.manual_seed(12) leak the rng for the rest of the tests?
I think yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's pretty bad :)
but yeah, it's OK for now to use the old pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it certainly does leak it and it's bitten us quite a few times in the past.
class F(torch.autograd.Function): | ||
|
||
@staticmethod | ||
def forward(ctx, i): | ||
result = forward_op(i, size, False) | ||
ctx.save_for_backward(i, result) | ||
return result | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
i, result = ctx.saved_tensors | ||
ishape = i.shape | ||
oshape = result.shape[2:] | ||
return backward_op(grad_output, oshape, ishape, False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine for now. I suppose the next step would be to move those functions to PyTorch so that we nave native autograd support?
for (int64_t oh = 0; oh < output_height; oh++) { | ||
F<int64_t, scalar_t>::_compute_weights_aa( | ||
oh, | ||
input_height, | ||
height_scale, | ||
support_h, | ||
wy.data(), | ||
interp_height, | ||
filter_fn, | ||
ymin, | ||
ysize); | ||
|
||
for (int64_t ow = 0; ow < output_width; ow++) { | ||
F<int64_t, scalar_t>::_compute_weights_aa( | ||
ow, | ||
input_width, | ||
width_scale, | ||
support_w, | ||
wx.data(), | ||
interp_width, | ||
filter_fn, | ||
xmin, | ||
xsize); | ||
|
||
for (int64_t c = begin; c < end; c++) { | ||
scalar_t grad_output_value = | ||
grad_output_data[c * output_slice_size + oh * output_width + ow]; | ||
|
||
for (size_t y = 0; y < ysize; y++) { | ||
for (size_t x = 0; x < xsize; x++) { | ||
*input_indexr(c, ymin + y, xmin + x) += | ||
wx[x] * wy[y] * grad_output_value; | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation is good and follows what we have discussed previously (i.e., follows the backwards op in PyTorch's interpolate).
I would like you to think if there would be a way of generalizing this for nd cases that doesn't involve a lot of copy-pasting. This will be specially important when we move the function to PyTorch, as we will also support 3d and 5d cases.
Hey @fmassa! You merged this PR, but no labels were added. |
…-alias option (#4208) Summary: * WIP on backward op interpolation with AA * Removed cuda tests and reformat cpp code * Fixed clang wrong formatting * Added channels last test case Reviewed By: NicolasHug Differential Revision: D30069956 fbshipit-source-id: cbd7163b7407f653f50aaba09f27ae2b7cb8e094 Co-authored-by: vfdev-5 <vfdev-5@gmail.com> Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
Summary: Description: - Added antialias flag to interpolate (CPU only) - forward and backward for bilinear mode - added tests ### Benchmarks <details> <summary> Forward pass, CPU. PTH interpolation vs PIL </summary> Cases: - PTH RGB 3 Channels, float32 vs PIL RGB uint8 (apply vs pears) - PTH 1 Channel, float32 vs PIL 1 Channel Float Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112 ``` # OMP_NUM_THREADS=1 python bench_interp_aa_vs_pillow.py Torch config: PyTorch built with: - GCC 9.3 - C++ Version: 201402 - OpenMP 201511 (a.k.a. OpenMP 4.5) - CPU capability usage: AVX2 - CUDA Runtime 11.1 - NVCC architecture flags: -gencode;arch=compute_75,code=sm_75 - CuDNN 8.0.5 - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, Num threads: 1 [------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (320, 196) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2.9 | 3.1 channels_last non-contiguous torch.float32 | 2.6 | 3.6 Times are in milliseconds (ms). [------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (460, 220) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3.4 | 4.0 channels_last non-contiguous torch.float32 | 3.4 | 4.8 Times are in milliseconds (ms). [------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 96) -------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 1.6 | 1.8 channels_last non-contiguous torch.float32 | 1.6 | 1.9 Times are in milliseconds (ms). [----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (1200, 196) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 9.0 | 11.3 channels_last non-contiguous torch.float32 | 8.9 | 12.5 Times are in milliseconds (ms). [----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 1200) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2.1 | 1.8 channels_last non-contiguous torch.float32 | 2.1 | 3.4 Times are in milliseconds (ms). [--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (320, 196) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.2 | 1.0 Times are in milliseconds (ms). [--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (460, 220) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.4 | 1.3 Times are in milliseconds (ms). [--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 96) ---------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 719.9 | 599.9 Times are in microseconds (us). [-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (1200, 196) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 3.7 | 3.5 Times are in milliseconds (ms). [-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 1200) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 834.4 | 605.7 Times are in microseconds (us). ``` </details> Code is moved from torchvision: pytorch/vision#4208 Pull Request resolved: #65142 Reviewed By: mrshenli Differential Revision: D32432405 Pulled By: jbschlosser fbshipit-source-id: b66c548347f257c522c36105868532e8bc1d4c6d
Summary: Description: - Added antialias flag to interpolate (CPU only) - forward and backward for bicubic mode - added tests Previous PR for bilinear, #65142 ### Benchmarks <details> <summary> Forward pass, CPU. PTH interpolation vs PIL </summary> Cases: - PTH RGB 3 Channels, float32 vs PIL RGB uint8 (apples vs pears) - PTH 1 Channel, float32 vs PIL 1 Channel Float Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112 ``` Torch config: PyTorch built with: - GCC 9.3 - C++ Version: 201402 - OpenMP 201511 (a.k.a. OpenMP 4.5) - CPU capability usage: AVX2 - CUDA Runtime 11.1 - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61 - CuDNN 8.0.5 - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF, Num threads: 1 [------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (320, 196) -------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 4.5 | 5.2 channels_last non-contiguous torch.float32 | 4.5 | 5.3 Times are in milliseconds (ms). [------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (460, 220) -------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 5.7 | 6.4 channels_last non-contiguous torch.float32 | 5.7 | 6.4 Times are in milliseconds (ms). [------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 96) --------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3.0 | 4.0 channels_last non-contiguous torch.float32 | 2.9 | 4.1 Times are in milliseconds (ms). [------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (1200, 196) -------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 14.7 | 17.1 channels_last non-contiguous torch.float32 | 14.8 | 17.2 Times are in milliseconds (ms). [------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 1200) -------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3.5 | 3.9 channels_last non-contiguous torch.float32 | 3.5 | 3.9 Times are in milliseconds (ms). [---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (320, 196) ---------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 2.4 | 1.8 Times are in milliseconds (ms). [---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (460, 220) ---------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 3.1 | 2.2 Times are in milliseconds (ms). [---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 96) ----------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.6 | 1.4 Times are in milliseconds (ms). [--------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (1200, 196) ---------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 7.9 | 5.7 Times are in milliseconds (ms). [--------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 1200) ---------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.7 | 1.3 Times are in milliseconds (ms). ``` </details> Code is moved from torchvision: pytorch/vision#3810 and pytorch/vision#4208 Pull Request resolved: #68819 Reviewed By: mikaylagawarecki Differential Revision: D33339117 Pulled By: jbschlosser fbshipit-source-id: 6a0443bbba5439f52c7dbc1be819b75634cf67c4
Summary: Description: - Added antialias flag to interpolate (CPU only) - forward and backward for bicubic mode - added tests Previous PR for bilinear, #65142 ### Benchmarks <details> <summary> Forward pass, CPU. PTH interpolation vs PIL </summary> Cases: - PTH RGB 3 Channels, float32 vs PIL RGB uint8 (apples vs pears) - PTH 1 Channel, float32 vs PIL 1 Channel Float Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112 ``` Torch config: PyTorch built with: - GCC 9.3 - C++ Version: 201402 - OpenMP 201511 (a.k.a. OpenMP 4.5) - CPU capability usage: AVX2 - CUDA Runtime 11.1 - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61 - CuDNN 8.0.5 - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF, Num threads: 1 [------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (320, 196) -------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 4.5 | 5.2 channels_last non-contiguous torch.float32 | 4.5 | 5.3 Times are in milliseconds (ms). [------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (460, 220) -------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 5.7 | 6.4 channels_last non-contiguous torch.float32 | 5.7 | 6.4 Times are in milliseconds (ms). [------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 96) --------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3.0 | 4.0 channels_last non-contiguous torch.float32 | 2.9 | 4.1 Times are in milliseconds (ms). [------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (1200, 196) -------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 14.7 | 17.1 channels_last non-contiguous torch.float32 | 14.8 | 17.2 Times are in milliseconds (ms). [------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 1200) -------------------] | Reference, PIL 8.4.0, mode: RGB | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3.5 | 3.9 channels_last non-contiguous torch.float32 | 3.5 | 3.9 Times are in milliseconds (ms). [---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (320, 196) ---------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 2.4 | 1.8 Times are in milliseconds (ms). [---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (460, 220) ---------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 3.1 | 2.2 Times are in milliseconds (ms). [---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 96) ----------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.6 | 1.4 Times are in milliseconds (ms). [--------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (1200, 196) ---------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 7.9 | 5.7 Times are in milliseconds (ms). [--------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 1200) ---------] | Reference, PIL 8.4.0, mode: F | 1.11.0a0+gitb0bdf58 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.7 | 1.3 Times are in milliseconds (ms). ``` </details> Code is moved from torchvision: pytorch/vision#3810 and pytorch/vision#4208 Pull Request resolved: #68819 Reviewed By: mikaylagawarecki Differential Revision: D33339117 Pulled By: jbschlosser fbshipit-source-id: 6a0443bbba5439f52c7dbc1be819b75634cf67c4
Description:
cc @fmassa