Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Estimator] refactor estimator to allow overriding evaluate/fit of a batch #16678

Merged
merged 3 commits into from
Oct 31, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 84 additions & 30 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ def prepare_loss_and_metrics(self):
Based on loss functions and training metrics in estimator
Create metric wrappers to record loss values,
Create copies of train loss/metric objects to record validation values
Returns train_metrics and val_metrics

Returns
-------
train_metrics, val_metrics
"""
if any(not hasattr(self, attribute) for attribute in
['train_metrics', 'val_metrics']):
Expand All @@ -199,21 +201,50 @@ def prepare_loss_and_metrics(self):
self.val_metrics.append(val_metric)
return self.train_metrics, self.val_metrics

def evaluate_batch(self,
val_batch,
val_metrics,
batch_axis=0):
"""Evaluate model on a batch of validation data.

Parameters
----------
val_batch : tuple
Data and label of a batch from the validation data loader.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
metric.update(0, loss)
else:
metric.update(label, pred)

def evaluate(self,
val_data,
val_metrics,
batch_axis=0):
"""Evaluate model on validation data

Parameters
----------
val_data : DataLoader
Validation data loader with data and labels.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
"""Evaluate model on validation data.

This function calls :py:func:`evaluate_batch` on each of the batches from the
validation data loader. Thus, for custom use cases, it's possible to inherit the
estimator class and override :py:func:`evaluate_batch`.

Parameters
----------
val_data : DataLoader
Validation data loader with data and labels.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
if not isinstance(val_data, DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
Expand All @@ -223,15 +254,44 @@ def evaluate(self,
metric.reset()

for _, batch in enumerate(val_data):
data, label = self._get_data_and_label(batch, self.context, batch_axis)
self.evaluate_batch(batch, val_metrics, batch_axis)

def fit_batch(self, train_batch,
batch_axis=0):
"""Trains the model on a batch of training data.

Parameters
----------
train_batch : tuple
Data and label of a batch from the training data loader.
batch_axis : int, default 0
Batch axis to split the training data into devices.

Returns
-------
data: List of NDArray
Sharded data from the batch.
label: List of NDArray
Sharded label from the batch.
pred: List of NDArray
Prediction of each of the shareded batch.
loss: List of NDArray
Loss of each of the shareded batch.
"""
data, label = self._get_data_and_label(train_batch, self.context, batch_axis)

batch_size = train_batch[0].shape[batch_axis]

with autograd.record():
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
metric.update(0, loss)
else:
metric.update(label, pred)

for l in loss:
l.backward()

self.trainer.step(batch_size)

return data, label, pred, loss

def fit(self, train_data,
val_data=None,
Expand All @@ -243,6 +303,10 @@ def fit(self, train_data,
number of epochs or batches. The batch size is inferred from the
data loader's batch_size.

This function calls :py:func:`fit_batch` on each of the batches from the
training data loader. Thus, for custom use cases, it's possible to inherit the
estimator class and override :py:func:`fit_batch`.

Parameters
----------
train_data : DataLoader
Expand Down Expand Up @@ -293,22 +357,12 @@ def fit(self, train_data,
handler.epoch_begin(estimator_ref)

for i, batch in enumerate(train_data):
data, label = self._get_data_and_label(batch, self.context, batch_axis)

batch_size = batch[0].shape[0]

# batch begin
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)

with autograd.record():
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]

for l in loss:
l.backward()
_, label, pred, loss = self.fit_batch(batch, batch_axis)

self.trainer.step(batch_size)
# batch end

batch_end_result = []
Expand Down