Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
examples: set gpus=device_count (#638)
Browse files Browse the repository at this point in the history
* , gpus=-1

* torch.cuda.device_count()

* .

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* .

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Aug 16, 2021
1 parent f733c26 commit 2e3c891
Show file tree
Hide file tree
Showing 28 changed files with 84 additions and 29 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ Flash has a [Summarization task](https://lightning-flash.readthedocs.io/en/lates

```python
import flash
import torch
from flash.core.data.utils import download_data
from flash.text import SummarizationData, SummarizationTask

Expand All @@ -244,7 +245,7 @@ datamodule = SummarizationData.from_csv(
model = SummarizationTask()

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1, gpus=1, precision=16)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), precision=16)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/finetuning_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Here's an example of finetuning.
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/training_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Here's an example:
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Train the model
trainer.fit(model, datamodule=datamodule)
Expand Down
4 changes: 3 additions & 1 deletion docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ supplying the task itself, and the associated data:

model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs)

trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False)
trainer = flash.Trainer(
max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule=datamodule)


Expand Down
4 changes: 3 additions & 1 deletion flash_examples/audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.audio import AudioClassificationData
from flash.core.data.utils import download_data
Expand All @@ -30,7 +32,7 @@
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))

# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/custom_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ class NumpyDataModule(flash.DataModule):
datamodule = NumpyDataModule.from_numpy(x, y)
model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs)

trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False)
trainer = flash.Trainer(
max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule=datamodule)

predict_data = np.array(
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/graph_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE
from flash.graph import GraphClassificationData, GraphClassifier
Expand All @@ -32,7 +34,7 @@
model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)

# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify some graphs!
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
Expand All @@ -27,7 +29,7 @@
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict what's on a few images! ants or bees?
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/image_classification_multi_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
Expand All @@ -32,7 +34,7 @@
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, multi_label=True)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict the genre of a few movies!
Expand Down
3 changes: 3 additions & 0 deletions flash_examples/integrations/fiftyone/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from itertools import chain

import torch

import flash
from flash.core.classification import FiftyOneLabels, Labels
from flash.core.data.utils import download_data
Expand All @@ -39,6 +41,7 @@
)
trainer = flash.Trainer(
max_epochs=1,
gpus=torch.cuda.device_count(),
limit_train_batches=1,
limit_val_batches=1,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from itertools import chain

import fiftyone as fo
import torch

import flash
from flash.core.classification import FiftyOneLabels, Labels
Expand Down Expand Up @@ -53,6 +54,7 @@
)
trainer = flash.Trainer(
max_epochs=1,
gpus=torch.cuda.device_count(),
limit_train_batches=1,
limit_val_batches=1,
)
Expand Down
5 changes: 4 additions & 1 deletion flash_examples/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector
Expand All @@ -23,13 +25,14 @@
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
batch_size=2,
)

# 2. Build the task
model = ObjectDetector(model="retinanet", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule)

# 4. Detect objects in a few images!
Expand Down
6 changes: 5 additions & 1 deletion flash_examples/pointcloud_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData
Expand All @@ -28,7 +30,9 @@
model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0)
trainer = flash.Trainer(
max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)

# 4. Predict what's within a few PointClouds?
Expand Down
6 changes: 5 additions & 1 deletion flash_examples/pointcloud_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData
Expand All @@ -28,7 +30,9 @@
model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0)
trainer = flash.Trainer(
max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)

# 4. Predict what's within a few PointClouds?
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
Expand Down Expand Up @@ -39,7 +41,7 @@
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Segment a few images!
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data
Expand All @@ -29,7 +31,7 @@
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

# 4. Predict on audio files!
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import os

import torch

import flash
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData
Expand All @@ -26,7 +28,7 @@
model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"))

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Apply style transfer to a few images!
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassificationData, TabularClassifier
Expand All @@ -30,7 +32,7 @@
model = TabularClassifier.from_data(datamodule)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions from a CSV
Expand Down
3 changes: 2 additions & 1 deletion flash_examples/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from sklearn import datasets

import flash
Expand All @@ -27,7 +28,7 @@
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify a few examples
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
Expand All @@ -30,7 +32,7 @@
model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Classify a few sentences! How was the movie?
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/text_classification_multi_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
Expand All @@ -36,7 +38,7 @@
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Generate predictions for a few comments!
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TranslationData, TranslationTask
Expand All @@ -30,7 +32,7 @@
model = TranslationTask(backbone="Helsinki-NLP/opus-mt-en-ro")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule)

# 4. Translate something!
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import os

import torch

import flash
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier
Expand All @@ -33,7 +35,7 @@
model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Make a prediction
Expand Down
Loading

0 comments on commit 2e3c891

Please sign in to comment.