Skip to content

Latest commit

 

History

History
124 lines (93 loc) · 4.46 KB

MoXing_API_Flowers.md

File metadata and controls

124 lines (93 loc) · 4.46 KB

1.4. 训练 resnet50-flowers

数据集下载地址:http://download.tensorflow.org/example_images/flower_photos.tgz

将数据集解压到目录/tmp/tensorflow/

1.4.1. 定义主函数和超参

引入MoXing-TensorFlow

import moxing.tensorflow as mox

通过mox.get_flag获取命令行参数num_gpusworker_hosts,从而获取当前使用的GPU数量和节点数量

num_gpus = mox.get_flag('num_gpus')
num_workers = len(mox.get_flag('worker_hosts').split(','))

flowers数据集的格式如下:

/tmp/tensorflow/flower_photos
    |-- daisy
        |-- xxx0.jpg
        ...
    |-- dandelion
        |-- xxx1.jpg
        ...
    |-- roses
        |-- xxx2.jpg
        ...
    |-- sunflowers
        |-- xxx3.jpg
        ...
    |-- tulips
        |-- xxx4.jpg
        ...

每一个子目录代表一个分类,每个分类下有若干张图片,对于这种类型的数据集,可以使用mox.ImageClassificationRawMetadatamox.ImageClassificationRawDataset来读取。MoXing-TensorFlow预置了若干种解析数据集的类,一般会使用数据集元信息类+数据集读取类的模式来读取。数据集元信息类不会创建任何TensorFlow的数据流图,建议在main方法中直接实例化,那么代码的其他地方都能获取数据集的元信息(如样本数量,分类数量)。数据集读取类必须在input_fn中实例化,该类的实例化会在TensorFlow数据流图中创建节点。

创建一个数据集元信息类base_dir即指定flower_photos所在目录

data_meta = mox.ImageClassificationRawMetadata(base_dir=flags.data_url)

input_fn中创建一个数据增强方法(基于resnet50)和一个数据集读取类

def input_fn(mode):
  # 创建一个数据增强方法,该方法基于resnet50论文实现
  augmentation_fn = mox.get_data_augmentation_fn(name='resnet_v1_50',
                                                 run_mode=mode,
                                                 output_height=224,
                                                 output_width=224)

  # 创建`数据集读取类`,并将数据增强方法传入,最多读取20个epoch
  dataset = mox.ImageClassificationRawDataset(data_meta,
                                              batch_size=flags.batch_size,
                                              num_epochs=20,
                                              augmentation_fn=augmentation_fn)
  image, label = dataset.get(['image', 'label'])
  return image, label

定义model_fn

def model_fn(inputs, mode):
  images, labels = inputs

  # 获取一个resnet50的模型,输入images,输入logits和end_points,这里不关心end_points,仅取logits
  logits, _ = mox.get_model_fn(name='resnet_v1_50',
                               run_mode=mode,
                               num_classes=data_meta.num_classes,
                               weight_decay=0.00004)(images)

  # 计算交叉熵损失值
  labels_one_hot = slim.one_hot_encoding(labels, data_meta.num_classes)
  loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels_one_hot)

  # 获取正则项损失值,并加到loss上,这里必须要用mox.get_collection代替tf.get_collection
  regularization_losses = mox.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  regularization_loss = tf.add_n(regularization_losses)
  loss = loss + regularization_loss

  # 计算分类正确率
  accuracy = tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits, labels, 1), tf.float32))

  # 返回MoXing-TensorFlow用于定义模型的类ModelSpec
  return mox.ModelSpec(loss=loss, log_info={'loss': loss, 'accuracy': accuracy})

定义optimizer_fn

from moxing.tensorflow.optimizer import learning_rate_scheduler

def optimizer_fn():
  # 使用分段式学习率,0-10个epoch为0.01,10-20个epoch为0.001
  lr = learning_rate_scheduler.piecewise_lr('10:0.01,20:0.001',
                                            num_samples=data_meta.total_num_samples,
                                            global_batch_size=flags.batch_size * num_gpus * num_workers)
  return tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)

完整代码请参考:mox_flowers.py

执行训练:

python mox_flowers.py \
--data_url=/tmp/tensorflow/flower_photos \
--train_url=/tmp/flowers \
--num_gpus=4

使用 4 * Nvidia-Tesla-K80 运行时间大约为:698秒,在训练集上的训练精度约为:50%