Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Jul 17, 2024
1 parent 2eec503 commit e678cc8
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions docs/eager.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ input = torch.randn(64, 3, 224, 224).to(device)
# model tracing
res = model(input)

# model execution, you can also use `xm.mark_step`
# model execution, same as `xm.mark_step`
torch_xla.sync()
```
The actual model compilation and device execution happens when `torch_xla.sync` is called. There are multiple drawback of this approach.
Expand Down Expand Up @@ -55,7 +55,7 @@ Note that
The implementation of the `torch_xla.experimental.compile` is actually pretty straight forward, it disable the eager mode when entering the target function and start tracing. It will call the `torch_xla.sync()` when target function returns and reenable the eager mode. You can expect the same perfomrance by using the `eager` + `compile` API compared to the existing `mark_step/sync` approach.


### Infernce
### Inference
```python
torch_xla.experimental.eager_mode(True)

Expand All @@ -67,16 +67,15 @@ It is recommened to use the `torch.compile` instead of `torch_xla.experimental.c
```python
torch_xla.experimental.eager_mode(True)

def step_fn(self, data, target):
self.optimizer.zero_grad()
logits = self.model(data)
loss = self.loss_fn(
logits.view(-1, self.config.vocab_size), target.view(-1))
def step_fn(model, data, target, loss_fn, optimizer):
optimizer.zero_grad()
logits = model(data)
loss = loss_fn(logits, target)
loss.backward()
optimizer.step()
return loss

self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn)
step_fn = torch_xla.experimental.compile(step_fn)
```
In training we asked user to refactor the `step_fn` out because it is usually better to compile the model's forward, backward and optimizer together. The long term goal is to also use `torch.compile` for training but right now we recommend user to use `torch_xla.experimental.compile`(for perfomrance reason).

Expand Down

0 comments on commit e678cc8

Please sign in to comment.