Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.eval() #2083

Closed
JaejinCho opened this issue Oct 15, 2022 · 4 comments · Fixed by #2400
Closed

model.eval() #2083

JaejinCho opened this issue Oct 15, 2022 · 4 comments · Fixed by #2400
Assignees
Labels
docathon-h1-2023 A label for the docathon in H1 2023 easy intro

Comments

@JaejinCho
Copy link

JaejinCho commented Oct 15, 2022

In the tutorial below, isn't it better to have model.eval() for more general cases in addition to the context manager torch.no_grad(), or at least have a brief explanation regarding the difference between the two? I think no_grad does not take care of dropout or batchnorm. Although not having model.eval() is fine in this tutorial, it seems necessary generally for evaluation.

cc @suraj813 @jerryzh168 @z-a-f @vkuzo

@svekars svekars added the arch-optimization quantization, sparsity, ns label Oct 17, 2022
@z-a-f
Copy link
Contributor

z-a-f commented Jan 10, 2023

@svekars I believe label:arch-optimization is for quantization/sparsity related topics, so might not be applicable for this issue.

@svekars svekars added intro and removed arch-optimization quantization, sparsity, ns labels Mar 1, 2023
@svekars svekars added easy docathon-h1-2023 A label for the docathon in H1 2023 labels May 31, 2023
@zabboud
Copy link
Contributor

zabboud commented May 31, 2023

/assigntome

@zabboud
Copy link
Contributor

zabboud commented May 31, 2023

@JaejinCho as you mentioned model.eval() is important to ensure setting dropout and batch normalization layers to evaluation mode. As per the documentation

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

While the role of torch.no_grad() is to disable gradient calculation at inference time, such that Tensor.backward() will not be called. Also torch.no_grad() serves another purpose, where the memory consumption is decreased for tensors that have requires_grad=True. In effect, torch.no_grad() will temporarily set all the computations on tensors to requires_grad=False. For more details see the documentation.

However, I would tend to agree that for beginners to learn best practices, they should use both model.eval() and with torch.no_grad() to ensure the reduction of memory consumption for unnecessary computations, and for ensuring correct setting of batch_norm and dropout layers to eval mode.

Do you think what is needed is updating the example with comments to clarify the use case of these two modes? and to update the example to include model.eval()?

@zabboud
Copy link
Contributor

zabboud commented May 31, 2023

@JaejinCho Would the following additions to the tutorial be sufficient?

Current test loop:

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

Addition of comments and model.eval()

def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

zabboud added a commit to zabboud/tutorials that referenced this issue Jun 1, 2023
@zabboud zabboud mentioned this issue Jun 1, 2023
4 tasks
svekars pushed a commit that referenced this issue Jun 1, 2023
Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
svekars pushed a commit that referenced this issue Jun 2, 2023
…ansforms.Normalize (#2405)

* Fixes #2083 - explain model.eval, torch.no_grad
* set norm to mean & std of CIFAR10(#1818)
---------

Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docathon-h1-2023 A label for the docathon in H1 2023 easy intro
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants