Skip to content

Commit

Permalink
[Estimator] refactor estimator to allow overriding evaluate/fit of a …
Browse files Browse the repository at this point in the history
…batch (apache#16678)

* refactor estimator to allow overriding evaluate/fit of a batch

* add doc to explain call structure and how to override

* fix and doc
  • Loading branch information
szha authored and yajiedesign committed Nov 6, 2019
1 parent 5497522 commit 07247a3
Showing 1 changed file with 84 additions and 30 deletions.
114 changes: 84 additions & 30 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ def prepare_loss_and_metrics(self):
Based on loss functions and training metrics in estimator
Create metric wrappers to record loss values,
Create copies of train loss/metric objects to record validation values
Returns train_metrics and val_metrics
Returns
-------
train_metrics, val_metrics
"""
if any(not hasattr(self, attribute) for attribute in
['train_metrics', 'val_metrics']):
Expand All @@ -190,21 +192,50 @@ def prepare_loss_and_metrics(self):
self.val_metrics.append(val_metric)
return self.train_metrics, self.val_metrics

def evaluate_batch(self,
val_batch,
val_metrics,
batch_axis=0):
"""Evaluate model on a batch of validation data.
Parameters
----------
val_batch : tuple
Data and label of a batch from the validation data loader.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
metric.update(0, loss)
else:
metric.update(label, pred)

def evaluate(self,
val_data,
val_metrics,
batch_axis=0):
"""Evaluate model on validation data
Parameters
----------
val_data : DataLoader
Validation data loader with data and labels.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
"""Evaluate model on validation data.
This function calls :py:func:`evaluate_batch` on each of the batches from the
validation data loader. Thus, for custom use cases, it's possible to inherit the
estimator class and override :py:func:`evaluate_batch`.
Parameters
----------
val_data : DataLoader
Validation data loader with data and labels.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
if not isinstance(val_data, DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
Expand All @@ -214,15 +245,44 @@ def evaluate(self,
metric.reset()

for _, batch in enumerate(val_data):
data, label = self._get_data_and_label(batch, self.context, batch_axis)
self.evaluate_batch(batch, val_metrics, batch_axis)

def fit_batch(self, train_batch,
batch_axis=0):
"""Trains the model on a batch of training data.
Parameters
----------
train_batch : tuple
Data and label of a batch from the training data loader.
batch_axis : int, default 0
Batch axis to split the training data into devices.
Returns
-------
data: List of NDArray
Sharded data from the batch.
label: List of NDArray
Sharded label from the batch.
pred: List of NDArray
Prediction of each of the shareded batch.
loss: List of NDArray
Loss of each of the shareded batch.
"""
data, label = self._get_data_and_label(train_batch, self.context, batch_axis)

batch_size = train_batch[0].shape[batch_axis]

with autograd.record():
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
metric.update(0, loss)
else:
metric.update(label, pred)

for l in loss:
l.backward()

self.trainer.step(batch_size)

return data, label, pred, loss

def fit(self, train_data,
val_data=None,
Expand All @@ -234,6 +294,10 @@ def fit(self, train_data,
number of epochs or batches. The batch size is inferred from the
data loader's batch_size.
This function calls :py:func:`fit_batch` on each of the batches from the
training data loader. Thus, for custom use cases, it's possible to inherit the
estimator class and override :py:func:`fit_batch`.
Parameters
----------
train_data : DataLoader
Expand Down Expand Up @@ -284,22 +348,12 @@ def fit(self, train_data,
handler.epoch_begin(estimator_ref)

for i, batch in enumerate(train_data):
data, label = self._get_data_and_label(batch, self.context, batch_axis)

batch_size = batch[0].shape[0]

# batch begin
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)

with autograd.record():
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]

for l in loss:
l.backward()
_, label, pred, loss = self.fit_batch(batch, batch_axis)

self.trainer.step(batch_size)
# batch end

batch_end_result = []
Expand Down

0 comments on commit 07247a3

Please sign in to comment.