diff --git a/examples/trials/mnist-keras/mnist-keras.py b/examples/trials/mnist-keras/mnist-keras.py index f26dd8c389..2d7dac0004 100644 --- a/examples/trials/mnist-keras/mnist-keras.py +++ b/examples/trials/mnist-keras/mnist-keras.py @@ -63,7 +63,9 @@ def load_mnist_data(args): ''' Load MNIST dataset ''' - (x_train, y_train), (x_test, y_test) = mnist.load_data() + mnist_path = os.path.join(os.environ.get('NNI_OUTPUT_DIR'), 'mnist.npz') + (x_train, y_train), (x_test, y_test) = mnist.load_data(path=mnist_path) + os.remove(mnist_path) x_train = (np.expand_dims(x_train, -1).astype(np.float) / 255.)[:args.num_train] x_test = (np.expand_dims(x_test, -1).astype(np.float) / 255.)[:args.num_test]