Skip to content

Commit

Permalink
Update __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim authored Sep 11, 2023
1 parent 7c52128 commit 77a666e
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions torchbenchmark/models/stable_diffusion_xl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ class Model(BenchmarkModel, HuggingFaceAuthMixin):
DEFAULT_TRAIN_BSIZE = 1
DEFAULT_EVAL_BSIZE = 1
ALLOW_CUSTOMIZE_BSIZE = False
# Default eval precision on CUDA device is fp16
DEFAULT_EVAL_CUDA_PRECISION = "fp16"

def __init__(self, test, device, batch_size=None, extra_args=[]):
HuggingFaceAuthMixin.__init__(self)
Expand All @@ -32,19 +30,13 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
torch.randn(1, 1, 2048).to(self.device),
{"text_embeds": torch.randn(1, 2560).to(self.device), "time_ids": torch.tensor([1]).to(self.device)}
]



def enable_fp16_half(self):
pass


def get_module(self):
self.pipe.unet, self.list_of_inputs

def train(self):
raise NotImplementedError("Train test is not implemented for the stable diffusion model.")
raise NotImplementedError("Train is not implemented for the stable diffusion model.")

def eval(self):
image = self.pipe(self.example_inputs)
return (image, )
image = self.pipe(*self.list_of_inputs)
return (image, )

0 comments on commit 77a666e

Please sign in to comment.