Skip to content

Commit

Permalink
Update to ML 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed May 8, 2021
1 parent 7701aef commit f1ecc6f
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
composer.lock
progress.csv
report.json
*.model
*.rbx
*.old
.vscode
.vs
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Rubix ML - MNIST Handwritten Digit Recognizer
# MNIST Handwritten Digit Recognizer
The [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset is a set of 70,000 human-labeled 28 x 28 greyscale images of individual handwritten digits. It is a subset of a larger dataset available from NIST - The National Institute of Standards and Technology. In this tutorial, you'll create your own handwritten digit recognizer using a multilayer neural network trained on the MNIST dataset.

- **Difficulty:** Hard
Expand All @@ -13,7 +13,7 @@ $ composer create-project rubix/mnist
> **Note:** Installation may take longer than usual due to the large dataset.
## Requirements
- [PHP](https://php.net) 7.2 or above
- [PHP](https://php.net) 7.4 or above
- [GD extension](https://www.php.net/manual/en/book.image.php)

#### Recommended
Expand Down Expand Up @@ -89,7 +89,7 @@ $estimator = new PersistentModel(
new Activation(new LeakyReLU()),
new Dropout(0.2),
], 256, new Adam(0.0001))),
new Filesystem('mnist.model', true)
new Filesystem('mnist.rbx', true)
);
```

Expand All @@ -102,14 +102,16 @@ $estimator->train($dataset);
```

### Validation Score and Loss
We can visualize the training progress at each stage by dumping the values of the loss function and validation metric after training. The `steps()` method will output an array containing the values of the default [Cross Entropy](https://docs.rubixml.com/latest/neural-network/cost-functions/cross-entropy.html) cost function and the `scores()` method will return an array of scores from the [F Beta](https://docs.rubixml.com/latest/cross-validation/metrics/f-beta.html) metric.
We can visualize the training progress at each stage by dumping the values of the loss function and validation metric after training. The `steps()` method will output an iterator containing the values of the default [Cross Entropy](https://docs.rubixml.com/latest/neural-network/cost-functions/cross-entropy.html) cost function and the `scores()` method will return an array of scores from the [F Beta](https://docs.rubixml.com/latest/cross-validation/metrics/f-beta.html) metric.

> **Note:** You can change the cost function and validation metric by setting them as hyper-parameters of the learner.
```php
$steps = $estimator->steps();
use Rubix\ML\Extractors\CSV;

$scores = $estimator->scores();
$extractor = new CSV('progress.csv', true);

$extractor->export($estimator->steps());
```

Then, we can plot the values using our favorite plotting software such as [Tableu](https://public.tableau.com/en-us/s/) or [Excel](https://products.office.com/en-us/excel-a). If all goes well, the value of the loss should go down as the value of the validation score goes up. Due to snapshotting, the epoch at which the validation score is highest and the loss is lowest is the point at which the values of the network parameters are taken for the final model. This prevents the network from overfitting the training data by effectively *unlearning* some of the noise in the dataset.
Expand Down Expand Up @@ -158,7 +160,7 @@ In our training script we made sure to save the model before we exited. In our v
use Rubix\ML\PersistentModel;
use Rubix\ML\Persisters\Filesystem;

$estimator = PersistentModel::load(new Filesystem('mnist.model'));
$estimator = PersistentModel::load(new Filesystem('mnist.rbx'));
```

### Make Predictions
Expand Down
4 changes: 2 additions & 2 deletions composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
}
],
"require": {
"php": ">=7.2",
"php": ">=7.4",
"ext-gd": "*",
"rubix/ml": "^0.3.0"
"rubix/ml": "^1.0"
},
"scripts": {
"train": "@php train.php",
Expand Down
15 changes: 5 additions & 10 deletions train.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

include __DIR__ . '/vendor/autoload.php';

use Rubix\ML\Other\Loggers\Screen;
use Rubix\ML\Loggers\Screen;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\PersistentModel;
use Rubix\ML\Pipeline;
Expand All @@ -16,9 +16,7 @@
use Rubix\ML\NeuralNet\ActivationFunctions\LeakyReLU;
use Rubix\ML\NeuralNet\Optimizers\Adam;
use Rubix\ML\Persisters\Filesystem;
use Rubix\ML\Datasets\Unlabeled;

use function Rubix\ML\array_transpose;
use Rubix\ML\Extractors\CSV;

ini_set('memory_limit', '-1');

Expand Down Expand Up @@ -53,19 +51,16 @@
new Activation(new LeakyReLU()),
new Dropout(0.2),
], 256, new Adam(0.0001))),
new Filesystem('mnist.model', true)
new Filesystem('mnist.rbx', true)
);

$estimator->setLogger($logger);

$estimator->train($dataset);

$scores = $estimator->scores();
$losses = $estimator->steps();
$extractor = new CSV('progress.csv', true);

Unlabeled::build(array_transpose([$scores, $losses]))
->toCSV(['scores', 'losses'])
->write('progress.csv');
$extractor->export($estimator->steps());

$logger->info('Progress saved to progress.csv');

Expand Down
6 changes: 3 additions & 3 deletions validate.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

include __DIR__ . '/vendor/autoload.php';

use Rubix\ML\Other\Loggers\Screen;
use Rubix\ML\Loggers\Screen;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\PersistentModel;
use Rubix\ML\Persisters\Filesystem;
Expand All @@ -27,7 +27,7 @@

$dataset = new Labeled($samples, $labels);

$estimator = PersistentModel::load(new Filesystem('mnist.model'));
$estimator = PersistentModel::load(new Filesystem('mnist.rbx'));

$logger->info('Making predictions');

Expand All @@ -42,6 +42,6 @@

echo $results;

$results->toJSON()->write('report.json');
$results->toJSON()->saveTo(new Filesystem('report.json'));

$logger->info('Report saved to report.json');

0 comments on commit f1ecc6f

Please sign in to comment.