Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test without gradient accumulation #841

Merged
merged 8 commits into from
Dec 11, 2024
90 changes: 90 additions & 0 deletions forge/test/mlir/mnist/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,96 @@
def test_mnist_training():
torch.manual_seed(0)

# Model and data type.
# For bfloat16, the following line should be added to the test_forge_vs_torch function:
# In file forge/forge/op/eval/forge/eltwise_unary.py:418 should be replaced with: threshold_tensor = ac.tensor(torch.zeros(shape, dtype=torch.bfloat16) + threshold)
# That sets relu threshold to bfloat16 tensor.
# And in file forge/forge/compile.py::compile_main forced bfloat 16 should be added compiler_cfg.default_df_override = DataFormat.Float16_b
dtype = torch.float32

# Set training hyperparameters
num_epochs = 3
batch_size = 64
learning_rate = 0.001

# Limit number of batches to run - quicker test
limit_num_batches = 1000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about removing this "parameter" and just run this with higher batch size, i.e. 2048?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MNIST is 60k inputs so with a large batch size we will never reach limit_num_batches. I agree that limit_num_batches can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done here: ae7bee0


# Load dataset
test_loader, train_loader = load_dataset(batch_size, dtype=dtype)

# Define model and instruct it to compile and run on TT device
framework_model = MNISTLinear(
bias=False, dtype=dtype
) # bias=False because batch_size=1 with bias=True is not supported

# Create a torch loss and leave on CPU
loss_fn = torch.nn.CrossEntropyLoss()

# Define optimizer and instruct it to compile and run on TT device
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate)
tt_model = forge.compile(
framework_model,
sample_inputs=[torch.rand(batch_size, 784, dtype=dtype)],
optimizer=framework_optimizer,
training=True,
)

logger.info("Starting training loop... (logger will be disabled)")
logger.disable("")
for epoch_idx in range(num_epochs):

total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
# Reset gradients (every batch)
framework_optimizer.zero_grad()

# Create target tensor and leave on CPU
target = nn.functional.one_hot(target, num_classes=10).to(dtype)

# Forward pass (prediction) on device
pred = tt_model(data)[0]
golden_pred = framework_model(data)
assert golden_pred.dtype == dtype
assert compare_with_golden(golden_pred, pred, verify_cfg=VerifyConfig(pcc=0.95))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this PR, there is a unified method for comparing golden output with output of compiled model: linkt to verify. It calls this compare_with_golden but also performs dataformat and some other sanity checks...

In this PR above, you can find examples of its use in tests. It makes things more concise...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need the predictions from the models since I am calculating the loss after this.
Should I change the verify method to make it return the predictions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify is under development at the moment (by @vkovinicTT), so I wouldn't change it. Let's leave compare_with_golden then...
@vkovinicTT did you plan on adding return for the predictions any time soon?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not diverge and extend the scope of planned changes for the verify refactor, yet. There are not so many training tests, so it's not a big problem at this moment IMO...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't plan on extending it, but I can do that if needed?


# Compute loss on CPU
loss = loss_fn(pred, target)
total_loss += loss.item()

golden_loss = loss_fn(golden_pred, target)
assert torch.allclose(loss, golden_loss, rtol=1e-1) # 10% tolerance

# Run backward pass on device
loss.backward()

tt_model.backward()

# Adjust weights (on CPU)
framework_optimizer.step()

if batch_idx >= limit_num_batches:
break

print(f"epoch: {epoch_idx} loss: {total_loss}")

test_loss = 0
for batch_idx, (data, target) in enumerate(test_loader):
pred = tt_model(data)[0]
target = nn.functional.one_hot(target, num_classes=10).to(dtype)

test_loss += loss_fn(pred, target)

if batch_idx == limit_num_batches:
break

print(f"Test (total) loss: {test_loss}")


@pytest.mark.push
def test_mnist_training_with_grad_accumulation():
torch.manual_seed(0)

# Config
num_epochs = 3
batch_size = 1
Expand Down
Loading