-
Notifications
You must be signed in to change notification settings - Fork 88
/
train_mitoses.py
2294 lines (1954 loc) · 96.4 KB
/
train_mitoses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Training - mitosis detection"""
import argparse
from datetime import datetime
import json
import math
import os
import pickle
import shutil
import sys
import numpy as np
import tensorflow as tf
import tensorboard as tb
import resnet
import resnet50
# data
def get_image(filename, patch_size):
"""Get image from filename.
Args:
filename: String filename of an image.
patch_size: Integer length to which the square image will be
resized.
Returns:
TensorFlow tensor containing the decoded and resized image with
type float32 and values in [0, 1).
"""
image_string = tf.read_file(filename)
# shape (h,w,c), uint8 in [0, 255]:
image = tf.image.decode_png(image_string, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32) # float32 [0, 1)
# TODO: remove this
#image = tf.image.resize_images(image, [patch_size, patch_size]) # float32 [0, 1)
#with tf.control_dependencies(
# [tf.assert_type(image, tf.float32, image.dtype),
# tf.verify_tensor_all_finite(image, "image tensor contains NaN or INF values"]):
return image
def get_label(filename):
"""Get label from filename.
Args:
filename: String in format "**/train|val/mitosis|normal/name.{ext}",
where the label is either "mitosis" or "normal".
Returns:
TensorFlow float binary label equal to 1 for mitosis or 0 for
normal.
"""
# note file name format:
# lab is a single digit, case and region are two digits with padding if needed
splits = tf.string_split([filename], "/")
label_str = splits.values[-2]
# check that label string is valid
is_valid = tf.logical_or(tf.equal(label_str, 'normal'), tf.equal(label_str, 'mitosis'))
assert_op = tf.Assert(is_valid, [label_str])
with tf.control_dependencies([assert_op]): # test for correct label extraction
#label = tf.to_int32(tf.equal(label_str, 'mitosis'))
label = tf.to_float(tf.equal(label_str, 'mitosis')) # required because model produces float
return label
def preprocess(filename, patch_size):
"""Get image and label from filename.
Args:
filename: String filename of an image.
patch_size: Integer length to which the square image will be
resized, if necessary.
Returns:
Tuple of a float32 image Tensor with shape (h,w,c) and values in
[0, 1), a binary label, and a filename.
"""
# return image_resized, label
label = get_label(filename)
#label = tf.expand_dims(label, -1) # make each scalar label a vector of length 1 to match model
image = get_image(filename, patch_size) # float32 in [0, 1)
return image, label, filename
def normalize(image, model_name):
"""Normalize an image tensor.
Note: due to broadcasting, this works with a single image, or a batch
of images.
Args:
image: A Tensor of shape (...,h,w,c) with values in [0, 1].
model_name: String indicating the model to use.
Returns:
A normalized image Tensor of shape (...,h,w,c).
"""
# NOTE: don't use in-place updates to avoid side-effects
if model_name in ("vgg", "vgg19", "resnet"):
means = np.array([103.939, 116.779, 123.68]).astype(np.float32)
image = image[..., ::-1] # rbg -> bgr
image = image * 255 # float32 in [0, 255]
image = image - means # mean centering using imagenet means
else:
# normalize to [-1, 1]
#image = image / 255
image = image - 0.5
image = image * 2
return image
def unnormalize(image, model_name):
"""Unnormalize an image tensor.
Note: due to broadcasting, this works with a single image, or a batch
of images.
Args:
image: A Tensor of shape (...,h,w,c) with normalized values.
model_name: String indicating the model to use.
Returns:
An unnormalized image Tensor of shape (...,h,w,c) with values in
[0, 1].
"""
# NOTE: don't use in-place updates to avoid side-effects
if model_name in ("vgg", "vgg19", "resnet"):
means = np.array([103.939, 116.779, 123.68]).astype(np.float32)
image = image + means # mean centering using imagenet means
image = image / 255 # float32 in [0, 1]
image = image[..., ::-1] # bgr -> rgb
else:
image = image / 2
image = image + 0.5
return image
def augment(image, patch_size, seed=None):
"""Apply random data augmentation to the given image.
Args:
image: A Tensor of shape (h,w,c) with values in [0, 1].
patch_size: The patch size to which to randomly crop the image.
seed: An integer used to create a random seed.
Returns:
A data-augmented image with values in [0, 1].
"""
# NOTE: these values currently come from the Google pathology paper:
# Liu Y, Gadepalli K, Norouzi M, Dahl GE, Kohlberger T, Boyko A, et al.
# Detecting Cancer Metastases on Gigapixel Pathology Images. arXiv.org. 2017.
# TODO: convert these hardcoded values into hyperparameters!!
# NOTE: if the seed is None, these ops will be seeded with a completely random seed, rather than
# a deterministic one based on the graph seed. This appears to only happen within the map
# functions of the Dataset API, based on the `test_num_parallel_calls` and
# `test_image_random_op_seeds` tests. For now, we will pass in a seed from the user and use it
# at the op level.
# NOTE: Additionally, if the Dataset.map() function that calls this function is using
# `num_parallel_calls` > 1, the results will be non-reproducible.
# TODO: https://github.com/tensorflow/tensorflow/issues/13932
# NOTE: ouch! It turns out that a reinitializable iterator for a Dataset will cause any ops with
# random seeds, such as these, to be reset, and thus each epoch will be evaluated exactly the
# same. The desired behavior would be to seed these ops once at the very beginning, so that an
# entire training run can be deterministic, but not with the exact same random augmentation during
# each epoch. Oh TensorFlow...
shape = tf.shape(image) # (h, w, c)
# random zoom
# TODO: possibly re-enable random zooms enabled via flag
#lb = shape[0] # lower bound on the resize is the current size of the image
#ub = lb + tf.to_int32(tf.to_float(lb)*0.25) # upper bound is 25% larger
#new_size = tf.random_uniform([2], minval=lb, maxval=ub, dtype=tf.int32, seed=seed)
#image = tf.image.resize_images(image, new_size) # random resize
#image = tf.random_crop(image, shape, seed=seed) # random cropping back to original size
# mirror padding if needed
size = int(math.ceil((patch_size + 30) * (math.cos(math.pi/4) + math.sin(math.pi/4))))
#pad_h = tf.maximum(0, size - shape[0])
#pad_w = tf.maximum(0, size - shape[1])
#pad_h_before = tf.to_int32(tf.floor(pad_h / 2))
#pad_w_before = tf.to_int32(tf.floor(pad_w / 2))
#pad_h_after = pad_h - pad_h_before
#pad_w_after = pad_w - pad_w_before
#paddings = tf.reshape(
# tf.stack([pad_h_before, pad_h_after, pad_w_before, pad_w_after, 0, 0], 0),
# [3, 2]) # h, w, z before/after paddings
pad = tf.to_int32(tf.ceil(tf.maximum(0, size - shape[0]) / 2))
paddings = tf.reshape(tf.stack([pad, pad, pad, pad, 0, 0], 0), [3, 2]) # h, w, z before/after
image = tf.pad(image, paddings, mode="REFLECT")
# random rotation
angle = tf.random_uniform([], minval=0, maxval=2*np.pi, seed=seed)
image = tf.contrib.image.rotate(image, angle, "BILINEAR")
# crop to bounding box to allow for random translation crop, if the input image is large enough
# note: translation distance: 7 µm = 30 pixels = max allowable euclidean pred distance from
# actual mitosis
# note: the allowable region is a circle with radius 30, but we are cropping to a square, so we
# can't crop from a square with side length of patch_size+30 or we run the risk of moving the
# mitosis to a spot in the corner of the resulting image, which would be outside of the circle
# radius, and thus we would be incorrect to label that image as positive. We also want to impose
# some amount of buffer into our learned model, so we place an upper bound of `c` pixels on the
# distance. We sample a distance along the height axis, compute a valid distance along the width
# axis that is upper bounded by a Euclidean translation distance of `c` in the worst case, crop
# the center of the image to this height and width, and then perform a random crop, yielding a
# patch for which the center is at most `c` pixels from the true center in terms of Euclidean
# distance.
# NOTE: In the dataset, all normal samples must be > 60 pixels from the center of a mitotic figure
# to avoid random crops that end up incorrectly within a mitotic region.
# c = 25 = sqrt(a**2 + b**2) = 6.25 µm
c = 25 # TODO: set this as a hyperparameter
a = tf.random_uniform([], minval=0, maxval=c, dtype=tf.int32, seed=seed)
b = tf.to_int32(tf.floor(tf.sqrt(tf.to_float(c**2 - a**2))))
crop_h = tf.minimum(shape[0], patch_size + a)
crop_w = tf.minimum(shape[1], patch_size + b)
image = tf.image.resize_image_with_crop_or_pad(image, crop_h, crop_w)
# random central crop == random translation augmentation
image = tf.random_crop(image, [patch_size, patch_size, 3], seed=seed)
image = tf.image.random_flip_up_down(image, seed=seed)
image = tf.image.random_flip_left_right(image, seed=seed)
image = tf.image.random_brightness(image, 64/255, seed=seed)
image = tf.image.random_contrast(image, 0.25, 1, seed=seed)
image = tf.image.random_saturation(image, 0.75, 1, seed=seed)
image = tf.image.random_hue(image, 0.04, seed=seed)
image = tf.clip_by_value(image, 0, 1)
return image
def create_augmented_batch(image, batch_size, patch_size):
"""Create a batch of augmented versions of the given image.
This will sample `batch_size/4` augmented images deterministically,
and yield four rotated variants for each augmented image (0, 90, 180,
270 degrees).
Args:
image: A Tensor of shape (h,w,c).
batch_size: Number of augmented versions to generate.
patch_size: The patch size to which to randomly crop the image.
Returns:
A Tensor of shape (batch_size,h,w,c) containing a batch of
data-augmented versions of the given image.
"""
assert batch_size % 4 == 0 or batch_size == 1, "batch_size must be 1 or divisible by 4"
# TODO rewrite this function to just draw `batch_size` examples from the `augment` function
def rots_batch(image):
rot0 = image
rot90 = tf.image.rot90(image)
rot180 = tf.image.rot90(image, k=2)
rot270 = tf.image.rot90(image, k=3)
rots = tf.stack([rot0, rot90, rot180, rot270])
return rots
if batch_size >= 4:
image_crop = tf.image.resize_image_with_crop_or_pad(image, patch_size, patch_size)
images = rots_batch(image_crop)
for i in range(round(batch_size/4)-1):
aug_image = augment(image, patch_size, i)
aug_image_rots = rots_batch(aug_image)
images = tf.concat([images, aug_image_rots], axis=0)
else:
images = tf.expand_dims(image, 0)
return images
def marginalize(x):
"""Marginalize over injected noise at test time.
This implements noise marginalization by averaging over a batch of
values. Typically, this would be used with logits for a batch of
augmented versions of a single image, or for the associated batch
of labels. This is only performed at test time when
`tf.keras.backend.learning_phase() == 0`.
Args:
x: A Tensor of shape (n,...).
Returns:
A Tensor of shape (1, ...) containing the average over the batch
dimension.
"""
avg_x = tf.reduce_mean(x, axis=0, keepdims=True, name="avg_x")
x = tf.cond(tf.logical_not(tf.keras.backend.learning_phase()), lambda: avg_x, lambda: x)
return x
def process_dataset(dataset, model_name, patch_size, augmentation, marginalization, marg_batch_size,
threads, seed=None):
"""Process a Dataset.
Args:
dataset: Dataset of filenames.
model_name: String indicating the model to use.
patch_size: Integer length to which the square patches will be
resized.
augmentation: Boolean for whether or not to apply random augmentation
to the images.
marginalization: Boolean for whether or not to use noise
marginalization when evaluating the validation set. If True, then
each image in the validation set will be expanded to a batch of
augmented versions of that image, and predicted probabilities for
each batch will be averaged to yield a single noise-marginalized
prediction for each image. Note: if this is True, then
`marg_batch_size` must be divisible by 4, or equal to 1 for a special
debugging case of no augmentation.
marg_batch_size: Integer training batch size.
threads: Integer number of threads for dataset buffering.
seed: Integer random seed.
Returns:
A labeled Dataset of augmented, normalized images, possibly with
marginalization.
"""
dataset = dataset.map(lambda filename: preprocess(filename, patch_size),
num_parallel_calls=threads)
# augment (typically at training time)
if augmentation:
dataset = dataset.map(
lambda image, label, filename: (augment(image, patch_size, seed), label, filename),
num_parallel_calls=threads)
else:
# we are now generating larger original images to allow for random rotations & translations
# during augmentation, and thus if we don't apply augmentation, we need to ensure that the
# images are center cropped to the correct size.
dataset = dataset.map(lambda image, label, filename:
(tf.image.resize_image_with_crop_or_pad(image, patch_size, patch_size), label, filename),
num_parallel_calls=threads)
# TODO: should this be in an `elif` block before the above `else` block? in particular, the
# patch sizes will be messed up
# marginalize (typically at eval time)
if marginalization:
dataset = dataset.map(lambda image, label, filename:
(create_augmented_batch(image, marg_batch_size, patch_size),
tf.tile(tf.expand_dims(label, -1), [marg_batch_size]),
tf.tile(tf.expand_dims(filename, -1), [marg_batch_size])),
num_parallel_calls=threads)
# normalize
dataset = dataset.map(lambda image, label, filename:
(normalize(image, model_name), label, filename), num_parallel_calls=threads)
return dataset
def create_dataset(path, model_name, patch_size, batch_size, shuffle, augmentation, marginalization,
oversampling, threads, prefetch_batches, seed=None):
"""Create a dataset.
Args:
path: String path to the generated image patches. This should
contain folders for each class.
model_name: String indicating the model to use.
patch_size: Integer length to which the square patches will be
resized.
batch_size: Integer training batch size.
shuffle: Boolean for whether or not to shuffle filenames.
augmentation: Boolean for whether or not to apply random augmentation
to the images.
marginalization: Boolean for whether or not to use noise
marginalization when evaluating the validation set. If True, then
each image in the validation set will be expanded to a batch of
augmented versions of that image, and predicted probabilities for
each batch will be averaged to yield a single noise-marginalized
prediction for each image. Note: if this is True, then
`batch_size` must be divisible by 4, or equal to 1 for a special
debugging case of no augmentation.
oversampling: Boolean for whether or not to oversample the minority
mitosis class via class-aware sampling. Not compatible with
marginalization.
threads: Integer number of threads for dataset buffering.
prefetch_batches: Integer number of batches to prefetch.
seed: Integer random seed.
Returns:
A Dataset object.
"""
# read & process images
if oversampling:
# oversample the minority mitosis class via class-aware sampling, in which we sample the mitosis
# and normal samples separately in order to yield class-balanced mini-batches.
mitosis_dataset = tf.data.Dataset.list_files(os.path.join(path, "mitosis", "*.png"))
normal_dataset = tf.data.Dataset.list_files(os.path.join(path, "normal", "*.png"))
# zipping will stop once the normal dataset is empty
mitosis_dataset = mitosis_dataset.repeat(-1).shuffle(int(1e6))
normal_dataset = normal_dataset.shuffle(int(1e6))
mitosis_dataset = process_dataset(mitosis_dataset, model_name, patch_size, augmentation, False,
batch_size, threads, seed)
normal_dataset = process_dataset(normal_dataset, model_name, patch_size, augmentation, False,
batch_size, threads, seed)
# zip together the datasets, then flatten and batch so that each mini-batch contains an even
# number of mitosis and normal samples
# NOTE: the number of elements in the zipped dataset is limited to the lesser of the mitosis and
# normal datasets, and since the mitosis dataset is set to repeat indefinitely, this zipped
# dataset will be limited to the number of normal samples
dataset = tf.data.Dataset.zip((mitosis_dataset, normal_dataset))
dataset = dataset.flat_map(lambda mitosis, normal:
tf.data.Dataset.from_tensors(mitosis).concatenate(tf.data.Dataset.from_tensors(normal)))
dataset = dataset.batch(batch_size)
# note that batch norm could be affected by very small final batches, but right now this would
# also affect evaluation tasks, so we will wait to enable this until we move to tf Estimators
#dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
dataset = tf.data.Dataset.list_files(os.path.join(path, "*", "*.png"))
if shuffle:
dataset = dataset.shuffle(int(1e7))
dataset = process_dataset(dataset, model_name, patch_size, augmentation, marginalization,
batch_size, threads, seed)
# batch if necessary
if not marginalization:
dataset = dataset.batch(batch_size)
# note that batch norm could be affected by very small final batches, but right now this would
# also affect evaluation tasks, so we will wait to enable this until we move to tf Estimators
#dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
# prefetch
dataset = dataset.prefetch(prefetch_batches)
return dataset
# model
def create_model(model_name, input_shape, images):
"""Create a model.
Args:
model_name: String indicating the model to use in ("vgg", "vgg19",
"resnet", "logreg").
input_shape: 3-Tuple containing the shape of a single image.
images: An image Tensor of shape (n,h,w,c).
Returns:
An unfrozen Keras Model in which `images` is the input tensor, and
another Model object representing the base model when using
pretrained models.
"""
if model_name == "logreg":
# logistic regression classifier
model_base = None
inputs = tf.keras.layers.Input(shape=input_shape, tensor=images)
x = tf.keras.layers.Flatten()(inputs)
# init tf.keras.layers.Dense weights with Gaussian scaled by sqrt(2/(fan_in+fan_out))
logits = tf.keras.layers.Dense(1, kernel_initializer="glorot_normal")(x)
model_tower = tf.keras.Model(inputs=inputs, outputs=logits, name="model")
elif model_name == "vgg":
# create a model by replacing the classifier of a VGG16 model with a new classifier specific
# to the breast cancer problem
# recommend fine-tuning last 4 layers
#with tf.device("/cpu"):
model_base = tf.keras.applications.VGG16(
include_top=False, input_shape=input_shape, input_tensor=images)
inputs = model_base.inputs
x = model_base.output
x = tf.keras.layers.Flatten()(x)
#x = tf.keras.layers.GlobalAveragePooling2D()(x)
#x = tf.keras.layers.Dropout(0.5)(x)
#x = tf.keras.layers.Dropout(0.1)(x)
#x = tf.keras.layers.Dense(256, activation='relu', name='fc1')(x)
#x = tf.keras.layers.Dense(256, kernel_initializer="he_normal",
# kernel_regularizer=keras.regularizers.l2(l2))(x)
#x = tf.keras.layers.Dropout(0.5)(x)
#x = tf.keras.layers.Dropout(0.1)(x)
#x = tf.keras.layers.Dense(256, activation='relu', name='fc2')(x)
#x = tf.keras.layers.Dense(256, kernel_initializer="he_normal",
# kernel_regularizer=keras.regularizers.l2(l2))(x)
# init tf.keras.layers.Dense weights with Gaussian scaled by sqrt(2/(fan_in+fan_out))
logits = tf.keras.layers.Dense(1, kernel_initializer="glorot_normal")(x)
model_tower = tf.keras.Model(inputs=inputs, outputs=logits, name="model")
elif model_name == "vgg_new":
# train a new vgg16-like model from scratch on inputs in [-1, 1].
#with tf.device("/cpu"):
model_base = tf.keras.applications.VGG16(
include_top=False, input_shape=input_shape, input_tensor=images, weights=None)
inputs = model_base.inputs
x = model_base.output
x = tf.keras.layers.Flatten()(x)
#x = tf.keras.layers.GlobalAveragePooling2D()(x)
#x = tf.keras.layers.Dropout(0.5)(x)
#x = tf.keras.layers.Dropout(0.1)(x)
#x = tf.keras.layers.Dense(256, activation='relu', name='fc1')(x)
#x = tf.keras.layers.Dense(256, kernel_initializer="he_normal",
# kernel_regularizer=tf.keras.regularizers.l2(l2))(x)
#x = tf.keras.layers.Dropout(0.5)(x)
#x = tf.keras.layers.Dropout(0.1)(x)
#x = tf.keras.layers.Dense(256, activation='relu', name='fc2')(x)
#x = tf.keras.layers.Dense(256, kernel_initializer="he_normal",
# kernel_regularizer=tf.keras.regularizers.l2(l2))(x)
# init tf.keras.layers.Dense weights with Gaussian scaled by sqrt(2/(fan_in+fan_out))
logits = tf.keras.layers.Dense(1, kernel_initializer="glorot_normal")(x)
model_tower = tf.keras.Model(inputs=inputs, outputs=logits, name="model")
elif model_name == "vgg19":
# create a model by replacing the classifier of a VGG19 model with a new classifier specific
# to the breast cancer problem
# recommend fine-tuning last 4 layers
#with tf.device("/cpu"):
#inputs = tf.keras.layers.Input(shape=input_shape)
model_base = tf.keras.applications.VGG19(
include_top=False, input_shape=input_shape, input_tensor=images)
inputs = model_base.inputs
x = model_base.output
x = tf.keras.layers.Flatten()(x)
# init tf.keras.layers.Dense weights with Gaussian scaled by sqrt(2/(fan_in+fan_out))
logits = tf.keras.layers.Dense(1, kernel_initializer="glorot_normal")(x)
model_tower = tf.keras.Model(inputs=inputs, outputs=logits, name="model")
elif model_name == "resnet":
# create a model by replacing the classifier of a ResNet50 model with a new classifier
# specific to the breast cancer problem
# recommend fine-tuning last 11 (stage 5 block c), 21 (stage 5 blocks b & c), or 33 (stage
# 5 blocks a,b,c) layers
#with tf.device("/cpu"):
# NOTE: there is an issue in keras with using batch norm with model templating, i.e.,
# defining a model with generic inputs and then calling it on a tensor. the issue stems from
# batch norm not being well defined for shared settings, but it makes it quite annoying in
# this context. to "fix" it, we define it by directly passing in the `images` tensor
# https://github.com/fchollet/keras/issues/2827
model_base = resnet50.ResNet50(include_top=False, input_shape=input_shape, input_tensor=images)
inputs = model_base.inputs
x = model_base.output
x = tf.keras.layers.Flatten()(x)
#x = tf.keras.layers.GlobalAveragePooling2D()(x)
# init tf.keras.layers.Dense weights with Gaussian scaled by sqrt(2/(fan_in+fan_out))
logits = tf.keras.layers.Dense(1, kernel_initializer="glorot_normal")(x)
model_tower = tf.keras.Model(inputs=inputs, outputs=logits, name="model")
elif model_name == "resnet_new":
# train a new resnet50-like model from scratch on inputs in [-1, 1].
#with tf.device("/cpu"):
# NOTE: there is an issue in keras with using batch norm with model templating, i.e.,
# defining a model with generic inputs and then calling it on a tensor. the issue stems from
# batch norm not being well defined for shared settings, but it makes it quite annoying in
# this context. to "fix" it, we define it by directly passing in the `images` tensor
# https://github.com/fchollet/keras/issues/2827
model_base = resnet50.ResNet50(
include_top=False, input_shape=input_shape, input_tensor=images, weights=None)
inputs = model_base.inputs
x = model_base.output
x = tf.keras.layers.Flatten()(x)
#x = tf.keras.layers.GlobalAveragePooling2D()(x)
# init tf.keras.layers.Dense weights with Gaussian scaled by sqrt(2/(fan_in+fan_out))
logits = tf.keras.layers.Dense(1, kernel_initializer="glorot_normal")(x)
model_tower = tf.keras.Model(inputs=inputs, outputs=logits, name="model")
elif model_name == "resnet_custom":
model_base = None
model_tower = resnet.ResNet(images, input_shape)
else:
raise Exception("model name unknown: {}".format(model_name))
# TODO: add this when it's necessary, and move to a separate function
## Multi-GPU exploitation via a linear combination of GPU loss functions.
#ins = []
#outs = []
#for i in range(num_gpus):
# with tf.device("/gpu:{}".format(i)):
# x = tf.keras.layers.Input(shape=input_shape) # split of batch
# out = resnet50(x) # run split on shared model
# ins.append(x)
# outs.append(out)
#model = tf.keras.Model(inputs=ins, outputs=outs) # multi-GPU, data-parallel model
model = model_tower
# unfreeze all model layers.
for layer in model.layers[1:]: # don't include input layer
layer.trainable = True
return model, model_base
# based on `keras.utils.multi_gpu_model`
def multi_gpu_model(model, gpus):
"""Replicates a model on different GPUs.
Specifically, this function implements single-machine
multi-GPU data parallelism. It works in the following way:
- Divide the model's input(s) into multiple sub-batches.
- Apply a model copy on each sub-batch. Every model copy
is executed on a dedicated GPU.
- Concatenate the results (on CPU) into one big batch.
E.g. if your `batch_size` is 64 and you use `gpus=2`,
then we will divide the input into 2 sub-batches of 32 samples,
process each sub-batch on one GPU, then return the full
batch of 64 processed samples.
This induces quasi-linear speedup on up to 8 GPUs.
This function is only available with the TensorFlow backend
for the time being.
# Arguments
model: A Keras model instance. To avoid OOM errors,
this model could have been built on CPU, for instance
(see usage example below).
gpus: Integer >= 2 or list of integers, number of GPUs or
list of GPU IDs on which to create model replicas.
# Returns
A Keras `Model` instance which can be used just like the initial
`model` argument, but which distributes its workload on multiple GPUs.
"""
if isinstance(gpus, (list, tuple)):
num_gpus = len(gpus)
target_gpu_ids = gpus
else:
num_gpus = gpus
target_gpu_ids = range(num_gpus)
def get_slice(data, i, parts):
shape = tf.shape(data)
batch_size = shape[:1]
input_shape = shape[1:]
step = batch_size // parts
if i == num_gpus - 1:
size = batch_size - step * i
else:
size = step
size = tf.concat([size, input_shape], axis=0)
stride = tf.concat([step, input_shape * 0], axis=0)
start = stride * i
return tf.slice(data, start, size)
all_outputs = []
for i in range(len(model.outputs)):
all_outputs.append([])
# Place a copy of the model on each GPU,
# each getting a slice of the inputs.
for i, gpu_id in enumerate(target_gpu_ids):
with tf.device('/cpu:0'):
inputs = []
# Retrieve a slice of the input on the CPU
for x in model.inputs:
input_shape = tuple(x.get_shape().as_list())[1:]
slice_i = tf.keras.layers.Lambda(
get_slice, output_shape=input_shape, arguments={'i': i, 'parts': num_gpus})(x)
inputs.append(slice_i)
with tf.device('/gpu:%d' % gpu_id):
with tf.name_scope('replica_%d' % gpu_id):
# Apply model on slice (creating a model replica on the target device).
outputs = model(inputs)
if not isinstance(outputs, list):
outputs = [outputs]
# Save the outputs for merging back together later.
for o in range(len(outputs)):
all_outputs[o].append(outputs[o])
# Merge outputs on CPU.
with tf.device('/cpu:0'):
merged = []
for name, outputs in zip(model.output_names, all_outputs):
merged.append(tf.keras.layers.concatenate(outputs, axis=0, name=name))
return tf.keras.Model(model.inputs, merged)
def compute_data_loss(labels, logits):
"""Compute the mean logistic loss.
Args:
labels: A Tensor of shape (n, 1) containing a batch of labels.
logits: A Tensor of shape (n, 1) containing a batch of pre-sigmoid
prediction values.
Returns:
A scalar Tensor representing the mean logistic loss.
"""
# TODO: this is a binary classification problem so optimizing a loss derived from a Bernoulli
# distribution is appropriate. however, would the dynamics of the training algorithm be more
# stable if we treated this as a multi-class classification problem and derived a loss from a
# Multinomial distribution with two classes (and a single trial)? it would be
# over-parameterized, but then again, the deep net itself is already heavily parameterized.
# Bernoulli-derived loss
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.reshape(labels, [-1, 1]), logits=logits))
# Multinomial-derived loss
#labels = tf.one_hot(indices=labels, depth=2, on_value=1, off_value=0)
#loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits))
#loss = tf.reduce_mean(
# tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, features=logits))
return loss
def compute_l2_reg_loss(model, include_frozen=False, reg_final=True, reg_biases=False):
"""Compute L2 loss of trainable model weights.
This places a Gaussian prior with mean 0 std 1 on each of the model
parameters.
Args:
model: A Keras Model object.
include_frozen: Boolean for whether or not to ignore frozen layers.
reg_final: Boolean for whether or not to regularize the final
logits-producing layer.
reg_biases: Boolean for whether or not to regularize biases.
Returns:
The L2 regularization loss of all trainable (i.e., unfrozen) model
weights, unless `include_frozen` is True, in which case all weights
are used.
"""
weights = []
if reg_final:
end = None
else: # don't regularize the final function that produces logits
end = -1 if not model.layers[-1].name.startswith("flatten") else -2
for layer in model.layers[:end]:
if layer.trainable or include_frozen:
if hasattr(layer, 'kernel'): # conv, dense
weights.append(layer.kernel)
elif hasattr(layer, 'gamma'): # batch norm scale
weights.append(1.0 - layer.gamma) # Gaussian prior centered at 1 for batch norm gamma value
if reg_biases:
# TODO: generally, we don't regularize the biases, but could we determine a probabilistic
# motivation to do this?
if hasattr(layer, 'bias'):
weights.append(layer.bias)
elif hasattr(layer, 'beta'):
weights.append(layer.beta)
l2_loss = tf.add_n([tf.nn.l2_loss(w) for w in weights])
return l2_loss
def compute_metrics(loss, labels, preds, probs, num_thresholds):
"""Compute metrics.
This creates ops that compute metrics in a streaming fashion.
Args:
loss: A Tensor representing the current batch mean loss.
labels: A Tensor of shape (n, 1) containing a batch of labels.
preds: A Tensor of shape (n, 1) containing a batch of binary
prediction values.
probs: A Tensor of shape (n, 1) containing a batch of probabilistic
prediction values.
num_thresholds: An integer indicating the number of thresholds to
use to compute PR curves.
Returns:
A tuple of mean loss, accuracy, positive predictive value
(precision), sensitivity (recall), F1, PR curve data, F1 list
based on the PR curve data, a grouped metrics update op, and a
group metrics reset op.
"""
# TODO: think about converting this to a class
mean_loss, mean_loss_update_op, mean_loss_reset_op = create_resettable_metric(tf.metrics.mean,
'mean_loss', values=loss)
acc, acc_update_op, acc_reset_op = create_resettable_metric(tf.metrics.accuracy,
'acc', labels=labels, predictions=preds)
ppv, ppv_update_op, ppv_reset_op = create_resettable_metric(tf.metrics.precision,
'ppv', labels=labels, predictions=preds)
sens, sens_update_op, sens_reset_op = create_resettable_metric(tf.metrics.recall,
'sens', labels=labels, predictions=preds)
f1 = 2 * (ppv * sens) / (ppv + sens)
pr, pr_update_op, pr_reset_op = create_resettable_metric(
tf.contrib.metrics.precision_recall_at_equal_thresholds,
'pr', labels=tf.cast(labels, dtype=tf.bool), predictions=probs, num_thresholds=num_thresholds)
f1s = 2 * (pr.precision * pr.recall) / (pr.precision + pr.recall)
# combine all reset & update ops
metric_update_ops = tf.group(
mean_loss_update_op, acc_update_op, ppv_update_op, sens_update_op, pr_update_op)
metric_reset_ops = tf.group(
mean_loss_reset_op, acc_reset_op, ppv_reset_op, sens_reset_op, pr_reset_op)
return mean_loss, acc, ppv, sens, f1, pr, f1s, metric_update_ops, metric_reset_ops
#return mean_loss, acc, ppv, sens, f1, metric_update_ops, metric_reset_ops
# utils
def create_resettable_metric(metric, scope, **metric_kwargs): # prob safer to only allow kwargs
"""Create a resettable metric.
Args:
metric: A tf.metrics metric function.
scope: A String scope name to enclose the metric variables within.
metric_kwargs: Kwargs for the metric.
Returns:
The metric op, the metric update op, and a metric reset op.
"""
# started with an implementation from https://github.com/tensorflow/tensorflow/issues/4814
with tf.variable_scope(scope) as scope:
metric_op, update_op = metric(**metric_kwargs)
scope_name = tf.contrib.framework.get_name_scope() # in case nested name/variable scopes
local_vars = tf.contrib.framework.get_variables(scope_name,
collection=tf.GraphKeys.LOCAL_VARIABLES) # get all local variables in this scope
reset_op = tf.variables_initializer(local_vars)
return metric_op, update_op, reset_op
def initialize_variables(sess):
"""Initialize variables for training.
This initializes all tensor variables in the graph.
Args:
sess: A TensorFlow Session.
"""
# NOTE: Keras keeps track of the variables that are initialized, and any call to
# `tf.keras.backend.get_session()`, which is even used internally, will include logic to
# initialize variables. There is a situation in which resuming from a previous checkpoint and
# then saving the model after the first epoch will result in part of the model being
# reinitialized. The problem is that calling `tf.keras.backend.get_session()` here is too soon
# to initialize any variables, the resume branch skips any variable initialization, and then the
# `model.save` code path ends up calling `tf.keras.backend.get_session()`, thus causing part of
# the model to be reinitialized. Specifically, the model base is fine because it is initialized
# when the pretrained weights are added in, but the new dense classifier will not be marked as
# initialized by Keras. The non-resume branch will initialize any variables not initialized by
# Keras yet, and thus will avoid this issue. It could be possible to use
# `tf.keras.backend.manual_variable_initialization(True)` and then manually initialize
# all variables, but this would cause any pretrained weights to be removed. Instead, we should
# initialize all variables first with the equivalent of the logic in
# `tf.keras.backend.get_session()`, and then call resume.
# NOTE: the global variables initializer will erase the pretrained weights, so we instead only
# initialize the other variables
# NOTE: reproduced from the old tf.keras.backend._initialize_variables() function
# EDIT: this was updated in the master branch in commit
# https://github.com/fchollet/keras/commit/9166733c3c144739868fe0c30d57b861b4947b44
# TODO: given the change in master, reevaluate whether or not this is actually necessary anymore
variables = tf.global_variables()
uninitialized_variables = []
for v in variables:
if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
uninitialized_variables.append(v)
v._keras_initialized = True
global_init_op = tf.variables_initializer(uninitialized_variables)
local_init_op = tf.local_variables_initializer()
sess.run([global_init_op, local_init_op])
# training
def train(train_path, val_path, exp_path, model_name, model_weights, patch_size, train_batch_size,
val_batch_size, clf_epochs, finetune_epochs, clf_lr, finetune_lr, finetune_momentum,
finetune_layers, l2, reg_biases, reg_final, augmentation, marginalization, oversampling,
num_gpus, threads, prefetch_batches, log_interval, checkpoint, resume, seed):
"""Train a model.
Args:
train_path: String path to the generated training image patches.
This should contain folders for each class.
val_path: String path to the generated validation image patches.
This should contain folders for each class.
exp_path: String path in which to store the model checkpoints, logs,
etc. for this experiment
model_name: String indicating the model to use.
model_weights: Optional string path to an HDF5 file containing the
initial weights of the model. If None, then pretrained imagenet
weights will be used.
patch_size: Integer length to which the square patches will be
resized.
train_batch_size: Integer training batch size.
val_batch_size: Integer validation batch size.
clf_epochs: Integer number of epochs for which to training the new
classifier layers.
finetune_epochs: Integer number of epochs for which to fine-tune the
model.
clf_lr: Float learning rate for training the new classifier layers.
finetune_lr: Float learning rate for fine-tuning the model.
finetune_momentum: Float momentum rate for fine-tuning the model.
finetune_layers: Integer number of layers at the end of the
pretrained portion of the model to fine-tune. The new classifier
layers will still be trained during fine-tuning as well.
l2: Float L2 global regularization value.
reg_biases: Boolean for whether or not to regularize biases.
reg_final: Boolean for whether or not to regularize the final
logits-producing layer.
augmentation: Boolean for whether or not to apply random
augmentation to the images.
marginalization: Boolean for whether or not to use noise
marginalization when evaluating the validation set. If True, then
each image in the validation set will be expanded to a batch of
augmented versions of that image, and predicted probabilities for
each batch will be averaged to yield a single noise-marginalized
prediction for each image. Note: if this is True, then
`val_batch_size` must be divisible by 4, or equal to 1 for a
special debugging case of no augmentation.
oversampling: Boolean for whether or not to oversample the minority
mitosis class via class-aware sampling.
num_gpus: Integer number of GPUs to use for data parallelism.
threads: Integer number of threads for dataset buffering.
prefetch_batches: Integer number of batches to prefetch.
log_interval: Integer number of steps between logging during
training.
checkpoint: Boolean flag for whether or not to save a checkpoint
after each epoch.
resume: Boolean flag for whether or not to resume training from a
checkpoint.
seed: Integer random seed.
"""
# TODO: break this out into:
# * data gen func
# * inference func
# * loss func
# * metrics func
# * logging func
# * train func
# set random seed
# NOTE: At the moment, this is faily useless because if the augmentation ops are seeded, they will
# be evaluated in the exact same deterministic manner on every epoch, which is not desired.
# Additionally, the multithreading needed to process the data will cause non-deterministic
# results. The one benefit is that the classification layers will be created deterministically.
np.random.seed(seed)
tf.set_random_seed(seed)
# create session, force tf.Keras to use it
config = tf.ConfigProto(allow_soft_placement=True)#, log_device_placement=True)
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
# debugger
#from tensorflow.python import debug as tf_debug
#sess = tf_debug.LocalCLIDebugWrapperSession(sess)
# data
with tf.name_scope("data"):
# NOTE: seed issues to be fixed in tf
train_dataset = create_dataset(train_path, model_name, patch_size, train_batch_size, True,
augmentation, False, oversampling, threads, prefetch_batches) #, seed)
val_dataset = create_dataset(val_path, model_name, patch_size, val_batch_size, False,
False, marginalization, False, threads, prefetch_batches) #, seed)
# note that batch norm could be affected by very small final batches, but right now the fix,
# which requires this change as well, would also affect evaluation tasks, so we will wait to
# enable that (and this change too) until we move to tf Estimators
#output_shapes = (tf.TensorShape([None, patch_size, patch_size, 3]),
# tf.TensorShape([None]),
# tf.TensorShape([None]))
iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
train_dataset.output_shapes)
#output_shapes)
train_init_op = iterator.make_initializer(train_dataset)
val_init_op = iterator.make_initializer(val_dataset)
images, labels, filenames = iterator.get_next()
input_shape = (patch_size, patch_size, 3)
# models
with tf.name_scope("model"):
# replicate model on each GPU to allow for data parallelism
if num_gpus > 1:
with tf.device("/cpu:0"):
model_tower, model_base = create_model(model_name, input_shape, images)
#model_tower, model_base = create_model(model_name, input_shape, images)
model = multi_gpu_model(model_tower, num_gpus)
else:
model_tower, model_base = create_model(model_name, input_shape, images)
model = model_tower
if model_weights is not None:
model_tower.load_weights(model_weights)
# compute logits and predictions, possibly with marginalization
# NOTE: tf prefers to feed logits into a combined sigmoid and logistic loss function for
# numerical stability
if marginalization:
logits = marginalize(model.output) # will marginalize at test time
labels = tf.cond(tf.keras.backend.learning_phase(), lambda: labels, lambda: labels[0:1])
else:
logits = model.output
# for Bernoulli-derived loss
probs = tf.nn.sigmoid(logits, name="probs")
preds = tf.round(probs, name="preds") # implicit threshold at 0.5
# for Multinomial-derived loss
#probs = tf.nn.softmax(logits, name="probs") # possible improved numerical stability
#preds = tf.argmax(probs, axis=1, name="preds")
# loss
with tf.name_scope("loss"):
with tf.control_dependencies([tf.assert_equal(tf.shape(labels)[0], tf.shape(logits)[0])]):
data_loss = compute_data_loss(labels, logits)
reg_loss = compute_l2_reg_loss(
model_tower, include_frozen=True, reg_final=reg_final, reg_biases=reg_biases)
loss = data_loss + l2*reg_loss
# TODO: enable this and test it
# use l2 reg during training, but not during validation. Otherwise, more fine-tuning will
# lead to an apparent lower validation loss, even though it may just be due to more layers
# that can be adjusted in order to lower the regularization portion of the loss.
#loss = tf.cond(tf.keras.backend.learning_phase(), lambda: data_loss + l2*reg_loss, lambda: data_loss)