Skip to content

Commit

Permalink
Add controlnet option "is_loop"
Browse files Browse the repository at this point in the history
  • Loading branch information
s9roll7 committed Aug 30, 2023
1 parent c29ede7 commit 469e761
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
1 change: 1 addition & 0 deletions config/prompts/prompt_travel.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"max_models_on_vram" : 3,
"save_detectmap": true,
"preprocess_on_gpu": true,
"is_loop": true,

"controlnet_tile":{
"enable": true,
Expand Down
1 change: 1 addition & 0 deletions config/prompts/prompt_travel_multi_controlnet.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"max_models_on_vram" : 3,
"save_detectmap": true,
"preprocess_on_gpu": true,
"is_loop": true,

"controlnet_tile":{
"enable": true,
Expand Down
1 change: 1 addition & 0 deletions src/animatediff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def run_inference(
controlnet_image_map=controlnet_image_map,
controlnet_max_samples_on_vram=controlnet_map["max_samples_on_vram"] if "max_samples_on_vram" in controlnet_map else 999,
controlnet_max_models_on_vram=controlnet_map["max_models_on_vram"] if "max_models_on_vram" in controlnet_map else 99,
controlnet_is_loop = controlnet_map["is_loop"] if "is_loop" in controlnet_map else True,
)
logger.info("Generation complete, saving...")

Expand Down
28 changes: 17 additions & 11 deletions src/animatediff/pipelines/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ def __call__(
controlnet_image_map: Dict[int, Dict[str,Any]] = None,
controlnet_max_samples_on_vram: int = 999,
controlnet_max_models_on_vram: int=99,
controlnet_is_loop: bool=True,
**kwargs,
):
controlnet_image_map_org = controlnet_image_map
Expand Down Expand Up @@ -751,7 +752,7 @@ def get_current_prompt_embeds(

# { "0_type_str" : { "scales" = [0.1, 0.3, 0.5, 1.0, 0.5, 0.3, 0.1], "frames"=[125, 126, 127, 0, 1, 2, 3] }}
controlnet_scale_map = {}
controlnet_affected_list = [False for i in range(video_length)]
controlnet_affected_list = np.zeros(video_length,dtype = int)

if controlnet_image_map:
for type_str in controlnet_image_map:
Expand All @@ -760,15 +761,22 @@ def get_current_prompt_embeds(
scale_list = scale_list[0: context_frames]
scale_len = len(scale_list)

frames = [ i if 0 <= i < video_length else (i+video_length if 0 > i else i- video_length) for i in range(key_frame_no-scale_len, key_frame_no+scale_len+1)]
if controlnet_is_loop:
frames = [ i%video_length for i in range(key_frame_no-scale_len, key_frame_no+scale_len+1)]

controlnet_scale_map[str(key_frame_no) + "_" + type_str] = {
"scales" : scale_list[::-1] + [1.0] + scale_list,
"frames" : frames,
}
controlnet_scale_map[str(key_frame_no) + "_" + type_str] = {
"scales" : scale_list[::-1] + [1.0] + scale_list,
"frames" : frames,
}
else:
frames = [ i for i in range(max(0, key_frame_no-scale_len), min(key_frame_no+scale_len+1, video_length))]

for f in frames:
controlnet_affected_list[f] = True
controlnet_scale_map[str(key_frame_no) + "_" + type_str] = {
"scales" : scale_list[:key_frame_no][::-1] + [1.0] + scale_list[:video_length-key_frame_no-1],
"frames" : frames,
}

controlnet_affected_list[frames] = 1

def controlnet_is_affected( frame_index:int):
return controlnet_affected_list[frame_index]
Expand Down Expand Up @@ -1003,10 +1011,8 @@ def sample_to_device( sample ):
for context in context_scheduler(
i, num_inference_steps, latents.shape[2], context_frames, context_stride, context_overlap
):
#logger.info(f"{context=}")
controlnet_target = list(range(context[0]-context_frames, context[0])) + context + list(range(context[-1]+1, context[-1]+1+context_frames))
controlnet_target = [(f+video_length if f < 0 else f) for f in controlnet_target]
controlnet_target = [(f-video_length if f >= video_length else f) for f in controlnet_target]
controlnet_target = [f%video_length for f in controlnet_target]
controlnet_target = list(set(controlnet_target))

process_controlnet(controlnet_target)
Expand Down

0 comments on commit 469e761

Please sign in to comment.