Skip to content

Commit

Permalink
Merge pull request #3813 from google:nnx-improve-lifted-example
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621208234
  • Loading branch information
Flax Authors committed Apr 2, 2024
2 parents 2b257b4 + 6d638e6 commit ea7f9df
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
10 changes: 6 additions & 4 deletions flax/experimental/nnx/examples/toy_examples/01_functional_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ def train_step(params, counts, batch):
x, y = batch

def loss_fn(params):
y_pred, (_, updates) = static.apply(params, counts)(x)
counts_ = updates.extract(Count)
model = static.merge(params, counts)
y_pred = model(x)
new_counts = model.extract(Count)
loss = jnp.mean((y - y_pred) ** 2)
return loss, counts_
return loss, new_counts

grad, counts = jax.grad(loss_fn, has_aux=True)(params)
# |-------- sgd ---------|
Expand All @@ -82,7 +83,8 @@ def loss_fn(params):
@jax.jit
def test_step(params: nnx.State, counts: nnx.State, batch):
x, y = batch
y_pred, _ = static.apply(params, counts)(x)
model = static.merge(params, counts)
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,43 +63,41 @@ def __call__(self, x):
optimizer = nnx.Optimizer(model, tx)

@nnx.jit
def train_step(optimizer: nnx.Optimizer, batch):
def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
x, y = batch

def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)

# |--default--|
grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(optimizer.model)
# |--default--|
grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model)
# sgd update
optimizer.update(grads=grads)

# no return!!!


@nnx.jit
def test_step(optimizer: nnx.Optimizer, batch):
def test_step(model: MLP, batch):
x, y = batch
y_pred = optimizer.model(x)
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}


total_steps = 10_000
for step, batch in enumerate(dataset(32)):
train_step(optimizer, batch)
train_step(model, optimizer, batch)

if step % 1000 == 0:
logs = test_step(optimizer, (X, Y))
logs = test_step(model, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break

print('times called:', optimizer.model.count.value)
print('times called:', model.count.value)

y_pred = optimizer.model(X)
y_pred = model(X)

plt.scatter(X, Y, color='blue')
plt.plot(X, y_pred, color='black')
Expand Down

0 comments on commit ea7f9df

Please sign in to comment.