Skip to content

Commit

Permalink
allow for returning framed inputs, and the appropriate frame averagin…
Browse files Browse the repository at this point in the history
…g function for the network output handled externally
  • Loading branch information
lucidrains committed Jun 3, 2024
1 parent 2231ba0 commit 7de1ea0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
35 changes: 24 additions & 11 deletions frame_averaging_pytorch/frame_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def forward(
points,
*args,
frame_average_mask = None,
return_framed_inputs_and_averaging_function = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -126,6 +127,28 @@ def forward(

inputs = einsum(frames, centered_points, 'b f d e, b n d -> b f n e')

# define the frame averaging function

def frame_average(out):
if not self.invariant_output:
# apply frames

out = einsum(frames, out, 'b f d e, b f ... e -> b f ... d')

if not self.stochastic:
# averaging across frames, thus "frame averaging"

out = reduce(out, 'b f ... -> b ...', 'mean')
else:
out = rearrange(out, 'b 1 ... -> b ...')

return out

# if one wants to handle the framed inputs externally

if return_framed_inputs_and_averaging_function:
return inputs, frame_average

# merge frames into batch

inputs = rearrange(inputs, 'b f ... -> (b f) ...')
Expand Down Expand Up @@ -162,17 +185,7 @@ def forward(

out = rearrange(out, '(b f) ... -> b f ...', f = num_frames)

if not self.invariant_output:
# apply frames

out = einsum(frames, out, 'b f d e, b f n e -> b f n d')

if not self.stochastic:
# averaging across frames, thus "frame averaging"

out = reduce(out, 'b f ... -> b ...', 'mean')
else:
out = rearrange(out, 'b 1 ... -> b ...')
out = frame_average(out)

if not is_multiple_output:
return out
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "frame-averaging-pytorch"
version = "0.0.14"
version = "0.0.15"
description = "Frame Averaging"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit 7de1ea0

Please sign in to comment.