diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index d583f33d64..229c248549 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -32,7 +32,8 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] steps: - uses: actions/checkout@v2 @@ -48,7 +49,7 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-check-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-check-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: @@ -56,21 +57,22 @@ jobs: args: > --manifest-path sqlx-core/Cargo.toml --no-default-features - --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }} + --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - uses: actions-rs/cargo@v1 with: command: check args: > --no-default-features - --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }},macros + --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }}-${{ matrix.tls }},macros test: name: Unit Test runs-on: ubuntu-20.04 strategy: matrix: - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] steps: - uses: actions/checkout@v2 @@ -93,7 +95,7 @@ jobs: command: test args: > --manifest-path sqlx-core/Cargo.toml - --features offline,all-databases,all-types,runtime-${{ matrix.runtime }} + --features offline,all-databases,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} cli: name: CLI Binaries @@ -148,7 +150,8 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -165,14 +168,14 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-sqlite-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-sqlite-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: test args: > --no-default-features - --features any,macros,migrate,sqlite,all-types,runtime-${{ matrix.runtime }} + --features any,macros,migrate,sqlite,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} -- --test-threads=1 env: @@ -183,8 +186,9 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - postgres: [12, 10, 9_6, 9_5] - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + postgres: [13, 9_6] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -201,23 +205,24 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-postgres-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-postgres-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: build args: > - --features postgres,all-types,runtime-${{ matrix.runtime }} + --features postgres,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - - run: docker-compose -f tests/docker-compose.yml run -d -p 5432:5432 postgres_${{ matrix.postgres }} - - run: sleep 10 + - run: | + docker-compose -f tests/docker-compose.yml run -d -p 5432:5432 --name postgres_${{ matrix.postgres }} postgres_${{ matrix.postgres }} + docker exec postgres_${{ matrix.postgres }} bash -c "until pg_isready; do sleep 1; done" - uses: actions-rs/cargo@v1 with: command: test args: > --no-default-features - --features any,postgres,macros,all-types,runtime-${{ matrix.runtime }} + --features any,postgres,macros,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: postgres://postgres:password@localhost:5432/sqlx @@ -226,7 +231,7 @@ jobs: command: test args: > --no-default-features - --features any,postgres,macros,migrate,all-types,runtime-${{ matrix.runtime }} + --features any,postgres,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: postgres://postgres:password@localhost:5432/sqlx?sslmode=verify-ca&sslrootcert=.%2Ftests%2Fcerts%2Fca.crt @@ -235,8 +240,9 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - mysql: [8, 5_7, 5_6] - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + mysql: [8, 5_6] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -253,13 +259,13 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: build args: > - --features mysql,all-types,runtime-${{ matrix.runtime }} + --features mysql,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 mysql_${{ matrix.mysql }} - run: sleep 60 @@ -269,7 +275,7 @@ jobs: command: test args: > --no-default-features - --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }} + --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx @@ -278,8 +284,9 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - mariadb: [10_5, 10_4, 10_3, 10_2, 10_1] - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + mariadb: [10_6, 10_2] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -297,13 +304,13 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: build args: > - --features mysql,all-types,runtime-${{ matrix.runtime }} + --features mysql,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 mariadb_${{ matrix.mariadb }} - run: sleep 30 @@ -313,7 +320,7 @@ jobs: command: test args: > --no-default-features - --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }} + --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx @@ -322,8 +329,9 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - mssql: [2019] - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + mssql: [2019, 2017] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -340,13 +348,13 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-mssql-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-mssql-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: build args: > - --features mssql,all-types,runtime-${{ matrix.runtime }} + --features mssql,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - run: docker-compose -f tests/docker-compose.yml run -d -p 1433:1433 mssql_${{ matrix.mssql }} - run: sleep 80 # MSSQL takes a "bit" to startup @@ -356,6 +364,6 @@ jobs: command: test args: > --no-default-features - --features any,mssql,macros,migrate,all-types,runtime-${{ matrix.runtime }} + --features any,mssql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: mssql://sa:Password123!@localhost/sqlx diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e3addaecd..ca19f348ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,67 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 0.5.9 - 2021-10-01 + +A hotfix release to address the issue of the `sqlx` crate itself still depending on older versions of `sqlx-core` and +`sqlx-macros`. + +No other changes from `0.5.8`. + +## 0.5.8 - 2021-10-01 (Yanked; use 0.5.9) + +[A total of 24 pull requests][0.5.8-prs] were merged this release cycle! Some highlights: + +* [[#1289]] Support the `immutable` option on SQLite connections [[@djmarcin]] +* [[#1295]] Support custom initial options for SQLite [[@ghassmo]] + * Allows specifying custom `PRAGMA`s and overriding those set by SQLx. +* [[#1345]] Initial support for Postgres `COPY FROM/TO`[[@montanalow], [@abonander]] +* [[#1439]] Handle multiple waiting results correctly in MySQL [[@eagletmt]] + +[#1289]: https://github.com/launchbadge/sqlx/pull/1289 +[#1295]: https://github.com/launchbadge/sqlx/pull/1295 +[#1345]: https://github.com/launchbadge/sqlx/pull/1345 +[#1439]: https://github.com/launchbadge/sqlx/pull/1439 +[0.5.8-prs]: https://github.com/launchbadge/sqlx/pulls?q=is%3Apr+is%3Amerged+merged%3A2021-08-21..2021-10-01 + +## 0.5.7 - 2021-08-20 + +* [[#1392]] use `resolve_path` when getting path for `include_str!()` [[@abonander]] + * Fixes a regression introduced by [[#1332]]. +* [[#1393]] avoid recursively spawning tasks in `PgListener::drop()` [[@abonander]] + * Fixes a panic that occurs when `PgListener` is dropped in `async fn main()`. + +[#1392]: https://github.com/launchbadge/sqlx/pull/1392 +[#1393]: https://github.com/launchbadge/sqlx/pull/1393 + +## 0.5.6 - 2021-08-16 + +A large bugfix release, including but not limited to: + +* [[#1329]] Implement `MACADDR` type for Postgres [[@nomick]] +* [[#1363]] Fix `PortalSuspended` for array of composite types in Postgres [[@AtkinsChang]] +* [[#1320]] Reimplement `sqlx::Pool` internals using `futures-intrusive` [[@abonander]] + * This addresses a number of deadlocks/stalls on acquiring connections from the pool. +* [[#1332]] Macros: tell the compiler about external files/env vars to watch [[@abonander]] + * Includes `sqlx build-script` to create a `build.rs` to watch `migrations/` for changes. + * Nightly users can try `RUSTFLAGS=--cfg sqlx_macros_unstable` to tell the compiler + to watch `migrations/` for changes instead of using a build script. + * See the new section in the docs for `sqlx::migrate!()` for details. +* [[#1351]] Fix a few sources of segfaults/errors in SQLite driver [[@abonander]] + * Includes contributions from [[@link2ext]] and [[@madadam]]. +* [[#1323]] Keep track of column typing in SQLite EXPLAIN parsing [[@marshoepial]] + * This fixes errors in the macros when using `INSERT/UPDATE/DELETE ... RETURNING ...` in SQLite. + +[A total of 25 pull requests][0.5.6-prs] were merged this release cycle! + +[#1329]: https://github.com/launchbadge/sqlx/pull/1329 +[#1363]: https://github.com/launchbadge/sqlx/pull/1363 +[#1320]: https://github.com/launchbadge/sqlx/pull/1320 +[#1332]: https://github.com/launchbadge/sqlx/pull/1332 +[#1351]: https://github.com/launchbadge/sqlx/pull/1351 +[#1323]: https://github.com/launchbadge/sqlx/pull/1323 +[0.5.6-prs]: https://github.com/launchbadge/sqlx/pulls?q=is%3Apr+is%3Amerged+merged%3A2021-05-24..2021-08-17 + ## 0.5.5 - 2021-05-24 - [[#1242]] Fix infinite loop at compile time when using query macros [[@toshokan]] @@ -925,3 +986,12 @@ Fix docs.rs build by enabling a runtime feature in the docs.rs metadata in `Carg [@feikesteenbergen]: https://github.com/feikesteenbergen [@etcaton]: https://github.com/ETCaton [@toshokan]: https://github.com/toshokan +[@nomick]: https://github.com/nomick +[@marshoepial]: https://github.com/marshoepial +[@link2ext]: https://github.com/link2ext +[@madadam]: https://github.com/madadam +[@AtkinsChang]: https://github.com/AtkinsChang +[@djmarcin]: https://github.com/djmarcin +[@ghassmo]: https://github.com/ghassmo +[@eagletmt]: https://github.com/eagletmt +[@montanalow]: https://github.com/montanalow \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 8c3f95670d..bf275b32f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -359,12 +359,6 @@ dependencies = [ "serde", ] -[[package]] -name = "build_const" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4ae4235e6dac0694637c763029ecea1a2ec9e4e06ec2729bd21ba4d9c863eb7" - [[package]] name = "bumpalo" version = "3.6.1" @@ -546,13 +540,19 @@ dependencies = [ [[package]] name = "crc" -version = "1.8.1" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d663548de7f5cca343f1e0a48d14dcfb0e9eb4e079ec58883b7251539fa10aeb" +checksum = "10c2722795460108a7872e1cd933a85d6ec38abc4baecad51028f702da28889f" dependencies = [ - "build_const", + "crc-catalog", ] +[[package]] +name = "crc-catalog" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccaeedb56da03b09f598226e25e80088cb4cd25f316e6e4df7d695f0feeb1403" + [[package]] name = "criterion" version = "0.3.4" @@ -646,9 +646,9 @@ dependencies = [ [[package]] name = "crypto-mac" -version = "0.10.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4857fd85a0c34b3c3297875b747c1e02e06b6a0ea32dd892d8192b9ce0813ea6" +checksum = "b1d1a86f49236c215f271d40892d5fc950490551400b02ef360692c29815c714" dependencies = [ "generic-array", "subtle", @@ -686,18 +686,6 @@ dependencies = [ "syn", ] -[[package]] -name = "dialoguer" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9dd058f8b65922819fabb4a41e7d1964e56344042c26efbccd465202c23fa0c" -dependencies = [ - "console", - "lazy_static", - "tempfile", - "zeroize", -] - [[package]] name = "difference" version = "2.0.0" @@ -722,6 +710,16 @@ dependencies = [ "dirs-sys", ] +[[package]] +name = "dirs-next" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf36e65a80337bea855cd4ef9b8401ffce06a7baedf2e85ec467b1ac3f6e82b6" +dependencies = [ + "cfg-if 1.0.0", + "dirs-sys-next", +] + [[package]] name = "dirs-sys" version = "0.3.6" @@ -733,6 +731,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + [[package]] name = "discard" version = "1.0.4" @@ -925,6 +934,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62007592ac46aa7c2b6416f7deb9a8a8f63a01e0f1d6e1787d5630170db2b63e" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.15" @@ -1028,9 +1048,9 @@ dependencies = [ [[package]] name = "git2" -version = "0.13.19" +version = "0.13.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17929de7239dea9f68aa14f94b2ab4974e7b24c1314275ffcc12a7758172fa18" +checksum = "d9831e983241f8c5591ed53f17d874833e2fa82cac2625f3888c50cbfe136cba" dependencies = [ "bitflags", "libc", @@ -1064,12 +1084,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62aca2aba2d62b4a7f5b33f3712cb1b0692779a56fb510499d5c0aa594daeaf3" -[[package]] -name = "hashbrown" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" - [[package]] name = "hashbrown" version = "0.11.2" @@ -1085,7 +1099,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" dependencies = [ - "hashbrown 0.11.2", + "hashbrown", ] [[package]] @@ -1114,9 +1128,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hmac" -version = "0.10.1" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15" +checksum = "2a2a2320eb7ec0ebe8da8f744d7812d9fc4cb4d09344ac01898dbcb6a20ae69b" dependencies = [ "crypto-mac", "digest", @@ -1150,12 +1164,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.6.2" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3" +checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5" dependencies = [ "autocfg 1.0.1", - "hashbrown 0.9.1", + "hashbrown", ] [[package]] @@ -1269,9 +1283,9 @@ checksum = "18794a8ad5b29321f790b55d93dfba91e125cb1a9edbd4f8e3150acc771c1a5e" [[package]] name = "libgit2-sys" -version = "0.12.20+1.1.0" +version = "0.12.21+1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e2f09917e00b9ad194ae72072bb5ada2cca16d8171a43e91ddba2afbb02664b" +checksum = "86271bacd72b2b9e854c3dcfb82efd538f15f870e4c11af66900effb462f6825" dependencies = [ "cc", "libc", @@ -1287,9 +1301,9 @@ checksum = "c7d73b3f436185384286bd8098d17ec07c9a7d2388a6599f824d8502b529702a" [[package]] name = "libsqlite3-sys" -version = "0.22.2" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290b64917f8b0cb885d9de0f9959fe1f775d7fa12f1da2db9001c1c8ab60f89d" +checksum = "abd5850c449b40bacb498b2bbdfaff648b1b055630073ba8db499caf2d0ea9f2" dependencies = [ "cc", "pkg-config", @@ -1327,6 +1341,16 @@ dependencies = [ "value-bag", ] +[[package]] +name = "mac_address" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d9bb26482176bddeea173ceaa2acec85146d20cdcc631eafaf9d605d3d4fc23" +dependencies = [ + "nix 0.19.1", + "winapi", +] + [[package]] name = "maplit" version = "1.0.2" @@ -1432,6 +1456,30 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nix" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83450fe6a6142ddd95fb064b746083fc4ef1705fe81f64a64e1d4b39f54a1055" +dependencies = [ + "bitflags", + "cc", + "cfg-if 0.1.10", + "libc", +] + +[[package]] +name = "nix" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ccba0cfe4fdf15982d1674c69b1fd80bad427d293849982668dfe454bd61f2" +dependencies = [ + "bitflags", + "cc", + "cfg-if 1.0.0", + "libc", +] + [[package]] name = "nom" version = "6.1.2" @@ -1840,13 +1888,22 @@ checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" [[package]] name = "proc-macro2" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038" +checksum = "5c7ed8b8c7b886ea3ed7dde405212185f423ab44682667c8c6dd14aa1d9f6612" dependencies = [ "unicode-xid", ] +[[package]] +name = "promptly" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b99cfb0289110d969dd21637cfbc922584329bc9e5037c5e576325c615658509" +dependencies = [ + "rustyline", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -2077,6 +2134,25 @@ dependencies = [ "webpki", ] +[[package]] +name = "rustyline" +version = "6.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0d5e7b0219a3eadd5439498525d4765c59b7c993ef0c12244865cd2d988413" +dependencies = [ + "cfg-if 0.1.10", + "dirs-next", + "libc", + "log", + "memchr", + "nix 0.18.0", + "scopeguard", + "unicode-segmentation", + "unicode-width", + "utf8parse", + "winapi", +] + [[package]] name = "ryu" version = "1.0.5" @@ -2334,7 +2410,7 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.5.5" +version = "0.5.9" dependencies = [ "anyhow", "async-std", @@ -2367,18 +2443,19 @@ dependencies = [ [[package]] name = "sqlx-cli" -version = "0.5.5" +version = "0.5.9" dependencies = [ "anyhow", "async-trait", "chrono", "clap 3.0.0-beta.2", + "clap_derive", "console", - "dialoguer", "dotenv", "futures 0.3.15", "glob", "openssl", + "promptly", "remove_dir_all 0.7.0", "serde", "serde_json", @@ -2389,7 +2466,7 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.5.5" +version = "0.5.9" dependencies = [ "ahash", "atoi", @@ -2411,6 +2488,7 @@ dependencies = [ "encoding_rs", "futures-channel", "futures-core", + "futures-intrusive", "futures-util", "generic-array", "getrandom", @@ -2418,11 +2496,13 @@ dependencies = [ "hashlink", "hex", "hmac", + "indexmap", "ipnetwork", "itoa", "libc", "libsqlite3-sys", "log", + "mac_address", "md-5", "memchr", "num-bigint 0.3.2", @@ -2501,6 +2581,15 @@ dependencies = [ "structopt", ] +[[package]] +name = "sqlx-example-postgres-transaction" +version = "0.1.0" +dependencies = [ + "async-std", + "futures", + "sqlx", +] + [[package]] name = "sqlx-example-sqlite-todos" version = "0.1.0" @@ -2515,11 +2604,10 @@ dependencies = [ [[package]] name = "sqlx-macros" -version = "0.5.5" +version = "0.5.9" dependencies = [ "dotenv", "either", - "futures 0.3.15", "heck", "hex", "once_cell", @@ -2536,7 +2624,7 @@ dependencies = [ [[package]] name = "sqlx-rt" -version = "0.5.5" +version = "0.5.9" dependencies = [ "actix-rt", "async-native-tls", @@ -2704,9 +2792,9 @@ checksum = "1e81da0851ada1f3e9d4312c704aa4f8806f0f9d69faaf8df2f3464b4a9437c2" [[package]] name = "syn" -version = "1.0.72" +version = "1.0.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e8cdbefb79a9a5a65e0db8b47b723ee907b7c7f8496c76a1770b5c310bab82" +checksum = "1873d832550d4588c3dbc20f01361ab00bfe741048f71e3fecf145a7cc18b29c" dependencies = [ "proc-macro2", "quote", @@ -3039,6 +3127,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8parse" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "936e4b492acfd135421d8dca4b1aa80a7bfc26e702ef3af710e0752684df5372" + [[package]] name = "uuid" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index 7336a90db7..a27f8848e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,12 +13,13 @@ members = [ "examples/postgres/listen", "examples/postgres/todos", "examples/postgres/mockable-todos", + "examples/postgres/transaction", "examples/sqlite/todos", ] [package] name = "sqlx" -version = "0.5.5" +version = "0.5.9" license = "MIT OR Apache-2.0" readme = "README.md" repository = "https://github.com/launchbadge/sqlx" @@ -60,6 +61,7 @@ all-types = [ "time", "chrono", "ipnetwork", + "mac_address", "uuid", "bit-vec", "bstr", @@ -122,6 +124,7 @@ bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros/bigdecimal"] decimal = ["sqlx-core/decimal", "sqlx-macros/decimal"] chrono = ["sqlx-core/chrono", "sqlx-macros/chrono"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-macros/ipnetwork"] +mac_address = ["sqlx-core/mac_address", "sqlx-macros/mac_address"] uuid = ["sqlx-core/uuid", "sqlx-macros/uuid"] json = ["sqlx-core/json", "sqlx-macros/json"] time = ["sqlx-core/time", "sqlx-macros/time"] @@ -130,8 +133,8 @@ bstr = ["sqlx-core/bstr"] git2 = ["sqlx-core/git2"] [dependencies] -sqlx-core = { version = "0.5.5", path = "sqlx-core", default-features = false } -sqlx-macros = { version = "0.5.5", path = "sqlx-macros", default-features = false, optional = true } +sqlx-core = { version = "0.5.9", path = "sqlx-core", default-features = false } +sqlx-macros = { version = "0.5.9", path = "sqlx-macros", default-features = false, optional = true } [dev-dependencies] anyhow = "1.0.31" diff --git a/FAQ.md b/FAQ.md new file mode 100644 index 0000000000..82ec4f4dde --- /dev/null +++ b/FAQ.md @@ -0,0 +1,200 @@ +SQLx Frequently Asked Questions +=============================== + +---------------------------------------------------------------- +### How can I do a `SELECT ... WHERE foo IN (...)` query? + + +In 0.6 SQLx will support binding arrays as a comma-separated list for every database, +but unfortunately there's no general solution for that currently in SQLx itself. +You would need to manually generate the query, at which point it +cannot be used with the macros. + +However, **in Postgres** you can work around this limitation by binding the arrays directly and using `= ANY()`: + +```rust +let db: PgPool = /* ... */; +let foo_ids: Vec = vec![/* ... */]; + +let foos = sqlx::query!( + "SELECT * FROM foo WHERE id = ANY($1)", + // a bug of the parameter typechecking code requires all array parameters to be slices + &foo_ids[..] +) + .fetch_all(&db) + .await?; +``` + +Even when SQLx gains generic placeholder expansion for arrays, this will still be the optimal way to do it for Postgres, +as comma-expansion means each possible length of the array generates a different query +(and represents a combinatorial explosion if more than one array is used). + +Note that you can use any operator that returns a boolean, but beware that `!= ANY($1)` is **not equivalent** to `NOT IN (...)` as it effectively works like this: + +`lhs != ANY(rhs) -> false OR lhs != rhs[0] OR lhs != rhs[1] OR ... lhs != rhs[length(rhs) - 1]` + +The equivalent of `NOT IN (...)` would be `!= ALL($1)`: + +`lhs != ALL(rhs) -> true AND lhs != rhs[0] AND lhs != rhs[1] AND ... lhs != rhs[length(rhs) - 1]` + +Note that `ANY` using any operator and passed an empty array will return `false`, thus the leading `false OR ...`. +Meanwhile, `ALL` with any operator and passed an empty array will return `true`, thus the leading `true AND ...`. + +See also: [Postgres Manual, Section 9.24: Row and Array Comparisons](https://www.postgresql.org/docs/current/functions-comparisons.html) + +----- +### How can I bind an array to a `VALUES()` clause? How can I do bulk inserts? + +Like the above, SQLx currently does not support this in the general case right now but will in 0.6. + +However, **Postgres** also has a feature to save the day here! You can pass an array to `UNNEST()` and +it will treat it as a temporary table: + +```rust +let foo_texts: Vec = vec![/* ... */]; + +sqlx::query!( + // because `UNNEST()` is a generic function, Postgres needs the cast on the parameter here + // in order to know what type to expect there when preparing the query + "INSERT INTO foo(text_column) SELECT * FROM UNNEST($1::text[])", + &foo_texts[..] +) + .execute(&db) + .await?; +``` + +`UNNEST()` can also take more than one array, in which case it'll treat each array as a column in the temporary table: + +```rust +// this solution currently requires each column to be its own vector +// in 0.6 we're aiming to allow binding iterators directly as arrays +// so you can take a vector of structs and bind iterators mapping to each field +let foo_texts: Vec = vec![/* ... */]; +let foo_bools: Vec = vec![/* ... */]; +let foo_ints: Vec = vec![/* ... */]; + +sqlx::query!( + " + INSERT INTO foo(text_column, bool_column, int_column) + SELECT * FROM UNNEST($1::text[], $2::bool[], $3::int8[]]) + ", + &foo_texts[..], + &foo_bools[..], + &foo_ints[..] +) + .execute(&db) + .await?; +``` + +Again, even with comma-expanded lists in 0.6 this will likely still be the most performant way to run bulk inserts +with Postgres--at least until we get around to implementing an interface for `COPY FROM STDIN`, though +this solution with `UNNEST()` will still be more flexible as you can use it in queries that are more complex +than just inserting into a table. + +Note that if some vectors are shorter than others, `UNNEST` will fill the corresponding columns with `NULL`s +to match the longest vector. + +For example, if `foo_texts` is length 5, `foo_bools` is length 4, `foo_ints` is length 3, the resulting table will +look like this: + +| Row # | `text_column` | `bool_column` | `int_column` | +| ----- | -------------- | -------------- | ------------- | +| 1 | `foo_texts[0]` | `foo_bools[0]` | `foo_ints[0]` | +| 2 | `foo_texts[1]` | `foo_bools[1]` | `foo_ints[1]` | +| 3 | `foo_texts[2]` | `foo_bools[2]` | `foo_ints[2]` | +| 4 | `foo_texts[3]` | `foo_bools[3]` | `NULL` | +| 5 | `foo_texts[4]` | `NULL` | `NULL` | + +See Also: +* [Postgres Manual, Section 7.2.1.4: Table Functions](https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-TABLEFUNCTIONS) +* [Postgres Manual, Section 9.19: Array Functions and Operators](https://www.postgresql.org/docs/current/functions-array.html) + +---- +### How do I compile with the macros without needing a database, e.g. in CI? + +The macros support an offline mode which saves data for existing queries to a JSON file, +so the macros can just read that file instead of talking to a database. + +See the following: + +* [the docs for `query!()`](https://docs.rs/sqlx/0.5.5/sqlx/macro.query.html#offline-mode-requires-the-offline-feature) +* [the README for `sqlx-cli`](sqlx-cli/README.md#enable-building-in-offline-mode-with-query) + +To keep `sqlx-data.json` up-to-date you need to run `cargo sqlx prepare` before every commit that +adds or changes a query; you can do this with a Git pre-commit hook: + +```shell +$ echo "cargo sqlx prepare > /dev/null 2>&1; git add sqlx-data.json > /dev/null" > .git/hooks/pre-commit +``` + +Note that this may make committing take some time as it'll cause your project to be recompiled, and +as an ergonomic choice it does _not_ block committing if `cargo sqlx prepare` fails. + +We're working on a way for the macros to save their data to the filesystem automatically which should be part of SQLx 0.6, +so your pre-commit hook would then just need to stage the changed files. + +---- + +### How do the query macros work under the hood? + +The macros work by talking to the database at compile time. When a(n) SQL client asks to create a prepared statement +from a query string, the response from the server typically includes information about the following: + +* the number of bind parameters, and their expected types if the database is capable of inferring that +* the number, names and types of result columns, as well as the original table and columns names before aliasing + +In MySQL/MariaDB, we also get boolean flag signaling if a column is `NOT NULL`, however +in Postgres and SQLite, we need to do a bit more work to determine whether a column can be `NULL` or not. + +After preparing, the Postgres driver will first look up the result columns in their source table and check if they have +a `NOT NULL` constraint. Then, it will execute `EXPLAIN (VERBOSE, FORMAT JSON) ` to determine which columns +come from half-open joins (LEFT and RIGHT joins), which makes a normally `NOT NULL` column nullable. Since the +`EXPLAIN VERBOSE` format is not stable or completely documented, this inference isn't perfect. However, it does err on +the side of producing false-positives (marking a column nullable when it's `NOT NULL`) to avoid errors at runtime. + +If you do encounter false-positives please feel free to open an issue; make sure to include your query, any relevant +schema as well as the output of `EXPLAIN (VERBOSE, FORMAT JSON) ` which will make this easier to debug. + +The SQLite driver will pull the bytecode of the prepared statement and step through it to find any instructions +that produce a null value for any column in the output. + +--- +### Why can't SQLx just look at my database schema/migrations and parse the SQL itself? + +Take a moment and think of the effort that would be required to do that. + +To implement this for a single database driver, SQLx would need to: + +* know how to parse SQL, and not just standard SQL but the specific dialect of that particular database +* know how to analyze and typecheck SQL queries in the context of the original schema +* if inferring schema from migrations it would need to simulate all the schema-changing effects of those migrations + +This is effectively reimplementing a good chunk of the database server's frontend, + +_and_ maintaining and ensuring correctness of that reimplementation, + +including bugs and idiosyncrasies, + +for the foreseeable future, + +for _every_ database we intend to support. + +Even Sisyphus would pity us. + +---- + +### Why does my project using sqlx query macros not build on docs.rs? + +Docs.rs doesn't have access to your database, so it needs to be provided a `sqlx-data.json` file and be instructed to set the `SQLX_OFFLINE` environment variable to true while compiling your project. Luckily for us, docs.rs creates a `DOCS_RS` environment variable that we can access in a custom build script to achieve this functionality. + +To do so, first, make sure that you have run `cargo sqlx prepare` to generate a `sqlx-data.json` file in your project. + +Next, create a file called `build.rs` in the root of your project directory (at the same level as `Cargo.toml`). Add the following code to it: +```rs +fn main() { + // When building in docs.rs, we want to set SQLX_OFFLINE mode to true + if std::env::var_os("DOCS_RS").is_some() { + println!("cargo:rustc-env=SQLX_OFFLINE=true"); + } +} +``` diff --git a/README.md b/README.md index 7b14e4574b..fe89e37a1d 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,12 @@
+
+
Have a question? Be sure to check the FAQ first!
+
+ +
+ SQLx is an async, pure Rust SQL crate featuring compile-time checked queries without a DSL. - **Truly Asynchronous**. Built from the ground-up using async/await for maximum concurrency. @@ -127,7 +133,7 @@ sqlx = { version = "0.5", features = [ "runtime-async-std-native-tls" ] } #### Cargo Feature Flags -- `runtime-async-std-native-tls` (on by default): Use the `async-std` runtime and `native-tls` TLS backend. +- `runtime-async-std-native-tls`: Use the `async-std` runtime and `native-tls` TLS backend. - `runtime-async-std-rustls`: Use the `async-std` runtime and `rustls` TLS backend. @@ -173,6 +179,11 @@ sqlx = { version = "0.5", features = [ "runtime-async-std-native-tls" ] } - `tls`: Add support for TLS connections. +- `offline`: Enables building the macros in offline mode when a live database is not available (such as CI). + - Requires `sqlx-cli` installed to use. See [sqlx-cli/README.md][readme-offline]. + +[readme-offline]: sqlx-cli/README.md#enable-building-in-offline-mode-with-query + ## SQLx is not an ORM! SQLx supports **compile-time checked queries**. It does not, however, do this by providing a Rust @@ -193,12 +204,24 @@ of SQLx. ## Usage +See the `examples/` folder for more in-depth usage. + ### Quickstart ```toml [dependencies] -sqlx = { version = "0.4.1", features = [ "postgres" ] } -async-std = { version = "1.6", features = [ "attributes" ] } +# PICK ONE: +# Async-std: +sqlx = { version = "0.5", features = [ "runtime-async-std-native-tls", "postgres" ] } +async-std = { version = "1", features = [ "attributes" ] } + +# Tokio: +sqlx = { version = "0.5", features = [ "runtime-tokio-native-tls" , "postgres" ] } +tokio = { version = "1", features = ["full"] } + +# Actix-web: +sqlx = { version = "0.5", features = [ "runtime-actix-native-tls" , "postgres" ] } +actix-web = "3" ``` ```rust @@ -208,6 +231,7 @@ use sqlx::postgres::PgPoolOptions; #[async_std::main] // or #[tokio::main] +// or #[actix_web::main] async fn main() -> Result<(), sqlx::Error> { // Create a connection pool // for MySQL, use MySqlPoolOptions::new() @@ -217,7 +241,7 @@ async fn main() -> Result<(), sqlx::Error> { .max_connections(5) .connect("postgres://postgres:password@localhost/test").await?; - // Make a simple query to return the given parameter + // Make a simple query to return the given parameter (use a question mark `?` instead of `$1` for MySQL) let row: (i64,) = sqlx::query_as("SELECT $1") .bind(150_i64) .fetch_one(&pool).await?; @@ -352,12 +376,14 @@ Differences from `query()`: queries against; the database does not have to contain any data but must be the same kind (MySQL, Postgres, etc.) and have the same schema as the database you will be connecting to at runtime. - For convenience, you can use a .env file to set DATABASE_URL so that you don't have to pass it every time: + For convenience, you can use [a `.env` file][dotenv] to set DATABASE_URL so that you don't have to pass it every time: ``` DATABASE_URL=mysql://localhost/my_database ``` +[dotenv]: https://github.com/dotenv-rs/dotenv#examples + The biggest downside to `query!()` is that the output type cannot be named (due to Rust not officially supporting anonymous records). To address that, there is a `query_as!()` macro that is mostly identical except that you can name the output type. diff --git a/examples/postgres/transaction/Cargo.toml b/examples/postgres/transaction/Cargo.toml new file mode 100644 index 0000000000..a51d4d37f8 --- /dev/null +++ b/examples/postgres/transaction/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "sqlx-example-postgres-transaction" +version = "0.1.0" +edition = "2018" +workspace = "../../../" + +[dependencies] +async-std = { version = "1.8.0", features = [ "attributes", "unstable" ] } +sqlx = { path = "../../../", features = [ "postgres", "tls", "runtime-async-std-native-tls" ] } +futures = "0.3.1" diff --git a/examples/postgres/transaction/README.md b/examples/postgres/transaction/README.md new file mode 100644 index 0000000000..2cfc1907c3 --- /dev/null +++ b/examples/postgres/transaction/README.md @@ -0,0 +1,18 @@ +# Postgres Transaction Example + +A simple example demonstrating how to obtain and roll back a transaction with postgres. + +## Usage + +Declare the database URL. This example does not include any reading or writing of data. + +``` +export DATABASE_URL="postgres://postgres@localhost/postgres" +``` + +Run. + +``` +cargo run +``` + diff --git a/examples/postgres/transaction/migrations/20200718111257_todos.sql b/examples/postgres/transaction/migrations/20200718111257_todos.sql new file mode 100644 index 0000000000..6599f8c10a --- /dev/null +++ b/examples/postgres/transaction/migrations/20200718111257_todos.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS todos +( + id BIGSERIAL PRIMARY KEY, + description TEXT NOT NULL, + done BOOLEAN NOT NULL DEFAULT FALSE +); diff --git a/examples/postgres/transaction/src/main.rs b/examples/postgres/transaction/src/main.rs new file mode 100644 index 0000000000..50539fe2ac --- /dev/null +++ b/examples/postgres/transaction/src/main.rs @@ -0,0 +1,37 @@ +use sqlx::query; + +#[async_std::main] +async fn main() -> Result<(), Box> { + let conn_str = + std::env::var("DATABASE_URL").expect("Env var DATABASE_URL is required for this example."); + let pool = sqlx::PgPool::connect(&conn_str).await?; + + let mut transaction = pool.begin().await?; + + let test_id = 1; + query!( + r#"INSERT INTO todos (id, description) + VALUES ( $1, $2 ) + "#, + test_id, + "test todo" + ) + .execute(&mut transaction) + .await?; + + // check that inserted todo can be fetched + let _ = query!(r#"SELECT FROM todos WHERE id = $1"#, test_id) + .fetch_one(&mut transaction) + .await?; + + transaction.rollback(); + + // check that inserted todo is now gone + let inserted_todo = query!(r#"SELECT FROM todos WHERE id = $1"#, test_id) + .fetch_one(&pool) + .await; + + assert!(inserted_todo.is_err()); + + Ok(()) +} diff --git a/prep-release.sh b/prep-release.sh new file mode 100755 index 0000000000..79c38330f1 --- /dev/null +++ b/prep-release.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env sh +set -ex + +VERSION=$1 + +if [ -z "$VERSION" ] +then + echo "USAGE: ./prep-release.sh " + exit 1 +fi + +cargo set-version -p sqlx-rt "$VERSION" +cargo set-version -p sqlx-core "$VERSION" +cargo set-version -p sqlx-macros "$VERSION" +cargo set-version -p sqlx "$VERSION" +cargo set-version -p sqlx-cli "$VERSION" \ No newline at end of file diff --git a/sqlx-cli/Cargo.toml b/sqlx-cli/Cargo.toml index 0c0d92a0bb..a31483a872 100644 --- a/sqlx-cli/Cargo.toml +++ b/sqlx-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlx-cli" -version = "0.5.5" +version = "0.5.9" description = "Command-line utility for SQLx, the Rust SQL toolkit." edition = "2018" readme = "README.md" @@ -27,20 +27,24 @@ path = "src/bin/cargo-sqlx.rs" [dependencies] dotenv = "0.15" tokio = { version = "1.0.1", features = ["macros", "rt", "rt-multi-thread"] } -sqlx = { version = "0.5.5", path = "..", default-features = false, features = [ +sqlx = { version = "0.5.9", path = "..", default-features = false, features = [ "runtime-async-std-native-tls", "migrate", "any", "offline", ] } futures = "0.3" +# FIXME: we need to fix both of these versions until Clap 3.0 proper is released, then we can drop `clap_derive` +# https://github.com/launchbadge/sqlx/issues/1378 +# https://github.com/clap-rs/clap/issues/2705 clap = "=3.0.0-beta.2" +clap_derive = "=3.0.0-beta.2" chrono = "0.4" anyhow = "1.0" url = { version = "2.1.1", default-features = false } async-trait = "0.1.30" console = "0.14.1" -dialoguer = "0.8.0" +promptly = "0.3.0" serde_json = "1.0.53" serde = { version = "1.0.110", features = ["derive"] } glob = "0.3.0" diff --git a/sqlx-cli/README.md b/sqlx-cli/README.md index f098dad8ab..2d64cf8d97 100644 --- a/sqlx-cli/README.md +++ b/sqlx-cli/README.md @@ -13,6 +13,9 @@ $ cargo install sqlx-cli # only for postgres $ cargo install sqlx-cli --no-default-features --features postgres + +# use vendored OpenSSL (build from source) +$ cargo install sqlx-cli --features openssl-vendored ``` ### Usage @@ -49,6 +52,40 @@ $ sqlx migrate run Compares the migration history of the running database against the `migrations/` folder and runs any scripts that are still pending. +#### Reverting Migrations + +If you would like to create _reversible_ migrations with corresponding "up" and "down" scripts, you use the `-r` flag when creating new migrations: + +```bash +$ sqlx migrate add -r +Creating migrations/20211001154420_.up.sql +Creating migrations/20211001154420_.down.sql +``` + +After that, you can run these as above: + +```bash +$ sqlx migrate run +Applied migrations/20211001154420 (32.517835ms) +``` + +And reverts work as well: + +```bash +$ sqlx migrate revert +Applied 20211001154420/revert +``` + +**Note**: attempting to mix "simple" migrations with reversible migrations with result in an error. + +```bash +$ sqlx migrate add +Creating migrations/20211001154420_.sql + +$ sqlx migrate add -r +error: cannot mix reversible migrations with simple migrations. All migrations should be reversible or simple migrations +``` + #### Enable building in "offline mode" with `query!()` Note: must be run as `cargo sqlx`. diff --git a/sqlx-cli/src/bin/cargo-sqlx.rs b/sqlx-cli/src/bin/cargo-sqlx.rs index bfae83e995..a924af4244 100644 --- a/sqlx-cli/src/bin/cargo-sqlx.rs +++ b/sqlx-cli/src/bin/cargo-sqlx.rs @@ -1,5 +1,6 @@ use clap::{crate_version, AppSettings, FromArgMatches, IntoApp}; use console::style; +use dotenv::dotenv; use sqlx_cli::Opt; use std::{env, process}; @@ -9,6 +10,7 @@ async fn main() { // so we want to notch out that superfluous "sqlx" let args = env::args_os().skip(2); + dotenv().ok(); let matches = Opt::into_app() .version(crate_version!()) .bin_name("cargo sqlx") diff --git a/sqlx-cli/src/bin/sqlx.rs b/sqlx-cli/src/bin/sqlx.rs index 0d18278577..e413581bb9 100644 --- a/sqlx-cli/src/bin/sqlx.rs +++ b/sqlx-cli/src/bin/sqlx.rs @@ -1,9 +1,11 @@ use clap::{crate_version, FromArgMatches, IntoApp}; use console::style; +use dotenv::dotenv; use sqlx_cli::Opt; #[tokio::main] async fn main() { + dotenv().ok(); let matches = Opt::into_app().version(crate_version!()).get_matches(); // no special handling here diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 6babb21a36..7521b1fb68 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -1,6 +1,6 @@ use crate::migrate; use console::style; -use dialoguer::Confirm; +use promptly::{prompt, ReadlineError}; use sqlx::any::Any; use sqlx::migrate::MigrateDatabase; @@ -13,16 +13,7 @@ pub async fn create(uri: &str) -> anyhow::Result<()> { } pub async fn drop(uri: &str, confirm: bool) -> anyhow::Result<()> { - if confirm - && !Confirm::new() - .with_prompt(format!( - "\nAre you sure you want to drop the database at {}?", - style(uri).cyan() - )) - .wait_for_newline(true) - .default(false) - .interact()? - { + if confirm && !ask_to_continue(uri) { return Ok(()); } @@ -42,3 +33,28 @@ pub async fn setup(migration_source: &str, uri: &str) -> anyhow::Result<()> { create(uri).await?; migrate::run(migration_source, uri, false, false).await } + +fn ask_to_continue(uri: &str) -> bool { + loop { + let r: Result = + prompt(format!("Drop database at {}? (y/n)", style(uri).cyan())); + match r { + Ok(response) => { + if response == "n" || response == "N" { + return false; + } else if response == "y" || response == "Y" { + return true; + } else { + println!( + "Response not recognized: {}\nPlease type 'y' or 'n' and press enter.", + response + ); + } + } + Err(e) => { + println!("{}", e); + return false; + } + } + } +} diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index 5dd4aeefc6..d02f4fa3b4 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -1,7 +1,6 @@ +use anyhow::Result; + use crate::opt::{Command, DatabaseCommand, MigrateCommand}; -use anyhow::anyhow; -use dotenv::dotenv; -use std::env; mod database; // mod migration; @@ -12,15 +11,7 @@ mod prepare; pub use crate::opt::Opt; -pub async fn run(opt: Opt) -> anyhow::Result<()> { - dotenv().ok(); - - let database_url = match opt.database_url { - Some(db_url) => db_url, - None => env::var("DATABASE_URL") - .map_err(|_| anyhow!("The DATABASE_URL environment variable must be set"))?, - }; - +pub async fn run(opt: Opt) -> Result<()> { match opt.command { Command::Migrate(migrate) => match migrate.command { MigrateCommand::Add { @@ -30,33 +21,47 @@ pub async fn run(opt: Opt) -> anyhow::Result<()> { MigrateCommand::Run { dry_run, ignore_missing, + database_url, } => migrate::run(&migrate.source, &database_url, dry_run, ignore_missing).await?, MigrateCommand::Revert { dry_run, ignore_missing, + database_url, } => migrate::revert(&migrate.source, &database_url, dry_run, ignore_missing).await?, - MigrateCommand::Info => migrate::info(&migrate.source, &database_url).await?, + MigrateCommand::Info { database_url } => { + migrate::info(&migrate.source, &database_url).await? + } + MigrateCommand::BuildScript { force } => migrate::build_script(&migrate.source, force)?, }, Command::Database(database) => match database.command { - DatabaseCommand::Create => database::create(&database_url).await?, - DatabaseCommand::Drop { yes } => database::drop(&database_url, !yes).await?, - DatabaseCommand::Reset { yes, source } => { - database::reset(&source, &database_url, !yes).await? + DatabaseCommand::Create { database_url } => database::create(&database_url).await?, + DatabaseCommand::Drop { yes, database_url } => { + database::drop(&database_url, !yes).await? } - DatabaseCommand::Setup { source } => database::setup(&source, &database_url).await?, + DatabaseCommand::Reset { + yes, + source, + database_url, + } => database::reset(&source, &database_url, !yes).await?, + DatabaseCommand::Setup { + source, + database_url, + } => database::setup(&source, &database_url).await?, }, Command::Prepare { check: false, merged, args, + database_url, } => prepare::run(&database_url, merged, args)?, Command::Prepare { check: true, merged, args, + database_url, } => prepare::check(&database_url, merged, args)?, }; diff --git a/sqlx-cli/src/migrate.rs b/sqlx-cli/src/migrate.rs index 20d61f1985..523cf83fa4 100644 --- a/sqlx-cli/src/migrate.rs +++ b/sqlx-cli/src/migrate.rs @@ -42,6 +42,11 @@ pub async fn add( ) -> anyhow::Result<()> { fs::create_dir_all(migration_source).context("Unable to create migrations directory")?; + // if the migrations directory is empty + let has_existing_migrations = fs::read_dir(migration_source) + .map(|mut dir| dir.next().is_some()) + .unwrap_or(false); + let migrator = Migrator::new(Path::new(migration_source)).await?; // This checks if all existing migrations are of the same type as the reverisble flag passed for migration in migrator.iter() { @@ -74,6 +79,31 @@ pub async fn add( )?; } + if !has_existing_migrations { + let quoted_source = if migration_source != "migrations" { + format!("{:?}", migration_source) + } else { + "".to_string() + }; + + print!( + r#" +Congratulations on creating your first migration! + +Did you know you can embed your migrations in your application binary? +On startup, after creating your database connection or pool, add: + +sqlx::migrate!({}).run(<&your_pool OR &mut your_connection>).await?; + +Note that the compiler won't pick up new migrations if no Rust source files have changed. +You can create a Cargo build script to work around this with `sqlx migrate build-script`. + +See: https://docs.rs/sqlx/0.5/sqlx/macro.migrate.html +"#, + quoted_source + ); + } + Ok(()) } @@ -245,3 +275,30 @@ pub async fn revert( Ok(()) } + +pub fn build_script(migration_source: &str, force: bool) -> anyhow::Result<()> { + anyhow::ensure!( + Path::new("Cargo.toml").exists(), + "must be run in a Cargo project root" + ); + + anyhow::ensure!( + (force || !Path::new("build.rs").exists()), + "build.rs already exists; use --force to overwrite" + ); + + let contents = format!( + r#"// generated by `sqlx migrate build-script` +fn main() {{ + // trigger recompilation when a new migration is added + println!("cargo:rerun-if-changed={}"); +}}"#, + migration_source + ); + + fs::write("build.rs", contents)?; + + println!("Created `build.rs`; be sure to check it into version control!"); + + Ok(()) +} diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 8d912668bf..20243a5e91 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -4,9 +4,6 @@ use clap::Clap; pub struct Opt { #[clap(subcommand)] pub command: Command, - - #[clap(short = 'D', long)] - pub database_url: Option, } #[derive(Clap, Debug)] @@ -36,6 +33,10 @@ pub enum Command { /// Arguments to be passed to `cargo rustc ...`. #[clap(last = true)] args: Vec, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, #[clap(alias = "mig")] @@ -52,7 +53,11 @@ pub struct DatabaseOpt { #[derive(Clap, Debug)] pub enum DatabaseCommand { /// Creates the database specified in your DATABASE_URL. - Create, + Create { + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, + }, /// Drops the database specified in your DATABASE_URL. Drop { @@ -60,6 +65,10 @@ pub enum DatabaseCommand { /// your database. #[clap(short)] yes: bool, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, /// Drops the database specified in your DATABASE_URL, re-creates it, and runs any pending migrations. @@ -72,6 +81,10 @@ pub enum DatabaseCommand { /// Path to folder containing migrations. #[clap(long, default_value = "migrations")] source: String, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, /// Creates the database specified in your DATABASE_URL and runs any pending migrations. @@ -79,6 +92,10 @@ pub enum DatabaseCommand { /// Path to folder containing migrations. #[clap(long, default_value = "migrations")] source: String, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, } @@ -115,6 +132,10 @@ pub enum MigrateCommand { /// Ignore applied migrations that missing in the resolved migrations #[clap(long)] ignore_missing: bool, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, /// Revert the latest migration with a down file. @@ -126,8 +147,25 @@ pub enum MigrateCommand { /// Ignore applied migrations that missing in the resolved migrations #[clap(long)] ignore_missing: bool, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, /// List all available migrations. - Info, + Info { + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, env)] + database_url: String, + }, + + /// Generate a `build.rs` to trigger recompilation when a new migration is added. + /// + /// Must be run in a Cargo project root. + BuildScript { + /// Overwrite the build script if it already exists. + #[clap(long)] + force: bool, + }, } diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index c9c6e5b46a..a6bd78f37d 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlx-core" -version = "0.5.5" +version = "0.5.9" repository = "https://github.com/launchbadge/sqlx" description = "Core of SQLx, the rust SQL toolkit. Not intended to be used directly." license = "MIT OR Apache-2.0" @@ -54,6 +54,7 @@ all-types = [ "bigdecimal", "decimal", "ipnetwork", + "mac_address", "json", "uuid", "bit-vec", @@ -100,7 +101,7 @@ offline = ["serde", "either/serde"] [dependencies] ahash = "0.7.2" atoi = "0.4.0" -sqlx-rt = { path = "../sqlx-rt", version = "0.5.5" } +sqlx-rt = { path = "../sqlx-rt", version = "0.5.9"} base64 = { version = "0.13.0", default-features = false, optional = true, features = ["std"] } bigdecimal_ = { version = "0.2.0", optional = true, package = "bigdecimal" } rust_decimal = { version = "1.8.1", optional = true } @@ -109,7 +110,7 @@ bitflags = { version = "1.2.1", default-features = false } bytes = "1.0.0" byteorder = { version = "1.3.4", default-features = false, features = ["std"] } chrono = { version = "0.4.11", default-features = false, features = ["clock"], optional = true } -crc = { version = "1.8.1", optional = true } +crc = { version = "2.0.0", optional = true } crossbeam-queue = "0.3.1" crossbeam-channel = "0.5.0" crossbeam-utils = { version = "0.8.1", default-features = false } @@ -119,14 +120,16 @@ encoding_rs = { version = "0.8.23", optional = true } either = "1.5.3" futures-channel = { version = "0.3.5", default-features = false, features = ["sink", "alloc", "std"] } futures-core = { version = "0.3.5", default-features = false } -futures-util = { version = "0.3.5", features = ["sink"] } +futures-intrusive = "0.4.0" +futures-util = { version = "0.3.5", default-features = false, features = ["alloc", "sink"] } generic-array = { version = "0.14.4", default-features = false, optional = true } hex = "0.4.2" -hmac = { version = "0.10.1", default-features = false, optional = true } +hmac = { version = "0.11.0", default-features = false, optional = true } itoa = "0.4.5" ipnetwork = { version = "0.17.0", default-features = false, optional = true } +mac_address = { version = "1.1", default-features = false, optional = true } libc = "0.2.71" -libsqlite3-sys = { version = "0.22.0", optional = true, default-features = false, features = [ +libsqlite3-sys = { version = "0.23.1", optional = true, default-features = false, features = [ "pkg-config", "vcpkg", "bundled", @@ -158,8 +161,9 @@ webpki-roots = { version = "0.21.0", optional = true } whoami = "1.0.1" stringprep = "0.1.2" bstr = { version = "0.2.14", default-features = false, features = ["std"], optional = true } -git2 = { version = "0.13.12", default-features = false, optional = true } +git2 = { version = "0.13.20", default-features = false, optional = true } hashlink = "0.7.0" +indexmap = "1.6.2" [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2.2", features = ["js"] } diff --git a/sqlx-core/src/any/kind.rs b/sqlx-core/src/any/kind.rs index 8d5454ed45..b3278a9650 100644 --- a/sqlx-core/src/any/kind.rs +++ b/sqlx-core/src/any/kind.rs @@ -1,7 +1,7 @@ use crate::error::Error; use std::str::FromStr; -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum AnyKind { #[cfg(feature = "postgres")] Postgres, diff --git a/sqlx-core/src/any/migrate.rs b/sqlx-core/src/any/migrate.rs index 1825ff939b..04fa659b74 100644 --- a/sqlx-core/src/any/migrate.rs +++ b/sqlx-core/src/any/migrate.rs @@ -223,7 +223,10 @@ impl Migrate for AnyConnection { AnyConnectionKind::MySql(conn) => conn.revert(migration), #[cfg(feature = "mssql")] - AnyConnectionKind::Mssql(_conn) => unimplemented!(), + AnyConnectionKind::Mssql(_conn) => { + let _ = migration; + unimplemented!() + } } } } diff --git a/sqlx-core/src/any/mod.rs b/sqlx-core/src/any/mod.rs index a5f794820d..9dd7fea9f8 100644 --- a/sqlx-core/src/any/mod.rs +++ b/sqlx-core/src/any/mod.rs @@ -1,5 +1,7 @@ //! Generic database driver with the specific driver selected at runtime. +use crate::executor::Executor; + #[macro_use] mod decode; @@ -45,6 +47,10 @@ pub type AnyPool = crate::pool::Pool; pub type AnyPoolOptions = crate::pool::PoolOptions; +/// An alias for [`Executor<'_, Database = Any>`][Executor]. +pub trait AnyExecutor<'c>: Executor<'c, Database = Any> {} +impl<'c, T: Executor<'c, Database = Any>> AnyExecutor<'c> for T {} + // NOTE: required due to the lack of lazy normalization impl_into_arguments_for_arguments!(AnyArguments<'q>); impl_executor_for_pool_connection!(Any, AnyConnection, AnyRow); diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 49b05699bf..245ffff178 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -36,7 +36,7 @@ pub enum Error { /// Error returned from the database. #[error("error returned from database: {0}")] - Database(Box), + Database(#[source] Box), /// Error communicating with the database backend. #[error("error communicating with the server: {0}")] @@ -105,6 +105,8 @@ pub enum Error { Migrate(#[source] Box), } +impl StdError for Box {} + impl Error { pub fn into_database_error(self) -> Option> { match self { diff --git a/sqlx-core/src/ext/async_stream.rs b/sqlx-core/src/ext/async_stream.rs index 3200a2f05d..1f24732da2 100644 --- a/sqlx-core/src/ext/async_stream.rs +++ b/sqlx-core/src/ext/async_stream.rs @@ -93,9 +93,9 @@ macro_rules! try_stream { ($($block:tt)*) => { crate::ext::async_stream::TryAsyncStream::new(move |mut sender| async move { macro_rules! r#yield { - ($v:expr) => { + ($v:expr) => {{ let _ = futures_util::sink::SinkExt::send(&mut sender, Ok($v)).await; - } + }} } $($block)* diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index 6b5b55a4ae..8f376cbfb0 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -15,7 +15,7 @@ pub struct BufStream where S: AsyncRead + AsyncWrite + Unpin, { - stream: S, + pub(crate) stream: S, // writes with `write` to the underlying stream are buffered // this can be flushed with `flush` diff --git a/sqlx-core/src/mssql/mod.rs b/sqlx-core/src/mssql/mod.rs index 068e77a750..ed3b325871 100644 --- a/sqlx-core/src/mssql/mod.rs +++ b/sqlx-core/src/mssql/mod.rs @@ -1,5 +1,7 @@ //! Microsoft SQL (MSSQL) database driver. +use crate::executor::Executor; + mod arguments; mod column; mod connection; @@ -32,6 +34,10 @@ pub use value::{MssqlValue, MssqlValueRef}; /// An alias for [`Pool`][crate::pool::Pool], specialized for MSSQL. pub type MssqlPool = crate::pool::Pool; +/// An alias for [`Executor<'_, Database = Mssql>`][Executor]. +pub trait MssqlExecutor<'c>: Executor<'c, Database = Mssql> {} +impl<'c, T: Executor<'c, Database = Mssql>> MssqlExecutor<'c> for T {} + // NOTE: required due to the lack of lazy normalization impl_into_arguments_for_arguments!(MssqlArguments); impl_executor_for_pool_connection!(Mssql, MssqlConnection, MssqlRow); diff --git a/sqlx-core/src/mssql/types/str.rs b/sqlx-core/src/mssql/types/str.rs index 4902d783be..048dd84cd3 100644 --- a/sqlx-core/src/mssql/types/str.rs +++ b/sqlx-core/src/mssql/types/str.rs @@ -5,6 +5,7 @@ use crate::mssql::io::MssqlBufMutExt; use crate::mssql::protocol::type_info::{Collation, CollationFlags, DataType, TypeInfo}; use crate::mssql::{Mssql, MssqlTypeInfo, MssqlValueRef}; use crate::types::Type; +use std::borrow::Cow; impl Type for str { fn type_info() -> MssqlTypeInfo { @@ -81,3 +82,33 @@ impl Decode<'_, Mssql> for String { .into_owned()) } } + +impl Encode<'_, Mssql> for Cow<'_, str> { + fn produces(&self) -> Option { + match self { + Cow::Borrowed(str) => <&str as Encode>::produces(str), + Cow::Owned(str) => <&str as Encode>::produces(&(str.as_ref())), + } + } + + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + match self { + Cow::Borrowed(str) => <&str as Encode>::encode_by_ref(str, buf), + Cow::Owned(str) => <&str as Encode>::encode_by_ref(&(str.as_ref()), buf), + } + } +} + +impl<'r> Decode<'r, Mssql> for Cow<'r, str> { + fn decode(value: MssqlValueRef<'r>) -> Result { + Ok(Cow::Owned( + value + .type_info + .0 + .encoding()? + .decode_without_bom_handling(value.as_bytes()?) + .0 + .into_owned(), + )) + } +} diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index 012714d710..9cb0690bda 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -4,7 +4,7 @@ use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::ext::ustr::UStr; use crate::logger::QueryLogger; -use crate::mysql::connection::stream::Busy; +use crate::mysql::connection::stream::Waiting; use crate::mysql::io::MySqlBufExt; use crate::mysql::protocol::response::Status; use crate::mysql::protocol::statement::{ @@ -93,7 +93,7 @@ impl MySqlConnection { let mut logger = QueryLogger::new(sql, self.log_settings.clone()); self.stream.wait_until_ready().await?; - self.stream.busy = Busy::Result; + self.stream.waiting.push_back(Waiting::Result); Ok(Box::pin(try_stream! { // make a slot for the shared column data @@ -146,12 +146,12 @@ impl MySqlConnection { continue; } - self.stream.busy = Busy::NotBusy; + self.stream.waiting.pop_front(); return Ok(()); } // otherwise, this first packet is the start of the result-set metadata, - self.stream.busy = Busy::Row; + *self.stream.waiting.front_mut().unwrap() = Waiting::Row; let num_columns = packet.get_uint_lenenc() as usize; // column count @@ -179,11 +179,11 @@ impl MySqlConnection { if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { // more result sets exist, continue to the next one - self.stream.busy = Busy::Result; + *self.stream.waiting.front_mut().unwrap() = Waiting::Result; break; } - self.stream.busy = Busy::NotBusy; + self.stream.waiting.pop_front(); return Ok(()); } diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs index 509426a63d..4ade06beeb 100644 --- a/sqlx-core/src/mysql/connection/mod.rs +++ b/sqlx-core/src/mysql/connection/mod.rs @@ -16,7 +16,7 @@ mod executor; mod stream; mod tls; -pub(crate) use stream::{Busy, MySqlStream}; +pub(crate) use stream::{MySqlStream, Waiting}; const MAX_PACKET_SIZE: u32 = 1024; diff --git a/sqlx-core/src/mysql/connection/stream.rs b/sqlx-core/src/mysql/connection/stream.rs index 8b2f453608..e43cf253c6 100644 --- a/sqlx-core/src/mysql/connection/stream.rs +++ b/sqlx-core/src/mysql/connection/stream.rs @@ -1,3 +1,4 @@ +use std::collections::VecDeque; use std::ops::{Deref, DerefMut}; use bytes::{Buf, Bytes}; @@ -16,15 +17,13 @@ pub struct MySqlStream { pub(crate) server_version: (u16, u16, u16), pub(super) capabilities: Capabilities, pub(crate) sequence_id: u8, - pub(crate) busy: Busy, + pub(crate) waiting: VecDeque, pub(crate) charset: CharSet, pub(crate) collation: Collation, } #[derive(Debug, PartialEq, Eq)] -pub(crate) enum Busy { - NotBusy, - +pub(crate) enum Waiting { // waiting for a result set Result, @@ -65,7 +64,7 @@ impl MySqlStream { } Ok(Self { - busy: Busy::NotBusy, + waiting: VecDeque::new(), capabilities, server_version: (0, 0, 0), sequence_id: 0, @@ -80,32 +79,32 @@ impl MySqlStream { self.stream.flush().await?; } - while self.busy != Busy::NotBusy { - while self.busy == Busy::Row { + while !self.waiting.is_empty() { + while self.waiting.front() == Some(&Waiting::Row) { let packet = self.recv_packet().await?; if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.capabilities)?; - self.busy = if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { - Busy::Result + if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { + *self.waiting.front_mut().unwrap() = Waiting::Result; } else { - Busy::NotBusy + self.waiting.pop_front(); }; } } - while self.busy == Busy::Result { + while self.waiting.front() == Some(&Waiting::Result) { let packet = self.recv_packet().await?; if packet[0] == 0x00 || packet[0] == 0xff { let ok = packet.ok()?; if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { - self.busy = Busy::NotBusy; + self.waiting.pop_front(); } } else { - self.busy = Busy::Row; + *self.waiting.front_mut().unwrap() = Waiting::Row; self.skip_result_metadata(packet).await?; } } @@ -150,7 +149,7 @@ impl MySqlStream { // TODO: packet joining if payload[0] == 0xff { - self.busy = Busy::NotBusy; + self.waiting.pop_front(); // instead of letting this packet be looked at everywhere, we check here // and emit a proper Error diff --git a/sqlx-core/src/mysql/migrate.rs b/sqlx-core/src/mysql/migrate.rs index 248fd6298e..c3898e9fc2 100644 --- a/sqlx-core/src/mysql/migrate.rs +++ b/sqlx-core/src/mysql/migrate.rs @@ -8,7 +8,6 @@ use crate::mysql::{MySql, MySqlConnectOptions, MySqlConnection}; use crate::query::query; use crate::query_as::query_as; use crate::query_scalar::query_scalar; -use crc::crc32; use futures_core::future::BoxFuture; use std::str::FromStr; use std::time::Duration; @@ -266,9 +265,10 @@ async fn current_database(conn: &mut MySqlConnection) -> Result String { + const CRC_IEEE: crc::Crc = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); // 0x3d32ad9e chosen by fair dice roll format!( "{:x}", - 0x3d32ad9e * (crc32::checksum_ieee(database_name.as_bytes()) as i64) + 0x3d32ad9e * (CRC_IEEE.checksum(database_name.as_bytes()) as i64) ) } diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index dc7f969936..e108e8591f 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -1,5 +1,7 @@ //! **MySQL** database driver. +use crate::executor::Executor; + mod arguments; mod collation; mod column; @@ -39,6 +41,10 @@ pub type MySqlPool = crate::pool::Pool; /// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for MySQL. pub type MySqlPoolOptions = crate::pool::PoolOptions; +/// An alias for [`Executor<'_, Database = MySql>`][Executor]. +pub trait MySqlExecutor<'c>: Executor<'c, Database = MySql> {} +impl<'c, T: Executor<'c, Database = MySql>> MySqlExecutor<'c> for T {} + // NOTE: required due to the lack of lazy normalization impl_into_arguments_for_arguments!(MySqlArguments); impl_executor_for_pool_connection!(MySql, MySqlConnection, MySqlRow); diff --git a/sqlx-core/src/mysql/transaction.rs b/sqlx-core/src/mysql/transaction.rs index b62fc143b5..97cb121d0e 100644 --- a/sqlx-core/src/mysql/transaction.rs +++ b/sqlx-core/src/mysql/transaction.rs @@ -2,7 +2,7 @@ use futures_core::future::BoxFuture; use crate::error::Error; use crate::executor::Executor; -use crate::mysql::connection::Busy; +use crate::mysql::connection::Waiting; use crate::mysql::protocol::text::Query; use crate::mysql::{MySql, MySqlConnection}; use crate::transaction::{ @@ -57,7 +57,7 @@ impl TransactionManager for MySqlTransactionManager { let depth = conn.transaction_depth; if depth > 0 { - conn.stream.busy = Busy::Result; + conn.stream.waiting.push_back(Waiting::Result); conn.stream.sequence_id = 0; conn.stream .write_packet(Query(&*rollback_ansi_transaction_sql(depth))); diff --git a/sqlx-core/src/mysql/types/chrono.rs b/sqlx-core/src/mysql/types/chrono.rs index 5a261804bf..76e8b2985d 100644 --- a/sqlx-core/src/mysql/types/chrono.rs +++ b/sqlx-core/src/mysql/types/chrono.rs @@ -1,7 +1,7 @@ use std::convert::TryFrom; use bytes::Buf; -use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; +use chrono::{DateTime, Datelike, Local, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -21,12 +21,14 @@ impl Type for DateTime { } } +/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl Encode<'_, MySql> for DateTime { fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { Encode::::encode(&self.naive_utc(), buf) } } +/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl<'r> Decode<'r, MySql> for DateTime { fn decode(value: MySqlValueRef<'r>) -> Result { let naive: NaiveDateTime = Decode::::decode(value)?; @@ -35,6 +37,30 @@ impl<'r> Decode<'r, MySql> for DateTime { } } +impl Type for DateTime { + fn type_info() -> MySqlTypeInfo { + MySqlTypeInfo::binary(ColumnType::Timestamp) + } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) + } +} + +/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). +impl Encode<'_, MySql> for DateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + Encode::::encode(&self.naive_utc(), buf) + } +} + +/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). +impl<'r> Decode<'r, MySql> for DateTime { + fn decode(value: MySqlValueRef<'r>) -> Result { + Ok( as Decode<'r, MySql>>::decode(value)?.with_timezone(&Local)) + } +} + impl Type for NaiveTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Time) diff --git a/sqlx-core/src/mysql/types/str.rs b/sqlx-core/src/mysql/types/str.rs index 19e3de62c9..076858901b 100644 --- a/sqlx-core/src/mysql/types/str.rs +++ b/sqlx-core/src/mysql/types/str.rs @@ -5,6 +5,7 @@ use crate::mysql::io::MySqlBufMutExt; use crate::mysql::protocol::text::{ColumnFlags, ColumnType}; use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; use crate::types::Type; +use std::borrow::Cow; const COLLATE_UTF8_GENERAL_CI: u16 = 33; const COLLATE_UTF8_UNICODE_CI: u16 = 192; @@ -80,3 +81,18 @@ impl Decode<'_, MySql> for String { <&str as Decode>::decode(value).map(ToOwned::to_owned) } } + +impl Encode<'_, MySql> for Cow<'_, str> { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + match self { + Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), + Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), + } + } +} + +impl<'r> Decode<'r, MySql> for Cow<'r, str> { + fn decode(value: MySqlValueRef<'r>) -> Result { + value.as_str().map(Cow::Borrowed) + } +} diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 732c1a8c92..88864566c1 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -1,13 +1,17 @@ -use super::inner::{DecrementSizeGuard, SharedPool}; -use crate::connection::Connection; -use crate::database::Database; -use crate::error::Error; -use sqlx_rt::spawn; use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::Instant; +use futures_intrusive::sync::SemaphoreReleaser; + +use crate::connection::Connection; +use crate::database::Database; +use crate::error::Error; + +use super::inner::{DecrementSizeGuard, SharedPool}; +use std::future::Future; + /// A connection managed by a [`Pool`][crate::pool::Pool]. /// /// Will be returned to the pool on-drop. @@ -28,8 +32,8 @@ pub(super) struct Idle { /// RAII wrapper for connections being handled by functions that may drop them pub(super) struct Floating<'p, C> { - inner: C, - guard: DecrementSizeGuard<'p>, + pub(super) inner: C, + pub(super) guard: DecrementSizeGuard<'p>, } const DEREF_ERR: &str = "(bug) connection already released to pool"; @@ -57,43 +61,85 @@ impl DerefMut for PoolConnection { impl PoolConnection { /// Explicitly release a connection from the pool - pub fn release(mut self) -> DB::Connection { + #[deprecated = "renamed to `.detach()` for clarity"] + pub fn release(self) -> DB::Connection { + self.detach() + } + + /// Detach this connection from the pool, allowing it to open a replacement. + /// + /// Note that if your application uses a single shared pool, this + /// effectively lets the application exceed the `max_connections` setting. + /// + /// If you want the pool to treat this connection as permanently checked-out, + /// use [`.leak()`][Self::leak] instead. + pub fn detach(mut self) -> DB::Connection { self.live .take() .expect("PoolConnection double-dropped") .float(&self.pool) .detach() } + + /// Detach this connection from the pool, treating it as permanently checked-out. + /// + /// This effectively will reduce the maximum capacity of the pool by 1 every time it is used. + /// + /// If you don't want to impact the pool's capacity, use [`.detach()`][Self::detach] instead. + pub fn leak(mut self) -> DB::Connection { + self.live.take().expect("PoolConnection double-dropped").raw + } + + /// Test the connection to make sure it is still live before returning it to the pool. + /// + /// This effectively runs the drop handler eagerly instead of spawning a task to do it. + pub(crate) fn return_to_pool(&mut self) -> impl Future + Send + 'static { + // we want these to happen synchronously so the drop handler doesn't try to spawn a task anyway + // this also makes the returned future `'static` + let live = self.live.take(); + let pool = self.pool.clone(); + + async move { + let mut floating = if let Some(live) = live { + live.float(&pool) + } else { + return; + }; + + // test the connection on-release to ensure it is still viable + // if an Executor future/stream is dropped during an `.await` call, the connection + // is likely to be left in an inconsistent state, in which case it should not be + // returned to the pool; also of course, if it was dropped due to an error + // this is simply a band-aid as SQLx-next (0.6) connections should be able + // to recover from cancellations + if let Err(e) = floating.raw.ping().await { + log::warn!( + "error occurred while testing the connection on-release: {}", + e + ); + + // we now consider the connection to be broken; just drop it to close + // trying to close gracefully might cause something weird to happen + drop(floating); + } else { + // if the connection is still viable, release it to the pool + pool.release(floating); + } + } + } } /// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from. impl Drop for PoolConnection { fn drop(&mut self) { - if let Some(live) = self.live.take() { - let pool = self.pool.clone(); - spawn(async move { - let mut floating = live.float(&pool); - - // test the connection on-release to ensure it is still viable - // if an Executor future/stream is dropped during an `.await` call, the connection - // is likely to be left in an inconsistent state, in which case it should not be - // returned to the pool; also of course, if it was dropped due to an error - // this is simply a band-aid as SQLx-next (0.6) connections should be able - // to recover from cancellations - if let Err(e) = floating.raw.ping().await { - log::warn!( - "error occurred while testing the connection on-release: {}", - e - ); - - // we now consider the connection to be broken; just drop it to close - // trying to close gracefully might cause something weird to happen - drop(floating); - } else { - // if the connection is still viable, release it to th epool - pool.release(floating); - } - }); + if self.live.is_some() { + #[cfg(not(feature = "_rt-async-std"))] + if let Ok(handle) = sqlx_rt::Handle::try_current() { + handle.spawn(self.return_to_pool()); + } + + #[cfg(feature = "_rt-async-std")] + sqlx_rt::spawn(self.return_to_pool()); } } } @@ -102,7 +148,8 @@ impl Live { pub fn float(self, pool: &SharedPool) -> Floating<'_, Self> { Floating { inner: self, - guard: DecrementSizeGuard::new(pool), + // create a new guard from a previously leaked permit + guard: DecrementSizeGuard::new_permit(pool), } } @@ -128,13 +175,6 @@ impl DerefMut for Idle { } } -impl<'s, C> Floating<'s, C> { - pub fn into_leakable(self) -> C { - self.guard.cancel(); - self.inner - } -} - impl<'s, DB: Database> Floating<'s, Live> { pub fn new_live(conn: DB::Connection, guard: DecrementSizeGuard<'s>) -> Self { Self { @@ -161,6 +201,11 @@ impl<'s, DB: Database> Floating<'s, Live> { } } + pub async fn close(self) -> Result<(), Error> { + // `guard` is dropped as intended + self.inner.raw.close().await + } + pub fn detach(self) -> DB::Connection { self.inner.raw } @@ -174,10 +219,14 @@ impl<'s, DB: Database> Floating<'s, Live> { } impl<'s, DB: Database> Floating<'s, Idle> { - pub fn from_idle(idle: Idle, pool: &'s SharedPool) -> Self { + pub fn from_idle( + idle: Idle, + pool: &'s SharedPool, + permit: SemaphoreReleaser<'s>, + ) -> Self { Self { inner: idle, - guard: DecrementSizeGuard::new(pool), + guard: DecrementSizeGuard::from_permit(pool, permit), } } @@ -192,9 +241,12 @@ impl<'s, DB: Database> Floating<'s, Idle> { } } - pub async fn close(self) -> Result<(), Error> { + pub async fn close(self) -> DecrementSizeGuard<'s> { // `guard` is dropped as intended - self.inner.live.raw.close().await + if let Err(e) = self.inner.live.raw.close().await { + log::debug!("error occurred while closing the pool connection: {}", e); + } + self.guard } } diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index f9e5df43b3..d67cdfc0f1 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -4,23 +4,28 @@ use crate::connection::Connection; use crate::database::Database; use crate::error::Error; use crate::pool::{deadline_as_timeout, PoolOptions}; -use crossbeam_queue::{ArrayQueue, SegQueue}; -use futures_core::task::{Poll, Waker}; -use futures_util::future; +use crossbeam_queue::ArrayQueue; + +use futures_intrusive::sync::{Semaphore, SemaphoreReleaser}; + use std::cmp; use std::mem; use std::ptr; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; -use std::sync::{Arc, Weak}; -use std::task::Context; +use std::sync::Arc; + use std::time::{Duration, Instant}; -type Waiters = SegQueue>; +/// Ihe number of permits to release to wake all waiters, such as on `SharedPool::close()`. +/// +/// This should be large enough to realistically wake all tasks waiting on the pool without +/// potentially overflowing the permits count in the semaphore itself. +const WAKE_ALL_PERMITS: usize = usize::MAX / 2; pub(crate) struct SharedPool { pub(super) connect_options: ::Options, pub(super) idle_conns: ArrayQueue>, - waiters: Waiters, + pub(super) semaphore: Semaphore, pub(super) size: AtomicU32, is_closed: AtomicBool, pub(super) options: PoolOptions, @@ -31,10 +36,18 @@ impl SharedPool { options: PoolOptions, connect_options: ::Options, ) -> Arc { + let capacity = options.max_connections as usize; + + // ensure the permit count won't overflow if we release `WAKE_ALL_PERMITS` + // this assert should never fire on 64-bit targets as `max_connections` is a u32 + let _ = capacity + .checked_add(WAKE_ALL_PERMITS) + .expect("max_connections exceeds max capacity of the pool"); + let pool = Self { connect_options, - idle_conns: ArrayQueue::new(options.max_connections as usize), - waiters: SegQueue::new(), + idle_conns: ArrayQueue::new(capacity), + semaphore: Semaphore::new(options.fair, capacity), size: AtomicU32::new(0), is_closed: AtomicBool::new(false), options, @@ -61,148 +74,133 @@ impl SharedPool { } pub(super) async fn close(&self) { - self.is_closed.store(true, Ordering::Release); - while let Some(waker) = self.waiters.pop() { - if let Some(waker) = waker.upgrade() { - waker.wake(); - } + let already_closed = self.is_closed.swap(true, Ordering::AcqRel); + + if !already_closed { + // if we were the one to mark this closed, release enough permits to wake all waiters + // we can't just do `usize::MAX` because that would overflow + // and we can't do this more than once cause that would _also_ overflow + self.semaphore.release(WAKE_ALL_PERMITS); } - // ensure we wait until the pool is actually closed - while self.size() > 0 { - if let Some(idle) = self.idle_conns.pop() { - if let Err(e) = Floating::from_idle(idle, self).close().await { - log::warn!("error occurred while closing the pool connection: {}", e); - } - } + // wait for all permits to be released + let _permits = self + .semaphore + .acquire(WAKE_ALL_PERMITS + (self.options.max_connections as usize)) + .await; - // yield to avoid starving the executor - sqlx_rt::yield_now().await; + while let Some(idle) = self.idle_conns.pop() { + let _ = idle.live.float(self).close().await; } } #[inline] - pub(super) fn try_acquire(&self) -> Option>> { - // don't cut in line - if self.options.fair && !self.waiters.is_empty() { + pub(super) fn try_acquire(&self) -> Option>> { + if self.is_closed() { return None; } - Some(self.pop_idle()?.into_live()) + + let permit = self.semaphore.try_acquire(1)?; + self.pop_idle(permit).ok() } - fn pop_idle(&self) -> Option>> { - if self.is_closed.load(Ordering::Acquire) { - return None; + fn pop_idle<'a>( + &'a self, + permit: SemaphoreReleaser<'a>, + ) -> Result>, SemaphoreReleaser<'a>> { + if let Some(idle) = self.idle_conns.pop() { + Ok(Floating::from_idle(idle, self, permit)) + } else { + Err(permit) } - - Some(Floating::from_idle(self.idle_conns.pop()?, self)) } pub(super) fn release(&self, mut floating: Floating<'_, Live>) { if let Some(test) = &self.options.after_release { if !test(&mut floating.raw) { - // drop the connection and do not return to the pool + // drop the connection and do not return it to the pool return; } } - let is_ok = self - .idle_conns - .push(floating.into_idle().into_leakable()) - .is_ok(); + let Floating { inner: idle, guard } = floating.into_idle(); - if !is_ok { + if !self.idle_conns.push(idle).is_ok() { panic!("BUG: connection queue overflow in release()"); } - wake_one(&self.waiters); + // NOTE: we need to make sure we drop the permit *after* we push to the idle queue + // don't decrease the size + guard.release_permit(); } /// Try to atomically increment the pool size for a new connection. /// /// Returns `None` if we are at max_connections or if the pool is closed. - pub(super) fn try_increment_size(&self) -> Option> { - if self.is_closed() { - return None; - } - - let mut size = self.size(); - - while size < self.options.max_connections { - match self - .size - .compare_exchange(size, size + 1, Ordering::AcqRel, Ordering::Acquire) - { - Ok(_) => return Some(DecrementSizeGuard::new(self)), - Err(new_size) => size = new_size, - } + pub(super) fn try_increment_size<'a>( + &'a self, + permit: SemaphoreReleaser<'a>, + ) -> Result, SemaphoreReleaser<'a>> { + match self + .size + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| { + size.checked_add(1) + .filter(|size| size <= &self.options.max_connections) + }) { + // we successfully incremented the size + Ok(_) => Ok(DecrementSizeGuard::from_permit(self, permit)), + // the pool is at max capacity + Err(_) => Err(permit), } - - None } #[allow(clippy::needless_lifetimes)] pub(super) async fn acquire<'s>(&'s self) -> Result>, Error> { - let start = Instant::now(); - let deadline = start + self.options.connect_timeout; - let mut waited = !self.options.fair; - - // the strong ref of the `Weak` that we push to the queue - // initialized during the `timeout()` call below - // as long as we own this, we keep our place in line - let mut waiter: Option> = None; - - // Unless the pool has been closed ... - while !self.is_closed() { - // Don't cut in line unless no one is waiting - if waited || self.waiters.is_empty() { - // Attempt to immediately acquire a connection. This will return Some - // if there is an idle connection in our channel. - if let Some(conn) = self.pop_idle() { - if let Some(live) = check_conn(conn, &self.options).await { - return Ok(live); - } - } + if self.is_closed() { + return Err(Error::PoolClosed); + } - // check if we can open a new connection - if let Some(guard) = self.try_increment_size() { - // pool has slots available; open a new connection - return self.connection(deadline, guard).await; - } - } + let deadline = Instant::now() + self.options.connect_timeout; - if let Some(ref waiter) = waiter { - // return the waiter to the queue, note that this does put it to the back - // of the queue when it should ideally stay at the front - self.waiters.push(Arc::downgrade(&waiter.inner)); - } + sqlx_rt::timeout( + self.options.connect_timeout, + async { + loop { + let permit = self.semaphore.acquire(1).await; - sqlx_rt::timeout( - // Returns an error if `deadline` passes - deadline_as_timeout::(deadline)?, - // `poll_fn` gets us easy access to a `Waker` that we can push to our queue - future::poll_fn(|cx| -> Poll<()> { - let waiter = waiter.get_or_insert_with(|| Waiter::push_new(cx, &self.waiters)); - - if waiter.is_woken() { - waiter.actually_woke = true; - Poll::Ready(()) - } else { - Poll::Pending + if self.is_closed() { + return Err(Error::PoolClosed); } - }), - ) - .await - .map_err(|_| Error::PoolTimedOut)?; - if let Some(ref mut waiter) = waiter { - waiter.reset(); + // First attempt to pop a connection from the idle queue. + let guard = match self.pop_idle(permit) { + + // Then, check that we can use it... + Ok(conn) => match check_conn(conn, &self.options).await { + + // All good! + Ok(live) => return Ok(live), + + // if the connection isn't usable for one reason or another, + // we get the `DecrementSizeGuard` back to open a new one + Err(guard) => guard, + }, + Err(permit) => if let Ok(guard) = self.try_increment_size(permit) { + // we can open a new connection + guard + } else { + log::debug!("woke but was unable to acquire idle connection or open new one; retrying"); + continue; + } + }; + + // Attempt to connect... + return self.connection(deadline, guard).await; + } } - - waited = true; - } - - Err(Error::PoolClosed) + ) + .await + .map_err(|_| Error::PoolTimedOut)? } pub(super) async fn connection<'s>( @@ -277,14 +275,13 @@ fn is_beyond_idle(idle: &Idle, options: &PoolOptions) -> b async fn check_conn<'s: 'p, 'p, DB: Database>( mut conn: Floating<'s, Idle>, options: &'p PoolOptions, -) -> Option>> { +) -> Result>, DecrementSizeGuard<'s>> { // If the connection we pulled has expired, close the connection and // immediately create a new connection if is_beyond_lifetime(&conn, options) { // we're closing the connection either way // close the connection but don't really care about the result - let _ = conn.close().await; - return None; + return Err(conn.close().await); } else if options.test_before_acquire { // Check that the connection is still live if let Err(e) = conn.ping().await { @@ -293,18 +290,18 @@ async fn check_conn<'s: 'p, 'p, DB: Database>( // the error itself here isn't necessarily unexpected so WARN is too strong log::info!("ping on idle connection returned error: {}", e); // connection is broken so don't try to close nicely - return None; + return Err(conn.close().await); } } else if let Some(test) = &options.before_acquire { match test(&mut conn.live.raw).await { Ok(false) => { // connection was rejected by user-defined hook - return None; + return Err(conn.close().await); } Err(error) => { log::info!("in `before_acquire`: {}", error); - return None; + return Err(conn.close().await); } Ok(true) => {} @@ -312,7 +309,7 @@ async fn check_conn<'s: 'p, 'p, DB: Database>( } // No need to re-connect; connection is alive or we don't care - Some(conn.into_live()) + Ok(conn.into_live()) } /// if `max_lifetime` or `idle_timeout` is set, spawn a task that reaps senescent connections @@ -329,11 +326,9 @@ fn spawn_reaper(pool: &Arc>) { sqlx_rt::spawn(async move { while !pool.is_closed() { - // only reap idle connections when no tasks are waiting - if pool.waiters.is_empty() { + if !pool.idle_conns.is_empty() { do_reap(&pool).await; } - sqlx_rt::sleep(period).await; } }); @@ -346,7 +341,7 @@ async fn do_reap(pool: &SharedPool) { // collect connections to reap let (reap, keep) = (0..max_reaped) // only connections waiting in the queue - .filter_map(|_| pool.pop_idle()) + .filter_map(|_| pool.try_acquire()) .partition::, _>(|conn| { is_beyond_idle(conn, &pool.options) || is_beyond_lifetime(conn, &pool.options) }); @@ -361,38 +356,44 @@ async fn do_reap(pool: &SharedPool) { } } -fn wake_one(waiters: &Waiters) { - while let Some(weak) = waiters.pop() { - if let Some(waiter) = weak.upgrade() { - if waiter.wake() { - return; - } - } - } -} - /// RAII guard returned by `Pool::try_increment_size()` and others. /// /// Will decrement the pool size if dropped, to avoid semantically "leaking" connections /// (where the pool thinks it has more connections than it does). pub(in crate::pool) struct DecrementSizeGuard<'a> { size: &'a AtomicU32, - waiters: &'a Waiters, + semaphore: &'a Semaphore, dropped: bool, } impl<'a> DecrementSizeGuard<'a> { - pub fn new(pool: &'a SharedPool) -> Self { + /// Create a new guard that will release a semaphore permit on-drop. + pub fn new_permit(pool: &'a SharedPool) -> Self { Self { size: &pool.size, - waiters: &pool.waiters, + semaphore: &pool.semaphore, dropped: false, } } + pub fn from_permit( + pool: &'a SharedPool, + mut permit: SemaphoreReleaser<'a>, + ) -> Self { + // here we effectively take ownership of the permit + permit.disarm(); + Self::new_permit(pool) + } + /// Return `true` if the internal references point to the same fields in `SharedPool`. pub fn same_pool(&self, pool: &'a SharedPool) -> bool { - ptr::eq(self.size, &pool.size) && ptr::eq(self.waiters, &pool.waiters) + ptr::eq(self.size, &pool.size) + } + + /// Release the semaphore permit without decreasing the pool size. + fn release_permit(self) { + self.semaphore.release(1); + self.cancel(); } pub fn cancel(self) { @@ -405,73 +406,8 @@ impl Drop for DecrementSizeGuard<'_> { assert!(!self.dropped, "double-dropped!"); self.dropped = true; self.size.fetch_sub(1, Ordering::SeqCst); - wake_one(&self.waiters); - } -} - -struct WaiterInner { - woken: AtomicBool, - waker: Waker, -} - -impl WaiterInner { - /// Wake this waiter if it has not previously been woken. - /// - /// Return `true` if this waiter was newly woken, or `false` if it was already woken. - fn wake(&self) -> bool { - // if we were the thread to flip this boolean from false to true - if let Ok(_) = self - .woken - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) - { - self.waker.wake_by_ref(); - return true; - } - false - } -} - -struct Waiter<'a> { - inner: Arc, - queue: &'a Waiters, - actually_woke: bool, -} - -impl<'a> Waiter<'a> { - fn push_new(cx: &mut Context<'_>, queue: &'a Waiters) -> Self { - let inner = Arc::new(WaiterInner { - woken: AtomicBool::new(false), - waker: cx.waker().clone(), - }); - - queue.push(Arc::downgrade(&inner)); - - Self { - inner, - queue, - actually_woke: false, - } - } - - fn is_woken(&self) -> bool { - self.inner.woken.load(Ordering::Acquire) - } - - fn reset(&mut self) { - self.inner - .woken - .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire) - .ok(); - self.actually_woke = false; - } -} - -impl Drop for Waiter<'_> { - fn drop(&mut self) { - // if we didn't actually wake to get a connection, wake the next task instead - if self.is_woken() && !self.actually_woke { - wake_one(self.queue); - } + // and here we release the permit we got on construction + self.semaphore.release(1); } } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 2b7d370005..826e6534c6 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -112,11 +112,23 @@ pub use self::options::PoolOptions; /// /// Calls to `acquire()` are fair, i.e. fulfilled on a first-come, first-serve basis. /// -/// `Pool` is `Send`, `Sync` and `Clone`, so it should be created once at the start of your -/// application/daemon/web server/etc. and then shared with all tasks throughout its lifetime. How -/// best to accomplish this depends on your program architecture. +/// `Pool` is `Send`, `Sync` and `Clone`. It is intended to be created once at the start of your +/// application/daemon/web server/etc. and then shared with all tasks throughout the process' +/// lifetime. How best to accomplish this depends on your program architecture. /// -/// In Actix-Web, you can share a single pool with all request handlers using [web::Data]. +/// In Actix-Web, for example, you can share a single pool with all request handlers using [web::Data]. +/// +/// Cloning `Pool` is cheap as it is simply a reference-counted handle to the inner pool state. +/// When the last remaining handle to the pool is dropped, the connections owned by the pool are +/// immediately closed (also by dropping). `PoolConnection` returned by [Pool::acquire] and +/// `Transaction` returned by [Pool::begin] both implicitly hold a reference to the pool for +/// their lifetimes. +/// +/// If you prefer to explicitly shutdown the pool and gracefully close its connections (which +/// depending on the database type, may include sending a message to the database server that the +/// connection is being closed), you can call [Pool::close] which causes all waiting and subsequent +/// calls to [Pool::acquire] to return [Error::PoolClosed], and waits until all connections have +/// been returned to the pool and gracefully closed. /// /// Type aliases are provided for each database to make it easier to sprinkle `Pool` through /// your codebase: @@ -126,7 +138,7 @@ pub use self::options::PoolOptions; /// * [PgPool][crate::postgres::PgPool] (PostgreSQL) /// * [SqlitePool][crate::sqlite::SqlitePool] (SQLite) /// -/// [web::Data]: https://docs.rs/actix-web/2.0.0/actix_web/web/struct.Data.html +/// [web::Data]: https://docs.rs/actix-web/3/actix_web/web/struct.Data.html /// /// ### Why Use a Pool? /// @@ -274,7 +286,9 @@ impl Pool { /// /// Returns `None` immediately if there are no idle connections available in the pool. pub fn try_acquire(&self) -> Option> { - self.0.try_acquire().map(|conn| conn.attach(&self.0)) + self.0 + .try_acquire() + .map(|conn| conn.into_live().attach(&self.0)) } /// Retrieves a new connection and immediately begins a new transaction. @@ -294,10 +308,29 @@ impl Pool { } } - /// Ends the use of a connection pool. Prevents any new connections - /// and will close all active connections when they are returned to the pool. + /// Shut down the connection pool, waiting for all connections to be gracefully closed. + /// + /// Upon `.await`ing this call, any currently waiting or subsequent calls to [Pool::acquire] and + /// the like will immediately return [Error::PoolClosed] and no new connections will be opened. + /// + /// Any connections currently idle in the pool will be immediately closed, including sending + /// a graceful shutdown message to the database server, if applicable. + /// + /// Checked-out connections are unaffected, but will be closed in the same manner when they are + /// returned to the pool. + /// + /// Does not resolve until all connections are returned to the pool and gracefully closed. + /// + /// ### Note: `async fn` + /// Because this is an `async fn`, the pool will *not* be marked as closed unless the + /// returned future is polled at least once. /// - /// Does not resolve until all connections are closed. + /// If you want to close the pool but don't want to wait for all connections to be gracefully + /// closed, you can do `pool.close().now_or_never()`, which polls the future exactly once + /// with a no-op waker. + // TODO: I don't want to change the signature right now in case it turns out to be a + // breaking change, but this probably should eagerly mark the pool as closed and then the + // returned future only needs to be awaited to gracefully close the connections. pub async fn close(&self) { self.0.close().await; } diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index a1b07f3721..32313808ff 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -231,19 +231,13 @@ impl PoolOptions { async fn init_min_connections(pool: &SharedPool) -> Result<(), Error> { for _ in 0..cmp::max(pool.options.min_connections, 1) { let deadline = Instant::now() + pool.options.connect_timeout; + let permit = pool.semaphore.acquire(1).await; // this guard will prevent us from exceeding `max_size` - if let Some(guard) = pool.try_increment_size() { + if let Ok(guard) = pool.try_increment_size(permit) { // [connect] will raise an error when past deadline let conn = pool.connection(deadline, guard).await?; - let is_ok = pool - .idle_conns - .push(conn.into_idle().into_leakable()) - .is_ok(); - - if !is_ok { - panic!("BUG: connection queue overflow in init_min_connections"); - } + pool.release(conn); } } diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index fc53bc0745..a14e2a1a69 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -402,13 +402,16 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 .fetch_all(&mut *self) .await?; - // patch up our null inference with data from EXPLAIN - let nullable_patch = self - .nullables_from_explain(stmt_id, meta.parameters.len()) - .await?; + // if it's cockroachdb skip this step #1248 + if !self.stream.parameter_statuses.contains_key("crdb_version") { + // patch up our null inference with data from EXPLAIN + let nullable_patch = self + .nullables_from_explain(stmt_id, meta.parameters.len()) + .await?; - for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) { - *nullable = patch.or(*nullable); + for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) { + *nullable = patch.or(*nullable); + } } Ok(nullables) diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index ac64940c2b..33d3948989 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -229,6 +229,10 @@ impl PgConnection { // patch holes created during encoding arguments.apply_patches(self, &metadata.parameters).await?; + // apply patches use fetch_optional thaht may produce `PortalSuspended` message, + // consume messages til `ReadyForQuery` before bind and execute + self.wait_until_ready().await?; + // bind to attach the arguments to the statement and create a portal self.stream.write(Bind { portal: None, diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index d843a87da9..09b2c02bd2 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -16,7 +16,6 @@ use crate::error::Error; use crate::executor::Executor; use crate::ext::ustr::UStr; use crate::io::Decode; -use crate::postgres::connection::stream::PgStream; use crate::postgres::message::{ Close, Message, MessageFormat, ReadyForQuery, Terminate, TransactionStatus, }; @@ -24,6 +23,8 @@ use crate::postgres::statement::PgStatementMetadata; use crate::postgres::{PgConnectOptions, PgTypeInfo, Postgres}; use crate::transaction::Transaction; +pub use self::stream::PgStream; + pub(crate) mod describe; mod establish; mod executor; @@ -73,7 +74,7 @@ pub struct PgConnection { impl PgConnection { // will return when the connection is ready for another query - async fn wait_until_ready(&mut self) -> Result<(), Error> { + pub(in crate::postgres) async fn wait_until_ready(&mut self) -> Result<(), Error> { if !self.stream.wbuf.is_empty() { self.stream.flush().await?; } @@ -195,3 +196,21 @@ impl Connection for PgConnection { !self.stream.wbuf.is_empty() } } + +pub trait PgConnectionInfo { + /// the version number of the server in `libpq` format + fn server_version_num(&self) -> Option; +} + +impl PgConnectionInfo for PgConnection { + fn server_version_num(&self) -> Option { + self.stream.server_version_num + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl PgConnectionInfo for crate::pool::PoolConnection { + fn server_version_num(&self) -> Option { + self.stream.server_version_num + } +} diff --git a/sqlx-core/src/postgres/connection/sasl.rs b/sqlx-core/src/postgres/connection/sasl.rs index 905afe974c..809c8ea170 100644 --- a/sqlx-core/src/postgres/connection/sasl.rs +++ b/sqlx-core/src/postgres/connection/sasl.rs @@ -98,7 +98,7 @@ pub(crate) async fn authenticate( )?; // ClientKey := HMAC(SaltedPassword, "Client Key") - let mut mac = Hmac::::new_varkey(&salted_password).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; mac.update(b"Client Key"); let client_key = mac.finalize().into_bytes(); @@ -122,7 +122,7 @@ pub(crate) async fn authenticate( ); // ClientSignature := HMAC(StoredKey, AuthMessage) - let mut mac = Hmac::::new_varkey(&stored_key).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(&stored_key).map_err(Error::protocol)?; mac.update(&auth_message.as_bytes()); let client_signature = mac.finalize().into_bytes(); @@ -135,13 +135,13 @@ pub(crate) async fn authenticate( .collect(); // ServerKey := HMAC(SaltedPassword, "Server Key") - let mut mac = Hmac::::new_varkey(&salted_password).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; mac.update(b"Server Key"); let server_key = mac.finalize().into_bytes(); // ServerSignature := HMAC(ServerKey, AuthMessage) - let mut mac = Hmac::::new_varkey(&server_key).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(&server_key).map_err(Error::protocol)?; mac.update(&auth_message.as_bytes()); // client-final-message = client-final-message-without-proof "," proof @@ -197,7 +197,7 @@ fn gen_nonce() -> String { // Hi(str, salt, i): fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> { - let mut mac = Hmac::::new_varkey(s.as_bytes()).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(s.as_bytes()).map_err(Error::protocol)?; mac.update(&salt); mac.update(&1u32.to_be_bytes()); @@ -206,7 +206,7 @@ fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error let mut hi = u; for _ in 1..iter_count { - let mut mac = Hmac::::new_varkey(s.as_bytes()).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(s.as_bytes()).map_err(Error::protocol)?; mac.update(u.as_slice()); u = mac.finalize().into_bytes(); hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect(); diff --git a/sqlx-core/src/postgres/connection/stream.rs b/sqlx-core/src/postgres/connection/stream.rs index 0152d54c67..2b3d8c3b70 100644 --- a/sqlx-core/src/postgres/connection/stream.rs +++ b/sqlx-core/src/postgres/connection/stream.rs @@ -1,4 +1,6 @@ +use std::collections::BTreeMap; use std::ops::{Deref, DerefMut}; +use std::str::FromStr; use bytes::{Buf, Bytes}; use futures_channel::mpsc::UnboundedSender; @@ -13,7 +15,7 @@ use crate::net::MaybeTlsStream; use crate::net::Socket; -use crate::postgres::message::{Message, MessageFormat, Notice, Notification}; +use crate::postgres::message::{Message, MessageFormat, Notice, Notification, ParameterStatus}; use crate::postgres::{PgConnectOptions, PgDatabaseError, PgSeverity}; // the stream is a separate type from the connection to uphold the invariant where an instantiated @@ -34,6 +36,10 @@ pub struct PgStream { // this is set when creating a PgListener and only written to if that listener is // re-used for query execution in-between receiving messages pub(crate) notifications: Option>, + + pub(crate) parameter_statuses: BTreeMap, + + pub(crate) server_version_num: Option, } impl PgStream { @@ -49,6 +55,8 @@ impl PgStream { Ok(Self { inner, notifications: None, + parameter_statuses: BTreeMap::default(), + server_version_num: None, }) } @@ -63,6 +71,8 @@ impl PgStream { Ok(Self { inner, notifications: None, + parameter_statuses: BTreeMap::default(), + server_version_num: None, }) } } @@ -131,7 +141,18 @@ impl PgStream { // informs the frontend about the current (initial) // setting of backend parameters - // we currently have no use for that data so we promptly ignore this message + let ParameterStatus { name, value } = message.decode()?; + // TODO: handle `client_encoding`, `DateStyle` change + + match name.as_str() { + "server_version" => { + self.server_version_num = parse_server_version(&value); + } + _ => { + self.parameter_statuses.insert(name, value); + } + } + continue; } @@ -191,3 +212,68 @@ impl DerefMut for PgStream { &mut self.inner } } + +// reference: +// https://github.com/postgres/postgres/blob/6feebcb6b44631c3dc435e971bd80c2dd218a5ab/src/interfaces/libpq/fe-exec.c#L1030-L1065 +fn parse_server_version(s: &str) -> Option { + let mut parts = Vec::::with_capacity(3); + + let mut from = 0; + let mut chs = s.char_indices().peekable(); + while let Some((i, ch)) = chs.next() { + match ch { + '.' => { + if let Ok(num) = u32::from_str(&s[from..i]) { + parts.push(num); + from = i + 1; + } else { + break; + } + } + _ if ch.is_digit(10) => { + if chs.peek().is_none() { + if let Ok(num) = u32::from_str(&s[from..]) { + parts.push(num); + } + break; + } + } + _ => { + if let Ok(num) = u32::from_str(&s[from..i]) { + parts.push(num); + } + break; + } + }; + } + + let version_num = match parts.as_slice() { + [major, minor, rev] => (100 * major + minor) * 100 + rev, + [major, minor] if *major >= 10 => 100 * 100 * major + minor, + [major, minor] => (100 * major + minor) * 100, + [major] => 100 * 100 * major, + _ => return None, + }; + + Some(version_num) +} + +#[cfg(test)] +mod tests { + use super::parse_server_version; + + #[test] + fn test_parse_server_version_num() { + // old style + assert_eq!(parse_server_version("9.6.1"), Some(90601)); + // new style + assert_eq!(parse_server_version("10.1"), Some(100001)); + // old style without minor version + assert_eq!(parse_server_version("9.6devel"), Some(90600)); + // new style without minor version, e.g. */ + assert_eq!(parse_server_version("10devel"), Some(100000)); + assert_eq!(parse_server_version("13devel87"), Some(130000)); + // unknown + assert_eq!(parse_server_version("unknown"), None); + } +} diff --git a/sqlx-core/src/postgres/copy.rs b/sqlx-core/src/postgres/copy.rs new file mode 100644 index 0000000000..ddff8ce5d0 --- /dev/null +++ b/sqlx-core/src/postgres/copy.rs @@ -0,0 +1,342 @@ +use crate::error::{Error, Result}; +use crate::ext::async_stream::TryAsyncStream; +#[cfg(not(target_arch = "wasm32"))] +use crate::pool::{Pool, PoolConnection}; +use crate::postgres::connection::PgConnection; +use crate::postgres::message::{ + CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query, +}; +#[cfg(not(target_arch = "wasm32"))] +use crate::postgres::Postgres; +use bytes::{BufMut, Bytes}; + +#[cfg(not(target_arch = "wasm32"))] +use futures_core::stream::BoxStream; +#[cfg(target_arch = "wasm32")] +use futures_core::stream::LocalBoxStream as BoxStream; + +use smallvec::alloc::borrow::Cow; +use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt}; +use std::convert::TryFrom; +use std::ops::{Deref, DerefMut}; + +impl PgConnection { + /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data + /// to Postgres. This is a more efficient way to import data into Postgres as compared to + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is + /// returned. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + /// + /// ### Note + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection + /// will return an error the next time it is used. + pub async fn copy_in_raw(&mut self, statement: &str) -> Result> { + PgCopyIn::begin(self, statement).await + } + + /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data + /// from Postgres. This is a more efficient way to export data from Postgres but + /// arrives in chunks of one of a few data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, + /// an error is returned. + /// + /// Note that once this process has begun, unless you read the stream to completion, + /// it can only be canceled in two ways: + /// + /// 1. by closing the connection, or: + /// 2. by using another connection to kill the server process that is sending the data as shown + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). + /// + /// If you don't read the stream to completion, the next time the connection is used it will + /// need to read and discard all the remaining queued data, which could take some time. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + #[allow(clippy::needless_lifetimes)] + pub async fn copy_out_raw<'c>( + &'c mut self, + statement: &str, + ) -> Result>> { + pg_begin_copy_out(self, statement).await + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl Pool { + /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres. + /// This is a more efficient way to import data into Postgres as compared to + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). + /// + /// A single connection will be checked out for the duration. + /// + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is + /// returned. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + /// + /// ### Note + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection + /// will return an error the next time it is used. + pub async fn copy_in_raw(&self, statement: &str) -> Result>> { + PgCopyIn::begin(self.acquire().await?, statement).await + } + + /// Issue a `COPY TO STDOUT` statement and begin streaming data + /// from Postgres. This is a more efficient way to export data from Postgres but + /// arrives in chunks of one of a few data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, + /// an error is returned. + /// + /// Note that once this process has begun, unless you read the stream to completion, + /// it can only be canceled in two ways: + /// + /// 1. by closing the connection, or: + /// 2. by using another connection to kill the server process that is sending the data as shown + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). + /// + /// If you don't read the stream to completion, the next time the connection is used it will + /// need to read and discard all the remaining queued data, which could take some time. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + pub async fn copy_out_raw(&self, statement: &str) -> Result>> { + pg_begin_copy_out(self.acquire().await?, statement).await + } +} + +/// A connection in streaming `COPY FROM STDIN` mode. +/// +/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw]. +/// +/// ### Note +/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection +/// will return an error the next time it is used. +#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"] +pub struct PgCopyIn> { + conn: Option, + response: CopyResponse, +} + +impl> PgCopyIn { + async fn begin(mut conn: C, statement: &str) -> Result { + conn.wait_until_ready().await?; + conn.stream.send(Query(statement)).await?; + + let response: CopyResponse = conn + .stream + .recv_expect(MessageFormat::CopyInResponse) + .await?; + + Ok(PgCopyIn { + conn: Some(conn), + response, + }) + } + + /// Returns `true` if Postgres is expecting data in text or CSV format. + pub fn is_textual(&self) -> bool { + self.response.format == 0 + } + + /// Returns the number of columns expected in the input. + pub fn num_columns(&self) -> usize { + assert_eq!( + self.response.num_columns as usize, + self.response.format_codes.len(), + "num_columns does not match format_codes.len()" + ); + self.response.format_codes.len() + } + + /// Check if a column is expecting data in text format (`true`) or binary format (`false`). + /// + /// ### Panics + /// If `column` is out of range according to [`.num_columns()`][Self::num_columns]. + pub fn column_is_textual(&self, column: usize) -> bool { + self.response.format_codes[column] == 0 + } + + /// Send a chunk of `COPY` data. + /// + /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead. + pub async fn send(&mut self, data: impl Deref) -> Result<&mut Self> { + self.conn + .as_deref_mut() + .expect("send_data: conn taken") + .stream + .send(CopyData(data)) + .await?; + + Ok(self) + } + + /// Copy data directly from `source` to the database without requiring an intermediate buffer. + /// + /// `source` will be read to the end. + /// + /// ### Note + /// You must still call either [Self::finish] or [Self::abort] to complete the process. + pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> { + // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing + struct BufGuard<'s>(&'s mut Vec); + + impl Drop for BufGuard<'_> { + fn drop(&mut self) { + self.0.clear() + } + } + + let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken"); + + // flush any existing messages in the buffer and clear it + conn.stream.flush().await?; + + { + let buf_stream = &mut *conn.stream; + let stream = &mut buf_stream.stream; + + // ensures the buffer isn't left in an inconsistent state + let mut guard = BufGuard(&mut buf_stream.wbuf); + + let buf: &mut Vec = &mut guard.0; + buf.push(b'd'); // CopyData format code + buf.resize(5, 0); // reserve space for the length + + loop { + let read = match () { + // Tokio lets us read into the buffer without zeroing first + #[cfg(any(feature = "runtime-tokio", feature = "runtime-actix"))] + _ if buf.len() != buf.capacity() => { + // in case we have some data in the buffer, which can occur + // if the previous write did not fill the buffer + buf.truncate(5); + source.read_buf(buf).await? + } + _ => { + // should be a no-op unless len != capacity + buf.resize(buf.capacity(), 0); + source.read(&mut buf[5..]).await? + } + }; + + if read == 0 { + break; + } + + let read32 = u32::try_from(read) + .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?; + + (&mut buf[1..]).put_u32(read32 + 4); + + stream.write_all(&buf[..read + 5]).await?; + stream.flush().await?; + } + } + + Ok(self) + } + + /// Signal that the `COPY` process should be aborted and any data received should be discarded. + /// + /// The given message can be used for indicating the reason for the abort in the database logs. + /// + /// The server is expected to respond with an error, so only _unexpected_ errors are returned. + pub async fn abort(mut self, msg: impl Into) -> Result<()> { + let mut conn = self + .conn + .take() + .expect("PgCopyIn::fail_with: conn taken illegally"); + + conn.stream.send(CopyFail::new(msg)).await?; + + match conn.stream.recv().await { + Ok(msg) => Err(err_protocol!( + "fail_with: expected ErrorResponse, got: {:?}", + msg.format + )), + Err(Error::Database(e)) => { + match e.code() { + Some(Cow::Borrowed("57014")) => { + // postgres abort received error code + conn.stream + .recv_expect(MessageFormat::ReadyForQuery) + .await?; + Ok(()) + } + _ => Err(Error::Database(e)), + } + } + Err(e) => Err(e), + } + } + + /// Signal that the `COPY` process is complete. + /// + /// The number of rows affected is returned. + pub async fn finish(mut self) -> Result { + let mut conn = self + .conn + .take() + .expect("CopyWriter::finish: conn taken illegally"); + + conn.stream.send(CopyDone).await?; + let cc: CommandComplete = conn + .stream + .recv_expect(MessageFormat::CommandComplete) + .await?; + + conn.stream + .recv_expect(MessageFormat::ReadyForQuery) + .await?; + + Ok(cc.rows_affected()) + } +} + +impl> Drop for PgCopyIn { + fn drop(&mut self) { + if let Some(mut conn) = self.conn.take() { + conn.stream.write(CopyFail::new( + "PgCopyIn dropped without calling finish() or fail()", + )); + } + } +} + +async fn pg_begin_copy_out<'c, C: DerefMut + Send + 'c>( + mut conn: C, + statement: &str, +) -> Result>> { + conn.wait_until_ready().await?; + conn.stream.send(Query(statement)).await?; + + let _: CopyResponse = conn + .stream + .recv_expect(MessageFormat::CopyOutResponse) + .await?; + + let stream: TryAsyncStream<'c, Bytes> = try_stream! { + loop { + let msg = conn.stream.recv().await?; + match msg.format { + MessageFormat::CopyData => r#yield!(msg.decode::>()?.0), + MessageFormat::CopyDone => { + let _ = msg.decode::()?; + conn.stream.recv_expect(MessageFormat::CommandComplete).await?; + conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; + return Ok(()) + }, + _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) + } + } + }; + + Ok(Box::pin(stream)) +} diff --git a/sqlx-core/src/postgres/listener.rs b/sqlx-core/src/postgres/listener.rs index 82c0460f4b..36a1cb5f5c 100644 --- a/sqlx-core/src/postgres/listener.rs +++ b/sqlx-core/src/postgres/listener.rs @@ -260,10 +260,23 @@ impl PgListener { impl Drop for PgListener { fn drop(&mut self) { if let Some(mut conn) = self.connection.take() { - // Unregister any listeners before returning the connection to the pool. - sqlx_rt::spawn(async move { + let fut = async move { let _ = conn.execute("UNLISTEN *").await; - }); + + // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task + // otherwise, it may trigger a panic if this task is dropped because the runtime is going away: + // https://github.com/launchbadge/sqlx/issues/1389 + conn.return_to_pool().await; + }; + + // Unregister any listeners before returning the connection to the pool. + #[cfg(not(feature = "_rt-async-std"))] + if let Ok(handle) = sqlx_rt::Handle::try_current() { + handle.spawn(fut); + } + + #[cfg(feature = "_rt-async-std")] + sqlx_rt::spawn(fut); } } } diff --git a/sqlx-core/src/postgres/message/copy.rs b/sqlx-core/src/postgres/message/copy.rs new file mode 100644 index 0000000000..58553d431b --- /dev/null +++ b/sqlx-core/src/postgres/message/copy.rs @@ -0,0 +1,96 @@ +use crate::error::Result; +use crate::io::{BufExt, BufMutExt, Decode, Encode}; +use bytes::{Buf, BufMut, Bytes}; +use std::ops::Deref; + +/// The same structure is sent for both `CopyInResponse` and `CopyOutResponse` +pub struct CopyResponse { + pub format: i8, + pub num_columns: i16, + pub format_codes: Vec, +} + +pub struct CopyData(pub B); + +pub struct CopyFail { + pub message: String, +} + +pub struct CopyDone; + +impl Decode<'_> for CopyResponse { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let format = buf.get_i8(); + let num_columns = buf.get_i16(); + + let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect(); + + Ok(CopyResponse { + format, + num_columns, + format_codes, + }) + } +} + +impl Decode<'_> for CopyData { + fn decode_with(buf: Bytes, _: ()) -> Result { + // well.. that was easy + Ok(CopyData(buf)) + } +} + +impl> Encode<'_> for CopyData { + fn encode_with(&self, buf: &mut Vec, _context: ()) { + buf.push(b'd'); + buf.put_u32(self.0.len() as u32 + 4); + buf.extend_from_slice(&self.0); + } +} + +impl Decode<'_> for CopyFail { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + Ok(CopyFail { + message: buf.get_str_nul()?, + }) + } +} + +impl Encode<'_> for CopyFail { + fn encode_with(&self, buf: &mut Vec, _: ()) { + let len = 4 + self.message.len() + 1; + + buf.push(b'f'); // to pay respects + buf.put_u32(len as u32); + buf.put_str_nul(&self.message); + } +} + +impl CopyFail { + pub fn new(msg: impl Into) -> CopyFail { + CopyFail { + message: msg.into(), + } + } +} + +impl Decode<'_> for CopyDone { + fn decode_with(buf: Bytes, _: ()) -> Result { + if buf.is_empty() { + Ok(CopyDone) + } else { + Err(err_protocol!( + "expected no data for CopyDone, got: {:?}", + buf + )) + } + } +} + +impl Encode<'_> for CopyDone { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.reserve(4); + buf.push(b'c'); + buf.put_u32(4); + } +} diff --git a/sqlx-core/src/postgres/message/mod.rs b/sqlx-core/src/postgres/message/mod.rs index 6c8d1f3023..1261bff339 100644 --- a/sqlx-core/src/postgres/message/mod.rs +++ b/sqlx-core/src/postgres/message/mod.rs @@ -8,12 +8,14 @@ mod backend_key_data; mod bind; mod close; mod command_complete; +mod copy; mod data_row; mod describe; mod execute; mod flush; mod notification; mod parameter_description; +mod parameter_status; mod parse; mod password; mod query; @@ -31,12 +33,14 @@ pub use backend_key_data::BackendKeyData; pub use bind::Bind; pub use close::Close; pub use command_complete::CommandComplete; +pub use copy::{CopyData, CopyDone, CopyFail, CopyResponse}; pub use data_row::DataRow; pub use describe::Describe; pub use execute::Execute; pub use flush::Flush; pub use notification::Notification; pub use parameter_description::ParameterDescription; +pub use parameter_status::ParameterStatus; pub use parse::Parse; pub use password::Password; pub use query::Query; @@ -57,6 +61,10 @@ pub enum MessageFormat { BindComplete, CloseComplete, CommandComplete, + CopyData, + CopyDone, + CopyInResponse, + CopyOutResponse, DataRow, EmptyQueryResponse, ErrorResponse, @@ -96,6 +104,10 @@ impl MessageFormat { b'2' => MessageFormat::BindComplete, b'3' => MessageFormat::CloseComplete, b'C' => MessageFormat::CommandComplete, + b'd' => MessageFormat::CopyData, + b'c' => MessageFormat::CopyDone, + b'G' => MessageFormat::CopyInResponse, + b'H' => MessageFormat::CopyOutResponse, b'D' => MessageFormat::DataRow, b'E' => MessageFormat::ErrorResponse, b'I' => MessageFormat::EmptyQueryResponse, diff --git a/sqlx-core/src/postgres/message/parameter_status.rs b/sqlx-core/src/postgres/message/parameter_status.rs new file mode 100644 index 0000000000..ffd0ef1b60 --- /dev/null +++ b/sqlx-core/src/postgres/message/parameter_status.rs @@ -0,0 +1,62 @@ +use bytes::Bytes; + +use crate::error::Error; +use crate::io::{BufExt, Decode}; + +#[derive(Debug)] +pub struct ParameterStatus { + pub name: String, + pub value: String, +} + +impl Decode<'_> for ParameterStatus { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let name = buf.get_str_nul()?; + let value = buf.get_str_nul()?; + + Ok(Self { name, value }) + } +} + +#[test] +fn test_decode_parameter_status() { + const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; + + let m = ParameterStatus::decode(DATA.into()).unwrap(); + + assert_eq!(&m.name, "client_encoding"); + assert_eq!(&m.value, "UTF8") +} + +#[test] +fn test_decode_empty_parameter_status() { + const DATA: &[u8] = b"\x00\x00"; + + let m = ParameterStatus::decode(DATA.into()).unwrap(); + + assert!(m.name.is_empty()); + assert!(m.value.is_empty()); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_decode_parameter_status(b: &mut test::Bencher) { + const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; + + b.iter(|| { + ParameterStatus::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + }); +} + +#[test] +fn test_decode_parameter_status_response() { + const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0"; + + let message = ParameterStatus::decode(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap(); + + assert_eq!(message.name, "crdb_version"); + assert_eq!( + message.value, + "CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)" + ); +} diff --git a/sqlx-core/src/postgres/migrate.rs b/sqlx-core/src/postgres/migrate.rs index 142918a69e..13bd2e3694 100644 --- a/sqlx-core/src/postgres/migrate.rs +++ b/sqlx-core/src/postgres/migrate.rs @@ -8,7 +8,6 @@ use crate::postgres::{PgConnectOptions, PgConnection, Postgres}; use crate::query::query; use crate::query_as::query_as; use crate::query_scalar::query_scalar; -use crc::crc32; use futures_core::future::BoxFuture; use std::str::FromStr; use std::time::Duration; @@ -25,9 +24,9 @@ fn parse_for_maintenance(uri: &str) -> Result<(PgConnectOptions, String), Error> .to_owned(); // switch us to the maintenance database - // use `postgres` _unless_ the current user is postgres, in which case, use `template1` + // use `postgres` _unless_ the database is postgres, in which case, use `template1` // this matches the behavior of the `createdb` util - options.database = if options.username == "postgres" { + options.database = if database == "postgres" { Some("template1".into()) } else { Some("postgres".into()) @@ -281,6 +280,7 @@ async fn current_database(conn: &mut PgConnection) -> Result i64 { + const CRC_IEEE: crc::Crc = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); // 0x3d32ad9e chosen by fair dice roll - 0x3d32ad9e * (crc32::checksum_ieee(database_name.as_bytes()) as i64) + 0x3d32ad9e * (CRC_IEEE.checksum(database_name.as_bytes()) as i64) } diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index c65f173c00..8482315d23 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -1,8 +1,11 @@ //! **PostgreSQL** database driver. +use crate::executor::Executor; + mod arguments; mod column; mod connection; +mod copy; mod database; mod error; mod io; @@ -27,7 +30,8 @@ mod migrate; pub use arguments::{PgArgumentBuffer, PgArguments}; pub use column::PgColumn; -pub use connection::PgConnection; +pub use connection::{PgConnection, PgConnectionInfo}; +pub use copy::PgCopyIn; pub use database::Postgres; pub use error::{PgDatabaseError, PgErrorPosition}; @@ -53,6 +57,10 @@ pub type PgPool = crate::pool::Pool; #[cfg(not(target_arch = "wasm32"))] pub type PgPoolOptions = crate::pool::PoolOptions; +/// An alias for [`Executor<'_, Database = Postgres>`][Executor]. +pub trait PgExecutor<'c>: Executor<'c, Database = Postgres> {} +impl<'c, T: Executor<'c, Database = Postgres>> PgExecutor<'c> for T {} + impl_into_arguments_for_arguments!(PgArguments); #[cfg(not(target_arch = "wasm32"))] diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index 6f85364a85..37c018f798 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -198,6 +198,8 @@ impl PgTypeInfo { .contains(self) { Some("ipnetwork") + } else if [PgTypeInfo::MACADDR].contains(self) { + Some("mac_address") } else if [PgTypeInfo::NUMERIC, PgTypeInfo::NUMERIC_ARRAY].contains(self) { Some("bigdecimal") } else { @@ -740,8 +742,11 @@ impl PgType { PgType::Custom(ty) => &ty.kind, - PgType::DeclareWithOid(_) | PgType::DeclareWithName(_) => { - unreachable!("(bug) use of unresolved type declaration [kind]") + PgType::DeclareWithOid(oid) => { + unreachable!("(bug) use of unresolved type declaration [oid={}]", oid); + } + PgType::DeclareWithName(name) => { + unreachable!("(bug) use of unresolved type declaration [name={}]", name); } } } diff --git a/sqlx-core/src/postgres/types/decimal.rs b/sqlx-core/src/postgres/types/decimal.rs index e206b86b04..61ca06fcb7 100644 --- a/sqlx-core/src/postgres/types/decimal.rs +++ b/sqlx-core/src/postgres/types/decimal.rs @@ -88,7 +88,8 @@ impl TryFrom<&'_ Decimal> for PgNumeric { type Error = BoxDynError; fn try_from(decimal: &Decimal) -> Result { - if decimal.is_zero() { + // `Decimal` added `is_zero()` as an inherent method in a more recent version + if Zero::is_zero(decimal) { return Ok(PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, diff --git a/sqlx-core/src/postgres/types/interval.rs b/sqlx-core/src/postgres/types/interval.rs index 8a5307adac..42805a56b5 100644 --- a/sqlx-core/src/postgres/types/interval.rs +++ b/sqlx-core/src/postgres/types/interval.rs @@ -148,9 +148,32 @@ impl TryFrom for PgInterval { /// Convert a `chrono::Duration` to a `PgInterval`. /// /// This returns an error if there is a loss of precision using nanoseconds or if there is a - /// microsecond or nanosecond overflow. + /// nanosecond overflow. fn try_from(value: chrono::Duration) -> Result { - value.to_std()?.try_into() + value + .num_nanoseconds() + .map_or::, _>( + Err("Overflow has occurred for PostgreSQL `INTERVAL`".into()), + |nanoseconds| { + if nanoseconds % 1000 != 0 { + return Err( + "PostgreSQL `INTERVAL` does not support nanoseconds precision".into(), + ); + } + Ok(()) + }, + )?; + + value.num_microseconds().map_or( + Err("Overflow has occurred for PostgreSQL `INTERVAL`".into()), + |microseconds| { + Ok(Self { + months: 0, + days: 0, + microseconds: microseconds, + }) + }, + ) } } @@ -283,6 +306,7 @@ fn test_encode_interval() { #[test] fn test_pginterval_std() { + // Case for positive duration let interval = PgInterval { days: 0, months: 0, @@ -292,11 +316,18 @@ fn test_pginterval_std() { &PgInterval::try_from(std::time::Duration::from_micros(27_000)).unwrap(), &interval ); + + // Case when precision loss occurs + assert!(PgInterval::try_from(std::time::Duration::from_nanos(27_000_001)).is_err()); + + // Case when microsecond overflow occurs + assert!(PgInterval::try_from(std::time::Duration::from_secs(20_000_000_000_000)).is_err()); } #[test] #[cfg(feature = "chrono")] fn test_pginterval_chrono() { + // Case for positive duration let interval = PgInterval { days: 0, months: 0, @@ -306,11 +337,31 @@ fn test_pginterval_chrono() { &PgInterval::try_from(chrono::Duration::microseconds(27_000)).unwrap(), &interval ); + + // Case for negative duration + let interval = PgInterval { + days: 0, + months: 0, + microseconds: -27_000, + }; + assert_eq!( + &PgInterval::try_from(chrono::Duration::microseconds(-27_000)).unwrap(), + &interval + ); + + // Case when precision loss occurs + assert!(PgInterval::try_from(chrono::Duration::nanoseconds(27_000_001)).is_err()); + assert!(PgInterval::try_from(chrono::Duration::nanoseconds(-27_000_001)).is_err()); + + // Case when nanosecond overflow occurs + assert!(PgInterval::try_from(chrono::Duration::seconds(10_000_000_000)).is_err()); + assert!(PgInterval::try_from(chrono::Duration::seconds(-10_000_000_000)).is_err()); } #[test] #[cfg(feature = "time")] fn test_pginterval_time() { + // Case for positive duration let interval = PgInterval { days: 0, months: 0, @@ -320,4 +371,23 @@ fn test_pginterval_time() { &PgInterval::try_from(time::Duration::microseconds(27_000)).unwrap(), &interval ); + + // Case for negative duration + let interval = PgInterval { + days: 0, + months: 0, + microseconds: -27_000, + }; + assert_eq!( + &PgInterval::try_from(time::Duration::microseconds(-27_000)).unwrap(), + &interval + ); + + // Case when precision loss occurs + assert!(PgInterval::try_from(time::Duration::nanoseconds(27_000_001)).is_err()); + assert!(PgInterval::try_from(time::Duration::nanoseconds(-27_000_001)).is_err()); + + // Case when microsecond overflow occurs + assert!(PgInterval::try_from(time::Duration::seconds(10_000_000_000_000)).is_err()); + assert!(PgInterval::try_from(time::Duration::seconds(-10_000_000_000_000)).is_err()); } diff --git a/sqlx-core/src/postgres/types/ipnetwork.rs b/sqlx-core/src/postgres/types/ipnetwork.rs index 5d579e8648..84611814b2 100644 --- a/sqlx-core/src/postgres/types/ipnetwork.rs +++ b/sqlx-core/src/postgres/types/ipnetwork.rs @@ -38,6 +38,10 @@ impl Type for [IpNetwork] { fn type_info() -> PgTypeInfo { PgTypeInfo::INET_ARRAY } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY + } } impl Type for Vec { diff --git a/sqlx-core/src/postgres/types/mac_address.rs b/sqlx-core/src/postgres/types/mac_address.rs new file mode 100644 index 0000000000..37bd543217 --- /dev/null +++ b/sqlx-core/src/postgres/types/mac_address.rs @@ -0,0 +1,63 @@ +use mac_address::MacAddress; + +use std::convert::TryInto; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use crate::types::Type; + +impl Type for MacAddress { + fn type_info() -> PgTypeInfo { + PgTypeInfo::MACADDR + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::MACADDR + } +} + +impl Type for [MacAddress] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::MACADDR_ARRAY + } +} + +impl Type for Vec { + fn type_info() -> PgTypeInfo { + <[MacAddress] as Type>::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <[MacAddress] as Type>::compatible(ty) + } +} + +impl Encode<'_, Postgres> for MacAddress { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + buf.extend_from_slice(&self.bytes()); // write just the address + IsNull::No + } + + fn size_hint(&self) -> usize { + 6 + } +} + +impl Decode<'_, Postgres> for MacAddress { + fn decode(value: PgValueRef<'_>) -> Result { + let bytes = match value.format() { + PgValueFormat::Binary => value.as_bytes()?, + PgValueFormat::Text => { + return Ok(value.as_str()?.parse()?); + } + }; + + if bytes.len() == 6 { + return Ok(MacAddress::new(bytes.try_into().unwrap())); + } + + Err("invalid data received when expecting an MACADDR".into()) + } +} diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 3827d9dc6e..066a8c2309 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -73,6 +73,14 @@ //! |---------------------------------------|------------------------------------------------------| //! | `ipnetwork::IpNetwork` | INET, CIDR | //! +//! ### [`mac_address`](https://crates.io/crates/mac_address) +//! +//! Requires the `mac_address` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `mac_address::MacAddress` | MACADDR | +//! //! ### [`bit-vec`](https://crates.io/crates/bit-vec) //! //! Requires the `bit-vec` Cargo feature flag. @@ -194,6 +202,9 @@ mod json; #[cfg(feature = "ipnetwork")] mod ipnetwork; +#[cfg(feature = "mac_address")] +mod mac_address; + #[cfg(feature = "bit-vec")] mod bit_vec; diff --git a/sqlx-core/src/postgres/types/money.rs b/sqlx-core/src/postgres/types/money.rs index 2ae47dcd63..f327726710 100644 --- a/sqlx-core/src/postgres/types/money.rs +++ b/sqlx-core/src/postgres/types/money.rs @@ -20,46 +20,102 @@ use std::{ /// /// Reading `MONEY` value in text format is not supported and will cause an error. /// +/// ### `locale_frac_digits` +/// This parameter corresponds to the number of digits after the decimal separator. +/// +/// This value must match what Postgres is expecting for the locale set in the database +/// or else the decimal value you see on the client side will not match the `money` value +/// on the server side. +/// +/// **For _most_ locales, this value is `2`.** +/// +/// If you're not sure what locale your database is set to or how many decimal digits it specifies, +/// you can execute `SHOW lc_monetary;` to get the locale name, and then look it up in this list +/// (you can ignore the `.utf8` prefix): +/// https://lh.2xlibre.net/values/frac_digits/ +/// +/// If that link is dead and you're on a POSIX-compliant system (Unix, FreeBSD) you can also execute: +/// +/// ```sh +/// $ LC_MONETARY= locale -k frac_digits +/// ``` +/// +/// And the value you want is `N` in `frac_digits=N`. If you have shell access to the database +/// server you should execute it there as available locales may differ between machines. +/// +/// Note that if `frac_digits` for the locale is outside the range `[0, 10]`, Postgres assumes +/// it's a sentinel value and defaults to 2: +/// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/cash.c#L114-L123 +/// /// [`MONEY`]: https://www.postgresql.org/docs/current/datatype-money.html #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct PgMoney(pub i64); +pub struct PgMoney( + /// The raw integer value sent over the wire; for locales with `frac_digits=2` (i.e. most + /// of them), this will be the value in whole cents. + /// + /// E.g. for `select '$123.45'::money` with a locale of `en_US` (`frac_digits=2`), + /// this will be `12345`. + /// + /// If the currency of your locale does not have fractional units, e.g. Yen, then this will + /// just be the units of the currency. + /// + /// See the type-level docs for an explanation of `locale_frac_units`. + pub i64, +); impl PgMoney { - /// Convert the money value into a [`BigDecimal`] using the correct - /// precision defined in the PostgreSQL settings. The default precision is - /// two. + /// Convert the money value into a [`BigDecimal`] using `locale_frac_digits`. + /// + /// See the type-level docs for an explanation of `locale_frac_digits`. /// /// [`BigDecimal`]: crate::types::BigDecimal #[cfg(feature = "bigdecimal")] - pub fn to_bigdecimal(self, scale: i64) -> bigdecimal::BigDecimal { + pub fn to_bigdecimal(self, locale_frac_digits: i64) -> bigdecimal::BigDecimal { let digits = num_bigint::BigInt::from(self.0); - bigdecimal::BigDecimal::new(digits, scale) + bigdecimal::BigDecimal::new(digits, locale_frac_digits) } - /// Convert the money value into a [`Decimal`] using the correct precision - /// defined in the PostgreSQL settings. The default precision is two. + /// Convert the money value into a [`Decimal`] using `locale_frac_digits`. + /// + /// See the type-level docs for an explanation of `locale_frac_digits`. /// /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "decimal")] - pub fn to_decimal(self, scale: u32) -> rust_decimal::Decimal { - rust_decimal::Decimal::new(self.0, scale) + pub fn to_decimal(self, locale_frac_digits: u32) -> rust_decimal::Decimal { + rust_decimal::Decimal::new(self.0, locale_frac_digits) } - /// Convert a [`Decimal`] value into money using the correct precision - /// defined in the PostgreSQL settings. The default precision is two. + /// Convert a [`Decimal`] value into money using `locale_frac_digits`. /// - /// Conversion may involve a loss of precision. + /// See the type-level docs for an explanation of `locale_frac_digits`. + /// + /// Note that `Decimal` has 96 bits of precision, but `PgMoney` only has 63 plus the sign bit. + /// If the value is larger than 63 bits it will be truncated. /// /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "decimal")] - pub fn from_decimal(decimal: rust_decimal::Decimal, scale: u32) -> Self { - let cents = (decimal * rust_decimal::Decimal::new(10i64.pow(scale), 0)).round(); + pub fn from_decimal(mut decimal: rust_decimal::Decimal, locale_frac_digits: u32) -> Self { + use std::convert::TryFrom; + + // this is all we need to convert to our expected locale's `frac_digits` + decimal.rescale(locale_frac_digits); + + /// a mask to bitwise-AND with an `i64` to zero the sign bit + const SIGN_MASK: i64 = i64::MAX; + + let is_negative = decimal.is_sign_negative(); + let serialized = decimal.serialize(); - let mut buf: [u8; 8] = [0; 8]; - buf.copy_from_slice(¢s.serialize()[4..12]); + // interpret bytes `4..12` as an i64, ignoring the sign bit + // this is where truncation occurs + let value = i64::from_le_bytes( + *<&[u8; 8]>::try_from(&serialized[4..12]) + .expect("BUG: slice of serialized should be 8 bytes"), + ) & SIGN_MASK; // zero out the sign bit - Self(i64::from_le_bytes(buf)) + // negate if necessary + Self(if is_negative { -value } else { value }) } /// Convert a [`BigDecimal`](crate::types::BigDecimal) value into money using the correct precision @@ -67,12 +123,14 @@ impl PgMoney { #[cfg(feature = "bigdecimal")] pub fn from_bigdecimal( decimal: bigdecimal::BigDecimal, - scale: u32, + locale_frac_digits: u32, ) -> Result { use bigdecimal::ToPrimitive; - let multiplier = - bigdecimal::BigDecimal::new(num_bigint::BigInt::from(10i128.pow(scale)), 0); + let multiplier = bigdecimal::BigDecimal::new( + num_bigint::BigInt::from(10i128.pow(locale_frac_digits)), + 0, + ); let cents = decimal * multiplier; @@ -277,9 +335,25 @@ mod tests { #[test] #[cfg(feature = "decimal")] fn conversion_from_decimal_works() { - let dec = rust_decimal::Decimal::new(12345, 2); + assert_eq!( + PgMoney(12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(12345, 2), 2) + ); - assert_eq!(PgMoney(12345), PgMoney::from_decimal(dec, 2)); + assert_eq!( + PgMoney(12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(123450, 3), 2) + ); + + assert_eq!( + PgMoney(-12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(-123450, 3), 2) + ); + + assert_eq!( + PgMoney(-12300), + PgMoney::from_decimal(rust_decimal::Decimal::new(-123, 0), 2) + ); } #[test] diff --git a/sqlx-core/src/postgres/types/range.rs b/sqlx-core/src/postgres/types/range.rs index 760249f79c..59f689d9c0 100644 --- a/sqlx-core/src/postgres/types/range.rs +++ b/sqlx-core/src/postgres/types/range.rs @@ -142,6 +142,17 @@ impl Type for PgRange { } } +#[cfg(feature = "decimal")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + #[cfg(feature = "chrono")] impl Type for PgRange { fn type_info() -> PgTypeInfo { @@ -227,6 +238,13 @@ impl Type for [PgRange] { } } +#[cfg(feature = "decimal")] +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE_ARRAY + } +} + #[cfg(feature = "chrono")] impl Type for [PgRange] { fn type_info() -> PgTypeInfo { @@ -288,6 +306,13 @@ impl Type for Vec> { } } +#[cfg(feature = "decimal")] +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE_ARRAY + } +} + #[cfg(feature = "chrono")] impl Type for Vec> { fn type_info() -> PgTypeInfo { diff --git a/sqlx-core/src/postgres/types/str.rs b/sqlx-core/src/postgres/types/str.rs index 3607a4b898..7a721569d6 100644 --- a/sqlx-core/src/postgres/types/str.rs +++ b/sqlx-core/src/postgres/types/str.rs @@ -4,6 +4,7 @@ use crate::error::BoxDynError; use crate::postgres::types::array_compatible; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres}; use crate::types::Type; +use std::borrow::Cow; impl Type for str { fn type_info() -> PgTypeInfo { @@ -22,6 +23,16 @@ impl Type for str { } } +impl Type for Cow<'_, str> { + fn type_info() -> PgTypeInfo { + <&str as Type>::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + impl Type for [&'_ str] { fn type_info() -> PgTypeInfo { PgTypeInfo::TEXT_ARRAY @@ -50,6 +61,15 @@ impl Encode<'_, Postgres> for &'_ str { } } +impl Encode<'_, Postgres> for Cow<'_, str> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + match self { + Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), + Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), + } + } +} + impl Encode<'_, Postgres> for String { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { <&str as Encode>::encode(&**self, buf) @@ -62,6 +82,12 @@ impl<'r> Decode<'r, Postgres> for &'r str { } } +impl<'r> Decode<'r, Postgres> for Cow<'r, str> { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(Cow::Borrowed(value.as_str()?)) + } +} + impl Type for String { fn type_info() -> PgTypeInfo { <&str as Type>::type_info() diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 89ce4b7b8a..e23d3f2581 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -10,7 +10,7 @@ use futures_core::stream::LocalBoxStream as BoxStream; use futures_util::{StreamExt, TryStreamExt}; use crate::arguments::IntoArguments; -use crate::database::{Database, HasArguments, HasStatement}; +use crate::database::{Database, HasArguments, HasStatement, HasStatementCache}; use crate::encode::Encode; use crate::error::Error; use crate::executor::{Execute, Executor}; @@ -97,6 +97,24 @@ impl<'q, DB: Database, O> QueryAs<'q, DB, O, >::Arguments } } +impl<'q, DB, O, A> QueryAs<'q, DB, O, A> +where + DB: Database + HasStatementCache, +{ + /// If `true`, the statement will get prepared once and cached to the + /// connection's statement cache. + /// + /// If queried once with the flag set to `true`, all subsequent queries + /// matching the one with the flag will use the cached statement until the + /// cache is cleared. + /// + /// Default: `true`. + pub fn persistent(mut self, value: bool) -> Self { + self.inner = self.inner.persistent(value); + self + } +} + // FIXME: This is very close, nearly 1:1 with `Map` // noinspection DuplicatedCode #[cfg(not(target_arch = "wasm32"))] diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index 1f36caa6e0..8898ccf654 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -8,7 +8,7 @@ use futures_core::stream::LocalBoxStream as BoxStream; use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; use crate::arguments::IntoArguments; -use crate::database::{Database, HasArguments, HasStatement}; +use crate::database::{Database, HasArguments, HasStatement, HasStatementCache}; use crate::encode::Encode; use crate::error::Error; use crate::executor::{Execute, Executor}; @@ -92,6 +92,24 @@ impl<'q, DB: Database, O> QueryScalar<'q, DB, O, >::Argum } } +impl<'q, DB, O, A> QueryScalar<'q, DB, O, A> +where + DB: Database + HasStatementCache, +{ + /// If `true`, the statement will get prepared once and cached to the + /// connection's statement cache. + /// + /// If queried once with the flag set to `true`, all subsequent queries + /// matching the one with the flag will use the cached statement until the + /// cache is cleared. + /// + /// Default: `true`. + pub fn persistent(mut self, value: bool) -> Self { + self.inner = self.inner.persistent(value); + self + } +} + // FIXME: This is very close, nearly 1:1 with `Map` // noinspection DuplicatedCode #[cfg(not(target_arch = "wasm32"))] diff --git a/sqlx-core/src/sqlite/connection/describe.rs b/sqlx-core/src/sqlite/connection/describe.rs index 8bc9f9ceed..cb86e7e024 100644 --- a/sqlx-core/src/sqlite/connection/describe.rs +++ b/sqlx-core/src/sqlite/connection/describe.rs @@ -64,7 +64,7 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( // fallback to [column_decltype] if !stepped && stmt.read_only() { stepped = true; - let _ = conn.worker.step(*stmt).await; + let _ = conn.worker.step(stmt).await; } let mut ty = stmt.column_type_info(col); diff --git a/sqlx-core/src/sqlite/connection/establish.rs b/sqlx-core/src/sqlite/connection/establish.rs index 20206a4388..ce8105a652 100644 --- a/sqlx-core/src/sqlite/connection/establish.rs +++ b/sqlx-core/src/sqlite/connection/establish.rs @@ -7,8 +7,8 @@ use crate::{ }; use libsqlite3_sys::{ sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK, - SQLITE_OPEN_CREATE, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_PRIVATECACHE, - SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE, + SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX, + SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE, }; use sqlx_rt::blocking; use std::io; @@ -29,13 +29,15 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result Result Result Result Result( +async fn prepare<'a>( + worker: &mut StatementWorker, statements: &'a mut StatementCache, statement: &'a mut Option, query: &str, @@ -39,7 +40,7 @@ fn prepare<'a>( if exists { // as this statement has been executed before, we reset before continuing // this also causes any rows that are from the statement to be inflated - statement.reset(); + statement.reset(worker).await?; } Ok(statement) @@ -61,19 +62,25 @@ fn bind( /// A structure holding sqlite statement handle and resetting the /// statement when it is dropped. -struct StatementResetter { - handle: StatementHandle, +struct StatementResetter<'a> { + handle: Arc, + worker: &'a mut StatementWorker, } -impl StatementResetter { - fn new(handle: StatementHandle) -> Self { - Self { handle } +impl<'a> StatementResetter<'a> { + fn new(worker: &'a mut StatementWorker, handle: &Arc) -> Self { + Self { + worker, + handle: Arc::clone(handle), + } } } -impl Drop for StatementResetter { +impl Drop for StatementResetter<'_> { fn drop(&mut self) { - self.handle.reset(); + // this method is designed to eagerly send the reset command + // so we don't need to await or spawn it + let _ = self.worker.reset(&self.handle); } } @@ -103,7 +110,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } = self; // prepare statement object (or checkout from cache) - let stmt = prepare(statements, statement, sql, persistent)?; + let stmt = prepare(worker, statements, statement, sql, persistent).await?; // keep track of how many arguments we have bound let mut num_arguments = 0; @@ -113,7 +120,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // is dropped. `StatementResetter` will reliably reset the // statement even if the stream returned from `fetch_many` // is dropped early. - let _resetter = StatementResetter::new(*stmt); + let resetter = StatementResetter::new(worker, stmt); // bind values to the statement num_arguments += bind(stmt, &arguments, num_arguments)?; @@ -125,7 +132,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // invoke [sqlite3_step] on the dedicated worker thread // this will move us forward one row or finish the statement - let s = worker.step(*stmt).await?; + let s = resetter.worker.step(stmt).await?; match s { Either::Left(changes) => { @@ -145,7 +152,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { Either::Right(()) => { let (row, weak_values_ref) = SqliteRow::current( - *stmt, + stmt.to_ref(conn.to_ref()), columns, column_names ); @@ -188,7 +195,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } = self; // prepare statement object (or checkout from cache) - let virtual_stmt = prepare(statements, statement, sql, persistent)?; + let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?; // keep track of how many arguments we have bound let mut num_arguments = 0; @@ -205,18 +212,21 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // invoke [sqlite3_step] on the dedicated worker thread // this will move us forward one row or finish the statement - match worker.step(*stmt).await? { + match worker.step(stmt).await? { Either::Left(_) => (), Either::Right(()) => { - let (row, weak_values_ref) = - SqliteRow::current(*stmt, columns, column_names); + let (row, weak_values_ref) = SqliteRow::current( + stmt.to_ref(self.handle.to_ref()), + columns, + column_names, + ); *last_row_values = Some(weak_values_ref); logger.increment_rows(); - virtual_stmt.reset(); + virtual_stmt.reset(worker).await?; return Ok(Some(row)); } } @@ -238,11 +248,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { handle: ref mut conn, ref mut statements, ref mut statement, + ref mut worker, .. } = self; // prepare statement object (or checkout from cache) - let statement = prepare(statements, statement, sql, true)?; + let statement = prepare(worker, statements, statement, sql, true).await?; let mut parameters = 0; let mut columns = None; diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index 3797a3d0bb..14df95e6ac 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -17,6 +17,13 @@ const SQLITE_AFF_REAL: u8 = 0x45; /* 'E' */ const OP_INIT: &str = "Init"; const OP_GOTO: &str = "Goto"; const OP_COLUMN: &str = "Column"; +const OP_MAKE_RECORD: &str = "MakeRecord"; +const OP_INSERT: &str = "Insert"; +const OP_IDX_INSERT: &str = "IdxInsert"; +const OP_OPEN_READ: &str = "OpenRead"; +const OP_OPEN_WRITE: &str = "OpenWrite"; +const OP_OPEN_EPHEMERAL: &str = "OpenEphemeral"; +const OP_OPEN_AUTOINDEX: &str = "OpenAutoindex"; const OP_AGG_STEP: &str = "AggStep"; const OP_FUNCTION: &str = "Function"; const OP_MOVE: &str = "Move"; @@ -34,6 +41,7 @@ const OP_BLOB: &str = "Blob"; const OP_VARIABLE: &str = "Variable"; const OP_COUNT: &str = "Count"; const OP_ROWID: &str = "Rowid"; +const OP_NEWROWID: &str = "NewRowid"; const OP_OR: &str = "Or"; const OP_AND: &str = "And"; const OP_BIT_AND: &str = "BitAnd"; @@ -48,6 +56,21 @@ const OP_REMAINDER: &str = "Remainder"; const OP_CONCAT: &str = "Concat"; const OP_RESULT_ROW: &str = "ResultRow"; +#[derive(Debug, Clone, Eq, PartialEq)] +enum RegDataType { + Single(DataType), + Record(Vec), +} + +impl RegDataType { + fn map_to_datatype(self) -> DataType { + match self { + RegDataType::Single(d) => d, + RegDataType::Record(_) => DataType::Null, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context + } + } +} + #[allow(clippy::wildcard_in_or_patterns)] fn affinity_to_type(affinity: u8) -> DataType { match affinity { @@ -73,13 +96,19 @@ fn opcode_to_type(op: &str) -> DataType { } } +// Opcode Reference: https://sqlite.org/opcode.html pub(super) async fn explain( conn: &mut SqliteConnection, query: &str, ) -> Result<(Vec, Vec>), Error> { - let mut r = HashMap::::with_capacity(6); + // Registers + let mut r = HashMap::::with_capacity(6); + // Map between pointer and register let mut r_cursor = HashMap::>::with_capacity(6); + // Rows that pointers point to + let mut p = HashMap::>::with_capacity(6); + // Nullable columns let mut n = HashMap::::with_capacity(6); let program = @@ -119,15 +148,52 @@ pub(super) async fn explain( } OP_COLUMN => { - r_cursor.entry(p1).or_default().push(p3); + //Get the row stored at p1, or NULL; get the column stored at p2, or NULL + if let Some(record) = p.get(&p1) { + if let Some(col) = record.get(&p2) { + // insert into p3 the datatype of the col + r.insert(p3, RegDataType::Single(*col)); + // map between pointer p1 and register p3 + r_cursor.entry(p1).or_default().push(p3); + } else { + r.insert(p3, RegDataType::Single(DataType::Null)); + } + } else { + r.insert(p3, RegDataType::Single(DataType::Null)); + } + } + + OP_MAKE_RECORD => { + // p3 = Record([p1 .. p1 + p2]) + let mut record = Vec::with_capacity(p2 as usize); + for reg in p1..p1 + p2 { + record.push( + r.get(®) + .map(|d| d.clone().map_to_datatype()) + .unwrap_or(DataType::Null), + ); + } + r.insert(p3, RegDataType::Record(record)); + } + + OP_INSERT | OP_IDX_INSERT => { + if let Some(RegDataType::Record(record)) = r.get(&p2) { + if let Some(row) = p.get_mut(&p1) { + // Insert the record into wherever pointer p1 is + *row = (0..).zip(record.iter().copied()).collect(); + } + } + //Noop if the register p2 isn't a record, or if pointer p1 does not exist + } - // r[p3] = - r.insert(p3, DataType::Null); + OP_OPEN_READ | OP_OPEN_WRITE | OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX => { + //Create a new pointer which is referenced by p1 + p.insert(p1, HashMap::with_capacity(6)); } OP_VARIABLE => { // r[p2] = - r.insert(p2, DataType::Null); + r.insert(p2, RegDataType::Single(DataType::Null)); n.insert(p3, true); } @@ -136,7 +202,7 @@ pub(super) async fn explain( match from_utf8(p4).map_err(Error::protocol)? { "last_insert_rowid(0)" => { // last_insert_rowid() -> INTEGER - r.insert(p3, DataType::Int64); + r.insert(p3, RegDataType::Single(DataType::Int64)); n.insert(p3, n.get(&p3).copied().unwrap_or(false)); } @@ -145,9 +211,9 @@ pub(super) async fn explain( } OP_NULL_ROW => { - // all values of cursor X are potentially nullable - for column in &r_cursor[&p1] { - n.insert(*column, true); + // all registers that map to cursor X are potentially nullable + for register in &r_cursor[&p1] { + n.insert(*register, true); } } @@ -156,9 +222,9 @@ pub(super) async fn explain( if p4.starts_with("count(") { // count(_) -> INTEGER - r.insert(p3, DataType::Int64); + r.insert(p3, RegDataType::Single(DataType::Int64)); n.insert(p3, n.get(&p3).copied().unwrap_or(false)); - } else if let Some(v) = r.get(&p2).copied() { + } else if let Some(v) = r.get(&p2).cloned() { // r[p3] = AGG ( r[p2] ) r.insert(p3, v); let val = n.get(&p2).copied().unwrap_or(true); @@ -169,13 +235,13 @@ pub(super) async fn explain( OP_CAST => { // affinity(r[p1]) if let Some(v) = r.get_mut(&p1) { - *v = affinity_to_type(p2 as u8); + *v = RegDataType::Single(affinity_to_type(p2 as u8)); } } OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => { // r[p2] = r[p1] - if let Some(v) = r.get(&p1).copied() { + if let Some(v) = r.get(&p1).cloned() { r.insert(p2, v); if let Some(null) = n.get(&p1).copied() { @@ -184,15 +250,16 @@ pub(super) async fn explain( } } - OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => { + OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID + | OP_NEWROWID => { // r[p2] = - r.insert(p2, opcode_to_type(&opcode)); + r.insert(p2, RegDataType::Single(opcode_to_type(&opcode))); n.insert(p2, n.get(&p2).copied().unwrap_or(false)); } OP_NOT => { // r[p2] = NOT r[p1] - if let Some(a) = r.get(&p1).copied() { + if let Some(a) = r.get(&p1).cloned() { r.insert(p2, a); let val = n.get(&p1).copied().unwrap_or(true); n.insert(p2, val); @@ -202,9 +269,16 @@ pub(super) async fn explain( OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT | OP_ADD | OP_SUBTRACT | OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => { // r[p3] = r[p1] + r[p2] - match (r.get(&p1).copied(), r.get(&p2).copied()) { + match (r.get(&p1).cloned(), r.get(&p2).cloned()) { (Some(a), Some(b)) => { - r.insert(p3, if matches!(a, DataType::Null) { b } else { a }); + r.insert( + p3, + if matches!(a, RegDataType::Single(DataType::Null)) { + b + } else { + a + }, + ); } (Some(v), None) => { @@ -252,7 +326,11 @@ pub(super) async fn explain( if let Some(result) = result { for i in result { - output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null))); + output.push(SqliteTypeInfo( + r.remove(&i) + .map(|d| d.map_to_datatype()) + .unwrap_or(DataType::Null), + )); nullable.push(n.remove(&i)); } } diff --git a/sqlx-core/src/sqlite/connection/handle.rs b/sqlx-core/src/sqlite/connection/handle.rs index 6aa8f37667..c714fcc5f4 100644 --- a/sqlx-core/src/sqlite/connection/handle.rs +++ b/sqlx-core/src/sqlite/connection/handle.rs @@ -3,11 +3,24 @@ use std::ptr::NonNull; use libsqlite3_sys::{sqlite3, sqlite3_close, SQLITE_OK}; use crate::sqlite::SqliteError; +use std::sync::Arc; /// Managed handle to the raw SQLite3 database handle. -/// The database handle will be closed when this is dropped. +/// The database handle will be closed when this is dropped and no `ConnectionHandleRef`s exist. #[derive(Debug)] -pub(crate) struct ConnectionHandle(pub(super) NonNull); +pub(crate) struct ConnectionHandle(Arc); + +/// A wrapper around `ConnectionHandle` which only exists for a `StatementWorker` to own +/// which prevents the `sqlite3` handle from being finalized while it is running `sqlite3_step()` +/// or `sqlite3_reset()`. +/// +/// Note that this does *not* actually give access to the database handle! +#[derive(Clone, Debug)] +pub(crate) struct ConnectionHandleRef(Arc); + +// Wrapper for `*mut sqlite3` which finalizes the handle on-drop. +#[derive(Debug)] +struct HandleInner(NonNull); // A SQLite3 handle is safe to send between threads, provided not more than // one is accessing it at the same time. This is upheld as long as [SQLITE_CONFIG_MULTITHREAD] is @@ -20,19 +33,32 @@ pub(crate) struct ConnectionHandle(pub(super) NonNull); unsafe impl Send for ConnectionHandle {} +// SAFETY: `Arc` normally only implements `Send` where `T: Sync` because it allows +// concurrent access. +// +// However, in this case we're only using `Arc` to prevent the database handle from being +// finalized while the worker still holds a statement handle; `ConnectionHandleRef` thus +// should *not* actually provide access to the database handle. +unsafe impl Send for ConnectionHandleRef {} + impl ConnectionHandle { #[inline] pub(super) unsafe fn new(ptr: *mut sqlite3) -> Self { - Self(NonNull::new_unchecked(ptr)) + Self(Arc::new(HandleInner(NonNull::new_unchecked(ptr)))) } #[inline] pub(crate) fn as_ptr(&self) -> *mut sqlite3 { - self.0.as_ptr() + self.0 .0.as_ptr() + } + + #[inline] + pub(crate) fn to_ref(&self) -> ConnectionHandleRef { + ConnectionHandleRef(Arc::clone(&self.0)) } } -impl Drop for ConnectionHandle { +impl Drop for HandleInner { fn drop(&mut self) { unsafe { // https://sqlite.org/c3ref/close.html diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index 92926beef4..e001f08fa3 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -17,7 +17,7 @@ mod executor; mod explain; mod handle; -pub(crate) use handle::ConnectionHandle; +pub(crate) use handle::{ConnectionHandle, ConnectionHandleRef}; /// A connection to a [Sqlite] database. pub struct SqliteConnection { @@ -62,9 +62,15 @@ impl Connection for SqliteConnection { type Options = SqliteConnectOptions; - fn close(self) -> BoxFuture<'static, Result<(), Error>> { - // nothing explicit to do; connection will close in drop - Box::pin(future::ok(())) + fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { + Box::pin(async move { + let shutdown = self.worker.shutdown(); + // Drop the statement worker and any outstanding statements, which should + // cover all references to the connection handle outside of the worker thread + drop(self); + // Ensure the worker thread has terminated + shutdown.await + }) } fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { @@ -104,8 +110,7 @@ impl Connection for SqliteConnection { impl Drop for SqliteConnection { fn drop(&mut self) { - // before the connection handle is dropped, - // we must explicitly drop the statements as the drop-order in a struct is undefined + // explicitly drop statements before the connection handle is dropped self.statements.clear(); self.statement.take(); } diff --git a/sqlx-core/src/sqlite/mod.rs b/sqlx-core/src/sqlite/mod.rs index 6b31ff02b5..5be8cbfd92 100644 --- a/sqlx-core/src/sqlite/mod.rs +++ b/sqlx-core/src/sqlite/mod.rs @@ -5,6 +5,8 @@ // invariants. #![allow(unsafe_code)] +use crate::executor::Executor; + mod arguments; mod column; mod connection; @@ -43,6 +45,10 @@ pub type SqlitePool = crate::pool::Pool; /// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for SQLite. pub type SqlitePoolOptions = crate::pool::PoolOptions; +/// An alias for [`Executor<'_, Database = Sqlite>`][Executor]. +pub trait SqliteExecutor<'c>: Executor<'c, Database = Sqlite> {} +impl<'c, T: Executor<'c, Database = Sqlite>> SqliteExecutor<'c> for T {} + // NOTE: required due to the lack of lazy normalization impl_into_arguments_for_arguments!(SqliteArguments<'q>); impl_executor_for_pool_connection!(Sqlite, SqliteConnection, SqliteRow); diff --git a/sqlx-core/src/sqlite/options/connect.rs b/sqlx-core/src/sqlite/options/connect.rs index 6c29120a2f..cbd465ec31 100644 --- a/sqlx-core/src/sqlite/options/connect.rs +++ b/sqlx-core/src/sqlite/options/connect.rs @@ -18,20 +18,12 @@ impl ConnectOptions for SqliteConnectOptions { let mut conn = establish(self).await?; // send an initial sql statement comprised of options - // - // page_size must be set before any other action on the database. - // - // Note that locking_mode should be set before journal_mode; see - // https://www.sqlite.org/wal.html#use_of_wal_without_shared_memory . - let init = format!( - "PRAGMA page_size = {}; PRAGMA locking_mode = {}; PRAGMA journal_mode = {}; PRAGMA foreign_keys = {}; PRAGMA synchronous = {}; PRAGMA auto_vacuum = {}", - self.page_size, - self.locking_mode.as_str(), - self.journal_mode.as_str(), - if self.foreign_keys { "ON" } else { "OFF" }, - self.synchronous.as_str(), - self.auto_vacuum.as_str(), - ); + let mut init = String::new(); + + for (key, value) in self.pragmas.iter() { + use std::fmt::Write; + write!(init, "PRAGMA {} = {}; ", key, value).ok(); + } conn.execute(&*init).await?; diff --git a/sqlx-core/src/sqlite/options/mod.rs b/sqlx-core/src/sqlite/options/mod.rs index ba50bc05d6..9db122f355 100644 --- a/sqlx-core/src/sqlite/options/mod.rs +++ b/sqlx-core/src/sqlite/options/mod.rs @@ -14,6 +14,8 @@ pub use locking_mode::SqliteLockingMode; use std::{borrow::Cow, time::Duration}; pub use synchronous::SqliteSynchronous; +use indexmap::IndexMap; + /// Options and flags which can be used to configure a SQLite connection. /// /// A value of `SqliteConnectOptions` can be parsed from a connection URI, @@ -53,16 +55,13 @@ pub struct SqliteConnectOptions { pub(crate) in_memory: bool, pub(crate) read_only: bool, pub(crate) create_if_missing: bool, - pub(crate) journal_mode: SqliteJournalMode, - pub(crate) locking_mode: SqliteLockingMode, - pub(crate) foreign_keys: bool, pub(crate) shared_cache: bool, pub(crate) statement_cache_capacity: usize, pub(crate) busy_timeout: Duration, pub(crate) log_settings: LogSettings, - pub(crate) synchronous: SqliteSynchronous, - pub(crate) auto_vacuum: SqliteAutoVacuum, - pub(crate) page_size: u32, + pub(crate) immutable: bool, + pub(crate) pragmas: IndexMap, Cow<'static, str>>, + pub(crate) serialized: bool, } impl Default for SqliteConnectOptions { @@ -73,21 +72,45 @@ impl Default for SqliteConnectOptions { impl SqliteConnectOptions { pub fn new() -> Self { + // set default pragmas + let mut pragmas: IndexMap, Cow<'static, str>> = IndexMap::new(); + + let locking_mode: SqliteLockingMode = Default::default(); + let auto_vacuum: SqliteAutoVacuum = Default::default(); + + // page_size must be set before any other action on the database. + pragmas.insert("page_size".into(), "4096".into()); + + // Note that locking_mode should be set before journal_mode; see + // https://www.sqlite.org/wal.html#use_of_wal_without_shared_memory . + pragmas.insert("locking_mode".into(), locking_mode.as_str().into()); + + pragmas.insert( + "journal_mode".into(), + SqliteJournalMode::Wal.as_str().into(), + ); + + pragmas.insert("foreign_keys".into(), "ON".into()); + + pragmas.insert( + "synchronous".into(), + SqliteSynchronous::Full.as_str().into(), + ); + + pragmas.insert("auto_vacuum".into(), auto_vacuum.as_str().into()); + Self { filename: Cow::Borrowed(Path::new(":memory:")), in_memory: false, read_only: false, create_if_missing: false, - foreign_keys: true, shared_cache: false, statement_cache_capacity: 100, - journal_mode: SqliteJournalMode::Wal, - locking_mode: Default::default(), busy_timeout: Duration::from_secs(5), log_settings: Default::default(), - synchronous: SqliteSynchronous::Full, - auto_vacuum: Default::default(), - page_size: 4096, + immutable: false, + pragmas, + serialized: false, } } @@ -101,7 +124,10 @@ impl SqliteConnectOptions { /// /// By default, this is enabled. pub fn foreign_keys(mut self, on: bool) -> Self { - self.foreign_keys = on; + self.pragmas.insert( + "foreign_keys".into(), + (if on { "ON" } else { "OFF" }).into(), + ); self } @@ -118,7 +144,8 @@ impl SqliteConnectOptions { /// The default journal mode is WAL. For most use cases this can be significantly faster but /// there are [disadvantages](https://www.sqlite.org/wal.html). pub fn journal_mode(mut self, mode: SqliteJournalMode) -> Self { - self.journal_mode = mode; + self.pragmas + .insert("journal_mode".into(), mode.as_str().into()); self } @@ -126,7 +153,8 @@ impl SqliteConnectOptions { /// /// The default locking mode is NORMAL. pub fn locking_mode(mut self, mode: SqliteLockingMode) -> Self { - self.locking_mode = mode; + self.pragmas + .insert("locking_mode".into(), mode.as_str().into()); self } @@ -171,7 +199,8 @@ impl SqliteConnectOptions { /// The default synchronous settings is FULL. However, if durability is not a concern, /// then NORMAL is normally all one needs in WAL mode. pub fn synchronous(mut self, synchronous: SqliteSynchronous) -> Self { - self.synchronous = synchronous; + self.pragmas + .insert("synchronous".into(), synchronous.as_str().into()); self } @@ -179,7 +208,8 @@ impl SqliteConnectOptions { /// /// The default auto_vacuum setting is NONE. pub fn auto_vacuum(mut self, auto_vacuum: SqliteAutoVacuum) -> Self { - self.auto_vacuum = auto_vacuum; + self.pragmas + .insert("auto_vacuum".into(), auto_vacuum.as_str().into()); self } @@ -187,7 +217,33 @@ impl SqliteConnectOptions { /// /// The default page_size setting is 4096. pub fn page_size(mut self, page_size: u32) -> Self { - self.page_size = page_size; + self.pragmas + .insert("page_size".into(), page_size.to_string().into()); + self + } + + /// Sets custom initial pragma for the database connection. + pub fn pragma(mut self, key: K, value: V) -> Self + where + K: Into>, + V: Into>, + { + self.pragmas.insert(key.into(), value.into()); + self + } + + pub fn immutable(mut self, immutable: bool) -> Self { + self.immutable = immutable; + self + } + + /// Sets the [threading mode](https://www.sqlite.org/threadsafe.html) for the database connection. + /// + /// The default setting is `false` corersponding to using `OPEN_NOMUTEX`, if `true` then `OPEN_FULLMUTEX`. + /// + /// See [open](https://www.sqlite.org/c3ref/open.html) for more details. + pub fn serialized(mut self, serialized: bool) -> Self { + self.serialized = serialized; self } } diff --git a/sqlx-core/src/sqlite/options/parse.rs b/sqlx-core/src/sqlite/options/parse.rs index 7c21adf469..f677df62b6 100644 --- a/sqlx-core/src/sqlite/options/parse.rs +++ b/sqlx-core/src/sqlite/options/parse.rs @@ -94,6 +94,20 @@ impl FromStr for SqliteConnectOptions { } }, + "immutable" => match &*value { + "true" | "1" => { + options.immutable = true; + } + "false" | "0" => { + options.immutable = false; + } + _ => { + return Err(Error::Configuration( + format!("unknown value {:?} for `immutable`", value).into(), + )); + } + }, + _ => { return Err(Error::Configuration( format!( diff --git a/sqlx-core/src/sqlite/row.rs b/sqlx-core/src/sqlite/row.rs index 9f14ca58f0..4199915fe1 100644 --- a/sqlx-core/src/sqlite/row.rs +++ b/sqlx-core/src/sqlite/row.rs @@ -11,7 +11,7 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::row::Row; -use crate::sqlite::statement::StatementHandle; +use crate::sqlite::statement::{StatementHandle, StatementHandleRef}; use crate::sqlite::{Sqlite, SqliteColumn, SqliteValue, SqliteValueRef}; /// Implementation of [`Row`] for SQLite. @@ -23,7 +23,7 @@ pub struct SqliteRow { // IF the user drops the Row before iterating the stream (so // nearly all of our internal stream iterators), the executor moves on; otherwise, // it actually inflates this row with a list of owned sqlite3 values. - pub(crate) statement: StatementHandle, + pub(crate) statement: StatementHandleRef, pub(crate) values: Arc>, pub(crate) num_values: usize, @@ -48,7 +48,7 @@ impl SqliteRow { // returns a weak reference to an atomic list where the executor should inflate if its going // to increment the statement with [step] pub(crate) fn current( - statement: StatementHandle, + statement: StatementHandleRef, columns: &Arc>, column_names: &Arc>, ) -> (Self, Weak>) { diff --git a/sqlx-core/src/sqlite/statement/handle.rs b/sqlx-core/src/sqlite/statement/handle.rs index d1af117a7d..27e7b59020 100644 --- a/sqlx-core/src/sqlite/statement/handle.rs +++ b/sqlx-core/src/sqlite/statement/handle.rs @@ -1,5 +1,6 @@ use std::ffi::c_void; use std::ffi::CStr; + use std::os::raw::{c_char, c_int}; use std::ptr; use std::ptr::NonNull; @@ -9,21 +10,34 @@ use std::str::{from_utf8, from_utf8_unchecked}; use libsqlite3_sys::{ sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64, sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name, - sqlite3_bind_text64, sqlite3_changes, sqlite3_column_blob, sqlite3_column_bytes, - sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype, - sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name, - sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type, - sqlite3_column_value, sqlite3_db_handle, sqlite3_reset, sqlite3_sql, sqlite3_stmt, - sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, SQLITE_OK, - SQLITE_TRANSIENT, SQLITE_UTF8, + sqlite3_bind_text64, sqlite3_changes, sqlite3_clear_bindings, sqlite3_column_blob, + sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_database_name, + sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, + sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name, + sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_sql, + sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, + SQLITE_MISUSE, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8, }; use crate::error::{BoxDynError, Error}; +use crate::sqlite::connection::ConnectionHandleRef; use crate::sqlite::type_info::DataType; use crate::sqlite::{SqliteError, SqliteTypeInfo}; - -#[derive(Debug, Copy, Clone)] -pub(crate) struct StatementHandle(pub(super) NonNull); +use std::ops::Deref; +use std::sync::Arc; + +#[derive(Debug)] +pub(crate) struct StatementHandle(NonNull); + +// wrapper for `Arc` which also holds a reference to the `ConnectionHandle` +#[derive(Clone, Debug)] +pub(crate) struct StatementHandleRef { + // NOTE: the ordering of fields here determines the drop order: + // https://doc.rust-lang.org/reference/destructors.html#destructors + // the statement *must* be dropped before the connection + statement: Arc, + connection: ConnectionHandleRef, +} // access to SQLite3 statement handles are safe to send and share between threads // as long as the `sqlite3_step` call is serialized. @@ -32,6 +46,14 @@ unsafe impl Send for StatementHandle {} unsafe impl Sync for StatementHandle {} impl StatementHandle { + pub(super) fn new(ptr: NonNull) -> Self { + Self(ptr) + } + + pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt { + self.0.as_ptr() + } + #[inline] pub(super) unsafe fn db_handle(&self) -> *mut sqlite3 { // O(c) access to the connection handle for this statement handle @@ -280,7 +302,44 @@ impl StatementHandle { Ok(from_utf8(self.column_blob(index))?) } - pub(crate) fn reset(&self) { - unsafe { sqlite3_reset(self.0.as_ptr()) }; + pub(crate) fn clear_bindings(&self) { + unsafe { sqlite3_clear_bindings(self.0.as_ptr()) }; + } + + pub(crate) fn to_ref( + self: &Arc, + conn: ConnectionHandleRef, + ) -> StatementHandleRef { + StatementHandleRef { + statement: Arc::clone(self), + connection: conn, + } + } +} + +impl Drop for StatementHandle { + fn drop(&mut self) { + // SAFETY: we have exclusive access to the `StatementHandle` here + unsafe { + // https://sqlite.org/c3ref/finalize.html + let status = sqlite3_finalize(self.0.as_ptr()); + if status == SQLITE_MISUSE { + // Panic in case of detected misuse of SQLite API. + // + // sqlite3_finalize returns it at least in the + // case of detected double free, i.e. calling + // sqlite3_finalize on already finalized + // statement. + panic!("Detected sqlite3_finalize misuse."); + } + } + } +} + +impl Deref for StatementHandleRef { + type Target = StatementHandle; + + fn deref(&self) -> &Self::Target { + &self.statement } } diff --git a/sqlx-core/src/sqlite/statement/mod.rs b/sqlx-core/src/sqlite/statement/mod.rs index dec11dcc17..97ca9f8685 100644 --- a/sqlx-core/src/sqlite/statement/mod.rs +++ b/sqlx-core/src/sqlite/statement/mod.rs @@ -12,7 +12,7 @@ mod handle; mod r#virtual; mod worker; -pub(crate) use handle::StatementHandle; +pub(crate) use handle::{StatementHandle, StatementHandleRef}; pub(crate) use r#virtual::VirtualStatement; pub(crate) use worker::StatementWorker; diff --git a/sqlx-core/src/sqlite/statement/virtual.rs b/sqlx-core/src/sqlite/statement/virtual.rs index 0063e06508..3da6d33d64 100644 --- a/sqlx-core/src/sqlite/statement/virtual.rs +++ b/sqlx-core/src/sqlite/statement/virtual.rs @@ -3,13 +3,12 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::sqlite::connection::ConnectionHandle; -use crate::sqlite::statement::StatementHandle; +use crate::sqlite::statement::{StatementHandle, StatementWorker}; use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue}; use crate::HashMap; use bytes::{Buf, Bytes}; use libsqlite3_sys::{ - sqlite3, sqlite3_clear_bindings, sqlite3_finalize, sqlite3_prepare_v3, sqlite3_reset, - sqlite3_stmt, SQLITE_MISUSE, SQLITE_OK, SQLITE_PREPARE_PERSISTENT, + sqlite3, sqlite3_prepare_v3, sqlite3_stmt, SQLITE_OK, SQLITE_PREPARE_PERSISTENT, }; use smallvec::SmallVec; use std::i32; @@ -31,7 +30,7 @@ pub(crate) struct VirtualStatement { // underlying sqlite handles for each inner statement // a SQL query string in SQLite is broken up into N statements // we use a [`SmallVec`] to optimize for the most likely case of a single statement - pub(crate) handles: SmallVec<[StatementHandle; 1]>, + pub(crate) handles: SmallVec<[Arc; 1]>, // each set of columns pub(crate) columns: SmallVec<[Arc>; 1]>, @@ -92,7 +91,7 @@ fn prepare( query.advance(n); if let Some(handle) = NonNull::new(statement_handle) { - return Ok(Some(StatementHandle(handle))); + return Ok(Some(StatementHandle::new(handle))); } } @@ -126,7 +125,7 @@ impl VirtualStatement { conn: &mut ConnectionHandle, ) -> Result< Option<( - &StatementHandle, + &Arc, &mut Arc>, &Arc>, &mut Option>>, @@ -159,7 +158,7 @@ impl VirtualStatement { column_names.insert(name, i); } - self.handles.push(statement); + self.handles.push(Arc::new(statement)); self.columns.push(Arc::new(columns)); self.column_names.push(Arc::new(column_names)); self.last_row_values.push(None); @@ -177,20 +176,20 @@ impl VirtualStatement { ))) } - pub(crate) fn reset(&mut self) { + pub(crate) async fn reset(&mut self, worker: &mut StatementWorker) -> Result<(), Error> { self.index = 0; for (i, handle) in self.handles.iter().enumerate() { SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); - unsafe { - // Reset A Prepared Statement Object - // https://www.sqlite.org/c3ref/reset.html - // https://www.sqlite.org/c3ref/clear_bindings.html - sqlite3_reset(handle.0.as_ptr()); - sqlite3_clear_bindings(handle.0.as_ptr()); - } + // Reset A Prepared Statement Object + // https://www.sqlite.org/c3ref/reset.html + // https://www.sqlite.org/c3ref/clear_bindings.html + worker.reset(handle).await?; + handle.clear_bindings(); } + + Ok(()) } } @@ -198,20 +197,6 @@ impl Drop for VirtualStatement { fn drop(&mut self) { for (i, handle) in self.handles.drain(..).enumerate() { SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); - - unsafe { - // https://sqlite.org/c3ref/finalize.html - let status = sqlite3_finalize(handle.0.as_ptr()); - if status == SQLITE_MISUSE { - // Panic in case of detected misuse of SQLite API. - // - // sqlite3_finalize returns it at least in the - // case of detected double free, i.e. calling - // sqlite3_finalize on already finalized - // statement. - panic!("Detected sqlite3_finalize misuse."); - } - } } } } diff --git a/sqlx-core/src/sqlite/statement/worker.rs b/sqlx-core/src/sqlite/statement/worker.rs index 8b1d229978..5a06f637b0 100644 --- a/sqlx-core/src/sqlite/statement/worker.rs +++ b/sqlx-core/src/sqlite/statement/worker.rs @@ -3,9 +3,14 @@ use crate::sqlite::statement::StatementHandle; use crossbeam_channel::{unbounded, Sender}; use either::Either; use futures_channel::oneshot; -use libsqlite3_sys::{sqlite3_step, SQLITE_DONE, SQLITE_ROW}; +use std::sync::{Arc, Weak}; use std::thread; +use crate::sqlite::connection::ConnectionHandleRef; + +use libsqlite3_sys::{sqlite3_reset, sqlite3_step, SQLITE_DONE, SQLITE_ROW}; +use std::future::Future; + // Each SQLite connection has a dedicated thread. // TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce @@ -18,31 +23,70 @@ pub(crate) struct StatementWorker { enum StatementWorkerCommand { Step { - statement: StatementHandle, + statement: Weak, tx: oneshot::Sender, Error>>, }, + Reset { + statement: Weak, + tx: oneshot::Sender<()>, + }, + Shutdown { + tx: oneshot::Sender<()>, + }, } impl StatementWorker { - pub(crate) fn new() -> Self { + pub(crate) fn new(conn: ConnectionHandleRef) -> Self { let (tx, rx) = unbounded(); thread::spawn(move || { for cmd in rx { match cmd { StatementWorkerCommand::Step { statement, tx } => { - let status = unsafe { sqlite3_step(statement.0.as_ptr()) }; + let statement = if let Some(statement) = statement.upgrade() { + statement + } else { + // statement is already finalized, the sender shouldn't be expecting a response + continue; + }; - let resp = match status { + // SAFETY: only the `StatementWorker` calls this function + let status = unsafe { sqlite3_step(statement.as_ptr()) }; + let result = match status { SQLITE_ROW => Ok(Either::Right(())), SQLITE_DONE => Ok(Either::Left(statement.changes())), _ => Err(statement.last_error().into()), }; - let _ = tx.send(resp); + let _ = tx.send(result); + } + StatementWorkerCommand::Reset { statement, tx } => { + if let Some(statement) = statement.upgrade() { + // SAFETY: this must be the only place we call `sqlite3_reset` + unsafe { sqlite3_reset(statement.as_ptr()) }; + + // `sqlite3_reset()` always returns either `SQLITE_OK` + // or the last error code for the statement, + // which should have already been handled; + // so it's assumed the return value is safe to ignore. + // + // https://www.sqlite.org/c3ref/reset.html + + let _ = tx.send(()); + } + } + StatementWorkerCommand::Shutdown { tx } => { + // drop the connection reference before sending confirmation + // and ending the command loop + drop(conn); + let _ = tx.send(()); + return; } } } + + // SAFETY: we need to make sure a strong ref to `conn` always outlives anything in `rx` + drop(conn); }); Self { tx } @@ -50,14 +94,68 @@ impl StatementWorker { pub(crate) async fn step( &mut self, - statement: StatementHandle, + statement: &Arc, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.tx - .send(StatementWorkerCommand::Step { statement, tx }) + .send(StatementWorkerCommand::Step { + statement: Arc::downgrade(statement), + tx, + }) .map_err(|_| Error::WorkerCrashed)?; rx.await.map_err(|_| Error::WorkerCrashed)? } + + /// Send a command to the worker to execute `sqlite3_reset()` next. + /// + /// This method is written to execute the sending of the command eagerly so + /// you do not need to await the returned future unless you want to. + /// + /// The only error is `WorkerCrashed` as `sqlite3_reset()` returns the last error + /// in the statement execution which should have already been handled from `step()`. + pub(crate) fn reset( + &mut self, + statement: &Arc, + ) -> impl Future> { + // execute the sending eagerly so we don't need to spawn the future + let (tx, rx) = oneshot::channel(); + + let send_res = self + .tx + .send(StatementWorkerCommand::Reset { + statement: Arc::downgrade(statement), + tx, + }) + .map_err(|_| Error::WorkerCrashed); + + async move { + send_res?; + + // wait for the response + rx.await.map_err(|_| Error::WorkerCrashed) + } + } + + /// Send a command to the worker to shut down the processing thread. + /// + /// A `WorkerCrashed` error may be returned if the thread has already stopped. + /// Subsequent calls to `step()`, `reset()`, or this method will fail with + /// `WorkerCrashed`. Ensure that any associated statements are dropped first. + pub(crate) fn shutdown(&mut self) -> impl Future> { + let (tx, rx) = oneshot::channel(); + + let send_res = self + .tx + .send(StatementWorkerCommand::Shutdown { tx }) + .map_err(|_| Error::WorkerCrashed); + + async move { + send_res?; + + // wait for the response + rx.await.map_err(|_| Error::WorkerCrashed) + } + } } diff --git a/sqlx-core/src/sqlite/types/chrono.rs b/sqlx-core/src/sqlite/types/chrono.rs index cd01c3bde2..1ebb2c4f45 100644 --- a/sqlx-core/src/sqlite/types/chrono.rs +++ b/sqlx-core/src/sqlite/types/chrono.rs @@ -76,7 +76,7 @@ impl Encode<'_, Sqlite> for NaiveDate { impl Encode<'_, Sqlite> for NaiveTime { fn encode_by_ref(&self, buf: &mut Vec>) -> IsNull { - Encode::::encode(self.format("%T%.f%").to_string(), buf) + Encode::::encode(self.format("%T%.f").to_string(), buf) } } @@ -179,9 +179,11 @@ impl<'r> Decode<'r, Sqlite> for NaiveTime { // Loop over common time patterns, inspired by Diesel // https://github.com/diesel-rs/diesel/blob/93ab183bcb06c69c0aee4a7557b6798fd52dd0d8/diesel/src/sqlite/types/date_and_time/chrono.rs#L29-L47 + #[rustfmt::skip] // don't like how rustfmt mangles the comments let sqlite_time_formats = &[ // Most likely format - "%T.f", // Other formats in order of appearance in docs + "%T.f", "%T%.f", + // Other formats in order of appearance in docs "%R", "%RZ", "%T%.fZ", "%R%:z", "%T%.f%:z", ]; diff --git a/sqlx-core/src/sqlite/types/str.rs b/sqlx-core/src/sqlite/types/str.rs index 6a3ed533b1..086597ef10 100644 --- a/sqlx-core/src/sqlite/types/str.rs +++ b/sqlx-core/src/sqlite/types/str.rs @@ -52,3 +52,23 @@ impl<'r> Decode<'r, Sqlite> for String { value.text().map(ToOwned::to_owned) } } + +impl<'q> Encode<'q, Sqlite> for Cow<'q, str> { + fn encode(self, args: &mut Vec>) -> IsNull { + args.push(SqliteArgumentValue::Text(self)); + + IsNull::No + } + + fn encode_by_ref(&self, args: &mut Vec>) -> IsNull { + args.push(SqliteArgumentValue::Text(self.clone())); + + IsNull::No + } +} + +impl<'r> Decode<'r, Sqlite> for Cow<'r, str> { + fn decode(value: SqliteValueRef<'r>) -> Result { + value.text().map(Cow::Borrowed) + } +} diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 600daf0fdd..2bf4e3b5d2 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -75,6 +75,13 @@ pub mod ipnetwork { pub use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; } +#[cfg(feature = "mac_address")] +#[cfg_attr(docsrs, doc(cfg(feature = "mac_address")))] +pub mod mac_address { + #[doc(no_inline)] + pub use mac_address::MacAddress; +} + #[cfg(feature = "json")] pub use json::Json; diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index b76185ed53..099049ea90 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlx-macros" -version = "0.5.5" +version = "0.5.9" repository = "https://github.com/launchbadge/sqlx" description = "Macros for SQLx, the rust SQL toolkit. Not intended to be used directly." license = "MIT OR Apache-2.0" @@ -72,20 +72,20 @@ decimal = ["sqlx-core/decimal"] chrono = ["sqlx-core/chrono"] time = ["sqlx-core/time"] ipnetwork = ["sqlx-core/ipnetwork"] +mac_address = ["sqlx-core/mac_address"] uuid = ["sqlx-core/uuid"] bit-vec = ["sqlx-core/bit-vec"] json = ["sqlx-core/json", "serde_json"] [dependencies] dotenv = { version = "0.15.0", default-features = false } -futures = { version = "0.3.4", default-features = false, features = ["executor"] } hex = { version = "0.4.2", optional = true } heck = "0.3.1" either = "1.5.3" once_cell = "1.5.2" proc-macro2 = { version = "1.0.9", default-features = false } -sqlx-core = { version = "0.5.5", default-features = false, path = "../sqlx-core" } -sqlx-rt = { version = "0.5.5", default-features = false, path = "../sqlx-rt" } +sqlx-core = { version = "0.5.9", default-features = false, path = "../sqlx-core" } +sqlx-rt = { version = "0.5.9", default-features = false, path = "../sqlx-rt" } serde = { version = "1.0.111", features = ["derive"], optional = true } serde_json = { version = "1.0.30", features = ["preserve_order"], optional = true } sha2 = { version = "0.9.1", optional = true } diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index 05f0a88bd6..5330bb3cd9 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -60,6 +60,9 @@ impl_database_ext! { #[cfg(feature = "ipnetwork")] sqlx::types::ipnetwork::IpNetwork, + #[cfg(feature = "mac_address")] + sqlx::types::mac_address::MacAddress, + #[cfg(feature = "json")] serde_json::Value, @@ -113,6 +116,9 @@ impl_database_ext! { #[cfg(feature = "ipnetwork")] Vec | &[sqlx::types::ipnetwork::IpNetwork], + #[cfg(feature = "mac_address")] + Vec | &[sqlx::types::mac_address::MacAddress], + #[cfg(feature = "json")] Vec | &[serde_json::Value], diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index bdf2812999..202b6b5a17 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -15,7 +15,7 @@ macro_rules! assert_attribute { macro_rules! fail { ($t:expr, $m:expr) => { - return Err(syn::Error::new_spanned($t, $m)); + return Err(syn::Error::new_spanned($t, $m)) }; } @@ -216,8 +216,6 @@ pub fn check_transparent_attributes( field ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - let ch_attributes = parse_child_attributes(&field.attrs)?; assert_attribute!( diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 57de5b4421..8a4ea4a248 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -2,6 +2,10 @@ not(any(feature = "postgres", feature = "mysql", feature = "offline")), allow(dead_code, unused_macros, unused_imports) )] +#![cfg_attr( + any(sqlx_macros_unstable, procmacro2_semver_exempt), + feature(track_path, proc_macro_tracked_env) +)] extern crate proc_macro; use proc_macro::TokenStream; diff --git a/sqlx-macros/src/migrate.rs b/sqlx-macros/src/migrate.rs index f10bc22318..018ba1b41e 100644 --- a/sqlx-macros/src/migrate.rs +++ b/sqlx-macros/src/migrate.rs @@ -24,7 +24,7 @@ struct QuotedMigration { version: i64, description: String, migration_type: QuotedMigrationType, - sql: String, + path: String, checksum: Vec, } @@ -34,7 +34,7 @@ impl ToTokens for QuotedMigration { version, description, migration_type, - sql, + path, checksum, } = &self; @@ -43,7 +43,8 @@ impl ToTokens for QuotedMigration { version: #version, description: ::std::borrow::Cow::Borrowed(#description), migration_type: #migration_type, - sql: ::std::borrow::Cow::Borrowed(#sql), + // this tells the compiler to watch this path for changes + sql: ::std::borrow::Cow::Borrowed(include_str!(#path)), checksum: ::std::borrow::Cow::Borrowed(&[ #(#checksum),* ]), @@ -59,7 +60,7 @@ pub(crate) fn expand_migrator_from_dir(dir: LitStr) -> crate::Result crate::Result crate::Result, query: &str) -> crate::Result { - serde_json::Deserializer::from_reader(BufReader::new( + let this = serde_json::Deserializer::from_reader(BufReader::new( File::open(path.as_ref()).map_err(|e| { format!("failed to open path {}: {}", path.as_ref().display(), e) })?, @@ -69,8 +69,22 @@ pub mod offline { .deserialize_map(DataFileVisitor { query, hash: hash_string(query), - }) - .map_err(Into::into) + })?; + + #[cfg(procmacr2_semver_exempt)] + { + let path = path.as_ref().canonicalize()?; + let path = path.to_str().ok_or_else(|| { + format!( + "sqlx-data.json path cannot be represented as a string: {:?}", + path + ) + })?; + + proc_macro::tracked_path::path(path); + } + + Ok(this) } } diff --git a/sqlx-macros/src/query/input.rs b/sqlx-macros/src/query/input.rs index 86627d60b1..f3bce4a333 100644 --- a/sqlx-macros/src/query/input.rs +++ b/sqlx-macros/src/query/input.rs @@ -8,7 +8,7 @@ use syn::{ExprArray, Type}; /// Macro input shared by `query!()` and `query_file!()` pub struct QueryMacroInput { - pub(super) src: String, + pub(super) sql: String, #[cfg_attr(not(feature = "offline"), allow(dead_code))] pub(super) src_span: Span, @@ -18,6 +18,8 @@ pub struct QueryMacroInput { pub(super) arg_exprs: Vec, pub(super) checked: bool, + + pub(super) file_path: Option, } enum QuerySrc { @@ -94,12 +96,15 @@ impl Parse for QueryMacroInput { let arg_exprs = args.unwrap_or_default(); + let file_path = src.file_path(src_span)?; + Ok(QueryMacroInput { - src: src.resolve(src_span)?, + sql: src.resolve(src_span)?, src_span, record_type, arg_exprs, checked, + file_path, }) } } @@ -112,6 +117,27 @@ impl QuerySrc { QuerySrc::File(file) => read_file_src(&file, source_span), } } + + fn file_path(&self, source_span: Span) -> syn::Result> { + if let QuerySrc::File(ref file) = *self { + let path = crate::common::resolve_path(file, source_span)? + .canonicalize() + .map_err(|e| syn::Error::new(source_span, e))?; + + Ok(Some( + path.to_str() + .ok_or_else(|| { + syn::Error::new( + source_span, + "query file path cannot be represented as a string", + ) + })? + .to_string(), + )) + } else { + Ok(None) + } + } } fn read_file_src(source: &str, source_span: Span) -> syn::Result { diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index 44237c8901..58c5dc5f34 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -1,4 +1,6 @@ use std::path::PathBuf; +#[cfg(feature = "offline")] +use std::sync::{Arc, Mutex}; use once_cell::sync::Lazy; use proc_macro2::TokenStream; @@ -28,71 +30,84 @@ mod input; mod output; struct Metadata { + #[allow(unused)] manifest_dir: PathBuf, offline: bool, database_url: Option, #[cfg(feature = "offline")] target_dir: PathBuf, #[cfg(feature = "offline")] - workspace_root: PathBuf, + workspace_root: Arc>>, +} + +#[cfg(feature = "offline")] +impl Metadata { + pub fn workspace_root(&self) -> PathBuf { + let mut root = self.workspace_root.lock().unwrap(); + if root.is_none() { + use serde::Deserialize; + use std::process::Command; + + let cargo = env("CARGO").expect("`CARGO` must be set"); + + let output = Command::new(&cargo) + .args(&["metadata", "--format-version=1"]) + .current_dir(&self.manifest_dir) + .env_remove("__CARGO_FIX_PLZ") + .output() + .expect("Could not fetch metadata"); + + #[derive(Deserialize)] + struct CargoMetadata { + workspace_root: PathBuf, + } + + let metadata: CargoMetadata = + serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output"); + + *root = Some(metadata.workspace_root); + } + root.clone().unwrap() + } } // If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't // reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946 static METADATA: Lazy = Lazy::new(|| { - use std::env; - - let manifest_dir: PathBuf = env::var("CARGO_MANIFEST_DIR") + let manifest_dir: PathBuf = env("CARGO_MANIFEST_DIR") .expect("`CARGO_MANIFEST_DIR` must be set") .into(); #[cfg(feature = "offline")] - let target_dir = - env::var_os("CARGO_TARGET_DIR").map_or_else(|| "target".into(), |dir| dir.into()); + let target_dir = env("CARGO_TARGET_DIR").map_or_else(|_| "target".into(), |dir| dir.into()); // If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this, // otherwise fallback to default dotenv behaviour. let env_path = manifest_dir.join(".env"); - if env_path.exists() { + + #[cfg_attr(not(procmacro2_semver_exempt), allow(unused_variables))] + let env_path = if env_path.exists() { let res = dotenv::from_path(&env_path); if let Err(e) = res { panic!("failed to load environment from {:?}, {}", env_path, e); } + + Some(env_path) } else { - let _ = dotenv::dotenv(); + dotenv::dotenv().ok() + }; + + // tell the compiler to watch the `.env` for changes, if applicable + #[cfg(procmacro2_semver_exempt)] + if let Some(env_path) = env_path.as_ref().and_then(|path| path.to_str()) { + proc_macro::tracked_path::path(env_path); } - // TODO: Switch to `var_os` after feature(osstring_ascii) is stable. - // Stabilization PR: https://github.com/rust-lang/rust/pull/80193 - let offline = env::var("SQLX_OFFLINE") + let offline = env("SQLX_OFFLINE") .map(|s| s.eq_ignore_ascii_case("true") || s == "1") .unwrap_or(false); - let database_url = env::var("DATABASE_URL").ok(); - - #[cfg(feature = "offline")] - let workspace_root = { - use serde::Deserialize; - use std::process::Command; - - let cargo = env::var_os("CARGO").expect("`CARGO` must be set"); - - let output = Command::new(&cargo) - .args(&["metadata", "--format-version=1"]) - .current_dir(&manifest_dir) - .output() - .expect("Could not fetch metadata"); - - #[derive(Deserialize)] - struct CargoMetadata { - workspace_root: PathBuf, - } - - let metadata: CargoMetadata = - serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output"); - - metadata.workspace_root - }; + let database_url = env("DATABASE_URL").ok(); Metadata { manifest_dir, @@ -101,7 +116,7 @@ static METADATA: Lazy = Lazy::new(|| { #[cfg(feature = "offline")] target_dir, #[cfg(feature = "offline")] - workspace_root, + workspace_root: Arc::new(Mutex::new(None)), } }); @@ -116,18 +131,20 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result { #[cfg(feature = "offline")] _ => { let data_file_path = METADATA.manifest_dir.join("sqlx-data.json"); - let workspace_data_file_path = METADATA.workspace_root.join("sqlx-data.json"); if data_file_path.exists() { expand_from_file(input, data_file_path) - } else if workspace_data_file_path.exists() { - expand_from_file(input, workspace_data_file_path) } else { - Err( - "`DATABASE_URL` must be set, or `cargo sqlx prepare` must have been run \ + let workspace_data_file_path = METADATA.workspace_root().join("sqlx-data.json"); + if workspace_data_file_path.exists() { + expand_from_file(input, workspace_data_file_path) + } else { + Err( + "`DATABASE_URL` must be set, or `cargo sqlx prepare` must have been run \ and sqlx-data.json must exist, to use query macros" - .into(), - ) + .into(), + ) + } } } @@ -156,7 +173,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { let data = block_on(async { let mut conn = sqlx_core::postgres::PgConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.src).await + QueryData::from_db(&mut conn, &input.sql).await })?; expand_with_data(input, data, false) @@ -188,7 +205,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { let data = block_on(async { let mut conn = sqlx_core::mssql::MssqlConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.src).await + QueryData::from_db(&mut conn, &input.sql).await })?; expand_with_data(input, data, false) @@ -201,7 +218,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { let data = block_on(async { let mut conn = sqlx_core::mysql::MySqlConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.src).await + QueryData::from_db(&mut conn, &input.sql).await })?; expand_with_data(input, data, false) @@ -214,7 +231,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { let data = block_on(async { let mut conn = sqlx_core::sqlite::SqliteConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.src).await + QueryData::from_db(&mut conn, &input.sql).await })?; expand_with_data(input, data, false) @@ -231,7 +248,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result crate::Result { use data::offline::DynQueryData; - let query_data = DynQueryData::from_data_file(file, &input.src)?; + let query_data = DynQueryData::from_data_file(file, &input.sql)?; assert!(!query_data.db_name.is_empty()); match &*query_data.db_name { @@ -312,7 +329,7 @@ where .all(|it| it.type_info().is_void()) { let db_path = DB::db_path(); - let sql = &input.src; + let sql = &input.sql; quote! { ::sqlx::query_with::<#db_path, _>(#sql, #query_args) @@ -338,6 +355,7 @@ where |&output::RustColumn { ref ident, ref type_, + .. }| quote!(#ident: #type_,), ); @@ -392,3 +410,16 @@ where Ok(ret_tokens) } + +/// Get the value of an environment variable, telling the compiler about it if applicable. +fn env(name: &str) -> Result { + #[cfg(procmacro2_semver_exempt)] + { + proc_macro::tracked_env::var(name) + } + + #[cfg(not(procmacro2_semver_exempt))] + { + std::env::var(name) + } +} diff --git a/sqlx-macros/src/query/output.rs b/sqlx-macros/src/query/output.rs index e7b482e044..f7d56646dd 100644 --- a/sqlx-macros/src/query/output.rs +++ b/sqlx-macros/src/query/output.rs @@ -14,6 +14,7 @@ use syn::Token; pub struct RustColumn { pub(super) ident: Ident, + pub(super) var_name: Ident, pub(super) type_: ColumnType, } @@ -114,6 +115,9 @@ fn column_to_rust(describe: &Describe, i: usize) -> crate:: }; Ok(RustColumn { + // prefix the variable name we use in `quote_query_as!()` so it doesn't conflict + // https://github.com/launchbadge/sqlx/issues/1322 + var_name: quote::format_ident!("sqlx_query_as_{}", decl.ident), ident: decl.ident, type_, }) @@ -129,7 +133,7 @@ pub fn quote_query_as( |( i, &RustColumn { - ref ident, + ref var_name, ref type_, .. }, @@ -140,24 +144,32 @@ pub fn quote_query_as( // binding to a `let` avoids confusing errors about // "try expression alternatives have incompatible types" // it doesn't seem to hurt inference in the other branches - let #ident = row.try_get_unchecked::<#type_, _>(#i)?; + let #var_name = row.try_get_unchecked::<#type_, _>(#i)?; }, // type was overridden to be a wildcard so we fallback to the runtime check - (true, ColumnType::Wildcard) => quote! ( let #ident = row.try_get(#i)?; ), + (true, ColumnType::Wildcard) => quote! ( let #var_name = row.try_get(#i)?; ), (true, ColumnType::OptWildcard) => { - quote! ( let #ident = row.try_get::<::std::option::Option<_>, _>(#i)?; ) + quote! ( let #var_name = row.try_get::<::std::option::Option<_>, _>(#i)?; ) } // macro is the `_unchecked!()` variant so this will die in decoding if it's wrong - (false, _) => quote!( let #ident = row.try_get_unchecked(#i)?; ), + (false, _) => quote!( let #var_name = row.try_get_unchecked(#i)?; ), } }, ); let ident = columns.iter().map(|col| &col.ident); + let var_name = columns.iter().map(|col| &col.var_name); let db_path = DB::db_path(); let row_path = DB::row_path(); - let sql = &input.src; + + // if this query came from a file, use `include_str!()` to tell the compiler where it came from + let sql = if let Some(ref path) = &input.file_path { + quote::quote_spanned! { input.src_span => include_str!(#path) } + } else { + let sql = &input.sql; + quote! { #sql } + }; quote! { ::sqlx::query_with::<#db_path, _>(#sql, #bind_args).try_map(|row: #row_path| { @@ -165,7 +177,7 @@ pub fn quote_query_as( #(#instantiations)* - Ok(#out_ty { #(#ident: #ident),* }) + Ok(#out_ty { #(#ident: #var_name),* }) }) } } @@ -200,7 +212,7 @@ pub fn quote_query_scalar( }; let db = DB::db_path(); - let query = &input.src; + let query = &input.sql; Ok(quote! { ::sqlx::query_scalar_with::<#db, #ty, _>(#query, #bind_args) diff --git a/sqlx-rt/Cargo.toml b/sqlx-rt/Cargo.toml index 5d1cb0326e..5d9d8d523a 100644 --- a/sqlx-rt/Cargo.toml +++ b/sqlx-rt/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlx-rt" -version = "0.5.5" +version = "0.5.9" repository = "https://github.com/launchbadge/sqlx" license = "MIT OR Apache-2.0" description = "Runtime abstraction used by SQLx, the Rust SQL toolkit. Not intended to be used directly." diff --git a/sqlx-rt/src/lib.rs b/sqlx-rt/src/lib.rs index 9b82302d99..ed28d79c73 100644 --- a/sqlx-rt/src/lib.rs +++ b/sqlx-rt/src/lib.rs @@ -37,7 +37,7 @@ pub use native_tls; ))] pub use tokio::{ self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf, - net::TcpStream, task::spawn, task::yield_now, time::sleep, time::timeout, + net::TcpStream, runtime::Handle, task::spawn, task::yield_now, time::sleep, time::timeout, }; #[cfg(all( @@ -105,7 +105,8 @@ pub use tokio_rustls::{client::TlsStream, TlsConnector}; #[macro_export] macro_rules! blocking { ($($expr:tt)*) => { - $crate::tokio::task::block_in_place(move || { $($expr)* }) + $crate::tokio::task::spawn_blocking(move || { $($expr)* }) + .await.expect("Blocking task failed to complete.") }; } diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index ccef6424e8..da70155e5f 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -197,7 +197,8 @@ macro_rules! __test_prepared_type { #[macro_export] macro_rules! MySql_query_for_test_prepared_type { () => { - "SELECT {0} <=> ?, {0}, ?" + // MySQL 8.0.27 changed `<=>` to return an unsigned integer + "SELECT CAST({0} <=> ? AS SIGNED INTEGER), {0}, ?" }; } diff --git a/src/lib.rs b/src/lib.rs index f7ab668255..3681850bff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,15 +49,15 @@ pub use sqlx_core::migrate; ), feature = "any" ))] -pub use sqlx_core::any::{self, Any, AnyConnection, AnyPool}; +pub use sqlx_core::any::{self, Any, AnyConnection, AnyExecutor, AnyPool}; #[cfg(feature = "mysql")] #[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] -pub use sqlx_core::mysql::{self, MySql, MySqlConnection, MySqlPool}; +pub use sqlx_core::mysql::{self, MySql, MySqlConnection, MySqlExecutor, MySqlPool}; #[cfg(feature = "mssql")] #[cfg_attr(docsrs, doc(cfg(feature = "mssql")))] -pub use sqlx_core::mssql::{self, Mssql, MssqlConnection, MssqlPool}; +pub use sqlx_core::mssql::{self, Mssql, MssqlConnection, MssqlExecutor, MssqlPool}; #[cfg(all(feature = "postgres", not(target_arch = "wasm32")))] #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] @@ -65,11 +65,11 @@ pub use sqlx_core::postgres::PgPool; #[cfg(feature = "postgres")] #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] -pub use sqlx_core::postgres::{self, PgConnection, Postgres}; +pub use sqlx_core::postgres::{self, PgConnection, PgExecutor, Postgres}; #[cfg(feature = "sqlite")] #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] -pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool}; +pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqliteExecutor, SqlitePool}; #[cfg(feature = "macros")] #[doc(hidden)] diff --git a/src/macros.rs b/src/macros.rs index fdabc1e1ef..a35a119a03 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -36,6 +36,7 @@ /// | Zero or One | `.fetch_optional(...).await`| `sqlx::Result>` | Extra rows are ignored. | /// | Exactly One | `.fetch_one(...).await` | `sqlx::Result<{adhoc struct}>` | Errors if no rows were returned. Extra rows are ignored. Aggregate queries, use this. | /// | At Least One | `.fetch(...)` | `impl Stream>` | Call `.try_next().await` to get each row result. | +/// | Multiple | `.fetch_all(...)` | `sqlx::Result>` | | /// /// \* All methods accept one of `&mut {connection type}`, `&mut Transaction` or `&Pool`. /// † Only callable if the query returns no columns; otherwise it's assumed the query *may* return at least one row. @@ -459,6 +460,7 @@ macro_rules! query_file_unchecked ( /// | Zero or One | `.fetch_optional(...).await`| `sqlx::Result>` | Extra rows are ignored. | /// | Exactly One | `.fetch_one(...).await` | `sqlx::Result` | Errors if no rows were returned. Extra rows are ignored. Aggregate queries, use this. | /// | At Least One | `.fetch(...)` | `impl Stream>` | Call `.try_next().await` to get each row result. | +/// | Multiple | `.fetch_all(...)` | `sqlx::Result>` | | /// /// \* All methods accept one of `&mut {connection type}`, `&mut Transaction` or `&Pool`. /// (`.execute()` is omitted as this macro requires at least one column to be returned.) @@ -715,6 +717,70 @@ macro_rules! query_file_scalar_unchecked ( /// The directory must be relative to the project root (the directory containing `Cargo.toml`), /// unlike `include_str!()` which uses compiler internals to get the path of the file where it /// was invoked. +/// +/// See [MigrationSource][crate::migrate::MigrationSource] for details on structure of the ./migrations directory. +/// +/// ## Triggering Recompilation on Migration Changes +/// In some cases when making changes to embedded migrations, such as adding a new migration without +/// changing any Rust source files, you might find that `cargo build` doesn't actually do anything, +/// or when you do `cargo run` your application isn't applying new migrations on startup. +/// +/// This is because our ability to tell the compiler to watch external files for changes +/// from a proc-macro is very limited. The compiler by default only re-runs proc macros when +/// one ore more source files have changed, because normally it shouldn't have to otherwise. SQLx is +/// just weird in that external factors can change the output of proc macros, much to the chagrin of +/// the compiler team and IDE plugin authors. +/// +/// As of 0.5.6, we emit `include_str!()` with an absolute path for each migration, but that +/// only works to get the compiler to watch _existing_ migration files for changes. +/// +/// Our only options for telling it to watch the whole `migrations/` directory are either via the +/// user creating a Cargo build script in their project, or using an unstable API on nightly +/// governed by a `cfg`-flag. +/// +/// ##### Stable Rust: Cargo Build Script +/// The only solution on stable Rust right now is to create a Cargo build script in your project +/// and have it print `cargo:rerun-if-changed=migrations`: +/// +/// `build.rs` +/// ``` +/// fn main() { +/// println!("cargo:rerun-if-changed=migrations"); +/// } +/// ``` +/// +/// You can run `sqlx migrate build-script` to generate this file automatically. +/// +/// See: [The Cargo Book: 3.8 Build Scripts; Outputs of the Build Script](https://doc.rust-lang.org/stable/cargo/reference/build-scripts.html#outputs-of-the-build-script) +/// +/// #### Nightly Rust: `cfg` Flag +/// The `migrate!()` macro also listens to `--cfg sqlx_macros_unstable`, which will enable +/// the `track_path` feature to directly tell the compiler to watch the `migrations/` directory: +/// +/// ```sh,ignore +/// $ env RUSTFLAGS='--cfg sqlx_macros_unstable' cargo build +/// ``` +/// +/// Note that this unfortunately will trigger a fully recompile of your dependency tree, at least +/// for the first time you use it. It also, of course, requires using a nightly compiler. +/// +/// You can also set it in `build.rustflags` in `.cargo/config.toml`: +/// ```toml,ignore +/// [build] +/// rustflags = ["--cfg sqlx_macros_unstable"] +/// ``` +/// +/// And then continue building and running your project normally. +/// +/// If you're building on nightly anyways, it would be extremely helpful to help us test +/// this feature and find any bugs in it. +/// +/// Subscribe to [the `track_path` tracking issue](https://github.com/rust-lang/rust/issues/73921) +/// for discussion and the future stabilization of this feature. +/// +/// For brevity and because it involves the same commitment to unstable features in `proc_macro`, +/// if you're using `--cfg procmacro2_semver_exempt` it will also enable this feature +/// (see [`proc-macro2` docs / Unstable Features](https://docs.rs/proc-macro2/1.0.27/proc_macro2/#unstable-features)). #[cfg(feature = "migrate")] #[macro_export] macro_rules! migrate { diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000000..bc2dc2327c --- /dev/null +++ b/tests/README.md @@ -0,0 +1,18 @@ + + +### Running Tests +SQLx uses docker to run many compatible database systems for integration testing. You'll need to [install docker](https://docs.docker.com/engine/) to run the full suite. You can validate your docker installation with: + + $ docker run hello-world + +Start the databases with `docker-compose` before running tests: + + $ docker-compose up + +Run all tests against all supported databases using: + + $ ./x.py + +If you see test failures, or want to run a more specific set of tests against a specific database, you can specify both the features to be tests and the DATABASE_URL. e.g. + + $ DATABASE_URL=mysql://root:password@127.0.0.1:49183/sqlx cargo test --no-default-features --features macros,offline,any,all-types,mysql,runtime-async-std-native-tls diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 4d0e39848e..a6f0109692 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -2,7 +2,7 @@ version: "3" services: # - # MySQL 5.6.x, 5.7.x, 8.x + # MySQL 8.x, 5.7.x, 5.6.x # https://www.mysql.com/support/supportedplatforms/database.html # @@ -40,12 +40,12 @@ services: MYSQL_DATABASE: sqlx # - # MariaDB 10.5, 10.4, 10.3, 10.2, 10.1 + # MariaDB 10.6, 10.5, 10.4, 10.3, 10.2 # https://mariadb.org/about/#maintenance-policy # - mariadb_10_5: - image: mariadb:10.5 + mariadb_10_6: + image: mariadb:10.6 volumes: - "./mysql/setup.sql:/docker-entrypoint-initdb.d/setup.sql" ports: @@ -54,8 +54,8 @@ services: MYSQL_ROOT_PASSWORD: password MYSQL_DATABASE: sqlx - mariadb_10_4: - image: mariadb:10.4 + mariadb_10_5: + image: mariadb:10.5 volumes: - "./mysql/setup.sql:/docker-entrypoint-initdb.d/setup.sql" ports: @@ -64,8 +64,8 @@ services: MYSQL_ROOT_PASSWORD: password MYSQL_DATABASE: sqlx - mariadb_10_3: - image: mariadb:10.3 + mariadb_10_4: + image: mariadb:10.4 volumes: - "./mysql/setup.sql:/docker-entrypoint-initdb.d/setup.sql" ports: @@ -74,8 +74,8 @@ services: MYSQL_ROOT_PASSWORD: password MYSQL_DATABASE: sqlx - mariadb_10_2: - image: mariadb:10.2 + mariadb_10_3: + image: mariadb:10.3 volumes: - "./mysql/setup.sql:/docker-entrypoint-initdb.d/setup.sql" ports: @@ -84,8 +84,8 @@ services: MYSQL_ROOT_PASSWORD: password MYSQL_DATABASE: sqlx - mariadb_10_1: - image: mariadb:10.1 + mariadb_10_2: + image: mariadb:10.2 volumes: - "./mysql/setup.sql:/docker-entrypoint-initdb.d/setup.sql" ports: @@ -95,16 +95,35 @@ services: MYSQL_DATABASE: sqlx # - # PostgreSQL 12.x, 10.x, 9.6.x, 9.5.x + # PostgreSQL 13.x, 12.x, 11.x 10.x, 9.6.x # https://www.postgresql.org/support/versioning/ # + postgres_14: + build: + context: . + dockerfile: postgres/Dockerfile + args: + VERSION: 14rc1 + ports: + - 5432 + environment: + POSTGRES_DB: sqlx + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_HOST_AUTH_METHOD: scram-sha-256 + POSTGRES_INITDB_ARGS: --auth-host=scram-sha-256 + volumes: + - "./postgres/setup.sql:/docker-entrypoint-initdb.d/setup.sql" + command: > + -c ssl=on -c ssl_cert_file=/var/lib/postgresql/server.crt -c ssl_key_file=/var/lib/postgresql/server.key + postgres_13: build: context: . dockerfile: postgres/Dockerfile args: - VERSION: 13-beta1 + VERSION: 13 ports: - 5432 environment: @@ -123,7 +142,7 @@ services: context: . dockerfile: postgres/Dockerfile args: - VERSION: 12.3 + VERSION: 12 ports: - 5432 environment: @@ -137,55 +156,57 @@ services: command: > -c ssl=on -c ssl_cert_file=/var/lib/postgresql/server.crt -c ssl_key_file=/var/lib/postgresql/server.key - postgres_10: + postgres_11: build: context: . dockerfile: postgres/Dockerfile args: - VERSION: 10.13 + VERSION: 11 ports: - 5432 environment: POSTGRES_DB: sqlx POSTGRES_USER: postgres POSTGRES_PASSWORD: password - POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_HOST_AUTH_METHOD: scram-sha-256 + POSTGRES_INITDB_ARGS: --auth-host=scram-sha-256 volumes: - "./postgres/setup.sql:/docker-entrypoint-initdb.d/setup.sql" command: > -c ssl=on -c ssl_cert_file=/var/lib/postgresql/server.crt -c ssl_key_file=/var/lib/postgresql/server.key - postgres_9_6: + postgres_10: build: context: . dockerfile: postgres/Dockerfile args: - VERSION: 9.6 + VERSION: 10 ports: - 5432 environment: POSTGRES_DB: sqlx POSTGRES_USER: postgres POSTGRES_PASSWORD: password - POSTGRES_HOST_AUTH_METHOD: md5 + POSTGRES_HOST_AUTH_METHOD: scram-sha-256 + POSTGRES_INITDB_ARGS: --auth-host=scram-sha-256 volumes: - "./postgres/setup.sql:/docker-entrypoint-initdb.d/setup.sql" command: > -c ssl=on -c ssl_cert_file=/var/lib/postgresql/server.crt -c ssl_key_file=/var/lib/postgresql/server.key - postgres_9_5: + postgres_9_6: build: context: . dockerfile: postgres/Dockerfile args: - VERSION: 9.5 + VERSION: 9.6 ports: - 5432 environment: POSTGRES_DB: sqlx POSTGRES_USER: postgres POSTGRES_PASSWORD: password - POSTGRES_HOST_AUTH_METHOD: password + POSTGRES_HOST_AUTH_METHOD: md5 volumes: - "./postgres/setup.sql:/docker-entrypoint-initdb.d/setup.sql" command: > @@ -205,17 +226,19 @@ services: ports: - 1433 environment: - ACCEPT_EULA: Y + ACCEPT_EULA: "Y" SA_PASSWORD: Password123! mssql_2017: build: context: . - dockerfile: mssql/Dockerfile + dockerfile: mssql/mssql-2017.dockerfile args: VERSION: 2017-latest + ports: + - 1433 environment: - ACCEPT_EULA: Y + ACCEPT_EULA: "Y" SA_PASSWORD: Password123! # diff --git a/tests/docker.py b/tests/docker.py index b4cdadd650..e664c38c6e 100644 --- a/tests/docker.py +++ b/tests/docker.py @@ -1,4 +1,5 @@ import subprocess +import sys import time from os import path diff --git a/tests/migrate/macro.rs b/tests/migrate/macro.rs index 9a3c16150e..7215046bef 100644 --- a/tests/migrate/macro.rs +++ b/tests/migrate/macro.rs @@ -7,6 +7,8 @@ static EMBEDDED: Migrator = sqlx::migrate!("tests/migrate/migrations"); async fn same_output() -> anyhow::Result<()> { let runtime = Migrator::new(Path::new("tests/migrate/migrations")).await?; + assert_eq!(runtime.migrations.len(), EMBEDDED.migrations.len()); + for (e, r) in EMBEDDED.iter().zip(runtime.iter()) { assert_eq!(e.version, r.version); assert_eq!(e.description, r.description); diff --git a/tests/mssql/mssql-2017.dockerfile b/tests/mssql/mssql-2017.dockerfile new file mode 100644 index 0000000000..a2e0b58dae --- /dev/null +++ b/tests/mssql/mssql-2017.dockerfile @@ -0,0 +1,19 @@ +# vim: set ft=dockerfile: +ARG VERSION +FROM mcr.microsoft.com/mssql/server:${VERSION} + +# Create a config directory +RUN mkdir -p /usr/config +WORKDIR /usr/config + +# Bundle config source +COPY mssql/entrypoint.sh /usr/config/entrypoint.sh +COPY mssql/configure-db.sh /usr/config/configure-db.sh +COPY mssql/setup.sql /usr/config/setup.sql + +# Grant permissions for to our scripts to be executable +USER root +RUN chmod +x /usr/config/entrypoint.sh +RUN chmod +x /usr/config/configure-db.sh + +ENTRYPOINT ["/usr/config/entrypoint.sh"] diff --git a/tests/mysql/macros.rs b/tests/mysql/macros.rs index 9b9b436c98..80eb1b2e91 100644 --- a/tests/mysql/macros.rs +++ b/tests/mysql/macros.rs @@ -188,12 +188,13 @@ async fn test_column_override_nullable() -> anyhow::Result<()> { async fn with_test_row<'a>( conn: &'a mut MySqlConnection, -) -> anyhow::Result> { +) -> anyhow::Result<(Transaction<'a, MySql>, MyInt)> { let mut transaction = conn.begin().await?; - sqlx::query!("INSERT INTO tweet(id, text, owner_id) VALUES (1, '#sqlx is pretty cool!', 1)") + let id = sqlx::query!("INSERT INTO tweet(text, owner_id) VALUES ('#sqlx is pretty cool!', 1)") .execute(&mut transaction) - .await?; - Ok(transaction) + .await? + .last_insert_id(); + Ok((transaction, MyInt(id as i64))) } #[derive(PartialEq, Eq, Debug, sqlx::Type)] @@ -211,13 +212,13 @@ struct OptionalRecord { #[sqlx_macros::test] async fn test_column_override_wildcard() -> anyhow::Result<()> { let mut conn = new::().await?; - let mut conn = with_test_row(&mut conn).await?; + let (mut conn, id) = with_test_row(&mut conn).await?; let record = sqlx::query_as!(Record, "select id as `id: _` from tweet") .fetch_one(&mut conn) .await?; - assert_eq!(record.id, MyInt(1)); + assert_eq!(record.id, id); // this syntax is also useful for expressions let record = sqlx::query_as!(Record, "select * from (select 1 as `id: _`) records") @@ -238,7 +239,7 @@ async fn test_column_override_wildcard() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_column_override_wildcard_not_null() -> anyhow::Result<()> { let mut conn = new::().await?; - let mut conn = with_test_row(&mut conn).await?; + let (mut conn, _) = with_test_row(&mut conn).await?; let record = sqlx::query_as!(Record, "select owner_id as `id!: _` from tweet") .fetch_one(&mut conn) @@ -252,13 +253,13 @@ async fn test_column_override_wildcard_not_null() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_column_override_wildcard_nullable() -> anyhow::Result<()> { let mut conn = new::().await?; - let mut conn = with_test_row(&mut conn).await?; + let (mut conn, id) = with_test_row(&mut conn).await?; let record = sqlx::query_as!(OptionalRecord, "select id as `id?: _` from tweet") .fetch_one(&mut conn) .await?; - assert_eq!(record.id, Some(MyInt(1))); + assert_eq!(record.id, Some(id)); Ok(()) } @@ -266,13 +267,13 @@ async fn test_column_override_wildcard_nullable() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_column_override_exact() -> anyhow::Result<()> { let mut conn = new::().await?; - let mut conn = with_test_row(&mut conn).await?; + let (mut conn, id) = with_test_row(&mut conn).await?; let record = sqlx::query!("select id as `id: MyInt` from tweet") .fetch_one(&mut conn) .await?; - assert_eq!(record.id, MyInt(1)); + assert_eq!(record.id, id); // we can also support this syntax for expressions let record = sqlx::query!("select * from (select 1 as `id: MyInt`) records") @@ -293,7 +294,7 @@ async fn test_column_override_exact() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_column_override_exact_not_null() -> anyhow::Result<()> { let mut conn = new::().await?; - let mut conn = with_test_row(&mut conn).await?; + let (mut conn, _) = with_test_row(&mut conn).await?; let record = sqlx::query!("select owner_id as `id!: MyInt` from tweet") .fetch_one(&mut conn) @@ -307,13 +308,13 @@ async fn test_column_override_exact_not_null() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_column_override_exact_nullable() -> anyhow::Result<()> { let mut conn = new::().await?; - let mut conn = with_test_row(&mut conn).await?; + let (mut conn, id) = with_test_row(&mut conn).await?; let record = sqlx::query!("select id as `id?: MyInt` from tweet") .fetch_one(&mut conn) .await?; - assert_eq!(record.id, Some(MyInt(1))); + assert_eq!(record.id, Some(id)); Ok(()) } diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index baeaf9923a..d78009b4e2 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -387,3 +387,62 @@ async fn test_issue_622() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_can_work_with_transactions() -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.execute("CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY);") + .await?; + + // begin .. rollback + + let mut tx = conn.begin().await?; + sqlx::query("INSERT INTO users (id) VALUES (?)") + .bind(1_i32) + .execute(&mut tx) + .await?; + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut tx) + .await?; + assert_eq!(count, 1); + tx.rollback().await?; + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut conn) + .await?; + assert_eq!(count, 0); + + // begin .. commit + + let mut tx = conn.begin().await?; + sqlx::query("INSERT INTO users (id) VALUES (?)") + .bind(1_i32) + .execute(&mut tx) + .await?; + tx.commit().await?; + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut conn) + .await?; + assert_eq!(count, 1); + + // begin .. (drop) + + { + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO users (id) VALUES (?)") + .bind(2) + .execute(&mut tx) + .await?; + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut tx) + .await?; + assert_eq!(count, 2); + // tx is dropped + } + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&mut conn) + .await?; + assert_eq!(count, 1); + + Ok(()) +} diff --git a/tests/mysql/types.rs b/tests/mysql/types.rs index e1c4d9f52b..9bd93e0f1b 100644 --- a/tests/mysql/types.rs +++ b/tests/mysql/types.rs @@ -235,7 +235,8 @@ mod json_tests { test_type!(json( MySql, - "SELECT CAST({0} AS BINARY) <=> CAST(? AS BINARY), CAST({0} AS BINARY) as _2, ? as _3", + // MySQL 8.0.27 changed `<=>` to return an unsigned integer + "SELECT CAST(CAST({0} AS BINARY) <=> CAST(? AS BINARY) AS SIGNED INTEGER), CAST({0} AS BINARY) as _2, ? as _3", "'\"Hello, World\"'" == json!("Hello, World"), "'\"😎\"'" == json!("😎"), "'\"🙋‍♀️\"'" == json!("🙋‍♀️"), @@ -250,7 +251,8 @@ mod json_tests { test_type!(json_struct>( MySql, - "SELECT CAST({0} AS BINARY) <=> CAST(? AS BINARY), CAST({0} AS BINARY) as _2, ? as _3", + // MySQL 8.0.27 changed `<=>` to return an unsigned integer + "SELECT CAST(CAST({0} AS BINARY) <=> CAST(? AS BINARY) AS SIGNED INTEGER), CAST({0} AS BINARY) as _2, ? as _3", "\'{\"name\":\"Joe\",\"age\":33}\'" == Json(Friend { name: "Joe".to_string(), age: 33 }) )); diff --git a/tests/postgres/macros.rs b/tests/postgres/macros.rs index bc770e050f..51d1f89bb8 100644 --- a/tests/postgres/macros.rs +++ b/tests/postgres/macros.rs @@ -105,7 +105,8 @@ async fn test_query_file() -> anyhow::Result<()> { .fetch_one(&mut conn) .await?; - println!("{:?}", account); + assert_eq!(account.id, 1); + assert_eq!(account.name, Option::::None); Ok(()) } diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 590f06b5c5..51dfbc6d37 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1,8 +1,8 @@ -use futures::TryStreamExt; +use futures::{StreamExt, TryStreamExt}; use sqlx::postgres::{ PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgSeverity, }; -use sqlx::postgres::{PgPoolOptions, PgRow, Postgres}; +use sqlx::postgres::{PgConnectionInfo, PgPoolOptions, PgRow, Postgres}; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; use sqlx_test::{new, setup_if_needed}; use std::env; @@ -519,14 +519,19 @@ async fn pool_smoke_test() -> anyhow::Result<()> { for i in 0..200 { let pool = pool.clone(); sqlx_rt::spawn(async move { - loop { + for j in 0.. { if let Err(e) = sqlx::query("select 1 + 1").execute(&pool).await { // normal error at termination of the test - if !matches!(e, sqlx::Error::PoolClosed) { - eprintln!("pool task {} dying due to {}", i, e); - break; + if matches!(e, sqlx::Error::PoolClosed) { + eprintln!("pool task {} exiting normally after {} iterations", i, j); + } else { + eprintln!("pool task {} dying due to {} after {} iterations", i, e, j); } + break; } + + // shouldn't be necessary if the pool is fair + // sqlx_rt::yield_now().await; } }); } @@ -547,6 +552,8 @@ async fn pool_smoke_test() -> anyhow::Result<()> { }) .await; + // this one is necessary since this is a hot loop, + // otherwise this task will never be descheduled sqlx_rt::yield_now().await; } }); @@ -961,6 +968,30 @@ async fn test_listener_cleanup() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result<()> { + // Only supported in Postgres 11+ + let mut conn = new::().await?; + if matches!(conn.server_version_num(), Some(version) if version < 110000) { + return Ok(()); + } + + conn.execute( + r#" +DROP TABLE IF EXISTS heating_bills; +DROP DOMAIN IF EXISTS winter_year_month; +DROP TYPE IF EXISTS year_month; +DROP DOMAIN IF EXISTS month_id; + +CREATE DOMAIN month_id AS INT2 CHECK (1 <= value AND value <= 12); +CREATE TYPE year_month AS (year INT4, month month_id); +CREATE DOMAIN winter_year_month AS year_month CHECK ((value).month <= 3); +CREATE TABLE heating_bills ( + month winter_year_month NOT NULL PRIMARY KEY, + cost INT4 NOT NULL +); + "#, + ) + .await?; + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct MonthId(i16); @@ -1032,41 +1063,176 @@ async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result< sqlx::encode::IsNull::No } } - let mut conn = new::().await?; - { - let result = sqlx::query("DELETE FROM heating_bills;") + let result = sqlx::query("DELETE FROM heating_bills;") + .execute(&mut conn) + .await; + + let result = result.unwrap(); + assert_eq!(result.rows_affected(), 0); + + let result = + sqlx::query("INSERT INTO heating_bills(month, cost) VALUES($1::winter_year_month, 100);") + .bind(WinterYearMonth { + year: 2021, + month: MonthId(1), + }) .execute(&mut conn) .await; - let result = result.unwrap(); - assert_eq!(result.rows_affected(), 1); - } + let result = result.unwrap(); + assert_eq!(result.rows_affected(), 1); - { - let result = sqlx::query( - "INSERT INTO heating_bills(month, cost) VALUES($1::winter_year_month, 100);", - ) - .bind(WinterYearMonth { - year: 2021, - month: MonthId(1), - }) + let result = sqlx::query("DELETE FROM heating_bills;") .execute(&mut conn) .await; - let result = result.unwrap(); - assert_eq!(result.rows_affected(), 1); - } + let result = result.unwrap(); + assert_eq!(result.rows_affected(), 1); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_pg_server_num() -> anyhow::Result<()> { + use sqlx::postgres::PgConnectionInfo; + + let conn = new::().await?; + + assert!(conn.server_version_num().is_some()); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_copy_in() -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.execute( + r#" + CREATE TEMPORARY TABLE users (id INTEGER NOT NULL); + "#, + ) + .await?; + + let mut copy = conn + .copy_in_raw( + r#" + COPY users (id) FROM STDIN WITH (FORMAT CSV, HEADER); + "#, + ) + .await?; + + copy.send("id\n1\n2\n".as_bytes()).await?; + let rows = copy.finish().await?; + assert_eq!(rows, 2); + + // conn is safe for reuse + let value = sqlx::query("select 1 + 1") + .try_map(|row: PgRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + + assert_eq!(2i32, value); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_abort_copy_in() -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.execute( + r#" + CREATE TEMPORARY TABLE users (id INTEGER NOT NULL); + "#, + ) + .await?; + + let copy = conn + .copy_in_raw( + r#" + COPY users (id) FROM STDIN WITH (FORMAT CSV, HEADER); + "#, + ) + .await?; + + copy.abort("this is only a test").await?; + + // conn is safe for reuse + let value = sqlx::query("select 1 + 1") + .try_map(|row: PgRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + + assert_eq!(2i32, value); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_copy_out() -> anyhow::Result<()> { + let mut conn = new::().await?; { - let result = sqlx::query("DELETE FROM heating_bills;") - .execute(&mut conn) - .await; + let mut copy = conn + .copy_out_raw( + " + COPY (SELECT generate_series(1, 2) AS id) TO STDOUT WITH (FORMAT CSV, HEADER); + ", + ) + .await?; + + assert_eq!(copy.next().await.unwrap().unwrap(), "id\n"); + assert_eq!(copy.next().await.unwrap().unwrap(), "1\n"); + assert_eq!(copy.next().await.unwrap().unwrap(), "2\n"); + if copy.next().await.is_some() { + anyhow::bail!("Unexpected data from COPY"); + } + } + + // conn is safe for reuse + let value = sqlx::query("select 1 + 1") + .try_map(|row: PgRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + + assert_eq!(2i32, value); + + Ok(()) +} - let result = result.unwrap(); - assert_eq!(result.rows_affected(), 1); +#[sqlx_macros::test] +async fn test_issue_1254() -> anyhow::Result<()> { + #[derive(sqlx::Type)] + #[sqlx(type_name = "pair")] + struct Pair { + one: i32, + two: i32, } + // array for custom type is not supported, use wrapper + #[derive(sqlx::Type)] + #[sqlx(type_name = "_pair")] + struct Pairs(Vec); + + let mut conn = new::().await?; + conn.execute( + " +DROP TABLE IF EXISTS issue_1254; +DROP TYPE IF EXISTS pair; + +CREATE TYPE pair AS (one INT4, two INT4); +CREATE TABLE issue_1254 (id INT4 PRIMARY KEY, pairs PAIR[]); +", + ) + .await?; + + let result = sqlx::query("INSERT INTO issue_1254 VALUES($1, $2)") + .bind(0) + .bind(Pairs(vec![Pair { one: 94, two: 87 }])) + .execute(&mut conn) + .await?; + assert_eq!(result.rows_affected(), 1); + Ok(()) } diff --git a/tests/postgres/setup.sql b/tests/postgres/setup.sql index d013d43400..9818d139ba 100644 --- a/tests/postgres/setup.sql +++ b/tests/postgres/setup.sql @@ -29,11 +29,3 @@ CREATE TABLE products ( name TEXT, price NUMERIC CHECK (price > 0) ); - -CREATE DOMAIN month_id AS INT2 CHECK (1 <= value AND value <= 12); -CREATE TYPE year_month AS (year INT4, month month_id); -CREATE DOMAIN winter_year_month AS year_month CHECK ((value).month <= 3); -CREATE TABLE heating_bills ( - month winter_year_month NOT NULL PRIMARY KEY, - cost INT4 NOT NULL -); diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index a0aa64eb69..7b39de8596 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -167,6 +167,14 @@ test_type!(ipnetwork(Postgres, .unwrap(), )); +#[cfg(feature = "mac_address")] +test_type!(mac_address(Postgres, + "'00:01:02:03:04:05'::macaddr" + == "00:01:02:03:04:05" + .parse::() + .unwrap() +)); + #[cfg(feature = "bit-vec")] test_type!(bitvec( Postgres, @@ -201,6 +209,15 @@ test_type!(ipnetwork_vec>(Postgres, ] )); +#[cfg(feature = "mac_address")] +test_type!(mac_address_vec>(Postgres, + "'{01:02:03:04:05:06,FF:FF:FF:FF:FF:FF}'::macaddr[]" + == vec![ + "01:02:03:04:05:06".parse::().unwrap(), + "FF:FF:FF:FF:FF:FF".parse::().unwrap() + ] +)); + #[cfg(feature = "chrono")] mod chrono { use super::*; @@ -408,6 +425,13 @@ test_type!(bigdecimal(Postgres, "12345.6789::numeric" == "12345.6789".parse::().unwrap(), )); +#[cfg(feature = "bigdecimal")] +test_type!(numrange_bigdecimal>(Postgres, + "'(1.3,2.4)'::numrange" == PgRange::from( + (Bound::Excluded("1.3".parse::().unwrap()), + Bound::Excluded("2.4".parse::().unwrap()))) +)); + #[cfg(feature = "decimal")] test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(), @@ -419,6 +443,13 @@ test_type!(decimal(Postgres, "12345.6789::numeric" == sqlx::types::Decimal::from_str("12345.6789").unwrap(), )); +#[cfg(feature = "decimal")] +test_type!(numrange_decimal>(Postgres, + "'(1.3,2.4)'::numrange" == PgRange::from( + (Bound::Excluded(sqlx::types::Decimal::from_str("1.3").unwrap()), + Bound::Excluded(sqlx::types::Decimal::from_str("2.4").unwrap()))), +)); + const EXC2: Bound = Bound::Excluded(2); const EXC3: Bound = Bound::Excluded(3); const INC1: Bound = Bound::Included(1); diff --git a/tests/sqlite/.gitignore b/tests/sqlite/.gitignore new file mode 100644 index 0000000000..02a6711c35 --- /dev/null +++ b/tests/sqlite/.gitignore @@ -0,0 +1,2 @@ +sqlite.db + diff --git a/tests/sqlite/derives.rs b/tests/sqlite/derives.rs index bbbc3d673d..d91e012b30 100644 --- a/tests/sqlite/derives.rs +++ b/tests/sqlite/derives.rs @@ -1,5 +1,5 @@ use sqlx::Sqlite; -use sqlx_test::{new, test_type}; +use sqlx_test::test_type; #[derive(Debug, PartialEq, sqlx::Type)] #[repr(u32)] diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index 02d935a1bd..90d59284ea 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -171,6 +171,21 @@ async fn it_describes_insert_with_read_only() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_describes_insert_with_returning() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let d = conn + .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello') RETURNING *") + .await?; + + assert_eq!(d.columns().len(), 4); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.column(1).type_info().name(), "TEXT"); + + Ok(()) +} + #[sqlx_macros::test] async fn it_describes_bad_statement() -> anyhow::Result<()> { let mut conn = new::().await?; diff --git a/tests/sqlite/sqlite.db b/tests/sqlite/sqlite.db index 49913441df..a3d8d5cc29 100644 Binary files a/tests/sqlite/sqlite.db and b/tests/sqlite/sqlite.db differ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 1334e493c1..12f1834e8c 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -206,7 +206,8 @@ async fn it_executes_with_pool() -> anyhow::Result<()> { async fn it_opens_in_memory() -> anyhow::Result<()> { // If the filename is ":memory:", then a private, temporary in-memory database // is created for the connection. - let _ = SqliteConnection::connect(":memory:").await?; + let conn = SqliteConnection::connect(":memory:").await?; + conn.close().await?; Ok(()) } @@ -215,7 +216,8 @@ async fn it_opens_in_memory() -> anyhow::Result<()> { async fn it_opens_temp_on_disk() -> anyhow::Result<()> { // If the filename is an empty string, then a private, temporary on-disk database will // be created. - let _ = SqliteConnection::connect("").await?; + let conn = SqliteConnection::connect("").await?; + conn.close().await?; Ok(()) } @@ -536,3 +538,57 @@ async fn it_resets_prepared_statement_after_fetch_many() -> anyhow::Result<()> { Ok(()) } + +// https://github.com/launchbadge/sqlx/issues/1300 +#[sqlx_macros::test] +async fn concurrent_resets_dont_segfault() { + use sqlx::{sqlite::SqliteConnectOptions, ConnectOptions}; + use std::{str::FromStr, time::Duration}; + + let mut conn = SqliteConnectOptions::from_str(":memory:") + .unwrap() + .connect() + .await + .unwrap(); + + sqlx::query("CREATE TABLE stuff (name INTEGER, value INTEGER)") + .execute(&mut conn) + .await + .unwrap(); + + sqlx_rt::spawn(async move { + for i in 0..1000 { + sqlx::query("INSERT INTO stuff (name, value) VALUES (?, ?)") + .bind(i) + .bind(0) + .execute(&mut conn) + .await + .unwrap(); + } + }); + + sqlx_rt::sleep(Duration::from_millis(1)).await; +} + +// https://github.com/launchbadge/sqlx/issues/1419 +// note: this passes before and after the fix; you need to run it with `--nocapture` +// to see the panic from the worker thread, which doesn't happen after the fix +#[sqlx_macros::test] +async fn row_dropped_after_connection_doesnt_panic() { + let mut conn = SqliteConnection::connect(":memory:").await.unwrap(); + + let books = sqlx::query("SELECT 'hello' AS title") + .fetch_all(&mut conn) + .await + .unwrap(); + + for book in &books { + // force the row to be inflated + let _title: String = book.get("title"); + } + + // hold `books` past the lifetime of `conn` + drop(conn); + sqlx_rt::sleep(std::time::Duration::from_secs(1)).await; + drop(books); +} diff --git a/tests/x.py b/tests/x.py index 2133beefe4..3fd77e8892 100755 --- a/tests/x.py +++ b/tests/x.py @@ -88,116 +88,101 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data # check # -run("cargo c", comment="check with a default set of features", tag="check") - -run( - "cargo c --no-default-features --features runtime-async-std-native-tls,all-databases,all-types,offline,macros", - comment="check with async-std", - tag="check_async_std" -) - -run( - "cargo c --no-default-features --features runtime-tokio-native-tls,all-databases,all-types,offline,macros", - comment="check with tokio", - tag="check_tokio" -) - -run( - "cargo c --no-default-features --features runtime-actix-native-tls,all-databases,all-types,offline,macros", - comment="check with actix", - tag="check_actix" -) +for runtime in ["async-std", "tokio", "actix"]: + for tls in ["native-tls", "rustls"]: + run( + f"cargo c --no-default-features --features all-databases,all-types,offline,macros,runtime-{runtime}-{tls}", + comment="check with async-std", + tag=f"check_{runtime}_{tls}" + ) # # unit test # -run( - "cargo test --manifest-path sqlx-core/Cargo.toml --features all-databases,all-types", - comment="unit test core", - tag="unit" -) - -run( - "cargo test --no-default-features --manifest-path sqlx-core/Cargo.toml --features all-databases,all-types,runtime-tokio-native-tls", - comment="unit test core", - tag="unit_tokio" -) +for runtime in ["async-std", "tokio", "actix"]: + for tls in ["native-tls", "rustls"]: + run( + f"cargo test --no-default-features --manifest-path sqlx-core/Cargo.toml --features all-databases,all-types,runtime-{runtime}-{tls}", + comment="unit test core", + tag=f"unit_{runtime}_{tls}" + ) # # integration tests # for runtime in ["async-std", "tokio", "actix"]: + for tls in ["native-tls", "rustls"]: - # - # sqlite - # - - run( - f"cargo test --no-default-features --features macros,offline,any,all-types,sqlite,runtime-{runtime}-native-tls", - comment=f"test sqlite", - service="sqlite", - tag=f"sqlite" if runtime == "async-std" else f"sqlite_{runtime}", - ) - - # - # postgres - # - - for version in ["12", "10", "9_6", "9_5"]: - run( - f"cargo test --no-default-features --features macros,offline,any,all-types,postgres,runtime-{runtime}-native-tls", - comment=f"test postgres {version}", - service=f"postgres_{version}", - tag=f"postgres_{version}" if runtime == "async-std" else f"postgres_{version}_{runtime}", - ) - - # +ssl - for version in ["12", "10", "9_6", "9_5"]: - run( - f"cargo test --no-default-features --features macros,offline,any,all-types,postgres,runtime-{runtime}-native-tls", - comment=f"test postgres {version} ssl", - database_url_args="sslmode=verify-ca&sslrootcert=.%2Ftests%2Fcerts%2Fca.crt", - service=f"postgres_{version}", - tag=f"postgres_{version}_ssl" if runtime == "async-std" else f"postgres_{version}_ssl_{runtime}", - ) - - # - # mysql - # - - for version in ["8", "5_7", "5_6"]: - run( - f"cargo test --no-default-features --features macros,offline,any,all-types,mysql,runtime-{runtime}-native-tls", - comment=f"test mysql {version}", - service=f"mysql_{version}", - tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", - ) - - # - # mariadb - # + # + # sqlite + # - for version in ["10_5", "10_4", "10_3", "10_2", "10_1"]: run( - f"cargo test --no-default-features --features macros,offline,any,all-types,mysql,runtime-{runtime}-native-tls", - comment=f"test mariadb {version}", - service=f"mariadb_{version}", - tag=f"mariadb_{version}" if runtime == "async-std" else f"mariadb_{version}_{runtime}", + f"cargo test --no-default-features --features macros,offline,any,all-types,sqlite,runtime-{runtime}-{tls}", + comment=f"test sqlite", + service="sqlite", + tag=f"sqlite" if runtime == "async-std" else f"sqlite_{runtime}", ) - # - # mssql - # - - for version in ["2019"]: - run( - f"cargo test --no-default-features --features macros,offline,any,all-types,mssql,runtime-{runtime}-native-tls", - comment=f"test mssql {version}", - service=f"mssql_{version}", - tag=f"mssql_{version}" if runtime == "async-std" else f"mssql_{version}_{runtime}", - ) + # + # postgres + # + + for version in ["13", "12", "11", "10", "9_6"]: + run( + f"cargo test --no-default-features --features macros,offline,any,all-types,postgres,runtime-{runtime}-{tls}", + comment=f"test postgres {version}", + service=f"postgres_{version}", + tag=f"postgres_{version}" if runtime == "async-std" else f"postgres_{version}_{runtime}", + ) + + ## +ssl + for version in ["13", "12", "11", "10", "9_6"]: + run( + f"cargo test --no-default-features --features macros,offline,any,all-types,postgres,runtime-{runtime}-{tls}", + comment=f"test postgres {version} ssl", + database_url_args="sslmode=verify-ca&sslrootcert=.%2Ftests%2Fcerts%2Fca.crt", + service=f"postgres_{version}", + tag=f"postgres_{version}_ssl" if runtime == "async-std" else f"postgres_{version}_ssl_{runtime}", + ) + + # + # mysql + # + + for version in ["8", "5_7", "5_6"]: + run( + f"cargo test --no-default-features --features macros,offline,any,all-types,mysql,runtime-{runtime}-{tls}", + comment=f"test mysql {version}", + service=f"mysql_{version}", + tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", + ) + + # + # mariadb + # + + for version in ["10_6", "10_5", "10_4", "10_3", "10_2"]: + run( + f"cargo test --no-default-features --features macros,offline,any,all-types,mysql,runtime-{runtime}-{tls}", + comment=f"test mariadb {version}", + service=f"mariadb_{version}", + tag=f"mariadb_{version}" if runtime == "async-std" else f"mariadb_{version}_{runtime}", + ) + + # + # mssql + # + + for version in ["2019", "2017"]: + run( + f"cargo test --no-default-features --features macros,offline,any,all-types,mssql,runtime-{runtime}-{tls}", + comment=f"test mssql {version}", + service=f"mssql_{version}", + tag=f"mssql_{version}" if runtime == "async-std" else f"mssql_{version}_{runtime}", + ) # TODO: Use [grcov] if available # ~/.cargo/bin/grcov tests/.cache/target/debug -s sqlx-core/ -t html --llvm --branch -o ./target/debug/coverage