diff --git a/doc/python/io.md b/doc/python/io.md index f6464a879d68..0dbda7b3b901 100644 --- a/doc/python/io.md +++ b/doc/python/io.md @@ -40,7 +40,7 @@ The following code gives an example of creating a Cifar data iterator. >>> # Dataset Paramter >>> # Impulsary >>> # indicating the image size after preprocessing ->>> input_shape=(3,28,28), +>>> data_shape=(3,28,28), >>> # Batch Paramter >>> # Impulsary >>> # tells how many images in a batch @@ -51,7 +51,7 @@ The following code gives an example of creating a Cifar data iterator. >>> mean_img="data/cifar/cifar10_mean.bin", >>> # Augmentation Parameter >>> # Optional ->>> # randomly crop a patch of the input_shape from the original image +>>> # randomly crop a patch of the data_shape from the original image >>> rand_crop=True, >>> # Augmentation Parameter >>> # Optional diff --git a/example/python-howto/data_iter.py b/example/python-howto/data_iter.py index b25c3fde0d0a..d1cebc0a470d 100644 --- a/example/python-howto/data_iter.py +++ b/example/python-howto/data_iter.py @@ -14,18 +14,18 @@ # Dataset/Augment Paramter # Impulsary # indicating the image size after preprocessing - input_shape=(3,28,28), + data_shape=(3,28,28), # Batch Paramter # Impulsary # tells how many images in a batch - batch_size=100, + batch_size=100, # Augmentation Parameter # Optional # when offers mean_img, each image will substract the mean value at each pixel mean_img="data/cifar/cifar10_mean.bin", # Augmentation Parameter # Optional - # randomly crop a patch of the input_shape from the original image + # randomly crop a patch of the data_shape from the original image rand_crop=True, # Augmentation Parameter # Optional @@ -42,7 +42,7 @@ # Backend Parameter # Optional # Prefetch buffer size - prefetch_buffer=1) + prefetch_buffer=4) batchidx = 0 for data, label in dataiter: diff --git a/include/mxnet/io.h b/include/mxnet/io.h index 7e2cf8180fd5..1c9a6bc8d61a 100644 --- a/include/mxnet/io.h +++ b/include/mxnet/io.h @@ -55,22 +55,13 @@ struct DataInst { }; // struct DataInst /*! - * \brief a standard batch of data commonly used by iterator - * a databatch contains multiple TBlobs. Each Tblobs has - * a name stored in a map. There's no different between - * data and label, how we use them is to see the DNN implementation. + * \brief DataBatch of NDArray, returned by Iterator */ struct DataBatch { - public: /*! \brief content of dense data, if this DataBatch is dense */ std::vector data; /*! \brief extra data to be fed to the network */ std::string extra_data; - public: - /*! \brief constructor */ - DataBatch(void) {} - /*! \brief destructor */ - ~DataBatch() {} }; // struct DataBatch /*! \brief typedef the factory function of data iterator */ @@ -122,7 +113,7 @@ struct DataIteratorReg * * \code * // example of registering a imagerec iterator - * MXNET_REGISTER_IO_CHAINED_ITERATOR(ImageRecordIter, + * MXNET_REGISTER_IO_CHAINED_ITERATOR(ImageRecordIter, * ImageRecordIter, ImageRecBatchLoader, Prefetcher) * .describe("batched image record data iterator"); * diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index 6a2d3c34ee92..f3b4154425da 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -69,7 +69,7 @@ struct ImageAugmentParam : public dmlc::Parameter { /*! \brief whether to print augment info */ bool silent; /*! \brief shape of the image data*/ - TShape input_shape; + TShape data_shape; // declare parameters DMLC_DECLARE_PARAMETER(ImageAugmentParam) { DMLC_DECLARE_FIELD(rand_crop).set_default(true) @@ -120,9 +120,9 @@ struct ImageAugmentParam : public dmlc::Parameter { .describe("Augmentation Param: Maximum value of illumination variation."); DMLC_DECLARE_FIELD(silent).set_default(true) .describe("Augmentation Param: Whether to print augmentor info."); - index_t input_shape_default[] = {3, 224, 224}; - DMLC_DECLARE_FIELD(input_shape) - .set_default(TShape(input_shape_default, input_shape_default + 3)) + index_t data_shape_default[] = {3, 224, 224}; + DMLC_DECLARE_FIELD(data_shape) + .set_default(TShape(data_shape_default, data_shape_default + 3)) .set_expect_ndim(3).enforce_nonzero() .describe("Dataset Param: Input shape of the neural net."); } @@ -234,20 +234,20 @@ class ImageAugmenter { y /= 2; x /= 2; } cv::Rect roi(x, y, rand_crop_size, rand_crop_size); - cv::resize(res(roi), res, cv::Size(param_.input_shape[1], param_.input_shape[2])); + cv::resize(res(roi), res, cv::Size(param_.data_shape[1], param_.data_shape[2])); } else { - CHECK(static_cast(res.cols) >= param_.input_shape[1] \ - && static_cast(res.rows) >= param_.input_shape[2]) + CHECK(static_cast(res.cols) >= param_.data_shape[1] \ + && static_cast(res.rows) >= param_.data_shape[2]) << "input image size smaller than input shape"; - mshadow::index_t y = res.rows - param_.input_shape[2]; - mshadow::index_t x = res.cols - param_.input_shape[1]; + mshadow::index_t y = res.rows - param_.data_shape[2]; + mshadow::index_t x = res.cols - param_.data_shape[1]; if (param_.rand_crop != 0) { y = NextUInt32(y + 1, prnd); x = NextUInt32(x + 1, prnd); } else { y /= 2; x /= 2; } - cv::Rect roi(x, y, param_.input_shape[1], param_.input_shape[2]); + cv::Rect roi(x, y, param_.data_shape[1], param_.data_shape[2]); res = res(roi); } return res; @@ -300,24 +300,24 @@ class ImageAugmenter { } } img_.Resize(mshadow::Shape3((*p_data).shape_[0], - param_.input_shape[1], param_.input_shape[2])); - if (param_.input_shape[1] == 1) { + param_.data_shape[1], param_.data_shape[2])); + if (param_.data_shape[1] == 1) { img_ = (*p_data) * param_.scale; } else { - CHECK(p_data->size(1) >= param_.input_shape[1] && p_data->size(2) >= param_.input_shape[2]) + CHECK(p_data->size(1) >= param_.data_shape[1] && p_data->size(2) >= param_.data_shape[2]) << "Data size must be bigger than the input size to net."; - mshadow::index_t yy = p_data->size(1) - param_.input_shape[1]; - mshadow::index_t xx = p_data->size(2) - param_.input_shape[2]; + mshadow::index_t yy = p_data->size(1) - param_.data_shape[1]; + mshadow::index_t xx = p_data->size(2) - param_.data_shape[2]; if (param_.rand_crop != 0 && (yy != 0 || xx != 0)) { yy = NextUInt32(yy + 1, prnd); xx = NextUInt32(xx + 1, prnd); } else { yy /= 2; xx /= 2; } - if (p_data->size(1) != param_.input_shape[1] && param_.crop_y_start != -1) { + if (p_data->size(1) != param_.data_shape[1] && param_.crop_y_start != -1) { yy = param_.crop_y_start; } - if (p_data->size(2) != param_.input_shape[2] && param_.crop_x_start != -1) { + if (p_data->size(2) != param_.data_shape[2] && param_.crop_x_start != -1) { xx = param_.crop_x_start; } float contrast = NextDouble(prnd) * param_.max_random_contrast \ diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 68f0e0fd7303..6752db02d658 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2015 by Contributors * \file iter_batchloader.h - * \brief define a batch adapter to create tblob batch + * \brief define a batch adapter to create tblob batch */ #ifndef MXNET_IO_ITER_BATCHLOADER_H_ #define MXNET_IO_ITER_BATCHLOADER_H_ @@ -22,30 +22,26 @@ struct BatchParam : public dmlc::Parameter { /*! \brief label width */ index_t batch_size; /*! \brief input shape */ - TShape input_shape; + TShape data_shape; /*! \brief label width */ index_t label_width; /*! \brief use round roubin to handle overflow batch */ bool round_batch; - /*! \brief skip read */ - bool test_skipread; /*! \brief silent */ bool silent; // declare parameters DMLC_DECLARE_PARAMETER(BatchParam) { DMLC_DECLARE_FIELD(batch_size) .describe("Batch Param: Batch size."); - index_t input_shape_default[] = {3, 224, 224}; - DMLC_DECLARE_FIELD(input_shape) - .set_default(TShape(input_shape_default, input_shape_default + 3)) + index_t data_shape_default[] = {3, 224, 224}; + DMLC_DECLARE_FIELD(data_shape) + .set_default(TShape(data_shape_default, data_shape_default + 3)) .set_expect_ndim(3).enforce_nonzero() - .describe("Dataset Param: Input shape of the neural net."); + .describe("Dataset Param: Shape of each instance generated by the DataIter."); DMLC_DECLARE_FIELD(label_width).set_default(1) .describe("Dataset Param: Label width."); DMLC_DECLARE_FIELD(round_batch).set_default(true) .describe("Batch Param: Use round robin to handle overflow batch."); - DMLC_DECLARE_FIELD(test_skipread).set_default(false) - .describe("Batch Param: Skip read for testing."); DMLC_DECLARE_FIELD(silent).set_default(false) .describe("Batch Param: Whether to print batch information."); } @@ -70,8 +66,9 @@ class BatchLoader : public IIterator { base_->Init(kwargs); std::vector data_shape_vec; data_shape_vec.push_back(param_.batch_size); - for (size_t shape_dim = 0; shape_dim < param_.input_shape.ndim(); shape_dim++) - data_shape_vec.push_back(param_.input_shape[shape_dim]); + for (size_t shape_dim = 0; shape_dim < param_.data_shape.ndim(); ++shape_dim) { + data_shape_vec.push_back(param_.data_shape[shape_dim]); + } data_shape_ = TShape(data_shape_vec.begin(), data_shape_vec.end()); std::vector label_shape_vec; label_shape_vec.push_back(param_.batch_size); @@ -96,12 +93,7 @@ class BatchLoader : public IIterator { } inline bool Next(void) { out_.num_batch_padd = 0; - - // skip read if in head version - if (param_.test_skipread != 0 && head_ == 0) - return true; - else - this->head_ = 0; + this->head_ = 0; // if overflow from previous round, directly return false, until before first is called if (num_overflow_ != 0) return false; diff --git a/src/io/iter_image_recordio.cc b/src/io/iter_image_recordio.cc index 17852a42af6a..f8aded4defa7 100644 --- a/src/io/iter_image_recordio.cc +++ b/src/io/iter_image_recordio.cc @@ -93,14 +93,10 @@ struct ImageRecParserParam : public dmlc::Parameter { std::string path_imglist; /*! \brief path to image recordio */ std::string path_imgrec; - /*! \brief virtually split the data into n parts */ - int num_parts; - /*! \brief only read the i-th part */ - int part_index; /*! \brief label-width */ int label_width; /*! \brief input shape */ - TShape input_shape; + TShape data_shape; /*! \brief number of threads */ int preprocess_threads; /*! \brief whether to remain silent */ @@ -113,15 +109,11 @@ struct ImageRecParserParam : public dmlc::Parameter { .describe("Dataset Param: Path to image record file."); DMLC_DECLARE_FIELD(label_width).set_lower_bound(1).set_default(1) .describe("Dataset Param: How many labels for an image."); - DMLC_DECLARE_FIELD(num_parts).set_lower_bound(1).set_default(1) - .describe("Dataset Param: virtually split the data into n parts"); - DMLC_DECLARE_FIELD(part_index).set_default(0) - .describe("Dataset Param: only read the i-th part"); - index_t input_shape_default[] = {3, 224, 224}; - DMLC_DECLARE_FIELD(input_shape) - .set_default(TShape(input_shape_default, input_shape_default + 3)) + index_t data_shape_default[] = {3, 224, 224}; + DMLC_DECLARE_FIELD(data_shape) + .set_default(TShape(data_shape_default, data_shape_default + 3)) .enforce_nonzero() - .describe("Dataset Param: Input shape of the neural net"); + .describe("Dataset Param: Shape of each instance generated by the DataIter."); DMLC_DECLARE_FIELD(preprocess_threads).set_lower_bound(1).set_default(4) .describe("Backend Param: Number of thread to do preprocessing."); DMLC_DECLARE_FIELD(silent).set_default(false) @@ -184,8 +176,8 @@ inline void ImageRecordIOParser::Init( int maxthread, threadget; #pragma omp parallel { - // why ? (muli) - maxthread = std::max(omp_get_num_procs() / 2 - 1, 1); + // be conservative, set number of real cores + maxthread = std::max(omp_get_num_procs() / 2, 1); } param_.preprocess_threads = std::min(maxthread, param_.preprocess_threads); #pragma omp parallel num_threads(param_.preprocess_threads) @@ -208,9 +200,12 @@ inline void ImageRecordIOParser::Init( CHECK(param_.path_imgrec.length() != 0) << "ImageRecordIOIterator: must specify image_rec"; + // TODO(mu, tianjun) add DMLC env variable to detect parition + const int part_index = 0; + const int num_parts = 1; source_ = dmlc::InputSplit::Create( - param_.path_imgrec.c_str(), param_.part_index, - param_.num_parts, "recordio"); + param_.path_imgrec.c_str(), part_index, + num_parts, "recordio"); // use 64 MB chunk when possible source_->HintChunkSize(8 << 20UL); #else @@ -279,7 +274,7 @@ ParseNext(std::vector *out_vec) { out.Clear(); for (size_t j = 0; j < opencv_out.Size(); j++) { out.Push(opencv_out.Index(j), - mshadow::Shape3(param_.input_shape[0], param_.input_shape[1], param_.input_shape[2]), + mshadow::Shape3(param_.data_shape[0], param_.data_shape[1], param_.data_shape[2]), mshadow::Shape1(param_.label_width)); DataInst inst = out.Back(); DataInst opencv_inst = opencv_out[j]; diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index ddf0f053f34b..2449d4a38bc5 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -24,145 +24,100 @@ namespace io { struct PrefetcherParam : public dmlc::Parameter { /*! \brief number of prefetched batches */ size_t prefetch_buffer; - /*! \brief label width */ - index_t batch_size; - /*! \brief input shape */ - TShape input_shape; - /*! \brief label width */ - index_t label_width; // declare parameters DMLC_DECLARE_PARAMETER(PrefetcherParam) { - DMLC_DECLARE_FIELD(prefetch_buffer).set_default(1) - .describe("Backend Param: Number of prefetched batches."); - DMLC_DECLARE_FIELD(batch_size) - .describe("Batch Param: Batch size."); - index_t input_shape_default[] = {3, 224, 224}; - DMLC_DECLARE_FIELD(input_shape) - .set_default(TShape(input_shape_default, input_shape_default + 3)) - .enforce_nonzero() - .describe("Dataset Param: Input shape of the neural net."); - DMLC_DECLARE_FIELD(label_width).set_default(1) - .describe("Dataset Param: Label width."); + DMLC_DECLARE_FIELD(prefetch_buffer).set_default(4) + .describe("Backend Param: Number of prefetched parameters"); } }; // iterator on image recordio class PrefetcherIter : public IIterator { public: - explicit PrefetcherIter(IIterator* base) : loader_(base) { - pdata_vec.clear(); - plabel_vec.clear(); + explicit PrefetcherIter(IIterator* base) + : out_(nullptr), loader_(base) { } - virtual ~PrefetcherIter(void) { - iter_.Destroy(); - for (size_t i = 0; i < pdata_vec.size(); i++) { - delete[] pdata_vec[i]; - delete[] plabel_vec[i]; + + ~PrefetcherIter() { + while (recycle_queue_.size() != 0) { + DataBatch *batch = recycle_queue_.front(); + recycle_queue_.pop(); + delete batch; } - delete loader_; + delete out_; + iter_.Destroy(); } + virtual void Init(const std::vector >& kwargs) { std::vector > kwargs_left; // init image rec param kwargs_left = param_.InitAllowUnknown(kwargs); // use the kwarg to init batch loader loader_->Init(kwargs); - // create the shape - std::vector data_shape_vec; - data_shape_vec.push_back(param_.batch_size); - for (size_t shape_dim = 0; shape_dim < param_.input_shape.ndim(); shape_dim++) - data_shape_vec.push_back(param_.input_shape[shape_dim]); - data_shape_ = TShape(data_shape_vec.begin(), data_shape_vec.end()); - std::vector label_shape_vec; - label_shape_vec.push_back(param_.batch_size); - label_shape_vec.push_back(param_.label_width); - label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end()); + // maximum prefetch threaded iter internal size + const int kMaxPrefetchBuffer = 16; // init thread iter - iter_.set_max_capacity(param_.prefetch_buffer); - iter_.Init([this](TBlobBatch **dptr) { - bool load_success = loader_->Next(); - if (load_success == false) - return false; - if (*dptr == NULL) { - *dptr = new TBlobBatch(); - // create the spaces and record the pointers - real_t* pdata = new real_t[data_shape_.Size()]; - pdata_vec.push_back(pdata); - real_t* plabel = new real_t[label_shape_.Size()]; - plabel_vec.push_back(plabel); - (*dptr)->data.push_back(TBlob(pdata, data_shape_, mshadow::cpu::kDevMask)); - (*dptr)->data.push_back(TBlob(plabel, label_shape_, mshadow::cpu::kDevMask)); - } + iter_.set_max_capacity(kMaxPrefetchBuffer); + + iter_.Init([this](DataBatch **dptr) { + if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); - if (data_shape_.ndim() == 4) { - mshadow::Copy((*dptr)->data[0].get(), - batch.data[0].get()); - } else if (data_shape_.ndim() == 2) { - mshadow::Copy((*dptr)->data[0].get(), - batch.data[0].get()); - } else { - // TODO(tianjun): ? - LOG(FATAL) << "fail"; + + if (*dptr == nullptr) { + // allocate databatch + *dptr = new DataBatch(); + (*dptr)->data.resize(batch.data.size()); + for (size_t i = 0; i < batch.data.size(); ++i) { + (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); + } } - mshadow::Copy((*dptr)->data[1].get(), - batch.data[1].get()); - return load_success; + CHECK(batch.data.size() == (*dptr)->data.size()); + // copy data over + for (size_t i = 0; i < batch.data.size(); ++i) { + CHECK_EQ((*dptr)->data.at(i).shape(), batch.data[i].shape_); + mshadow::Copy(((*dptr)->data)[i].data().FlatTo2D(), + batch.data[i].FlatTo2D()); + } + return true; }, [this]() { loader_->BeforeFirst(); }); } + virtual void BeforeFirst(void) { iter_.BeforeFirst(); } + virtual bool Next(void) { - if (ready_batches_.size() == param_.prefetch_buffer) { - TBlobBatch* old_batch = ready_batches_.front(); - for (size_t i = 0; i < old_batch->data.size(); i++) { - NDArray old_ndarray = ready_ndarrays_.front(); - old_ndarray.WaitToWrite(); - ready_ndarrays_.pop(); - } - iter_.Recycle(&old_batch); - ready_batches_.pop(); - } - TBlobBatch* next_batch = NULL; - if (!iter_.Next(&next_batch)) return false; - out_.data.clear(); - // copy the batch - for (size_t i = 0; i < next_batch->data.size(); ++i) { - out_.data.push_back(NDArray(next_batch->data[i], 0)); - ready_ndarrays_.push(out_.data[i]); - } - // push the narrays and batch into the queue - ready_batches_.push(next_batch); - return true; + if (out_ != nullptr) { + recycle_queue_.push(out_); out_ = nullptr; + } + // do recycle + if (recycle_queue_.size() == param_.prefetch_buffer) { + DataBatch *old_batch = recycle_queue_.front(); + // can be more efficienct on engine + for (NDArray& arr : old_batch->data) { + arr.WaitToWrite(); + } + recycle_queue_.pop(); + iter_.Recycle(&old_batch); + } + return iter_.Next(&out_); } virtual const DataBatch &Value(void) const { - return out_; + return *out_; } private: /*! \brief prefetcher parameters */ PrefetcherParam param_; - /*! \brief output data */ - DataBatch out_; - /*! \brief batch holder */ - TBlobBatch out_holder_; - /*! \brief queue to hold the NDArrays for check whether writable */ - std::queue ready_batches_; - /*! \breif ndarrays to wait to write */ - std::queue ready_ndarrays_; - // internal batch loader - IIterator* loader_; + // output data + DataBatch *out_; + // queue to be recycled + std::queue recycle_queue_; // backend thread - dmlc::ThreadedIter iter_; - /*! \brief data shape */ - TShape data_shape_; - /*! \brief label shape */ - TShape label_shape_; - /*! \brief log the pointers of the space created for data*/ - std::vector pdata_vec; - /*! \brief log the pointers of the space created for label*/ - std::vector plabel_vec; + dmlc::ThreadedIter iter_; + // internal batch loader + std::unique_ptr > loader_; }; } // namespace io } // namespace mxnet diff --git a/tests/python/train/test_conv.py b/tests/python/train/test_conv.py index a27222c1227a..e411c9e9c8f6 100644 --- a/tests/python/train/test_conv.py +++ b/tests/python/train/test_conv.py @@ -73,12 +73,12 @@ def Update(grad, weight): train_dataiter = mx.io.MNISTIter( image="data/train-images-idx3-ubyte", label="data/train-labels-idx1-ubyte", - input_shape=(1, 28, 28), + data_shape=(1, 28, 28), batch_size=batch_size, shuffle=True, flat=False, silent=False, seed=10) val_dataiter = mx.io.MNISTIter( image="data/t10k-images-idx3-ubyte", label="data/t10k-labels-idx1-ubyte", - input_shape=(1, 28, 28), + data_shape=(1, 28, 28), batch_size=batch_size, shuffle=True, flat=False, silent=False) def test_mnist(): diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index e863b1164258..350adfde274e 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -28,12 +28,12 @@ train_dataiter = mx.io.MNISTIter( image="data/train-images-idx3-ubyte", label="data/train-labels-idx1-ubyte", - input_shape=(784,), + data_shape=(784,), batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) val_dataiter = mx.io.MNISTIter( image="data/t10k-images-idx3-ubyte", label="data/t10k-labels-idx1-ubyte", - input_shape=(784,), + data_shape=(784,), batch_size=batch_size, shuffle=True, flat=True, silent=False) def test_mlp(): diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 8814e758dbd6..501b09df54a4 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -15,7 +15,7 @@ def test_MNISTIter(): train_dataiter = mx.io.MNISTIter( image="data/train-images-idx3-ubyte", label="data/train-labels-idx1-ubyte", - input_shape=(784,), + data_shape=(784,), batch_size=batch_size, shuffle=1, flat=1, silent=0, seed=10) # test_loop nbatch = 60000 / batch_size @@ -44,11 +44,11 @@ def test_Cifar10Rec(): rand_crop=False, and_mirror=False, shuffle=False, - input_shape=(3,28,28), + data_shape=(3,28,28), batch_size=100, preprocess_threads=4, prefetch_buffer=1) - labelcount = [0 for i in range(10)] + labelcount = [0 for i in range(10)] batchcount = 0 for data, label in dataiter: npdata = data.asnumpy().flatten().sum()