Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Latent Consistency models OpenVINO export and inference #463

Merged
merged 6 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"OVStableDiffusionInpaintPipeline",
"OVStableDiffusionXLPipeline",
"OVStableDiffusionXLImg2ImgPipeline",
"OVLatentConsistencyModelPipeline",
]
else:
_import_structure["openvino"].extend(
Expand All @@ -71,6 +72,7 @@
"OVStableDiffusionInpaintPipeline",
"OVStableDiffusionXLPipeline",
"OVStableDiffusionXLImg2ImgPipeline",
"OVLatentConsistencyModelPipeline",
]
)

Expand Down Expand Up @@ -158,6 +160,7 @@
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_openvino_and_diffusers_objects import (
OVLatentConsistencyModelPipeline,
OVStableDiffusionImg2ImgPipeline,
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
Expand All @@ -166,6 +169,7 @@
)
else:
from .openvino import (
OVLatentConsistencyModelPipeline,
OVStableDiffusionImg2ImgPipeline,
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

if is_diffusers_available():
from .modeling_diffusion import (
OVLatentConsistencyModelPipeline,
OVStableDiffusionImg2ImgPipeline,
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
Expand Down
92 changes: 72 additions & 20 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from openvino.runtime import Core
from transformers import CLIPFeatureExtractor, CLIPTokenizer

from optimum.pipelines.diffusers.pipeline_latent_consistency import LatentConsistencyPipelineMixin
from optimum.pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin
from optimum.pipelines.diffusers.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipelineMixin
from optimum.pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin
Expand Down Expand Up @@ -69,16 +70,16 @@ class OVStableDiffusionPipelineBase(OVBaseModel, OVTextualInversionLoaderMixin):

def __init__(
self,
vae_decoder: openvino.runtime.Model,
text_encoder: openvino.runtime.Model,
unet: openvino.runtime.Model,
config: Dict[str, Any],
tokenizer: "CLIPTokenizer",
scheduler: Union["DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler"],
feature_extractor: Optional["CLIPFeatureExtractor"] = None,
vae_decoder: Optional[openvino.runtime.Model] = None,
vae_encoder: Optional[openvino.runtime.Model] = None,
text_encoder: Optional[openvino.runtime.Model] = None,
text_encoder_2: Optional[openvino.runtime.Model] = None,
tokenizer: Optional["CLIPTokenizer"] = None,
tokenizer_2: Optional["CLIPTokenizer"] = None,
feature_extractor: Optional["CLIPFeatureExtractor"] = None,
device: str = "CPU",
dynamic_shapes: bool = True,
compile: bool = True,
Expand Down Expand Up @@ -270,20 +271,7 @@ def _from_pretrained(
if model_save_dir is None:
model_save_dir = new_model_save_dir

return cls(
vae_decoder=components["vae_decoder"],
text_encoder=components["text_encoder"],
unet=unet,
config=config,
tokenizer=kwargs.pop("tokenizer", None),
scheduler=kwargs.pop("scheduler"),
feature_extractor=kwargs.pop("feature_extractor", None),
vae_encoder=components["vae_encoder"],
text_encoder_2=components["text_encoder_2"],
tokenizer_2=kwargs.pop("tokenizer_2", None),
model_save_dir=model_save_dir,
**kwargs,
)
return cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs)

@classmethod
def _from_transformers(
Expand All @@ -295,10 +283,11 @@ def _from_transformers(
force_download: bool = False,
cache_dir: Optional[str] = None,
local_files_only: bool = False,
tokenizer: "CLIPTokenizer" = None,
tokenizer: Optional["CLIPTokenizer"] = None,
scheduler: Union["DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler"] = None,
feature_extractor: Optional["CLIPFeatureExtractor"] = None,
load_in_8bit: bool = False,
tokenizer_2: Optional["CLIPTokenizer"] = None,
**kwargs,
):
save_dir = TemporaryDirectory()
Expand Down Expand Up @@ -329,6 +318,7 @@ def _from_transformers(
local_files_only=local_files_only,
model_save_dir=save_dir,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
feature_extractor=feature_extractor,
load_in_8bit=load_in_8bit,
Expand Down Expand Up @@ -377,8 +367,10 @@ def _reshape_unet(
if batch_size == -1 or num_images_per_prompt == -1:
batch_size = -1
else:
batch_size *= num_images_per_prompt
# The factor of 2 comes from the guidance scale > 1
batch_size = 2 * batch_size * num_images_per_prompt
if "timestep_cond" not in {inputs.get_any_name() for inputs in model.inputs}:
batch_size *= 2

height = height // self.vae_scale_factor if height > 0 else height
width = width // self.vae_scale_factor if width > 0 else width
Expand All @@ -402,6 +394,8 @@ def _reshape_unet(
shapes[inputs] = [batch_size, self.text_encoder_2.config["projection_dim"]]
elif inputs.get_any_name() == "time_ids":
shapes[inputs] = [batch_size, inputs.get_partial_shape()[1]]
elif inputs.get_any_name() == "timestep_cond":
shapes[inputs] = [batch_size, self.unet.config["time_cond_proj_dim"]]
else:
shapes[inputs][0] = batch_size
shapes[inputs][1] = tokenizer_max_length
Expand Down Expand Up @@ -585,6 +579,7 @@ def __call__(
encoder_hidden_states: np.ndarray,
text_embeds: Optional[np.ndarray] = None,
time_ids: Optional[np.ndarray] = None,
timestep_cond: Optional[np.ndarray] = None,
):
self._compile()

Expand All @@ -598,6 +593,8 @@ def __call__(
inputs["text_embeds"] = text_embeds
if time_ids is not None:
inputs["time_ids"] = time_ids
if timestep_cond is not None:
inputs["timestep_cond"] = timestep_cond

outputs = self.request(inputs, shared_memory=True)
return list(outputs.values())
Expand Down Expand Up @@ -930,6 +927,61 @@ def __call__(
)


class OVLatentConsistencyModelPipeline(OVStableDiffusionPipelineBase, LatentConsistencyPipelineMixin):
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 4,
original_inference_steps: int = None,
guidance_scale: float = 8.5,
num_images_per_prompt: int = 1,
**kwargs,
):
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
width = width or self.unet.config["sample_size"] * self.vae_scale_factor
_height = self.height
_width = self.width
expected_batch_size = self._batch_size

if _height != -1 and height != _height:
logger.warning(
f"`height` was set to {height} but the static model will output images of height {_height}."
"To fix the height, please reshape your model accordingly using the `.reshape()` method."
)
height = _height

if _width != -1 and width != _width:
logger.warning(
f"`width` was set to {width} but the static model will output images of width {_width}."
"To fix the width, please reshape your model accordingly using the `.reshape()` method."
)
width = _width

if expected_batch_size != -1:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = kwargs.get("prompt_embeds").shape[0]

_raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale=0.0)

return LatentConsistencyPipelineMixin.__call__(
self,
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
original_inference_steps=original_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
**kwargs,
)


def _raise_invalid_batch_size(
expected_batch_size: int, batch_size: int, num_images_per_prompt: int, guidance_scale: float
):
Expand Down
11 changes: 11 additions & 0 deletions optimum/intel/utils/dummy_openvino_and_diffusers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,14 @@ def __init__(self, *args, **kwargs):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino", "diffusers"])


class OVLatentConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["openvino", "diffusers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["openvino", "diffusers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino", "diffusers"])
68 changes: 68 additions & 0 deletions tests/openvino/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
from diffusers.utils import load_image
from diffusers.utils.testing_utils import floats_tensor
from openvino.runtime.ie_api import CompiledModel
from packaging.version import Version, parse
from parameterized import parameterized
from utils_tests import MODEL_NAMES, SEED

from optimum.intel import (
OVLatentConsistencyModelPipeline,
OVStableDiffusionImg2ImgPipeline,
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
Expand All @@ -50,6 +52,7 @@
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPipeline,
)
from optimum.utils.import_utils import _diffusers_version


def _generate_inputs(batch_size=1):
Expand Down Expand Up @@ -475,3 +478,68 @@ def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"):
inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type)
inputs["strength"] = 0.75
return inputs


class OVLatentConsistencyModelPipelineTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("latent-consistency",)
MODEL_CLASS = OVLatentConsistencyModelPipeline
TASK = "text-to-image"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version")
def test_compare_to_diffusers(self, model_arch: str):
ov_pipeline = self.MODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True)
self.assertIsInstance(ov_pipeline.text_encoder, OVModelTextEncoder)
self.assertIsInstance(ov_pipeline.vae_encoder, OVModelVaeEncoder)
self.assertIsInstance(ov_pipeline.vae_decoder, OVModelVaeDecoder)
self.assertIsInstance(ov_pipeline.unet, OVModelUnet)
self.assertIsInstance(ov_pipeline.config, Dict)

from diffusers import LatentConsistencyModelPipeline

pipeline = LatentConsistencyModelPipeline.from_pretrained(MODEL_NAMES[model_arch])
batch_size, num_images_per_prompt, height, width = 2, 3, 64, 128
latents = ov_pipeline.prepare_latents(
batch_size * num_images_per_prompt,
ov_pipeline.unet.config["in_channels"],
height,
width,
dtype=np.float32,
generator=np.random.RandomState(0),
)

kwargs = {
"prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size,
"num_inference_steps": 1,
"num_images_per_prompt": num_images_per_prompt,
"height": height,
"width": width,
"guidance_scale": 8.5,
}

for output_type in ["latent", "np"]:
ov_outputs = ov_pipeline(latents=latents, output_type=output_type, **kwargs).images
self.assertIsInstance(ov_outputs, np.ndarray)
with torch.no_grad():
outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images

# Compare model outputs
self.assertTrue(np.allclose(ov_outputs, outputs, atol=1e-4))
# Compare model devices
self.assertEqual(pipeline.device.type, ov_pipeline.device)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version")
def test_num_images_per_prompt_static_model(self, model_arch: str):
model_id = MODEL_NAMES[model_arch]
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False)
batch_size, num_images, height, width = 3, 4, 128, 64
pipeline.half()
pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images)
self.assertFalse(pipeline.is_dynamic)
pipeline.compile()

for _height in [height, height + 16]:
inputs = _generate_inputs(batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=_height, width=width).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"stable-diffusion-xl-refiner": "echarlaix/tiny-random-stable-diffusion-xl-refiner",
"latent-consistency": "echarlaix/tiny-random-latent-consistency",
"sew": "hf-internal-testing/tiny-random-SEWModel",
"sew_d": "hf-internal-testing/tiny-random-SEWDModel",
"swin": "hf-internal-testing/tiny-random-SwinModel",
Expand Down
Loading