Skip to content

Commit

Permalink
handle multiple outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 7, 2024
1 parent 1cbb82f commit 72aab16
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 19 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Pytest
on: [push, pull_request]

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[test]
- name: Test with pytest
run: |
python -m pytest tests/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<img src="./frame-averaging.png" width="350px"></img>

## Frame Averaging - Pytorch (wip)
## Frame Averaging - Pytorch

Pytorch implementation of a simple way to enable <a href="https://arxiv.org/abs/2305.05577">(Stochastic)</a> <a href="https://arxiv.org/abs/2110.03336">Frame Averaging</a> for any network. This technique was recently adopted by Prescient Design in <a href="https://arxiv.org/abs/2308.05027">AbDiffuser</a>

Expand Down
22 changes: 5 additions & 17 deletions frame_averaging_pytorch/frame_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def frame_average(out):
if not self.stochastic:
args, kwargs = tree_map(
lambda el: (
rearrange(el, 'b ... -> (b f) ...', f = num_frames)
repeat(el, 'b ... -> (b f) ...', f = num_frames)
if torch.is_tensor(el)
else el
)
Expand All @@ -179,21 +179,9 @@ def frame_average(out):

out = self.net(inputs, *args, **kwargs)

# handle if output is a tuple - just follow convention that first output is the one to be frame averaged
# (todo) - handle multiple outputs that need frame averaging
# use tree map to handle multiple outputs

is_multiple_output = isinstance(out, tuple)
out = tree_map(lambda t: rearrange(t, '(b f) ... -> b f ...', f = num_frames), out)
out = tree_map(lambda t: frame_average(t) if torch.is_tensor(t) else t, out)

if is_multiple_output:
out, *rest = out

# split frames from batch

out = rearrange(out, '(b f) ... -> b f ...', f = num_frames)

out = frame_average(out)

if not is_multiple_output:
return out

return (out, *rest)
return out
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "frame-averaging-pytorch"
version = "0.0.19"
version = "0.1.0"
description = "Frame Averaging"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down Expand Up @@ -34,6 +34,14 @@ Repository = "https://github.com/lucidrains/frame-averaging-pytorch"

[project.optional-dependencies]
examples = []
test = [
"pytest"
]

[tool.pytest.ini_options]
pythonpath = [
"."
]

[build-system]
requires = ["hatchling"]
Expand Down
71 changes: 71 additions & 0 deletions tests/test_frame_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest

import torch
from torch import nn
from torch.nn import Module
from frame_averaging_pytorch import FrameAverage

@pytest.mark.parametrize('stochastic', (True, False))
@pytest.mark.parametrize('dim', (2, 3, 4))
@pytest.mark.parametrize('has_mask', (True, False))
def test_frame_average(
stochastic: bool,
dim: int,
has_mask: bool
):

net = torch.nn.Linear(dim, dim)

net = FrameAverage(
net,
dim = dim,
stochastic = stochastic
)

points = torch.randn(4, 1024, dim)

mask = None
if has_mask:
mask = torch.ones(4, 1024).bool()

out = net(points, frame_average_mask = mask)
assert out.shape == points.shape

def test_frame_average_manual():

net = torch.nn.Linear(3, 3)

fa = FrameAverage()
points = torch.randn(4, 1024, 3)

framed_inputs, frame_average_fn = fa(points)

net_out = net(framed_inputs)

frame_averaged = frame_average_fn(net_out)

assert frame_averaged.shape == points.shape

def test_frame_average_multiple_inputs_and_outputs():

class Network(Module):
def __init__(self):
super().__init__()
self.net = nn.Linear(3, 3)
self.to_out1 = nn.Linear(3, 3)
self.to_out2 = nn.Linear(3, 3)

def forward(self, x, mask):
x = x.masked_fill(~mask[..., None], 0.)
hidden = self.net(x)
return self.to_out1(hidden), self.to_out2(hidden)

net = Network()
net = FrameAverage(net)

points = torch.randn(4, 1024, 3)
mask = torch.ones(4, 1024).bool()

out1, out2 = net(points, mask)

assert out1.shape == out2.shape == points.shape

0 comments on commit 72aab16

Please sign in to comment.