Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
simplify examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Feb 2, 2021
1 parent 101f560 commit da50d9b
Show file tree
Hide file tree
Showing 12 changed files with 212 additions and 231 deletions.
4 changes: 3 additions & 1 deletion flash/setup_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def _load_readme_description(path_dir: str, homepage: str = __homepage__, ver: s
github_source_url = os.path.join(homepage, "raw", ver)
# replace relative repository path to absolute link to the release
# do not replace all "docs" as in the readme we reger some other sources with particular path to docs
text = text.replace("docs/source/_static/images/", f"{os.path.join(github_source_url, 'docs/source/_static/images/')}")
text = text.replace(
"docs/source/_static/images/", f"{os.path.join(github_source_url, 'docs/source/_static/images/')}"
)

# readthedocs badge
text = text.replace('badge/?version=stable', f'badge/?version={ver}')
Expand Down
38 changes: 18 additions & 20 deletions flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,27 @@
from flash.core.finetuning import FreezeUnfreeze
from flash.vision import ImageClassificationData, ImageClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)
# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes)

# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes)
# 4. Create the trainer. Run twice on data
trainer = flash.Trainer(max_epochs=2)

# 4. Create the trainer. Run twice on data
trainer = flash.Trainer(max_epochs=2)
# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
# 6. Test the model
trainer.test()

# 6. Test the model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")
# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")
41 changes: 20 additions & 21 deletions flash_examples/finetuning/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,29 @@
from flash import download_data
from flash.text import SummarizationData, SummarizationTask

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')

# 2. Load the data
datamodule = SummarizationData.from_files(
train_file="data/xsum/train.csv",
valid_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
)
# 2. Load the data
datamodule = SummarizationData.from_files(
train_file="data/xsum/train.csv",
valid_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
)

# 3. Build the model
model = SummarizationTask()
# 3. Build the model
model = SummarizationTask()

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)
# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)

# 6. Test model
trainer.test()
# 6. Test model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("summarization_model_xsum.pt")
# 7. Save it!
trainer.save_checkpoint("summarization_model_xsum.pt")
44 changes: 21 additions & 23 deletions flash_examples/finetuning/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,30 @@
from flash.core.data import download_data
from flash.tabular import TabularClassifier, TabularData

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
# 2. Load the data
datamodule = TabularData.from_csv(
"./data/titanic/titanic.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
target="Survived",
val_size=0.25,
)

# 2. Load the data
datamodule = TabularData.from_csv(
"./data/titanic/titanic.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
target="Survived",
val_size=0.25,
)
# 3. Build the model
model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])

# 3. Build the model
model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])
# 4. Create the trainer. Run 10 times on data
trainer = flash.Trainer(max_epochs=10)

# 4. Create the trainer. Run 10 times on data
trainer = flash.Trainer(max_epochs=10)
# 5. Train the model
trainer.fit(model, datamodule=datamodule)

# 5. Train the model
trainer.fit(model, datamodule=datamodule)
# 6. Test model
trainer.test()

# 6. Test model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("tabular_classification_model.pt")
# 7. Save it!
trainer.save_checkpoint("tabular_classification_model.pt")
44 changes: 21 additions & 23 deletions flash_examples/finetuning/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,30 @@
from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
# 2. Load the data
datamodule = TextClassificationData.from_files(
train_file="data/imdb/train.csv",
valid_file="data/imdb/valid.csv",
test_file="data/imdb/test.csv",
input="review",
target="sentiment",
batch_size=512
)

# 2. Load the data
datamodule = TextClassificationData.from_files(
train_file="data/imdb/train.csv",
valid_file="data/imdb/valid.csv",
test_file="data/imdb/test.csv",
input="review",
target="sentiment",
batch_size=512
)
# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)
# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy='freeze')

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy='freeze')
# 6. Test model
trainer.test()

# 6. Test model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("text_classification_model.pt")
# 7. Save it!
trainer.save_checkpoint("text_classification_model.pt")
41 changes: 20 additions & 21 deletions flash_examples/finetuning/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,29 @@
from flash import download_data
from flash.text import TranslationData, TranslationTask

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/')
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/')

# 2. Load the data
datamodule = TranslationData.from_files(
train_file="data/wmt_en_ro/train.csv",
valid_file="data/wmt_en_ro/valid.csv",
test_file="data/wmt_en_ro/test.csv",
input="input",
target="target",
)
# 2. Load the data
datamodule = TranslationData.from_files(
train_file="data/wmt_en_ro/train.csv",
valid_file="data/wmt_en_ro/valid.csv",
test_file="data/wmt_en_ro/test.csv",
input="input",
target="target",
)

# 3. Build the model
model = TranslationTask()
# 3. Build the model
model = TranslationTask()

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1, precision=16, gpus=1)
# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1, precision=16, gpus=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)
# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)

# 6. Test model
trainer.test()
# 6. Test model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("translation_model_en_ro.pt")
# 7. Save it!
trainer.save_checkpoint("translation_model_en_ro.pt")
32 changes: 15 additions & 17 deletions flash_examples/predict/classify_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@
from flash.core.data import download_data
from flash.vision import ImageClassificationData, ImageClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/")
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/")
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
18 changes: 7 additions & 11 deletions flash_examples/predict/classify_tabular.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from flash.core.data import download_data
from flash.tabular import TabularClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")

# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt"
)

# 3. Generate predictions from a sheet file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)
# 3. Generate predictions from a sheet file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)
42 changes: 20 additions & 22 deletions flash_examples/predict/classify_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,26 @@
from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier

if __name__ == "__main__":
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)

# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)

# 2b. Or generate predictions from a sheet file!
datamodule = TextClassificationData.from_file(
predict_file="data/imdb/predict.csv",
input="review",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
# 2b. Or generate predictions from a sheet file!
datamodule = TextClassificationData.from_file(
predict_file="data/imdb/predict.csv",
input="review",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
Loading

0 comments on commit da50d9b

Please sign in to comment.