From 665a3d1d23ce3ff63d7160070ccc41f947fc8261 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Tue, 18 Jul 2023 18:10:45 -0500 Subject: [PATCH] Initial commit (#299) --- CHANGELOG.md | 3 ++ docs/datasets/generators/blob.md | 12 ++++++-- src/Datasets/Generators/Blob.php | 42 ++++++++++++++++++++++++++ src/NeuralNet/Snapshotter.php | 0 tests/Datasets/Generators/BlobTest.php | 21 +++++++++++++ 5 files changed, 76 insertions(+), 2 deletions(-) delete mode 100644 src/NeuralNet/Snapshotter.php diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a3fcbfa5..b5ba4d025 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +- 2.5.0 + - Blob Generator can now `simulate()` a Dataset object + - 2.4.1 - Sentence Tokenizer fix Arabic and Farsi language support diff --git a/docs/datasets/generators/blob.md b/docs/datasets/generators/blob.md index a0a8cbe4e..e99bf65ab 100644 --- a/docs/datasets/generators/blob.md +++ b/docs/datasets/generators/blob.md @@ -17,8 +17,16 @@ A normally distributed (Gaussian) n-dimensional blob of samples centered at a gi ```php use Rubix\ML\Datasets\Generators\Blob; -$generator = new Blob([-1.2, -5., 2.6, 0.8, 10.], 0.25); +$generator = new Blob([-1.2, -5.0, 2.6, 0.8, 10.0], 0.25); ``` ## Additional Methods -This generator does not have any additional methods. +Fit a Blob generator to the samples in a dataset. +```php +public static simulate(Dataset $dataset) : self +``` + +Return the center coordinates of the Blob. +```php +public center() : array +``` diff --git a/src/Datasets/Generators/Blob.php b/src/Datasets/Generators/Blob.php index 7bf65adbd..0b9ce5250 100644 --- a/src/Datasets/Generators/Blob.php +++ b/src/Datasets/Generators/Blob.php @@ -4,10 +4,14 @@ use Tensor\Matrix; use Tensor\Vector; +use Rubix\ML\DataType; +use Rubix\ML\Helpers\Stats; +use Rubix\ML\Datasets\Dataset; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Exceptions\InvalidArgumentException; use function count; +use function sqrt; /** * Blob @@ -37,6 +41,34 @@ class Blob implements Generator */ protected $stdDev; + /** + * Fit a Blob generator to the samples in a dataset. + * + * @param \Rubix\ML\Datasets\Dataset $dataset + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + * @return self + */ + public static function simulate(Dataset $dataset) : self + { + $features = $dataset->featuresByType(DataType::continuous()); + + if (count($features) !== $dataset->numFeatures()) { + throw new InvalidArgumentException('Dataset must only contain' + . ' continuous features.'); + } + + $means = $stdDevs = []; + + foreach ($features as $values) { + [$mean, $variance] = Stats::meanVar($values); + + $means[] = $mean; + $stdDevs[] = sqrt($variance); + } + + return new self($means, $stdDevs); + } + /** * @param (int|float)[] $center * @param int|float|(int|float)[] $stdDev @@ -74,6 +106,16 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0) $this->stdDev = $stdDev; } + /** + * Return the center coordinates of the Blob. + * + * @return list + */ + public function center() : array + { + return $this->center->asArray(); + } + /** * Return the dimensionality of the data this generates. * diff --git a/src/NeuralNet/Snapshotter.php b/src/NeuralNet/Snapshotter.php deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/Datasets/Generators/BlobTest.php b/tests/Datasets/Generators/BlobTest.php index 1dcd39055..b5ea3631c 100644 --- a/tests/Datasets/Generators/BlobTest.php +++ b/tests/Datasets/Generators/BlobTest.php @@ -29,6 +29,19 @@ protected function setUp() : void $this->generator = new Blob([0, 0, 0], 1.0); } + /** + * @test + */ + public function simulate() : void + { + $dataset = $this->generator->generate(100); + + $generator = Blob::simulate($dataset); + + $this->assertInstanceOf(Blob::class, $generator); + $this->assertInstanceOf(Generator::class, $generator); + } + /** * @test */ @@ -38,6 +51,14 @@ public function build() : void $this->assertInstanceOf(Generator::class, $this->generator); } + /** + * @test + */ + public function center() : void + { + $this->assertEquals([0, 0, 0], $this->generator->center()); + } + /** * @test */