Skip to content

Commit

Permalink
Add articulated animation (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
FuouM authored Aug 4, 2024
1 parent d7386f0 commit cd43aad
Show file tree
Hide file tree
Showing 27 changed files with 2,637 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Stuff

articulate_module/models/*.pth

checkpoints/*
!checkpoints/place_checkpoints_here.txt

Expand Down
24 changes: 20 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ Now supports Face Swapping using Motion Supervised co-part Segmentation

https://github.com/AliaksandrSiarohin/motion-cosegmentation

Now also supports Motion Representations for Articulated Animation

https://github.com/snap-research/articulated-animation

## Workflow:


Expand All @@ -15,15 +19,21 @@ https://github.com/FuouM/ComfyUI-FirstOrderMM/assets/57849916/e8adca97-b53c-48ce

### FOMM

[FOMM.json](FOMM.json)
[FOMM.json](workflows/FOMM.json)

![FOMM Workflow](workflow.png)
![FOMM Workflow](workflows/workflow.png)

### Part Swap

[FOMM_PARTSWAP.json](FOMM_PARTSWAP.json)
[FOMM_PARTSWAP.json](workflows/FOMM_PARTSWAP.json)

![Partswap Workflow](workflows/workflow_fomm_partswap.png)

### Articulate

![Partswap Workflow](workflow_fomm_partswap.png)
[ARTICULATE.json](workflows/ARTICULATE.json)

![Workflow Articulate](workflows/workflow_articulate.png)

## Arguments

Expand All @@ -42,6 +52,10 @@ https://github.com/FuouM/ComfyUI-FirstOrderMM/assets/57849916/e8adca97-b53c-48ce
* `use_face_parser`: For Seg-based models, may help with cleaning up residual background (should only use `15seg` with this). TODO: Additional cleanup face_parser masks. Should definitely be used for FOMM models
* `viz_alpha`: Opacity of the segments in the visualization

### Articulate

Doesn't need any

## Installation

1. Clone the repo to `ComfyUI/custom_nodes/`
Expand Down Expand Up @@ -98,3 +112,5 @@ face_parsing_model.py
resnet18-5c106cde.pth
79999_iter.pth
```

For **Articulate**, download the model from [Pre-trained checkpoints](https://github.com/snap-research/articulated-animation?tab=readme-ov-file#pre-trained-checkpoints) section and place it here: `articulate_module\models\vox256.pth`
3 changes: 3 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
FOMM_Seg5Chooser,
FOMM_Seg10Chooser,
FOMM_Seg15Chooser,
Articulate_Runner,
)

NODE_CLASS_MAPPINGS = {
Expand All @@ -12,6 +13,7 @@
"FOMM_Seg5Chooser": FOMM_Seg5Chooser,
"FOMM_Seg10Chooser": FOMM_Seg10Chooser,
"FOMM_Seg15Chooser": FOMM_Seg15Chooser,
"Articulate_Runner": Articulate_Runner,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -20,6 +22,7 @@
"FOMM_Seg5Chooser": "FOMM Seg5 Chooser",
"FOMM_Seg10Chooser": "FOMM Seg10 Chooser",
"FOMM_Seg15Chooser": "FOMM Seg15 Chooser",
"Articulate_Runner": "Articulate Runner",
}


Expand Down
200 changes: 200 additions & 0 deletions articulate_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import numpy as np
import torch
import tqdm
import yaml
from comfy.utils import ProgressBar
from scipy.spatial import ConvexHull

from .articulate_module.avd_network import AVDNetwork
from .articulate_module.bg_motion_predictor import BGMotionPredictor
from .articulate_module.generator import Generator
from .articulate_module.region_predictor import RegionPredictor
from .sync_batchnorm.replicate import DataParallelWithCallback


def articulate_inference(
source_image,
driving_video: list,
config_path: str,
checkpoint_path: str,
estimate_affine=True,
pca_based=True,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open(config_path) as f:
config = yaml.full_load(f)

generator, region_predictor, bg_predictor, avd_network = init_models(
config, estimate_affine, pca_based
)
generator = generator.to(device)
region_predictor = region_predictor.to(device)
bg_predictor = bg_predictor.to(device)
avd_network = avd_network.to(device)

animate_params = config["animate_params"]

load_cpk(checkpoint_path, generator, region_predictor, bg_predictor, avd_network)

if torch.cuda.is_available():
generator = DataParallelWithCallback(generator)
region_predictor = DataParallelWithCallback(region_predictor)
avd_network = DataParallelWithCallback(avd_network)

generator.eval()
region_predictor.eval()
avd_network.eval()

source_frame = source_image
driving = driving_video
predictions = []

num_frames = driving.shape[2]
pbar = ProgressBar(num_frames)

with torch.no_grad():
source_region_params = region_predictor(source_frame)
driving_region_params_initial = region_predictor(driving_video[:, :, 0])

for frame_idx in tqdm.tqdm(range(num_frames)):
driving_frame = driving[:, :, frame_idx]
driving_region_params = region_predictor(driving_frame)
new_region_params = get_animation_region_params(
source_region_params,
driving_region_params,
driving_region_params_initial,
mode=animate_params["mode"],
avd_network=avd_network,
)
out = generator(
source_frame,
source_region_params=source_region_params,
driving_region_params=new_region_params,
)

out["driving_region_params"] = driving_region_params
out["source_region_params"] = source_region_params
out["new_region_params"] = new_region_params

# visualization = Visualizer(**config["visualizer_params"]).visualize(
# source=source_frame, driving=driving_frame, out=out
# ) / 255.0

prediction = out["prediction"].data.cpu().numpy()
prediction = np.transpose(prediction, [0, 2, 3, 1]).squeeze(0)

# visualizations.append(visualization)
predictions.append(prediction)

pbar.update_absolute(frame_idx, num_frames)

# print(f"{predictions[0].shape=}")
# print(f"{visualizations[0].shape=}")

# return predictions, visualizations
return predictions


def init_models(config, estimate_affine, pca_based):
generator = Generator(
num_regions=config["model_params"]["num_regions"],
num_channels=config["model_params"]["num_channels"],
revert_axis_swap=config["model_params"]["revert_axis_swap"],
**config["model_params"]["generator_params"],
)

config["model_params"]["region_predictor_params"]["pca_based"] = pca_based

region_predictor = RegionPredictor(
num_regions=config["model_params"]["num_regions"],
num_channels=config["model_params"]["num_channels"],
estimate_affine=estimate_affine,
**config["model_params"]["region_predictor_params"],
)

bg_predictor = BGMotionPredictor(
num_channels=config["model_params"]["num_channels"],
**config["model_params"]["bg_predictor_params"],
)

avd_network = AVDNetwork(
num_regions=config["model_params"]["num_regions"],
**config["model_params"]["avd_network_params"],
)

return generator, region_predictor, bg_predictor, avd_network


def load_cpk(
checkpoint_path,
generator=None,
region_predictor=None,
bg_predictor=None,
avd_network=None,
optimizer_reconstruction=None,
optimizer_avd=None,
):
checkpoint = torch.load(checkpoint_path)
# print(checkpoint.keys())
if generator is not None:
generator.load_state_dict(checkpoint["generator"])
if region_predictor is not None:
region_predictor.load_state_dict(checkpoint["region_predictor"])
if bg_predictor is not None:
bg_predictor.load_state_dict(checkpoint["bg_predictor"])
if avd_network is not None:
if "avd_network" in checkpoint:
avd_network.load_state_dict(checkpoint["avd_network"])

if optimizer_reconstruction is not None:
optimizer_reconstruction.load_state_dict(checkpoint["optimizer_reconstruction"])
return checkpoint["epoch_reconstruction"]

if optimizer_avd is not None:
if "optimizer_avd" in checkpoint:
optimizer_avd.load_state_dict(checkpoint["optimizer_avd"])
return checkpoint["epoch_avd"]
return 0

return 0


def get_animation_region_params(
source_region_params,
driving_region_params,
driving_region_params_initial,
mode="standard",
avd_network=None,
adapt_movement_scale=True,
):
assert mode in ["standard", "relative", "avd"]
new_region_params = {k: v for k, v in driving_region_params.items()}
if mode == "standard":
return new_region_params
elif mode == "relative":
source_area = ConvexHull(
source_region_params["shift"][0].data.cpu().numpy()
).volume
driving_area = ConvexHull(
driving_region_params_initial["shift"][0].data.cpu().numpy()
).volume
movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)

shift_diff = (
driving_region_params["shift"] - driving_region_params_initial["shift"]
)
shift_diff *= movement_scale
new_region_params["shift"] = shift_diff + source_region_params["shift"]

affine_diff = torch.matmul(
driving_region_params["affine"],
torch.inverse(driving_region_params_initial["affine"]),
)
new_region_params["affine"] = torch.matmul(
affine_diff, source_region_params["affine"]
)
return new_region_params
elif mode == "avd":
new_region_params = avd_network(source_region_params, driving_region_params)
return new_region_params
90 changes: 90 additions & 0 deletions articulate_module/avd_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.
No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,
publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.
Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,
title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.
In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.
"""
import torch
from torch import nn


class AVDNetwork(nn.Module):
"""
Animation via Disentanglement network
"""

def __init__(self, num_regions, id_bottle_size=64, pose_bottle_size=64, revert_axis_swap=True):
super(AVDNetwork, self).__init__()
input_size = (2 + 4) * num_regions
self.num_regions = num_regions
self.revert_axis_swap = revert_axis_swap

self.id_encoder = nn.Sequential(
nn.Linear(input_size, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(inplace=True),
nn.Linear(1024, id_bottle_size)
)

self.pose_encoder = nn.Sequential(
nn.Linear(input_size, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(inplace=True),
nn.Linear(1024, pose_bottle_size)
)

self.decoder = nn.Sequential(
nn.Linear(pose_bottle_size + id_bottle_size, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, input_size)
)

@staticmethod
def region_params_to_emb(x):
mean = x['shift']
jac = x['affine']
emb = torch.cat([mean, jac.view(jac.shape[0], jac.shape[1], -1)], dim=-1)
emb = emb.view(emb.shape[0], -1)
return emb

def emb_to_region_params(self, emb):
emb = emb.view(emb.shape[0], self.num_regions, 6)
mean = emb[:, :, :2]
jac = emb[:, :, 2:].view(emb.shape[0], emb.shape[1], 2, 2)
return {'shift': mean, 'affine': jac}

def forward(self, x_id, x_pose, alpha=0.2):
if self.revert_axis_swap:
affine = torch.matmul(x_id['affine'], torch.inverse(x_pose['affine']))
sign = torch.sign(affine[:, :, 0:1, 0:1])
x_id = {'affine': x_id['affine'] * sign, 'shift': x_id['shift']}

pose_emb = self.pose_encoder(self.region_params_to_emb(x_pose))
id_emb = self.id_encoder(self.region_params_to_emb(x_id))

rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1))

rec = self.emb_to_region_params(rec)
rec['covar'] = torch.matmul(rec['affine'], rec['affine'].permute(0, 1, 3, 2))
return rec
Loading

0 comments on commit cd43aad

Please sign in to comment.