Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

fix keras examples #72

Merged
merged 3 commits into from
May 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/conv_filter_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def deprocess_image(x):
return x

# build the VGG16 network with ImageNet weights
model = vgg16.VGG16(weights='imagenet', include_top=False)
model = vgg16.VGG16(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))
print('Model loaded.')

model.summary()
Expand Down
4 changes: 4 additions & 0 deletions examples/conv_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from keras.layers.normalization import BatchNormalization
import numpy as np
import pylab as plt
from keras import backend as K

if K.backend() == 'mxnet':
raise NotImplementedError("MXNet Backend: ConvLSTM2D Layer is not supported yet.")

# We create a layer which take as input movies of shape
# (n_frames, width, height, channels) and returns a movie
Expand Down
2 changes: 2 additions & 0 deletions examples/deep_dream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from keras.applications import inception_v3
from keras import backend as K

if K.backend() == 'mxnet':
raise NotImplementedError("MXNet Backend: Symbolic Gradients is not supported yet.")
parser = argparse.ArgumentParser(description='Deep Dreams with Keras.')
parser.add_argument('base_image_path', metavar='base', type=str,
help='Path to the image to transform.')
Expand Down
12 changes: 10 additions & 2 deletions examples/imdb_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.layers import Dense, Embedding
from keras.layers import LSTM
from keras.datasets import imdb
from keras import backend as K

max_features = 20000
maxlen = 80 # cut texts after this number of words (among top max_features most common words)
Expand All @@ -37,8 +38,15 @@

print('Build model...')
model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))

# MXNet backend does not support dropout in LSTM and cannot automatically infer shape
if K.backend() == 'mxnet':
# specifying input_length and removed dropout params
model.add(Embedding(max_features, 128, input_length=maxlen))
model.add(LSTM(128, unroll=True))
else:
model.add(Embedding(max_features, 128))
model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(1, activation='sigmoid'))

# try using different optimizers and different optimizer configs
Expand Down
8 changes: 6 additions & 2 deletions examples/lstm_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,18 @@
decoder_target_data[i, t - 1, target_token_index[char]] = 1.

# Define an input sequence and process it.
encoder_inputs = Input(shape=(None, num_encoder_tokens))
# MXNet backend RNN required input shape, for TensorFlow backend, you can provide the shape as:
# encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder_inputs = Input(shape=(max_encoder_seq_length, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None, num_decoder_tokens))
# MXNet backend RNN required input shape, for TensorFlow backend, you can provide the shape as:
# decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_inputs = Input(shape=(max_decoder_seq_length, num_decoder_tokens))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
Expand Down
11 changes: 8 additions & 3 deletions examples/mnist_denoising_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@
(x_train, _), (x_test, _) = mnist.load_data()

image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
if K.image_data_format() == 'channels_first':
x_train = np.reshape(x_train, [-1, 1, image_size, image_size])
x_test = np.reshape(x_test, [-1, 1, image_size, image_size])
input_shape = (1, image_size, image_size)
else:
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
input_shape = (image_size, image_size, 1)
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

Expand All @@ -52,7 +58,6 @@
x_test_noisy = np.clip(x_test_noisy, 0., 1.)

# Network parameters
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
latent_dim = 16
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist_swwae.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def getwhere(x):
y_prepool, y_postpool = x
return K.gradients(K.sum(y_postpool), y_prepool)

if K.backend() == 'tensorflow':
if K.backend() != 'theano':
raise RuntimeError('This example can only run with the '
'Theano backend for the time being, '
'because it requires taking the gradient '
Expand Down
3 changes: 3 additions & 0 deletions examples/neural_doodle.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
from keras.preprocessing.image import load_img, img_to_array
from keras.applications import vgg19

if K.backend() == 'mxnet':
raise NotImplementedError("MXNet Backend: Symbolic Gradients is not supported yet.")

# Command line arguments
parser = argparse.ArgumentParser(description='Keras neural doodle example')
parser.add_argument('--nlabels', type=int,
Expand Down
3 changes: 3 additions & 0 deletions examples/neural_style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
from keras.applications import vgg19
from keras import backend as K

if K.backend() == 'mxnet':
raise NotImplementedError("MXNet Backend: Symbolic Gradients is not supported yet.")

parser = argparse.ArgumentParser(description='Neural style transfer with Keras.')
parser.add_argument('base_image_path', metavar='base', type=str,
help='Path to the image to transform.')
Expand Down
3 changes: 2 additions & 1 deletion examples/pretrained_word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
print('Indexing word vectors.')

embeddings_index = {}
with open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt')) as f:
# change encoding to utf-8 for glove.6B dataset
with open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt'), encoding='utf-8') as f:
for line in f:
values = line.split()
word = values[0]
Expand Down
3 changes: 2 additions & 1 deletion examples/variational_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def sampling(args):
xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(xent_loss + kl_loss)

if K.backend() == 'mxnet':
raise NotImplementedError("MXNet Backend: Custom loss is not supported yet.")
vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')
vae.summary()
Expand Down
2 changes: 2 additions & 0 deletions examples/variational_autoencoder_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def sampling(args):
K.flatten(x_decoded_mean_squash))
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(xent_loss + kl_loss)
if K.backend() == 'mxnet':
raise NotImplementedError("MXNet Backend: Custom loss is not supported yet.")
vae.add_loss(vae_loss)

vae.compile(optimizer='rmsprop')
Expand Down
22 changes: 22 additions & 0 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3501,6 +3501,28 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
return KerasSymbol(sym)


# CTC(Connectionist Temporal Classification)

def ctc_batch_cost(y_true, y_pred, input_length, label_length):
"""Runs CTC loss algorithm on each batch element.

# Arguments
y_true: tensor `(samples, max_string_length)`
containing the truth labels.
y_pred: tensor `(samples, time_steps, num_categories)`
containing the prediction, or output of the softmax.
input_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_pred`.
label_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_true`.

# Returns
Tensor with shape (samples,1) containing the
CTC loss of each element.
"""
raise NotImplementedError("MXNet Backend: CTC is not supported yet.")


# HIGH ORDER FUNCTIONS

def map_fn(fn, elems, name=None, dtype=None):
Expand Down