diff --git a/src/Psl/Iter/Iterator.php b/src/Psl/Iter/Iterator.php index 265f8e755..ad3f0146d 100644 --- a/src/Psl/Iter/Iterator.php +++ b/src/Psl/Iter/Iterator.php @@ -28,9 +28,9 @@ final class Iterator implements Countable, SeekableIterator private array $entries = []; /** - * Whether or not the current value/key pair has been added to the local entries. + * Whether or not the current value/key pair has been added to the local entries. */ - private bool $saved = false; + private bool $saved = true; /** * Current cursor position for the local entries. @@ -79,7 +79,7 @@ public static function create(iterable $iterable): Iterator /** * @var (callable(): Generator) $factory */ - $factory = static fn (): Generator => yield from $iterable; + $factory = static fn(): Generator => yield from $iterable; return new self($factory()); } @@ -94,34 +94,34 @@ public static function create(iterable $iterable): Iterator public function current(): mixed { Psl\invariant($this->valid(), 'The Iterator is invalid.'); - if (!contains_key($this->entries, $this->position)) { - $this->progress(); - } + $this->save(); return $this->entries[$this->position][1]; } /** - * Move forward to the next element. + * Checks if current position is valid. */ - public function next(): void + public function valid(): bool { - $this->position++; - if (null === $this->generator || !$this->generator->valid()) { - return; + if (isset($this->entries[$this->position])) { + return true; } - if (contains_key($this->entries, $this->position + 1)) { - return; + if (null !== $this->generator && $this->generator->valid()) { + return true; } - if (!$this->saved) { - $this->progress(); - } - $this->saved = false; - $this->generator?->next(); + $this->generator = null; + return false; + } - $this->progress(); + private function save(): void + { + if (!$this->saved && $this->generator !== null) { + $this->saved = true; + $this->entries[] = [$this->generator->key(), $this->generator->current()]; + } } /** @@ -134,34 +134,11 @@ public function next(): void public function key(): mixed { Psl\invariant($this->valid(), 'The Iterator is invalid.'); - if (!contains_key($this->entries, $this->position)) { - $this->progress(); - } + $this->save(); return $this->entries[$this->position][0]; } - /** - * Checks if current position is valid. - */ - public function valid(): bool - { - if (contains_key($this->entries, $this->position)) { - return true; - } - - if (null === $this->generator) { - return false; - } - - if ($this->generator->valid()) { - return true; - } - - $this->generator = null; - return false; - } - /** * Rewind the Iterator to the first element. */ @@ -179,16 +156,18 @@ public function rewind(): void */ public function seek(int $position): void { - if (0 === $position || $position <= $this->position) { + if ($position <= $this->position) { $this->position = $position; return; } if ($this->generator) { - while ($this->position !== $position) { + do { + $this->save(); $this->next(); - Psl\invariant($this->valid(), 'Position is out-of-bounds.'); - } + /** @psalm-suppress PossiblyNullReference - ->next() and ->save() don't mutate ->generator. */ + Psl\invariant($this->generator->valid(), 'Position is out-of-bounds.'); + } while ($this->position < $position); return; } @@ -199,43 +178,39 @@ public function seek(int $position): void } /** - * @return 0|positive-int - * - * @psalm-suppress MoreSpecificReturnType + * Move forward to the next element. */ - public function count(): int + public function next(): void { - if ($this->generator) { - $this->exhaust(); + $this->position++; + + if (isset($this->entries[$this->position]) || null === $this->generator || !$this->generator->valid()) { + return; } - /** - * @psalm-suppress LessSpecificReturnStatement - */ - return count($this->entries); + $this->generator->next(); + $this->saved = false; } - private function exhaust(): void + /** + * @return int<0, max> + * + * @psalm-suppress PossiblyNullReference + */ + public function count(): int { if ($this->generator) { - if ($this->generator->valid()) { - foreach ($this->generator as $key => $value) { - $this->entries[] = [$key, $value]; - } - } + $previous = $this->position; + do { + $this->save(); + $this->next(); + } while ($this->generator->valid()); + $this->position = $previous; $this->generator = null; } - } - /** - * Save the current key and value to the local entries if the generator is still valid. - */ - private function progress(): void - { - if ($this->generator && $this->generator->valid() && !$this->saved) { - $this->entries[] = [$this->generator->key(), $this->generator->current()]; - $this->saved = true; - } + /** @var int<0, max> */ + return count($this->entries); } } diff --git a/tests/unit/Iter/IteratorTest.php b/tests/unit/Iter/IteratorTest.php index cf33355a7..3688f6fe8 100644 --- a/tests/unit/Iter/IteratorTest.php +++ b/tests/unit/Iter/IteratorTest.php @@ -129,6 +129,39 @@ public function testIterating(): void ], $spy->toArray()); } + public function testCountWhileIterating(): void + { + $spy = new MutableVector([]); + + $generator = (static function () use ($spy): iterable { + for ($i = 0; $i < 3; $i++) { + $spy->add('sending (' . $i . ')'); + + yield ['foo', 'bar'] => $i; + } + })(); + + $rewindable = Iter\rewindable($generator); + foreach ($rewindable as $key => $value) { + $spy->add('count (' . $rewindable->count() . ')'); + $spy->add('received (' . $value . ')'); + + static::assertSame(['foo', 'bar'], $key); + } + + static::assertSame([ + 'sending (0)', + 'sending (1)', + 'sending (2)', + 'count (3)', + 'received (0)', + 'count (3)', + 'received (1)', + 'count (3)', + 'received (2)', + ], $spy->toArray()); + } + public function testRewindingValidGenerator(): void { $spy = new MutableVector([]);