diff --git a/forge/test/mlir/mnist/training/test_training.py b/forge/test/mlir/mnist/training/test_training.py index 17e635892..d46c5e472 100644 --- a/forge/test/mlir/mnist/training/test_training.py +++ b/forge/test/mlir/mnist/training/test_training.py @@ -21,6 +21,87 @@ def test_mnist_training(): torch.manual_seed(0) + # Model and data type. + # For bfloat16, the following line should be added to the test_forge_vs_torch function: + # In file forge/forge/op/eval/forge/eltwise_unary.py:418 should be replaced with: threshold_tensor = ac.tensor(torch.zeros(shape, dtype=torch.bfloat16) + threshold) + # That sets relu threshold to bfloat16 tensor. + # And in file forge/forge/compile.py::compile_main forced bfloat 16 should be added compiler_cfg.default_df_override = DataFormat.Float16_b + dtype = torch.float32 + + # Set training hyperparameters + num_epochs = 3 + batch_size = 2048 + learning_rate = 0.001 + + # Load dataset + test_loader, train_loader = load_dataset(batch_size, dtype=dtype) + + # Define model and instruct it to compile and run on TT device + framework_model = MNISTLinear( + bias=False, dtype=dtype + ) # bias=False because batch_size=1 with bias=True is not supported + + # Create a torch loss and leave on CPU + loss_fn = torch.nn.CrossEntropyLoss() + + # Define optimizer and instruct it to compile and run on TT device + framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate) + tt_model = forge.compile( + framework_model, + sample_inputs=[torch.rand(batch_size, 784, dtype=dtype)], + optimizer=framework_optimizer, + training=True, + ) + + logger.info("Starting training loop... (logger will be disabled)") + logger.disable("") + for epoch_idx in range(num_epochs): + + total_loss = 0 + for batch_idx, (data, target) in enumerate(train_loader): + # Reset gradients (every batch) + framework_optimizer.zero_grad() + + # Create target tensor and leave on CPU + target = nn.functional.one_hot(target, num_classes=10).to(dtype) + + # Forward pass (prediction) on device + pred = tt_model(data)[0] + golden_pred = framework_model(data) + assert golden_pred.dtype == dtype + assert compare_with_golden(golden_pred, pred, verify_cfg=VerifyConfig(pcc=0.95)) + + # Compute loss on CPU + loss = loss_fn(pred, target) + total_loss += loss.item() + + golden_loss = loss_fn(golden_pred, target) + assert torch.allclose(loss, golden_loss, rtol=1e-1) # 10% tolerance + + # Run backward pass on device + loss.backward() + + tt_model.backward() + + # Adjust weights (on CPU) + framework_optimizer.step() + + print(f"epoch: {epoch_idx} loss: {total_loss}") + + test_loss = 0 + for batch_idx, (data, target) in enumerate(test_loader): + pred = tt_model(data)[0] + target = nn.functional.one_hot(target, num_classes=10).to(dtype) + + test_loss += loss_fn(pred, target) + + print(f"Test (total) loss: {test_loss}") + + +@pytest.mark.push +def test_mnist_training_with_grad_accumulation(): + torch.manual_seed(0) + # Config num_epochs = 3 batch_size = 1