Skip to content

Commit

Permalink
Add CTE support to select in QueryBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
nio-dtp committed Nov 30, 2024
1 parent 052545f commit 3a04999
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 1 deletion.
20 changes: 20 additions & 0 deletions docs/en/reference/query-builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,26 @@ or QueryBuilder instances to one of the following methods:
->orderBy('field', 'DESC')
->setMaxResults(100);
WITH-Clause
~~~~~~~~~~~

To define Common Table Expressions (CTEs) that can be used in select query.

* ``with(string $name, string|QueryBuilder $queryBuilder, array $columns = [])``

.. code-block:: php
<?php
$queryBuilder
->with('cte_a', 'SELECT id FROM table_a')
->with('cte_b', 'SELECT id FROM table_b')
->select('id')
->from('cte_b', 'b')
->join('b', 'cte_a', 'a', 'a.id = b.id');
Multiple CTEs can be defined by calling the addWith method multiple times.

Building Expressions
--------------------

Expand Down
6 changes: 6 additions & 0 deletions src/Platforms/AbstractPlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
use Doctrine\DBAL\SQL\Builder\DefaultUnionSQLBuilder;
use Doctrine\DBAL\SQL\Builder\SelectSQLBuilder;
use Doctrine\DBAL\SQL\Builder\UnionSQLBuilder;
use Doctrine\DBAL\SQL\Builder\WithSQLBuilder;
use Doctrine\DBAL\SQL\Parser;
use Doctrine\DBAL\TransactionIsolationLevel;
use Doctrine\DBAL\Types;
Expand Down Expand Up @@ -802,6 +803,11 @@ public function createUnionSQLBuilder(): UnionSQLBuilder
return new DefaultUnionSQLBuilder($this);
}

public function createWithSQLBuilder(): WithSQLBuilder
{
return new WithSQLBuilder();
}

/**
* @internal
*
Expand Down
6 changes: 6 additions & 0 deletions src/Platforms/MySQL80Platform.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use Doctrine\DBAL\Platforms\Keywords\KeywordList;
use Doctrine\DBAL\Platforms\Keywords\MySQL80Keywords;
use Doctrine\DBAL\SQL\Builder\SelectSQLBuilder;
use Doctrine\DBAL\SQL\Builder\WithSQLBuilder;
use Doctrine\Deprecations\Deprecation;

/**
Expand All @@ -32,4 +33,9 @@ public function createSelectSQLBuilder(): SelectSQLBuilder
{
return AbstractPlatform::createSelectSQLBuilder();
}

public function createWithSQLBuilder(): WithSQLBuilder

Check warning on line 37 in src/Platforms/MySQL80Platform.php

View check run for this annotation

Codecov / codecov/patch

src/Platforms/MySQL80Platform.php#L37

Added line #L37 was not covered by tests
{
return AbstractPlatform::createWithSQLBuilder();
}

Check warning on line 40 in src/Platforms/MySQL80Platform.php

View check run for this annotation

Codecov / codecov/patch

src/Platforms/MySQL80Platform.php#L39-L40

Added lines #L39 - L40 were not covered by tests
}
7 changes: 7 additions & 0 deletions src/Platforms/MySQLPlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

namespace Doctrine\DBAL\Platforms;

use Doctrine\DBAL\Platforms\Exception\NotSupported;
use Doctrine\DBAL\Platforms\Keywords\KeywordList;
use Doctrine\DBAL\Platforms\Keywords\MySQLKeywords;
use Doctrine\DBAL\Schema\Index;
use Doctrine\DBAL\SQL\Builder\WithSQLBuilder;
use Doctrine\DBAL\Types\BlobType;
use Doctrine\DBAL\Types\TextType;
use Doctrine\Deprecations\Deprecation;
Expand Down Expand Up @@ -35,6 +37,11 @@ public function getDefaultValueDeclarationSQL(array $column): string
return parent::getDefaultValueDeclarationSQL($column);
}

public function createWithSQLBuilder(): WithSQLBuilder
{
throw NotSupported::new(__METHOD__);

Check warning on line 42 in src/Platforms/MySQLPlatform.php

View check run for this annotation

Codecov / codecov/patch

src/Platforms/MySQLPlatform.php#L42

Added line #L42 was not covered by tests
}

/**
* {@inheritDoc}
*/
Expand Down
46 changes: 45 additions & 1 deletion src/Query/QueryBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ class QueryBuilder
*/
private array $unionParts = [];

