Skip to content

Commit

Permalink
Fix/ar-tutorial (#1655)
Browse files Browse the repository at this point in the history
This PR fixes a device issue for the 'ar' tutorial where tensors were on different devices if a "mps" device is present leading to an error.
  • Loading branch information
jdb78 committed Sep 7, 2024
1 parent fa3dc24 commit 6c66358
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions docs/source/tutorials/ar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,8 @@
],
"source": [
"# calculate baseline absolute error\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(device)\n",
"baseline_predictions = Baseline().predict(val_dataloader)\n",
"SMAPE()(baseline_predictions, actuals)"
"baseline_predictions = Baseline().predict(val_dataloader, return_y=True)\n",
"SMAPE()(baseline_predictions.output, baseline_predictions.y)"
]
},
{
Expand Down Expand Up @@ -517,8 +515,8 @@
}
],
"source": [
"actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(device)\n",
"predictions = best_model.predict(val_dataloader)\n",
"actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(\"cpu\")\n",
"predictions = best_model.predict(val_dataloader, trainer_kwargs=dict(accelerator=\"cpu\"))\n",
"(actuals - predictions).abs().mean()"
]
},
Expand Down

0 comments on commit 6c66358

Please sign in to comment.