-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Freeze layers for transfer learning #3247
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
from .evaluate import evaluate | ||
from six.moves import zip, range | ||
from .util.config import Config, initialize_globals | ||
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint | ||
from .util.checkpoints import drop_freeze_number_to_layers, load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint | ||
from .util.evaluate_tools import save_samples_json | ||
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features | ||
from .util.flags import create_flags, FLAGS | ||
|
@@ -322,8 +322,24 @@ def get_tower_results(iterator, optimizer, dropout_rates): | |
# Retain tower's avg losses | ||
tower_avg_losses.append(avg_loss) | ||
|
||
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) | ||
|
||
# Filter out layers if we want to freeze some | ||
if FLAGS.freeze_source_layers > 0: | ||
filter_vars = drop_freeze_number_to_layers(FLAGS.freeze_source_layers, "freeze") | ||
new_train_vars = list(train_vars) | ||
for fv in filter_vars: | ||
for tv in train_vars: | ||
if fv in tv.name: | ||
new_train_vars.remove(tv) | ||
train_vars = new_train_vars | ||
msg = "Tower {} - Training only variables: {}" | ||
print(msg.format(i, [v.name for v in train_vars])) | ||
else: | ||
print("Tower {} - Training all layers".format(i)) | ||
|
||
# Compute gradients for model parameters using tower's mini-batch | ||
gradients = optimizer.compute_gradients(avg_loss) | ||
gradients = optimizer.compute_gradients(avg_loss, var_list=train_vars) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like someone else to take a look at this. |
||
|
||
# Retain tower's gradients | ||
tower_gradients.append(gradients) | ||
|
@@ -671,7 +687,6 @@ def __call__(self, progress, data, **kwargs): | |
|
||
print('-' * 80) | ||
|
||
|
||
except KeyboardInterrupt: | ||
pass | ||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
import sys | ||
import tensorflow as tf | ||
|
||
import tensorflow.compat.v1 as tfv1 | ||
|
||
from .flags import FLAGS | ||
from .logging import log_info, log_error, log_warn | ||
from .logging import log_error, log_info, log_warn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why change the order here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Autosort? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know what the DS policy is for that, I'd have to ask. |
||
|
||
|
||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True): | ||
|
@@ -19,32 +19,33 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= | |
# compatibility with older checkpoints. | ||
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') | ||
if lr_var and ('learning_rate' not in vars_in_ckpt or | ||
(FLAGS.force_initialize_learning_rate and allow_lr_init)): | ||
(FLAGS.force_initialize_learning_rate and allow_lr_init)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Spacing only change... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixing: PEP 8: E127 continuation line over-indented for visual indent |
||
assert len(lr_var) <= 1 | ||
load_vars -= lr_var | ||
init_vars |= lr_var | ||
|
||
if FLAGS.load_cudnn: | ||
# Initialize training from a CuDNN RNN checkpoint | ||
# Identify the variables which we cannot load, and set them | ||
# for initialization | ||
missing_vars = set() | ||
for v in load_vars: | ||
if v.op.name not in vars_in_ckpt: | ||
log_warn('CUDNN variable not found: %s' % (v.op.name)) | ||
missing_vars.add(v) | ||
# After training with "freeze_source_layers" the Adam moment tensors for the frozen layers | ||
# are missing because they were not used. This might also occur when loading a cudnn checkpoint | ||
# Therefore we have to initialize them again to continue training on such checkpoints | ||
print_msg = False | ||
for v in load_vars: | ||
if v.op.name not in vars_in_ckpt: | ||
if 'Adam' in v.name: | ||
init_vars.add(v) | ||
print_msg = True | ||
if print_msg: | ||
msg = "Some Adam tensors are missing, they will be initialized automatically." | ||
log_info(msg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you not do just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe something like missing = []
if ...
missing.append(v)
if missing:
for v in missing:
log_info("Missing... {}".format(v)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The split into two parts was an autoformatting issue, black would have split this into three lines. I didn't print every missing layer, because there are messages later, that state which layers were reinitialized exactly. |
||
load_vars -= init_vars | ||
|
||
load_vars -= init_vars | ||
|
||
# Check that the only missing variables (i.e. those to be initialised) | ||
# are the Adam moment tensors, if they aren't then we have an issue | ||
missing_var_names = [v.op.name for v in missing_vars] | ||
if any('Adam' not in v for v in missing_var_names): | ||
log_error('Tried to load a CuDNN RNN checkpoint but there were ' | ||
'more missing variables than just the Adam moment ' | ||
'tensors. Missing variables: {}'.format(missing_var_names)) | ||
sys.exit(1) | ||
if FLAGS.load_cudnn: | ||
# Check all required tensors are included in the cudnn checkpoint we want to load | ||
for v in load_vars: | ||
if v.op.name not in vars_in_ckpt and 'Adam' not in v.op.name: | ||
msg = 'Tried to load a CuDNN RNN checkpoint but there was a missing' \ | ||
' variable other than an Adam moment tensor: {}' | ||
log_error(msg.format(v.op.name)) | ||
sys.exit(1) | ||
|
||
if allow_drop_layers and FLAGS.drop_source_layers > 0: | ||
# This transfer learning approach requires supplying | ||
|
@@ -59,7 +60,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= | |
'dropping only 5 layers.') | ||
FLAGS.drop_source_layers = 5 | ||
|
||
dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):] | ||
dropped_layers = drop_freeze_number_to_layers(FLAGS.drop_source_layers, "drop") | ||
# Initialize all variables needed for DS, but not loaded from ckpt | ||
for v in load_vars: | ||
if any(layer in v.op.name for layer in dropped_layers): | ||
|
@@ -75,6 +76,24 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= | |
session.run(v.initializer) | ||
|
||
|
||
def drop_freeze_number_to_layers(drop_freeze_number, mode): | ||
""" Convert number of layers to drop or freeze into layer names """ | ||
|
||
if drop_freeze_number >= 6: | ||
log_warn('The checkpoint only has 6 layers, but you are trying ' | ||
'to drop or freeze all of them or more. Continuing with 5 layers.') | ||
drop_freeze_number = 5 | ||
|
||
layer_keys = ["layer_1", "layer_2", "layer_3", "lstm", "layer_5", "layer_6"] | ||
if mode == "drop": | ||
layer_keys = layer_keys[-1 * int(drop_freeze_number):] | ||
elif mode == "freeze": | ||
layer_keys = layer_keys[:-1 * int(drop_freeze_number)] | ||
else: | ||
raise ValueError | ||
return layer_keys | ||
|
||
|
||
def _checkpoint_path_or_none(checkpoint_filename): | ||
checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename) | ||
if not checkpoint: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason not to build up
new_train_vars
from empty, something like:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seemed more intuitive, we want to train all except the filtered layers.
Your example doesn't work by the way, because the
filter_vars
contain names likelayer_1
andtrain_vars
have the full layer namelayer_1:dense:0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hence the "something like" :) -- You could of course use
find()
or something like that. I have no particularly strong feeling about it, but err on the side of simplicity.