-
Notifications
You must be signed in to change notification settings - Fork 394
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix accelerate bug, document, add scripts
Partly resolves #944 There is an issue with using skorch in a multi-GPU setting with accelerate. After some searching, it turns out there were two problems: 1. skorch did not call `accelerator.gather_for_metrics`, which resulted in `y_pred` not having the correct size. For more on this, consult the [accelerate docs](https://huggingface.co/docs/accelerate/quicktour#distributed-evaluation). 2. accelerate has an issue with beeing deepcopied, which happens for instance when using GridSearchCV. The problem is that some references get messed up, resulting in the GradientState of the accelerator instance and of the dataloader to diverge. Therefore, the accelerator did not "know" when the last batch was encountered and was thus unable to remove the dummy samples added for multi-GPU inference. The fix for 1. is provided in this PR. For 2., there is no solution in skorch, but a possible (maybe hacky) fix is suggested in the docs. The fix consists of writing a custom Accelerator class that overrides __deepcopy__ to just return self. I don't know enough about accelerate internals to determine if this is a safe solution or if it can cause more issues down the line, but it resolves the issue. Since reproducing this bug requires a multi-GPU setup and running the scripts with the accelerate launcher, it cannot be covered by normal unit tests. Instead, this PR adds two scripts to reproduce the issue. With the appropriate hardware, they can be used to check the solution.
- Loading branch information
1 parent
a148fed
commit ecf3a40
Showing
6 changed files
with
234 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Testing skorch with accelerate in multi GPU setting | ||
|
||
The full history of this can be found here: https://github.com/skorch-dev/skorch/issues/944 | ||
|
||
There was an issue with using skorch in a multi-GPU setting with accelerate. After some searching, it turns out there were two problems: | ||
|
||
1. skorch did not call `accelerator.gather_for_metrics`, which resulted in `y_pred` not having the correct size. For more on this, consult the [accelerate docs](https://huggingface.co/docs/accelerate/quicktour#distributed-evaluation). | ||
2. accelerate has an issue with beeing deepcopied, which happens for instance when using `GridSearchCV`. The problem is that some references get messed up, resulting in the `GradientState` of the `accelerator` instance and of the `dataloader` to diverge. Therefore, the `accelerator` did not "know" when the last batch was encountered and was thus unable to remove the dummy samples added for multi-GPU inference. | ||
|
||
The fix for 1. is provided in the same PR as this was added. For 2., the scripts contain a custom `Accelerator` class that overrides `__deepcopy__` to just return `self`. I don't know enough about accelerate internals to determine if this is a safe solution or if it can cause more issues down the line, but it resolves the issue. | ||
|
||
This example contains two scripts, one involving skorch and one with skorch completely removed. The scripts reproduce the issue in a multi-GPU setup (tested on a GCP VM instance with two T4's). Unfortunately, the GitHub Action runners don't have such an option, which is why there is no unit test being added for the bug. | ||
|
||
Run the scripts like this: | ||
|
||
```sh | ||
accelerate launch <script.py> | ||
``` | ||
|
||
The accelerate config is: | ||
|
||
```yaml | ||
compute_environment: LOCAL_MACHINE | ||
distributed_type: MULTI_GPU | ||
downcast_bf16: 'no' | ||
gpu_ids: all | ||
machine_rank: 0 | ||
main_training_function: main | ||
mixed_precision: 'no' | ||
num_machines: 1 | ||
num_processes: 2 | ||
rdzv_backend: static | ||
same_network: true | ||
tpu_env: [] | ||
tpu_use_cluster: false | ||
tpu_use_sudo: false | ||
use_cpu: false | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import numpy as np | ||
import torch | ||
from accelerate import Accelerator | ||
from sklearn.datasets import make_classification | ||
from sklearn.model_selection import cross_validate | ||
from sklearn.base import BaseEstimator | ||
from torch import nn | ||
|
||
|
||
class MyModule(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.dense0 = nn.Linear(100, 2) | ||
self.nonlin = nn.LogSoftmax(dim=-1) | ||
|
||
def forward(self, X): | ||
X = self.dense0(X) | ||
X = self.nonlin(X) | ||
return X | ||
|
||
|
||
class Net(BaseEstimator): | ||
def __init__(self, module, accelerator): | ||
self.module = module | ||
self.accelerator = accelerator | ||
|
||
def fit(self, X, y, **fit_params): | ||
X = torch.as_tensor(X) | ||
y = torch.as_tensor(y) | ||
dataset = torch.utils.data.TensorDataset(X, y) | ||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64) | ||
optimizer = torch.optim.SGD(self.module.parameters(), lr=0.01) | ||
|
||
self.module = self.accelerator.prepare(self.module) | ||
optimizer = self.accelerator.prepare(optimizer) | ||
dataloader = self.accelerator.prepare(dataloader) | ||
|
||
# training | ||
self.module.train() | ||
for epoch in range(5): | ||
for source, targets in dataloader: | ||
optimizer.zero_grad() | ||
output = self.module(source) | ||
loss = nn.functional.nll_loss(output, targets) | ||
self.accelerator.backward(loss) | ||
optimizer.step() | ||
|
||
return self | ||
|
||
def predict_proba(self, X): | ||
self.module.eval() | ||
X = torch.as_tensor(X) | ||
dataset = torch.utils.data.TensorDataset(X) | ||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64) | ||
dataloader = self.accelerator.prepare(dataloader) | ||
|
||
probas = [] | ||
with torch.no_grad(): | ||
for source, *_ in dataloader: | ||
output = self.module(source) | ||
output = self.accelerator.gather_for_metrics(output) | ||
output = output.cpu().detach().numpy() | ||
probas.append(output) | ||
|
||
return np.vstack(probas) | ||
|
||
def predict(self, X): | ||
y_proba = self.predict_proba(X) | ||
return y_proba.argmax(1) | ||
|
||
|
||
class MyAccelerator(Accelerator): | ||
def __deepcopy__(self, memo): | ||
return self | ||
|
||
|
||
def main(): | ||
X, y = make_classification(10000, n_features=100, n_informative=50, random_state=0) | ||
X = X.astype(np.float32) | ||
|
||
module = MyModule() | ||
accelerator = MyAccelerator() | ||
net = Net(module, accelerator) | ||
# cross_validate creates a deepcopy of the accelerator attribute | ||
res = cross_validate( | ||
net, X, y, cv=2, scoring='accuracy', verbose=3, error_score='raise', | ||
) | ||
print(res) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import numpy as np | ||
import torch | ||
from accelerate import Accelerator | ||
from sklearn.datasets import make_classification | ||
from sklearn.model_selection import cross_validate | ||
from torch import nn | ||
|
||
from skorch import NeuralNetClassifier | ||
from skorch.hf import AccelerateMixin | ||
|
||
|
||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.dense0 = nn.Linear(100, 2) | ||
self.nonlin = nn.LogSoftmax(dim=-1) | ||
|
||
def forward(self, X): | ||
X = self.dense0(X) | ||
X = self.nonlin(X) | ||
return X | ||
|
||
class AcceleratedNeuralNetClassifier(AccelerateMixin, NeuralNetClassifier): | ||
pass | ||
|
||
|
||
class MyAccelerator(Accelerator): | ||
def __deepcopy__(self, memo): | ||
return self | ||
|
||
|
||
def main(): | ||
X, y = make_classification(10000, n_features=100, n_informative=50, random_state=0) | ||
X = X.astype(np.float32) | ||
|
||
accelerator = MyAccelerator() | ||
model = AcceleratedNeuralNetClassifier( | ||
MyModule, | ||
accelerator=accelerator, | ||
max_epochs=3, | ||
lr=0.001, | ||
) | ||
|
||
cross_validate( | ||
model, | ||
X, | ||
y, | ||
cv=2, | ||
scoring="average_precision", | ||
error_score="raise", | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters