Skip to content

Commit

Permalink
quick fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 7, 2024
1 parent a62fc65 commit 49f861f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion frame_averaging_pytorch/frame_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def frame_average(out):

# use tree map to handle multiple outputs

out = tree_map(lambda t: rearrange(t, '(b f) ... -> b f ...', f = num_frames), out)
out = tree_map(lambda t: rearrange(t, '(b f) ... -> b f ...', f = num_frames) if torch.is_tensor(t) else t, out)
out = tree_map(lambda t: frame_average(t) if torch.is_tensor(t) else t, out)

return out
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "frame-averaging-pytorch"
version = "0.1.0"
version = "0.1.1"
description = "Frame Averaging"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
4 changes: 2 additions & 2 deletions tests/test_frame_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ def __init__(self):
def forward(self, x, mask):
x = x.masked_fill(~mask[..., None], 0.)
hidden = self.net(x)
return self.to_out1(hidden), self.to_out2(hidden)
return 0., self.to_out1(hidden), self.to_out2(hidden)

net = Network()
net = FrameAverage(net)

points = torch.randn(4, 1024, 3)
mask = torch.ones(4, 1024).bool()

out1, out2 = net(points, mask, frame_average_mask = mask)
_, out1, out2 = net(points, mask, frame_average_mask = mask)

assert out1.shape == out2.shape == points.shape

0 comments on commit 49f861f

Please sign in to comment.