Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

simplify examples #57

Merged
merged 8 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions flash/core/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import platform
from typing import Any, Optional

import pytorch_lightning as pl
Expand Down Expand Up @@ -63,11 +64,11 @@ def __init__(
self.batch_size = batch_size

# TODO: figure out best solution for setting num_workers
# if num_workers is None:
# num_workers = os.cpu_count()
if num_workers is None:
# warnings.warn("Could not infer cpu count automatically, setting it to zero")
num_workers = 0
if platform.system() == "Darwin":
num_workers = 0
else:
num_workers = os.cpu_count()
self.num_workers = num_workers

self._data_pipeline = None
Expand Down
38 changes: 18 additions & 20 deletions flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,27 @@
from flash.core.finetuning import FreezeUnfreeze
from flash.vision import ImageClassificationData, ImageClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)
# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes)

# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes)
# 4. Create the trainer. Run twice on data
trainer = flash.Trainer(max_epochs=2)

# 4. Create the trainer. Run twice on data
trainer = flash.Trainer(max_epochs=2)
# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
# 6. Test the model
trainer.test()

# 6. Test the model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")
# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")
41 changes: 20 additions & 21 deletions flash_examples/finetuning/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,29 @@
from flash import download_data
from flash.text import SummarizationData, SummarizationTask

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')

# 2. Load the data
datamodule = SummarizationData.from_files(
train_file="data/xsum/train.csv",
valid_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
)
# 2. Load the data
datamodule = SummarizationData.from_files(
train_file="data/xsum/train.csv",
valid_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
)

# 3. Build the model
model = SummarizationTask()
# 3. Build the model
model = SummarizationTask()

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)
# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)

# 6. Test model
trainer.test()
# 6. Test model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("summarization_model_xsum.pt")
# 7. Save it!
trainer.save_checkpoint("summarization_model_xsum.pt")
44 changes: 21 additions & 23 deletions flash_examples/finetuning/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,30 @@
from flash.core.data import download_data
from flash.tabular import TabularClassifier, TabularData

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
# 2. Load the data
datamodule = TabularData.from_csv(
"./data/titanic/titanic.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
target="Survived",
val_size=0.25,
)

# 2. Load the data
datamodule = TabularData.from_csv(
"./data/titanic/titanic.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
target="Survived",
val_size=0.25,
)
# 3. Build the model
model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])

# 3. Build the model
model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])
# 4. Create the trainer. Run 10 times on data
trainer = flash.Trainer(max_epochs=10)

# 4. Create the trainer. Run 10 times on data
trainer = flash.Trainer(max_epochs=10)
# 5. Train the model
trainer.fit(model, datamodule=datamodule)

# 5. Train the model
trainer.fit(model, datamodule=datamodule)
# 6. Test model
trainer.test()

# 6. Test model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("tabular_classification_model.pt")
# 7. Save it!
trainer.save_checkpoint("tabular_classification_model.pt")
44 changes: 21 additions & 23 deletions flash_examples/finetuning/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,30 @@
from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
# 2. Load the data
datamodule = TextClassificationData.from_files(
train_file="data/imdb/train.csv",
valid_file="data/imdb/valid.csv",
test_file="data/imdb/test.csv",
input="review",
target="sentiment",
batch_size=512
)

# 2. Load the data
datamodule = TextClassificationData.from_files(
train_file="data/imdb/train.csv",
valid_file="data/imdb/valid.csv",
test_file="data/imdb/test.csv",
input="review",
target="sentiment",
batch_size=512
)
# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)
# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy='freeze')

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy='freeze')
# 6. Test model
trainer.test()

# 6. Test model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("text_classification_model.pt")
# 7. Save it!
trainer.save_checkpoint("text_classification_model.pt")
41 changes: 20 additions & 21 deletions flash_examples/finetuning/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,29 @@
from flash import download_data
from flash.text import TranslationData, TranslationTask

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/')
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/')

# 2. Load the data
datamodule = TranslationData.from_files(
train_file="data/wmt_en_ro/train.csv",
valid_file="data/wmt_en_ro/valid.csv",
test_file="data/wmt_en_ro/test.csv",
input="input",
target="target",
)
# 2. Load the data
datamodule = TranslationData.from_files(
train_file="data/wmt_en_ro/train.csv",
valid_file="data/wmt_en_ro/valid.csv",
test_file="data/wmt_en_ro/test.csv",
input="input",
target="target",
)

# 3. Build the model
model = TranslationTask()
# 3. Build the model
model = TranslationTask()

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1, precision=16, gpus=1)
# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1, precision=16, gpus=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)
# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)

# 6. Test model
trainer.test()
# 6. Test model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("translation_model_en_ro.pt")
# 7. Save it!
trainer.save_checkpoint("translation_model_en_ro.pt")
32 changes: 15 additions & 17 deletions flash_examples/predict/classify_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@
from flash.core.data import download_data
from flash.vision import ImageClassificationData, ImageClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/")
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/")
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
18 changes: 7 additions & 11 deletions flash_examples/predict/classify_tabular.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from flash.core.data import download_data
from flash.tabular import TabularClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")

# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt"
)

# 3. Generate predictions from a sheet file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)
# 3. Generate predictions from a sheet file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)
42 changes: 20 additions & 22 deletions flash_examples/predict/classify_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,26 @@
from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)

# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)

# 2b. Or generate predictions from a sheet file!
datamodule = TextClassificationData.from_file(
predict_file="data/imdb/predict.csv",
input="review",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
# 2b. Or generate predictions from a sheet file!
datamodule = TextClassificationData.from_file(
predict_file="data/imdb/predict.csv",
input="review",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
Loading