Skip to content

Commit

Permalink
sample with multiple text
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Feb 6, 2025
1 parent 64d449f commit def6ef0
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 43 deletions.
11 changes: 6 additions & 5 deletions examples/hunyuanvideo/hyvideo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ def add_inference_args(parser: argparse.ArgumentParser):
"--prompt",
type=str,
default=None,
help="Prompt for sampling during evaluation.",
help="A single prompt string or a path to a .txt file containing multiple prompts. "
"If a .txt file is provided, each line should contain one prompt.",
)
group.add_argument(
"--seed-type",
Expand Down Expand Up @@ -365,11 +366,11 @@ def add_inference_args(parser: argparse.ArgumentParser):
group.add_argument(
"--text-embed-path",
type=str,
help="path to npz containing text embeds, "
"including positive/negative prompt embed of text encoder 1 and 2"
", and the mask for positive and negative prompt",
default=None,
help="A single .npz file path or a path to a .txt file containing multiple .npz file paths. "
"If a .txt file is provided, each line should contain one .npz file path. "
"This argument is required if `prompt` is a .txt file.",
)

# mindspore args
group.add_argument("--ms-mode", type=int, default=0, help="0 graph, 1 pynative")
group.add_argument(
Expand Down
6 changes: 6 additions & 0 deletions examples/hunyuanvideo/hyvideo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ def predict(
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if prompt is None:
if text_embed_path is None:
raise ValueError("Either `prompt` or `text_embed_path` must be specified.")
prompt = ""

if not isinstance(prompt, str):
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
prompt = [prompt.strip()]
Expand Down Expand Up @@ -532,6 +537,7 @@ def predict(
width: {target_width}
video_length: {target_video_length}
prompt: {prompt}
text_embed_path: {text_embed_path}
neg_prompt: {negative_prompt}
seed: {seed}
infer_steps: {infer_steps}
Expand Down
52 changes: 52 additions & 0 deletions examples/hunyuanvideo/hyvideo/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
def process_prompt_and_text_embed(prompt, text_embed_path):
"""
Process the prompt and text embed path to ensure they are in the correct format and match in number.
Args:
prompt (str or None): A string prompt or path to a text file containing prompts.
text_embed_path (str or None): A path to a text file containing text embeddings or a npz file.
Returns:
tuple: A tuple containing two lists, the first with prompts and the second with text embed paths.
Raises:
ValueError: If the prompt and text embed path combination is invalid or the number of prompts and text embed paths do not match.
"""
if prompt is None and text_embed_path is None:
raise ValueError("Either `prompt` or `text_embed_path` must be provided.")

if prompt is not None and prompt.endswith(".txt"):
with open(prompt, "r") as f:
prompts = [line.strip() for line in f.readlines()]
if text_embed_path is not None:
assert text_embed_path.endswith(
".txt"
), "When `prompt` is a txt file, `text_embed_path` should be a txt file too."
with open(text_embed_path, "r") as f:
text_embed_paths = [line.strip() for line in f.readlines()]
else:
text_embed_paths = [None] * len(prompts)
elif prompt is not None and isinstance(prompt, str):
prompts = [prompt.strip()]
if text_embed_path is not None:
assert text_embed_path.endswith(
".npz"
), "When `prompt` is a string, `text_embed_path` should be a npz file."
text_embed_paths = [text_embed_path.strip()]
else:
text_embed_paths = [None] * len(prompts)
elif text_embed_path is not None:
if text_embed_path.endswith(".txt"):
with open(text_embed_path, "r") as f:
text_embed_paths = [line.strip() for line in f.readlines()]
prompts = [None] * len(text_embed_paths)
else:
text_embed_paths = [text_embed_path.strip()]
prompts = [None]
else:
raise ValueError("Invalid combination of `prompt` and `text_embed_path`.")

if len(prompts) != len(text_embed_paths):
raise ValueError("The number of prompts and text embed paths must match.")

return prompts, text_embed_paths
87 changes: 49 additions & 38 deletions examples/hunyuanvideo/sample_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
sys.path.append(".")
from hyvideo.config import parse_args
from hyvideo.inference import HunyuanVideoSampler
from hyvideo.utils.file_utils import process_prompt_and_text_embed
from hyvideo.utils.ms_utils import init_env


Expand All @@ -27,9 +28,9 @@ def main():
raise ValueError(f"`models_root` not exists: {models_root_path}")

# Create save folder to save the samples
save_path = args.save_path if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
if not os.path.exists(args.save_path):
os.makedirs(save_path, exist_ok=True)
save_dir = args.save_path if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
if not os.path.exists(args.save_dir):
os.makedirs(save_dir, exist_ok=True)

# ms env init
rank_id, _ = init_env(
Expand All @@ -51,44 +52,54 @@ def main():
args = hunyuan_video_sampler.args

# Start sampling
# TODO: batch inference check
outputs = hunyuan_video_sampler.predict(
prompt=args.prompt,
height=args.video_size[0],
width=args.video_size[1],
video_length=args.video_length,
seed=args.seed,
negative_prompt=args.neg_prompt,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
num_videos_per_prompt=args.num_videos,
flow_shift=args.flow_shift,
batch_size=args.batch_size,
embedded_guidance_scale=args.embedded_cfg_scale,
output_type=args.output_type,
text_embed_path=args.text_embed_path,
)
samples = outputs["samples"]

# Save samples
if rank_id == 0:
for i, sample in enumerate(samples):
sample = samples[i].unsqueeze(0)
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
save_path = (
f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
if args.prompt is None and args.text_embed_path is None:
raise ValueError("Either `prompt` or `text_embed_path` must be provided.")

prompts, text_embed_paths = process_prompt_and_text_embed(args.prompt, args.text_embed_path)
for prompt, text_embed_path in zip(prompts, text_embed_paths):
if prompt is not None:
logger.info(f"Sampling with prompt: {prompt}")
if text_embed_path is not None:
logger.info(f"Sampling with text embed path: {text_embed_path}")
try:
outputs = hunyuan_video_sampler.predict(
prompt=prompt,
height=args.video_size[0],
width=args.video_size[1],
video_length=args.video_length,
seed=args.seed,
negative_prompt=args.neg_prompt,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
num_videos_per_prompt=args.num_videos,
flow_shift=args.flow_shift,
batch_size=args.batch_size,
embedded_guidance_scale=args.embedded_cfg_scale,
output_type=args.output_type,
text_embed_path=text_embed_path,
)
samples = outputs["samples"]
except Exception as e:
logger.error(f"Error during prediction: {e}")
continue

# Save samples
if rank_id == 0:
for i, sample in enumerate(samples):
sample = samples[i].unsqueeze(0)
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
save_path = f"{save_dir}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"

if args.output_type != "latent":
# save_videos_grid(sample, save_path, fps=24)
# b c t h w -> b t h w c
sample = sample.permute(0, 2, 3, 4, 1).asnumpy()
save_videos(sample, save_path, fps=24)
else:
save_path = save_path[:-4] + ".npy"
np.save(save_path, sample)
if args.output_type != "latent":
# save_videos_grid(sample, save_path, fps=24)
# b c t h w -> b t h w c
sample = sample.permute(0, 2, 3, 4, 1).asnumpy()
save_videos(sample, save_path, fps=24)
else:
save_path = save_path[:-4] + ".npy"
np.save(save_path, sample)

logger.info(f"Sample save to: {save_path}")
logger.info(f"Sample save to: {save_path}")


if __name__ == "__main__":
Expand Down

0 comments on commit def6ef0

Please sign in to comment.