Skip to content

Commit

Permalink
Unbundle inputs generated by DummyTimestepInputGenerator (#2107)
Browse files Browse the repository at this point in the history
unbundle
  • Loading branch information
JingyaHuang authored Nov 28, 2024
1 parent a6c696c commit bd08f12
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,23 +897,31 @@ def __init__(
):
self.task = task
self.vocab_size = normalized_config.vocab_size
self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim
self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6
self.text_encoder_projection_dim = getattr(normalized_config, "text_encoder_projection_dim", None)
self.time_ids = 5 if getattr(normalized_config, "requires_aesthetics_score", False) else 6
if random_batch_size_range:
low, high = random_batch_size_range
self.batch_size = random.randint(low, high)
else:
self.batch_size = batch_size
self.time_cond_proj_dim = normalized_config.config.time_cond_proj_dim
self.time_cond_proj_dim = getattr(normalized_config.config, "time_cond_proj_dim", None)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "timestep":
shape = [] # a scalar with no dimension (it can be int or float depending on the sd architecture)
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)

if input_name == "text_embeds":
if self.text_encoder_projection_dim is None:
raise ValueError(
"Unable to infer the value of `text_encoder_projection_dim` for generating `text_embeds`, please double check the config of your model."
)
dim = self.text_encoder_projection_dim
elif input_name == "timestep_cond":
if self.time_cond_proj_dim is None:
raise ValueError(
"Unable to infer the value of `time_cond_proj_dim` for generating `timestep_cond`, please double check the config of your model."
)
dim = self.time_cond_proj_dim
else:
dim = self.time_ids
Expand Down

0 comments on commit bd08f12

Please sign in to comment.