-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add test without gradient accumulation #841
Changes from 6 commits
7244e1d
78b42a0
19fb3a9
c6d5f40
6b71aea
de73a93
ae7bee0
bb0c454
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,96 @@ | |
def test_mnist_training(): | ||
torch.manual_seed(0) | ||
|
||
# Model and data type. | ||
# For bfloat16, the following line should be added to the test_forge_vs_torch function: | ||
# In file forge/forge/op/eval/forge/eltwise_unary.py:418 should be replaced with: threshold_tensor = ac.tensor(torch.zeros(shape, dtype=torch.bfloat16) + threshold) | ||
# That sets relu threshold to bfloat16 tensor. | ||
# And in file forge/forge/compile.py::compile_main forced bfloat 16 should be added compiler_cfg.default_df_override = DataFormat.Float16_b | ||
dtype = torch.float32 | ||
|
||
# Set training hyperparameters | ||
num_epochs = 3 | ||
batch_size = 64 | ||
learning_rate = 0.001 | ||
|
||
# Limit number of batches to run - quicker test | ||
limit_num_batches = 1000 | ||
|
||
# Load dataset | ||
test_loader, train_loader = load_dataset(batch_size, dtype=dtype) | ||
|
||
# Define model and instruct it to compile and run on TT device | ||
framework_model = MNISTLinear( | ||
bias=False, dtype=dtype | ||
) # bias=False because batch_size=1 with bias=True is not supported | ||
|
||
# Create a torch loss and leave on CPU | ||
loss_fn = torch.nn.CrossEntropyLoss() | ||
|
||
# Define optimizer and instruct it to compile and run on TT device | ||
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate) | ||
tt_model = forge.compile( | ||
framework_model, | ||
sample_inputs=[torch.rand(batch_size, 784, dtype=dtype)], | ||
optimizer=framework_optimizer, | ||
training=True, | ||
) | ||
|
||
logger.info("Starting training loop... (logger will be disabled)") | ||
logger.disable("") | ||
for epoch_idx in range(num_epochs): | ||
|
||
total_loss = 0 | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
# Reset gradients (every batch) | ||
framework_optimizer.zero_grad() | ||
|
||
# Create target tensor and leave on CPU | ||
target = nn.functional.one_hot(target, num_classes=10).to(dtype) | ||
|
||
# Forward pass (prediction) on device | ||
pred = tt_model(data)[0] | ||
golden_pred = framework_model(data) | ||
assert golden_pred.dtype == dtype | ||
assert compare_with_golden(golden_pred, pred, verify_cfg=VerifyConfig(pcc=0.95)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After this PR, there is a unified method for comparing golden output with output of compiled model: linkt to verify. It calls this compare_with_golden but also performs dataformat and some other sanity checks... In this PR above, you can find examples of its use in tests. It makes things more concise... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need the predictions from the models since I am calculating the loss after this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verify is under development at the moment (by @vkovinicTT), so I wouldn't change it. Let's leave There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not diverge and extend the scope of planned changes for the verify refactor, yet. There are not so many training tests, so it's not a big problem at this moment IMO... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't plan on extending it, but I can do that if needed? |
||
|
||
# Compute loss on CPU | ||
loss = loss_fn(pred, target) | ||
total_loss += loss.item() | ||
|
||
golden_loss = loss_fn(golden_pred, target) | ||
assert torch.allclose(loss, golden_loss, rtol=1e-1) # 10% tolerance | ||
|
||
# Run backward pass on device | ||
loss.backward() | ||
|
||
tt_model.backward() | ||
|
||
# Adjust weights (on CPU) | ||
framework_optimizer.step() | ||
|
||
if batch_idx >= limit_num_batches: | ||
break | ||
|
||
print(f"epoch: {epoch_idx} loss: {total_loss}") | ||
|
||
test_loss = 0 | ||
for batch_idx, (data, target) in enumerate(test_loader): | ||
pred = tt_model(data)[0] | ||
target = nn.functional.one_hot(target, num_classes=10).to(dtype) | ||
|
||
test_loss += loss_fn(pred, target) | ||
|
||
if batch_idx == limit_num_batches: | ||
break | ||
|
||
print(f"Test (total) loss: {test_loss}") | ||
|
||
|
||
@pytest.mark.push | ||
def test_mnist_training_with_grad_accumulation(): | ||
torch.manual_seed(0) | ||
|
||
# Config | ||
num_epochs = 3 | ||
batch_size = 1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about removing this "parameter" and just run this with higher batch size, i.e. 2048?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MNIST is 60k inputs so with a large batch size we will never reach limit_num_batches. I agree that limit_num_batches can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done here: ae7bee0