diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 86316b92c60a4..0f2e39f0c8826 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -76,26 +76,26 @@ def forward(self, x): def training_step(self, batch, batch_idx): x, y = batch - y_hat = self.backbone(x) + y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.log("train_loss", loss, on_epoch=True) return loss def validation_step(self, batch, batch_idx): x, y = batch - y_hat = self.backbone(x) + y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.log("valid_loss", loss, on_step=True) def test_step(self, batch, batch_idx): x, y = batch - y_hat = self.backbone(x) + y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.log("test_loss", loss) def predict_step(self, batch, batch_idx, dataloader_idx=None): x, y = batch - return self.backbone(x) + return self(x) def configure_optimizers(self): # self.hparams available because we called self.save_hyperparameters()