/**
* The common table expression parts.
*
* @var With[]
*/
private array $withParts = [];

/**
* The query cache profile used for caching results.
*/
Expand Down Expand Up @@ -557,6 +564,33 @@ public function addUnion(string|QueryBuilder $part, UnionType $type = UnionType:
return $this;
}

/**
* Add a Common Table Expression to be used for a select query.
*
* <code>
* // WITH cte_name AS (SELECT 1 AS field1)
* $qb = $conn->createQueryBuilder()
* ->with('cte_name', 'SELECT 1 AS field1');
*
* // WITH cte_name(field1) AS (SELECT 1 AS field1)
* $qb = $conn->createQueryBuilder()
* ->with('cte_name', 'SELECT 1 AS field1', ['field1']);
* </code>
*
* @param string $name The name of the CTE
* @param string[] $columns The optional columns list to select in the CTE.
*
* @return $this This QueryBuilder instance.
*/
public function with(string $name, string|QueryBuilder $part, array $columns = []): self
{
$this->withParts[] = new With($name, $part, $columns);

$this->sql = null;

return $this;
}

/**
* Specifies an item that is to be returned in the query result.
* Replaces any previously specified selections, if any.
Expand Down Expand Up @@ -1266,7 +1300,15 @@ private function getSQLForSelect(): string
throw new QueryException('No SELECT expressions given. Please use select() or addSelect().');
}

return $this->connection->getDatabasePlatform()
$databasePlatform = $this->connection->getDatabasePlatform();
$selectParts = [];
if (count($this->withParts) > 0) {
$selectParts[] = $databasePlatform
->createWithSQLBuilder()
->buildSQL(...$this->withParts);
}

$selectParts[] = $databasePlatform
->createSelectSQLBuilder()
->buildSQL(
new SelectQuery(
Expand All @@ -1281,6 +1323,8 @@ private function getSQLForSelect(): string
$this->forUpdate,
),
);

return implode(' ', $selectParts);
}

/**
Expand Down
17 changes: 17 additions & 0 deletions src/Query/With.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<?php

declare(strict_types=1);

namespace Doctrine\DBAL\Query;

/** @internal */
final class With
{
/** @param string[] $columns */
public function __construct(
public readonly string $name,
public readonly string|QueryBuilder $query,
public readonly array $columns = [],
) {
}
}
36 changes: 36 additions & 0 deletions src/SQL/Builder/WithSQLBuilder.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<?php

declare(strict_types=1);

namespace Doctrine\DBAL\SQL\Builder;

use Doctrine\DBAL\Query\With;

use function array_map;
use function count;
use function implode;
use function sprintf;

final class WithSQLBuilder
{
public function buildSQL(With ...$withParts): string
{
$parts = array_map(
static fn (With $part) => sprintf(
'%s%s AS (%s)',
$part->name,
self::columns($part->columns),
$part->query,
),
$withParts,
);

return 'WITH ' . implode(', ', $parts);
}

/** @param string[] $columns */
private static function columns(array $columns): string
{
return count($columns) > 0 ? '(' . implode(', ', $columns) . ')' : '';
}
}
128 changes: 128 additions & 0 deletions tests/Functional/Query/QueryBuilderTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@

namespace Doctrine\DBAL\Tests\Functional\Query;

use Doctrine\DBAL\ArrayParameterType;
use Doctrine\DBAL\DriverManager;
use Doctrine\DBAL\Exception;
use Doctrine\DBAL\ParameterType;
use Doctrine\DBAL\Platforms\DB2Platform;
use Doctrine\DBAL\Platforms\Exception\NotSupported;
use Doctrine\DBAL\Platforms\MariaDB1060Platform;
use Doctrine\DBAL\Platforms\MariaDBPlatform;
use Doctrine\DBAL\Platforms\MySQL80Platform;
use Doctrine\DBAL\Platforms\MySQLPlatform;
use Doctrine\DBAL\Platforms\SQLitePlatform;
use Doctrine\DBAL\Platforms\SQLServerPlatform;
use Doctrine\DBAL\Query\ForUpdate\ConflictResolutionMode;
use Doctrine\DBAL\Query\UnionType;
use Doctrine\DBAL\Schema\Table;
Expand Down Expand Up @@ -332,6 +335,117 @@ public function testUnionAndAddUnionWorksWithQueryBuilderPartsAndReturnsExpected
self::assertSame($expectedRows, $qb->executeQuery()->fetchAllAssociative());
}

