diff --git a/lib/Doctrine/ODM/MongoDB/Iterator/CachingIterator.php b/lib/Doctrine/ODM/MongoDB/Iterator/CachingIterator.php index 9f8ddcf54a..d726712013 100644 --- a/lib/Doctrine/ODM/MongoDB/Iterator/CachingIterator.php +++ b/lib/Doctrine/ODM/MongoDB/Iterator/CachingIterator.php @@ -58,7 +58,9 @@ public function __construct(Traversable $iterator) /** @see https://php.net/countable.count */ public function count(): int { + $currentKey = key($this->items); $this->exhaustIterator(); + for (reset($this->items); key($this->items) !== $currentKey; next($this->items)); return count($this->items); } diff --git a/tests/Doctrine/ODM/MongoDB/Tests/Functional/Iterator/CachingIteratorTest.php b/tests/Doctrine/ODM/MongoDB/Tests/Functional/Iterator/CachingIteratorTest.php index c077cadcb5..52b90c1dec 100644 --- a/tests/Doctrine/ODM/MongoDB/Tests/Functional/Iterator/CachingIteratorTest.php +++ b/tests/Doctrine/ODM/MongoDB/Tests/Functional/Iterator/CachingIteratorTest.php @@ -114,7 +114,7 @@ public function testToArrayAfterPartialIteration(): void public function testCount(): void { $iterator = new CachingIterator($this->getTraversable([1, 2, 3])); - $this->assertCount(3, $iterator); + self::assertCount(3, $iterator); } public function testCountAfterPartialIteration(): void @@ -122,18 +122,24 @@ public function testCountAfterPartialIteration(): void $iterator = new CachingIterator($this->getTraversable([1, 2, 3])); $iterator->rewind(); - $this->assertTrue($iterator->valid()); - $this->assertSame(0, $iterator->key()); - $this->assertSame(1, $iterator->current()); + self::assertTrue($iterator->valid()); + self::assertSame(0, $iterator->key()); + self::assertSame(1, $iterator->current()); $iterator->next(); - $this->assertCount(3, $iterator); + self::assertSame(1, $iterator->key()); + self::assertSame(2, $iterator->current()); + + self::assertCount(3, $iterator); + self::assertTrue($iterator->valid()); + self::assertSame(1, $iterator->key()); + self::assertSame(2, $iterator->current()); } public function testCountWithEmptySet(): void { $iterator = new CachingIterator($this->getTraversable([])); - $this->assertCount(0, $iterator); + self::assertCount(0, $iterator); } /** @@ -172,7 +178,7 @@ public function rewind(): void }; $iterator = new CachingIterator($nestedIterator); - $this->assertCount(1, $iterator); + self::assertCount(1, $iterator); } /**