Skip to content

Commit

Permalink
use tree_map
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 4, 2024
1 parent 41fb87b commit 1cbb82f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
20 changes: 8 additions & 12 deletions frame_averaging_pytorch/frame_averaging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from random import randrange
from optree import tree_flatten, tree_unflatten
from optree import tree_map

import torch
from torch.nn import Module
Expand Down Expand Up @@ -167,17 +167,13 @@ def frame_average(out):
# automatically take care of this

if not self.stochastic:
flattened_args_kwargs, tree_spec = tree_flatten([args, kwargs])

mapped_args_kwargs = []

for el in flattened_args_kwargs:
if torch.is_tensor(el):
el = repeat(el, 'b ... -> (b f) ...', f = num_frames)

mapped_args_kwargs.append(el)

args, kwargs = tree_unflatten(tree_spec, mapped_args_kwargs)
args, kwargs = tree_map(
lambda el: (
rearrange(el, 'b ... -> (b f) ...', f = num_frames)
if torch.is_tensor(el)
else el
)
, (args, kwargs))

# main network forward

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "frame-averaging-pytorch"
version = "0.0.18"
version = "0.0.19"
description = "Frame Averaging"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit 1cbb82f

Please sign in to comment.