Skip to content

Commit

Permalink
Add InstructPix2Pix pipeline support.
Browse files Browse the repository at this point in the history
  • Loading branch information
asntr committed Jun 7, 2024
1 parent 5b311a3 commit 0eea205
Show file tree
Hide file tree
Showing 7 changed files with 551 additions and 1 deletion.
2 changes: 2 additions & 0 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"NeuronStableDiffusionPipeline",
"NeuronStableDiffusionImg2ImgPipeline",
"NeuronStableDiffusionInpaintPipeline",
"NeuronStableDiffusionInstructPix2PixPipeline",
"NeuronLatentConsistencyModelPipeline",
"NeuronStableDiffusionXLPipeline",
"NeuronStableDiffusionXLImg2ImgPipeline",
Expand Down Expand Up @@ -78,6 +79,7 @@
NeuronLatentConsistencyModelPipeline,
NeuronStableDiffusionImg2ImgPipeline,
NeuronStableDiffusionInpaintPipeline,
NeuronStableDiffusionInstructPix2PixPipeline,
NeuronStableDiffusionPipeline,
NeuronStableDiffusionXLImg2ImgPipeline,
NeuronStableDiffusionXLInpaintPipeline,
Expand Down
8 changes: 7 additions & 1 deletion optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
NeuronLatentConsistencyPipelineMixin,
NeuronStableDiffusionImg2ImgPipelineMixin,
NeuronStableDiffusionInpaintPipelineMixin,
NeuronStableDiffusionInstructPix2PixPipelineMixin,
NeuronStableDiffusionPipelineMixin,
NeuronStableDiffusionXLImg2ImgPipelineMixin,
NeuronStableDiffusionXLInpaintPipelineMixin,
Expand Down Expand Up @@ -1003,6 +1004,12 @@ class NeuronStableDiffusionInpaintPipeline(
__call__ = NeuronStableDiffusionInpaintPipelineMixin.__call__


class NeuronStableDiffusionInstructPix2PixPipeline(
NeuronStableDiffusionPipelineBase, NeuronStableDiffusionInstructPix2PixPipelineMixin
):
__call__ = NeuronStableDiffusionInstructPix2PixPipelineMixin.__call__


class NeuronLatentConsistencyModelPipeline(NeuronStableDiffusionPipelineBase, NeuronLatentConsistencyPipelineMixin):
__call__ = NeuronLatentConsistencyPipelineMixin.__call__

Expand Down Expand Up @@ -1081,7 +1088,6 @@ class NeuronStableDiffusionXLInpaintPipeline(
if is_neuronx_available():
# TO REMOVE: This class will be included directly in the DDP API of Neuron SDK 2.20
class WeightSeparatedDataParallel(torch_neuronx.DataParallel):

def _load_modules(self, module):
try:
self.device_ids.sort()
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"NeuronStableDiffusionPipelineMixin",
"NeuronStableDiffusionImg2ImgPipelineMixin",
"NeuronStableDiffusionInpaintPipelineMixin",
"NeuronStableDiffusionInstructPix2PixPipelineMixin",
"NeuronLatentConsistencyPipelineMixin",
"NeuronStableDiffusionXLPipelineMixin",
"NeuronStableDiffusionXLImg2ImgPipelineMixin",
Expand All @@ -36,6 +37,7 @@
NeuronLatentConsistencyPipelineMixin,
NeuronStableDiffusionImg2ImgPipelineMixin,
NeuronStableDiffusionInpaintPipelineMixin,
NeuronStableDiffusionInstructPix2PixPipelineMixin,
NeuronStableDiffusionPipelineMixin,
NeuronStableDiffusionXLImg2ImgPipelineMixin,
NeuronStableDiffusionXLInpaintPipelineMixin,
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/pipelines/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .pipeline_stable_diffusion import NeuronStableDiffusionPipelineMixin
from .pipeline_stable_diffusion_img2img import NeuronStableDiffusionImg2ImgPipelineMixin
from .pipeline_stable_diffusion_inpaint import NeuronStableDiffusionInpaintPipelineMixin
from .pipeline_stable_diffusion_instruct_pix2pix import NeuronStableDiffusionInstructPix2PixPipelineMixin
from .pipeline_stable_diffusion_xl import NeuronStableDiffusionXLPipelineMixin
from .pipeline_stable_diffusion_xl_img2img import NeuronStableDiffusionXLImg2ImgPipelineMixin
from .pipeline_stable_diffusion_xl_inpaint import NeuronStableDiffusionXLInpaintPipelineMixin
Loading

0 comments on commit 0eea205

Please sign in to comment.