diff --git a/frame_averaging_pytorch/frame_averaging.py b/frame_averaging_pytorch/frame_averaging.py index 856a618..474b4d6 100644 --- a/frame_averaging_pytorch/frame_averaging.py +++ b/frame_averaging_pytorch/frame_averaging.py @@ -150,6 +150,14 @@ def forward( out = self.net(inputs, *args, **kwargs) + # handle if output is a tuple - just follow convention that first output is the one to be frame averaged + # (todo) - handle multiple outputs that need frame averaging + + is_multiple_output = isinstance(out, tuple) + + if is_multiple_output: + out, *rest = out + # split frames from batch out = rearrange(out, '(b f) ... -> b f ...', f = num_frames) @@ -166,4 +174,8 @@ def forward( else: out = rearrange(out, 'b 1 ... -> b ...') - return out + if not is_multiple_output: + return out + + return (out, *rest) + diff --git a/pyproject.toml b/pyproject.toml index 146b387..bda9195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "frame-averaging-pytorch" -version = "0.0.11" +version = "0.0.14" description = "Frame Averaging" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }