Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix and doc
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Oct 31, 2019
1 parent 9abf451 commit 16b02b0
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ def prepare_loss_and_metrics(self):
Based on loss functions and training metrics in estimator
Create metric wrappers to record loss values,
Create copies of train loss/metric objects to record validation values
Returns train_metrics and val_metrics
Returns
-------
train_metrics, val_metrics
"""
if any(not hasattr(self, attribute) for attribute in
['train_metrics', 'val_metrics']):
Expand Down Expand Up @@ -264,6 +266,17 @@ def fit_batch(self, train_batch,
Data and label of a batch from the training data loader.
batch_axis : int, default 0
Batch axis to split the training data into devices.
Returns
-------
data: List of NDArray
Sharded data from the batch.
label: List of NDArray
Sharded label from the batch.
pred: List of NDArray
Prediction of each of the shareded batch.
loss: List of NDArray
Loss of each of the shareded batch.
"""
data, label = self._get_data_and_label(train_batch, self.context, batch_axis)

Expand All @@ -278,6 +291,8 @@ def fit_batch(self, train_batch,

self.trainer.step(batch_size)

return data, label, pred, loss

def fit(self, train_data,
val_data=None,
epochs=None,
Expand Down Expand Up @@ -346,7 +361,7 @@ def fit(self, train_data,
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)

self.fit_batch(batch, batch_axis)
_, label, pred, loss = self.fit_batch(batch, batch_axis)

# batch end

Expand Down

0 comments on commit 16b02b0

Please sign in to comment.