Skip to content

Commit

Permalink
network forward can be handled externally
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 3, 2024
1 parent 7de1ea0 commit 434ec24
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions frame_averaging_pytorch/frame_averaging.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from random import randrange
from optree import tree_flatten, tree_unflatten

Expand All @@ -19,7 +21,7 @@ def default(v, d):
class FrameAverage(Module):
def __init__(
self,
net: Module,
net: Module | None = None,
dim = 3,
stochastic = False,
invariant_output = False
Expand Down Expand Up @@ -146,7 +148,7 @@ def frame_average(out):

# if one wants to handle the framed inputs externally

if return_framed_inputs_and_averaging_function:
if return_framed_inputs_and_averaging_function or not exists(self.net):
return inputs, frame_average

# merge frames into batch
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "frame-averaging-pytorch"
version = "0.0.15"
version = "0.0.16"
description = "Frame Averaging"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit 434ec24

Please sign in to comment.