diff --git a/docker/diffusers-pytorch-compile-cuda/Dockerfile b/docker/diffusers-pytorch-compile-cuda/Dockerfile index b0646084964e..a41be50f9d58 100644 --- a/docker/diffusers-pytorch-compile-cuda/Dockerfile +++ b/docker/diffusers-pytorch-compile-cuda/Dockerfile @@ -14,22 +14,23 @@ RUN apt update && \ libsndfile1-dev \ libgl1 \ python3.9 \ + python3.9-dev \ python3-pip \ python3.9-venv && \ rm -rf /var/lib/apt/lists # make sure to use venv -RUN python3 -m venv /opt/venv +RUN python3.9 -m venv /opt/venv ENV PATH="/opt/venv/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) -RUN python3 -m pip install --no-cache-dir --upgrade pip && \ - python3 -m pip install --no-cache-dir \ +RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \ + python3.9 -m pip install --no-cache-dir \ torch \ torchvision \ torchaudio \ invisible_watermark && \ - python3 -m pip install --no-cache-dir \ + python3.9 -m pip install --no-cache-dir \ accelerate \ datasets \ hf-doc-builder \ diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 32e3729828fa..d11b59cc1510 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -85,16 +85,17 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): prompt = "bird" image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ) + ).resize((512, 512)) - output = pipe(prompt, image, generator=generator, output_type="np") + output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np") image = output.images[0] - assert image.shape == (768, 512, 3) + assert image.shape == (512, 512, 3) expected_image = load_numpy( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy" ) + expected_image = np.resize(expected_image, (512, 512, 3)) assert np.abs(expected_image - image).max() < 1.0