-
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
2,637 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
# Stuff | ||
|
||
articulate_module/models/*.pth | ||
|
||
checkpoints/* | ||
!checkpoints/place_checkpoints_here.txt | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import numpy as np | ||
import torch | ||
import tqdm | ||
import yaml | ||
from comfy.utils import ProgressBar | ||
from scipy.spatial import ConvexHull | ||
|
||
from .articulate_module.avd_network import AVDNetwork | ||
from .articulate_module.bg_motion_predictor import BGMotionPredictor | ||
from .articulate_module.generator import Generator | ||
from .articulate_module.region_predictor import RegionPredictor | ||
from .sync_batchnorm.replicate import DataParallelWithCallback | ||
|
||
|
||
def articulate_inference( | ||
source_image, | ||
driving_video: list, | ||
config_path: str, | ||
checkpoint_path: str, | ||
estimate_affine=True, | ||
pca_based=True, | ||
): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
with open(config_path) as f: | ||
config = yaml.full_load(f) | ||
|
||
generator, region_predictor, bg_predictor, avd_network = init_models( | ||
config, estimate_affine, pca_based | ||
) | ||
generator = generator.to(device) | ||
region_predictor = region_predictor.to(device) | ||
bg_predictor = bg_predictor.to(device) | ||
avd_network = avd_network.to(device) | ||
|
||
animate_params = config["animate_params"] | ||
|
||
load_cpk(checkpoint_path, generator, region_predictor, bg_predictor, avd_network) | ||
|
||
if torch.cuda.is_available(): | ||
generator = DataParallelWithCallback(generator) | ||
region_predictor = DataParallelWithCallback(region_predictor) | ||
avd_network = DataParallelWithCallback(avd_network) | ||
|
||
generator.eval() | ||
region_predictor.eval() | ||
avd_network.eval() | ||
|
||
source_frame = source_image | ||
driving = driving_video | ||
predictions = [] | ||
|
||
num_frames = driving.shape[2] | ||
pbar = ProgressBar(num_frames) | ||
|
||
with torch.no_grad(): | ||
source_region_params = region_predictor(source_frame) | ||
driving_region_params_initial = region_predictor(driving_video[:, :, 0]) | ||
|
||
for frame_idx in tqdm.tqdm(range(num_frames)): | ||
driving_frame = driving[:, :, frame_idx] | ||
driving_region_params = region_predictor(driving_frame) | ||
new_region_params = get_animation_region_params( | ||
source_region_params, | ||
driving_region_params, | ||
driving_region_params_initial, | ||
mode=animate_params["mode"], | ||
avd_network=avd_network, | ||
) | ||
out = generator( | ||
source_frame, | ||
source_region_params=source_region_params, | ||
driving_region_params=new_region_params, | ||
) | ||
|
||
out["driving_region_params"] = driving_region_params | ||
out["source_region_params"] = source_region_params | ||
out["new_region_params"] = new_region_params | ||
|
||
# visualization = Visualizer(**config["visualizer_params"]).visualize( | ||
# source=source_frame, driving=driving_frame, out=out | ||
# ) / 255.0 | ||
|
||
prediction = out["prediction"].data.cpu().numpy() | ||
prediction = np.transpose(prediction, [0, 2, 3, 1]).squeeze(0) | ||
|
||
# visualizations.append(visualization) | ||
predictions.append(prediction) | ||
|
||
pbar.update_absolute(frame_idx, num_frames) | ||
|
||
# print(f"{predictions[0].shape=}") | ||
# print(f"{visualizations[0].shape=}") | ||
|
||
# return predictions, visualizations | ||
return predictions | ||
|
||
|
||
def init_models(config, estimate_affine, pca_based): | ||
generator = Generator( | ||
num_regions=config["model_params"]["num_regions"], | ||
num_channels=config["model_params"]["num_channels"], | ||
revert_axis_swap=config["model_params"]["revert_axis_swap"], | ||
**config["model_params"]["generator_params"], | ||
) | ||
|
||
config["model_params"]["region_predictor_params"]["pca_based"] = pca_based | ||
|
||
region_predictor = RegionPredictor( | ||
num_regions=config["model_params"]["num_regions"], | ||
num_channels=config["model_params"]["num_channels"], | ||
estimate_affine=estimate_affine, | ||
**config["model_params"]["region_predictor_params"], | ||
) | ||
|
||
bg_predictor = BGMotionPredictor( | ||
num_channels=config["model_params"]["num_channels"], | ||
**config["model_params"]["bg_predictor_params"], | ||
) | ||
|
||
avd_network = AVDNetwork( | ||
num_regions=config["model_params"]["num_regions"], | ||
**config["model_params"]["avd_network_params"], | ||
) | ||
|
||
return generator, region_predictor, bg_predictor, avd_network | ||
|
||
|
||
def load_cpk( | ||
checkpoint_path, | ||
generator=None, | ||
region_predictor=None, | ||
bg_predictor=None, | ||
avd_network=None, | ||
optimizer_reconstruction=None, | ||
optimizer_avd=None, | ||
): | ||
checkpoint = torch.load(checkpoint_path) | ||
# print(checkpoint.keys()) | ||
if generator is not None: | ||
generator.load_state_dict(checkpoint["generator"]) | ||
if region_predictor is not None: | ||
region_predictor.load_state_dict(checkpoint["region_predictor"]) | ||
if bg_predictor is not None: | ||
bg_predictor.load_state_dict(checkpoint["bg_predictor"]) | ||
if avd_network is not None: | ||
if "avd_network" in checkpoint: | ||
avd_network.load_state_dict(checkpoint["avd_network"]) | ||
|
||
if optimizer_reconstruction is not None: | ||
optimizer_reconstruction.load_state_dict(checkpoint["optimizer_reconstruction"]) | ||
return checkpoint["epoch_reconstruction"] | ||
|
||
if optimizer_avd is not None: | ||
if "optimizer_avd" in checkpoint: | ||
optimizer_avd.load_state_dict(checkpoint["optimizer_avd"]) | ||
return checkpoint["epoch_avd"] | ||
return 0 | ||
|
||
return 0 | ||
|
||
|
||
def get_animation_region_params( | ||
source_region_params, | ||
driving_region_params, | ||
driving_region_params_initial, | ||
mode="standard", | ||
avd_network=None, | ||
adapt_movement_scale=True, | ||
): | ||
assert mode in ["standard", "relative", "avd"] | ||
new_region_params = {k: v for k, v in driving_region_params.items()} | ||
if mode == "standard": | ||
return new_region_params | ||
elif mode == "relative": | ||
source_area = ConvexHull( | ||
source_region_params["shift"][0].data.cpu().numpy() | ||
).volume | ||
driving_area = ConvexHull( | ||
driving_region_params_initial["shift"][0].data.cpu().numpy() | ||
).volume | ||
movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) | ||
|
||
shift_diff = ( | ||
driving_region_params["shift"] - driving_region_params_initial["shift"] | ||
) | ||
shift_diff *= movement_scale | ||
new_region_params["shift"] = shift_diff + source_region_params["shift"] | ||
|
||
affine_diff = torch.matmul( | ||
driving_region_params["affine"], | ||
torch.inverse(driving_region_params_initial["affine"]), | ||
) | ||
new_region_params["affine"] = torch.matmul( | ||
affine_diff, source_region_params["affine"] | ||
) | ||
return new_region_params | ||
elif mode == "avd": | ||
new_region_params = avd_network(source_region_params, driving_region_params) | ||
return new_region_params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
""" | ||
Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. | ||
No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, | ||
publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. | ||
Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, | ||
title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. | ||
In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. | ||
""" | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class AVDNetwork(nn.Module): | ||
""" | ||
Animation via Disentanglement network | ||
""" | ||
|
||
def __init__(self, num_regions, id_bottle_size=64, pose_bottle_size=64, revert_axis_swap=True): | ||
super(AVDNetwork, self).__init__() | ||
input_size = (2 + 4) * num_regions | ||
self.num_regions = num_regions | ||
self.revert_axis_swap = revert_axis_swap | ||
|
||
self.id_encoder = nn.Sequential( | ||
nn.Linear(input_size, 256), | ||
nn.BatchNorm1d(256), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(256, 512), | ||
nn.BatchNorm1d(512), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(512, 1024), | ||
nn.BatchNorm1d(1024), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(1024, id_bottle_size) | ||
) | ||
|
||
self.pose_encoder = nn.Sequential( | ||
nn.Linear(input_size, 256), | ||
nn.BatchNorm1d(256), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(256, 512), | ||
nn.BatchNorm1d(512), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(512, 1024), | ||
nn.BatchNorm1d(1024), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(1024, pose_bottle_size) | ||
) | ||
|
||
self.decoder = nn.Sequential( | ||
nn.Linear(pose_bottle_size + id_bottle_size, 1024), | ||
nn.BatchNorm1d(1024), | ||
nn.ReLU(), | ||
nn.Linear(1024, 512), | ||
nn.BatchNorm1d(512), | ||
nn.ReLU(), | ||
nn.Linear(512, 256), | ||
nn.BatchNorm1d(256), | ||
nn.ReLU(), | ||
nn.Linear(256, input_size) | ||
) | ||
|
||
@staticmethod | ||
def region_params_to_emb(x): | ||
mean = x['shift'] | ||
jac = x['affine'] | ||
emb = torch.cat([mean, jac.view(jac.shape[0], jac.shape[1], -1)], dim=-1) | ||
emb = emb.view(emb.shape[0], -1) | ||
return emb | ||
|
||
def emb_to_region_params(self, emb): | ||
emb = emb.view(emb.shape[0], self.num_regions, 6) | ||
mean = emb[:, :, :2] | ||
jac = emb[:, :, 2:].view(emb.shape[0], emb.shape[1], 2, 2) | ||
return {'shift': mean, 'affine': jac} | ||
|
||
def forward(self, x_id, x_pose, alpha=0.2): | ||
if self.revert_axis_swap: | ||
affine = torch.matmul(x_id['affine'], torch.inverse(x_pose['affine'])) | ||
sign = torch.sign(affine[:, :, 0:1, 0:1]) | ||
x_id = {'affine': x_id['affine'] * sign, 'shift': x_id['shift']} | ||
|
||
pose_emb = self.pose_encoder(self.region_params_to_emb(x_pose)) | ||
id_emb = self.id_encoder(self.region_params_to_emb(x_id)) | ||
|
||
rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1)) | ||
|
||
rec = self.emb_to_region_params(rec) | ||
rec['covar'] = torch.matmul(rec['affine'], rec['affine'].permute(0, 1, 3, 2)) | ||
return rec |
Oops, something went wrong.