public function testSelectWithCTENamedParameter(): void
{
if (! $this->platformSupportsCTEs()) {
self::markTestSkipped('The database platform does not support CTE.');
}

if (! $this->platformSupportsCTEColumnDefinition()) {
self::markTestSkipped('The database platform does not support CTE column definition.');
}

$expectedRows = $this->prepareExpectedRows([['virtual_id' => 1]]);
$qb = $this->connection->createQueryBuilder();

$cteQueryBuilder = $this->connection->createQueryBuilder();
$cteQueryBuilder->select('id AS virtual_id')
->from('for_update')
->where('virtual_id = :id');

$qb->with('cte_a', $cteQueryBuilder, ['virtual_id'])
->select('virtual_id')
->from('cte_a')
->setParameter('id', 1);

self::assertSame($expectedRows, $qb->executeQuery()->fetchAllAssociative());
}

public function testSelectWithCTEPositionalParameter(): void
{
if (! $this->platformSupportsCTEs()) {
self::markTestSkipped('The database platform does not support CTE.');
}

if (! $this->platformSupportsCTEColumnDefinition()) {
self::markTestSkipped('The database platform does not support CTE column definition.');
}

$expectedRows = $this->prepareExpectedRows([['virtual_id' => 1]]);
$qb = $this->connection->createQueryBuilder();

$cteQueryBuilder1 = $this->connection->createQueryBuilder();
$cteQueryBuilder1->select('id AS virtual_id')
->from('for_update')
->where($qb->expr()->eq('virtual_id', '?'));

$cteQueryBuilder2 = $this->connection->createQueryBuilder();
$cteQueryBuilder2->select('id AS virtual_id')
->from('for_update')
->where($qb->expr()->in('id', '?'));

$qb->with('cte_a', $cteQueryBuilder1, ['virtual_id'])
->with('cte_b', $cteQueryBuilder2, ['virtual_id'])
->select('a.virtual_id')
->from('cte_a', 'a')
->join('a', 'cte_b', 'b', 'a.virtual_id = b.virtual_id')
->setParameters([1, [1, 2]], [ParameterType::INTEGER, ArrayParameterType::INTEGER]);

self::assertSame($expectedRows, $qb->executeQuery()->fetchAllAssociative());
}

public function testSelectWithCTEUnion(): void
{
if (! $this->platformSupportsCTEs()) {
self::markTestSkipped('The database platform does not support CTE.');
}

$expectedRows = $this->prepareExpectedRows([['id' => 2], ['id' => 1]]);
$qb = $this->connection->createQueryBuilder();

$subQueryBuilder1 = $this->connection->createQueryBuilder();
$subQueryBuilder1->select('id')
->from('for_update')
->where($qb->expr()->eq('id', '?'));

$subQueryBuilder2 = $this->connection->createQueryBuilder();
$subQueryBuilder2->select('id')
->from('for_update')
->where($qb->expr()->eq('id', '?'));

$subQueryBuilder3 = $this->connection->createQueryBuilder();
$subQueryBuilder3->union($subQueryBuilder1)
->addUnion($subQueryBuilder2, UnionType::DISTINCT);

$qb->with('cte_a', $subQueryBuilder3)
->select('id')
->from('cte_a')
->orderBy('id', 'DESC')
->setParameters([1, 2]);

self::assertSame($expectedRows, $qb->executeQuery()->fetchAllAssociative());
}

public function testPlatformDoesNotSupportCTE(): void
{
if ($this->platformSupportsCTEs()) {
self::markTestSkipped('The database platform does support CTE.');
}

$qb = $this->connection->createQueryBuilder();

$cteQueryBuilder = $this->connection->createQueryBuilder();
$cteQueryBuilder->select('id')
->from('for_update');

$qb->with('cte_a', $cteQueryBuilder)
->select('id')
->from('cte_a');

self::expectException(NotSupported::class);
$qb->executeQuery();
}

/**
* @param array<array<string, int>> $rows
*
Expand Down Expand Up @@ -380,4 +494,18 @@ private function platformSupportsSkipLocked(): bool

return ! $platform instanceof SQLitePlatform;
}

private function platformSupportsCTEs(): bool
{
$platform = $this->connection->getDatabasePlatform();

return ! $platform instanceof MySQLPlatform || $platform instanceof MySQL80Platform;
}

private function platformSupportsCTEColumnDefinition(): bool
{
$platform = $this->connection->getDatabasePlatform();

return ! $platform instanceof SQLServerPlatform;
}
}
Loading

0 comments on commit 3a04999

Please sign in to comment.