Skip to content

Commit

Permalink
Add test without gradient accumulation (#841)
Browse files Browse the repository at this point in the history
* Add test without gradient accumulation

* Remove unused import and comments

* Fix formatting

* Switch to new forge compile API

* Add bf16 instructions

* Remove num of batches limit and increase batch size
  • Loading branch information
pglusacTT authored Dec 11, 2024
1 parent a09a541 commit 10b90b4
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions forge/test/mlir/mnist/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,87 @@
def test_mnist_training():
torch.manual_seed(0)

# Model and data type.
# For bfloat16, the following line should be added to the test_forge_vs_torch function:
# In file forge/forge/op/eval/forge/eltwise_unary.py:418 should be replaced with: threshold_tensor = ac.tensor(torch.zeros(shape, dtype=torch.bfloat16) + threshold)
# That sets relu threshold to bfloat16 tensor.
# And in file forge/forge/compile.py::compile_main forced bfloat 16 should be added compiler_cfg.default_df_override = DataFormat.Float16_b
dtype = torch.float32

# Set training hyperparameters
num_epochs = 3
batch_size = 2048
learning_rate = 0.001

# Load dataset
test_loader, train_loader = load_dataset(batch_size, dtype=dtype)

# Define model and instruct it to compile and run on TT device
framework_model = MNISTLinear(
bias=False, dtype=dtype
) # bias=False because batch_size=1 with bias=True is not supported

# Create a torch loss and leave on CPU
loss_fn = torch.nn.CrossEntropyLoss()

# Define optimizer and instruct it to compile and run on TT device
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate)
tt_model = forge.compile(
framework_model,
sample_inputs=[torch.rand(batch_size, 784, dtype=dtype)],
optimizer=framework_optimizer,
training=True,
)

logger.info("Starting training loop... (logger will be disabled)")
logger.disable("")
for epoch_idx in range(num_epochs):

total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
# Reset gradients (every batch)
framework_optimizer.zero_grad()

# Create target tensor and leave on CPU
target = nn.functional.one_hot(target, num_classes=10).to(dtype)

# Forward pass (prediction) on device
pred = tt_model(data)[0]
golden_pred = framework_model(data)
assert golden_pred.dtype == dtype
assert compare_with_golden(golden_pred, pred, verify_cfg=VerifyConfig(pcc=0.95))

# Compute loss on CPU
loss = loss_fn(pred, target)
total_loss += loss.item()

golden_loss = loss_fn(golden_pred, target)
assert torch.allclose(loss, golden_loss, rtol=1e-1) # 10% tolerance

# Run backward pass on device
loss.backward()

tt_model.backward()

# Adjust weights (on CPU)
framework_optimizer.step()

print(f"epoch: {epoch_idx} loss: {total_loss}")

test_loss = 0
for batch_idx, (data, target) in enumerate(test_loader):
pred = tt_model(data)[0]
target = nn.functional.one_hot(target, num_classes=10).to(dtype)

test_loss += loss_fn(pred, target)

print(f"Test (total) loss: {test_loss}")


@pytest.mark.push
def test_mnist_training_with_grad_accumulation():
torch.manual_seed(0)

# Config
num_epochs = 3
batch_size = 1
Expand Down

0 comments on commit 10b90b4

Please sign in to comment.