Skip to content

Commit

Permalink
use factories in context example
Browse files Browse the repository at this point in the history
Using the same instance multiple times is a bad idea
because PyTorch memorizes things internally. Among
other things this breaks Chain's `__repr__`.
  • Loading branch information
catwell committed Apr 5, 2024
1 parent bbb46e3 commit e033306
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions docs/concepts/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,22 @@ Another use of the context is simplifying complex models, in particular those wi
To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net:

```py
square = fl.Lambda(lambda x: x ** 2)
sqrt = fl.Lambda(lambda x: x ** 0.5)
square = lambda: fl.Lambda(lambda x: x ** 2)
sqrt = lambda: fl.Lambda(lambda x: x ** 0.5)

m1 = fl.Chain(
fl.Residual(
square,
square(),
fl.Residual(
square,
square(),
fl.Residual(
square,
square(),
),
sqrt,
sqrt(),
),
sqrt,
sqrt(),
),
sqrt,
sqrt(),
)
```

Expand All @@ -86,7 +86,12 @@ class MyModel(fl.Chain):
def init_context(self) -> Contexts:
return {"mymodel": {"residuals": []}}

push_residual = fl.SetContext("mymodel", "residuals", callback=lambda l, x: l.append(x))
def push_residual():
return fl.SetContext(
"mymodel",
"residuals",
callback=lambda l, x: l.append(x),
)

class ApplyResidual(fl.Sum):
def __init__(self):
Expand All @@ -95,8 +100,8 @@ class ApplyResidual(fl.Sum):
fl.UseContext("mymodel", "residuals").compose(lambda x: x.pop()),
)

squares = fl.Chain(x for _ in range(3) for x in (push_residual, square))
sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt))
squares = fl.Chain(x for _ in range(3) for x in (push_residual(), square()))
sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt()))
m2 = MyModel(squares, sqrts)
```

Expand Down

0 comments on commit e033306

Please sign in to comment.