Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Commit

Permalink
Merge pull request #5 from cgarciae/fix-optimizer
Browse files Browse the repository at this point in the history
0.4.0
  • Loading branch information
cgarciae authored Sep 13, 2021
2 parents 5669195 + 3598b64 commit d58d0d3
Show file tree
Hide file tree
Showing 25 changed files with 2,016 additions and 321 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,8 @@ cython_debug/
# data folder
data

/test.py
# test file to try out stuff
/test.py

# rsync
.git/
218 changes: 206 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ y = linear(x)
Valid type annotations include:
* Subtypes of `tx.TreePart` e.g. `tx.Parameter`, `tx.BatchStat`, etc.
* Subtypes of `tx.Module` e.g. `tx.Linear`, custom Module types, etc.
* Generic subtypes from the `typing` module of the previous e.g. `List[tx.Parameter]` or `Dict[str, tx.Linear]`.
* Generic subtypes from the `typing` module containing `TreeObject` subtypes e.g. `List[tx.Linear]` or `Dict[str, tx.Conv]`.
* Generic types cannot contain `tx.TreePart` subtypes e.g. this is not allowed `Tuple[int, tx.Parameter[float]]`.

Type annotations that do not comform to the above rules will be ignored and the field will not be counted as part of the Pytree.
Fields with annotations that do not comform to the above rules will be counted as static or yield an error when invalid.

```python
class MLP(tx.Module):
Expand Down Expand Up @@ -168,7 +169,22 @@ class CNN(tx.Module):
self.dropout2 = tx.Dropout(0.5)
```

Note that this won't work if e.g. you have a field with e.g. a list/dict of Modules, for that you have to use proper type annotations.
Note that this won't work if you have a field with e.g. a list/dict of Modules, for that you have to use proper type annotations.

