-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #129 from GDSC-Delft-Dev/111-detect-pests-and-dise…
…ases-from-close-up-images-of-corn Detect pests and diseases from close up images of potato
- Loading branch information
Showing
15 changed files
with
516 additions
and
45 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from .potato import potato | ||
from .tomato import tomato | ||
|
||
mappings = { | ||
"tomato": { | ||
"bucket": "terrafarm-plantvillage-tomato", | ||
"labels": ['bacterial_spot', 'early_blight', 'healthy', | ||
'late_blight', 'leaf_mold', 'septoria_leaf_spot', | ||
'spider_mites', 'target_spot', 'mosaic_virus', | ||
'yellow_leaf_curl_virus'], | ||
"model": tomato.ResNetTomato | ||
}, | ||
"potato": { | ||
"bucket": "terrafarm-plantvillage-potato", | ||
"labels": ['early_blight', 'healthy', 'late_blight'], | ||
"model": potato.ResNetPotato | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import tensorflow as tf | ||
|
||
|
||
def keras_estimator(model, model_dir, config) -> tf.estimator.Estimator: | ||
""" | ||
Create a Keras Estimator from a Keras Model. | ||
Args: | ||
model: Keras Model | ||
model_dir: directory to save model parameters, graph and etc. | ||
config: configuration for Estimator | ||
Returns: | ||
tf.estimator.Estimator | ||
""" | ||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4) | ||
loss = tf.losses.CategoricalCrossentropy() | ||
model.compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=["accuracy"], | ||
) | ||
return tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=model_dir, config=config) | ||
|
||
def input_fn(features, labels, batch_size, mode): | ||
""" | ||
Input function for the Estimator. | ||
Args: | ||
features: np.ndarray of input shape | ||
labels: np.ndarray of output shape | ||
batch_size: batch size | ||
mode: tf.estimator.ModeKeys | ||
""" | ||
if labels is None: | ||
inputs = features | ||
else: | ||
inputs = (features, labels) | ||
dataset = tf.data.Dataset.from_tensor_slices(inputs) | ||
if mode == tf.estimator.ModeKeys.TRAIN: | ||
dataset = dataset.shuffle(1000).repeat().batch(batch_size) | ||
if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT): | ||
dataset = dataset.batch(batch_size) | ||
return dataset.make_one_shot_iterator().get_next() | ||
|
||
def serving_input_fn(input_shape: tuple) -> tf.estimator.export.TensorServingInputReceiver: | ||
""" | ||
Input function for serving. | ||
Args: | ||
input_shape: shape of input tensor | ||
Returns: | ||
tf.estimator.export.TensorServingInputReceiver | ||
""" | ||
feature_placeholder = tf.placeholder(tf.float32, input_shape) | ||
features = feature_placeholder | ||
return tf.estimator.export.TensorServingInputReceiver(features, feature_placeholder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import matplotlib.pyplot as plt | ||
from ...utils import performance_visualization | ||
|
||
|
||
def ResNetPotato(): | ||
""" | ||
Create a ResNet50 model for the tomato disease potato task. | ||
""" | ||
# create a ResNet50 model | ||
resnet = tf.keras.applications.ResNet50( | ||
include_top=False, | ||
weights="imagenet", | ||
input_shape=(256, 256, 3), | ||
pooling="avg", | ||
) | ||
# freeze the weights of the model | ||
resnet.trainable = False | ||
# add a dense layer with 10 output units, each correspoding to a class | ||
# in utils.py | ||
output = tf.keras.layers.Dense(3, activation="softmax")(resnet.output) | ||
model = tf.keras.Model(inputs=resnet.input, outputs=output) | ||
return model | ||
|
||
|
||
def train(model: tf.keras.Model, dataset: tuple[np.ndarray, np.ndarray], | ||
optimizer: tf.keras.optimizers.Optimizer, | ||
loss: tf.keras.losses.Loss): | ||
""" | ||
Train the model on the PlantVillage dataset. | ||
Assume the model is not compiled yet. | ||
Args: | ||
model - tf.keras.Model to finetune | ||
dataset - np.ndarray of shape (num_train_imgs x width x height x channels) | ||
optimizer - tf.keras.optimizers.Optimizer to use | ||
loss - tf.keras.losses.Loss to use | ||
""" | ||
|
||
model.compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=["accuracy"], | ||
) | ||
# shuffle data | ||
perm = np.random.permutation(dataset[0].shape[0]) | ||
dataset = (dataset[0][perm], dataset[1][perm]) | ||
|
||
train_data = (dataset[0][:int(0.8 * dataset[0].shape[0])], dataset[1][:int(0.8 * dataset[0].shape[0])]) | ||
val_data = (dataset[0][int(0.8 * dataset[0].shape[0]):], dataset[1][int(0.8 * dataset[0].shape[0]):]) | ||
|
||
# load the dataset | ||
|
||
train_dataset = tf.data.Dataset.from_tensor_slices((train_data[0], train_data[1])) | ||
val_dataset = tf.data.Dataset.from_tensor_slices((val_data[0], val_data[1])) | ||
|
||
# shuffle the dataset | ||
SHUFFLE_SIZE = 10000 | ||
BATCH_SIZE = 32 | ||
train_dataset = train_dataset.shuffle(SHUFFLE_SIZE).batch(BATCH_SIZE) | ||
val_dataset = val_dataset.shuffle(SHUFFLE_SIZE).batch(BATCH_SIZE) | ||
|
||
print(f"Number of samples in the training set: {train_data[0].shape[0]}") | ||
print(f"Number of samples in the validation set: {val_data[0].shape[0]}") | ||
|
||
# pre-fetch data for performance | ||
AUTOTUNE = tf.data.AUTOTUNE | ||
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE) | ||
validation_dataset = val_dataset.prefetch(buffer_size=AUTOTUNE) | ||
|
||
hist = model.fit(train_dataset, | ||
epochs=15, | ||
validation_data=validation_dataset) | ||
|
||
performance_visualization(hist) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import argparse | ||
from .model import keras_estimator, input_fn, serving_input_fn | ||
from .mappings import mappings | ||
from ..utils import create_dataset | ||
import numpy as np | ||
import tensorflow as tf | ||
import os | ||
import subprocess | ||
|
||
# working directory | ||
WORKING_DIR = os.getcwd() | ||
# Google Cloud Storage bucket | ||
SUPPORTED = ['potato', 'tomato'] | ||
|
||
|
||
def download_files_from_gcs(sources: list[str], destinations: list[str]): | ||
""" | ||
Download files from Google Cloud Storage from the specified sources and destinations. | ||
Args: | ||
sources: list of GCS paths | ||
destinations: list of local paths | ||
""" | ||
for source, dest in zip(sources, destinations): | ||
subprocess.check_call(['gsutil', 'cp', source, dest]) | ||
|
||
def load_data(bucket: str, labels: list[str], val_split: float = 0.2) -> tuple[tuple, tuple]: | ||
""" | ||
Load and preprocess dataset from local file system. | ||
Args: | ||
val_split: validation split | ||
Returns: | ||
tuple of tuples containing the training and validation data and labels | ||
""" | ||
assert 0 <= val_split < 1 | ||
|
||
# download data from GCS | ||
sources = ["gs://" + bucket + "/" + label + "/*" for label in labels] | ||
destinations = [WORKING_DIR + "/data/" + label + "/" for label in labels] | ||
download_files_from_gcs(sources, destinations) | ||
|
||
# TODO: give the paths as parameters instead of hardcoding them in utils.py | ||
data, dlabels = create_dataset() | ||
|
||
# shuffle data | ||
perm = np.random.permutation(data.shape[0]) | ||
dataset = (data[perm], dlabels[perm]) | ||
|
||
# split in train and validation data | ||
train_data = (dataset[0][:int((1-val_split) * dataset[0].shape[0])], dataset[1][:int((1-val_split) * dataset[0].shape[0])]) | ||
val_data = (dataset[0][int((1-val_split) * dataset[0].shape[0]):], dataset[1][int((1-val_split) * dataset[0].shape[0]):]) | ||
return train_data, val_data | ||
|
||
def get_args(): | ||
""" | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--job-dir', | ||
type=str, | ||
help='GCS location to write checkpoints and export models') | ||
parser.add_argument( | ||
'--test-split', | ||
type=float, | ||
default=0.2, | ||
help='Split between training and test, default=0.2') | ||
parser.add_argument( | ||
'--num-epochs', | ||
type=float, | ||
default=500, | ||
help='number of times to go through the data, default=500') | ||
parser.add_argument( | ||
'--batch-size', | ||
type=int, | ||
default=128, | ||
help='number of records to read during each training step, default=128') | ||
parser.add_argument( | ||
'--verbosity', | ||
choices=['DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN'], | ||
default='INFO') | ||
parser.add_argument( | ||
'--crop', | ||
type=str, | ||
help="type of crop from the supported crops") | ||
args, _ = parser.parse_known_args() | ||
assert args.crop in SUPPORTED | ||
return args | ||
|
||
def train_and_evaluate(args: argparse.Namespace): | ||
""" | ||
Train and evaluate the model on GCP | ||
Args: | ||
args: command line arguments | ||
""" | ||
train_dict: dict = mappings[args.crop] | ||
# load data | ||
(train_data, train_labels), (test_data, test_labels) = load_data(train_dict["bucket"], | ||
train_dict["labels"]) | ||
# save checkpoints every this many steps | ||
run_config = tf.estimator.RunConfig(save_checkpoints_steps=500) | ||
# number of training steps | ||
train_steps = args.num_epochs * len(train_data) / args.batch_size | ||
|
||
# specifications for training | ||
train_spec = tf.estimator.TrainSpec( | ||
input_fn=lambda: input_fn( | ||
train_data, | ||
train_labels, | ||
args.batch_size, | ||
mode=tf.estimator.ModeKeys.TRAIN), | ||
max_steps=train_steps) | ||
exporter = tf.estimator.LatestExporter('exporter', serving_input_fn) | ||
# specifications for evaluation | ||
eval_spec = tf.estimator.EvalSpec( | ||
input_fn=lambda: input_fn( | ||
test_data, | ||
test_labels, | ||
args.batch_size, | ||
mode=tf.estimator.ModeKeys.EVAL), | ||
steps=None, | ||
exporters=[exporter], | ||
start_delay_secs=10, | ||
throttle_secs=10) | ||
# define estimator | ||
estimator = keras_estimator( | ||
train_dict["model"](), # initialize model | ||
model_dir=args.job_dir, | ||
config=run_config) | ||
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) | ||
|
||
if __name__ == '__main__': | ||
args = get_args() | ||
tf.logging.set_verbosity(args.verbosity) | ||
train_and_evaluate(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import tensorflow as tf | ||
from .utils import create_dataset | ||
from .disease.potato.potato import ResNetPotato, train | ||
|
||
def main(): | ||
LEARNING_RATE = 3e-4 # learning rate | ||
dataset = create_dataset() | ||
model = ResNetPotato() | ||
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE) | ||
loss = tf.losses.CategoricalCrossentropy() | ||
model = train(model, dataset, optimizer, loss) | ||
model.save("potato_model/resnet_potato") | ||
print(f"Finished") | ||
|
||
main() # run main |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from setuptools import find_packages | ||
from setuptools import setup | ||
|
||
REQUIRED_PACKAGES = [ | ||
'notebook==6.5.2', | ||
'pandas==1.5.2', | ||
'numpy==1.24.1', | ||
'opencv-python==4.7.0.68', | ||
'matplotlib==3.6.2', | ||
'screeninfo==0.8.1', | ||
'tensorflow==2.10.0', | ||
'matplotlib==3.6.2', | ||
'mypy==1.0.1', | ||
'pylint==2.16.2', | ||
'pytest==7.2.1', | ||
'google-cloud-storage==2.7.0', | ||
'google-cloud-firestore==2.10.0', | ||
'argparse==1.4.0', | ||
'firebase-admin==6.1.0', | ||
'pydash==6.0.2', | ||
'pytest-asyncio==0.20.3', | ||
'tqdm==4.65.0', | ||
] | ||
|
||
setup( | ||
name='disease', | ||
version='0.1', | ||
install_requires=REQUIRED_PACKAGES, | ||
packages=find_packages(), | ||
include_package_data=True, | ||
requires=[] | ||
) |
Oops, something went wrong.