From c29aa6bd69d75f244b19b04fa51efa3e08249850 Mon Sep 17 00:00:00 2001 From: Samuel Levy Date: Mon, 12 Aug 2024 16:22:30 +1000 Subject: [PATCH] [11.x] Added PreventsCircularRecursion This adds a trait for Eloquent which can be used to prevent recursively serializing circular references. --- .../Eloquent/Concerns/HasAttributes.php | 5 +- .../Concerns/PreventsCircularRecursion.php | 74 ++++++++ src/Illuminate/Database/Eloquent/Model.php | 70 ++++---- ...eConcernsPreventsCircularRecursionTest.php | 162 ++++++++++++++++++ 4 files changed, 279 insertions(+), 32 deletions(-) create mode 100644 src/Illuminate/Database/Eloquent/Concerns/PreventsCircularRecursion.php create mode 100644 tests/Database/DatabaseConcernsPreventsCircularRecursionTest.php diff --git a/src/Illuminate/Database/Eloquent/Concerns/HasAttributes.php b/src/Illuminate/Database/Eloquent/Concerns/HasAttributes.php index 31564c1f56a8..be0998e4252e 100644 --- a/src/Illuminate/Database/Eloquent/Concerns/HasAttributes.php +++ b/src/Illuminate/Database/Eloquent/Concerns/HasAttributes.php @@ -394,7 +394,10 @@ public function relationsToArray() // If the relation value has been set, we will set it on this attributes // list for returning. If it was not arrayable or null, we'll not set // the value on the array because it is some type of invalid value. - if (array_key_exists('relation', get_defined_vars())) { // check if $relation is in scope (could be null) + if (array_key_exists( + 'relation', + get_defined_vars() + )) { // check if $relation is in scope (could be null) $attributes[$key] = $relation ?? null; } diff --git a/src/Illuminate/Database/Eloquent/Concerns/PreventsCircularRecursion.php b/src/Illuminate/Database/Eloquent/Concerns/PreventsCircularRecursion.php new file mode 100644 index 000000000000..8c06c5351579 --- /dev/null +++ b/src/Illuminate/Database/Eloquent/Concerns/PreventsCircularRecursion.php @@ -0,0 +1,74 @@ +> + */ + protected static $recursionCache; + + /** + * Get the current recursion cache being used by the model. + * + * @return \WeakMap + */ + protected static function getRecursionCache() + { + return static::$recursionCache ??= new \WeakMap(); + } + + /** + * Get the current stack of methods being called recursively. + * + * @param object $object + * @return array + */ + protected static function getRecursiveCallStack($object): array + { + return static::getRecursionCache()->offsetExists($object) + ? static::getRecursionCache()->offsetGet($object) + : []; + } + + /** + * Prevent a method from being called multiple times on the same object within the same call stack. + * + * @param callable $callback + * @param mixed $default + * @return mixed + */ + protected function once($callback, $default = null) + { + $trace = debug_backtrace(DEBUG_BACKTRACE_PROVIDE_OBJECT, 2); + + $onceable = Onceable::tryFromTrace($trace, $callback); + + $object = $onceable->object ?? $this; + $stack = static::getRecursiveCallStack($object); + + if (isset($stack[$onceable->hash])) { + return $stack[$onceable->hash]; + } + + try { + // Set the default first to prevent recursion + $stack[$onceable->hash] = $default; + static::getRecursionCache()->offsetSet($object, $stack); + + return call_user_func($onceable->callable); + } finally { + if ($stack = Arr::except($this->getRecursiveCallStack($object), $onceable->hash)) { + static::getRecursionCache()->offsetSet($object, $stack); + } elseif (static::getRecursionCache()->offsetExists($object)) { + static::getRecursionCache()->offsetUnset($object); + } + } + } +} diff --git a/src/Illuminate/Database/Eloquent/Model.php b/src/Illuminate/Database/Eloquent/Model.php index 59b81aeed178..5ad8f7a3ce2f 100644 --- a/src/Illuminate/Database/Eloquent/Model.php +++ b/src/Illuminate/Database/Eloquent/Model.php @@ -35,6 +35,7 @@ abstract class Model implements Arrayable, ArrayAccess, CanBeEscapedWhenCastToSt Concerns\HasUniqueIds, Concerns\HidesAttributes, Concerns\GuardsAttributes, + Concerns\PreventsCircularRecursion, ForwardsCalls; /** @use HasCollection<\Illuminate\Database\Eloquent\Collection> */ use HasCollection; @@ -1083,25 +1084,27 @@ protected function decrementQuietly($column, $amount = 1, array $extra = []) */ public function push() { - if (! $this->save()) { - return false; - } - - // To sync all of the relationships to the database, we will simply spin through - // the relationships and save each model via this "push" method, which allows - // us to recurse into all of these nested relations for the model instance. - foreach ($this->relations as $models) { - $models = $models instanceof Collection - ? $models->all() : [$models]; + return $this->once(function () { + if (! $this->save()) { + return false; + } - foreach (array_filter($models) as $model) { - if (! $model->push()) { - return false; + // To sync all of the relationships to the database, we will simply spin through + // the relationships and save each model via this "push" method, which allows + // us to recurse into all of these nested relations for the model instance. + foreach ($this->relations as $models) { + $models = $models instanceof Collection + ? $models->all() : [$models]; + + foreach (array_filter($models) as $model) { + if (! $model->push()) { + return false; + } } } - } - return true; + return true; + }, true); } /** @@ -1644,7 +1647,10 @@ public function callNamedScope($scope, array $parameters = []) */ public function toArray() { - return array_merge($this->attributesToArray(), $this->relationsToArray()); + return $this->once( + fn () => array_merge($this->attributesToArray(), $this->relationsToArray()), + $this->attributesToArray(), + ); } /** @@ -1991,29 +1997,31 @@ public function getQueueableId() */ public function getQueueableRelations() { - $relations = []; + return $this->once(function () { + $relations = []; - foreach ($this->getRelations() as $key => $relation) { - if (! method_exists($this, $key)) { - continue; - } + foreach ($this->getRelations() as $key => $relation) { + if (! method_exists($this, $key)) { + continue; + } - $relations[] = $key; + $relations[] = $key; - if ($relation instanceof QueueableCollection) { - foreach ($relation->getQueueableRelations() as $collectionValue) { - $relations[] = $key.'.'.$collectionValue; + if ($relation instanceof QueueableCollection) { + foreach ($relation->getQueueableRelations() as $collectionValue) { + $relations[] = $key.'.'.$collectionValue; + } } - } - if ($relation instanceof QueueableEntity) { - foreach ($relation->getQueueableRelations() as $entityValue) { - $relations[] = $key.'.'.$entityValue; + if ($relation instanceof QueueableEntity) { + foreach ($relation->getQueueableRelations() as $entityValue) { + $relations[] = $key.'.'.$entityValue; + } } } - } - return array_unique($relations); + return array_unique($relations); + }, []); } /** diff --git a/tests/Database/DatabaseConcernsPreventsCircularRecursionTest.php b/tests/Database/DatabaseConcernsPreventsCircularRecursionTest.php new file mode 100644 index 000000000000..3c3cc8fa46df --- /dev/null +++ b/tests/Database/DatabaseConcernsPreventsCircularRecursionTest.php @@ -0,0 +1,162 @@ +assertEquals(0, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(0, $instance->instanceStack); + + $this->assertEquals(0, $instance->callStack()); + $this->assertEquals(1, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(1, $instance->instanceStack); + + $this->assertEquals(1, $instance->callStack()); + $this->assertEquals(2, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(2, $instance->instanceStack); + } + + public function testRecursiveCallsAreLimitedToIndividualInstances() + { + $instance = new PreventsCircularRecursionWithRecursiveMethod(); + $other = $instance->other; + + $this->assertEquals(0, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(0, $instance->instanceStack); + $this->assertEquals(0, $other->instanceStack); + + $instance->callStack(); + $this->assertEquals(1, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(1, $instance->instanceStack); + $this->assertEquals(0, $other->instanceStack); + + $instance->callStack(); + $this->assertEquals(2, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(2, $instance->instanceStack); + $this->assertEquals(0, $other->instanceStack); + + $other->callStack(); + $this->assertEquals(3, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(2, $instance->instanceStack); + $this->assertEquals(1, $other->instanceStack); + + $other->callStack(); + $this->assertEquals(4, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(2, $instance->instanceStack); + $this->assertEquals(2, $other->instanceStack); + } + + public function testRecursiveCallsToCircularReferenceCallsOtherInstanceOnce() + { + $instance = new PreventsCircularRecursionWithRecursiveMethod(); + $other = $instance->other; + + $this->assertEquals(0, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(0, $instance->instanceStack); + $this->assertEquals(0, $other->instanceStack); + + $instance->callOtherStack(); + $this->assertEquals(2, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(1, $instance->instanceStack); + $this->assertEquals(1, $other->instanceStack); + + $instance->callOtherStack(); + $this->assertEquals(4, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(2, $instance->instanceStack); + $this->assertEquals(2, $other->instanceStack); + + $other->callOtherStack(); + $this->assertEquals(6, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(3, $other->instanceStack); + $this->assertEquals(3, $instance->instanceStack); + + $other->callOtherStack(); + $this->assertEquals(8, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(4, $other->instanceStack); + $this->assertEquals(4, $instance->instanceStack); + } + + public function testRecursiveCallsToCircularLinkedListCallsEachInstanceOnce() + { + $instance = new PreventsCircularRecursionWithRecursiveMethod(); + $second = $instance->other; + $third = new PreventsCircularRecursionWithRecursiveMethod($second); + $instance->other = $third; + + $this->assertEquals(0, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(0, $instance->instanceStack); + $this->assertEquals(0, $second->instanceStack); + $this->assertEquals(0, $third->instanceStack); + + $instance->callOtherStack(); + $this->assertEquals(3, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(1, $instance->instanceStack); + $this->assertEquals(1, $second->instanceStack); + $this->assertEquals(1, $third->instanceStack); + + $second->callOtherStack(); + $this->assertEquals(6, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(2, $instance->instanceStack); + $this->assertEquals(2, $second->instanceStack); + $this->assertEquals(2, $third->instanceStack); + + $third->callOtherStack(); + $this->assertEquals(9, PreventsCircularRecursionWithRecursiveMethod::$globalStack); + $this->assertEquals(3, $instance->instanceStack); + $this->assertEquals(3, $second->instanceStack); + $this->assertEquals(3, $third->instanceStack); + } +} + +class PreventsCircularRecursionWithRecursiveMethod +{ + use PreventsCircularRecursion; + + public function __construct( + public ?PreventsCircularRecursionWithRecursiveMethod $other = null, + ) { + $this->other ??= new PreventsCircularRecursionWithRecursiveMethod($this); + } + + public static int $globalStack = 0; + public int $instanceStack = 0; + + public function callStack(): int + { + return $this->once( + function () { + static::$globalStack++; + $this->instanceStack++; + + return $this->callStack(); + }, + $this->instanceStack + ); + } + + public function callOtherStack(): int + { + return $this->once( + function () { + $this->other->callStack(); + + return $this->other->callOtherStack(); + }, + $this->instanceStack + ); + } +}