### Predefined Layers
The `treex.nn` module contains a number of pre-defined layers which can be used to create more complex models:
```
BatchNorm
Conv
Dropout
FlaxModule
Linear
MLP
Lambda
Sequential
sequence
```
Check them out in the (API Reference)[https://cgarciae.github.io/treex/] section of the documentation. All modules and functions from `treex.nn` are also available on the base `treex` module so e.g. you can use `tx.Conv` and `tx.BatchNorm`.

### Pytrees
Since Modules are pytrees they can be arguments to JAX functions such as `jit`, `grad`, `vmap`, etc, and the `jax.tree_*` function family.
Expand Down Expand Up @@ -234,8 +250,9 @@ module # MyModule(a=array([0.3]), b=array([3.2]))
```
As shown here, field `Initializer`s are always called before `module_init`.

### Filter and Update API
The `filter` method allows you to select a subtree by filtering based on a `TreeType` type, all leaves whose type annotations are a subclass of such type are kept, the rest are set to a special `Nothing` value.
### Basic API
#### Filter
The `filter` method allows you to select a subtree by filtering based on a `TreePart` type, all leaves whose type annotations are a subclass of such type are kept, the rest are set to a special `Nothing` value.
```python
class MyModule(tx.Module):
a: tx.Parameter[np.ndarray] = np.array(1)
Expand All @@ -247,13 +264,13 @@ module = MyModule(...)
module.filter(tx.Parameter) # MyModule(a=array([1]), b=Nothing)
module.filter(tx.BatchStat) # MyModule(a=Nothing, b=array([2]))
```
`Nothing` much like `None` is an empty Pytree so it gets ignored by tree operations:
Since `Nothing` is an empty Pytree it gets ignored by tree operations, this effectively allows you to easily operate on a subset of the fields:

```python
jax.tree_leaves(module.filter(tx.Parameter)) # [array([1])]
jax.tree_leaves(module.filter(tx.BatchStat)) # [array([2])]
jax.tree_map(lambda x: -x, module.filter(tx.Parameter)) # MyModule(a=array([-1]), b=Nothing)
jax.tree_map(lambda x: -x, module.filter(tx.BatchStat)) # MyModule(a=Nothing, b=array([-2]))
```

##### functions
If you need to do more complex filtering, you can pass callables with the signature `FieldInfo -> bool` instead of types:

```python
Expand All @@ -264,6 +281,62 @@ module.filter(
)
# MyModule(a=Nothing, b=array([2]))
```
##### multiple filters
The previous could be abbreviated using multiple filters as its required that **all** filters pass for a field to be kept:
```python
# all States that are not OptStates
module.filter(
tx.State,
lambda field: not issubclass(field.annotation, tx.OptState)
)
# MyModule(a=Nothing, b=array([2]))
```

#### Update
The `update` method allows you to merge the values of one or more incoming modules with the current module, this is useful for integrating filtered modules back into the main module.

```python
module = MyModule(...) # MyModule(a=array([1]), b=array([2]))
params = module.filter(tx.Parameter) # MyModule(a=array([1]), b=Nothing)
negative = jax.tree_map(lambda x: -x, params) # MyModule(a=array([-1]), b=Nothing)
module = module.update(negative) # MyModule(a=array([-1]), b=array([2]))
```

#### Map
The `map` method provides a convenient way to map a function over the fields of a module:

```python
module = MyModule(...) # MyModule(a=array([1]), b=array([2]))
params = module.filter(tx.Parameter) # MyModule(a=array([1]), b=Nothing)
negative = params.map(lambda x: -x) # MyModule(a=array([-1]), b=Nothing)
module = module.update(negative) # MyModule(a=array([-1]), b=array([2]))
```

The previous pattern is so common that `map` provides a shortcut for applying `filter -> tree_map -> update` in sequence:

```python
module = MyModule(...) # MyModule(a=array([1]), b=array([2]))
module = module.map(lambda x: -x, tx.Parameter) # MyModule(a=array([-1]), b=array([2]))
```

As shown here, `map` accepts the same varargs as `filter` and calls `update` at the end if filters are given.

#### Functional API
All the previous methods are available as functions that can be applied to arbirary Pytrees. Here is the full list of functions:
```python
# -----------------
# basic API
# -----------------
tx.filter(obj: A, *filters: Filter) -> A
tx.update(module: A, other: A, *rest: A) -> A
tx.map(f: tp.Callable, obj: A, *filters: Filter) -> A

# -----------------
# other functions
# -----------------
# applies a function to TreeObjects instead of leaves, useful to modify static properties
tx.object_apply(f: tp.Callable, obj: A, *rest: A, inplace: bool) -> A
```

#### Use cases
##### grad & optimizers
Expand Down Expand Up @@ -298,6 +371,14 @@ batch_stats = jax.lax.pmean(batch_stats, axis_name="device")
model = model.update(batch_stats)
```

The previous is roughly equivalent to:

```python
model = model.map(lambda x: jax.lax.pmean(x, axis_name="device"), tx.BatchStat)
```

However, this applies `pmean` to the leaves instead of the whole tree which may or may not be desirable.

### Optimizer

Optax is an amazing library however, its optimizers are not pytrees, this means that their state and computation are separate and you cannot jit them. To solve this Treex provides a `tx.Optimizer` class that can wrap any Optax optimizer.
Expand Down Expand Up @@ -426,6 +507,61 @@ model = model.eval()
# make predictions
y_pred = model(X_test)
```

### Freezing Layers
Similar to `.training`, Treex `Module`s have a `.frozen` property that specifies whether the module is frozen or not. This property is used to condition the behavior of modules such as `Dropout` and `BatchNorm` which will behave deterministically/statically when `frozen=True`. Freezing layers is useful for tasks such as Transfer Learning where you want to freeze most of the weights of a model and train only a few of them on a new dataset.

To switch between modes, use the `.freeze()` and `.unfreeze()` methods, they return a new Module whose `frozen` state and the state of all of its submodules (recursively) are set to the desired value.

For example, you can leverage the fact that `Sequential` has its submoules in a `layers: List[Module]` field to freeze the first few layers of a model:

```python
class ConvBlock(tx.Module):
...

model = tx.Sequential(
ConvBlock(3, 32),
ConvBlock(32, 64),
ConvBlock(64, 128),
...
)

# train model
...

# freeze some layers
for layer in model.layers[:-1]:
layer.freeze(inplace=True)

# fine-tune the model
...
```

If you have a backbone you can just freeze the entire model:

```python
backbone = get_pretrained_model()
backbone = backbone.freeze()

model = tx.Sequential(
backbone,
tx.Linear(backbone.output_features, 10)
).init(42)

...

@jax.jit
def train_step(model, x, y, optimizer):
# only differentiate w.r.t. parameters whose module is not frozen
params = model.filter(
tx.Parameter,
lambda field: not field.module.frozen,
)
(loss, model), grads = loss_fn(params, model, x, y)

...
```

### Parameter Annotations
The role of each field is defined by its annotation. While any valid parameter annotation is just type which inherits from `tx.TreePart`, the default annotations from Treex are organized into the following hierarchy:

Expand All @@ -434,21 +570,24 @@ The role of each field is defined by its annotation. While any valid parameter a

```mermaid
graph TD;
TreePart-->Parameter;
TreePart-->State;
TreePart-->Parameter;
State-->Rng;
State-->ModelState;
State-->OptState;
ModelState-->BatchStat;
ModelState-->Cache;
TreePart-->Log;
Log-->Loss;
Log-->Metric;
State-->OptState;
TreePart-->DiffHyperParam;
```

</details>

![types](images/types.png)
<!-- Uncomment to test during development -->
<!-- ![types](images/types.png) -->
![types](https://raw.githubusercontent.com/cgarciae/treex/master/images/types.png)

This is useful because you can make specific or more general queries using `filter` depending on what you want to achive. e.g.

Expand All @@ -457,6 +596,7 @@ rngs = model.filter(tx.Rng)
batch_stats = model.filter(tx.BatchStat)
all_states = model.filter(tx.State)
```

#### Static Analysis
All `TreePart` instances included in Treex like `Parameter` and `State` currently behave as a `typing.Union` in the eyes of static analyzers. This means that they will think the following types resolve to:

Expand Down Expand Up @@ -510,6 +650,60 @@ class MyModule(tx.Module):
```
Hopefully a better way is found in the future, however, this will keep the static analyzers happy as they will think `cache` is an `ndarray` while Treex will get the correct `_Cache` annotation metadata.

#### Explicit Static fields
All field that are NOT marked with a `TreePart` subclass annotation are considered static, there is one exception: auto-annotations. All submodules are considered dynamic by default so if you explicitly want to exclude them from being dynamic you can use the `tx.Static` annotation.

As an example we will create a `ResetableLinear` class that will store a copy of itself in the `initial_state` field and go back to the initial state when `reset` is called:

```python
class ResetableLinear(tx.Linear):
initial_state: tx.Static[tx.Linear, None] = None

def module_init(self, _key):
super().module_init(_key)
self.initial_state = self.copy()

def reset(self):
self.update(self.initial_state, inplace=True)

```

They key here is that `initial_state` is annotated with `Static` field so it won't be tracked, this will gurarantee that it won't be modified by the optimizer.

### Non-hashable static fields
If you want to have a static field that contains a non-hashable value like a numpy or jax array, you can use `tx.Hashable` to wrap around it such that it:

```python
class MyModule(tx.Module):
table: tx.Hashable[np.ndarray]
...

def __init__(self, table: np.ndarray):
self.table = tx.Hashable(table)
...

def __call__(self, x: np.ndarray) -> np.ndarray:
table = self.table.value
...
```
The hash from `Hashable` will only depend on object identity but not on the actual `value`, therefore you should treat it as an immutable, if you want to update its value you should create a new `Hashable` instance:

```python
table = np.ones((10, 10))
module = MyModule(table)

# use module as an argument for a jit-ed function
...

module.table = tx.Hashable(np.zeros((10, 10)))

# jit-ed function will recompile now
...
```
If you are somehow able to mutate `value` directly JAX won't know about this and `jit` won't recompile.

**Note:** Currently JAX does not complain when you have a static field is a numpy array, however, in case you mutate such field and pass its module through jit again you will get a deprecation warning saying this situation will be an error in the future.

### Full Example

```python
Expand Down
Loading

0 comments on commit d58d0d3

Please sign in to comment.