diff --git a/.asf.yaml b/.asf.yaml index 364b9b254..ce27a54e3 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -22,7 +22,7 @@ github: description: "Apache Iceberg" - homepage: https://iceberg.apache.org/ + homepage: https://rust.iceberg.apache.org/ labels: - iceberg - apache @@ -42,17 +42,20 @@ github: required_approving_review_count: 1 required_linear_history: true + features: wiki: false issues: true - projects: false + projects: true collaborators: - Xuanwo - liurenjie1024 - JanKaul + ghp_branch: gh-pages + ghp_path: / notifications: - commits: commits@iceberg.apache.org - issues: issues@iceberg.apache.org - pullrequests: issues@iceberg.apache.org - jira_options: link label link label + commits: commits@iceberg.apache.org + issues: issues@iceberg.apache.org + pullrequests: issues@iceberg.apache.org + jira_options: link label link label diff --git a/.cargo/audit.toml b/.cargo/audit.toml new file mode 100644 index 000000000..5db5a9d81 --- /dev/null +++ b/.cargo/audit.toml @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[advisories] +ignore = [ + # rsa + # Marvin Attack: potential key recovery through timing sidechannels + # Issues: https://github.com/apache/iceberg-rust/issues/221 + "RUSTSEC-2023-0071", +] diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..908bda4b5 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +website export-ignore diff --git a/.github/actions/setup-builder/action.yml b/.github/actions/setup-builder/action.yml new file mode 100644 index 000000000..43de1cbaa --- /dev/null +++ b/.github/actions/setup-builder/action.yml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is heavily inspired by +# [datafusion](https://github.com/apache/datafusion/blob/main/.github/actions/setup-builder/action.yaml). +name: Prepare Rust Builder +description: 'Prepare Rust Build Environment' +inputs: + rust-version: + description: 'version of rust to install (e.g. stable)' + required: true + default: 'stable' +runs: + using: "composite" + steps: + - name: Setup Rust toolchain + shell: bash + run: | + echo "Installing ${{ inputs.rust-version }}" + rustup toolchain install ${{ inputs.rust-version }} + rustup default ${{ inputs.rust-version }} + rustup component add rustfmt clippy + - name: Fixup git permissions + # https://github.com/actions/checkout/issues/766 + shell: bash + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" \ No newline at end of file diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml new file mode 100644 index 000000000..0d65b1aa8 --- /dev/null +++ b/.github/workflows/audit.yml @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Security audit + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +on: + push: + paths: + - "**/Cargo.toml" + - "**/Cargo.lock" + + pull_request: + paths: + - "**/Cargo.toml" + - "**/Cargo.lock" + +jobs: + security_audit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install cargo-audit + run: cargo install cargo-audit + - name: Run audit check + run: cargo audit diff --git a/.github/workflows/bindings_python_ci.yml b/.github/workflows/bindings_python_ci.yml new file mode 100644 index 000000000..d4b1aa922 --- /dev/null +++ b/.github/workflows/bindings_python_ci.yml @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Bindings Python CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }} + cancel-in-progress: true + +jobs: + check-rust: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Check format + run: cargo fmt --all -- --check + - name: Check clippy + run: cargo clippy --all-targets --all-features -- -D warnings + + check-python: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install tools + run: | + pip install ruff + - name: Check format + working-directory: "bindings/python" + run: | + ruff format . --diff + - name: Check style + working-directory: "bindings/python" + run: | + ruff check . + + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.8 + - uses: PyO3/maturin-action@v1 + with: + working-directory: "bindings/python" + command: build + args: --out dist --sdist + - name: Run tests + working-directory: "bindings/python" + shell: bash + run: | + set -e + pip install hatch==1.12.0 + hatch run dev:pip install dist/pyiceberg_core-*.whl --force-reinstall + hatch run dev:test diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1523971a2..38f450bf7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,14 +29,28 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }} cancel-in-progress: true +env: + rust_msrv: "1.77.1" + jobs: check: - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest steps: - uses: actions/checkout@v4 - name: Check License Header - uses: apache/skywalking-eyes/header@v0.5.0 + uses: apache/skywalking-eyes/header@v0.6.0 + + - name: Install cargo-sort + run: make install-cargo-sort + + - name: Install taplo-cli + run: make install-taplo-cli - name: Cargo format run: make check-fmt @@ -50,8 +64,29 @@ jobs: - name: Cargo sort run: make cargo-sort + - name: Cargo Machete + run: make cargo-machete build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ env.rust_msrv }} + + - name: Build + run: make build + + build_with_no_default_features: runs-on: ${{ matrix.os }} strategy: matrix: @@ -62,15 +97,23 @@ jobs: steps: - uses: actions/checkout@v4 - name: Build - run: cargo build + run: cargo build -p iceberg --no-default-features unit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ env.rust_msrv }} + - name: Test run: cargo test --no-fail-fast --all-targets --all-features --workspace - + + - name: Async-std Test + run: cargo test --no-fail-fast --all-targets --no-default-features --features "async-std" --features "storage-all" --workspace + - name: Doc Test run: cargo test --no-fail-fast --doc --all-features --workspace diff --git a/.github/workflows/ci_typos.yml b/.github/workflows/ci_typos.yml index 51a6a7b91..da72929dd 100644 --- a/.github/workflows/ci_typos.yml +++ b/.github/workflows/ci_typos.yml @@ -41,7 +41,5 @@ jobs: FORCE_COLOR: 1 steps: - uses: actions/checkout@v4 - - run: curl -LsSf https://github.com/crate-ci/typos/releases/download/v1.14.8/typos-v1.14.8-x86_64-unknown-linux-musl.tar.gz | tar zxf - -C ${CARGO_HOME:-~/.cargo}/bin - - - name: do typos check with typos-cli - run: typos + - name: Check typos + uses: crate-ci/typos@v1.24.5 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 000000000..486d66246 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Publish + +on: + push: + tags: + - '*' + workflow_dispatch: + +env: + rust_msrv: "1.77.1" + +jobs: + publish: + runs-on: ubuntu-latest + strategy: + # Publish package one by one instead of flooding the registry + max-parallel: 1 + matrix: + # Order here is sensitive, as it will be used to determine the order of publishing + package: + - "crates/iceberg" + - "crates/catalog/glue" + - "crates/catalog/hms" + - "crates/catalog/memory" + - "crates/catalog/rest" + # sql is not ready for release yet. + # - "crates/catalog/sql" + - "crates/integrations/datafusion" + steps: + - uses: actions/checkout@v4 + + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ env.rust_msrv }} + + - name: Publish ${{ matrix.package }} + working-directory: ${{ matrix.package }} + # Only publish if it's a tag and the tag is not a pre-release + if: ${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }} + run: cargo publish --all-features + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/website.yml b/.github/workflows/website.yml new file mode 100644 index 000000000..bbe3e53c4 --- /dev/null +++ b/.github/workflows/website.yml @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Website + +on: + push: + branches: + - main + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }} + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + + - name: Setup mdBook + uses: peaceiris/actions-mdbook@v2 + with: + mdbook-version: '0.4.36' + + - name: Build + working-directory: website + run: mdbook build + + - name: Copy asf file + run: cp .asf.yaml ./website/book/.asf.yaml + + - name: Build API docs + run: | + cargo doc --no-deps --workspace --all-features + cp -r target/doc ./website/book/api + + - name: Deploy to gh-pages + uses: peaceiris/actions-gh-pages@v4.0.0 + if: github.event_name == 'push' && github.ref_name == 'main' + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: website/book + publish_branch: gh-pages diff --git a/.gitignore b/.gitignore index 72c34840c..a3f05e817 100644 --- a/.gitignore +++ b/.gitignore @@ -15,8 +15,14 @@ # specific language governing permissions and limitations # under the License. -/target -/Cargo.lock +target +Cargo.lock .idea .vscode -**/.DS_Store \ No newline at end of file +**/.DS_Store +dist/* +**/venv +*.so +*.pyc +*.whl +*.tar.gz diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 000000000..ea5e0779f --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +!.gitignore +!vcs.xml diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 000000000..6fd581ec8 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,30 @@ + + + + + + + + + + diff --git a/.licenserc.yaml b/.licenserc.yaml index cd362bc94..38aa58402 100644 --- a/.licenserc.yaml +++ b/.licenserc.yaml @@ -23,6 +23,12 @@ header: paths-ignore: - 'LICENSE' - 'NOTICE' + - '.gitattributes' - '**/*.json' - + # Generated content by mdbook + - 'website/book' + # Generated content by scripts + - '**/DEPENDENCIES.*.tsv' + # Release distributions + - 'dist/*' comment: on-failure diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d1a50b33..fc576c52f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,4 +24,278 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/) and this project adheres to [Semantic Versioning](https://semver.org/). -## Unreleased +## [v0.3.0] - 2024-08-14 + +* Smooth out release steps by @Fokko in https://github.com/apache/iceberg-rust/pull/197 +* refactor: remove support of manifest list format as a list of file path by @Dysprosium0626 in https://github.com/apache/iceberg-rust/pull/201 +* refactor: remove unwraps by @odysa in https://github.com/apache/iceberg-rust/pull/196 +* Fix: add required rust version in cargo.toml by @dp-0 in https://github.com/apache/iceberg-rust/pull/193 +* Fix the REST spec version by @Fokko in https://github.com/apache/iceberg-rust/pull/198 +* feat: Add Sync + Send to Catalog trait by @ZhengLin-Li in https://github.com/apache/iceberg-rust/pull/202 +* feat: Make thrift transport configurable by @DeaconDesperado in https://github.com/apache/iceberg-rust/pull/194 +* Add UnboundSortOrder by @fqaiser94 in https://github.com/apache/iceberg-rust/pull/115 +* ci: Add workflow for publish by @Xuanwo in https://github.com/apache/iceberg-rust/pull/218 +* Add workflow for cargo audit by @sdd in https://github.com/apache/iceberg-rust/pull/217 +* docs: Add basic README for all crates by @Xuanwo in https://github.com/apache/iceberg-rust/pull/215 +* Follow naming convention from Iceberg's Java and Python implementations by @s-akhtar-baig in https://github.com/apache/iceberg-rust/pull/204 +* doc: Add download page by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/219 +* chore(deps): Update derive_builder requirement from 0.13.0 to 0.20.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/203 +* test: add FileIO s3 test by @odysa in https://github.com/apache/iceberg-rust/pull/220 +* ci: Ignore RUSTSEC-2023-0071 for no actions to take by @Xuanwo in https://github.com/apache/iceberg-rust/pull/222 +* feat: Add expression builder and display. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/169 +* chord: Add IssueNavigationLink for RustRover by @stream2000 in https://github.com/apache/iceberg-rust/pull/230 +* minor: Fix `double` API doc by @viirya in https://github.com/apache/iceberg-rust/pull/226 +* feat: add `UnboundPredicate::negate()` by @sdd in https://github.com/apache/iceberg-rust/pull/228 +* fix: Remove deprecated methods to pass ci by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/234 +* Implement basic Parquet data file reading capability by @sdd in https://github.com/apache/iceberg-rust/pull/207 +* chore: doc-test as a target by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/235 +* feat: add parquet writer by @ZENOTME in https://github.com/apache/iceberg-rust/pull/176 +* Add hive metastore catalog support (part 1/2) by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/237 +* chore: Enable projects. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/247 +* refactor: Make plan_files as asynchronous stream by @viirya in https://github.com/apache/iceberg-rust/pull/243 +* feat: Implement binding expression by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/231 +* Implement Display instead of ToString by @lewiszlw in https://github.com/apache/iceberg-rust/pull/256 +* add rewrite_not by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/263 +* feat: init TableMetadataBuilder by @ZENOTME in https://github.com/apache/iceberg-rust/pull/262 +* Rename stat_table to table_exists in Catalog trait by @lewiszlw in https://github.com/apache/iceberg-rust/pull/257 +* feat (static table): implement a read-only table struct loaded from metadata by @a-agmon in https://github.com/apache/iceberg-rust/pull/259 +* feat: implement OAuth for catalog rest client by @TennyZhuang in https://github.com/apache/iceberg-rust/pull/254 +* docs: annotate precision and length to primitive types by @waynexia in https://github.com/apache/iceberg-rust/pull/270 +* build: Restore CI by making parquet and arrow version consistent by @viirya in https://github.com/apache/iceberg-rust/pull/280 +* Metadata Serde + default partition_specs and sort_orders by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/272 +* feat: make optional oauth param configurable by @himadripal in https://github.com/apache/iceberg-rust/pull/278 +* fix: enable public access to ManifestEntry properties by @a-agmon in https://github.com/apache/iceberg-rust/pull/284 +* feat: Implement the conversion from Arrow Schema to Iceberg Schema by @viirya in https://github.com/apache/iceberg-rust/pull/258 +* Rename function name to `add_manifests` by @viirya in https://github.com/apache/iceberg-rust/pull/293 +* Modify `Bind` calls so that they don't consume `self` and instead return a new struct, leaving the original unmoved by @sdd in https://github.com/apache/iceberg-rust/pull/290 +* Add hive metastore catalog support (part 2/2) by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/285 +* feat: implement prune column for schema by @Dysprosium0626 in https://github.com/apache/iceberg-rust/pull/261 +* chore(deps): Update reqwest requirement from ^0.11 to ^0.12 by @dependabot in https://github.com/apache/iceberg-rust/pull/296 +* Glue Catalog: Basic Setup + Test Infra (1/3) by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/294 +* feat: rest client respect prefix prop by @TennyZhuang in https://github.com/apache/iceberg-rust/pull/297 +* fix: HMS Catalog missing properties `fn create_namespace` by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/303 +* fix: renaming FileScanTask.data_file to data_manifest_entry by @a-agmon in https://github.com/apache/iceberg-rust/pull/300 +* feat: Make OAuth token server configurable by @whynick1 in https://github.com/apache/iceberg-rust/pull/305 +* feat: Glue Catalog - namespace operations (2/3) by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/304 +* feat: add transform_literal by @ZENOTME in https://github.com/apache/iceberg-rust/pull/287 +* feat: Complete predicate builders for all operators. by @QuakeWang in https://github.com/apache/iceberg-rust/pull/276 +* feat: Support customized header in Rest catalog client by @whynick1 in https://github.com/apache/iceberg-rust/pull/306 +* fix: chrono dep by @odysa in https://github.com/apache/iceberg-rust/pull/274 +* feat: Read Parquet data file with projection by @viirya in https://github.com/apache/iceberg-rust/pull/245 +* Fix day timestamp micro by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/312 +* feat: support uri redirect in rest client by @TennyZhuang in https://github.com/apache/iceberg-rust/pull/310 +* refine: separate parquet reader and arrow convert by @ZENOTME in https://github.com/apache/iceberg-rust/pull/313 +* chore: upgrade to rust-version 1.77.1 by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/316 +* Support identifier warehouses by @Fokko in https://github.com/apache/iceberg-rust/pull/308 +* feat: Project transform by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/309 +* Add Struct Accessors to BoundReferences by @sdd in https://github.com/apache/iceberg-rust/pull/317 +* Use `str` args rather than `String` in transform to avoid needing to clone strings by @sdd in https://github.com/apache/iceberg-rust/pull/325 +* chore(deps): Update pilota requirement from 0.10.0 to 0.11.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/327 +* chore(deps): Bump peaceiris/actions-mdbook from 1 to 2 by @dependabot in https://github.com/apache/iceberg-rust/pull/332 +* chore(deps): Bump peaceiris/actions-gh-pages from 3.9.3 to 4.0.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/333 +* chore(deps): Bump apache/skywalking-eyes from 0.5.0 to 0.6.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/328 +* Add `BoundPredicateVisitor` (alternate version) by @sdd in https://github.com/apache/iceberg-rust/pull/334 +* add `InclusiveProjection` Visitor by @sdd in https://github.com/apache/iceberg-rust/pull/335 +* feat: Implement the conversion from Iceberg Schema to Arrow Schema by @ZENOTME in https://github.com/apache/iceberg-rust/pull/277 +* Simplify expression when doing `{and,or}` operations by @Fokko in https://github.com/apache/iceberg-rust/pull/339 +* feat: Glue Catalog - table operations (3/3) by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/314 +* chore: update roadmap by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/336 +* Add `ManifestEvaluator`, used to filter manifests in table scans by @sdd in https://github.com/apache/iceberg-rust/pull/322 +* feat: init iceberg writer by @ZENOTME in https://github.com/apache/iceberg-rust/pull/275 +* Implement manifest filtering in `TableScan` by @sdd in https://github.com/apache/iceberg-rust/pull/323 +* Refactor: Extract `partition_filters` from `ManifestEvaluator` by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/360 +* Basic Integration with Datafusion by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/324 +* refactor: cache partition_schema in `fn plan_files()` by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/362 +* fix (manifest-list): added serde aliases to support both forms conventions by @a-agmon in https://github.com/apache/iceberg-rust/pull/365 +* feat: Extract FileRead and FileWrite trait by @Xuanwo in https://github.com/apache/iceberg-rust/pull/364 +* feat: Convert predicate to arrow filter and push down to parquet reader by @viirya in https://github.com/apache/iceberg-rust/pull/295 +* chore(deps): Update datafusion requirement from 37.0.0 to 38.0.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/369 +* chore(deps): Update itertools requirement from 0.12 to 0.13 by @dependabot in https://github.com/apache/iceberg-rust/pull/376 +* Add `InclusiveMetricsEvaluator` by @sdd in https://github.com/apache/iceberg-rust/pull/347 +* Rename V2 spec names by @gupteaj in https://github.com/apache/iceberg-rust/pull/380 +* feat: make file scan task serializable by @ZENOTME in https://github.com/apache/iceberg-rust/pull/377 +* Feature: Schema into_builder method by @c-thiel in https://github.com/apache/iceberg-rust/pull/381 +* replaced `i32` in `TableUpdate::SetDefaultSortOrder` to `i64` by @rwwwx in https://github.com/apache/iceberg-rust/pull/387 +* fix: make PrimitiveLiteral and Literal not be Ord by @ZENOTME in https://github.com/apache/iceberg-rust/pull/386 +* docs(writer/docker): fix small typos and wording by @jdockerty in https://github.com/apache/iceberg-rust/pull/389 +* feat: `StructAccessor.get` returns `Result>` instead of `Result` by @sdd in https://github.com/apache/iceberg-rust/pull/390 +* feat: add `ExpressionEvaluator` by @marvinlanhenke in https://github.com/apache/iceberg-rust/pull/363 +* Derive Clone for TableUpdate by @c-thiel in https://github.com/apache/iceberg-rust/pull/402 +* Add accessor for Schema identifier_field_ids by @c-thiel in https://github.com/apache/iceberg-rust/pull/388 +* deps: Bump arrow related crates to 52 by @Dysprosium0626 in https://github.com/apache/iceberg-rust/pull/403 +* SnapshotRetention::Tag max_ref_age_ms should be optional by @c-thiel in https://github.com/apache/iceberg-rust/pull/391 +* feat: Add storage features for iceberg by @Xuanwo in https://github.com/apache/iceberg-rust/pull/400 +* Implement BoundPredicateVisitor trait for ManifestFilterVisitor by @s-akhtar-baig in https://github.com/apache/iceberg-rust/pull/367 +* Add missing arrow predicate pushdown implementations for `StartsWith`, `NotStartsWith`, `In`, and `NotIn` by @sdd in https://github.com/apache/iceberg-rust/pull/404 +* feat: make BoundPredicate,Datum serializable by @ZENOTME in https://github.com/apache/iceberg-rust/pull/406 +* refactor: Upgrade hive_metastore to 0.1 by @Xuanwo in https://github.com/apache/iceberg-rust/pull/409 +* fix: Remove duplicate filter by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/414 +* Enhancement: refine the reader interface by @ZENOTME in https://github.com/apache/iceberg-rust/pull/401 +* refactor(catalog/rest): Split http client logic to separate mod by @Xuanwo in https://github.com/apache/iceberg-rust/pull/423 +* Remove #[allow(dead_code)] from the codebase by @vivek378521 in https://github.com/apache/iceberg-rust/pull/421 +* ci: use official typos github action by @shoothzj in https://github.com/apache/iceberg-rust/pull/426 +* feat: support lower_bound&&upper_bound for parquet writer by @ZENOTME in https://github.com/apache/iceberg-rust/pull/383 +* refactor: Implement ArrowAsyncFileWriter directly to remove tokio by @Xuanwo in https://github.com/apache/iceberg-rust/pull/427 +* chore: Don't enable reqwest default features by @Xuanwo in https://github.com/apache/iceberg-rust/pull/432 +* refactor(catalogs/rest): Split user config and runtime config by @Xuanwo in https://github.com/apache/iceberg-rust/pull/431 +* feat: runtime module by @odysa in https://github.com/apache/iceberg-rust/pull/233 +* fix: Fix namespace identifier in url by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/435 +* refactor(io): Split io into smaller mods by @Xuanwo in https://github.com/apache/iceberg-rust/pull/438 +* chore: Use once_cell to replace lazy_static by @Xuanwo in https://github.com/apache/iceberg-rust/pull/443 +* fix: Fix build while no-default-features enabled by @Xuanwo in https://github.com/apache/iceberg-rust/pull/442 +* chore(deps): Bump crate-ci/typos from 1.22.9 to 1.23.1 by @dependabot in https://github.com/apache/iceberg-rust/pull/447 +* docs: Refactor the README to be more user-oriented by @Xuanwo in https://github.com/apache/iceberg-rust/pull/444 +* feat: Add cargo machete by @vaibhawvipul in https://github.com/apache/iceberg-rust/pull/448 +* chore: Use nightly toolchain for check by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/445 +* reuse docker container to save compute resources by @thexiay in https://github.com/apache/iceberg-rust/pull/428 +* feat: Add macos runner for ci by @QuakeWang in https://github.com/apache/iceberg-rust/pull/441 +* chore: remove compose obsolete version (#452) by @yinheli in https://github.com/apache/iceberg-rust/pull/454 +* Refactor file_io_s3_test.rs by @fqaiser94 in https://github.com/apache/iceberg-rust/pull/455 +* chore(deps): Bump crate-ci/typos from 1.23.1 to 1.23.2 by @dependabot in https://github.com/apache/iceberg-rust/pull/457 +* refine: move binary serialize in literal to datum by @ZENOTME in https://github.com/apache/iceberg-rust/pull/456 +* fix: Hms test on macos should use correct arch by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/461 +* Fix ManifestFile length calculation by @nooberfsh in https://github.com/apache/iceberg-rust/pull/466 +* chore(deps): Update typed-builder requirement from ^0.18 to ^0.19 by @dependabot in https://github.com/apache/iceberg-rust/pull/473 +* fix: use avro fixed to represent decimal by @xxchan in https://github.com/apache/iceberg-rust/pull/472 +* feat(catalog!): Deprecate rest.authorization-url in favor of oauth2-server-uri by @ndrluis in https://github.com/apache/iceberg-rust/pull/480 +* Alter `Transform::Day` to map partition types to `Date` rather than `Int` for consistency with reference implementation by @sdd in https://github.com/apache/iceberg-rust/pull/479 +* feat(iceberg): Add memory file IO support by @Xuanwo in https://github.com/apache/iceberg-rust/pull/481 +* Add memory catalog implementation by @fqaiser94 in https://github.com/apache/iceberg-rust/pull/475 +* chore: Enable new rust code format settings by @Xuanwo in https://github.com/apache/iceberg-rust/pull/483 +* docs: Generate rust API docs by @Xuanwo in https://github.com/apache/iceberg-rust/pull/486 +* chore: Fix format of recent PRs by @Xuanwo in https://github.com/apache/iceberg-rust/pull/487 +* Rename folder to memory by @fqaiser94 in https://github.com/apache/iceberg-rust/pull/490 +* chore(deps): Bump crate-ci/typos from 1.23.2 to 1.23.5 by @dependabot in https://github.com/apache/iceberg-rust/pull/493 +* View Spec implementation by @c-thiel in https://github.com/apache/iceberg-rust/pull/331 +* fix: Return error on reader task by @ndrluis in https://github.com/apache/iceberg-rust/pull/498 +* chore: Bump OpenDAL to 0.48 by @Xuanwo in https://github.com/apache/iceberg-rust/pull/500 +* feat: add check compatible func for primitive type by @ZENOTME in https://github.com/apache/iceberg-rust/pull/492 +* refactor(iceberg): Remove an extra config parse logic by @Xuanwo in https://github.com/apache/iceberg-rust/pull/499 +* feat: permit Datum Date<->Int type conversion by @sdd in https://github.com/apache/iceberg-rust/pull/496 +* Add additional S3 FileIO Attributes by @c-thiel in https://github.com/apache/iceberg-rust/pull/505 +* docs: Add links to dev docs by @Xuanwo in https://github.com/apache/iceberg-rust/pull/508 +* chore: Remove typo in README by @Xuanwo in https://github.com/apache/iceberg-rust/pull/509 +* feat: podman support by @alexyin1 in https://github.com/apache/iceberg-rust/pull/489 +* feat(table): Add debug and clone trait to static table struct by @ndrluis in https://github.com/apache/iceberg-rust/pull/510 +* Use namespace location or warehouse location if table location is missing by @fqaiser94 in https://github.com/apache/iceberg-rust/pull/511 +* chore(deps): Bump crate-ci/typos from 1.23.5 to 1.23.6 by @dependabot in https://github.com/apache/iceberg-rust/pull/521 +* Concurrent table scans by @sdd in https://github.com/apache/iceberg-rust/pull/373 +* refactor: replace num_cpus with thread::available_parallelism by @SteveLauC in https://github.com/apache/iceberg-rust/pull/526 +* Fix: MappedLocalTime should not be exposed by @c-thiel in https://github.com/apache/iceberg-rust/pull/529 +* feat: Establish subproject pyiceberg_core by @Xuanwo in https://github.com/apache/iceberg-rust/pull/518 +* fix: complete miss attribute for map && list in avro schema by @ZENOTME in https://github.com/apache/iceberg-rust/pull/411 +* arrow/schema.rs: refactor tests by @AndreMouche in https://github.com/apache/iceberg-rust/pull/531 +* feat: initialise SQL Catalog by @callum-ryan in https://github.com/apache/iceberg-rust/pull/524 +* chore(deps): Bump actions/setup-python from 4 to 5 by @dependabot in https://github.com/apache/iceberg-rust/pull/536 +* feat(storage): support aws session token by @twuebi in https://github.com/apache/iceberg-rust/pull/530 +* Simplify PrimitiveLiteral by @ZENOTME in https://github.com/apache/iceberg-rust/pull/502 +* chore: bump opendal to 0.49 by @jdockerty in https://github.com/apache/iceberg-rust/pull/540 +* feat: support timestamp columns in row filters by @sdd in https://github.com/apache/iceberg-rust/pull/533 +* fix: don't silently drop errors encountered in table scan file planning by @sdd in https://github.com/apache/iceberg-rust/pull/535 +* chore(deps): Update sqlx requirement from 0.7.4 to 0.8.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/537 +* Fix main branch building break by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/541 +* feat: support for gcs storage by @jdockerty in https://github.com/apache/iceberg-rust/pull/520 +* feat: Allow FileIO to reuse http client by @Xuanwo in https://github.com/apache/iceberg-rust/pull/544 +* docs: Add an example to scan an iceberg table by @Xuanwo in https://github.com/apache/iceberg-rust/pull/545 +* Concurrent data file fetching and parallel RecordBatch processing by @sdd in https://github.com/apache/iceberg-rust/pull/515 +* doc: Add statement for contributors to avoid force push as much as possible by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/546 +* feat: Partition Binding and safe PartitionSpecBuilder by @c-thiel in https://github.com/apache/iceberg-rust/pull/491 + +## v0.2.0 - 2024-02-20 + +* chore: Setup project layout by @Xuanwo in https://github.com/apache/iceberg-rust/pull/1 +* ci: Fix version for apache/skywalking-eyes/header by @Xuanwo in https://github.com/apache/iceberg-rust/pull/4 +* feat: Implement serialize/deserialize for datatypes by @JanKaul in https://github.com/apache/iceberg-rust/pull/6 +* docs: Add CONTRIBUTING and finish project setup by @Xuanwo in https://github.com/apache/iceberg-rust/pull/7 +* feat: Add lookup tables to StructType by @JanKaul in https://github.com/apache/iceberg-rust/pull/12 +* feat: Implement error handling by @Xuanwo in https://github.com/apache/iceberg-rust/pull/13 +* chore: Use HashMap instead of BTreeMap for storing fields by id in StructType by @amogh-jahagirdar in https://github.com/apache/iceberg-rust/pull/14 +* chore: Change iceberg into workspace by @Xuanwo in https://github.com/apache/iceberg-rust/pull/15 +* feat: Use macro to define from error. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/17 +* feat: Introduce schema definition. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/19 +* refactor: Align data type with other implementation. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/21 +* chore: Ignore .idea by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/27 +* feat: Implement Iceberg values by @JanKaul in https://github.com/apache/iceberg-rust/pull/20 +* feat: Define schema post order visitor. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/25 +* feat: Add transform by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/26 +* fix: Fix build break in main branch by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/30 +* fix: Update github configuration to avoid conflicting merge by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/31 +* chore(deps): Bump apache/skywalking-eyes from 0.4.0 to 0.5.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/35 +* feat: Table metadata by @JanKaul in https://github.com/apache/iceberg-rust/pull/29 +* feat: Add utility methods to help conversion between literals. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/38 +* [comment] should be IEEE 754 rather than 753 by @zhjwpku in https://github.com/apache/iceberg-rust/pull/39 +* fix: Add doc test action by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/44 +* chore: Ping toolchain version by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/48 +* feat: Introduce conversion between iceberg schema and avro schema by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/40 +* feat: Allow Schema Serialization/deserialization by @y0psolo in https://github.com/apache/iceberg-rust/pull/46 +* chore: Add cargo sort check by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/51 +* chore(deps): Bump actions/checkout from 3 to 4 by @dependabot in https://github.com/apache/iceberg-rust/pull/58 +* Metadata integration tests by @JanKaul in https://github.com/apache/iceberg-rust/pull/57 +* feat: Introduce FileIO by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/53 +* feat: Add Catalog API by @Xuanwo in https://github.com/apache/iceberg-rust/pull/54 +* feat: support transform function by @ZENOTME in https://github.com/apache/iceberg-rust/pull/42 +* chore(deps): Update ordered-float requirement from 3.7.0 to 4.0.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/64 +* feat: Add public methods for catalog related structs by @Xuanwo in https://github.com/apache/iceberg-rust/pull/63 +* minor: Upgrade to latest toolchain by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/68 +* chore(deps): Update opendal requirement from 0.39 to 0.40 by @dependabot in https://github.com/apache/iceberg-rust/pull/65 +* refactor: Make directory for catalog by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/69 +* feat: support read Manifest List by @ZENOTME in https://github.com/apache/iceberg-rust/pull/56 +* chore(deps): Update apache-avro requirement from 0.15 to 0.16 by @dependabot in https://github.com/apache/iceberg-rust/pull/71 +* fix: avro bytes test for Literal by @JanKaul in https://github.com/apache/iceberg-rust/pull/80 +* chore(deps): Update opendal requirement from 0.40 to 0.41 by @dependabot in https://github.com/apache/iceberg-rust/pull/84 +* feat: manifest list writer by @barronw in https://github.com/apache/iceberg-rust/pull/76 +* feat: First version of rest catalog. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/78 +* chore(deps): Update typed-builder requirement from ^0.17 to ^0.18 by @dependabot in https://github.com/apache/iceberg-rust/pull/87 +* feat: Implement load table api. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/89 +* chroes:Manage dependencies using workspace. by @my-vegetable-has-exploded in https://github.com/apache/iceberg-rust/pull/93 +* minor: Provide Debug impl for pub structs #73 by @DeaconDesperado in https://github.com/apache/iceberg-rust/pull/92 +* feat: support ser/deser of value by @ZENOTME in https://github.com/apache/iceberg-rust/pull/82 +* fix: Migrate from tempdir to tempfile crate by @cdaudt in https://github.com/apache/iceberg-rust/pull/91 +* chore(deps): Update opendal requirement from 0.41 to 0.42 by @dependabot in https://github.com/apache/iceberg-rust/pull/101 +* chore(deps): Update itertools requirement from 0.11 to 0.12 by @dependabot in https://github.com/apache/iceberg-rust/pull/102 +* Replace i64 with DateTime by @fqaiser94 in https://github.com/apache/iceberg-rust/pull/94 +* feat: Implement create table and update table api for rest catalog. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/97 +* Fix compile failures by @fqaiser94 in https://github.com/apache/iceberg-rust/pull/105 +* feat: replace 'Builder' with 'TypedBuilder' for 'Snapshot' by @xiaoyang-sde in https://github.com/apache/iceberg-rust/pull/110 +* chore: Upgrade uuid manually and remove pinned version by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/108 +* chore: Add cargo build and build guide by @manuzhang in https://github.com/apache/iceberg-rust/pull/111 +* feat: Add hms catalog layout by @Xuanwo in https://github.com/apache/iceberg-rust/pull/112 +* feat: support UnboundPartitionSpec by @my-vegetable-has-exploded in https://github.com/apache/iceberg-rust/pull/106 +* test: Add integration tests for rest catalog. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/109 +* chore(deps): Update opendal requirement from 0.42 to 0.43 by @dependabot in https://github.com/apache/iceberg-rust/pull/116 +* feat: support read/write Manifest by @ZENOTME in https://github.com/apache/iceberg-rust/pull/79 +* test: Remove binary manifest list avro file by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/118 +* refactor: Conversion between literal and json should depends on type. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/120 +* fix: fix parse partitions in manifest_list by @ZENOTME in https://github.com/apache/iceberg-rust/pull/122 +* feat: Add website layout by @Xuanwo in https://github.com/apache/iceberg-rust/pull/130 +* feat: Expression system. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/132 +* website: Fix typo in book.toml by @Xuanwo in https://github.com/apache/iceberg-rust/pull/136 +* Set `ghp_{pages,path}` properties by @Fokko in https://github.com/apache/iceberg-rust/pull/138 +* chore: Upgrade toolchain to 1.75.0 by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/140 +* feat: Add roadmap and features status in README.md by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/134 +* Remove `publish:` section from `.asf.yaml` by @Fokko in https://github.com/apache/iceberg-rust/pull/141 +* chore(deps): Bump peaceiris/actions-gh-pages from 3.9.2 to 3.9.3 by @dependabot in https://github.com/apache/iceberg-rust/pull/143 +* chore(deps): Update opendal requirement from 0.43 to 0.44 by @dependabot in https://github.com/apache/iceberg-rust/pull/142 +* docs: Change homepage to rust.i.a.o by @Xuanwo in https://github.com/apache/iceberg-rust/pull/146 +* feat: Introduce basic file scan planning. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/129 +* chore: Update contributing guide. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/163 +* chore: Update reader api status by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/162 +* #154 : Add homepage to Cargo.toml by @hiirrxnn in https://github.com/apache/iceberg-rust/pull/160 +* Add formatting for toml files by @Tyler-Sch in https://github.com/apache/iceberg-rust/pull/167 +* chore(deps): Update env_logger requirement from 0.10.0 to 0.11.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/170 +* feat: init file writer interface by @ZENOTME in https://github.com/apache/iceberg-rust/pull/168 +* fix: Manifest parsing should consider schema evolution. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/171 +* docs: Add release guide for iceberg-rust by @Xuanwo in https://github.com/apache/iceberg-rust/pull/147 +* fix: Ignore negative statistics value by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/173 +* feat: Add user guide for website. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/178 +* chore(deps): Update derive_builder requirement from 0.12.0 to 0.13.0 by @dependabot in https://github.com/apache/iceberg-rust/pull/175 +* refactor: Replace unwrap by @odysa in https://github.com/apache/iceberg-rust/pull/183 +* feat: add handwritten serialize by @odysa in https://github.com/apache/iceberg-rust/pull/185 +* Fix: avro schema names for manifest and manifest_list by @JanKaul in https://github.com/apache/iceberg-rust/pull/182 +* feat: Bump hive_metastore to use pure rust thrift impl `volo` by @Xuanwo in https://github.com/apache/iceberg-rust/pull/174 +* feat: Bump version 0.2.0 to prepare for release. by @liurenjie1024 in https://github.com/apache/iceberg-rust/pull/181 +* fix: default_partition_spec using the partition_spec_id set by @odysa in https://github.com/apache/iceberg-rust/pull/190 +* Docs: Add required Cargo version to install guide by @manuzhang in https://github.com/apache/iceberg-rust/pull/191 +* chore(deps): Update opendal requirement from 0.44 to 0.45 by @dependabot in https://github.com/apache/iceberg-rust/pull/195 + +[v0.3.0]: https://github.com/apache/iceberg-rust/compare/v0.2.0...v0.3.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 712e7e27d..f66d3248e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -73,6 +73,8 @@ All pull requests should be reviewed by at least one iceberg-rust committer. All pull requests are squash merged. We generally discourage large pull requests that are over 300-500 lines of diff. If you would like to propose a change that is larger we suggest coming onto [Iceberg's DEV mailing list](mailto:dev@iceberg.apache.org) or [Slack #rust Channel](https://join.slack.com/t/apache-iceberg/shared_invite/zt-1zbov3k6e-KtJfoaxp97YfX6dPz1Bk7A) and discuss it with us. This way we can talk through the solution and discuss if a change that large is even needed! This will produce a quicker response to the change and likely produce code that aligns better with our process. +When a pull request is under review, please avoid using force push as it makes it difficult for reviewer to track changes. If you need to keep the branch up to date with the main branch, consider using `git merge` instead. + ### CI Currently, iceberg-rust uses GitHub Actions to run tests. The workflows are defined in `.github/workflows`. @@ -91,6 +93,8 @@ The fastest way is: ### Bring your own toolbox +#### Install rust + iceberg-rust is primarily a Rust project. To build iceberg-rust, you will need to set up Rust development first. We highly recommend using [rustup](https://rustup.rs/) for the setup process. For Linux or MacOS, use the following command: @@ -108,11 +112,22 @@ $ cargo version cargo 1.69.0 (6e9a83356 2023-04-12) ``` +#### Install Docker or Podman + +Currently, iceberg-rust uses Docker to set up environment for integration tests. Podman is also supported. + +You can learn how to install Docker from [here](https://docs.docker.com/get-docker/). + +For macos users, you can install [OrbStack](https://orbstack.dev/) as a docker alternative. + +For podman users, refer to [Using Podman instead of Docker](docs/contributing/podman.md) + ## Build * To compile the project: `make build` * To check code styles: `make check` -* To run tests: `make test` +* To run unit tests only: `make unit-test` +* To run all tests: `make test` ## Code of Conduct diff --git a/Cargo.toml b/Cargo.toml index a59a4bb4c..8d04f6799 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,45 +17,81 @@ [workspace] resolver = "2" -members = ["crates/catalog/*", "crates/iceberg", "crates/test_utils"] +members = [ + "crates/catalog/*", + "crates/examples", + "crates/iceberg", + "crates/integrations/*", + "crates/test_utils", +] +exclude = ["bindings/python"] + +[workspace.package] +version = "0.3.0" +edition = "2021" +homepage = "https://rust.iceberg.apache.org/" + +repository = "https://github.com/apache/iceberg-rust" +license = "Apache-2.0" +rust-version = "1.77.1" [workspace.dependencies] anyhow = "1.0.72" -apache-avro = "0.16" -arrow-arith = { version = ">=46" } -arrow-array = { version = ">=46" } -arrow-schema = { version = ">=46" } +apache-avro = "0.17" +array-init = "2" +arrow-arith = { version = "52" } +arrow-array = { version = "52" } +arrow-ord = { version = "52" } +arrow-schema = { version = "52" } +arrow-select = { version = "52" } +arrow-string = { version = "52" } +async-stream = "0.3.5" async-trait = "0.1" +async-std = "1.12" +aws-config = "1.1.8" +aws-sdk-glue = "1.21" bimap = "0.6" bitvec = "1.0.1" -chrono = "0.4" -derive_builder = "0.12.0" +bytes = "1.5" +chrono = "0.4.34" +ctor = "0.2.8" +derive_builder = "0.20" either = "1" -env_logger = "0.10.0" +env_logger = "0.11.0" +fnv = "1" futures = "0.3" -iceberg = { path = "./crates/iceberg" } -iceberg-catalog-rest = { path = "./crates/catalog/rest" } -itertools = "0.12" -lazy_static = "1" -log = "^0.4" -mockito = "^1" +iceberg = { version = "0.3.0", path = "./crates/iceberg" } +iceberg-catalog-rest = { version = "0.3.0", path = "./crates/catalog/rest" } +iceberg-catalog-hms = { version = "0.3.0", path = "./crates/catalog/hms" } +iceberg-catalog-memory = { version = "0.3.0", path = "./crates/catalog/memory" } +itertools = "0.13" +log = "0.4" +mockito = "1" murmur3 = "0.5.2" once_cell = "1" -opendal = "0.43" -ordered-float = "4.0.0" -pretty_assertions = "1.4.0" +opendal = "0.49" +ordered-float = "4" +parquet = "52" +paste = "1" +pilota = "0.11.2" +pretty_assertions = "1.4" port_scanner = "0.1.5" -reqwest = { version = "^0.11", features = ["json"] } -rust_decimal = "1.31.0" -serde = { version = "^1.0", features = ["rc"] } +rand = "0.8" +regex = "1.10.5" +reqwest = { version = "0.12", default-features = false, features = ["json"] } +rust_decimal = "1.31" +serde = { version = "1", features = ["rc"] } serde_bytes = "0.11.8" -serde_derive = "^1.0" -serde_json = "^1.0" +serde_derive = "1" +serde_json = "1" serde_repr = "0.1.16" -serde_with = "3.4.0" +serde_with = "3.4" tempfile = "3.8" -tokio = { version = "1", features = ["macros"] } -typed-builder = "^0.18" +tokio = { version = "1", default-features = false } +typed-builder = "0.20" url = "2" urlencoding = "2" -uuid = "1.6.1" +uuid = { version = "1.6.1", features = ["v7"] } +volo-thrift = "0.10" +hive_metastore = "0.1" +tera = "1" diff --git a/Makefile b/Makefile index c34f6c97d..4ecc9bd88 100644 --- a/Makefile +++ b/Makefile @@ -17,26 +17,46 @@ .EXPORT_ALL_VARIABLES: -RUST_LOG = debug - build: - cargo build + cargo build --all-targets --all-features --workspace check-fmt: - cargo fmt --all -- --check + cargo fmt --all -- --check check-clippy: - cargo clippy --all-targets --all-features --workspace -- -D warnings + cargo clippy --all-targets --all-features --workspace -- -D warnings + +install-cargo-sort: + cargo install cargo-sort@1.0.9 -cargo-sort: - cargo install cargo-sort +cargo-sort: install-cargo-sort cargo sort -c -w -check: check-fmt check-clippy cargo-sort +install-cargo-machete: + cargo install cargo-machete + +cargo-machete: install-cargo-machete + cargo machete + +install-taplo-cli: + cargo install taplo-cli@0.9.0 + +fix-toml: install-taplo-cli + taplo fmt -unit-test: +check-toml: install-taplo-cli + taplo check + +check: check-fmt check-clippy cargo-sort check-toml cargo-machete + +doc-test: + cargo test --no-fail-fast --doc --all-features --workspace + +unit-test: doc-test cargo test --no-fail-fast --lib --all-features --workspace -test: +test: doc-test cargo test --no-fail-fast --all-targets --all-features --workspace - cargo test --no-fail-fast --doc --all-features --workspace \ No newline at end of file + +clean: + cargo clean diff --git a/README.md b/README.md index d7caa34bc..30168141b 100644 --- a/README.md +++ b/README.md @@ -17,18 +17,90 @@ ~ under the License. --> -# Apache Iceberg Rust +# Apache Iceberg™ Rust + + + +Rust implementation of [Apache Iceberg™](https://iceberg.apache.org/). + +Working on [v0.3.0 Release Milestone](https://github.com/apache/iceberg-rust/milestone/2) + +## Components + +The Apache Iceberg Rust project is composed of the following components: + +| Name | Release | Docs | +|--------------------------|-----------------------------------------------------------------|-------------------------------------------------------------------------------------------------------| +| [iceberg] | [![iceberg image]][iceberg link] | [![docs release]][iceberg release docs] [![docs dev]][iceberg dev docs] | +| [iceberg-datafusion] | [![iceberg-datafusion image]][iceberg-datafusion link] | [![docs release]][iceberg-datafusion release docs] [![docs dev]][iceberg-datafusion dev docs] | +| [iceberg-catalog-glue] | [![iceberg-catalog-glue image]][iceberg-catalog-glue link] | [![docs release]][iceberg-catalog-glue release docs] [![docs dev]][iceberg-catalog-glue dev docs] | +| [iceberg-catalog-hms] | [![iceberg-catalog-hms image]][iceberg-catalog-hms link] | [![docs release]][iceberg-catalog-hms release docs] [![docs dev]][iceberg-catalog-hms dev docs] | +| [iceberg-catalog-memory] | [![iceberg-catalog-memory image]][iceberg-catalog-memory link] | [![docs release]][iceberg-catalog-memory release docs] [![docs dev]][iceberg-catalog-memory dev docs] | +| [iceberg-catalog-rest] | [![iceberg-catalog-rest image]][iceberg-catalog-rest link] | [![docs release]][iceberg-catalog-rest release docs] [![docs dev]][iceberg-catalog-rest dev docs] | + +[docs release]: https://img.shields.io/badge/docs-release-blue +[docs dev]: https://img.shields.io/badge/docs-dev-blue +[iceberg]: crates/iceberg/README.md +[iceberg image]: https://img.shields.io/crates/v/iceberg.svg +[iceberg link]: https://crates.io/crates/iceberg +[iceberg release docs]: https://docs.rs/iceberg +[iceberg dev docs]: https://rust.iceberg.apache.org/api/iceberg/ + +[iceberg-datafusion]: crates/integrations/datafusion/README.md +[iceberg-datafusion image]: https://img.shields.io/crates/v/iceberg-datafusion.svg +[iceberg-datafusion link]: https://crates.io/crates/iceberg-datafusion +[iceberg-datafusion dev docs]: https://rust.iceberg.apache.org/api/iceberg_datafusion/ +[iceberg-datafusion release docs]: https://docs.rs/iceberg-datafusion + +[iceberg-catalog-glue]: crates/catalog/glue/README.md +[iceberg-catalog-glue image]: https://img.shields.io/crates/v/iceberg-catalog-glue.svg +[iceberg-catalog-glue link]: https://crates.io/crates/iceberg-catalog-glue +[iceberg-catalog-glue release docs]: https://docs.rs/iceberg-catalog-glue +[iceberg-catalog-glue dev docs]: https://rust.iceberg.apache.org/api/iceberg_catalog_glue/ + +[iceberg-catalog-hms]: crates/catalog/hms/README.md +[iceberg-catalog-hms image]: https://img.shields.io/crates/v/iceberg-catalog-hms.svg +[iceberg-catalog-hms link]: https://crates.io/crates/iceberg-catalog-hms +[iceberg-catalog-hms release docs]: https://docs.rs/iceberg-catalog-hms +[iceberg-catalog-hms dev docs]: https://rust.iceberg.apache.org/api/iceberg_catalog_hms/ + +[iceberg-catalog-memory]: crates/catalog/memory/README.md +[iceberg-catalog-memory image]: https://img.shields.io/crates/v/iceberg-catalog-memory.svg +[iceberg-catalog-memory link]: https://crates.io/crates/iceberg-catalog-memory +[iceberg-catalog-memory release docs]: https://docs.rs/iceberg-catalog-memory +[iceberg-catalog-memory dev docs]: https://rust.iceberg.apache.org/api/iceberg_catalog_memory/ + +[iceberg-catalog-rest]: crates/catalog/rest/README.md +[iceberg-catalog-rest image]: https://img.shields.io/crates/v/iceberg-catalog-rest.svg +[iceberg-catalog-rest link]: https://crates.io/crates/iceberg-catalog-rest +[iceberg-catalog-rest release docs]: https://docs.rs/iceberg-catalog-rest +[iceberg-catalog-rest dev docs]: https://rust.iceberg.apache.org/api/iceberg_catalog_rest/ + +## Supported Rust Version + +Iceberg Rust is built and tested with stable rust, and will keep a rolling MSRV(minimum supported rust version). The +current MSRV is 1.77.1. + +Also, we use unstable rust to run linters, such as `clippy` and `rustfmt`. But this will not affect downstream users, +and only MSRV is required. -Native Rust implementation of [Apache Iceberg](https://iceberg.apache.org/). ## Contribute -Iceberg is an active open-source project. We are always open to people who want to use it or contribute to it. Here are some ways to go. +Apache Iceberg is an active open-source project, governed under the Apache Software Foundation (ASF). We are always open to people who want to use or contribute to it. Here are some ways to get involved. - Start with [Contributing Guide](CONTRIBUTING.md). - Submit [Issues](https://github.com/apache/iceberg-rust/issues/new) for bug report or feature requests. -- Discuss at [dev mailing list](mailto:dev@iceberg.apache.org) ([subscribe](mailto:dev-subscribe@iceberg.apache.org?subject=(send%20this%20email%20to%20subscribe)) / [unsubscribe](mailto:dev-unsubscribe@iceberg.apache.org?subject=(send%20this%20email%20to%20unsubscribe)) / [archives](https://lists.apache.org/list.html?dev@iceberg.apache.org)) -- Talk to community directly at [Slack #rust channel](https://join.slack.com/t/apache-iceberg/shared_invite/zt-1zbov3k6e-KtJfoaxp97YfX6dPz1Bk7A). +- Discuss + at [dev mailing list](mailto:dev@iceberg.apache.org) ([subscribe]() / [unsubscribe]() / [archives](https://lists.apache.org/list.html?dev@iceberg.apache.org)) +- Talk to the community directly + at [Slack #rust channel](https://join.slack.com/t/apache-iceberg/shared_invite/zt-1zbov3k6e-KtJfoaxp97YfX6dPz1Bk7A). + +The Apache Iceberg community is built on the principles described in the [Apache Way](https://www.apache.org/theapacheway/index.html) and all who engage with the community are expected to be respectful, open, come with the best interests of the community in mind, and abide by the Apache Foundation [Code of Conduct](https://www.apache.org/foundation/policies/conduct.html). +## Users + +- [Databend](https://github.com/datafuselabs/databend/): An open-source cloud data warehouse that serves as a cost-effective alternative to Snowflake. +- [iceberg-catalog](https://github.com/hansetag/iceberg-catalog): A Rust implementation of the Iceberg REST Catalog specification. ## License diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml new file mode 100644 index 000000000..0260f788b --- /dev/null +++ b/bindings/python/Cargo.toml @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "pyiceberg_core_rust" +version = "0.0.1" +edition = "2021" +homepage = "https://rust.iceberg.apache.org" +rust-version = "1.77.1" +# This crate is used to build python bindings, we don't want to publish it +publish = false + +license = "Apache-2.0" +keywords = ["iceberg"] + +[lib] +crate-type = ["cdylib"] + +[dependencies] +iceberg = { path = "../../crates/iceberg" } +pyo3 = { version = "0.21.1", features = ["extension-module"] } +arrow = { version = "52.2.0", features = ["pyarrow"] } diff --git a/bindings/python/README.md b/bindings/python/README.md new file mode 100644 index 000000000..fe4300e1f --- /dev/null +++ b/bindings/python/README.md @@ -0,0 +1,40 @@ + + +# Pyiceberg Core + +This project is used to build an iceberg-rust powered core for pyiceberg. + +## Setup + +```shell +pip install hatch==1.12.0 +``` + +## Build + +```shell +hatch run dev:develop +``` + +## Test + +```shell +hatch run dev:test +``` \ No newline at end of file diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml new file mode 100644 index 000000000..f1f0a100f --- /dev/null +++ b/bindings/python/pyproject.toml @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[project] +name = "pyiceberg_core" +version = "0.0.1" +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +[tool.maturin] +features = ["pyo3/extension-module"] +python-source = "python" +module-name = "pyiceberg_core.pyiceberg_core_rust" + +[tool.ruff.lint] +ignore = ["F403", "F405"] + +[tool.hatch.envs.dev] +dependencies = [ + "maturin>=1.0,<2.0", + "pytest>=8.3.2", + "pyarrow>=17.0.0", +] + +[tool.hatch.envs.dev.scripts] +develop = "maturin develop" +build = "maturin build --out dist --sdist" +test = "pytest" diff --git a/bindings/python/python/pyiceberg_core/__init__.py b/bindings/python/python/pyiceberg_core/__init__.py new file mode 100644 index 000000000..067bb6f07 --- /dev/null +++ b/bindings/python/python/pyiceberg_core/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .pyiceberg_core_rust import * + +__doc__ = pyiceberg_core_rust.__doc__ +__all__ = pyiceberg_core_rust.__all__ diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs new file mode 100644 index 000000000..5c3f77ff7 --- /dev/null +++ b/bindings/python/src/lib.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use iceberg::io::FileIOBuilder; +use pyo3::prelude::*; +use pyo3::wrap_pyfunction; + +mod transform; + +#[pyfunction] +fn hello_world() -> PyResult { + let _ = FileIOBuilder::new_fs_io().build().unwrap(); + Ok("Hello, world!".to_string()) +} + + +#[pymodule] +fn pyiceberg_core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(hello_world, m)?)?; + + m.add_class::()?; + Ok(()) +} diff --git a/bindings/python/src/transform.rs b/bindings/python/src/transform.rs new file mode 100644 index 000000000..8f4585b2a --- /dev/null +++ b/bindings/python/src/transform.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use iceberg::spec::Transform; +use iceberg::transform::create_transform_function; + +use arrow::{ + array::{make_array, Array, ArrayData}, +}; +use arrow::pyarrow::{FromPyArrow, ToPyArrow}; +use pyo3::{exceptions::PyValueError, prelude::*}; + +fn to_py_err(err: iceberg::Error) -> PyErr { + PyValueError::new_err(err.to_string()) +} + +#[pyclass] +pub struct ArrowArrayTransform { +} + +fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult { + // import + let array = ArrayData::from_pyarrow_bound(array.bind(py))?; + let array = make_array(array); + let transform_function = create_transform_function(&transform).map_err(to_py_err)?; + let array = transform_function.transform(array).map_err(to_py_err)?; + // export + let array = array.into_data(); + array.to_pyarrow(py) +} + +#[pymethods] +impl ArrowArrayTransform { + #[staticmethod] + pub fn identity(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Identity, py) + } + + #[staticmethod] + pub fn void(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Void, py) + } + + #[staticmethod] + pub fn year(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Year, py) + } + + #[staticmethod] + pub fn month(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Month, py) + } + + #[staticmethod] + pub fn day(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Day, py) + } + + #[staticmethod] + pub fn hour(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Hour, py) + } + + #[staticmethod] + pub fn bucket(array: PyObject, num_buckets: u32, py: Python) -> PyResult { + apply(array, Transform::Bucket(num_buckets), py) + } + + #[staticmethod] + pub fn truncate(array: PyObject, width: u32, py: Python) -> PyResult { + apply(array, Transform::Truncate(width), py) + } +} diff --git a/bindings/python/tests/test_basic.py b/bindings/python/tests/test_basic.py new file mode 100644 index 000000000..817793ba8 --- /dev/null +++ b/bindings/python/tests/test_basic.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyiceberg_core import hello_world + + +def test_hello_world(): + hello_world() diff --git a/bindings/python/tests/test_transform.py b/bindings/python/tests/test_transform.py new file mode 100644 index 000000000..1fa2d577a --- /dev/null +++ b/bindings/python/tests/test_transform.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import date, datetime + +import pyarrow as pa +import pytest +from pyiceberg_core import ArrowArrayTransform + + +def test_identity_transform(): + arr = pa.array([1, 2]) + result = ArrowArrayTransform.identity(arr) + assert result == arr + + +def test_bucket_transform(): + arr = pa.array([1, 2]) + result = ArrowArrayTransform.bucket(arr, 10) + expected = pa.array([6, 2], type=pa.int32()) + assert result == expected + + +def test_bucket_transform_fails_for_list_type_input(): + arr = pa.array([[1, 2], [3, 4]]) + with pytest.raises( + ValueError, + match=r"FeatureUnsupported => Unsupported data type for bucket transform", + ): + ArrowArrayTransform.bucket(arr, 10) + + +def test_bucket_chunked_array(): + chunked = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]) + result_chunks = [] + for arr in chunked.iterchunks(): + result_chunks.append(ArrowArrayTransform.bucket(arr, 10)) + + expected = pa.chunked_array( + [pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())] + ) + assert pa.chunked_array(result_chunks).equals(expected) + + +def test_year_transform(): + arr = pa.array([date(1970, 1, 1), date(2000, 1, 1)]) + result = ArrowArrayTransform.year(arr) + expected = pa.array([0, 30], type=pa.int32()) + assert result == expected + + +def test_month_transform(): + arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)]) + result = ArrowArrayTransform.month(arr) + expected = pa.array([0, 30 * 12 + 3], type=pa.int32()) + assert result == expected + + +def test_day_transform(): + arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)]) + result = ArrowArrayTransform.day(arr) + expected = pa.array([0, 11048], type=pa.int32()) + assert result == expected + + +def test_hour_transform(): + arr = pa.array([datetime(1970, 1, 1, 19, 1, 23), datetime(2000, 3, 1, 12, 1, 23)]) + result = ArrowArrayTransform.hour(arr) + expected = pa.array([19, 264420], type=pa.int32()) + assert result == expected + + +def test_truncate_transform(): + arr = pa.array(["this is a long string", "hi my name is sung"]) + result = ArrowArrayTransform.truncate(arr, 5) + expected = pa.array(["this ", "hi my"]) + assert result == expected diff --git a/crates/catalog/glue/Cargo.toml b/crates/catalog/glue/Cargo.toml new file mode 100644 index 000000000..0d2e1f983 --- /dev/null +++ b/crates/catalog/glue/Cargo.toml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "iceberg-catalog-glue" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +rust-version = { workspace = true } + +categories = ["database"] +description = "Apache Iceberg Glue Catalog Support" +repository = { workspace = true } +license = { workspace = true } +keywords = ["iceberg", "glue", "catalog"] + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +aws-config = { workspace = true } +aws-sdk-glue = { workspace = true } +iceberg = { workspace = true } +log = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +typed-builder = { workspace = true } +uuid = { workspace = true } + +[dev-dependencies] +ctor = { workspace = true } +iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } +port_scanner = { workspace = true } diff --git a/crates/catalog/glue/DEPENDENCIES.rust.tsv b/crates/catalog/glue/DEPENDENCIES.rust.tsv new file mode 100644 index 000000000..735d5447b --- /dev/null +++ b/crates/catalog/glue/DEPENDENCIES.rust.tsv @@ -0,0 +1,328 @@ +crate 0BSD Apache-2.0 Apache-2.0 WITH LLVM-exception BSD-2-Clause BSD-3-Clause BSL-1.0 CC0-1.0 ISC MIT MPL-2.0 OpenSSL Unicode-DFS-2016 Unlicense Zlib +addr2line@0.22.0 X X +adler@1.0.2 X X X +adler32@1.2.0 X +ahash@0.8.11 X X +aho-corasick@1.1.3 X X +alloc-no-stdlib@2.0.4 X +alloc-stdlib@0.2.2 X +allocator-api2@0.2.18 X X +android-tzdata@0.1.1 X X +android_system_properties@0.1.5 X X +anstream@0.6.15 X X +anstyle@1.0.8 X X +anstyle-parse@0.2.5 X X +anstyle-query@1.1.1 X X +anstyle-wincon@3.0.4 X X +anyhow@1.0.86 X X +apache-avro@0.17.0 X +array-init@2.1.0 X X +arrayvec@0.7.4 X X +arrow-arith@52.2.0 X +arrow-array@52.2.0 X +arrow-buffer@52.2.0 X +arrow-cast@52.2.0 X +arrow-data@52.2.0 X +arrow-ipc@52.2.0 X +arrow-ord@52.2.0 X +arrow-schema@52.2.0 X +arrow-select@52.2.0 X +arrow-string@52.2.0 X +async-trait@0.1.81 X X +atoi@2.0.0 X +autocfg@1.3.0 X X +aws-config@1.5.5 X +aws-credential-types@1.2.0 X +aws-runtime@1.4.0 X +aws-sdk-glue@1.53.0 X +aws-sdk-sso@1.37.0 X +aws-sdk-ssooidc@1.38.0 X +aws-sdk-sts@1.37.0 X +aws-sigv4@1.2.3 X +aws-smithy-async@1.2.1 X +aws-smithy-http@0.60.9 X +aws-smithy-json@0.60.7 X +aws-smithy-query@0.60.7 X +aws-smithy-runtime@1.6.2 X +aws-smithy-runtime-api@1.7.2 X +aws-smithy-types@1.2.0 X +aws-smithy-xml@0.60.8 X +aws-types@1.3.3 X +backon@0.4.4 X +backtrace@0.3.73 X X +base64@0.21.7 X X +base64@0.22.1 X X +base64-simd@0.8.0 X +bigdecimal@0.4.5 X X +bimap@0.6.3 X X +bitflags@1.3.2 X X +bitflags@2.6.0 X X +bitvec@1.0.1 X +block-buffer@0.10.4 X X +brotli@6.0.0 X X +brotli-decompressor@4.0.1 X X +bumpalo@3.16.0 X X +byteorder@1.5.0 X X +bytes@1.7.1 X +bytes-utils@0.1.4 X X +cc@1.1.11 X X +cfg-if@1.0.0 X X +chrono@0.4.38 X X +colorchoice@1.0.2 X X +const-oid@0.9.6 X X +const-random@0.1.18 X X +const-random-macro@0.1.16 X X +core-foundation@0.9.4 X X +core-foundation-sys@0.8.7 X X +core2@0.4.0 X X +cpufeatures@0.2.13 X X +crc32c@0.6.8 X X +crc32fast@1.4.2 X X +crunchy@0.2.2 X +crypto-common@0.1.6 X X +darling@0.20.10 X +darling_core@0.20.10 X +darling_macro@0.20.10 X +dary_heap@0.3.6 X X +deranged@0.3.11 X X +derive_builder@0.20.0 X X +derive_builder_core@0.20.0 X X +derive_builder_macro@0.20.0 X X +digest@0.10.7 X X +either@1.13.0 X X +env_filter@0.1.2 X X +env_logger@0.11.5 X X +equivalent@1.0.1 X X +fastrand@2.1.0 X X +flagset@0.4.6 X +flatbuffers@24.3.25 X +flate2@1.0.31 X X +fnv@1.0.7 X X +form_urlencoded@1.2.1 X X +funty@2.0.0 X +futures@0.3.30 X X +futures-channel@0.3.30 X X +futures-core@0.3.30 X X +futures-executor@0.3.30 X X +futures-io@0.3.30 X X +futures-macro@0.3.30 X X +futures-sink@0.3.30 X X +futures-task@0.3.30 X X +futures-util@0.3.30 X X +generic-array@0.14.7 X +getrandom@0.2.15 X X +gimli@0.29.0 X X +h2@0.3.26 X +half@2.4.1 X X +hashbrown@0.14.5 X X +heck@0.5.0 X X +hermit-abi@0.3.9 X X +hex@0.4.3 X X +hmac@0.12.1 X X +home@0.5.9 X X +http@0.2.12 X X +http@1.1.0 X X +http-body@0.4.6 X +http-body@1.0.1 X +http-body-util@0.1.2 X +httparse@1.9.4 X X +httpdate@1.0.3 X X +humantime@2.1.0 X X +hyper@0.14.30 X +hyper@1.4.1 X +hyper-rustls@0.24.2 X X X +hyper-rustls@0.27.2 X X X +hyper-util@0.1.7 X +iana-time-zone@0.1.60 X X +iana-time-zone-haiku@0.1.2 X X +iceberg@0.3.0 X +iceberg-catalog-glue@0.3.0 X +iceberg-catalog-memory@0.3.0 X +iceberg_test_utils@0.3.0 X +ident_case@1.0.1 X X +idna@0.5.0 X X +indexmap@2.4.0 X X +integer-encoding@3.0.4 X +ipnet@2.9.0 X X +is_terminal_polyfill@1.70.1 X X +itertools@0.13.0 X X +itoa@1.0.11 X X +jobserver@0.1.32 X X +js-sys@0.3.70 X X +lexical-core@0.8.5 X X +lexical-parse-float@0.8.5 X X +lexical-parse-integer@0.8.6 X X +lexical-util@0.8.5 X X +lexical-write-float@0.8.5 X X +lexical-write-integer@0.8.5 X X +libc@0.2.155 X X +libflate@2.1.0 X +libflate_lz77@2.1.0 X +libm@0.2.8 X X +log@0.4.22 X X +lz4_flex@0.11.3 X +md-5@0.10.6 X X +memchr@2.7.4 X X +mime@0.3.17 X X +miniz_oxide@0.7.4 X X X +mio@1.0.2 X +murmur3@0.5.2 X X +num@0.4.3 X X +num-bigint@0.4.6 X X +num-complex@0.4.6 X X +num-conv@0.1.0 X X +num-integer@0.1.46 X X +num-iter@0.1.45 X X +num-rational@0.4.2 X X +num-traits@0.2.19 X X +object@0.36.3 X X +once_cell@1.19.0 X X +opendal@0.49.0 X +openssl-probe@0.1.5 X X +ordered-float@2.10.1 X +ordered-float@4.2.2 X +outref@0.5.1 X +parquet@52.2.0 X +paste@1.0.15 X X +percent-encoding@2.3.1 X X +pin-project@1.1.5 X X +pin-project-internal@1.1.5 X X +pin-project-lite@0.2.14 X X +pin-utils@0.1.0 X X +pkg-config@0.3.30 X X +powerfmt@0.2.0 X X +ppv-lite86@0.2.20 X X +proc-macro2@1.0.86 X X +quad-rand@0.2.1 X +quick-xml@0.36.1 X +quote@1.0.36 X X +radium@0.7.0 X +rand@0.8.5 X X +rand_chacha@0.3.1 X X +rand_core@0.6.4 X X +regex@1.10.6 X X +regex-automata@0.4.7 X X +regex-lite@0.1.6 X X +regex-syntax@0.8.4 X X +reqsign@0.16.0 X +reqwest@0.12.5 X X +ring@0.17.8 X +rle-decode-fast@1.0.3 X X +rust_decimal@1.35.0 X +rustc-demangle@0.1.24 X X +rustc_version@0.4.0 X X +rustls@0.21.12 X X X +rustls@0.23.12 X X X +rustls-native-certs@0.6.3 X X X +rustls-pemfile@1.0.4 X X X +rustls-pemfile@2.1.3 X X X +rustls-pki-types@1.8.0 X X +rustls-webpki@0.101.7 X +rustls-webpki@0.102.6 X +rustversion@1.0.17 X X +ryu@1.0.18 X X +schannel@0.1.23 X +sct@0.7.1 X X X +security-framework@2.11.1 X X +security-framework-sys@2.11.1 X X +semver@1.0.23 X X +seq-macro@0.3.5 X X +serde@1.0.207 X X +serde_bytes@0.11.15 X X +serde_derive@1.0.207 X X +serde_json@1.0.124 X X +serde_repr@0.1.19 X X +serde_urlencoded@0.7.1 X X +serde_with@3.9.0 X X +serde_with_macros@3.9.0 X X +sha1@0.10.6 X X +sha2@0.10.8 X X +shlex@1.3.0 X X +signal-hook-registry@1.4.2 X X +slab@0.4.9 X +smallvec@1.13.2 X X +snap@1.1.1 X +socket2@0.5.7 X X +spin@0.9.8 X +static_assertions@1.1.0 X X +strsim@0.11.1 X +strum@0.26.3 X +strum_macros@0.26.4 X +subtle@2.6.1 X +syn@2.0.74 X X +sync_wrapper@1.0.1 X +tap@1.0.1 X +thiserror@1.0.63 X X +thiserror-impl@1.0.63 X X +thrift@0.17.0 X +time@0.3.36 X X +time-core@0.1.2 X X +tiny-keccak@2.0.2 X +tinyvec@1.8.0 X X X +tinyvec_macros@0.1.1 X X X +tokio@1.39.2 X +tokio-macros@2.4.0 X +tokio-rustls@0.24.1 X X +tokio-rustls@0.26.0 X X +tokio-util@0.7.11 X +tower@0.4.13 X +tower-layer@0.3.3 X +tower-service@0.3.3 X +tracing@0.1.40 X +tracing-attributes@0.1.27 X +tracing-core@0.1.32 X +try-lock@0.2.5 X +twox-hash@1.6.3 X +typed-builder@0.19.1 X X +typed-builder-macro@0.19.1 X X +typenum@1.17.0 X X +unicode-bidi@0.3.15 X X +unicode-ident@1.0.12 X X X +unicode-normalization@0.1.23 X X +untrusted@0.9.0 X +url@2.5.2 X X +urlencoding@2.1.3 X +utf8parse@0.2.2 X X +uuid@1.10.0 X X +version_check@0.9.5 X X +vsimd@0.8.0 X +want@0.3.1 X +wasi@0.11.0+wasi-snapshot-preview1 X X X +wasm-bindgen@0.2.93 X X +wasm-bindgen-backend@0.2.93 X X +wasm-bindgen-futures@0.4.43 X X +wasm-bindgen-macro@0.2.93 X X +wasm-bindgen-macro-support@0.2.93 X X +wasm-bindgen-shared@0.2.93 X X +wasm-streams@0.4.0 X X +web-sys@0.3.70 X X +webpki-roots@0.26.3 X +windows-core@0.52.0 X X +windows-sys@0.48.0 X X +windows-sys@0.52.0 X X +windows-targets@0.48.5 X X +windows-targets@0.52.6 X X +windows_aarch64_gnullvm@0.48.5 X X +windows_aarch64_gnullvm@0.52.6 X X +windows_aarch64_msvc@0.48.5 X X +windows_aarch64_msvc@0.52.6 X X +windows_i686_gnu@0.48.5 X X +windows_i686_gnu@0.52.6 X X +windows_i686_gnullvm@0.52.6 X X +windows_i686_msvc@0.48.5 X X +windows_i686_msvc@0.52.6 X X +windows_x86_64_gnu@0.48.5 X X +windows_x86_64_gnu@0.52.6 X X +windows_x86_64_gnullvm@0.48.5 X X +windows_x86_64_gnullvm@0.52.6 X X +windows_x86_64_msvc@0.48.5 X X +windows_x86_64_msvc@0.52.6 X X +winreg@0.52.0 X +wyz@0.5.1 X +xmlparser@0.13.6 X X +zerocopy@0.7.35 X X X +zerocopy-derive@0.7.35 X X X +zeroize@1.8.1 X X +zstd@0.13.2 X +zstd-safe@7.2.1 X X +zstd-sys@2.0.12+zstd.1.5.6 X X diff --git a/crates/catalog/glue/README.md b/crates/catalog/glue/README.md new file mode 100644 index 000000000..fb7f6bf0f --- /dev/null +++ b/crates/catalog/glue/README.md @@ -0,0 +1,27 @@ + + +# Apache Iceberg Glue Catalog Official Native Rust Implementation + +[![crates.io](https://img.shields.io/crates/v/iceberg.svg)](https://crates.io/crates/iceberg-catalog-glue) +[![docs.rs](https://img.shields.io/docsrs/iceberg.svg)](https://docs.rs/iceberg/latest/iceberg-catalog-glue/) + +This crate contains the official Native Rust implementation of Apache Iceberg Glue Catalog. + +See the [API documentation](https://docs.rs/iceberg-catalog-glue/latest) for examples and the full API. diff --git a/crates/catalog/glue/src/catalog.rs b/crates/catalog/glue/src/catalog.rs new file mode 100644 index 000000000..18e30f3d0 --- /dev/null +++ b/crates/catalog/glue/src/catalog.rs @@ -0,0 +1,600 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt::Debug; + +use async_trait::async_trait; +use aws_sdk_glue::types::TableInput; +use iceberg::io::FileIO; +use iceberg::spec::{TableMetadata, TableMetadataBuilder}; +use iceberg::table::Table; +use iceberg::{ + Catalog, Error, ErrorKind, Namespace, NamespaceIdent, Result, TableCommit, TableCreation, + TableIdent, +}; +use typed_builder::TypedBuilder; + +use crate::error::{from_aws_build_error, from_aws_sdk_error}; +use crate::utils::{ + convert_to_database, convert_to_glue_table, convert_to_namespace, create_metadata_location, + create_sdk_config, get_default_table_location, get_metadata_location, validate_namespace, +}; +use crate::with_catalog_id; + +#[derive(Debug, TypedBuilder)] +/// Glue Catalog configuration +pub struct GlueCatalogConfig { + #[builder(default, setter(strip_option))] + uri: Option, + #[builder(default, setter(strip_option))] + catalog_id: Option, + warehouse: String, + #[builder(default)] + props: HashMap, +} + +struct GlueClient(aws_sdk_glue::Client); + +/// Glue Catalog +pub struct GlueCatalog { + config: GlueCatalogConfig, + client: GlueClient, + file_io: FileIO, +} + +impl Debug for GlueCatalog { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GlueCatalog") + .field("config", &self.config) + .finish_non_exhaustive() + } +} + +impl GlueCatalog { + /// Create a new glue catalog + pub async fn new(config: GlueCatalogConfig) -> Result { + let sdk_config = create_sdk_config(&config.props, config.uri.as_ref()).await; + + let client = aws_sdk_glue::Client::new(&sdk_config); + + let file_io = FileIO::from_path(&config.warehouse)? + .with_props(&config.props) + .build()?; + + Ok(GlueCatalog { + config, + client: GlueClient(client), + file_io, + }) + } + /// Get the catalogs `FileIO` + pub fn file_io(&self) -> FileIO { + self.file_io.clone() + } +} + +#[async_trait] +impl Catalog for GlueCatalog { + /// List namespaces from glue catalog. + /// + /// Glue doesn't support nested namespaces. + /// We will return an empty list if parent is some. + async fn list_namespaces( + &self, + parent: Option<&NamespaceIdent>, + ) -> Result> { + if parent.is_some() { + return Ok(vec![]); + } + + let mut database_list: Vec = Vec::new(); + let mut next_token: Option = None; + + loop { + let builder = match &next_token { + Some(token) => self.client.0.get_databases().next_token(token), + None => self.client.0.get_databases(), + }; + let builder = with_catalog_id!(builder, self.config); + let resp = builder.send().await.map_err(from_aws_sdk_error)?; + + let dbs: Vec = resp + .database_list() + .iter() + .map(|db| NamespaceIdent::new(db.name().to_string())) + .collect(); + + database_list.extend(dbs); + + next_token = resp.next_token().map(ToOwned::to_owned); + if next_token.is_none() { + break; + } + } + + Ok(database_list) + } + + /// Creates a new namespace with the given identifier and properties. + /// + /// Attempts to create a namespace defined by the `namespace` + /// parameter and configured with the specified `properties`. + /// + /// This function can return an error in the following situations: + /// + /// - Errors from `validate_namespace` if the namespace identifier does not + /// meet validation criteria. + /// - Errors from `convert_to_database` if the properties cannot be + /// successfully converted into a database configuration. + /// - Errors from the underlying database creation process, converted using + /// `from_sdk_error`. + async fn create_namespace( + &self, + namespace: &NamespaceIdent, + properties: HashMap, + ) -> Result { + let db_input = convert_to_database(namespace, &properties)?; + + let builder = self.client.0.create_database().database_input(db_input); + let builder = with_catalog_id!(builder, self.config); + + builder.send().await.map_err(from_aws_sdk_error)?; + + Ok(Namespace::with_properties(namespace.clone(), properties)) + } + + /// Retrieves a namespace by its identifier. + /// + /// Validates the given namespace identifier and then queries the + /// underlying database client to fetch the corresponding namespace data. + /// Constructs a `Namespace` object with the retrieved data and returns it. + /// + /// This function can return an error in any of the following situations: + /// - If the provided namespace identifier fails validation checks + /// - If there is an error querying the database, returned by + /// `from_sdk_error`. + async fn get_namespace(&self, namespace: &NamespaceIdent) -> Result { + let db_name = validate_namespace(namespace)?; + + let builder = self.client.0.get_database().name(&db_name); + let builder = with_catalog_id!(builder, self.config); + + let resp = builder.send().await.map_err(from_aws_sdk_error)?; + + match resp.database() { + Some(db) => { + let namespace = convert_to_namespace(db); + Ok(namespace) + } + None => Err(Error::new( + ErrorKind::DataInvalid, + format!("Database with name: {} does not exist", db_name), + )), + } + } + + /// Checks if a namespace exists within the Glue Catalog. + /// + /// Validates the namespace identifier by querying the Glue Catalog + /// to determine if the specified namespace (database) exists. + /// + /// # Returns + /// A `Result` indicating the outcome of the check: + /// - `Ok(true)` if the namespace exists. + /// - `Ok(false)` if the namespace does not exist, identified by a specific + /// `EntityNotFoundException` variant. + /// - `Err(...)` if an error occurs during validation or the Glue Catalog + /// query, with the error encapsulating the issue. + async fn namespace_exists(&self, namespace: &NamespaceIdent) -> Result { + let db_name = validate_namespace(namespace)?; + + let builder = self.client.0.get_database().name(&db_name); + let builder = with_catalog_id!(builder, self.config); + + let resp = builder.send().await; + + match resp { + Ok(_) => Ok(true), + Err(err) => { + if err + .as_service_error() + .map(|e| e.is_entity_not_found_exception()) + == Some(true) + { + return Ok(false); + } + Err(from_aws_sdk_error(err)) + } + } + } + + /// Asynchronously updates properties of an existing namespace. + /// + /// Converts the given namespace identifier and properties into a database + /// representation and then attempts to update the corresponding namespace + /// in the Glue Catalog. + /// + /// # Returns + /// Returns `Ok(())` if the namespace update is successful. If the + /// namespace cannot be updated due to missing information or an error + /// during the update process, an `Err(...)` is returned. + async fn update_namespace( + &self, + namespace: &NamespaceIdent, + properties: HashMap, + ) -> Result<()> { + let db_name = validate_namespace(namespace)?; + let db_input = convert_to_database(namespace, &properties)?; + + let builder = self + .client + .0 + .update_database() + .name(&db_name) + .database_input(db_input); + let builder = with_catalog_id!(builder, self.config); + + builder.send().await.map_err(from_aws_sdk_error)?; + + Ok(()) + } + + /// Asynchronously drops a namespace from the Glue Catalog. + /// + /// Checks if the namespace is empty. If it still contains tables the + /// namespace will not be dropped, but an error is returned instead. + /// + /// # Returns + /// A `Result<()>` indicating the outcome: + /// - `Ok(())` signifies successful namespace deletion. + /// - `Err(...)` signifies failure to drop the namespace due to validation + /// errors, connectivity issues, or Glue Catalog constraints. + async fn drop_namespace(&self, namespace: &NamespaceIdent) -> Result<()> { + let db_name = validate_namespace(namespace)?; + let table_list = self.list_tables(namespace).await?; + + if !table_list.is_empty() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!("Database with name: {} is not empty", &db_name), + )); + } + + let builder = self.client.0.delete_database().name(db_name); + let builder = with_catalog_id!(builder, self.config); + + builder.send().await.map_err(from_aws_sdk_error)?; + + Ok(()) + } + + /// Asynchronously lists all tables within a specified namespace. + /// + /// # Returns + /// A `Result>`, which is: + /// - `Ok(vec![...])` containing a vector of `TableIdent` instances, each + /// representing a table within the specified namespace. + /// - `Err(...)` if an error occurs during namespace validation or while + /// querying the database. + async fn list_tables(&self, namespace: &NamespaceIdent) -> Result> { + let db_name = validate_namespace(namespace)?; + + let mut table_list: Vec = Vec::new(); + let mut next_token: Option = None; + + loop { + let builder = match &next_token { + Some(token) => self + .client + .0 + .get_tables() + .database_name(&db_name) + .next_token(token), + None => self.client.0.get_tables().database_name(&db_name), + }; + let builder = with_catalog_id!(builder, self.config); + let resp = builder.send().await.map_err(from_aws_sdk_error)?; + + let tables: Vec<_> = resp + .table_list() + .iter() + .map(|tbl| TableIdent::new(namespace.clone(), tbl.name().to_string())) + .collect(); + + table_list.extend(tables); + + next_token = resp.next_token().map(ToOwned::to_owned); + if next_token.is_none() { + break; + } + } + + Ok(table_list) + } + + /// Creates a new table within a specified namespace using the provided + /// table creation settings. + /// + /// # Returns + /// A `Result` wrapping a `Table` object representing the newly created + /// table. + /// + /// # Errors + /// This function may return an error in several cases, including invalid + /// namespace identifiers, failure to determine a default storage location, + /// issues generating or writing table metadata, and errors communicating + /// with the Glue Catalog. + async fn create_table( + &self, + namespace: &NamespaceIdent, + creation: TableCreation, + ) -> Result { + let db_name = validate_namespace(namespace)?; + let table_name = creation.name.clone(); + + let location = match &creation.location { + Some(location) => location.clone(), + None => { + let ns = self.get_namespace(namespace).await?; + get_default_table_location(&ns, &db_name, &table_name, &self.config.warehouse) + } + }; + + let metadata = TableMetadataBuilder::from_table_creation(creation)?.build()?; + let metadata_location = create_metadata_location(&location, 0)?; + + self.file_io + .new_output(&metadata_location)? + .write(serde_json::to_vec(&metadata)?.into()) + .await?; + + let glue_table = convert_to_glue_table( + &table_name, + metadata_location.clone(), + &metadata, + metadata.properties(), + None, + )?; + + let builder = self + .client + .0 + .create_table() + .database_name(&db_name) + .table_input(glue_table); + let builder = with_catalog_id!(builder, self.config); + + builder.send().await.map_err(from_aws_sdk_error)?; + + Table::builder() + .file_io(self.file_io()) + .metadata_location(metadata_location) + .metadata(metadata) + .identifier(TableIdent::new(NamespaceIdent::new(db_name), table_name)) + .build() + } + + /// Loads a table from the Glue Catalog and constructs a `Table` object + /// based on its metadata. + /// + /// # Returns + /// A `Result` wrapping a `Table` object that represents the loaded table. + /// + /// # Errors + /// This function may return an error in several scenarios, including: + /// - Failure to validate the namespace. + /// - Failure to retrieve the table from the Glue Catalog. + /// - Absence of metadata location information in the table's properties. + /// - Issues reading or deserializing the table's metadata file. + async fn load_table(&self, table: &TableIdent) -> Result
{ + let db_name = validate_namespace(table.namespace())?; + let table_name = table.name(); + + let builder = self + .client + .0 + .get_table() + .database_name(&db_name) + .name(table_name); + let builder = with_catalog_id!(builder, self.config); + + let glue_table_output = builder.send().await.map_err(from_aws_sdk_error)?; + + match glue_table_output.table() { + None => Err(Error::new( + ErrorKind::Unexpected, + format!( + "Table object for database: {} and table: {} does not exist", + db_name, table_name + ), + )), + Some(table) => { + let metadata_location = get_metadata_location(&table.parameters)?; + + let input_file = self.file_io.new_input(&metadata_location)?; + let metadata_content = input_file.read().await?; + let metadata = serde_json::from_slice::(&metadata_content)?; + + Table::builder() + .file_io(self.file_io()) + .metadata_location(metadata_location) + .metadata(metadata) + .identifier(TableIdent::new( + NamespaceIdent::new(db_name), + table_name.to_owned(), + )) + .build() + } + } + } + + /// Asynchronously drops a table from the database. + /// + /// # Errors + /// Returns an error if: + /// - The namespace provided in `table` cannot be validated + /// or does not exist. + /// - The underlying database client encounters an error while + /// attempting to drop the table. This includes scenarios where + /// the table does not exist. + /// - Any network or communication error occurs with the database backend. + async fn drop_table(&self, table: &TableIdent) -> Result<()> { + let db_name = validate_namespace(table.namespace())?; + let table_name = table.name(); + + let builder = self + .client + .0 + .delete_table() + .database_name(&db_name) + .name(table_name); + let builder = with_catalog_id!(builder, self.config); + + builder.send().await.map_err(from_aws_sdk_error)?; + + Ok(()) + } + + /// Asynchronously checks the existence of a specified table + /// in the database. + /// + /// # Returns + /// - `Ok(true)` if the table exists in the database. + /// - `Ok(false)` if the table does not exist in the database. + /// - `Err(...)` if an error occurs during the process + async fn table_exists(&self, table: &TableIdent) -> Result { + let db_name = validate_namespace(table.namespace())?; + let table_name = table.name(); + + let builder = self + .client + .0 + .get_table() + .database_name(&db_name) + .name(table_name); + let builder = with_catalog_id!(builder, self.config); + + let resp = builder.send().await; + + match resp { + Ok(_) => Ok(true), + Err(err) => { + if err + .as_service_error() + .map(|e| e.is_entity_not_found_exception()) + == Some(true) + { + return Ok(false); + } + Err(from_aws_sdk_error(err)) + } + } + } + + /// Asynchronously renames a table within the database + /// or moves it between namespaces (databases). + /// + /// # Returns + /// - `Ok(())` on successful rename or move of the table. + /// - `Err(...)` if an error occurs during the process. + async fn rename_table(&self, src: &TableIdent, dest: &TableIdent) -> Result<()> { + let src_db_name = validate_namespace(src.namespace())?; + let dest_db_name = validate_namespace(dest.namespace())?; + + let src_table_name = src.name(); + let dest_table_name = dest.name(); + + let builder = self + .client + .0 + .get_table() + .database_name(&src_db_name) + .name(src_table_name); + let builder = with_catalog_id!(builder, self.config); + + let glue_table_output = builder.send().await.map_err(from_aws_sdk_error)?; + + match glue_table_output.table() { + None => Err(Error::new( + ErrorKind::Unexpected, + format!( + "'Table' object for database: {} and table: {} does not exist", + src_db_name, src_table_name + ), + )), + Some(table) => { + let rename_table_input = TableInput::builder() + .name(dest_table_name) + .set_parameters(table.parameters.clone()) + .set_storage_descriptor(table.storage_descriptor.clone()) + .set_table_type(table.table_type.clone()) + .set_description(table.description.clone()) + .build() + .map_err(from_aws_build_error)?; + + let builder = self + .client + .0 + .create_table() + .database_name(&dest_db_name) + .table_input(rename_table_input); + let builder = with_catalog_id!(builder, self.config); + + builder.send().await.map_err(from_aws_sdk_error)?; + + let drop_src_table_result = self.drop_table(src).await; + + match drop_src_table_result { + Ok(_) => Ok(()), + Err(_) => { + let err_msg_src_table = format!( + "Failed to drop old table {}.{}.", + src_db_name, src_table_name + ); + + let drop_dest_table_result = self.drop_table(dest).await; + + match drop_dest_table_result { + Ok(_) => Err(Error::new( + ErrorKind::Unexpected, + format!( + "{} Rolled back table creation for {}.{}.", + err_msg_src_table, dest_db_name, dest_table_name + ), + )), + Err(_) => Err(Error::new( + ErrorKind::Unexpected, + format!( + "{} Failed to roll back table creation for {}.{}. Please clean up manually.", + err_msg_src_table, dest_db_name, dest_table_name + ), + )), + } + } + } + } + } + } + + async fn update_table(&self, _commit: TableCommit) -> Result
{ + Err(Error::new( + ErrorKind::FeatureUnsupported, + "Updating a table is not supported yet", + )) + } +} diff --git a/crates/catalog/glue/src/error.rs b/crates/catalog/glue/src/error.rs new file mode 100644 index 000000000..a94f6c220 --- /dev/null +++ b/crates/catalog/glue/src/error.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; + +use anyhow::anyhow; +use iceberg::{Error, ErrorKind}; + +/// Format AWS SDK error into iceberg error +pub(crate) fn from_aws_sdk_error(error: aws_sdk_glue::error::SdkError) -> Error +where T: Debug { + Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting aws skd error".to_string(), + ) + .with_source(anyhow!("aws sdk error: {:?}", error)) +} + +/// Format AWS Build error into iceberg error +pub(crate) fn from_aws_build_error(error: aws_sdk_glue::error::BuildError) -> Error { + Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting aws build error".to_string(), + ) + .with_source(anyhow!("aws build error: {:?}", error)) +} diff --git a/crates/catalog/glue/src/lib.rs b/crates/catalog/glue/src/lib.rs new file mode 100644 index 000000000..237657335 --- /dev/null +++ b/crates/catalog/glue/src/lib.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Iceberg Glue Catalog implementation. + +#![deny(missing_docs)] + +mod catalog; +mod error; +mod schema; +mod utils; +pub use catalog::*; +pub use utils::{ + AWS_ACCESS_KEY_ID, AWS_PROFILE_NAME, AWS_REGION_NAME, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN, +}; diff --git a/crates/catalog/glue/src/schema.rs b/crates/catalog/glue/src/schema.rs new file mode 100644 index 000000000..bb676e36e --- /dev/null +++ b/crates/catalog/glue/src/schema.rs @@ -0,0 +1,482 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Property `iceberg.field.id` for `Column` +pub(crate) const ICEBERG_FIELD_ID: &str = "iceberg.field.id"; +/// Property `iceberg.field.optional` for `Column` +pub(crate) const ICEBERG_FIELD_OPTIONAL: &str = "iceberg.field.optional"; +/// Property `iceberg.field.current` for `Column` +pub(crate) const ICEBERG_FIELD_CURRENT: &str = "iceberg.field.current"; + +use std::collections::HashMap; + +use aws_sdk_glue::types::Column; +use iceberg::spec::{visit_schema, PrimitiveType, SchemaVisitor, TableMetadata}; +use iceberg::{Error, ErrorKind, Result}; + +use crate::error::from_aws_build_error; + +type GlueSchema = Vec; + +#[derive(Debug, Default)] +pub(crate) struct GlueSchemaBuilder { + schema: GlueSchema, + is_current: bool, + depth: usize, +} + +impl GlueSchemaBuilder { + /// Creates a new `GlueSchemaBuilder` from iceberg `Schema` + pub fn from_iceberg(metadata: &TableMetadata) -> Result { + let current_schema = metadata.current_schema(); + + let mut builder = Self { + schema: Vec::new(), + is_current: true, + depth: 0, + }; + + visit_schema(current_schema, &mut builder)?; + + builder.is_current = false; + + for schema in metadata.schemas_iter() { + if schema.schema_id() == current_schema.schema_id() { + continue; + } + + visit_schema(schema, &mut builder)?; + } + + Ok(builder) + } + + /// Returns the newly converted `GlueSchema` + pub fn build(self) -> GlueSchema { + self.schema + } + + /// Check if is in `StructType` while traversing schema + fn is_inside_struct(&self) -> bool { + self.depth > 0 + } +} + +impl SchemaVisitor for GlueSchemaBuilder { + type T = String; + + fn schema( + &mut self, + _schema: &iceberg::spec::Schema, + value: Self::T, + ) -> iceberg::Result { + Ok(value) + } + + fn before_struct_field(&mut self, _field: &iceberg::spec::NestedFieldRef) -> Result<()> { + self.depth += 1; + Ok(()) + } + + fn r#struct( + &mut self, + r#_struct: &iceberg::spec::StructType, + results: Vec, + ) -> iceberg::Result { + Ok(format!("struct<{}>", results.join(", "))) + } + + fn after_struct_field(&mut self, _field: &iceberg::spec::NestedFieldRef) -> Result<()> { + self.depth -= 1; + Ok(()) + } + + fn field( + &mut self, + field: &iceberg::spec::NestedFieldRef, + value: String, + ) -> iceberg::Result { + if self.is_inside_struct() { + return Ok(format!("{}:{}", field.name, &value)); + } + + let parameters = HashMap::from([ + (ICEBERG_FIELD_ID.to_string(), format!("{}", field.id)), + ( + ICEBERG_FIELD_OPTIONAL.to_string(), + format!("{}", field.required).to_lowercase(), + ), + ( + ICEBERG_FIELD_CURRENT.to_string(), + format!("{}", self.is_current).to_lowercase(), + ), + ]); + + let mut builder = Column::builder() + .name(field.name.clone()) + .r#type(&value) + .set_parameters(Some(parameters)); + + if let Some(comment) = field.doc.as_ref() { + builder = builder.comment(comment); + } + + let column = builder.build().map_err(from_aws_build_error)?; + + self.schema.push(column); + + Ok(value) + } + + fn list(&mut self, _list: &iceberg::spec::ListType, value: String) -> iceberg::Result { + Ok(format!("array<{}>", value)) + } + + fn map( + &mut self, + _map: &iceberg::spec::MapType, + key_value: String, + value: String, + ) -> iceberg::Result { + Ok(format!("map<{},{}>", key_value, value)) + } + + fn primitive(&mut self, p: &iceberg::spec::PrimitiveType) -> iceberg::Result { + let glue_type = match p { + PrimitiveType::Boolean => "boolean".to_string(), + PrimitiveType::Int => "int".to_string(), + PrimitiveType::Long => "bigint".to_string(), + PrimitiveType::Float => "float".to_string(), + PrimitiveType::Double => "double".to_string(), + PrimitiveType::Date => "date".to_string(), + PrimitiveType::Timestamp => "timestamp".to_string(), + PrimitiveType::TimestampNs => "timestamp_ns".to_string(), + PrimitiveType::TimestamptzNs => "timestamptz_ns".to_string(), + PrimitiveType::Time | PrimitiveType::String | PrimitiveType::Uuid => { + "string".to_string() + } + PrimitiveType::Binary | PrimitiveType::Fixed(_) => "binary".to_string(), + PrimitiveType::Decimal { precision, scale } => { + format!("decimal({},{})", precision, scale) + } + _ => { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + "Conversion from 'Timestamptz' is not supported", + )) + } + }; + + Ok(glue_type) + } +} + +#[cfg(test)] +mod tests { + use iceberg::spec::{Schema, TableMetadataBuilder}; + use iceberg::TableCreation; + + use super::*; + + fn create_metadata(schema: Schema) -> Result { + let table_creation = TableCreation::builder() + .name("my_table".to_string()) + .location("my_location".to_string()) + .schema(schema) + .build(); + let metadata = TableMetadataBuilder::from_table_creation(table_creation)?.build()?; + + Ok(metadata) + } + + fn create_column( + name: impl Into, + r#type: impl Into, + id: impl Into, + ) -> Result { + let parameters = HashMap::from([ + (ICEBERG_FIELD_ID.to_string(), id.into()), + (ICEBERG_FIELD_OPTIONAL.to_string(), "true".to_string()), + (ICEBERG_FIELD_CURRENT.to_string(), "true".to_string()), + ]); + + Column::builder() + .name(name) + .r#type(r#type) + .set_comment(None) + .set_parameters(Some(parameters)) + .build() + .map_err(from_aws_build_error) + } + + #[test] + fn test_schema_with_simple_fields() -> Result<()> { + let record = r#"{ + "type": "struct", + "schema-id": 1, + "fields": [ + { + "id": 1, + "name": "c1", + "required": true, + "type": "boolean" + }, + { + "id": 2, + "name": "c2", + "required": true, + "type": "int" + }, + { + "id": 3, + "name": "c3", + "required": true, + "type": "long" + }, + { + "id": 4, + "name": "c4", + "required": true, + "type": "float" + }, + { + "id": 5, + "name": "c5", + "required": true, + "type": "double" + }, + { + "id": 6, + "name": "c6", + "required": true, + "type": "decimal(2,2)" + }, + { + "id": 7, + "name": "c7", + "required": true, + "type": "date" + }, + { + "id": 8, + "name": "c8", + "required": true, + "type": "time" + }, + { + "id": 9, + "name": "c9", + "required": true, + "type": "timestamp" + }, + { + "id": 10, + "name": "c10", + "required": true, + "type": "string" + }, + { + "id": 11, + "name": "c11", + "required": true, + "type": "uuid" + }, + { + "id": 12, + "name": "c12", + "required": true, + "type": "fixed[4]" + }, + { + "id": 13, + "name": "c13", + "required": true, + "type": "binary" + } + ] + }"#; + + let schema = serde_json::from_str::(record)?; + let metadata = create_metadata(schema)?; + + let result = GlueSchemaBuilder::from_iceberg(&metadata)?.build(); + + let expected = vec![ + create_column("c1", "boolean", "1")?, + create_column("c2", "int", "2")?, + create_column("c3", "bigint", "3")?, + create_column("c4", "float", "4")?, + create_column("c5", "double", "5")?, + create_column("c6", "decimal(2,2)", "6")?, + create_column("c7", "date", "7")?, + create_column("c8", "string", "8")?, + create_column("c9", "timestamp", "9")?, + create_column("c10", "string", "10")?, + create_column("c11", "string", "11")?, + create_column("c12", "binary", "12")?, + create_column("c13", "binary", "13")?, + ]; + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_schema_with_structs() -> Result<()> { + let record = r#"{ + "type": "struct", + "schema-id": 1, + "fields": [ + { + "id": 1, + "name": "person", + "required": true, + "type": { + "type": "struct", + "fields": [ + { + "id": 2, + "name": "name", + "required": true, + "type": "string" + }, + { + "id": 3, + "name": "age", + "required": false, + "type": "int" + } + ] + } + } + ] + }"#; + + let schema = serde_json::from_str::(record)?; + let metadata = create_metadata(schema)?; + + let result = GlueSchemaBuilder::from_iceberg(&metadata)?.build(); + + let expected = vec![create_column( + "person", + "struct", + "1", + )?]; + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_schema_with_struct_inside_list() -> Result<()> { + let record = r#" + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "location", + "required": true, + "type": { + "type": "list", + "element-id": 2, + "element-required": true, + "element": { + "type": "struct", + "fields": [ + { + "id": 3, + "name": "latitude", + "required": false, + "type": "float" + }, + { + "id": 4, + "name": "longitude", + "required": false, + "type": "float" + } + ] + } + } + } + ] + } + "#; + + let schema = serde_json::from_str::(record)?; + let metadata = create_metadata(schema)?; + + let result = GlueSchemaBuilder::from_iceberg(&metadata)?.build(); + + let expected = vec![create_column( + "location", + "array>", + "1", + )?]; + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_schema_with_nested_maps() -> Result<()> { + let record = r#" + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "quux", + "required": true, + "type": { + "type": "map", + "key-id": 2, + "key": "string", + "value-id": 3, + "value-required": true, + "value": { + "type": "map", + "key-id": 4, + "key": "string", + "value-id": 5, + "value-required": true, + "value": "int" + } + } + } + ] + } + "#; + + let schema = serde_json::from_str::(record)?; + let metadata = create_metadata(schema)?; + + let result = GlueSchemaBuilder::from_iceberg(&metadata)?.build(); + + let expected = vec![create_column("quux", "map>", "1")?]; + + assert_eq!(result, expected); + + Ok(()) + } +} diff --git a/crates/catalog/glue/src/utils.rs b/crates/catalog/glue/src/utils.rs new file mode 100644 index 000000000..a99fb19c7 --- /dev/null +++ b/crates/catalog/glue/src/utils.rs @@ -0,0 +1,518 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use aws_config::{BehaviorVersion, Region, SdkConfig}; +use aws_sdk_glue::config::Credentials; +use aws_sdk_glue::types::{Database, DatabaseInput, StorageDescriptor, TableInput}; +use iceberg::spec::TableMetadata; +use iceberg::{Error, ErrorKind, Namespace, NamespaceIdent, Result}; +use uuid::Uuid; + +use crate::error::from_aws_build_error; +use crate::schema::GlueSchemaBuilder; + +/// Property aws profile name +pub const AWS_PROFILE_NAME: &str = "profile_name"; +/// Property aws region +pub const AWS_REGION_NAME: &str = "region_name"; +/// Property aws access key +pub const AWS_ACCESS_KEY_ID: &str = "aws_access_key_id"; +/// Property aws secret access key +pub const AWS_SECRET_ACCESS_KEY: &str = "aws_secret_access_key"; +/// Property aws session token +pub const AWS_SESSION_TOKEN: &str = "aws_session_token"; +/// Parameter namespace description +const DESCRIPTION: &str = "description"; +/// Parameter namespace location uri +const LOCATION: &str = "location_uri"; +/// Property `metadata_location` for `TableInput` +const METADATA_LOCATION: &str = "metadata_location"; +/// Property `previous_metadata_location` for `TableInput` +const PREV_METADATA_LOCATION: &str = "previous_metadata_location"; +/// Property external table for `TableInput` +const EXTERNAL_TABLE: &str = "EXTERNAL_TABLE"; +/// Parameter key `table_type` for `TableInput` +const TABLE_TYPE: &str = "table_type"; +/// Parameter value `table_type` for `TableInput` +const ICEBERG: &str = "ICEBERG"; + +/// Creates an aws sdk configuration based on +/// provided properties and an optional endpoint URL. +pub(crate) async fn create_sdk_config( + properties: &HashMap, + endpoint_uri: Option<&String>, +) -> SdkConfig { + let mut config = aws_config::defaults(BehaviorVersion::latest()); + + if let Some(endpoint) = endpoint_uri { + config = config.endpoint_url(endpoint) + }; + + if properties.is_empty() { + return config.load().await; + } + + if let (Some(access_key), Some(secret_key)) = ( + properties.get(AWS_ACCESS_KEY_ID), + properties.get(AWS_SECRET_ACCESS_KEY), + ) { + let session_token = properties.get(AWS_SESSION_TOKEN).cloned(); + let credentials_provider = + Credentials::new(access_key, secret_key, session_token, None, "properties"); + + config = config.credentials_provider(credentials_provider) + }; + + if let Some(profile_name) = properties.get(AWS_PROFILE_NAME) { + config = config.profile_name(profile_name); + } + + if let Some(region_name) = properties.get(AWS_REGION_NAME) { + let region = Region::new(region_name.clone()); + config = config.region(region); + } + + config.load().await +} + +/// Create `DatabaseInput` from `NamespaceIdent` and properties +pub(crate) fn convert_to_database( + namespace: &NamespaceIdent, + properties: &HashMap, +) -> Result { + let db_name = validate_namespace(namespace)?; + let mut builder = DatabaseInput::builder().name(db_name); + + for (k, v) in properties.iter() { + match k.as_ref() { + DESCRIPTION => { + builder = builder.description(v); + } + LOCATION => { + builder = builder.location_uri(v); + } + _ => { + builder = builder.parameters(k, v); + } + } + } + + builder.build().map_err(from_aws_build_error) +} + +/// Create `Namespace` from aws sdk glue `Database` +pub(crate) fn convert_to_namespace(database: &Database) -> Namespace { + let db_name = database.name().to_string(); + let mut properties = database + .parameters() + .map_or_else(HashMap::new, |p| p.clone()); + + if let Some(location_uri) = database.location_uri() { + properties.insert(LOCATION.to_string(), location_uri.to_string()); + }; + + if let Some(description) = database.description() { + properties.insert(DESCRIPTION.to_string(), description.to_string()); + } + + Namespace::with_properties(NamespaceIdent::new(db_name), properties) +} + +/// Converts Iceberg table metadata into an +/// AWS Glue `TableInput` representation. +/// +/// This function facilitates the integration of Iceberg tables with AWS Glue +/// by converting Iceberg table metadata into a Glue-compatible `TableInput` +/// structure. +pub(crate) fn convert_to_glue_table( + table_name: impl Into, + metadata_location: String, + metadata: &TableMetadata, + properties: &HashMap, + prev_metadata_location: Option, +) -> Result { + let glue_schema = GlueSchemaBuilder::from_iceberg(metadata)?.build(); + + let storage_descriptor = StorageDescriptor::builder() + .set_columns(Some(glue_schema)) + .location(&metadata_location) + .build(); + + let mut parameters = HashMap::from([ + (TABLE_TYPE.to_string(), ICEBERG.to_string()), + (METADATA_LOCATION.to_string(), metadata_location), + ]); + + if let Some(prev) = prev_metadata_location { + parameters.insert(PREV_METADATA_LOCATION.to_string(), prev); + } + + let mut table_input_builder = TableInput::builder() + .name(table_name) + .set_parameters(Some(parameters)) + .storage_descriptor(storage_descriptor) + .table_type(EXTERNAL_TABLE); + + if let Some(description) = properties.get(DESCRIPTION) { + table_input_builder = table_input_builder.description(description); + } + + let table_input = table_input_builder.build().map_err(from_aws_build_error)?; + + Ok(table_input) +} + +/// Checks if provided `NamespaceIdent` is valid +pub(crate) fn validate_namespace(namespace: &NamespaceIdent) -> Result { + let name = namespace.as_ref(); + + if name.len() != 1 { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Invalid database name: {:?}, hierarchical namespaces are not supported", + namespace + ), + )); + } + + let name = name[0].clone(); + + if name.is_empty() { + return Err(Error::new( + ErrorKind::DataInvalid, + "Invalid database, provided namespace is empty.", + )); + } + + Ok(name) +} + +/// Get default table location from `Namespace` properties +pub(crate) fn get_default_table_location( + namespace: &Namespace, + db_name: impl AsRef, + table_name: impl AsRef, + warehouse: impl AsRef, +) -> String { + let properties = namespace.properties(); + + match properties.get(LOCATION) { + Some(location) => format!("{}/{}", location, table_name.as_ref()), + None => { + let warehouse_location = warehouse.as_ref().trim_end_matches('/'); + + format!( + "{}/{}.db/{}", + warehouse_location, + db_name.as_ref(), + table_name.as_ref() + ) + } + } +} + +/// Create metadata location from `location` and `version` +pub(crate) fn create_metadata_location(location: impl AsRef, version: i32) -> Result { + if version < 0 { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Table metadata version: '{}' must be a non-negative integer", + version + ), + )); + }; + + let version = format!("{:0>5}", version); + let id = Uuid::new_v4(); + let metadata_location = format!( + "{}/metadata/{}-{}.metadata.json", + location.as_ref(), + version, + id + ); + + Ok(metadata_location) +} + +/// Get metadata location from `GlueTable` parameters +pub(crate) fn get_metadata_location( + parameters: &Option>, +) -> Result { + match parameters { + Some(properties) => match properties.get(METADATA_LOCATION) { + Some(location) => Ok(location.to_string()), + None => Err(Error::new( + ErrorKind::DataInvalid, + format!("No '{}' set on table", METADATA_LOCATION), + )), + }, + None => Err(Error::new( + ErrorKind::DataInvalid, + "No 'parameters' set on table. Location of metadata is undefined", + )), + } +} + +#[macro_export] +/// Extends aws sdk builder with `catalog_id` if present +macro_rules! with_catalog_id { + ($builder:expr, $config:expr) => {{ + if let Some(catalog_id) = &$config.catalog_id { + $builder.catalog_id(catalog_id) + } else { + $builder + } + }}; +} + +#[cfg(test)] +mod tests { + use aws_sdk_glue::config::ProvideCredentials; + use aws_sdk_glue::types::Column; + use iceberg::spec::{NestedField, PrimitiveType, Schema, TableMetadataBuilder, Type}; + use iceberg::{Namespace, Result, TableCreation}; + + use super::*; + use crate::schema::{ICEBERG_FIELD_CURRENT, ICEBERG_FIELD_ID, ICEBERG_FIELD_OPTIONAL}; + + fn create_metadata(schema: Schema) -> Result { + let table_creation = TableCreation::builder() + .name("my_table".to_string()) + .location("my_location".to_string()) + .schema(schema) + .build(); + let metadata = TableMetadataBuilder::from_table_creation(table_creation)?.build()?; + + Ok(metadata) + } + + #[test] + fn test_get_metadata_location() -> Result<()> { + let params_valid = Some(HashMap::from([( + METADATA_LOCATION.to_string(), + "my_location".to_string(), + )])); + let params_missing_key = Some(HashMap::from([( + "not_here".to_string(), + "my_location".to_string(), + )])); + + let result_valid = get_metadata_location(¶ms_valid)?; + let result_missing_key = get_metadata_location(¶ms_missing_key); + let result_no_params = get_metadata_location(&None); + + assert_eq!(result_valid, "my_location"); + assert!(result_missing_key.is_err()); + assert!(result_no_params.is_err()); + + Ok(()) + } + + #[test] + fn test_convert_to_glue_table() -> Result<()> { + let table_name = "my_table".to_string(); + let location = "s3a://warehouse/hive".to_string(); + let metadata_location = create_metadata_location(location.clone(), 0)?; + let properties = HashMap::new(); + let schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![NestedField::required( + 1, + "foo", + Type::Primitive(PrimitiveType::Int), + ) + .into()]) + .build()?; + + let metadata = create_metadata(schema)?; + + let parameters = HashMap::from([ + (ICEBERG_FIELD_ID.to_string(), "1".to_string()), + (ICEBERG_FIELD_OPTIONAL.to_string(), "true".to_string()), + (ICEBERG_FIELD_CURRENT.to_string(), "true".to_string()), + ]); + + let column = Column::builder() + .name("foo") + .r#type("int") + .set_parameters(Some(parameters)) + .set_comment(None) + .build() + .map_err(from_aws_build_error)?; + + let storage_descriptor = StorageDescriptor::builder() + .set_columns(Some(vec![column])) + .location(&metadata_location) + .build(); + + let result = + convert_to_glue_table(&table_name, metadata_location, &metadata, &properties, None)?; + + assert_eq!(result.name(), &table_name); + assert_eq!(result.description(), None); + assert_eq!(result.storage_descriptor, Some(storage_descriptor)); + + Ok(()) + } + + #[test] + fn test_create_metadata_location() -> Result<()> { + let location = "my_base_location"; + let valid_version = 0; + let invalid_version = -1; + + let valid_result = create_metadata_location(location, valid_version)?; + let invalid_result = create_metadata_location(location, invalid_version); + + assert!(valid_result.starts_with("my_base_location/metadata/00000-")); + assert!(valid_result.ends_with(".metadata.json")); + assert!(invalid_result.is_err()); + + Ok(()) + } + + #[test] + fn test_get_default_table_location() -> Result<()> { + let properties = HashMap::from([(LOCATION.to_string(), "db_location".to_string())]); + + let namespace = + Namespace::with_properties(NamespaceIdent::new("default".into()), properties); + let db_name = validate_namespace(namespace.name())?; + let table_name = "my_table"; + + let expected = "db_location/my_table"; + let result = + get_default_table_location(&namespace, db_name, table_name, "warehouse_location"); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_get_default_table_location_warehouse() -> Result<()> { + let namespace = Namespace::new(NamespaceIdent::new("default".into())); + let db_name = validate_namespace(namespace.name())?; + let table_name = "my_table"; + + let expected = "warehouse_location/default.db/my_table"; + let result = + get_default_table_location(&namespace, db_name, table_name, "warehouse_location"); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_convert_to_namespace() -> Result<()> { + let db = Database::builder() + .name("my_db") + .location_uri("my_location") + .description("my_description") + .build() + .map_err(from_aws_build_error)?; + + let properties = HashMap::from([ + (DESCRIPTION.to_string(), "my_description".to_string()), + (LOCATION.to_string(), "my_location".to_string()), + ]); + + let expected = + Namespace::with_properties(NamespaceIdent::new("my_db".to_string()), properties); + let result = convert_to_namespace(&db); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_convert_to_database() -> Result<()> { + let namespace = NamespaceIdent::new("my_database".to_string()); + let properties = HashMap::from([(LOCATION.to_string(), "my_location".to_string())]); + + let result = convert_to_database(&namespace, &properties)?; + + assert_eq!("my_database", result.name()); + assert_eq!(Some("my_location".to_string()), result.location_uri); + + Ok(()) + } + + #[test] + fn test_validate_namespace() { + let valid_ns = Namespace::new(NamespaceIdent::new("ns".to_string())); + let empty_ns = Namespace::new(NamespaceIdent::new("".to_string())); + let hierarchical_ns = Namespace::new( + NamespaceIdent::from_vec(vec!["level1".to_string(), "level2".to_string()]).unwrap(), + ); + + let valid = validate_namespace(valid_ns.name()); + let empty = validate_namespace(empty_ns.name()); + let hierarchical = validate_namespace(hierarchical_ns.name()); + + assert!(valid.is_ok()); + assert!(empty.is_err()); + assert!(hierarchical.is_err()); + } + + #[tokio::test] + async fn test_config_with_custom_endpoint() { + let properties = HashMap::new(); + let endpoint_url = "http://custom_url:5000"; + + let sdk_config = create_sdk_config(&properties, Some(&endpoint_url.to_string())).await; + + let result = sdk_config.endpoint_url().unwrap(); + + assert_eq!(result, endpoint_url); + } + + #[tokio::test] + async fn test_config_with_properties() { + let properties = HashMap::from([ + (AWS_PROFILE_NAME.to_string(), "my_profile".to_string()), + (AWS_REGION_NAME.to_string(), "us-east-1".to_string()), + (AWS_ACCESS_KEY_ID.to_string(), "my-access-id".to_string()), + ( + AWS_SECRET_ACCESS_KEY.to_string(), + "my-secret-key".to_string(), + ), + (AWS_SESSION_TOKEN.to_string(), "my-token".to_string()), + ]); + + let sdk_config = create_sdk_config(&properties, None).await; + + let region = sdk_config.region().unwrap().as_ref(); + let credentials = sdk_config + .credentials_provider() + .unwrap() + .provide_credentials() + .await + .unwrap(); + + assert_eq!("us-east-1", region); + assert_eq!("my-access-id", credentials.access_key_id()); + assert_eq!("my-secret-key", credentials.secret_access_key()); + assert_eq!("my-token", credentials.session_token().unwrap()); + } +} diff --git a/crates/catalog/glue/testdata/glue_catalog/docker-compose.yaml b/crates/catalog/glue/testdata/glue_catalog/docker-compose.yaml new file mode 100644 index 000000000..0a2c938a7 --- /dev/null +++ b/crates/catalog/glue/testdata/glue_catalog/docker-compose.yaml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +services: + minio: + image: minio/minio:RELEASE.2024-03-07T00-43-48Z + expose: + - 9000 + - 9001 + environment: + - MINIO_ROOT_USER=admin + - MINIO_ROOT_PASSWORD=password + - MINIO_DOMAIN=minio + command: [ "server", "/data", "--console-address", ":9001" ] + + mc: + depends_on: + - minio + image: minio/mc:RELEASE.2024-03-07T00-31-49Z + environment: + - AWS_ACCESS_KEY_ID=admin + - AWS_SECRET_ACCESS_KEY=password + - AWS_REGION=us-east-1 + entrypoint: > + /bin/sh -c " until (/usr/bin/mc config host add minio http://minio:9000 admin password) do echo '...waiting...' && sleep 1; done; /usr/bin/mc mb minio/warehouse; /usr/bin/mc policy set public minio/warehouse; tail -f /dev/null " + + moto: + image: motoserver/moto:5.0.3 + expose: + - 5000 diff --git a/crates/catalog/glue/tests/glue_catalog_test.rs b/crates/catalog/glue/tests/glue_catalog_test.rs new file mode 100644 index 000000000..d9c5b4e0b --- /dev/null +++ b/crates/catalog/glue/tests/glue_catalog_test.rs @@ -0,0 +1,367 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for glue catalog. + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::RwLock; + +use ctor::{ctor, dtor}; +use iceberg::io::{S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY}; +use iceberg::spec::{NestedField, PrimitiveType, Schema, Type}; +use iceberg::{Catalog, Namespace, NamespaceIdent, Result, TableCreation, TableIdent}; +use iceberg_catalog_glue::{ + GlueCatalog, GlueCatalogConfig, AWS_ACCESS_KEY_ID, AWS_REGION_NAME, AWS_SECRET_ACCESS_KEY, +}; +use iceberg_test_utils::docker::DockerCompose; +use iceberg_test_utils::{normalize_test_name, set_up}; +use port_scanner::scan_port_addr; +use tokio::time::sleep; + +const GLUE_CATALOG_PORT: u16 = 5000; +const MINIO_PORT: u16 = 9000; +static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); + +#[ctor] +fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + let docker_compose = DockerCompose::new( + normalize_test_name(module_path!()), + format!("{}/testdata/glue_catalog", env!("CARGO_MANIFEST_DIR")), + ); + docker_compose.run(); + guard.replace(docker_compose); +} + +#[dtor] +fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); +} + +async fn get_catalog() -> GlueCatalog { + set_up(); + + let (glue_catalog_ip, minio_ip) = { + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + ( + docker_compose.get_container_ip("moto"), + docker_compose.get_container_ip("minio"), + ) + }; + let glue_socket_addr = SocketAddr::new(glue_catalog_ip, GLUE_CATALOG_PORT); + let minio_socket_addr = SocketAddr::new(minio_ip, MINIO_PORT); + while !scan_port_addr(glue_socket_addr) { + log::info!("Waiting for 1s glue catalog to ready..."); + sleep(std::time::Duration::from_millis(1000)).await; + } + + let props = HashMap::from([ + (AWS_ACCESS_KEY_ID.to_string(), "my_access_id".to_string()), + ( + AWS_SECRET_ACCESS_KEY.to_string(), + "my_secret_key".to_string(), + ), + (AWS_REGION_NAME.to_string(), "us-east-1".to_string()), + ( + S3_ENDPOINT.to_string(), + format!("http://{}", minio_socket_addr), + ), + (S3_ACCESS_KEY_ID.to_string(), "admin".to_string()), + (S3_SECRET_ACCESS_KEY.to_string(), "password".to_string()), + (S3_REGION.to_string(), "us-east-1".to_string()), + ]); + + let config = GlueCatalogConfig::builder() + .uri(format!("http://{}", glue_socket_addr)) + .warehouse("s3a://warehouse/hive".to_string()) + .props(props.clone()) + .build(); + + GlueCatalog::new(config).await.unwrap() +} + +async fn set_test_namespace(catalog: &GlueCatalog, namespace: &NamespaceIdent) -> Result<()> { + let properties = HashMap::new(); + catalog.create_namespace(namespace, properties).await?; + + Ok(()) +} + +fn set_table_creation(location: impl ToString, name: impl ToString) -> Result { + let schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "foo", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build()?; + + let creation = TableCreation::builder() + .location(location.to_string()) + .name(name.to_string()) + .properties(HashMap::new()) + .schema(schema) + .build(); + + Ok(creation) +} + +#[tokio::test] +async fn test_rename_table() -> Result<()> { + let catalog = get_catalog().await; + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_rename_table".into())); + + catalog + .create_namespace(namespace.name(), HashMap::new()) + .await?; + + let table = catalog.create_table(namespace.name(), creation).await?; + + let dest = TableIdent::new(namespace.name().clone(), "my_table_rename".to_string()); + + catalog.rename_table(table.identifier(), &dest).await?; + + let table = catalog.load_table(&dest).await?; + assert_eq!(table.identifier(), &dest); + + let src = TableIdent::new(namespace.name().clone(), "my_table".to_string()); + + let src_table_exists = catalog.table_exists(&src).await?; + assert!(!src_table_exists); + + Ok(()) +} + +#[tokio::test] +async fn test_table_exists() -> Result<()> { + let catalog = get_catalog().await; + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_table_exists".into())); + + catalog + .create_namespace(namespace.name(), HashMap::new()) + .await?; + + let ident = TableIdent::new(namespace.name().clone(), "my_table".to_string()); + + let exists = catalog.table_exists(&ident).await?; + assert!(!exists); + + let table = catalog.create_table(namespace.name(), creation).await?; + + let exists = catalog.table_exists(table.identifier()).await?; + + assert!(exists); + + Ok(()) +} + +#[tokio::test] +async fn test_drop_table() -> Result<()> { + let catalog = get_catalog().await; + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_drop_table".into())); + + catalog + .create_namespace(namespace.name(), HashMap::new()) + .await?; + + let table = catalog.create_table(namespace.name(), creation).await?; + + catalog.drop_table(table.identifier()).await?; + + let result = catalog.table_exists(table.identifier()).await?; + + assert!(!result); + + Ok(()) +} + +#[tokio::test] +async fn test_load_table() -> Result<()> { + let catalog = get_catalog().await; + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_load_table".into())); + + catalog + .create_namespace(namespace.name(), HashMap::new()) + .await?; + + let expected = catalog.create_table(namespace.name(), creation).await?; + + let result = catalog + .load_table(&TableIdent::new( + namespace.name().clone(), + "my_table".to_string(), + )) + .await?; + + assert_eq!(result.identifier(), expected.identifier()); + assert_eq!(result.metadata_location(), expected.metadata_location()); + assert_eq!(result.metadata(), expected.metadata()); + + Ok(()) +} + +#[tokio::test] +async fn test_create_table() -> Result<()> { + let catalog = get_catalog().await; + let namespace = NamespaceIdent::new("test_create_table".to_string()); + set_test_namespace(&catalog, &namespace).await?; + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + + let result = catalog.create_table(&namespace, creation).await?; + + assert_eq!(result.identifier().name(), "my_table"); + assert!(result + .metadata_location() + .is_some_and(|location| location.starts_with("s3a://warehouse/hive/metadata/00000-"))); + assert!( + catalog + .file_io() + .is_exist("s3a://warehouse/hive/metadata/") + .await? + ); + + Ok(()) +} + +#[tokio::test] +async fn test_list_tables() -> Result<()> { + let catalog = get_catalog().await; + let namespace = NamespaceIdent::new("test_list_tables".to_string()); + set_test_namespace(&catalog, &namespace).await?; + + let expected = vec![]; + let result = catalog.list_tables(&namespace).await?; + + assert_eq!(result, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_drop_namespace() -> Result<()> { + let catalog = get_catalog().await; + let namespace = NamespaceIdent::new("test_drop_namespace".to_string()); + set_test_namespace(&catalog, &namespace).await?; + + let exists = catalog.namespace_exists(&namespace).await?; + assert!(exists); + + catalog.drop_namespace(&namespace).await?; + + let exists = catalog.namespace_exists(&namespace).await?; + assert!(!exists); + + Ok(()) +} + +#[tokio::test] +async fn test_update_namespace() -> Result<()> { + let catalog = get_catalog().await; + let namespace = NamespaceIdent::new("test_update_namespace".into()); + set_test_namespace(&catalog, &namespace).await?; + + let before_update = catalog.get_namespace(&namespace).await?; + let before_update = before_update.properties().get("description"); + + assert_eq!(before_update, None); + + let properties = HashMap::from([("description".to_string(), "my_update".to_string())]); + + catalog.update_namespace(&namespace, properties).await?; + + let after_update = catalog.get_namespace(&namespace).await?; + let after_update = after_update.properties().get("description"); + + assert_eq!(after_update, Some("my_update".to_string()).as_ref()); + + Ok(()) +} + +#[tokio::test] +async fn test_namespace_exists() -> Result<()> { + let catalog = get_catalog().await; + + let namespace = NamespaceIdent::new("test_namespace_exists".into()); + + let exists = catalog.namespace_exists(&namespace).await?; + assert!(!exists); + + set_test_namespace(&catalog, &namespace).await?; + + let exists = catalog.namespace_exists(&namespace).await?; + assert!(exists); + + Ok(()) +} + +#[tokio::test] +async fn test_get_namespace() -> Result<()> { + let catalog = get_catalog().await; + + let namespace = NamespaceIdent::new("test_get_namespace".into()); + + let does_not_exist = catalog.get_namespace(&namespace).await; + assert!(does_not_exist.is_err()); + + set_test_namespace(&catalog, &namespace).await?; + + let result = catalog.get_namespace(&namespace).await?; + let expected = Namespace::new(namespace); + + assert_eq!(result, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_create_namespace() -> Result<()> { + let catalog = get_catalog().await; + + let properties = HashMap::new(); + let namespace = NamespaceIdent::new("test_create_namespace".into()); + + let expected = Namespace::new(namespace.clone()); + + let result = catalog.create_namespace(&namespace, properties).await?; + + assert_eq!(result, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_list_namespace() -> Result<()> { + let catalog = get_catalog().await; + + let namespace = NamespaceIdent::new("test_list_namespace".to_string()); + set_test_namespace(&catalog, &namespace).await?; + + let result = catalog.list_namespaces(None).await?; + assert!(result.contains(&namespace)); + + let empty_result = catalog.list_namespaces(Some(&namespace)).await?; + assert!(empty_result.is_empty()); + + Ok(()) +} diff --git a/crates/catalog/hms/Cargo.toml b/crates/catalog/hms/Cargo.toml index 61c03fddf..e7d4ec2f3 100644 --- a/crates/catalog/hms/Cargo.toml +++ b/crates/catalog/hms/Cargo.toml @@ -17,23 +17,32 @@ [package] name = "iceberg-catalog-hms" -version = "0.1.0" -edition = "2021" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +rust-version = { workspace = true } categories = ["database"] description = "Apache Iceberg Hive Metastore Catalog Support" -repository = "https://github.com/apache/iceberg-rust" -license = "Apache-2.0" +repository = { workspace = true } +license = { workspace = true } keywords = ["iceberg", "hive", "catalog"] [dependencies] +anyhow = { workspace = true } async-trait = { workspace = true } -hive_metastore = "0.0.1" +chrono = { workspace = true } +hive_metastore = { workspace = true } iceberg = { workspace = true } -# the thrift upstream suffered from no regular rust release. -# -# [test-rs](https://github.com/tent-rs) is an organization that helps resolves this -# issue. And [tent-thrift](https://github.com/tent-rs/thrift) is a fork of the thrift -# crate, built from the thrift upstream with only version bumped. -thrift = { package = "tent-thrift", version = "0.18.1" } +log = { workspace = true } +pilota = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } typed-builder = { workspace = true } +uuid = { workspace = true } +volo-thrift = { workspace = true } + +[dev-dependencies] +ctor = { workspace = true } +iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } +port_scanner = { workspace = true } diff --git a/crates/catalog/hms/DEPENDENCIES.rust.tsv b/crates/catalog/hms/DEPENDENCIES.rust.tsv new file mode 100644 index 000000000..f54295ca5 --- /dev/null +++ b/crates/catalog/hms/DEPENDENCIES.rust.tsv @@ -0,0 +1,328 @@ +crate 0BSD Apache-2.0 Apache-2.0 WITH LLVM-exception BSD-2-Clause BSD-3-Clause BSL-1.0 CC0-1.0 ISC MIT MPL-2.0 OpenSSL Unicode-DFS-2016 Unlicense Zlib +addr2line@0.22.0 X X +adler@1.0.2 X X X +adler32@1.2.0 X +ahash@0.8.11 X X +aho-corasick@1.1.3 X X +alloc-no-stdlib@2.0.4 X +alloc-stdlib@0.2.2 X +allocator-api2@0.2.18 X X +android-tzdata@0.1.1 X X +android_system_properties@0.1.5 X X +anstream@0.6.15 X X +anstyle@1.0.8 X X +anstyle-parse@0.2.5 X X +anstyle-query@1.1.1 X X +anstyle-wincon@3.0.4 X X +anyhow@1.0.86 X X +apache-avro@0.17.0 X +array-init@2.1.0 X X +arrayref@0.3.8 X +arrayvec@0.7.4 X X +arrow-arith@52.2.0 X +arrow-array@52.2.0 X +arrow-buffer@52.2.0 X +arrow-cast@52.2.0 X +arrow-data@52.2.0 X +arrow-ipc@52.2.0 X +arrow-ord@52.2.0 X +arrow-schema@52.2.0 X +arrow-select@52.2.0 X +arrow-string@52.2.0 X +async-broadcast@0.7.1 X X +async-recursion@1.1.1 X X +async-trait@0.1.81 X X +atoi@2.0.0 X +autocfg@1.3.0 X X +backon@0.4.4 X +backtrace@0.3.73 X X +base64@0.22.1 X X +bigdecimal@0.4.5 X X +bimap@0.6.3 X X +bitflags@1.3.2 X X +bitflags@2.6.0 X X +bitvec@1.0.1 X +block-buffer@0.10.4 X X +brotli@6.0.0 X X +brotli-decompressor@4.0.1 X X +bumpalo@3.16.0 X X +byteorder@1.5.0 X X +bytes@1.7.1 X +cc@1.1.11 X X +cfg-if@1.0.0 X X +cfg_aliases@0.1.1 X +chrono@0.4.38 X X +colorchoice@1.0.2 X X +concurrent-queue@2.5.0 X X +const-oid@0.9.6 X X +const-random@0.1.18 X X +const-random-macro@0.1.16 X X +core-foundation-sys@0.8.7 X X +core2@0.4.0 X X +cpufeatures@0.2.13 X X +crc32c@0.6.8 X X +crc32fast@1.4.2 X X +crossbeam-utils@0.8.20 X X +crunchy@0.2.2 X +crypto-common@0.1.6 X X +darling@0.20.10 X +darling_core@0.20.10 X +darling_macro@0.20.10 X +dary_heap@0.3.6 X X +dashmap@5.5.3 X +derivative@2.2.0 X X +derive_builder@0.20.0 X X +derive_builder_core@0.20.0 X X +derive_builder_macro@0.20.0 X X +digest@0.10.7 X X +either@1.13.0 X X +env_filter@0.1.2 X X +env_logger@0.11.5 X X +equivalent@1.0.1 X X +event-listener@5.3.1 X X +event-listener-strategy@0.5.2 X X +fastrand@2.1.0 X X +faststr@0.2.21 X X +flagset@0.4.6 X +flatbuffers@24.3.25 X +flate2@1.0.31 X X +fnv@1.0.7 X X +form_urlencoded@1.2.1 X X +funty@2.0.0 X +futures@0.3.30 X X +futures-channel@0.3.30 X X +futures-core@0.3.30 X X +futures-executor@0.3.30 X X +futures-io@0.3.30 X X +futures-macro@0.3.30 X X +futures-sink@0.3.30 X X +futures-task@0.3.30 X X +futures-util@0.3.30 X X +generic-array@0.14.7 X +getrandom@0.2.15 X X +gimli@0.29.0 X X +half@2.4.1 X X +hashbrown@0.14.5 X X +heck@0.5.0 X X +hermit-abi@0.3.9 X X +hex@0.4.3 X X +hive_metastore@0.1.0 X +hmac@0.12.1 X X +home@0.5.9 X X +http@1.1.0 X X +http-body@1.0.1 X +http-body-util@0.1.2 X +httparse@1.9.4 X X +humantime@2.1.0 X X +hyper@1.4.1 X +hyper-rustls@0.27.2 X X X +hyper-util@0.1.7 X +iana-time-zone@0.1.60 X X +iana-time-zone-haiku@0.1.2 X X +iceberg@0.3.0 X +iceberg-catalog-hms@0.3.0 X +iceberg-catalog-memory@0.3.0 X +iceberg_test_utils@0.3.0 X +ident_case@1.0.1 X X +idna@0.5.0 X X +indexmap@2.4.0 X X +integer-encoding@3.0.4 X +integer-encoding@4.0.2 X +ipnet@2.9.0 X X +is_terminal_polyfill@1.70.1 X X +itertools@0.13.0 X X +itoa@1.0.11 X X +jobserver@0.1.32 X X +js-sys@0.3.70 X X +lazy_static@1.5.0 X X +lexical-core@0.8.5 X X +lexical-parse-float@0.8.5 X X +lexical-parse-integer@0.8.6 X X +lexical-util@0.8.5 X X +lexical-write-float@0.8.5 X X +lexical-write-integer@0.8.5 X X +libc@0.2.155 X X +libflate@2.1.0 X +libflate_lz77@2.1.0 X +libm@0.2.8 X X +linked-hash-map@0.5.6 X X +linkedbytes@0.1.8 X X +lock_api@0.4.12 X X +log@0.4.22 X X +lz4_flex@0.11.3 X +md-5@0.10.6 X X +memchr@2.7.4 X X +memoffset@0.9.1 X +metainfo@0.7.12 X X +mime@0.3.17 X X +miniz_oxide@0.7.4 X X X +mio@1.0.2 X +motore@0.4.1 X X +motore-macros@0.4.1 X X +mur3@0.1.0 X +murmur3@0.5.2 X X +nix@0.28.0 X +num@0.4.3 X X +num-bigint@0.4.6 X X +num-complex@0.4.6 X X +num-integer@0.1.46 X X +num-iter@0.1.45 X X +num-rational@0.4.2 X X +num-traits@0.2.19 X X +num_enum@0.7.3 X X X +num_enum_derive@0.7.3 X X X +object@0.36.3 X X +once_cell@1.19.0 X X +opendal@0.49.0 X +ordered-float@2.10.1 X +ordered-float@4.2.2 X +page_size@0.6.0 X X +parking@2.2.0 X X +parking_lot@0.12.3 X X +parking_lot_core@0.9.10 X X +parquet@52.2.0 X +paste@1.0.15 X X +percent-encoding@2.3.1 X X +pilota@0.11.3 X X +pin-project@1.1.5 X X +pin-project-internal@1.1.5 X X +pin-project-lite@0.2.14 X X +pin-utils@0.1.0 X X +pkg-config@0.3.30 X X +ppv-lite86@0.2.20 X X +proc-macro-crate@3.1.0 X X +proc-macro2@1.0.86 X X +quad-rand@0.2.1 X +quick-xml@0.36.1 X +quote@1.0.36 X X +radium@0.7.0 X +rand@0.8.5 X X +rand_chacha@0.3.1 X X +rand_core@0.6.4 X X +redox_syscall@0.5.3 X +regex@1.10.6 X X +regex-automata@0.4.7 X X +regex-lite@0.1.6 X X +regex-syntax@0.8.4 X X +reqsign@0.16.0 X +reqwest@0.12.5 X X +ring@0.17.8 X +rle-decode-fast@1.0.3 X X +rust_decimal@1.35.0 X +rustc-demangle@0.1.24 X X +rustc-hash@2.0.0 X X +rustc_version@0.4.0 X X +rustls@0.23.12 X X X +rustls-pemfile@2.1.3 X X X +rustls-pki-types@1.8.0 X X +rustls-webpki@0.102.6 X +rustversion@1.0.17 X X +ryu@1.0.18 X X +scopeguard@1.2.0 X X +semver@1.0.23 X X +seq-macro@0.3.5 X X +serde@1.0.207 X X +serde_bytes@0.11.15 X X +serde_derive@1.0.207 X X +serde_json@1.0.124 X X +serde_repr@0.1.19 X X +serde_urlencoded@0.7.1 X X +serde_with@3.9.0 X X +serde_with_macros@3.9.0 X X +sha1@0.10.6 X X +sha2@0.10.8 X X +shlex@1.3.0 X X +signal-hook-registry@1.4.2 X X +simdutf8@0.1.4 X X +slab@0.4.9 X +smallvec@1.13.2 X X +snap@1.1.1 X +socket2@0.5.7 X X +sonic-rs@0.3.10 X +spin@0.9.8 X +static_assertions@1.1.0 X X +strsim@0.11.1 X +strum@0.26.3 X +strum_macros@0.26.4 X +subtle@2.6.1 X +syn@1.0.109 X X +syn@2.0.74 X X +sync_wrapper@1.0.1 X +tap@1.0.1 X +thiserror@1.0.63 X X +thiserror-impl@1.0.63 X X +thrift@0.17.0 X +tiny-keccak@2.0.2 X +tinyvec@1.8.0 X X X +tinyvec_macros@0.1.1 X X X +tokio@1.39.2 X +tokio-macros@2.4.0 X +tokio-rustls@0.26.0 X X +tokio-stream@0.1.15 X +tokio-util@0.7.11 X +toml_datetime@0.6.8 X X +toml_edit@0.21.1 X X +tower@0.4.13 X +tower-layer@0.3.3 X +tower-service@0.3.3 X +tracing@0.1.40 X +tracing-attributes@0.1.27 X +tracing-core@0.1.32 X +try-lock@0.2.5 X +twox-hash@1.6.3 X +typed-builder@0.19.1 X X +typed-builder-macro@0.19.1 X X +typenum@1.17.0 X X +unicode-bidi@0.3.15 X X +unicode-ident@1.0.12 X X X +unicode-normalization@0.1.23 X X +untrusted@0.9.0 X +url@2.5.2 X X +utf8parse@0.2.2 X X +uuid@1.10.0 X X +version_check@0.9.5 X X +volo@0.10.1 X X +volo-thrift@0.10.2 X X +want@0.3.1 X +wasi@0.11.0+wasi-snapshot-preview1 X X X +wasm-bindgen@0.2.93 X X +wasm-bindgen-backend@0.2.93 X X +wasm-bindgen-futures@0.4.43 X X +wasm-bindgen-macro@0.2.93 X X +wasm-bindgen-macro-support@0.2.93 X X +wasm-bindgen-shared@0.2.93 X X +wasm-streams@0.4.0 X X +web-sys@0.3.70 X X +webpki-roots@0.26.3 X +winapi@0.3.9 X X +winapi-i686-pc-windows-gnu@0.4.0 X X +winapi-x86_64-pc-windows-gnu@0.4.0 X X +windows-core@0.52.0 X X +windows-sys@0.48.0 X X +windows-sys@0.52.0 X X +windows-targets@0.48.5 X X +windows-targets@0.52.6 X X +windows_aarch64_gnullvm@0.48.5 X X +windows_aarch64_gnullvm@0.52.6 X X +windows_aarch64_msvc@0.48.5 X X +windows_aarch64_msvc@0.52.6 X X +windows_i686_gnu@0.48.5 X X +windows_i686_gnu@0.52.6 X X +windows_i686_gnullvm@0.52.6 X X +windows_i686_msvc@0.48.5 X X +windows_i686_msvc@0.52.6 X X +windows_x86_64_gnu@0.48.5 X X +windows_x86_64_gnu@0.52.6 X X +windows_x86_64_gnullvm@0.48.5 X X +windows_x86_64_gnullvm@0.52.6 X X +windows_x86_64_msvc@0.48.5 X X +windows_x86_64_msvc@0.52.6 X X +winnow@0.5.40 X +winreg@0.52.0 X +wyz@0.5.1 X +zerocopy@0.7.35 X X X +zerocopy-derive@0.7.35 X X X +zeroize@1.8.1 X X +zstd@0.13.2 X +zstd-safe@7.2.1 X X +zstd-sys@2.0.12+zstd.1.5.6 X X diff --git a/crates/catalog/hms/README.md b/crates/catalog/hms/README.md new file mode 100644 index 000000000..bebb2200a --- /dev/null +++ b/crates/catalog/hms/README.md @@ -0,0 +1,27 @@ + + +# Apache Iceberg HiveMetaStore Catalog Official Native Rust Implementation + +[![crates.io](https://img.shields.io/crates/v/iceberg.svg)](https://crates.io/crates/iceberg-catalog-hms) +[![docs.rs](https://img.shields.io/docsrs/iceberg.svg)](https://docs.rs/iceberg/latest/iceberg-catalog-hms/) + +This crate contains the official Native Rust implementation of Apache Iceberg HiveMetaStore Catalog. + +See the [API documentation](https://docs.rs/iceberg-catalog-hms/latest) for examples and the full API. diff --git a/crates/catalog/hms/src/catalog.rs b/crates/catalog/hms/src/catalog.rs index 2b1fe2cc4..6e5db1968 100644 --- a/crates/catalog/hms/src/catalog.rs +++ b/crates/catalog/hms/src/catalog.rs @@ -15,49 +15,57 @@ // specific language governing permissions and limitations // under the License. -use super::utils::*; -use async_trait::async_trait; -use hive_metastore::{TThriftHiveMetastoreSyncClient, ThriftHiveMetastoreSyncClient}; -use iceberg::table::Table; -use iceberg::{Catalog, Namespace, NamespaceIdent, Result, TableCommit, TableCreation, TableIdent}; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; -use std::sync::{Arc, Mutex}; -use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol}; -use thrift::transport::{ - ReadHalf, TBufferedReadTransport, TBufferedWriteTransport, TIoChannel, WriteHalf, +use std::net::ToSocketAddrs; + +use anyhow::anyhow; +use async_trait::async_trait; +use hive_metastore::{ + ThriftHiveMetastoreClient, ThriftHiveMetastoreClientBuilder, + ThriftHiveMetastoreGetDatabaseException, ThriftHiveMetastoreGetTableException, +}; +use iceberg::io::FileIO; +use iceberg::spec::{TableMetadata, TableMetadataBuilder}; +use iceberg::table::Table; +use iceberg::{ + Catalog, Error, ErrorKind, Namespace, NamespaceIdent, Result, TableCommit, TableCreation, + TableIdent, }; use typed_builder::TypedBuilder; +use volo_thrift::MaybeException; + +use super::utils::*; +use crate::error::{from_io_error, from_thrift_error, from_thrift_exception}; + +/// Which variant of the thrift transport to communicate with HMS +/// See: +#[derive(Debug, Default)] +pub enum HmsThriftTransport { + /// Use the framed transport + Framed, + /// Use the buffered transport (default) + #[default] + Buffered, +} /// Hive metastore Catalog configuration. #[derive(Debug, TypedBuilder)] pub struct HmsCatalogConfig { address: String, + thrift_transport: HmsThriftTransport, + warehouse: String, + #[builder(default)] + props: HashMap, } -/// TODO: We only support binary protocol for now. -type HmsClientType = ThriftHiveMetastoreSyncClient< - TBinaryInputProtocol>>, - TBinaryOutputProtocol>>, ->; - -/// # TODO -/// -/// we are using the same connection everytime, we should support connection -/// pool in the future. -struct HmsClient(Arc>); - -impl HmsClient { - fn call(&self, f: impl FnOnce(&mut HmsClientType) -> thrift::Result) -> Result { - let mut client = self.0.lock().unwrap(); - f(&mut client).map_err(from_thrift_error) - } -} +struct HmsClient(ThriftHiveMetastoreClient); /// Hive metastore Catalog. pub struct HmsCatalog { config: HmsCatalogConfig, client: HmsClient, + file_io: FileIO, } impl Debug for HmsCatalog { @@ -71,24 +79,46 @@ impl Debug for HmsCatalog { impl HmsCatalog { /// Create a new hms catalog. pub fn new(config: HmsCatalogConfig) -> Result { - let mut channel = thrift::transport::TTcpChannel::new(); - channel - .open(config.address.as_str()) - .map_err(from_thrift_error)?; - let (i_chan, o_chan) = channel.split().map_err(from_thrift_error)?; - let i_chan = TBufferedReadTransport::new(i_chan); - let o_chan = TBufferedWriteTransport::new(o_chan); - let i_proto = TBinaryInputProtocol::new(i_chan, true); - let o_proto = TBinaryOutputProtocol::new(o_chan, true); - let client = ThriftHiveMetastoreSyncClient::new(i_proto, o_proto); + let address = config + .address + .as_str() + .to_socket_addrs() + .map_err(from_io_error)? + .next() + .ok_or_else(|| { + Error::new( + ErrorKind::Unexpected, + format!("invalid address: {}", config.address), + ) + })?; + + let builder = ThriftHiveMetastoreClientBuilder::new("hms").address(address); + + let client = match &config.thrift_transport { + HmsThriftTransport::Framed => builder + .make_codec(volo_thrift::codec::default::DefaultMakeCodec::framed()) + .build(), + HmsThriftTransport::Buffered => builder + .make_codec(volo_thrift::codec::default::DefaultMakeCodec::buffered()) + .build(), + }; + + let file_io = FileIO::from_path(&config.warehouse)? + .with_props(&config.props) + .build()?; + Ok(Self { config, - client: HmsClient(Arc::new(Mutex::new(client))), + client: HmsClient(client), + file_io, }) } + /// Get the catalogs `FileIO` + pub fn file_io(&self) -> FileIO { + self.file_io.clone() + } } -/// Refer to for implementation details. #[async_trait] impl Catalog for HmsCatalog { /// HMS doesn't support nested namespaces. @@ -103,69 +133,377 @@ impl Catalog for HmsCatalog { let dbs = if parent.is_some() { return Ok(vec![]); } else { - self.client.call(|client| client.get_all_databases())? + self.client + .0 + .get_all_databases() + .await + .map(from_thrift_exception) + .map_err(from_thrift_error)?? }; - Ok(dbs.into_iter().map(NamespaceIdent::new).collect()) + Ok(dbs + .into_iter() + .map(|v| NamespaceIdent::new(v.into())) + .collect()) } + /// Creates a new namespace with the given identifier and properties. + /// + /// Attempts to create a namespace defined by the `namespace` + /// parameter and configured with the specified `properties`. + /// + /// This function can return an error in the following situations: + /// + /// - If `hive.metastore.database.owner-type` is specified without + /// `hive.metastore.database.owner`, + /// - Errors from `validate_namespace` if the namespace identifier does not + /// meet validation criteria. + /// - Errors from `convert_to_database` if the properties cannot be + /// successfully converted into a database configuration. + /// - Errors from the underlying database creation process, converted using + /// `from_thrift_error`. async fn create_namespace( &self, - _namespace: &NamespaceIdent, - _properties: HashMap, + namespace: &NamespaceIdent, + properties: HashMap, ) -> Result { - todo!() + let database = convert_to_database(namespace, &properties)?; + + self.client + .0 + .create_database(database) + .await + .map_err(from_thrift_error)?; + + Ok(Namespace::with_properties(namespace.clone(), properties)) } - async fn get_namespace(&self, _namespace: &NamespaceIdent) -> Result { - todo!() + /// Retrieves a namespace by its identifier. + /// + /// Validates the given namespace identifier and then queries the + /// underlying database client to fetch the corresponding namespace data. + /// Constructs a `Namespace` object with the retrieved data and returns it. + /// + /// This function can return an error in any of the following situations: + /// - If the provided namespace identifier fails validation checks + /// - If there is an error querying the database, returned by + /// `from_thrift_error`. + async fn get_namespace(&self, namespace: &NamespaceIdent) -> Result { + let name = validate_namespace(namespace)?; + + let db = self + .client + .0 + .get_database(name.into()) + .await + .map(from_thrift_exception) + .map_err(from_thrift_error)??; + + let ns = convert_to_namespace(&db)?; + + Ok(ns) } - async fn namespace_exists(&self, _namespace: &NamespaceIdent) -> Result { - todo!() + /// Checks if a namespace exists within the Hive Metastore. + /// + /// Validates the namespace identifier by querying the Hive Metastore + /// to determine if the specified namespace (database) exists. + /// + /// # Returns + /// A `Result` indicating the outcome of the check: + /// - `Ok(true)` if the namespace exists. + /// - `Ok(false)` if the namespace does not exist, identified by a specific + /// `UserException` variant. + /// - `Err(...)` if an error occurs during validation or the Hive Metastore + /// query, with the error encapsulating the issue. + async fn namespace_exists(&self, namespace: &NamespaceIdent) -> Result { + let name = validate_namespace(namespace)?; + + let resp = self.client.0.get_database(name.into()).await; + + match resp { + Ok(MaybeException::Ok(_)) => Ok(true), + Ok(MaybeException::Exception(ThriftHiveMetastoreGetDatabaseException::O1(_))) => { + Ok(false) + } + Ok(MaybeException::Exception(exception)) => Err(Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting thrift error".to_string(), + ) + .with_source(anyhow!("thrift error: {:?}", exception))), + Err(err) => Err(from_thrift_error(err)), + } } + /// Asynchronously updates properties of an existing namespace. + /// + /// Converts the given namespace identifier and properties into a database + /// representation and then attempts to update the corresponding namespace + /// in the Hive Metastore. + /// + /// # Returns + /// Returns `Ok(())` if the namespace update is successful. If the + /// namespace cannot be updated due to missing information or an error + /// during the update process, an `Err(...)` is returned. async fn update_namespace( &self, - _namespace: &NamespaceIdent, - _properties: HashMap, + namespace: &NamespaceIdent, + properties: HashMap, ) -> Result<()> { - todo!() + let db = convert_to_database(namespace, &properties)?; + + let name = match &db.name { + Some(name) => name, + None => { + return Err(Error::new( + ErrorKind::DataInvalid, + "Database name must be specified", + )) + } + }; + + self.client + .0 + .alter_database(name.clone(), db) + .await + .map_err(from_thrift_error)?; + + Ok(()) } - async fn drop_namespace(&self, _namespace: &NamespaceIdent) -> Result<()> { - todo!() + /// Asynchronously drops a namespace from the Hive Metastore. + /// + /// # Returns + /// A `Result<()>` indicating the outcome: + /// - `Ok(())` signifies successful namespace deletion. + /// - `Err(...)` signifies failure to drop the namespace due to validation + /// errors, connectivity issues, or Hive Metastore constraints. + async fn drop_namespace(&self, namespace: &NamespaceIdent) -> Result<()> { + let name = validate_namespace(namespace)?; + + self.client + .0 + .drop_database(name.into(), false, false) + .await + .map_err(from_thrift_error)?; + + Ok(()) } - async fn list_tables(&self, _namespace: &NamespaceIdent) -> Result> { - todo!() + /// Asynchronously lists all tables within a specified namespace. + /// + /// # Returns + /// + /// A `Result>`, which is: + /// - `Ok(vec![...])` containing a vector of `TableIdent` instances, each + /// representing a table within the specified namespace. + /// - `Err(...)` if an error occurs during namespace validation or while + /// querying the database. + async fn list_tables(&self, namespace: &NamespaceIdent) -> Result> { + let name = validate_namespace(namespace)?; + + let tables = self + .client + .0 + .get_all_tables(name.into()) + .await + .map(from_thrift_exception) + .map_err(from_thrift_error)??; + + let tables = tables + .iter() + .map(|table| TableIdent::new(namespace.clone(), table.to_string())) + .collect(); + + Ok(tables) } + /// Creates a new table within a specified namespace using the provided + /// table creation settings. + /// + /// # Returns + /// A `Result` wrapping a `Table` object representing the newly created + /// table. + /// + /// # Errors + /// This function may return an error in several cases, including invalid + /// namespace identifiers, failure to determine a default storage location, + /// issues generating or writing table metadata, and errors communicating + /// with the Hive Metastore. async fn create_table( &self, - _namespace: &NamespaceIdent, - _creation: TableCreation, + namespace: &NamespaceIdent, + creation: TableCreation, ) -> Result
{ - todo!() + let db_name = validate_namespace(namespace)?; + let table_name = creation.name.clone(); + + let location = match &creation.location { + Some(location) => location.clone(), + None => { + let ns = self.get_namespace(namespace).await?; + get_default_table_location(&ns, &table_name, &self.config.warehouse) + } + }; + + let metadata = TableMetadataBuilder::from_table_creation(creation)?.build()?; + let metadata_location = create_metadata_location(&location, 0)?; + + self.file_io + .new_output(&metadata_location)? + .write(serde_json::to_vec(&metadata)?.into()) + .await?; + + let hive_table = convert_to_hive_table( + db_name.clone(), + metadata.current_schema(), + table_name.clone(), + location, + metadata_location.clone(), + metadata.properties(), + )?; + + self.client + .0 + .create_table(hive_table) + .await + .map_err(from_thrift_error)?; + + Table::builder() + .file_io(self.file_io()) + .metadata_location(metadata_location) + .metadata(metadata) + .identifier(TableIdent::new(NamespaceIdent::new(db_name), table_name)) + .build() } - async fn load_table(&self, _table: &TableIdent) -> Result
{ - todo!() + /// Loads a table from the Hive Metastore and constructs a `Table` object + /// based on its metadata. + /// + /// # Returns + /// A `Result` wrapping a `Table` object that represents the loaded table. + /// + /// # Errors + /// This function may return an error in several scenarios, including: + /// - Failure to validate the namespace. + /// - Failure to retrieve the table from the Hive Metastore. + /// - Absence of metadata location information in the table's properties. + /// - Issues reading or deserializing the table's metadata file. + async fn load_table(&self, table: &TableIdent) -> Result
{ + let db_name = validate_namespace(table.namespace())?; + + let hive_table = self + .client + .0 + .get_table(db_name.clone().into(), table.name.clone().into()) + .await + .map(from_thrift_exception) + .map_err(from_thrift_error)??; + + let metadata_location = get_metadata_location(&hive_table.parameters)?; + + let metadata_content = self.file_io.new_input(&metadata_location)?.read().await?; + let metadata = serde_json::from_slice::(&metadata_content)?; + + Table::builder() + .file_io(self.file_io()) + .metadata_location(metadata_location) + .metadata(metadata) + .identifier(TableIdent::new( + NamespaceIdent::new(db_name), + table.name.clone(), + )) + .build() } - async fn drop_table(&self, _table: &TableIdent) -> Result<()> { - todo!() + /// Asynchronously drops a table from the database. + /// + /// # Errors + /// Returns an error if: + /// - The namespace provided in `table` cannot be validated + /// or does not exist. + /// - The underlying database client encounters an error while + /// attempting to drop the table. This includes scenarios where + /// the table does not exist. + /// - Any network or communication error occurs with the database backend. + async fn drop_table(&self, table: &TableIdent) -> Result<()> { + let db_name = validate_namespace(table.namespace())?; + + self.client + .0 + .drop_table(db_name.into(), table.name.clone().into(), false) + .await + .map_err(from_thrift_error)?; + + Ok(()) } - async fn stat_table(&self, _table: &TableIdent) -> Result { - todo!() + /// Asynchronously checks the existence of a specified table + /// in the database. + /// + /// # Returns + /// - `Ok(true)` if the table exists in the database. + /// - `Ok(false)` if the table does not exist in the database. + /// - `Err(...)` if an error occurs during the process + async fn table_exists(&self, table: &TableIdent) -> Result { + let db_name = validate_namespace(table.namespace())?; + let table_name = table.name.clone(); + + let resp = self + .client + .0 + .get_table(db_name.into(), table_name.into()) + .await; + + match resp { + Ok(MaybeException::Ok(_)) => Ok(true), + Ok(MaybeException::Exception(ThriftHiveMetastoreGetTableException::O2(_))) => Ok(false), + Ok(MaybeException::Exception(exception)) => Err(Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting thrift error".to_string(), + ) + .with_source(anyhow!("thrift error: {:?}", exception))), + Err(err) => Err(from_thrift_error(err)), + } } - async fn rename_table(&self, _src: &TableIdent, _dest: &TableIdent) -> Result<()> { - todo!() + /// Asynchronously renames a table within the database + /// or moves it between namespaces (databases). + /// + /// # Returns + /// - `Ok(())` on successful rename or move of the table. + /// - `Err(...)` if an error occurs during the process. + async fn rename_table(&self, src: &TableIdent, dest: &TableIdent) -> Result<()> { + let src_dbname = validate_namespace(src.namespace())?; + let dest_dbname = validate_namespace(dest.namespace())?; + + let src_tbl_name = src.name.clone(); + let dest_tbl_name = dest.name.clone(); + + let mut tbl = self + .client + .0 + .get_table(src_dbname.clone().into(), src_tbl_name.clone().into()) + .await + .map(from_thrift_exception) + .map_err(from_thrift_error)??; + + tbl.db_name = Some(dest_dbname.into()); + tbl.table_name = Some(dest_tbl_name.into()); + + self.client + .0 + .alter_table(src_dbname.into(), src_tbl_name.into(), tbl) + .await + .map_err(from_thrift_error)?; + + Ok(()) } async fn update_table(&self, _commit: TableCommit) -> Result
{ - todo!() + Err(Error::new( + ErrorKind::FeatureUnsupported, + "Updating a table is not supported yet", + )) } } diff --git a/crates/catalog/hms/src/error.rs b/crates/catalog/hms/src/error.rs new file mode 100644 index 000000000..15da3eaf6 --- /dev/null +++ b/crates/catalog/hms/src/error.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; +use std::io; + +use anyhow::anyhow; +use iceberg::{Error, ErrorKind}; +use volo_thrift::MaybeException; + +/// Format a thrift error into iceberg error. +/// +/// Please only throw this error when you are sure that the error is caused by thrift. +pub fn from_thrift_error(error: impl std::error::Error) -> Error { + Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting thrift error".to_string(), + ) + .with_source(anyhow!("thrift error: {:?}", error)) +} + +/// Format a thrift exception into iceberg error. +pub fn from_thrift_exception(value: MaybeException) -> Result { + match value { + MaybeException::Ok(v) => Ok(v), + MaybeException::Exception(err) => Err(Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting thrift error".to_string(), + ) + .with_source(anyhow!("thrift error: {:?}", err))), + } +} + +/// Format an io error into iceberg error. +pub fn from_io_error(error: io::Error) -> Error { + Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting io error".to_string(), + ) + .with_source(error) +} diff --git a/crates/catalog/hms/src/lib.rs b/crates/catalog/hms/src/lib.rs index b75e74977..db0034d46 100644 --- a/crates/catalog/hms/src/lib.rs +++ b/crates/catalog/hms/src/lib.rs @@ -22,4 +22,6 @@ mod catalog; pub use catalog::*; +mod error; +mod schema; mod utils; diff --git a/crates/catalog/hms/src/schema.rs b/crates/catalog/hms/src/schema.rs new file mode 100644 index 000000000..4012098c2 --- /dev/null +++ b/crates/catalog/hms/src/schema.rs @@ -0,0 +1,460 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use hive_metastore::FieldSchema; +use iceberg::spec::{visit_schema, PrimitiveType, Schema, SchemaVisitor}; +use iceberg::{Error, ErrorKind, Result}; + +type HiveSchema = Vec; + +#[derive(Debug, Default)] +pub(crate) struct HiveSchemaBuilder { + schema: HiveSchema, + depth: usize, +} + +impl HiveSchemaBuilder { + /// Creates a new `HiveSchemaBuilder` from iceberg `Schema` + pub fn from_iceberg(schema: &Schema) -> Result { + let mut builder = Self::default(); + visit_schema(schema, &mut builder)?; + Ok(builder) + } + + /// Returns the newly converted `HiveSchema` + pub fn build(self) -> HiveSchema { + self.schema + } + + /// Check if is in `StructType` while traversing schema + fn is_inside_struct(&self) -> bool { + self.depth > 0 + } +} + +impl SchemaVisitor for HiveSchemaBuilder { + type T = String; + + fn schema( + &mut self, + _schema: &iceberg::spec::Schema, + value: String, + ) -> iceberg::Result { + Ok(value) + } + + fn before_struct_field( + &mut self, + _field: &iceberg::spec::NestedFieldRef, + ) -> iceberg::Result<()> { + self.depth += 1; + Ok(()) + } + + fn r#struct( + &mut self, + r#_struct: &iceberg::spec::StructType, + results: Vec, + ) -> iceberg::Result { + Ok(format!("struct<{}>", results.join(", "))) + } + + fn after_struct_field( + &mut self, + _field: &iceberg::spec::NestedFieldRef, + ) -> iceberg::Result<()> { + self.depth -= 1; + Ok(()) + } + + fn field( + &mut self, + field: &iceberg::spec::NestedFieldRef, + value: String, + ) -> iceberg::Result { + if self.is_inside_struct() { + return Ok(format!("{}:{}", field.name, value)); + } + + self.schema.push(FieldSchema { + name: Some(field.name.clone().into()), + r#type: Some(value.clone().into()), + comment: field.doc.clone().map(|doc| doc.into()), + }); + + Ok(value) + } + + fn list(&mut self, _list: &iceberg::spec::ListType, value: String) -> iceberg::Result { + Ok(format!("array<{}>", value)) + } + + fn map( + &mut self, + _map: &iceberg::spec::MapType, + key_value: String, + value: String, + ) -> iceberg::Result { + Ok(format!("map<{},{}>", key_value, value)) + } + + fn primitive(&mut self, p: &iceberg::spec::PrimitiveType) -> iceberg::Result { + let hive_type = match p { + PrimitiveType::Boolean => "boolean".to_string(), + PrimitiveType::Int => "int".to_string(), + PrimitiveType::Long => "bigint".to_string(), + PrimitiveType::Float => "float".to_string(), + PrimitiveType::Double => "double".to_string(), + PrimitiveType::Date => "date".to_string(), + PrimitiveType::Timestamp => "timestamp".to_string(), + PrimitiveType::TimestampNs => "timestamp_ns".to_string(), + PrimitiveType::TimestamptzNs => "timestamptz_ns".to_string(), + PrimitiveType::Time | PrimitiveType::String | PrimitiveType::Uuid => { + "string".to_string() + } + PrimitiveType::Binary | PrimitiveType::Fixed(_) => "binary".to_string(), + PrimitiveType::Decimal { precision, scale } => { + format!("decimal({},{})", precision, scale) + } + _ => { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + "Conversion from 'Timestamptz' is not supported", + )) + } + }; + + Ok(hive_type) + } +} + +#[cfg(test)] +mod tests { + use iceberg::spec::Schema; + use iceberg::Result; + + use super::*; + + #[test] + fn test_schema_with_nested_maps() -> Result<()> { + let record = r#" + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "quux", + "required": true, + "type": { + "type": "map", + "key-id": 2, + "key": "string", + "value-id": 3, + "value-required": true, + "value": { + "type": "map", + "key-id": 4, + "key": "string", + "value-id": 5, + "value-required": true, + "value": "int" + } + } + } + ] + } + "#; + + let schema = serde_json::from_str::(record)?; + + let result = HiveSchemaBuilder::from_iceberg(&schema)?.build(); + + let expected = vec![FieldSchema { + name: Some("quux".into()), + r#type: Some("map>".into()), + comment: None, + }]; + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_schema_with_struct_inside_list() -> Result<()> { + let record = r#" + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "location", + "required": true, + "type": { + "type": "list", + "element-id": 2, + "element-required": true, + "element": { + "type": "struct", + "fields": [ + { + "id": 3, + "name": "latitude", + "required": false, + "type": "float" + }, + { + "id": 4, + "name": "longitude", + "required": false, + "type": "float" + } + ] + } + } + } + ] + } + "#; + + let schema = serde_json::from_str::(record)?; + + let result = HiveSchemaBuilder::from_iceberg(&schema)?.build(); + + let expected = vec![FieldSchema { + name: Some("location".into()), + r#type: Some("array>".into()), + comment: None, + }]; + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_schema_with_structs() -> Result<()> { + let record = r#"{ + "type": "struct", + "schema-id": 1, + "fields": [ + { + "id": 1, + "name": "person", + "required": true, + "type": { + "type": "struct", + "fields": [ + { + "id": 2, + "name": "name", + "required": true, + "type": "string" + }, + { + "id": 3, + "name": "age", + "required": false, + "type": "int" + } + ] + } + } + ] + }"#; + + let schema = serde_json::from_str::(record)?; + + let result = HiveSchemaBuilder::from_iceberg(&schema)?.build(); + + let expected = vec![FieldSchema { + name: Some("person".into()), + r#type: Some("struct".into()), + comment: None, + }]; + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_schema_with_simple_fields() -> Result<()> { + let record = r#"{ + "type": "struct", + "schema-id": 1, + "fields": [ + { + "id": 1, + "name": "c1", + "required": true, + "type": "boolean" + }, + { + "id": 2, + "name": "c2", + "required": true, + "type": "int" + }, + { + "id": 3, + "name": "c3", + "required": true, + "type": "long" + }, + { + "id": 4, + "name": "c4", + "required": true, + "type": "float" + }, + { + "id": 5, + "name": "c5", + "required": true, + "type": "double" + }, + { + "id": 6, + "name": "c6", + "required": true, + "type": "decimal(2,2)" + }, + { + "id": 7, + "name": "c7", + "required": true, + "type": "date" + }, + { + "id": 8, + "name": "c8", + "required": true, + "type": "time" + }, + { + "id": 9, + "name": "c9", + "required": true, + "type": "timestamp" + }, + { + "id": 10, + "name": "c10", + "required": true, + "type": "string" + }, + { + "id": 11, + "name": "c11", + "required": true, + "type": "uuid" + }, + { + "id": 12, + "name": "c12", + "required": true, + "type": "fixed[4]" + }, + { + "id": 13, + "name": "c13", + "required": true, + "type": "binary" + } + ] + }"#; + + let schema = serde_json::from_str::(record)?; + + let result = HiveSchemaBuilder::from_iceberg(&schema)?.build(); + + let expected = vec![ + FieldSchema { + name: Some("c1".into()), + r#type: Some("boolean".into()), + comment: None, + }, + FieldSchema { + name: Some("c2".into()), + r#type: Some("int".into()), + comment: None, + }, + FieldSchema { + name: Some("c3".into()), + r#type: Some("bigint".into()), + comment: None, + }, + FieldSchema { + name: Some("c4".into()), + r#type: Some("float".into()), + comment: None, + }, + FieldSchema { + name: Some("c5".into()), + r#type: Some("double".into()), + comment: None, + }, + FieldSchema { + name: Some("c6".into()), + r#type: Some("decimal(2,2)".into()), + comment: None, + }, + FieldSchema { + name: Some("c7".into()), + r#type: Some("date".into()), + comment: None, + }, + FieldSchema { + name: Some("c8".into()), + r#type: Some("string".into()), + comment: None, + }, + FieldSchema { + name: Some("c9".into()), + r#type: Some("timestamp".into()), + comment: None, + }, + FieldSchema { + name: Some("c10".into()), + r#type: Some("string".into()), + comment: None, + }, + FieldSchema { + name: Some("c11".into()), + r#type: Some("string".into()), + comment: None, + }, + FieldSchema { + name: Some("c12".into()), + r#type: Some("binary".into()), + comment: None, + }, + FieldSchema { + name: Some("c13".into()), + r#type: Some("binary".into()), + comment: None, + }, + ]; + + assert_eq!(result, expected); + + Ok(()) + } +} diff --git a/crates/catalog/hms/src/utils.rs b/crates/catalog/hms/src/utils.rs index 0daa52aa1..1e48d3fbd 100644 --- a/crates/catalog/hms/src/utils.rs +++ b/crates/catalog/hms/src/utils.rs @@ -15,13 +15,547 @@ // specific language governing permissions and limitations // under the License. -use iceberg::{Error, ErrorKind}; - -/// Format a thrift error into iceberg error. -pub fn from_thrift_error(error: thrift::Error) -> Error { - Error::new( - ErrorKind::Unexpected, - "operation failed for hitting thrift error".to_string(), - ) - .with_source(error) +use std::collections::HashMap; + +use chrono::Utc; +use hive_metastore::{Database, PrincipalType, SerDeInfo, StorageDescriptor}; +use iceberg::spec::Schema; +use iceberg::{Error, ErrorKind, Namespace, NamespaceIdent, Result}; +use pilota::{AHashMap, FastStr}; +use uuid::Uuid; + +use crate::schema::HiveSchemaBuilder; + +/// hive.metastore.database.owner setting +const HMS_DB_OWNER: &str = "hive.metastore.database.owner"; +/// hive.metastore.database.owner default setting +const HMS_DEFAULT_DB_OWNER: &str = "user.name"; +/// hive.metastore.database.owner-type setting +const HMS_DB_OWNER_TYPE: &str = "hive.metastore.database.owner-type"; +/// hive metatore `owner` property +const OWNER: &str = "owner"; +/// hive metatore `description` property +const COMMENT: &str = "comment"; +/// hive metatore `location` property +const LOCATION: &str = "location"; +/// hive metatore `metadat_location` property +const METADATA_LOCATION: &str = "metadata_location"; +/// hive metatore `external` property +const EXTERNAL: &str = "EXTERNAL"; +/// hive metatore `external_table` property +const EXTERNAL_TABLE: &str = "EXTERNAL_TABLE"; +/// hive metatore `table_type` property +const TABLE_TYPE: &str = "table_type"; +/// hive metatore `SerDeInfo` serialization_lib parameter +const SERIALIZATION_LIB: &str = "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"; +/// hive metatore input format +const INPUT_FORMAT: &str = "org.apache.hadoop.mapred.FileInputFormat"; +/// hive metatore output format +const OUTPUT_FORMAT: &str = "org.apache.hadoop.mapred.FileOutputFormat"; + +/// Returns a `Namespace` by extracting database name and properties +/// from `hive_metastore::hms::Database` +pub(crate) fn convert_to_namespace(database: &Database) -> Result { + let mut properties = HashMap::new(); + + let name = database + .name + .as_ref() + .ok_or_else(|| Error::new(ErrorKind::DataInvalid, "Database name must be specified"))? + .to_string(); + + if let Some(description) = &database.description { + properties.insert(COMMENT.to_string(), description.to_string()); + }; + + if let Some(location) = &database.location_uri { + properties.insert(LOCATION.to_string(), location.to_string()); + }; + + if let Some(owner) = &database.owner_name { + properties.insert(HMS_DB_OWNER.to_string(), owner.to_string()); + }; + + if let Some(owner_type) = database.owner_type { + let value = if owner_type == PrincipalType::USER { + "User" + } else if owner_type == PrincipalType::GROUP { + "Group" + } else if owner_type == PrincipalType::ROLE { + "Role" + } else { + unreachable!("Invalid owner type") + }; + + properties.insert(HMS_DB_OWNER_TYPE.to_string(), value.to_string()); + }; + + if let Some(params) = &database.parameters { + params.iter().for_each(|(k, v)| { + properties.insert(k.clone().into(), v.clone().into()); + }); + }; + + Ok(Namespace::with_properties( + NamespaceIdent::new(name), + properties, + )) +} + +/// Converts name and properties into `hive_metastore::hms::Database` +/// after validating the `namespace` and `owner-settings`. +pub(crate) fn convert_to_database( + namespace: &NamespaceIdent, + properties: &HashMap, +) -> Result { + let name = validate_namespace(namespace)?; + validate_owner_settings(properties)?; + + let mut db = Database::default(); + let mut parameters = AHashMap::new(); + + db.name = Some(name.into()); + + for (k, v) in properties { + match k.as_str() { + COMMENT => db.description = Some(v.clone().into()), + LOCATION => db.location_uri = Some(format_location_uri(v.clone()).into()), + HMS_DB_OWNER => db.owner_name = Some(v.clone().into()), + HMS_DB_OWNER_TYPE => { + let owner_type = match v.to_lowercase().as_str() { + "user" => PrincipalType::USER, + "group" => PrincipalType::GROUP, + "role" => PrincipalType::ROLE, + _ => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!("Invalid value for setting 'owner_type': {}", v), + )) + } + }; + db.owner_type = Some(owner_type); + } + _ => { + parameters.insert( + FastStr::from_string(k.clone()), + FastStr::from_string(v.clone()), + ); + } + } + } + + db.parameters = Some(parameters); + + // Set default owner, if none provided + // https://github.com/apache/iceberg/blob/main/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveHadoopUtil.java#L44 + if db.owner_name.is_none() { + db.owner_name = Some(HMS_DEFAULT_DB_OWNER.into()); + db.owner_type = Some(PrincipalType::USER); + } + + Ok(db) +} + +pub(crate) fn convert_to_hive_table( + db_name: String, + schema: &Schema, + table_name: String, + location: String, + metadata_location: String, + properties: &HashMap, +) -> Result { + let serde_info = SerDeInfo { + serialization_lib: Some(SERIALIZATION_LIB.into()), + ..Default::default() + }; + + let hive_schema = HiveSchemaBuilder::from_iceberg(schema)?.build(); + + let storage_descriptor = StorageDescriptor { + location: Some(location.into()), + cols: Some(hive_schema), + input_format: Some(INPUT_FORMAT.into()), + output_format: Some(OUTPUT_FORMAT.into()), + serde_info: Some(serde_info), + ..Default::default() + }; + + let parameters = AHashMap::from([ + (FastStr::from(EXTERNAL), FastStr::from("TRUE")), + (FastStr::from(TABLE_TYPE), FastStr::from("ICEBERG")), + ( + FastStr::from(METADATA_LOCATION), + FastStr::from(metadata_location), + ), + ]); + + let current_time_ms = get_current_time()?; + let owner = properties + .get(OWNER) + .map_or(HMS_DEFAULT_DB_OWNER.to_string(), |v| v.into()); + + Ok(hive_metastore::Table { + table_name: Some(table_name.into()), + db_name: Some(db_name.into()), + table_type: Some(EXTERNAL_TABLE.into()), + owner: Some(owner.into()), + create_time: Some(current_time_ms), + last_access_time: Some(current_time_ms), + sd: Some(storage_descriptor), + parameters: Some(parameters), + ..Default::default() + }) +} + +/// Checks if provided `NamespaceIdent` is valid. +pub(crate) fn validate_namespace(namespace: &NamespaceIdent) -> Result { + let name = namespace.as_ref(); + + if name.len() != 1 { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Invalid database name: {:?}, hierarchical namespaces are not supported", + namespace + ), + )); + } + + let name = name[0].clone(); + + if name.is_empty() { + return Err(Error::new( + ErrorKind::DataInvalid, + "Invalid database, provided namespace is empty.", + )); + } + + Ok(name) +} + +/// Get default table location from `Namespace` properties +pub(crate) fn get_default_table_location( + namespace: &Namespace, + table_name: impl AsRef, + warehouse: impl AsRef, +) -> String { + let properties = namespace.properties(); + + let location = match properties.get(LOCATION) { + Some(location) => location, + None => warehouse.as_ref(), + }; + + format!("{}/{}", location, table_name.as_ref()) +} + +/// Create metadata location from `location` and `version` +pub(crate) fn create_metadata_location(location: impl AsRef, version: i32) -> Result { + if version < 0 { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Table metadata version: '{}' must be a non-negative integer", + version + ), + )); + }; + + let version = format!("{:0>5}", version); + let id = Uuid::new_v4(); + let metadata_location = format!( + "{}/metadata/{}-{}.metadata.json", + location.as_ref(), + version, + id + ); + + Ok(metadata_location) +} + +/// Get metadata location from `HiveTable` parameters +pub(crate) fn get_metadata_location( + parameters: &Option>, +) -> Result { + match parameters { + Some(properties) => match properties.get(METADATA_LOCATION) { + Some(location) => Ok(location.to_string()), + None => Err(Error::new( + ErrorKind::DataInvalid, + format!("No '{}' set on table", METADATA_LOCATION), + )), + }, + None => Err(Error::new( + ErrorKind::DataInvalid, + "No 'parameters' set on table. Location of metadata is undefined", + )), + } +} + +/// Formats location_uri by e.g. removing trailing slashes. +fn format_location_uri(location: String) -> String { + let mut location = location; + + if !location.starts_with('/') { + location = format!("/{}", location); + } + + if location.ends_with('/') && location.len() > 1 { + location.pop(); + } + + location +} + +/// Checks if `owner-settings` are valid. +/// If `owner_type` is set, then `owner` must also be set. +fn validate_owner_settings(properties: &HashMap) -> Result<()> { + let owner_is_set = properties.get(HMS_DB_OWNER).is_some(); + let owner_type_is_set = properties.get(HMS_DB_OWNER_TYPE).is_some(); + + if owner_type_is_set && !owner_is_set { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Setting '{}' without setting '{}' is not allowed", + HMS_DB_OWNER_TYPE, HMS_DB_OWNER + ), + )); + } + + Ok(()) +} + +fn get_current_time() -> Result { + let now = Utc::now(); + now.timestamp().try_into().map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "Current time is out of range for i32", + ) + }) +} + +#[cfg(test)] +mod tests { + use iceberg::spec::{NestedField, PrimitiveType, Type}; + use iceberg::{Namespace, NamespaceIdent}; + + use super::*; + + #[test] + fn test_get_metadata_location() -> Result<()> { + let params_valid = Some(AHashMap::from([( + FastStr::new(METADATA_LOCATION), + FastStr::new("my_location"), + )])); + let params_missing_key = Some(AHashMap::from([( + FastStr::new("not_here"), + FastStr::new("my_location"), + )])); + + let result_valid = get_metadata_location(¶ms_valid)?; + let result_missing_key = get_metadata_location(¶ms_missing_key); + let result_no_params = get_metadata_location(&None); + + assert_eq!(result_valid, "my_location"); + assert!(result_missing_key.is_err()); + assert!(result_no_params.is_err()); + + Ok(()) + } + + #[test] + fn test_convert_to_hive_table() -> Result<()> { + let db_name = "my_db".to_string(); + let table_name = "my_table".to_string(); + let location = "s3a://warehouse/hms".to_string(); + let metadata_location = create_metadata_location(location.clone(), 0)?; + let properties = HashMap::new(); + let schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::required(1, "foo", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + ]) + .build()?; + + let result = convert_to_hive_table( + db_name.clone(), + &schema, + table_name.clone(), + location.clone(), + metadata_location, + &properties, + )?; + + let serde_info = SerDeInfo { + serialization_lib: Some(SERIALIZATION_LIB.into()), + ..Default::default() + }; + + let hive_schema = HiveSchemaBuilder::from_iceberg(&schema)?.build(); + + let sd = StorageDescriptor { + location: Some(location.into()), + cols: Some(hive_schema), + input_format: Some(INPUT_FORMAT.into()), + output_format: Some(OUTPUT_FORMAT.into()), + serde_info: Some(serde_info), + ..Default::default() + }; + + assert_eq!(result.db_name, Some(db_name.into())); + assert_eq!(result.table_name, Some(table_name.into())); + assert_eq!(result.table_type, Some(EXTERNAL_TABLE.into())); + assert_eq!(result.owner, Some(HMS_DEFAULT_DB_OWNER.into())); + assert_eq!(result.sd, Some(sd)); + + Ok(()) + } + + #[test] + fn test_create_metadata_location() -> Result<()> { + let location = "my_base_location"; + let valid_version = 0; + let invalid_version = -1; + + let valid_result = create_metadata_location(location, valid_version)?; + let invalid_result = create_metadata_location(location, invalid_version); + + assert!(valid_result.starts_with("my_base_location/metadata/00000-")); + assert!(valid_result.ends_with(".metadata.json")); + assert!(invalid_result.is_err()); + + Ok(()) + } + + #[test] + fn test_get_default_table_location() -> Result<()> { + let properties = HashMap::from([(LOCATION.to_string(), "db_location".to_string())]); + + let namespace = + Namespace::with_properties(NamespaceIdent::new("default".into()), properties); + let table_name = "my_table"; + + let expected = "db_location/my_table"; + let result = get_default_table_location(&namespace, table_name, "warehouse_location"); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_get_default_table_location_warehouse() -> Result<()> { + let namespace = Namespace::new(NamespaceIdent::new("default".into())); + let table_name = "my_table"; + + let expected = "warehouse_location/my_table"; + let result = get_default_table_location(&namespace, table_name, "warehouse_location"); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_convert_to_namespace() -> Result<()> { + let properties = HashMap::from([ + (COMMENT.to_string(), "my_description".to_string()), + (LOCATION.to_string(), "/my_location".to_string()), + (HMS_DB_OWNER.to_string(), "apache".to_string()), + (HMS_DB_OWNER_TYPE.to_string(), "User".to_string()), + ("key1".to_string(), "value1".to_string()), + ]); + + let ident = NamespaceIdent::new("my_namespace".into()); + let db = convert_to_database(&ident, &properties)?; + + let expected_ns = Namespace::with_properties(ident, properties); + let result_ns = convert_to_namespace(&db)?; + + assert_eq!(expected_ns, result_ns); + + Ok(()) + } + + #[test] + fn test_validate_owner_settings() { + let valid = HashMap::from([ + (HMS_DB_OWNER.to_string(), "apache".to_string()), + (HMS_DB_OWNER_TYPE.to_string(), "user".to_string()), + ]); + let invalid = HashMap::from([(HMS_DB_OWNER_TYPE.to_string(), "user".to_string())]); + + assert!(validate_owner_settings(&valid).is_ok()); + assert!(validate_owner_settings(&invalid).is_err()); + } + + #[test] + fn test_convert_to_database() -> Result<()> { + let ns = NamespaceIdent::new("my_namespace".into()); + let properties = HashMap::from([ + (COMMENT.to_string(), "my_description".to_string()), + (LOCATION.to_string(), "my_location".to_string()), + (HMS_DB_OWNER.to_string(), "apache".to_string()), + (HMS_DB_OWNER_TYPE.to_string(), "user".to_string()), + ("key1".to_string(), "value1".to_string()), + ]); + + let db = convert_to_database(&ns, &properties)?; + + assert_eq!(db.name, Some(FastStr::from("my_namespace"))); + assert_eq!(db.description, Some(FastStr::from("my_description"))); + assert_eq!(db.owner_name, Some(FastStr::from("apache"))); + assert_eq!(db.owner_type, Some(PrincipalType::USER)); + + if let Some(params) = db.parameters { + assert_eq!(params.get("key1"), Some(&FastStr::from("value1"))); + } + + Ok(()) + } + + #[test] + fn test_convert_to_database_with_default_user() -> Result<()> { + let ns = NamespaceIdent::new("my_namespace".into()); + let properties = HashMap::new(); + + let db = convert_to_database(&ns, &properties)?; + + assert_eq!(db.name, Some(FastStr::from("my_namespace"))); + assert_eq!(db.owner_name, Some(FastStr::from(HMS_DEFAULT_DB_OWNER))); + assert_eq!(db.owner_type, Some(PrincipalType::USER)); + + Ok(()) + } + + #[test] + fn test_validate_namespace() { + let valid_ns = Namespace::new(NamespaceIdent::new("ns".to_string())); + let empty_ns = Namespace::new(NamespaceIdent::new("".to_string())); + let hierarchical_ns = Namespace::new( + NamespaceIdent::from_vec(vec!["level1".to_string(), "level2".to_string()]).unwrap(), + ); + + let valid = validate_namespace(valid_ns.name()); + let empty = validate_namespace(empty_ns.name()); + let hierarchical = validate_namespace(hierarchical_ns.name()); + + assert!(valid.is_ok()); + assert!(empty.is_err()); + assert!(hierarchical.is_err()); + } + + #[test] + fn test_format_location_uri() { + let inputs = vec!["iceberg", "is/", "/nice/", "really/nice/", "/"]; + let outputs = vec!["/iceberg", "/is", "/nice", "/really/nice", "/"]; + + inputs.into_iter().zip(outputs).for_each(|(inp, out)| { + let location = format_location_uri(inp.to_string()); + assert_eq!(location, out); + }) + } } diff --git a/crates/catalog/hms/testdata/hms_catalog/Dockerfile b/crates/catalog/hms/testdata/hms_catalog/Dockerfile new file mode 100644 index 000000000..8392e174a --- /dev/null +++ b/crates/catalog/hms/testdata/hms_catalog/Dockerfile @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM --platform=$BUILDPLATFORM openjdk:8-jre-slim AS build + +ARG BUILDPLATFORM + +RUN apt-get update -qq && apt-get -qq -y install curl + +ENV AWSSDK_VERSION=2.20.18 +ENV HADOOP_VERSION=3.1.0 + +RUN curl https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.11.271/aws-java-sdk-bundle-1.11.271.jar -Lo /tmp/aws-java-sdk-bundle-1.11.271.jar +RUN curl https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/${HADOOP_VERSION}/hadoop-aws-${HADOOP_VERSION}.jar -Lo /tmp/hadoop-aws-${HADOOP_VERSION}.jar + + +FROM apache/hive:3.1.3 + +ENV AWSSDK_VERSION=2.20.18 +ENV HADOOP_VERSION=3.1.0 + +COPY --from=build /tmp/hadoop-aws-${HADOOP_VERSION}.jar /opt/hive/lib/hadoop-aws-${HADOOP_VERSION}.jar +COPY --from=build /tmp/aws-java-sdk-bundle-1.11.271.jar /opt/hive/lib/aws-java-sdk-bundle-1.11.271.jar +COPY core-site.xml /opt/hadoop/etc/hadoop/core-site.xml \ No newline at end of file diff --git a/crates/catalog/hms/testdata/hms_catalog/core-site.xml b/crates/catalog/hms/testdata/hms_catalog/core-site.xml new file mode 100644 index 000000000..f0583a0bc --- /dev/null +++ b/crates/catalog/hms/testdata/hms_catalog/core-site.xml @@ -0,0 +1,51 @@ + + + + + fs.defaultFS + s3a://warehouse/hive + + + fs.s3a.impl + org.apache.hadoop.fs.s3a.S3AFileSystem + + + fs.s3a.fast.upload + true + + + fs.s3a.endpoint + http://minio:9000 + + + fs.s3a.access.key + admin + + + fs.s3a.secret.key + password + + + fs.s3a.connection.ssl.enabled + false + + + fs.s3a.path.style.access + true + + \ No newline at end of file diff --git a/crates/catalog/hms/testdata/hms_catalog/docker-compose.yaml b/crates/catalog/hms/testdata/hms_catalog/docker-compose.yaml new file mode 100644 index 000000000..181fac149 --- /dev/null +++ b/crates/catalog/hms/testdata/hms_catalog/docker-compose.yaml @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +services: + minio: + image: minio/minio:RELEASE.2024-03-07T00-43-48Z + expose: + - 9000 + - 9001 + environment: + - MINIO_ROOT_USER=admin + - MINIO_ROOT_PASSWORD=password + - MINIO_DOMAIN=minio + command: [ "server", "/data", "--console-address", ":9001" ] + + mc: + depends_on: + - minio + image: minio/mc:RELEASE.2024-03-07T00-31-49Z + environment: + - AWS_ACCESS_KEY_ID=admin + - AWS_SECRET_ACCESS_KEY=password + - AWS_REGION=us-east-1 + entrypoint: > + /bin/sh -c " until (/usr/bin/mc config host add minio http://minio:9000 admin password) do echo '...waiting...' && sleep 1; done; /usr/bin/mc mb minio/warehouse; /usr/bin/mc policy set public minio/warehouse; tail -f /dev/null " + + hive-metastore: + image: iceberg-hive-metastore + build: ./ + platform: ${DOCKER_DEFAULT_PLATFORM} + expose: + - 9083 + environment: + SERVICE_NAME: "metastore" + SERVICE_OPTS: "-Dmetastore.warehouse.dir=s3a://warehouse/hive/" diff --git a/crates/catalog/hms/tests/hms_catalog_test.rs b/crates/catalog/hms/tests/hms_catalog_test.rs new file mode 100644 index 000000000..5b8004439 --- /dev/null +++ b/crates/catalog/hms/tests/hms_catalog_test.rs @@ -0,0 +1,369 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for hms catalog. + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::RwLock; + +use ctor::{ctor, dtor}; +use iceberg::io::{S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY}; +use iceberg::spec::{NestedField, PrimitiveType, Schema, Type}; +use iceberg::{Catalog, Namespace, NamespaceIdent, TableCreation, TableIdent}; +use iceberg_catalog_hms::{HmsCatalog, HmsCatalogConfig, HmsThriftTransport}; +use iceberg_test_utils::docker::DockerCompose; +use iceberg_test_utils::{normalize_test_name, set_up}; +use port_scanner::scan_port_addr; +use tokio::time::sleep; + +const HMS_CATALOG_PORT: u16 = 9083; +const MINIO_PORT: u16 = 9000; +static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); +type Result = std::result::Result; + +#[ctor] +fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + let docker_compose = DockerCompose::new( + normalize_test_name(module_path!()), + format!("{}/testdata/hms_catalog", env!("CARGO_MANIFEST_DIR")), + ); + docker_compose.run(); + guard.replace(docker_compose); +} + +#[dtor] +fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); +} + +async fn get_catalog() -> HmsCatalog { + set_up(); + + let (hms_catalog_ip, minio_ip) = { + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + ( + docker_compose.get_container_ip("hive-metastore"), + docker_compose.get_container_ip("minio"), + ) + }; + let hms_socket_addr = SocketAddr::new(hms_catalog_ip, HMS_CATALOG_PORT); + let minio_socket_addr = SocketAddr::new(minio_ip, MINIO_PORT); + while !scan_port_addr(hms_socket_addr) { + log::info!("scan hms_socket_addr {} check", hms_socket_addr); + log::info!("Waiting for 1s hms catalog to ready..."); + sleep(std::time::Duration::from_millis(1000)).await; + } + + let props = HashMap::from([ + ( + S3_ENDPOINT.to_string(), + format!("http://{}", minio_socket_addr), + ), + (S3_ACCESS_KEY_ID.to_string(), "admin".to_string()), + (S3_SECRET_ACCESS_KEY.to_string(), "password".to_string()), + (S3_REGION.to_string(), "us-east-1".to_string()), + ]); + + let config = HmsCatalogConfig::builder() + .address(hms_socket_addr.to_string()) + .thrift_transport(HmsThriftTransport::Buffered) + .warehouse("s3a://warehouse/hive".to_string()) + .props(props) + .build(); + + HmsCatalog::new(config).unwrap() +} + +async fn set_test_namespace(catalog: &HmsCatalog, namespace: &NamespaceIdent) -> Result<()> { + let properties = HashMap::new(); + + catalog.create_namespace(namespace, properties).await?; + + Ok(()) +} + +fn set_table_creation(location: impl ToString, name: impl ToString) -> Result { + let schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "foo", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build()?; + + let creation = TableCreation::builder() + .location(location.to_string()) + .name(name.to_string()) + .properties(HashMap::new()) + .schema(schema) + .build(); + + Ok(creation) +} + +#[tokio::test] +async fn test_rename_table() -> Result<()> { + let catalog = get_catalog().await; + let creation: TableCreation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_rename_table".into())); + set_test_namespace(&catalog, namespace.name()).await?; + + let table: iceberg::table::Table = catalog.create_table(namespace.name(), creation).await?; + + let dest = TableIdent::new(namespace.name().clone(), "my_table_rename".to_string()); + + catalog.rename_table(table.identifier(), &dest).await?; + + let result = catalog.table_exists(&dest).await?; + + assert!(result); + + Ok(()) +} + +#[tokio::test] +async fn test_table_exists() -> Result<()> { + let catalog = get_catalog().await; + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_table_exists".into())); + set_test_namespace(&catalog, namespace.name()).await?; + + let table = catalog.create_table(namespace.name(), creation).await?; + + let result = catalog.table_exists(table.identifier()).await?; + + assert!(result); + + Ok(()) +} + +#[tokio::test] +async fn test_drop_table() -> Result<()> { + let catalog = get_catalog().await; + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_drop_table".into())); + set_test_namespace(&catalog, namespace.name()).await?; + + let table = catalog.create_table(namespace.name(), creation).await?; + + catalog.drop_table(table.identifier()).await?; + + let result = catalog.table_exists(table.identifier()).await?; + + assert!(!result); + + Ok(()) +} + +#[tokio::test] +async fn test_load_table() -> Result<()> { + let catalog = get_catalog().await; + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_load_table".into())); + set_test_namespace(&catalog, namespace.name()).await?; + + let expected = catalog.create_table(namespace.name(), creation).await?; + + let result = catalog + .load_table(&TableIdent::new( + namespace.name().clone(), + "my_table".to_string(), + )) + .await?; + + assert_eq!(result.identifier(), expected.identifier()); + assert_eq!(result.metadata_location(), expected.metadata_location()); + assert_eq!(result.metadata(), expected.metadata()); + + Ok(()) +} + +#[tokio::test] +async fn test_create_table() -> Result<()> { + let catalog = get_catalog().await; + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_create_table".into())); + set_test_namespace(&catalog, namespace.name()).await?; + + let result = catalog.create_table(namespace.name(), creation).await?; + + assert_eq!(result.identifier().name(), "my_table"); + assert!(result + .metadata_location() + .is_some_and(|location| location.starts_with("s3a://warehouse/hive/metadata/00000-"))); + assert!( + catalog + .file_io() + .is_exist("s3a://warehouse/hive/metadata/") + .await? + ); + + Ok(()) +} + +#[tokio::test] +async fn test_list_tables() -> Result<()> { + let catalog = get_catalog().await; + let ns = Namespace::new(NamespaceIdent::new("test_list_tables".into())); + let result = catalog.list_tables(ns.name()).await?; + set_test_namespace(&catalog, ns.name()).await?; + + assert_eq!(result, vec![]); + + let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; + catalog.create_table(ns.name(), creation).await?; + let result = catalog.list_tables(ns.name()).await?; + + assert_eq!(result, vec![TableIdent::new( + ns.name().clone(), + "my_table".to_string() + )]); + + Ok(()) +} + +#[tokio::test] +async fn test_list_namespace() -> Result<()> { + let catalog = get_catalog().await; + + let result_no_parent = catalog.list_namespaces(None).await?; + + let result_with_parent = catalog + .list_namespaces(Some(&NamespaceIdent::new("parent".into()))) + .await?; + + assert!(result_no_parent.contains(&NamespaceIdent::new("default".into()))); + assert!(result_with_parent.is_empty()); + + Ok(()) +} + +#[tokio::test] +async fn test_create_namespace() -> Result<()> { + let catalog = get_catalog().await; + + let properties = HashMap::from([ + ("comment".to_string(), "my_description".to_string()), + ("location".to_string(), "my_location".to_string()), + ( + "hive.metastore.database.owner".to_string(), + "apache".to_string(), + ), + ( + "hive.metastore.database.owner-type".to_string(), + "user".to_string(), + ), + ("key1".to_string(), "value1".to_string()), + ]); + + let ns = Namespace::with_properties( + NamespaceIdent::new("test_create_namespace".into()), + properties.clone(), + ); + + let result = catalog.create_namespace(ns.name(), properties).await?; + + assert_eq!(result, ns); + + Ok(()) +} + +#[tokio::test] +async fn test_get_default_namespace() -> Result<()> { + let catalog = get_catalog().await; + + let ns = Namespace::new(NamespaceIdent::new("default".into())); + let properties = HashMap::from([ + ("location".to_string(), "s3a://warehouse/hive".to_string()), + ( + "hive.metastore.database.owner-type".to_string(), + "Role".to_string(), + ), + ("comment".to_string(), "Default Hive database".to_string()), + ( + "hive.metastore.database.owner".to_string(), + "public".to_string(), + ), + ]); + + let expected = Namespace::with_properties(NamespaceIdent::new("default".into()), properties); + + let result = catalog.get_namespace(ns.name()).await?; + + assert_eq!(expected, result); + + Ok(()) +} + +#[tokio::test] +async fn test_namespace_exists() -> Result<()> { + let catalog = get_catalog().await; + + let ns_exists = Namespace::new(NamespaceIdent::new("default".into())); + let ns_not_exists = Namespace::new(NamespaceIdent::new("test_namespace_exists".into())); + + let result_exists = catalog.namespace_exists(ns_exists.name()).await?; + let result_not_exists = catalog.namespace_exists(ns_not_exists.name()).await?; + + assert!(result_exists); + assert!(!result_not_exists); + + Ok(()) +} + +#[tokio::test] +async fn test_update_namespace() -> Result<()> { + let catalog = get_catalog().await; + + let ns = NamespaceIdent::new("test_update_namespace".into()); + set_test_namespace(&catalog, &ns).await?; + let properties = HashMap::from([("comment".to_string(), "my_update".to_string())]); + + catalog.update_namespace(&ns, properties).await?; + + let db = catalog.get_namespace(&ns).await?; + + assert_eq!( + db.properties().get("comment"), + Some(&"my_update".to_string()) + ); + + Ok(()) +} + +#[tokio::test] +async fn test_drop_namespace() -> Result<()> { + let catalog = get_catalog().await; + + let ns = Namespace::new(NamespaceIdent::new("delete_me".into())); + + catalog.create_namespace(ns.name(), HashMap::new()).await?; + + let result = catalog.namespace_exists(ns.name()).await?; + assert!(result); + + catalog.drop_namespace(ns.name()).await?; + + let result = catalog.namespace_exists(ns.name()).await?; + assert!(!result); + + Ok(()) +} diff --git a/crates/catalog/memory/Cargo.toml b/crates/catalog/memory/Cargo.toml new file mode 100644 index 000000000..011479efc --- /dev/null +++ b/crates/catalog/memory/Cargo.toml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "iceberg-catalog-memory" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +rust-version = { workspace = true } + +categories = ["database"] +description = "Apache Iceberg Rust Memory Catalog API" +repository = { workspace = true } +license = { workspace = true } +keywords = ["iceberg", "memory", "catalog"] + +[dependencies] +async-trait = { workspace = true } +futures = { workspace = true } +iceberg = { workspace = true } +itertools = { workspace = true } +serde_json = { workspace = true } +uuid = { workspace = true, features = ["v4"] } + +[dev-dependencies] +regex = { workspace = true } +tempfile = { workspace = true } +tokio = { workspace = true } diff --git a/crates/catalog/memory/DEPENDENCIES.rust.tsv b/crates/catalog/memory/DEPENDENCIES.rust.tsv new file mode 100644 index 000000000..b4617eedb --- /dev/null +++ b/crates/catalog/memory/DEPENDENCIES.rust.tsv @@ -0,0 +1,276 @@ +crate 0BSD Apache-2.0 Apache-2.0 WITH LLVM-exception BSD-2-Clause BSD-3-Clause BSL-1.0 CC0-1.0 ISC MIT MPL-2.0 OpenSSL Unicode-DFS-2016 Unlicense Zlib +addr2line@0.22.0 X X +adler@1.0.2 X X X +adler32@1.2.0 X +ahash@0.8.11 X X +aho-corasick@1.1.3 X X +alloc-no-stdlib@2.0.4 X +alloc-stdlib@0.2.2 X +allocator-api2@0.2.18 X X +android-tzdata@0.1.1 X X +android_system_properties@0.1.5 X X +anstream@0.6.15 X X +anstyle@1.0.8 X X +anstyle-parse@0.2.5 X X +anstyle-query@1.1.1 X X +anstyle-wincon@3.0.4 X X +anyhow@1.0.86 X X +apache-avro@0.17.0 X +array-init@2.1.0 X X +arrayvec@0.7.4 X X +arrow-arith@52.2.0 X +arrow-array@52.2.0 X +arrow-buffer@52.2.0 X +arrow-cast@52.2.0 X +arrow-data@52.2.0 X +arrow-ipc@52.2.0 X +arrow-ord@52.2.0 X +arrow-schema@52.2.0 X +arrow-select@52.2.0 X +arrow-string@52.2.0 X +async-trait@0.1.81 X X +atoi@2.0.0 X +autocfg@1.3.0 X X +backon@0.4.4 X +backtrace@0.3.73 X X +base64@0.22.1 X X +bigdecimal@0.4.5 X X +bimap@0.6.3 X X +bitflags@1.3.2 X X +bitvec@1.0.1 X +block-buffer@0.10.4 X X +brotli@6.0.0 X X +brotli-decompressor@4.0.1 X X +bumpalo@3.16.0 X X +byteorder@1.5.0 X X +bytes@1.7.1 X +cc@1.1.11 X X +cfg-if@1.0.0 X X +chrono@0.4.38 X X +colorchoice@1.0.2 X X +const-oid@0.9.6 X X +const-random@0.1.18 X X +const-random-macro@0.1.16 X X +core-foundation-sys@0.8.7 X X +core2@0.4.0 X X +cpufeatures@0.2.13 X X +crc32c@0.6.8 X X +crc32fast@1.4.2 X X +crunchy@0.2.2 X +crypto-common@0.1.6 X X +darling@0.20.10 X +darling_core@0.20.10 X +darling_macro@0.20.10 X +dary_heap@0.3.6 X X +derive_builder@0.20.0 X X +derive_builder_core@0.20.0 X X +derive_builder_macro@0.20.0 X X +digest@0.10.7 X X +either@1.13.0 X X +env_filter@0.1.2 X X +env_logger@0.11.5 X X +fastrand@2.1.0 X X +flagset@0.4.6 X +flatbuffers@24.3.25 X +flate2@1.0.31 X X +fnv@1.0.7 X X +form_urlencoded@1.2.1 X X +funty@2.0.0 X +futures@0.3.30 X X +futures-channel@0.3.30 X X +futures-core@0.3.30 X X +futures-executor@0.3.30 X X +futures-io@0.3.30 X X +futures-macro@0.3.30 X X +futures-sink@0.3.30 X X +futures-task@0.3.30 X X +futures-util@0.3.30 X X +generic-array@0.14.7 X +getrandom@0.2.15 X X +gimli@0.29.0 X X +half@2.4.1 X X +hashbrown@0.14.5 X X +heck@0.5.0 X X +hermit-abi@0.3.9 X X +hex@0.4.3 X X +hmac@0.12.1 X X +home@0.5.9 X X +http@1.1.0 X X +http-body@1.0.1 X +http-body-util@0.1.2 X +httparse@1.9.4 X X +humantime@2.1.0 X X +hyper@1.4.1 X +hyper-rustls@0.27.2 X X X +hyper-util@0.1.7 X +iana-time-zone@0.1.60 X X +iana-time-zone-haiku@0.1.2 X X +iceberg@0.3.0 X +iceberg-catalog-memory@0.3.0 X +iceberg_test_utils@0.3.0 X +ident_case@1.0.1 X X +idna@0.5.0 X X +integer-encoding@3.0.4 X +ipnet@2.9.0 X X +is_terminal_polyfill@1.70.1 X X +itertools@0.13.0 X X +itoa@1.0.11 X X +jobserver@0.1.32 X X +js-sys@0.3.70 X X +lexical-core@0.8.5 X X +lexical-parse-float@0.8.5 X X +lexical-parse-integer@0.8.6 X X +lexical-util@0.8.5 X X +lexical-write-float@0.8.5 X X +lexical-write-integer@0.8.5 X X +libc@0.2.155 X X +libflate@2.1.0 X +libflate_lz77@2.1.0 X +libm@0.2.8 X X +log@0.4.22 X X +lz4_flex@0.11.3 X +md-5@0.10.6 X X +memchr@2.7.4 X X +mime@0.3.17 X X +miniz_oxide@0.7.4 X X X +mio@1.0.2 X +murmur3@0.5.2 X X +num@0.4.3 X X +num-bigint@0.4.6 X X +num-complex@0.4.6 X X +num-integer@0.1.46 X X +num-iter@0.1.45 X X +num-rational@0.4.2 X X +num-traits@0.2.19 X X +object@0.36.3 X X +once_cell@1.19.0 X X +opendal@0.49.0 X +ordered-float@2.10.1 X +ordered-float@4.2.2 X +parquet@52.2.0 X +paste@1.0.15 X X +percent-encoding@2.3.1 X X +pin-project@1.1.5 X X +pin-project-internal@1.1.5 X X +pin-project-lite@0.2.14 X X +pin-utils@0.1.0 X X +pkg-config@0.3.30 X X +ppv-lite86@0.2.20 X X +proc-macro2@1.0.86 X X +quad-rand@0.2.1 X +quick-xml@0.36.1 X +quote@1.0.36 X X +radium@0.7.0 X +rand@0.8.5 X X +rand_chacha@0.3.1 X X +rand_core@0.6.4 X X +regex@1.10.6 X X +regex-automata@0.4.7 X X +regex-lite@0.1.6 X X +regex-syntax@0.8.4 X X +reqsign@0.16.0 X +reqwest@0.12.5 X X +ring@0.17.8 X +rle-decode-fast@1.0.3 X X +rust_decimal@1.35.0 X +rustc-demangle@0.1.24 X X +rustc_version@0.4.0 X X +rustls@0.23.12 X X X +rustls-pemfile@2.1.3 X X X +rustls-pki-types@1.8.0 X X +rustls-webpki@0.102.6 X +rustversion@1.0.17 X X +ryu@1.0.18 X X +semver@1.0.23 X X +seq-macro@0.3.5 X X +serde@1.0.207 X X +serde_bytes@0.11.15 X X +serde_derive@1.0.207 X X +serde_json@1.0.124 X X +serde_repr@0.1.19 X X +serde_urlencoded@0.7.1 X X +serde_with@3.9.0 X X +serde_with_macros@3.9.0 X X +sha1@0.10.6 X X +sha2@0.10.8 X X +shlex@1.3.0 X X +slab@0.4.9 X +smallvec@1.13.2 X X +snap@1.1.1 X +socket2@0.5.7 X X +spin@0.9.8 X +static_assertions@1.1.0 X X +strsim@0.11.1 X +strum@0.26.3 X +strum_macros@0.26.4 X +subtle@2.6.1 X +syn@2.0.74 X X +sync_wrapper@1.0.1 X +tap@1.0.1 X +thiserror@1.0.63 X X +thiserror-impl@1.0.63 X X +thrift@0.17.0 X +tiny-keccak@2.0.2 X +tinyvec@1.8.0 X X X +tinyvec_macros@0.1.1 X X X +tokio@1.39.2 X +tokio-macros@2.4.0 X +tokio-rustls@0.26.0 X X +tokio-util@0.7.11 X +tower@0.4.13 X +tower-layer@0.3.3 X +tower-service@0.3.3 X +tracing@0.1.40 X +tracing-core@0.1.32 X +try-lock@0.2.5 X +twox-hash@1.6.3 X +typed-builder@0.19.1 X X +typed-builder-macro@0.19.1 X X +typenum@1.17.0 X X +unicode-bidi@0.3.15 X X +unicode-ident@1.0.12 X X X +unicode-normalization@0.1.23 X X +untrusted@0.9.0 X +url@2.5.2 X X +utf8parse@0.2.2 X X +uuid@1.10.0 X X +version_check@0.9.5 X X +want@0.3.1 X +wasi@0.11.0+wasi-snapshot-preview1 X X X +wasm-bindgen@0.2.93 X X +wasm-bindgen-backend@0.2.93 X X +wasm-bindgen-futures@0.4.43 X X +wasm-bindgen-macro@0.2.93 X X +wasm-bindgen-macro-support@0.2.93 X X +wasm-bindgen-shared@0.2.93 X X +wasm-streams@0.4.0 X X +web-sys@0.3.70 X X +webpki-roots@0.26.3 X +windows-core@0.52.0 X X +windows-sys@0.48.0 X X +windows-sys@0.52.0 X X +windows-targets@0.48.5 X X +windows-targets@0.52.6 X X +windows_aarch64_gnullvm@0.48.5 X X +windows_aarch64_gnullvm@0.52.6 X X +windows_aarch64_msvc@0.48.5 X X +windows_aarch64_msvc@0.52.6 X X +windows_i686_gnu@0.48.5 X X +windows_i686_gnu@0.52.6 X X +windows_i686_gnullvm@0.52.6 X X +windows_i686_msvc@0.48.5 X X +windows_i686_msvc@0.52.6 X X +windows_x86_64_gnu@0.48.5 X X +windows_x86_64_gnu@0.52.6 X X +windows_x86_64_gnullvm@0.48.5 X X +windows_x86_64_gnullvm@0.52.6 X X +windows_x86_64_msvc@0.48.5 X X +windows_x86_64_msvc@0.52.6 X X +winreg@0.52.0 X +wyz@0.5.1 X +zerocopy@0.7.35 X X X +zerocopy-derive@0.7.35 X X X +zeroize@1.8.1 X X +zstd@0.13.2 X +zstd-safe@7.2.1 X X +zstd-sys@2.0.12+zstd.1.5.6 X X diff --git a/crates/catalog/memory/README.md b/crates/catalog/memory/README.md new file mode 100644 index 000000000..5b04f78ab --- /dev/null +++ b/crates/catalog/memory/README.md @@ -0,0 +1,27 @@ + + +# Apache Iceberg Memory Catalog Official Native Rust Implementation + +[![crates.io](https://img.shields.io/crates/v/iceberg-catalog-memory.svg)](https://crates.io/crates/iceberg-catalog-memory) +[![docs.rs](https://img.shields.io/docsrs/iceberg-catalog-memory.svg)](https://docs.rs/iceberg/latest/iceberg-catalog-memory/) + +This crate contains the official Native Rust implementation of Apache Iceberg Memory Catalog. + +See the [API documentation](https://docs.rs/iceberg-catalog-memory/latest) for examples and the full API. diff --git a/crates/catalog/memory/src/catalog.rs b/crates/catalog/memory/src/catalog.rs new file mode 100644 index 000000000..1da044821 --- /dev/null +++ b/crates/catalog/memory/src/catalog.rs @@ -0,0 +1,1678 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains memory catalog implementation. + +use std::collections::HashMap; + +use async_trait::async_trait; +use futures::lock::Mutex; +use iceberg::io::FileIO; +use iceberg::spec::{TableMetadata, TableMetadataBuilder}; +use iceberg::table::Table; +use iceberg::{ + Catalog, Error, ErrorKind, Namespace, NamespaceIdent, Result, TableCommit, TableCreation, + TableIdent, +}; +use itertools::Itertools; +use uuid::Uuid; + +use crate::namespace_state::NamespaceState; + +/// namespace `location` property +const LOCATION: &str = "location"; + +/// Memory catalog implementation. +#[derive(Debug)] +pub struct MemoryCatalog { + root_namespace_state: Mutex, + file_io: FileIO, + warehouse_location: Option, +} + +impl MemoryCatalog { + /// Creates an memory catalog. + pub fn new(file_io: FileIO, warehouse_location: Option) -> Self { + Self { + root_namespace_state: Mutex::new(NamespaceState::default()), + file_io, + warehouse_location, + } + } +} + +#[async_trait] +impl Catalog for MemoryCatalog { + /// List namespaces inside the catalog. + async fn list_namespaces( + &self, + maybe_parent: Option<&NamespaceIdent>, + ) -> Result> { + let root_namespace_state = self.root_namespace_state.lock().await; + + match maybe_parent { + None => { + let namespaces = root_namespace_state + .list_top_level_namespaces() + .into_iter() + .map(|str| NamespaceIdent::new(str.to_string())) + .collect_vec(); + + Ok(namespaces) + } + Some(parent_namespace_ident) => { + let namespaces = root_namespace_state + .list_namespaces_under(parent_namespace_ident)? + .into_iter() + .map(|name| NamespaceIdent::new(name.to_string())) + .collect_vec(); + + Ok(namespaces) + } + } + } + + /// Create a new namespace inside the catalog. + async fn create_namespace( + &self, + namespace_ident: &NamespaceIdent, + properties: HashMap, + ) -> Result { + let mut root_namespace_state = self.root_namespace_state.lock().await; + + root_namespace_state.insert_new_namespace(namespace_ident, properties.clone())?; + let namespace = Namespace::with_properties(namespace_ident.clone(), properties); + + Ok(namespace) + } + + /// Get a namespace information from the catalog. + async fn get_namespace(&self, namespace_ident: &NamespaceIdent) -> Result { + let root_namespace_state = self.root_namespace_state.lock().await; + + let namespace = Namespace::with_properties( + namespace_ident.clone(), + root_namespace_state + .get_properties(namespace_ident)? + .clone(), + ); + + Ok(namespace) + } + + /// Check if namespace exists in catalog. + async fn namespace_exists(&self, namespace_ident: &NamespaceIdent) -> Result { + let guarded_namespaces = self.root_namespace_state.lock().await; + + Ok(guarded_namespaces.namespace_exists(namespace_ident)) + } + + /// Update a namespace inside the catalog. + /// + /// # Behavior + /// + /// The properties must be the full set of namespace. + async fn update_namespace( + &self, + namespace_ident: &NamespaceIdent, + properties: HashMap, + ) -> Result<()> { + let mut root_namespace_state = self.root_namespace_state.lock().await; + + root_namespace_state.replace_properties(namespace_ident, properties) + } + + /// Drop a namespace from the catalog. + async fn drop_namespace(&self, namespace_ident: &NamespaceIdent) -> Result<()> { + let mut root_namespace_state = self.root_namespace_state.lock().await; + + root_namespace_state.remove_existing_namespace(namespace_ident) + } + + /// List tables from namespace. + async fn list_tables(&self, namespace_ident: &NamespaceIdent) -> Result> { + let root_namespace_state = self.root_namespace_state.lock().await; + + let table_names = root_namespace_state.list_tables(namespace_ident)?; + let table_idents = table_names + .into_iter() + .map(|table_name| TableIdent::new(namespace_ident.clone(), table_name.clone())) + .collect_vec(); + + Ok(table_idents) + } + + /// Create a new table inside the namespace. + async fn create_table( + &self, + namespace_ident: &NamespaceIdent, + table_creation: TableCreation, + ) -> Result
{ + let mut root_namespace_state = self.root_namespace_state.lock().await; + + let table_name = table_creation.name.clone(); + let table_ident = TableIdent::new(namespace_ident.clone(), table_name); + + let (table_creation, location) = match table_creation.location.clone() { + Some(location) => (table_creation, location), + None => { + let namespace_properties = root_namespace_state.get_properties(namespace_ident)?; + let location_prefix = match namespace_properties.get(LOCATION) { + Some(namespace_location) => Ok(namespace_location.clone()), + None => match self.warehouse_location.clone() { + Some(warehouse_location) => Ok(format!("{}/{}", warehouse_location, namespace_ident.join("/"))), + None => Err(Error::new(ErrorKind::Unexpected, + format!( + "Cannot create table {:?}. No default path is set, please specify a location when creating a table.", + &table_ident + ))) + }, + }?; + + let location = format!("{}/{}", location_prefix, table_ident.name()); + + let new_table_creation = TableCreation { + location: Some(location.clone()), + ..table_creation + }; + + (new_table_creation, location) + } + }; + + let metadata = TableMetadataBuilder::from_table_creation(table_creation)?.build()?; + let metadata_location = format!( + "{}/metadata/{}-{}.metadata.json", + &location, + 0, + Uuid::new_v4() + ); + + self.file_io + .new_output(&metadata_location)? + .write(serde_json::to_vec(&metadata)?.into()) + .await?; + + root_namespace_state.insert_new_table(&table_ident, metadata_location.clone())?; + + Table::builder() + .file_io(self.file_io.clone()) + .metadata_location(metadata_location) + .metadata(metadata) + .identifier(table_ident) + .build() + } + + /// Load table from the catalog. + async fn load_table(&self, table_ident: &TableIdent) -> Result
{ + let root_namespace_state = self.root_namespace_state.lock().await; + + let metadata_location = root_namespace_state.get_existing_table_location(table_ident)?; + let input_file = self.file_io.new_input(metadata_location)?; + let metadata_content = input_file.read().await?; + let metadata = serde_json::from_slice::(&metadata_content)?; + + Table::builder() + .file_io(self.file_io.clone()) + .metadata_location(metadata_location.clone()) + .metadata(metadata) + .identifier(table_ident.clone()) + .build() + } + + /// Drop a table from the catalog. + async fn drop_table(&self, table_ident: &TableIdent) -> Result<()> { + let mut root_namespace_state = self.root_namespace_state.lock().await; + + root_namespace_state.remove_existing_table(table_ident) + } + + /// Check if a table exists in the catalog. + async fn table_exists(&self, table_ident: &TableIdent) -> Result { + let root_namespace_state = self.root_namespace_state.lock().await; + + root_namespace_state.table_exists(table_ident) + } + + /// Rename a table in the catalog. + async fn rename_table( + &self, + src_table_ident: &TableIdent, + dst_table_ident: &TableIdent, + ) -> Result<()> { + let mut root_namespace_state = self.root_namespace_state.lock().await; + + let mut new_root_namespace_state = root_namespace_state.clone(); + let metadata_location = new_root_namespace_state + .get_existing_table_location(src_table_ident)? + .clone(); + new_root_namespace_state.remove_existing_table(src_table_ident)?; + new_root_namespace_state.insert_new_table(dst_table_ident, metadata_location)?; + *root_namespace_state = new_root_namespace_state; + + Ok(()) + } + + /// Update a table to the catalog. + async fn update_table(&self, _commit: TableCommit) -> Result
{ + Err(Error::new( + ErrorKind::FeatureUnsupported, + "MemoryCatalog does not currently support updating tables.", + )) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + use std::hash::Hash; + use std::iter::FromIterator; + + use iceberg::io::FileIOBuilder; + use iceberg::spec::{NestedField, PartitionSpec, PrimitiveType, Schema, SortOrder, Type}; + use regex::Regex; + use tempfile::TempDir; + + use super::*; + + fn temp_path() -> String { + let temp_dir = TempDir::new().unwrap(); + temp_dir.path().to_str().unwrap().to_string() + } + + fn new_memory_catalog() -> impl Catalog { + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let warehouse_location = temp_path(); + MemoryCatalog::new(file_io, Some(warehouse_location)) + } + + async fn create_namespace(catalog: &C, namespace_ident: &NamespaceIdent) { + let _ = catalog + .create_namespace(namespace_ident, HashMap::new()) + .await + .unwrap(); + } + + async fn create_namespaces(catalog: &C, namespace_idents: &Vec<&NamespaceIdent>) { + for namespace_ident in namespace_idents { + let _ = create_namespace(catalog, namespace_ident).await; + } + } + + fn to_set(vec: Vec) -> HashSet { + HashSet::from_iter(vec) + } + + fn simple_table_schema() -> Schema { + Schema::builder() + .with_fields(vec![NestedField::required( + 1, + "foo", + Type::Primitive(PrimitiveType::Int), + ) + .into()]) + .build() + .unwrap() + } + + async fn create_table(catalog: &C, table_ident: &TableIdent) { + let _ = catalog + .create_table( + &table_ident.namespace, + TableCreation::builder() + .name(table_ident.name().into()) + .schema(simple_table_schema()) + .build(), + ) + .await + .unwrap(); + } + + async fn create_tables(catalog: &C, table_idents: Vec<&TableIdent>) { + for table_ident in table_idents { + create_table(catalog, table_ident).await; + } + } + + fn assert_table_eq(table: &Table, expected_table_ident: &TableIdent, expected_schema: &Schema) { + assert_eq!(table.identifier(), expected_table_ident); + + let metadata = table.metadata(); + + assert_eq!(metadata.current_schema().as_ref(), expected_schema); + + let expected_partition_spec = PartitionSpec::builder(expected_schema) + .with_spec_id(0) + .build() + .unwrap(); + + assert_eq!( + metadata + .partition_specs_iter() + .map(|p| p.as_ref()) + .collect_vec(), + vec![&expected_partition_spec] + ); + + let expected_sorted_order = SortOrder::builder() + .with_order_id(0) + .with_fields(vec![]) + .build(expected_schema) + .unwrap(); + + assert_eq!( + metadata + .sort_orders_iter() + .map(|s| s.as_ref()) + .collect_vec(), + vec![&expected_sorted_order] + ); + + assert_eq!(metadata.properties(), &HashMap::new()); + + assert!(!table.readonly()); + } + + const UUID_REGEX_STR: &str = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"; + + fn assert_table_metadata_location_matches(table: &Table, regex_str: &str) { + let actual = table.metadata_location().unwrap().to_string(); + let regex = Regex::new(regex_str).unwrap(); + assert!(regex.is_match(&actual)) + } + + #[tokio::test] + async fn test_list_namespaces_returns_empty_vector() { + let catalog = new_memory_catalog(); + + assert_eq!(catalog.list_namespaces(None).await.unwrap(), vec![]); + } + + #[tokio::test] + async fn test_list_namespaces_returns_single_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("abc".into()); + create_namespace(&catalog, &namespace_ident).await; + + assert_eq!(catalog.list_namespaces(None).await.unwrap(), vec![ + namespace_ident + ]); + } + + #[tokio::test] + async fn test_list_namespaces_returns_multiple_namespaces() { + let catalog = new_memory_catalog(); + let namespace_ident_1 = NamespaceIdent::new("a".into()); + let namespace_ident_2 = NamespaceIdent::new("b".into()); + create_namespaces(&catalog, &vec![&namespace_ident_1, &namespace_ident_2]).await; + + assert_eq!( + to_set(catalog.list_namespaces(None).await.unwrap()), + to_set(vec![namespace_ident_1, namespace_ident_2]) + ); + } + + #[tokio::test] + async fn test_list_namespaces_returns_only_top_level_namespaces() { + let catalog = new_memory_catalog(); + let namespace_ident_1 = NamespaceIdent::new("a".into()); + let namespace_ident_2 = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_3 = NamespaceIdent::new("b".into()); + create_namespaces(&catalog, &vec![ + &namespace_ident_1, + &namespace_ident_2, + &namespace_ident_3, + ]) + .await; + + assert_eq!( + to_set(catalog.list_namespaces(None).await.unwrap()), + to_set(vec![namespace_ident_1, namespace_ident_3]) + ); + } + + #[tokio::test] + async fn test_list_namespaces_returns_no_namespaces_under_parent() { + let catalog = new_memory_catalog(); + let namespace_ident_1 = NamespaceIdent::new("a".into()); + let namespace_ident_2 = NamespaceIdent::new("b".into()); + create_namespaces(&catalog, &vec![&namespace_ident_1, &namespace_ident_2]).await; + + assert_eq!( + catalog + .list_namespaces(Some(&namespace_ident_1)) + .await + .unwrap(), + vec![] + ); + } + + #[tokio::test] + async fn test_list_namespaces_returns_namespace_under_parent() { + let catalog = new_memory_catalog(); + let namespace_ident_1 = NamespaceIdent::new("a".into()); + let namespace_ident_2 = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_3 = NamespaceIdent::new("c".into()); + create_namespaces(&catalog, &vec![ + &namespace_ident_1, + &namespace_ident_2, + &namespace_ident_3, + ]) + .await; + + assert_eq!( + to_set(catalog.list_namespaces(None).await.unwrap()), + to_set(vec![namespace_ident_1.clone(), namespace_ident_3]) + ); + + assert_eq!( + catalog + .list_namespaces(Some(&namespace_ident_1)) + .await + .unwrap(), + vec![NamespaceIdent::new("b".into())] + ); + } + + #[tokio::test] + async fn test_list_namespaces_returns_multiple_namespaces_under_parent() { + let catalog = new_memory_catalog(); + let namespace_ident_1 = NamespaceIdent::new("a".to_string()); + let namespace_ident_2 = NamespaceIdent::from_strs(vec!["a", "a"]).unwrap(); + let namespace_ident_3 = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_4 = NamespaceIdent::from_strs(vec!["a", "c"]).unwrap(); + let namespace_ident_5 = NamespaceIdent::new("b".into()); + create_namespaces(&catalog, &vec![ + &namespace_ident_1, + &namespace_ident_2, + &namespace_ident_3, + &namespace_ident_4, + &namespace_ident_5, + ]) + .await; + + assert_eq!( + to_set( + catalog + .list_namespaces(Some(&namespace_ident_1)) + .await + .unwrap() + ), + to_set(vec![ + NamespaceIdent::new("a".into()), + NamespaceIdent::new("b".into()), + NamespaceIdent::new("c".into()), + ]) + ); + } + + #[tokio::test] + async fn test_namespace_exists_returns_false() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + + assert!(!catalog + .namespace_exists(&NamespaceIdent::new("b".into())) + .await + .unwrap()); + } + + #[tokio::test] + async fn test_namespace_exists_returns_true() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + + assert!(catalog.namespace_exists(&namespace_ident).await.unwrap()); + } + + #[tokio::test] + async fn test_create_namespace_with_empty_properties() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("a".into()); + + assert_eq!( + catalog + .create_namespace(&namespace_ident, HashMap::new()) + .await + .unwrap(), + Namespace::new(namespace_ident.clone()) + ); + + assert_eq!( + catalog.get_namespace(&namespace_ident).await.unwrap(), + Namespace::with_properties(namespace_ident, HashMap::new()) + ); + } + + #[tokio::test] + async fn test_create_namespace_with_properties() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("abc".into()); + + let mut properties: HashMap = HashMap::new(); + properties.insert("k".into(), "v".into()); + + assert_eq!( + catalog + .create_namespace(&namespace_ident, properties.clone()) + .await + .unwrap(), + Namespace::with_properties(namespace_ident.clone(), properties.clone()) + ); + + assert_eq!( + catalog.get_namespace(&namespace_ident).await.unwrap(), + Namespace::with_properties(namespace_ident, properties) + ); + } + + #[tokio::test] + async fn test_create_namespace_throws_error_if_namespace_already_exists() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + + assert_eq!( + catalog + .create_namespace(&namespace_ident, HashMap::new()) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => Cannot create namespace {:?}. Namespace already exists.", + &namespace_ident + ) + ); + + assert_eq!( + catalog.get_namespace(&namespace_ident).await.unwrap(), + Namespace::with_properties(namespace_ident, HashMap::new()) + ); + } + + #[tokio::test] + async fn test_create_nested_namespace() { + let catalog = new_memory_catalog(); + let parent_namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &parent_namespace_ident).await; + + let child_namespace_ident = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + + assert_eq!( + catalog + .create_namespace(&child_namespace_ident, HashMap::new()) + .await + .unwrap(), + Namespace::new(child_namespace_ident.clone()) + ); + + assert_eq!( + catalog.get_namespace(&child_namespace_ident).await.unwrap(), + Namespace::with_properties(child_namespace_ident, HashMap::new()) + ); + } + + #[tokio::test] + async fn test_create_deeply_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + let namespace_ident_a_b_c = NamespaceIdent::from_strs(vec!["a", "b", "c"]).unwrap(); + + assert_eq!( + catalog + .create_namespace(&namespace_ident_a_b_c, HashMap::new()) + .await + .unwrap(), + Namespace::new(namespace_ident_a_b_c.clone()) + ); + + assert_eq!( + catalog.get_namespace(&namespace_ident_a_b_c).await.unwrap(), + Namespace::with_properties(namespace_ident_a_b_c, HashMap::new()) + ); + } + + #[tokio::test] + async fn test_create_nested_namespace_throws_error_if_top_level_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + + let nested_namespace_ident = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + + assert_eq!( + catalog + .create_namespace(&nested_namespace_ident, HashMap::new()) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + NamespaceIdent::new("a".into()) + ) + ); + + assert_eq!(catalog.list_namespaces(None).await.unwrap(), vec![]); + } + + #[tokio::test] + async fn test_create_deeply_nested_namespace_throws_error_if_intermediate_namespace_doesnt_exist( + ) { + let catalog = new_memory_catalog(); + + let namespace_ident_a = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident_a).await; + + let namespace_ident_a_b_c = NamespaceIdent::from_strs(vec!["a", "b", "c"]).unwrap(); + + assert_eq!( + catalog + .create_namespace(&namespace_ident_a_b_c, HashMap::new()) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + NamespaceIdent::from_strs(vec!["a", "b"]).unwrap() + ) + ); + + assert_eq!(catalog.list_namespaces(None).await.unwrap(), vec![ + namespace_ident_a.clone() + ]); + + assert_eq!( + catalog + .list_namespaces(Some(&namespace_ident_a)) + .await + .unwrap(), + vec![] + ); + } + + #[tokio::test] + async fn test_get_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("abc".into()); + + let mut properties: HashMap = HashMap::new(); + properties.insert("k".into(), "v".into()); + let _ = catalog + .create_namespace(&namespace_ident, properties.clone()) + .await + .unwrap(); + + assert_eq!( + catalog.get_namespace(&namespace_ident).await.unwrap(), + Namespace::with_properties(namespace_ident, properties) + ) + } + + #[tokio::test] + async fn test_get_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + assert_eq!( + catalog.get_namespace(&namespace_ident_a_b).await.unwrap(), + Namespace::with_properties(namespace_ident_a_b, HashMap::new()) + ); + } + + #[tokio::test] + async fn test_get_deeply_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_a_b_c = NamespaceIdent::from_strs(vec!["a", "b", "c"]).unwrap(); + create_namespaces(&catalog, &vec![ + &namespace_ident_a, + &namespace_ident_a_b, + &namespace_ident_a_b_c, + ]) + .await; + + assert_eq!( + catalog.get_namespace(&namespace_ident_a_b_c).await.unwrap(), + Namespace::with_properties(namespace_ident_a_b_c, HashMap::new()) + ); + } + + #[tokio::test] + async fn test_get_namespace_throws_error_if_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + create_namespace(&catalog, &NamespaceIdent::new("a".into())).await; + + let non_existent_namespace_ident = NamespaceIdent::new("b".into()); + assert_eq!( + catalog + .get_namespace(&non_existent_namespace_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_namespace_ident + ) + ) + } + + #[tokio::test] + async fn test_update_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("abc".into()); + create_namespace(&catalog, &namespace_ident).await; + + let mut new_properties: HashMap = HashMap::new(); + new_properties.insert("k".into(), "v".into()); + + catalog + .update_namespace(&namespace_ident, new_properties.clone()) + .await + .unwrap(); + + assert_eq!( + catalog.get_namespace(&namespace_ident).await.unwrap(), + Namespace::with_properties(namespace_ident, new_properties) + ) + } + + #[tokio::test] + async fn test_update_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + let mut new_properties = HashMap::new(); + new_properties.insert("k".into(), "v".into()); + + catalog + .update_namespace(&namespace_ident_a_b, new_properties.clone()) + .await + .unwrap(); + + assert_eq!( + catalog.get_namespace(&namespace_ident_a_b).await.unwrap(), + Namespace::with_properties(namespace_ident_a_b, new_properties) + ); + } + + #[tokio::test] + async fn test_update_deeply_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_a_b_c = NamespaceIdent::from_strs(vec!["a", "b", "c"]).unwrap(); + create_namespaces(&catalog, &vec![ + &namespace_ident_a, + &namespace_ident_a_b, + &namespace_ident_a_b_c, + ]) + .await; + + let mut new_properties = HashMap::new(); + new_properties.insert("k".into(), "v".into()); + + catalog + .update_namespace(&namespace_ident_a_b_c, new_properties.clone()) + .await + .unwrap(); + + assert_eq!( + catalog.get_namespace(&namespace_ident_a_b_c).await.unwrap(), + Namespace::with_properties(namespace_ident_a_b_c, new_properties) + ); + } + + #[tokio::test] + async fn test_update_namespace_throws_error_if_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + create_namespace(&catalog, &NamespaceIdent::new("abc".into())).await; + + let non_existent_namespace_ident = NamespaceIdent::new("def".into()); + assert_eq!( + catalog + .update_namespace(&non_existent_namespace_ident, HashMap::new()) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_namespace_ident + ) + ) + } + + #[tokio::test] + async fn test_drop_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("abc".into()); + create_namespace(&catalog, &namespace_ident).await; + + catalog.drop_namespace(&namespace_ident).await.unwrap(); + + assert!(!catalog.namespace_exists(&namespace_ident).await.unwrap()) + } + + #[tokio::test] + async fn test_drop_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + catalog.drop_namespace(&namespace_ident_a_b).await.unwrap(); + + assert!(!catalog + .namespace_exists(&namespace_ident_a_b) + .await + .unwrap()); + + assert!(catalog.namespace_exists(&namespace_ident_a).await.unwrap()); + } + + #[tokio::test] + async fn test_drop_deeply_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_a_b_c = NamespaceIdent::from_strs(vec!["a", "b", "c"]).unwrap(); + create_namespaces(&catalog, &vec![ + &namespace_ident_a, + &namespace_ident_a_b, + &namespace_ident_a_b_c, + ]) + .await; + + catalog + .drop_namespace(&namespace_ident_a_b_c) + .await + .unwrap(); + + assert!(!catalog + .namespace_exists(&namespace_ident_a_b_c) + .await + .unwrap()); + + assert!(catalog + .namespace_exists(&namespace_ident_a_b) + .await + .unwrap()); + + assert!(catalog.namespace_exists(&namespace_ident_a).await.unwrap()); + } + + #[tokio::test] + async fn test_drop_namespace_throws_error_if_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + + let non_existent_namespace_ident = NamespaceIdent::new("abc".into()); + assert_eq!( + catalog + .drop_namespace(&non_existent_namespace_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_namespace_ident + ) + ) + } + + #[tokio::test] + async fn test_drop_namespace_throws_error_if_nested_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + create_namespace(&catalog, &NamespaceIdent::new("a".into())).await; + + let non_existent_namespace_ident = + NamespaceIdent::from_vec(vec!["a".into(), "b".into()]).unwrap(); + assert_eq!( + catalog + .drop_namespace(&non_existent_namespace_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_namespace_ident + ) + ) + } + + #[tokio::test] + async fn test_dropping_a_namespace_also_drops_namespaces_nested_under_that_one() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + catalog.drop_namespace(&namespace_ident_a).await.unwrap(); + + assert!(!catalog.namespace_exists(&namespace_ident_a).await.unwrap()); + + assert!(!catalog + .namespace_exists(&namespace_ident_a_b) + .await + .unwrap()); + } + + #[tokio::test] + async fn test_create_table_with_location() { + let tmp_dir = TempDir::new().unwrap(); + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + + let table_name = "abc"; + let location = tmp_dir.path().to_str().unwrap().to_string(); + let table_creation = TableCreation::builder() + .name(table_name.into()) + .location(location.clone()) + .schema(simple_table_schema()) + .build(); + + let expected_table_ident = TableIdent::new(namespace_ident.clone(), table_name.into()); + + assert_table_eq( + &catalog + .create_table(&namespace_ident, table_creation) + .await + .unwrap(), + &expected_table_ident, + &simple_table_schema(), + ); + + let table = catalog.load_table(&expected_table_ident).await.unwrap(); + + assert_table_eq(&table, &expected_table_ident, &simple_table_schema()); + + assert!(table + .metadata_location() + .unwrap() + .to_string() + .starts_with(&location)) + } + + #[tokio::test] + async fn test_create_table_falls_back_to_namespace_location_if_table_location_is_missing() { + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let warehouse_location = temp_path(); + let catalog = MemoryCatalog::new(file_io, Some(warehouse_location.clone())); + + let namespace_ident = NamespaceIdent::new("a".into()); + let mut namespace_properties = HashMap::new(); + let namespace_location = temp_path(); + namespace_properties.insert(LOCATION.to_string(), namespace_location.to_string()); + catalog + .create_namespace(&namespace_ident, namespace_properties) + .await + .unwrap(); + + let table_name = "tbl1"; + let expected_table_ident = TableIdent::new(namespace_ident.clone(), table_name.into()); + let expected_table_metadata_location_regex = format!( + "^{}/tbl1/metadata/0-{}.metadata.json$", + namespace_location, UUID_REGEX_STR, + ); + + let table = catalog + .create_table( + &namespace_ident, + TableCreation::builder() + .name(table_name.into()) + .schema(simple_table_schema()) + // no location specified for table + .build(), + ) + .await + .unwrap(); + assert_table_eq(&table, &expected_table_ident, &simple_table_schema()); + assert_table_metadata_location_matches(&table, &expected_table_metadata_location_regex); + + let table = catalog.load_table(&expected_table_ident).await.unwrap(); + assert_table_eq(&table, &expected_table_ident, &simple_table_schema()); + assert_table_metadata_location_matches(&table, &expected_table_metadata_location_regex); + } + + #[tokio::test] + async fn test_create_table_in_nested_namespace_falls_back_to_nested_namespace_location_if_table_location_is_missing( + ) { + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let warehouse_location = temp_path(); + let catalog = MemoryCatalog::new(file_io, Some(warehouse_location.clone())); + + let namespace_ident = NamespaceIdent::new("a".into()); + let mut namespace_properties = HashMap::new(); + let namespace_location = temp_path(); + namespace_properties.insert(LOCATION.to_string(), namespace_location.to_string()); + catalog + .create_namespace(&namespace_ident, namespace_properties) + .await + .unwrap(); + + let nested_namespace_ident = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let mut nested_namespace_properties = HashMap::new(); + let nested_namespace_location = temp_path(); + nested_namespace_properties + .insert(LOCATION.to_string(), nested_namespace_location.to_string()); + catalog + .create_namespace(&nested_namespace_ident, nested_namespace_properties) + .await + .unwrap(); + + let table_name = "tbl1"; + let expected_table_ident = + TableIdent::new(nested_namespace_ident.clone(), table_name.into()); + let expected_table_metadata_location_regex = format!( + "^{}/tbl1/metadata/0-{}.metadata.json$", + nested_namespace_location, UUID_REGEX_STR, + ); + + let table = catalog + .create_table( + &nested_namespace_ident, + TableCreation::builder() + .name(table_name.into()) + .schema(simple_table_schema()) + // no location specified for table + .build(), + ) + .await + .unwrap(); + assert_table_eq(&table, &expected_table_ident, &simple_table_schema()); + assert_table_metadata_location_matches(&table, &expected_table_metadata_location_regex); + + let table = catalog.load_table(&expected_table_ident).await.unwrap(); + assert_table_eq(&table, &expected_table_ident, &simple_table_schema()); + assert_table_metadata_location_matches(&table, &expected_table_metadata_location_regex); + } + + #[tokio::test] + async fn test_create_table_falls_back_to_warehouse_location_if_both_table_location_and_namespace_location_are_missing( + ) { + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let warehouse_location = temp_path(); + let catalog = MemoryCatalog::new(file_io, Some(warehouse_location.clone())); + + let namespace_ident = NamespaceIdent::new("a".into()); + // note: no location specified in namespace_properties + let namespace_properties = HashMap::new(); + catalog + .create_namespace(&namespace_ident, namespace_properties) + .await + .unwrap(); + + let table_name = "tbl1"; + let expected_table_ident = TableIdent::new(namespace_ident.clone(), table_name.into()); + let expected_table_metadata_location_regex = format!( + "^{}/a/tbl1/metadata/0-{}.metadata.json$", + warehouse_location, UUID_REGEX_STR + ); + + let table = catalog + .create_table( + &namespace_ident, + TableCreation::builder() + .name(table_name.into()) + .schema(simple_table_schema()) + // no location specified for table + .build(), + ) + .await + .unwrap(); + assert_table_eq(&table, &expected_table_ident, &simple_table_schema()); + assert_table_metadata_location_matches(&table, &expected_table_metadata_location_regex); + + let table = catalog.load_table(&expected_table_ident).await.unwrap(); + assert_table_eq(&table, &expected_table_ident, &simple_table_schema()); + assert_table_metadata_location_matches(&table, &expected_table_metadata_location_regex); + } + + #[tokio::test] + async fn test_create_table_in_nested_namespace_falls_back_to_warehouse_location_if_both_table_location_and_namespace_location_are_missing( + ) { + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let warehouse_location = temp_path(); + let catalog = MemoryCatalog::new(file_io, Some(warehouse_location.clone())); + + let namespace_ident = NamespaceIdent::new("a".into()); + catalog + // note: no location specified in namespace_properties + .create_namespace(&namespace_ident, HashMap::new()) + .await + .unwrap(); + + let nested_namespace_ident = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + catalog + // note: no location specified in namespace_properties + .create_namespace(&nested_namespace_ident, HashMap::new()) + .await + .unwrap(); + + let table_name = "tbl1"; + let expected_table_ident = + TableIdent::new(nested_namespace_ident.clone(), table_name.into()); + let expected_table_metadata_location_regex = format!( + "^{}/a/b/tbl1/metadata/0-{}.metadata.json$", + warehouse_location, UUID_REGEX_STR + ); + + let table = catalog + .create_table( + &nested_namespace_ident, + TableCreation::builder() + .name(table_name.into()) + .schema(simple_table_schema()) + // no location specified for table + .build(), + ) + .await + .unwrap(); + assert_table_eq(&table, &expected_table_ident, &simple_table_schema()); + assert_table_metadata_location_matches(&table, &expected_table_metadata_location_regex); + + let table = catalog.load_table(&expected_table_ident).await.unwrap(); + assert_table_eq(&table, &expected_table_ident, &simple_table_schema()); + assert_table_metadata_location_matches(&table, &expected_table_metadata_location_regex); + } + + #[tokio::test] + async fn test_create_table_throws_error_if_table_location_and_namespace_location_and_warehouse_location_are_missing( + ) { + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let catalog = MemoryCatalog::new(file_io, None); + + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + + let table_name = "tbl1"; + let expected_table_ident = TableIdent::new(namespace_ident.clone(), table_name.into()); + + assert_eq!( + catalog + .create_table( + &namespace_ident, + TableCreation::builder() + .name(table_name.into()) + .schema(simple_table_schema()) + .build(), + ) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => Cannot create table {:?}. No default path is set, please specify a location when creating a table.", + &expected_table_ident + ) + ) + } + + #[tokio::test] + async fn test_create_table_throws_error_if_table_with_same_name_already_exists() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + let table_name = "tbl1"; + let table_ident = TableIdent::new(namespace_ident.clone(), table_name.into()); + create_table(&catalog, &table_ident).await; + + let tmp_dir = TempDir::new().unwrap(); + let location = tmp_dir.path().to_str().unwrap().to_string(); + + assert_eq!( + catalog + .create_table( + &namespace_ident, + TableCreation::builder() + .name(table_name.into()) + .schema(simple_table_schema()) + .location(location) + .build() + ) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => Cannot create table {:?}. Table already exists.", + &table_ident + ) + ); + } + + #[tokio::test] + async fn test_list_tables_returns_empty_vector() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + + assert_eq!(catalog.list_tables(&namespace_ident).await.unwrap(), vec![]); + } + + #[tokio::test] + async fn test_list_tables_returns_a_single_table() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + + let table_ident = TableIdent::new(namespace_ident.clone(), "tbl1".into()); + create_table(&catalog, &table_ident).await; + + assert_eq!(catalog.list_tables(&namespace_ident).await.unwrap(), vec![ + table_ident + ]); + } + + #[tokio::test] + async fn test_list_tables_returns_multiple_tables() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + + let table_ident_1 = TableIdent::new(namespace_ident.clone(), "tbl1".into()); + let table_ident_2 = TableIdent::new(namespace_ident.clone(), "tbl2".into()); + let _ = create_tables(&catalog, vec![&table_ident_1, &table_ident_2]).await; + + assert_eq!( + to_set(catalog.list_tables(&namespace_ident).await.unwrap()), + to_set(vec![table_ident_1, table_ident_2]) + ); + } + + #[tokio::test] + async fn test_list_tables_returns_tables_from_correct_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_1 = NamespaceIdent::new("n1".into()); + let namespace_ident_2 = NamespaceIdent::new("n2".into()); + create_namespaces(&catalog, &vec![&namespace_ident_1, &namespace_ident_2]).await; + + let table_ident_1 = TableIdent::new(namespace_ident_1.clone(), "tbl1".into()); + let table_ident_2 = TableIdent::new(namespace_ident_1.clone(), "tbl2".into()); + let table_ident_3 = TableIdent::new(namespace_ident_2.clone(), "tbl1".into()); + let _ = create_tables(&catalog, vec![ + &table_ident_1, + &table_ident_2, + &table_ident_3, + ]) + .await; + + assert_eq!( + to_set(catalog.list_tables(&namespace_ident_1).await.unwrap()), + to_set(vec![table_ident_1, table_ident_2]) + ); + + assert_eq!( + to_set(catalog.list_tables(&namespace_ident_2).await.unwrap()), + to_set(vec![table_ident_3]) + ); + } + + #[tokio::test] + async fn test_list_tables_returns_table_under_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + let table_ident = TableIdent::new(namespace_ident_a_b.clone(), "tbl1".into()); + create_table(&catalog, &table_ident).await; + + assert_eq!( + catalog.list_tables(&namespace_ident_a_b).await.unwrap(), + vec![table_ident] + ); + } + + #[tokio::test] + async fn test_list_tables_throws_error_if_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + + let non_existent_namespace_ident = NamespaceIdent::new("n1".into()); + + assert_eq!( + catalog + .list_tables(&non_existent_namespace_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_namespace_ident + ), + ); + } + + #[tokio::test] + async fn test_drop_table() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + let table_ident = TableIdent::new(namespace_ident.clone(), "tbl1".into()); + create_table(&catalog, &table_ident).await; + + catalog.drop_table(&table_ident).await.unwrap(); + } + + #[tokio::test] + async fn test_drop_table_drops_table_under_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + let table_ident = TableIdent::new(namespace_ident_a_b.clone(), "tbl1".into()); + create_table(&catalog, &table_ident).await; + + catalog.drop_table(&table_ident).await.unwrap(); + + assert_eq!( + catalog.list_tables(&namespace_ident_a_b).await.unwrap(), + vec![] + ); + } + + #[tokio::test] + async fn test_drop_table_throws_error_if_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + + let non_existent_namespace_ident = NamespaceIdent::new("n1".into()); + let non_existent_table_ident = + TableIdent::new(non_existent_namespace_ident.clone(), "tbl1".into()); + + assert_eq!( + catalog + .drop_table(&non_existent_table_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_namespace_ident + ), + ); + } + + #[tokio::test] + async fn test_drop_table_throws_error_if_table_doesnt_exist() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + + let non_existent_table_ident = TableIdent::new(namespace_ident.clone(), "tbl1".into()); + + assert_eq!( + catalog + .drop_table(&non_existent_table_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such table: {:?}", + non_existent_table_ident + ), + ); + } + + #[tokio::test] + async fn test_table_exists_returns_true() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + let table_ident = TableIdent::new(namespace_ident.clone(), "tbl1".into()); + create_table(&catalog, &table_ident).await; + + assert!(catalog.table_exists(&table_ident).await.unwrap()); + } + + #[tokio::test] + async fn test_table_exists_returns_false() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + let non_existent_table_ident = TableIdent::new(namespace_ident.clone(), "tbl1".into()); + + assert!(!catalog + .table_exists(&non_existent_table_ident) + .await + .unwrap()); + } + + #[tokio::test] + async fn test_table_exists_under_nested_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + let table_ident = TableIdent::new(namespace_ident_a_b.clone(), "tbl1".into()); + create_table(&catalog, &table_ident).await; + + assert!(catalog.table_exists(&table_ident).await.unwrap()); + + let non_existent_table_ident = TableIdent::new(namespace_ident_a_b.clone(), "tbl2".into()); + assert!(!catalog + .table_exists(&non_existent_table_ident) + .await + .unwrap()); + } + + #[tokio::test] + async fn test_table_exists_throws_error_if_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + + let non_existent_namespace_ident = NamespaceIdent::new("n1".into()); + let non_existent_table_ident = + TableIdent::new(non_existent_namespace_ident.clone(), "tbl1".into()); + + assert_eq!( + catalog + .table_exists(&non_existent_table_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_namespace_ident + ), + ); + } + + #[tokio::test] + async fn test_rename_table_in_same_namespace() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + let src_table_ident = TableIdent::new(namespace_ident.clone(), "tbl1".into()); + let dst_table_ident = TableIdent::new(namespace_ident.clone(), "tbl2".into()); + create_table(&catalog, &src_table_ident).await; + + catalog + .rename_table(&src_table_ident, &dst_table_ident) + .await + .unwrap(); + + assert_eq!(catalog.list_tables(&namespace_ident).await.unwrap(), vec![ + dst_table_ident + ],); + } + + #[tokio::test] + async fn test_rename_table_across_namespaces() { + let catalog = new_memory_catalog(); + let src_namespace_ident = NamespaceIdent::new("a".into()); + let dst_namespace_ident = NamespaceIdent::new("b".into()); + create_namespaces(&catalog, &vec![&src_namespace_ident, &dst_namespace_ident]).await; + let src_table_ident = TableIdent::new(src_namespace_ident.clone(), "tbl1".into()); + let dst_table_ident = TableIdent::new(dst_namespace_ident.clone(), "tbl2".into()); + create_table(&catalog, &src_table_ident).await; + + catalog + .rename_table(&src_table_ident, &dst_table_ident) + .await + .unwrap(); + + assert_eq!( + catalog.list_tables(&src_namespace_ident).await.unwrap(), + vec![], + ); + + assert_eq!( + catalog.list_tables(&dst_namespace_ident).await.unwrap(), + vec![dst_table_ident], + ); + } + + #[tokio::test] + async fn test_rename_table_src_table_is_same_as_dst_table() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + let table_ident = TableIdent::new(namespace_ident.clone(), "tbl".into()); + create_table(&catalog, &table_ident).await; + + catalog + .rename_table(&table_ident, &table_ident) + .await + .unwrap(); + + assert_eq!(catalog.list_tables(&namespace_ident).await.unwrap(), vec![ + table_ident + ],); + } + + #[tokio::test] + async fn test_rename_table_across_nested_namespaces() { + let catalog = new_memory_catalog(); + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_a_b_c = NamespaceIdent::from_strs(vec!["a", "b", "c"]).unwrap(); + create_namespaces(&catalog, &vec![ + &namespace_ident_a, + &namespace_ident_a_b, + &namespace_ident_a_b_c, + ]) + .await; + + let src_table_ident = TableIdent::new(namespace_ident_a_b_c.clone(), "tbl1".into()); + create_tables(&catalog, vec![&src_table_ident]).await; + + let dst_table_ident = TableIdent::new(namespace_ident_a_b.clone(), "tbl1".into()); + catalog + .rename_table(&src_table_ident, &dst_table_ident) + .await + .unwrap(); + + assert!(!catalog.table_exists(&src_table_ident).await.unwrap()); + + assert!(catalog.table_exists(&dst_table_ident).await.unwrap()); + } + + #[tokio::test] + async fn test_rename_table_throws_error_if_src_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + + let non_existent_src_namespace_ident = NamespaceIdent::new("n1".into()); + let src_table_ident = + TableIdent::new(non_existent_src_namespace_ident.clone(), "tbl1".into()); + + let dst_namespace_ident = NamespaceIdent::new("n2".into()); + create_namespace(&catalog, &dst_namespace_ident).await; + let dst_table_ident = TableIdent::new(dst_namespace_ident.clone(), "tbl1".into()); + + assert_eq!( + catalog + .rename_table(&src_table_ident, &dst_table_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_src_namespace_ident + ), + ); + } + + #[tokio::test] + async fn test_rename_table_throws_error_if_dst_namespace_doesnt_exist() { + let catalog = new_memory_catalog(); + let src_namespace_ident = NamespaceIdent::new("n1".into()); + let src_table_ident = TableIdent::new(src_namespace_ident.clone(), "tbl1".into()); + create_namespace(&catalog, &src_namespace_ident).await; + create_table(&catalog, &src_table_ident).await; + + let non_existent_dst_namespace_ident = NamespaceIdent::new("n2".into()); + let dst_table_ident = + TableIdent::new(non_existent_dst_namespace_ident.clone(), "tbl1".into()); + assert_eq!( + catalog + .rename_table(&src_table_ident, &dst_table_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_dst_namespace_ident + ), + ); + } + + #[tokio::test] + async fn test_rename_table_throws_error_if_src_table_doesnt_exist() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + let src_table_ident = TableIdent::new(namespace_ident.clone(), "tbl1".into()); + let dst_table_ident = TableIdent::new(namespace_ident.clone(), "tbl2".into()); + + assert_eq!( + catalog + .rename_table(&src_table_ident, &dst_table_ident) + .await + .unwrap_err() + .to_string(), + format!("Unexpected => No such table: {:?}", src_table_ident), + ); + } + + #[tokio::test] + async fn test_rename_table_throws_error_if_dst_table_already_exists() { + let catalog = new_memory_catalog(); + let namespace_ident = NamespaceIdent::new("n1".into()); + create_namespace(&catalog, &namespace_ident).await; + let src_table_ident = TableIdent::new(namespace_ident.clone(), "tbl1".into()); + let dst_table_ident = TableIdent::new(namespace_ident.clone(), "tbl2".into()); + create_tables(&catalog, vec![&src_table_ident, &dst_table_ident]).await; + + assert_eq!( + catalog + .rename_table(&src_table_ident, &dst_table_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => Cannot create table {:? }. Table already exists.", + &dst_table_ident + ), + ); + } +} diff --git a/crates/catalog/memory/src/lib.rs b/crates/catalog/memory/src/lib.rs new file mode 100644 index 000000000..8988ac7b2 --- /dev/null +++ b/crates/catalog/memory/src/lib.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Iceberg memory Catalog API implementation. + +#![deny(missing_docs)] + +mod catalog; +mod namespace_state; + +pub use catalog::*; diff --git a/crates/catalog/memory/src/namespace_state.rs b/crates/catalog/memory/src/namespace_state.rs new file mode 100644 index 000000000..a65319568 --- /dev/null +++ b/crates/catalog/memory/src/namespace_state.rs @@ -0,0 +1,298 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::{hash_map, HashMap}; + +use iceberg::{Error, ErrorKind, NamespaceIdent, Result, TableIdent}; +use itertools::Itertools; + +// Represents the state of a namespace +#[derive(Debug, Clone, Default)] +pub(crate) struct NamespaceState { + // Properties of this namespace + properties: HashMap, + // Namespaces nested inside this namespace + namespaces: HashMap, + // Mapping of tables to metadata locations in this namespace + table_metadata_locations: HashMap, +} + +fn no_such_namespace_err(namespace_ident: &NamespaceIdent) -> Result { + Err(Error::new( + ErrorKind::Unexpected, + format!("No such namespace: {:?}", namespace_ident), + )) +} + +fn no_such_table_err(table_ident: &TableIdent) -> Result { + Err(Error::new( + ErrorKind::Unexpected, + format!("No such table: {:?}", table_ident), + )) +} + +fn namespace_already_exists_err(namespace_ident: &NamespaceIdent) -> Result { + Err(Error::new( + ErrorKind::Unexpected, + format!( + "Cannot create namespace {:?}. Namespace already exists.", + namespace_ident + ), + )) +} + +fn table_already_exists_err(table_ident: &TableIdent) -> Result { + Err(Error::new( + ErrorKind::Unexpected, + format!( + "Cannot create table {:?}. Table already exists.", + table_ident + ), + )) +} + +impl NamespaceState { + // Returns the state of the given namespace or an error if doesn't exist + fn get_namespace(&self, namespace_ident: &NamespaceIdent) -> Result<&NamespaceState> { + let mut acc_name_parts = vec![]; + let mut namespace_state = self; + for next_name in namespace_ident.iter() { + acc_name_parts.push(next_name); + match namespace_state.namespaces.get(next_name) { + None => { + let namespace_ident = NamespaceIdent::from_strs(acc_name_parts)?; + return no_such_namespace_err(&namespace_ident); + } + Some(intermediate_namespace) => { + namespace_state = intermediate_namespace; + } + } + } + + Ok(namespace_state) + } + + // Returns the state of the given namespace or an error if doesn't exist + fn get_mut_namespace( + &mut self, + namespace_ident: &NamespaceIdent, + ) -> Result<&mut NamespaceState> { + let mut acc_name_parts = vec![]; + let mut namespace_state = self; + for next_name in namespace_ident.iter() { + acc_name_parts.push(next_name); + match namespace_state.namespaces.get_mut(next_name) { + None => { + let namespace_ident = NamespaceIdent::from_strs(acc_name_parts)?; + return no_such_namespace_err(&namespace_ident); + } + Some(intermediate_namespace) => { + namespace_state = intermediate_namespace; + } + } + } + + Ok(namespace_state) + } + + // Returns the state of the parent of the given namespace or an error if doesn't exist + fn get_mut_parent_namespace_of( + &mut self, + namespace_ident: &NamespaceIdent, + ) -> Result<(&mut NamespaceState, String)> { + match namespace_ident.split_last() { + None => Err(Error::new( + ErrorKind::DataInvalid, + "Namespace identifier can't be empty!", + )), + Some((child_namespace_name, parent_name_parts)) => { + let parent_namespace_state = if parent_name_parts.is_empty() { + Ok(self) + } else { + let parent_namespace_ident = NamespaceIdent::from_strs(parent_name_parts)?; + self.get_mut_namespace(&parent_namespace_ident) + }?; + + Ok((parent_namespace_state, child_namespace_name.clone())) + } + } + } + + // Returns any top-level namespaces + pub(crate) fn list_top_level_namespaces(&self) -> Vec<&String> { + self.namespaces.keys().collect_vec() + } + + // Returns any namespaces nested under the given namespace or an error if the given namespace does not exist + pub(crate) fn list_namespaces_under( + &self, + namespace_ident: &NamespaceIdent, + ) -> Result> { + let nested_namespace_names = self + .get_namespace(namespace_ident)? + .namespaces + .keys() + .collect_vec(); + + Ok(nested_namespace_names) + } + + // Returns true if the given namespace exists, otherwise false + pub(crate) fn namespace_exists(&self, namespace_ident: &NamespaceIdent) -> bool { + self.get_namespace(namespace_ident).is_ok() + } + + // Inserts the given namespace or returns an error if it already exists + pub(crate) fn insert_new_namespace( + &mut self, + namespace_ident: &NamespaceIdent, + properties: HashMap, + ) -> Result<()> { + let (parent_namespace_state, child_namespace_name) = + self.get_mut_parent_namespace_of(namespace_ident)?; + + match parent_namespace_state + .namespaces + .entry(child_namespace_name) + { + hash_map::Entry::Occupied(_) => namespace_already_exists_err(namespace_ident), + hash_map::Entry::Vacant(entry) => { + let _ = entry.insert(NamespaceState { + properties, + namespaces: HashMap::new(), + table_metadata_locations: HashMap::new(), + }); + + Ok(()) + } + } + } + + // Removes the given namespace or returns an error if doesn't exist + pub(crate) fn remove_existing_namespace( + &mut self, + namespace_ident: &NamespaceIdent, + ) -> Result<()> { + let (parent_namespace_state, child_namespace_name) = + self.get_mut_parent_namespace_of(namespace_ident)?; + + match parent_namespace_state + .namespaces + .remove(&child_namespace_name) + { + None => no_such_namespace_err(namespace_ident), + Some(_) => Ok(()), + } + } + + // Returns the properties of the given namespace or an error if doesn't exist + pub(crate) fn get_properties( + &self, + namespace_ident: &NamespaceIdent, + ) -> Result<&HashMap> { + let properties = &self.get_namespace(namespace_ident)?.properties; + + Ok(properties) + } + + // Returns the properties of this namespace or an error if doesn't exist + fn get_mut_properties( + &mut self, + namespace_ident: &NamespaceIdent, + ) -> Result<&mut HashMap> { + let properties = &mut self.get_mut_namespace(namespace_ident)?.properties; + + Ok(properties) + } + + // Replaces the properties of the given namespace or an error if doesn't exist + pub(crate) fn replace_properties( + &mut self, + namespace_ident: &NamespaceIdent, + new_properties: HashMap, + ) -> Result<()> { + let properties = self.get_mut_properties(namespace_ident)?; + *properties = new_properties; + + Ok(()) + } + + // Returns the list of table names under the given namespace + pub(crate) fn list_tables(&self, namespace_ident: &NamespaceIdent) -> Result> { + let table_names = self + .get_namespace(namespace_ident)? + .table_metadata_locations + .keys() + .collect_vec(); + + Ok(table_names) + } + + // Returns true if the given table exists, otherwise false + pub(crate) fn table_exists(&self, table_ident: &TableIdent) -> Result { + let namespace_state = self.get_namespace(table_ident.namespace())?; + let table_exists = namespace_state + .table_metadata_locations + .contains_key(&table_ident.name); + + Ok(table_exists) + } + + // Returns the metadata location of the given table or an error if doesn't exist + pub(crate) fn get_existing_table_location(&self, table_ident: &TableIdent) -> Result<&String> { + let namespace = self.get_namespace(table_ident.namespace())?; + + match namespace.table_metadata_locations.get(table_ident.name()) { + None => no_such_table_err(table_ident), + Some(table_metadadata_location) => Ok(table_metadadata_location), + } + } + + // Inserts the given table or returns an error if it already exists + pub(crate) fn insert_new_table( + &mut self, + table_ident: &TableIdent, + metadata_location: String, + ) -> Result<()> { + let namespace = self.get_mut_namespace(table_ident.namespace())?; + + match namespace + .table_metadata_locations + .entry(table_ident.name().to_string()) + { + hash_map::Entry::Occupied(_) => table_already_exists_err(table_ident), + hash_map::Entry::Vacant(entry) => { + let _ = entry.insert(metadata_location); + + Ok(()) + } + } + } + + // Removes the given table or returns an error if doesn't exist + pub(crate) fn remove_existing_table(&mut self, table_ident: &TableIdent) -> Result<()> { + let namespace = self.get_mut_namespace(table_ident.namespace())?; + + match namespace + .table_metadata_locations + .remove(table_ident.name()) + { + None => no_such_table_err(table_ident), + Some(_) => Ok(()), + } + } +} diff --git a/crates/catalog/rest/Cargo.toml b/crates/catalog/rest/Cargo.toml index 883f55c02..add57183b 100644 --- a/crates/catalog/rest/Cargo.toml +++ b/crates/catalog/rest/Cargo.toml @@ -17,30 +17,35 @@ [package] name = "iceberg-catalog-rest" -version = "0.1.0" -edition = "2021" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +rust-version = { workspace = true } categories = ["database"] description = "Apache Iceberg Rust REST API" -repository = "https://github.com/apache/iceberg-rust" -license = "Apache-2.0" +repository = { workspace = true } +license = { workspace = true } keywords = ["iceberg", "rest", "catalog"] [dependencies] # async-trait = { workspace = true } async-trait = { workspace = true } chrono = { workspace = true } +http = "1.1.0" iceberg = { workspace = true } +itertools = { workspace = true } log = "0.4.20" reqwest = { workspace = true } serde = { workspace = true } serde_derive = { workspace = true } serde_json = { workspace = true } +tokio = { workspace = true, features = ["sync"] } typed-builder = { workspace = true } -urlencoding = { workspace = true } uuid = { workspace = true, features = ["v4"] } [dev-dependencies] +ctor = { workspace = true } iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } mockito = { workspace = true } port_scanner = { workspace = true } diff --git a/crates/catalog/rest/DEPENDENCIES.rust.tsv b/crates/catalog/rest/DEPENDENCIES.rust.tsv new file mode 100644 index 000000000..43b4ed3d3 --- /dev/null +++ b/crates/catalog/rest/DEPENDENCIES.rust.tsv @@ -0,0 +1,288 @@ +crate 0BSD Apache-2.0 Apache-2.0 WITH LLVM-exception BSD-2-Clause BSD-3-Clause BSL-1.0 CC0-1.0 ISC MIT MPL-2.0 OpenSSL Unicode-DFS-2016 Unlicense Zlib +addr2line@0.22.0 X X +adler@1.0.2 X X X +adler32@1.2.0 X +ahash@0.8.11 X X +aho-corasick@1.1.3 X X +alloc-no-stdlib@2.0.4 X +alloc-stdlib@0.2.2 X +allocator-api2@0.2.18 X X +android-tzdata@0.1.1 X X +android_system_properties@0.1.5 X X +anstream@0.6.15 X X +anstyle@1.0.8 X X +anstyle-parse@0.2.5 X X +anstyle-query@1.1.1 X X +anstyle-wincon@3.0.4 X X +anyhow@1.0.86 X X +apache-avro@0.17.0 X +array-init@2.1.0 X X +arrayvec@0.7.4 X X +arrow-arith@52.2.0 X +arrow-array@52.2.0 X +arrow-buffer@52.2.0 X +arrow-cast@52.2.0 X +arrow-data@52.2.0 X +arrow-ipc@52.2.0 X +arrow-ord@52.2.0 X +arrow-schema@52.2.0 X +arrow-select@52.2.0 X +arrow-string@52.2.0 X +async-trait@0.1.81 X X +atoi@2.0.0 X +atomic-waker@1.1.2 X X +autocfg@1.3.0 X X +backon@0.4.4 X +backtrace@0.3.73 X X +base64@0.22.1 X X +bigdecimal@0.4.5 X X +bimap@0.6.3 X X +bitflags@1.3.2 X X +bitflags@2.6.0 X X +bitvec@1.0.1 X +block-buffer@0.10.4 X X +brotli@6.0.0 X X +brotli-decompressor@4.0.1 X X +bumpalo@3.16.0 X X +byteorder@1.5.0 X X +bytes@1.7.1 X +cc@1.1.11 X X +cfg-if@1.0.0 X X +chrono@0.4.38 X X +colorchoice@1.0.2 X X +const-oid@0.9.6 X X +const-random@0.1.18 X X +const-random-macro@0.1.16 X X +core-foundation-sys@0.8.7 X X +core2@0.4.0 X X +cpufeatures@0.2.13 X X +crc32c@0.6.8 X X +crc32fast@1.4.2 X X +crunchy@0.2.2 X +crypto-common@0.1.6 X X +darling@0.20.10 X +darling_core@0.20.10 X +darling_macro@0.20.10 X +dary_heap@0.3.6 X X +derive_builder@0.20.0 X X +derive_builder_core@0.20.0 X X +derive_builder_macro@0.20.0 X X +digest@0.10.7 X X +either@1.13.0 X X +env_filter@0.1.2 X X +env_logger@0.11.5 X X +equivalent@1.0.1 X X +fastrand@2.1.0 X X +flagset@0.4.6 X +flatbuffers@24.3.25 X +flate2@1.0.31 X X +fnv@1.0.7 X X +form_urlencoded@1.2.1 X X +funty@2.0.0 X +futures@0.3.30 X X +futures-channel@0.3.30 X X +futures-core@0.3.30 X X +futures-executor@0.3.30 X X +futures-io@0.3.30 X X +futures-macro@0.3.30 X X +futures-sink@0.3.30 X X +futures-task@0.3.30 X X +futures-util@0.3.30 X X +generic-array@0.14.7 X +getrandom@0.2.15 X X +gimli@0.29.0 X X +h2@0.4.5 X +half@2.4.1 X X +hashbrown@0.14.5 X X +heck@0.5.0 X X +hermit-abi@0.3.9 X X +hex@0.4.3 X X +hmac@0.12.1 X X +home@0.5.9 X X +http@1.1.0 X X +http-body@1.0.1 X +http-body-util@0.1.2 X +httparse@1.9.4 X X +httpdate@1.0.3 X X +humantime@2.1.0 X X +hyper@1.4.1 X +hyper-rustls@0.27.2 X X X +hyper-util@0.1.7 X +iana-time-zone@0.1.60 X X +iana-time-zone-haiku@0.1.2 X X +iceberg@0.3.0 X +iceberg-catalog-memory@0.3.0 X +iceberg-catalog-rest@0.3.0 X +iceberg_test_utils@0.3.0 X +ident_case@1.0.1 X X +idna@0.5.0 X X +indexmap@2.4.0 X X +integer-encoding@3.0.4 X +ipnet@2.9.0 X X +is_terminal_polyfill@1.70.1 X X +itertools@0.13.0 X X +itoa@1.0.11 X X +jobserver@0.1.32 X X +js-sys@0.3.70 X X +lexical-core@0.8.5 X X +lexical-parse-float@0.8.5 X X +lexical-parse-integer@0.8.6 X X +lexical-util@0.8.5 X X +lexical-write-float@0.8.5 X X +lexical-write-integer@0.8.5 X X +libc@0.2.155 X X +libflate@2.1.0 X +libflate_lz77@2.1.0 X +libm@0.2.8 X X +lock_api@0.4.12 X X +log@0.4.22 X X +lz4_flex@0.11.3 X +md-5@0.10.6 X X +memchr@2.7.4 X X +mime@0.3.17 X X +miniz_oxide@0.7.4 X X X +mio@1.0.2 X +murmur3@0.5.2 X X +num@0.4.3 X X +num-bigint@0.4.6 X X +num-complex@0.4.6 X X +num-integer@0.1.46 X X +num-iter@0.1.45 X X +num-rational@0.4.2 X X +num-traits@0.2.19 X X +object@0.36.3 X X +once_cell@1.19.0 X X +opendal@0.49.0 X +ordered-float@2.10.1 X +ordered-float@4.2.2 X +parking_lot@0.12.3 X X +parking_lot_core@0.9.10 X X +parquet@52.2.0 X +paste@1.0.15 X X +percent-encoding@2.3.1 X X +pin-project@1.1.5 X X +pin-project-internal@1.1.5 X X +pin-project-lite@0.2.14 X X +pin-utils@0.1.0 X X +pkg-config@0.3.30 X X +ppv-lite86@0.2.20 X X +proc-macro2@1.0.86 X X +quad-rand@0.2.1 X +quick-xml@0.36.1 X +quote@1.0.36 X X +radium@0.7.0 X +rand@0.8.5 X X +rand_chacha@0.3.1 X X +rand_core@0.6.4 X X +redox_syscall@0.5.3 X +regex@1.10.6 X X +regex-automata@0.4.7 X X +regex-lite@0.1.6 X X +regex-syntax@0.8.4 X X +reqsign@0.16.0 X +reqwest@0.12.5 X X +ring@0.17.8 X +rle-decode-fast@1.0.3 X X +rust_decimal@1.35.0 X +rustc-demangle@0.1.24 X X +rustc_version@0.4.0 X X +rustls@0.23.12 X X X +rustls-pemfile@2.1.3 X X X +rustls-pki-types@1.8.0 X X +rustls-webpki@0.102.6 X +rustversion@1.0.17 X X +ryu@1.0.18 X X +scopeguard@1.2.0 X X +semver@1.0.23 X X +seq-macro@0.3.5 X X +serde@1.0.207 X X +serde_bytes@0.11.15 X X +serde_derive@1.0.207 X X +serde_json@1.0.124 X X +serde_repr@0.1.19 X X +serde_urlencoded@0.7.1 X X +serde_with@3.9.0 X X +serde_with_macros@3.9.0 X X +sha1@0.10.6 X X +sha2@0.10.8 X X +shlex@1.3.0 X X +slab@0.4.9 X +smallvec@1.13.2 X X +snap@1.1.1 X +socket2@0.5.7 X X +spin@0.9.8 X +static_assertions@1.1.0 X X +strsim@0.11.1 X +strum@0.26.3 X +strum_macros@0.26.4 X +subtle@2.6.1 X +syn@2.0.74 X X +sync_wrapper@1.0.1 X +tap@1.0.1 X +thiserror@1.0.63 X X +thiserror-impl@1.0.63 X X +thrift@0.17.0 X +tiny-keccak@2.0.2 X +tinyvec@1.8.0 X X X +tinyvec_macros@0.1.1 X X X +tokio@1.39.2 X +tokio-macros@2.4.0 X +tokio-rustls@0.26.0 X X +tokio-util@0.7.11 X +tower@0.4.13 X +tower-layer@0.3.3 X +tower-service@0.3.3 X +tracing@0.1.40 X +tracing-core@0.1.32 X +try-lock@0.2.5 X +twox-hash@1.6.3 X +typed-builder@0.19.1 X X +typed-builder-macro@0.19.1 X X +typenum@1.17.0 X X +unicode-bidi@0.3.15 X X +unicode-ident@1.0.12 X X X +unicode-normalization@0.1.23 X X +untrusted@0.9.0 X +url@2.5.2 X X +utf8parse@0.2.2 X X +uuid@1.10.0 X X +version_check@0.9.5 X X +want@0.3.1 X +wasi@0.11.0+wasi-snapshot-preview1 X X X +wasm-bindgen@0.2.93 X X +wasm-bindgen-backend@0.2.93 X X +wasm-bindgen-futures@0.4.43 X X +wasm-bindgen-macro@0.2.93 X X +wasm-bindgen-macro-support@0.2.93 X X +wasm-bindgen-shared@0.2.93 X X +wasm-streams@0.4.0 X X +web-sys@0.3.70 X X +webpki-roots@0.26.3 X +windows-core@0.52.0 X X +windows-sys@0.48.0 X X +windows-sys@0.52.0 X X +windows-targets@0.48.5 X X +windows-targets@0.52.6 X X +windows_aarch64_gnullvm@0.48.5 X X +windows_aarch64_gnullvm@0.52.6 X X +windows_aarch64_msvc@0.48.5 X X +windows_aarch64_msvc@0.52.6 X X +windows_i686_gnu@0.48.5 X X +windows_i686_gnu@0.52.6 X X +windows_i686_gnullvm@0.52.6 X X +windows_i686_msvc@0.48.5 X X +windows_i686_msvc@0.52.6 X X +windows_x86_64_gnu@0.48.5 X X +windows_x86_64_gnu@0.52.6 X X +windows_x86_64_gnullvm@0.48.5 X X +windows_x86_64_gnullvm@0.52.6 X X +windows_x86_64_msvc@0.48.5 X X +windows_x86_64_msvc@0.52.6 X X +winreg@0.52.0 X +wyz@0.5.1 X +zerocopy@0.7.35 X X X +zerocopy-derive@0.7.35 X X X +zeroize@1.8.1 X X +zstd@0.13.2 X +zstd-safe@7.2.1 X X +zstd-sys@2.0.12+zstd.1.5.6 X X diff --git a/crates/catalog/rest/README.md b/crates/catalog/rest/README.md new file mode 100644 index 000000000..e3bb70e94 --- /dev/null +++ b/crates/catalog/rest/README.md @@ -0,0 +1,27 @@ + + +# Apache Iceberg Rest Catalog Official Native Rust Implementation + +[![crates.io](https://img.shields.io/crates/v/iceberg.svg)](https://crates.io/crates/iceberg-catalog-rest) +[![docs.rs](https://img.shields.io/docsrs/iceberg.svg)](https://docs.rs/iceberg/latest/iceberg-catalog-rest/) + +This crate contains the official Native Rust implementation of Apache Iceberg Rest Catalog. + +See the [API documentation](https://docs.rs/iceberg-catalog-rest/latest) for examples and the full API. diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index 7ccd108b6..1181c3cc1 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -18,35 +18,36 @@ //! This module contains rest catalog implementation. use std::collections::HashMap; +use std::str::FromStr; use async_trait::async_trait; -use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue}; -use reqwest::{Client, Request, Response, StatusCode}; -use serde::de::DeserializeOwned; -use typed_builder::TypedBuilder; -use urlencoding::encode; - -use crate::catalog::_serde::{ - CommitTableRequest, CommitTableResponse, CreateTableRequest, LoadTableResponse, -}; use iceberg::io::FileIO; use iceberg::table::Table; -use iceberg::Result; use iceberg::{ - Catalog, Error, ErrorKind, Namespace, NamespaceIdent, TableCommit, TableCreation, TableIdent, + Catalog, Error, ErrorKind, Namespace, NamespaceIdent, Result, TableCommit, TableCreation, + TableIdent, +}; +use itertools::Itertools; +use reqwest::header::{ + HeaderMap, HeaderName, HeaderValue, {self}, }; +use reqwest::{Method, StatusCode, Url}; +use tokio::sync::OnceCell; +use typed_builder::TypedBuilder; -use self::_serde::{ - CatalogConfig, ErrorResponse, ListNamespaceResponse, ListTableResponse, NamespaceSerde, +use crate::client::HttpClient; +use crate::types::{ + CatalogConfig, CommitTableRequest, CommitTableResponse, CreateTableRequest, ErrorResponse, + ListNamespaceResponse, ListTableResponse, LoadTableResponse, NamespaceSerde, RenameTableRequest, NO_CONTENT, OK, }; -const ICEBERG_REST_SPEC_VERSION: &str = "1.14"; +const ICEBERG_REST_SPEC_VERSION: &str = "0.14.1"; const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); const PATH_V1: &str = "v1"; /// Rest catalog configuration. -#[derive(Debug, TypedBuilder)] +#[derive(Clone, Debug, TypedBuilder)] pub struct RestCatalogConfig { uri: String, #[builder(default, setter(strip_option))] @@ -57,48 +58,93 @@ pub struct RestCatalogConfig { } impl RestCatalogConfig { + fn url_prefixed(&self, parts: &[&str]) -> String { + [&self.uri, PATH_V1] + .into_iter() + .chain(self.props.get("prefix").map(|s| &**s)) + .chain(parts.iter().cloned()) + .join("/") + } + fn config_endpoint(&self) -> String { [&self.uri, PATH_V1, "config"].join("/") } + pub(crate) fn get_token_endpoint(&self) -> String { + if let Some(oauth2_uri) = self.props.get("oauth2-server-uri") { + oauth2_uri.to_string() + } else if let Some(auth_url) = self.props.get("rest.authorization-url") { + log::warn!( + "'rest.authorization-url' is deprecated and will be removed in version 0.4.0. \ + Please use 'oauth2-server-uri' instead." + ); + auth_url.to_string() + } else { + [&self.uri, PATH_V1, "oauth", "tokens"].join("/") + } + } + fn namespaces_endpoint(&self) -> String { - [&self.uri, PATH_V1, "namespaces"].join("/") + self.url_prefixed(&["namespaces"]) } fn namespace_endpoint(&self, ns: &NamespaceIdent) -> String { - [&self.uri, PATH_V1, "namespaces", &ns.encode_in_url()].join("/") + self.url_prefixed(&["namespaces", &ns.to_url_string()]) } fn tables_endpoint(&self, ns: &NamespaceIdent) -> String { - [ - &self.uri, - PATH_V1, - "namespaces", - &ns.encode_in_url(), - "tables", - ] - .join("/") + self.url_prefixed(&["namespaces", &ns.to_url_string(), "tables"]) } fn rename_table_endpoint(&self) -> String { - [&self.uri, PATH_V1, "tables", "rename"].join("/") + self.url_prefixed(&["tables", "rename"]) } fn table_endpoint(&self, table: &TableIdent) -> String { - [ - &self.uri, - PATH_V1, + self.url_prefixed(&[ "namespaces", - &table.namespace.encode_in_url(), + &table.namespace.to_url_string(), "tables", - encode(&table.name).as_ref(), - ] - .join("/") + &table.name, + ]) } - fn try_create_rest_client(&self) -> Result { - //TODO: We will add oauth, ssl config, sigv4 later - let headers = HeaderMap::from_iter([ + /// Get the token from the config. + /// + /// Client will use `token` to send requests if exists. + pub(crate) fn token(&self) -> Option { + self.props.get("token").cloned() + } + + /// Get the credentials from the config. Client will use `credential` + /// to fetch a new token if exists. + /// + /// ## Output + /// + /// - `None`: No credential is set. + /// - `Some(None, client_secret)`: No client_id is set, use client_secret directly. + /// - `Some(Some(client_id), client_secret)`: Both client_id and client_secret are set. + pub(crate) fn credential(&self) -> Option<(Option, String)> { + let cred = self.props.get("credential")?; + + match cred.split_once(':') { + Some((client_id, client_secret)) => { + Some((Some(client_id.to_string()), client_secret.to_string())) + } + None => Some((None, cred.to_string())), + } + } + + /// Get the extra headers from config. + /// + /// We will include: + /// + /// - `content-type` + /// - `x-client-version` + /// - `user-agnet` + /// - all headers specified by `header.xxx` in props. + pub(crate) fn extra_headers(&self) -> Result { + let mut headers = HeaderMap::from_iter([ ( header::CONTENT_TYPE, HeaderValue::from_static("application/json"), @@ -113,106 +159,160 @@ impl RestCatalogConfig { ), ]); - Ok(HttpClient( - Client::builder().default_headers(headers).build()?, - )) + for (key, value) in self + .props + .iter() + .filter(|(k, _)| k.starts_with("header.")) + // The unwrap here is same since we are filtering the keys + .map(|(k, v)| (k.strip_prefix("header.").unwrap(), v)) + { + headers.insert( + HeaderName::from_str(key).map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Invalid header name: {key}"), + ) + .with_source(e) + })?, + HeaderValue::from_str(value).map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Invalid header value: {value}"), + ) + .with_source(e) + })?, + ); + } + + Ok(headers) } -} -#[derive(Debug)] -struct HttpClient(Client); - -impl HttpClient { - async fn query< - R: DeserializeOwned, - E: DeserializeOwned + Into, - const SUCCESS_CODE: u16, - >( - &self, - request: Request, - ) -> Result { - let resp = self.0.execute(request).await?; + /// Get the optional oauth headers from the config. + pub(crate) fn extra_oauth_params(&self) -> HashMap { + let mut params = HashMap::new(); - if resp.status().as_u16() == SUCCESS_CODE { - let text = resp.bytes().await?; - Ok(serde_json::from_slice::(&text).map_err(|e| { - Error::new( - ErrorKind::Unexpected, - "Failed to parse response from rest catalog server!", - ) - .with_context("json", String::from_utf8_lossy(&text)) - .with_source(e) - })?) + if let Some(scope) = self.props.get("scope") { + params.insert("scope".to_string(), scope.to_string()); } else { - let text = resp.bytes().await?; - let e = serde_json::from_slice::(&text).map_err(|e| { - Error::new( - ErrorKind::Unexpected, - "Failed to parse response from rest catalog server!", - ) - .with_context("json", String::from_utf8_lossy(&text)) - .with_source(e) - })?; - Err(e.into()) + params.insert("scope".to_string(), "catalog".to_string()); } - } - - async fn execute, const SUCCESS_CODE: u16>( - &self, - request: Request, - ) -> Result<()> { - let resp = self.0.execute(request).await?; - if resp.status().as_u16() == SUCCESS_CODE { - Ok(()) - } else { - let code = resp.status(); - let text = resp.bytes().await?; - let e = serde_json::from_slice::(&text).map_err(|e| { - Error::new( - ErrorKind::Unexpected, - "Failed to parse response from rest catalog server!", - ) - .with_context("json", String::from_utf8_lossy(&text)) - .with_context("code", code.to_string()) - .with_source(e) - })?; - Err(e.into()) + let optional_params = ["audience", "resource"]; + for param_name in optional_params { + if let Some(value) = self.props.get(param_name) { + params.insert(param_name.to_string(), value.to_string()); + } } + params } - /// More generic logic handling for special cases like head. - async fn do_execute>( - &self, - request: Request, - handler: impl FnOnce(&Response) -> Option, - ) -> Result { - let resp = self.0.execute(request).await?; - - if let Some(ret) = handler(&resp) { - Ok(ret) - } else { - let code = resp.status(); - let text = resp.bytes().await?; - let e = serde_json::from_slice::(&text).map_err(|e| { - Error::new( - ErrorKind::Unexpected, - "Failed to parse response from rest catalog server!", - ) - .with_context("code", code.to_string()) - .with_context("json", String::from_utf8_lossy(&text)) - .with_source(e) - })?; - Err(e.into()) + /// Merge the config with the given config fetched from rest server. + pub(crate) fn merge_with_config(mut self, mut config: CatalogConfig) -> Self { + if let Some(uri) = config.overrides.remove("uri") { + self.uri = uri; } + + let mut props = config.defaults; + props.extend(self.props); + props.extend(config.overrides); + + self.props = props; + self } } +#[derive(Debug)] +struct RestContext { + client: HttpClient, + + /// Runtime config is fetched from rest server and stored here. + /// + /// It's could be different from the user config. + config: RestCatalogConfig, +} + +impl RestContext {} + /// Rest catalog implementation. #[derive(Debug)] pub struct RestCatalog { - config: RestCatalogConfig, - client: HttpClient, + /// User config is stored as-is and never be changed. + /// + /// It's could be different from the config fetched from the server and used at runtime. + user_config: RestCatalogConfig, + ctx: OnceCell, +} + +impl RestCatalog { + /// Creates a rest catalog from config. + pub fn new(config: RestCatalogConfig) -> Self { + Self { + user_config: config, + ctx: OnceCell::new(), + } + } + + /// Get the context from the catalog. + async fn context(&self) -> Result<&RestContext> { + self.ctx + .get_or_try_init(|| async { + let catalog_config = RestCatalog::load_config(&self.user_config).await?; + let config = self.user_config.clone().merge_with_config(catalog_config); + let client = HttpClient::new(&config)?; + + Ok(RestContext { config, client }) + }) + .await + } + + /// Load the runtime config from the server by user_config. + /// + /// It's required for a rest catalog to update it's config after creation. + async fn load_config(user_config: &RestCatalogConfig) -> Result { + let client = HttpClient::new(user_config)?; + + let mut request = client.request(Method::GET, user_config.config_endpoint()); + + if let Some(warehouse_location) = &user_config.warehouse { + request = request.query(&[("warehouse", warehouse_location)]); + } + + let config = client + .query::(request.build()?) + .await?; + Ok(config) + } + + async fn load_file_io( + &self, + metadata_location: Option<&str>, + extra_config: Option>, + ) -> Result { + let mut props = self.context().await?.config.props.clone(); + if let Some(config) = extra_config { + props.extend(config); + } + + // If the warehouse is a logical identifier instead of a URL we don't want + // to raise an exception + let warehouse_path = match self.context().await?.config.warehouse.as_deref() { + Some(url) if Url::parse(url).is_ok() => Some(url), + Some(_) => None, + None => None, + }; + + let file_io = match warehouse_path.or(metadata_location) { + Some(url) => FileIO::from_path(url)?.with_props(props).build()?, + None => { + return Err(Error::new( + ErrorKind::Unexpected, + "Unable to load file io, neither warehouse nor metadata location is set!", + ))? + } + }; + + Ok(file_io) + } } #[async_trait] @@ -222,12 +322,17 @@ impl Catalog for RestCatalog { &self, parent: Option<&NamespaceIdent>, ) -> Result> { - let mut request = self.client.0.get(self.config.namespaces_endpoint()); + let mut request = self.context().await?.client.request( + Method::GET, + self.context().await?.config.namespaces_endpoint(), + ); if let Some(ns) = parent { - request = request.query(&[("parent", ns.encode_in_url())]); + request = request.query(&[("parent", ns.to_url_string())]); } let resp = self + .context() + .await? .client .query::(request.build()?) .await?; @@ -245,9 +350,13 @@ impl Catalog for RestCatalog { properties: HashMap, ) -> Result { let request = self + .context() + .await? .client - .0 - .post(self.config.namespaces_endpoint()) + .request( + Method::POST, + self.context().await?.config.namespaces_endpoint(), + ) .json(&NamespaceSerde { namespace: namespace.as_ref().clone(), properties: Some(properties), @@ -255,6 +364,8 @@ impl Catalog for RestCatalog { .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; @@ -265,12 +376,18 @@ impl Catalog for RestCatalog { /// Get a namespace information from the catalog. async fn get_namespace(&self, namespace: &NamespaceIdent) -> Result { let request = self + .context() + .await? .client - .0 - .get(self.config.namespace_endpoint(namespace)) + .request( + Method::GET, + self.context().await?.config.namespace_endpoint(namespace), + ) .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; @@ -295,12 +412,18 @@ impl Catalog for RestCatalog { async fn namespace_exists(&self, ns: &NamespaceIdent) -> Result { let request = self + .context() + .await? .client - .0 - .head(self.config.namespace_endpoint(ns)) + .request( + Method::HEAD, + self.context().await?.config.namespace_endpoint(ns), + ) .build()?; - self.client + self.context() + .await? + .client .do_execute::(request, |resp| match resp.status() { StatusCode::NO_CONTENT => Some(true), StatusCode::NOT_FOUND => Some(false), @@ -312,12 +435,18 @@ impl Catalog for RestCatalog { /// Drop a namespace from the catalog. async fn drop_namespace(&self, namespace: &NamespaceIdent) -> Result<()> { let request = self + .context() + .await? .client - .0 - .delete(self.config.namespace_endpoint(namespace)) + .request( + Method::DELETE, + self.context().await?.config.namespace_endpoint(namespace), + ) .build()?; - self.client + self.context() + .await? + .client .execute::(request) .await } @@ -325,12 +454,18 @@ impl Catalog for RestCatalog { /// List tables from namespace. async fn list_tables(&self, namespace: &NamespaceIdent) -> Result> { let request = self + .context() + .await? .client - .0 - .get(self.config.tables_endpoint(namespace)) + .request( + Method::GET, + self.context().await?.config.tables_endpoint(namespace), + ) .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; @@ -339,6 +474,11 @@ impl Catalog for RestCatalog { } /// Create a new table inside the namespace. + /// + /// In the resulting table, if there are any config properties that + /// are present in both the response from the REST server and the + /// config provided when creating this `RestCatalog` instance then + /// the value provided locally to the `RestCatalog` will take precedence. async fn create_table( &self, namespace: &NamespaceIdent, @@ -347,9 +487,13 @@ impl Catalog for RestCatalog { let table_ident = TableIdent::new(namespace.clone(), creation.name.clone()); let request = self + .context() + .await? .client - .0 - .post(self.config.tables_endpoint(namespace)) + .request( + Method::POST, + self.context().await?.config.tables_endpoint(namespace), + ) .json(&CreateTableRequest { name: creation.name, location: creation.location, @@ -367,13 +511,24 @@ impl Catalog for RestCatalog { .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; - let file_io = self.load_file_io(resp.metadata_location.as_deref(), resp.config)?; + let config = resp + .config + .unwrap_or_default() + .into_iter() + .chain(self.user_config.props.clone().into_iter()) + .collect(); + + let file_io = self + .load_file_io(resp.metadata_location.as_deref(), Some(config)) + .await?; - let table = Table::builder() + Table::builder() .identifier(table_ident) .file_io(file_io) .metadata(resp.metadata) @@ -383,25 +538,43 @@ impl Catalog for RestCatalog { "Metadata location missing in create table response!", ) })?) - .build(); - - Ok(table) + .build() } /// Load table from the catalog. + /// + /// If there are any config properties that are present in + /// both the response from the REST server and the config provided + /// when creating this `RestCatalog` instance then the value + /// provided locally to the `RestCatalog` will take precedence. async fn load_table(&self, table: &TableIdent) -> Result
{ let request = self + .context() + .await? .client - .0 - .get(self.config.table_endpoint(table)) + .request( + Method::GET, + self.context().await?.config.table_endpoint(table), + ) .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; - let file_io = self.load_file_io(resp.metadata_location.as_deref(), resp.config)?; + let config = resp + .config + .unwrap_or_default() + .into_iter() + .chain(self.user_config.props.clone().into_iter()) + .collect(); + + let file_io = self + .load_file_io(resp.metadata_location.as_deref(), Some(config)) + .await?; let table_builder = Table::builder() .identifier(table.clone()) @@ -409,34 +582,46 @@ impl Catalog for RestCatalog { .metadata(resp.metadata); if let Some(metadata_location) = resp.metadata_location { - Ok(table_builder.metadata_location(metadata_location).build()) + table_builder.metadata_location(metadata_location).build() } else { - Ok(table_builder.build()) + table_builder.build() } } /// Drop a table from the catalog. async fn drop_table(&self, table: &TableIdent) -> Result<()> { let request = self + .context() + .await? .client - .0 - .delete(self.config.table_endpoint(table)) + .request( + Method::DELETE, + self.context().await?.config.table_endpoint(table), + ) .build()?; - self.client + self.context() + .await? + .client .execute::(request) .await } /// Check if a table exists in the catalog. - async fn stat_table(&self, table: &TableIdent) -> Result { + async fn table_exists(&self, table: &TableIdent) -> Result { let request = self + .context() + .await? .client - .0 - .head(self.config.table_endpoint(table)) + .request( + Method::HEAD, + self.context().await?.config.table_endpoint(table), + ) .build()?; - self.client + self.context() + .await? + .client .do_execute::(request, |resp| match resp.status() { StatusCode::NO_CONTENT => Some(true), StatusCode::NOT_FOUND => Some(false), @@ -448,16 +633,22 @@ impl Catalog for RestCatalog { /// Rename a table in the catalog. async fn rename_table(&self, src: &TableIdent, dest: &TableIdent) -> Result<()> { let request = self + .context() + .await? .client - .0 - .post(self.config.rename_table_endpoint()) + .request( + Method::POST, + self.context().await?.config.rename_table_endpoint(), + ) .json(&RenameTableRequest { source: src.clone(), destination: dest.clone(), }) .build()?; - self.client + self.context() + .await? + .client .execute::(request) .await } @@ -465,9 +656,16 @@ impl Catalog for RestCatalog { /// Update table. async fn update_table(&self, mut commit: TableCommit) -> Result
{ let request = self + .context() + .await? .client - .0 - .post(self.config.table_endpoint(commit.identifier())) + .request( + Method::POST, + self.context() + .await? + .config + .table_endpoint(commit.identifier()), + ) .json(&CommitTableRequest { identifier: commit.identifier().clone(), requirements: commit.take_requirements(), @@ -476,307 +674,354 @@ impl Catalog for RestCatalog { .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; - let file_io = self.load_file_io(Some(&resp.metadata_location), None)?; - Ok(Table::builder() + let file_io = self + .load_file_io(Some(&resp.metadata_location), None) + .await?; + Table::builder() .identifier(commit.identifier().clone()) .file_io(file_io) .metadata(resp.metadata) .metadata_location(resp.metadata_location) - .build()) + .build() } } -impl RestCatalog { - /// Creates a rest catalog from config. - pub async fn new(config: RestCatalogConfig) -> Result { - let mut catalog = Self { - client: config.try_create_rest_client()?, - config, - }; - - catalog.update_config().await?; - catalog.client = catalog.config.try_create_rest_client()?; +#[cfg(test)] +mod tests { + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; - Ok(catalog) - } + use chrono::{TimeZone, Utc}; + use iceberg::spec::{ + FormatVersion, NestedField, NullOrder, Operation, PrimitiveType, Schema, Snapshot, + SnapshotLog, SortDirection, SortField, SortOrder, Summary, Transform, Type, + UnboundPartitionField, UnboundPartitionSpec, + }; + use iceberg::transaction::Transaction; + use mockito::{Mock, Server, ServerGuard}; + use serde_json::json; + use uuid::uuid; - async fn update_config(&mut self) -> Result<()> { - let mut request = self.client.0.get(self.config.config_endpoint()); + use super::*; - if let Some(warehouse_location) = &self.config.warehouse { - request = request.query(&[("warehouse", warehouse_location)]); - } + #[tokio::test] + async fn test_update_config() { + let mut server = Server::new_async().await; - let config = self - .client - .query::(request.build()?) - .await?; + let config_mock = server + .mock("GET", "/v1/config") + .with_status(200) + .with_body( + r#"{ + "overrides": { + "warehouse": "s3://iceberg-catalog" + }, + "defaults": {} + }"#, + ) + .create_async() + .await; - let mut props = config.defaults; - props.extend(self.config.props.clone()); - props.extend(config.overrides); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); - self.config.props = props; + assert_eq!( + catalog + .context() + .await + .unwrap() + .config + .props + .get("warehouse"), + Some(&"s3://iceberg-catalog".to_string()) + ); - Ok(()) + config_mock.assert_async().await; } - fn load_file_io( - &self, - metadata_location: Option<&str>, - extra_config: Option>, - ) -> Result { - let mut props = self.config.props.clone(); - if let Some(config) = extra_config { - props.extend(config); - } - - let file_io = match self.config.warehouse.as_deref().or(metadata_location) { - Some(url) => FileIO::from_path(url)?.with_props(props).build()?, - None => { - return Err(Error::new( - ErrorKind::Unexpected, - "Unable to load file io, neither warehouse nor metadata location is set!", - ))? - } - }; + async fn create_config_mock(server: &mut ServerGuard) -> Mock { + server + .mock("GET", "/v1/config") + .with_status(200) + .with_body( + r#"{ + "overrides": { + "warehouse": "s3://iceberg-catalog" + }, + "defaults": {} + }"#, + ) + .create_async() + .await + } - Ok(file_io) + async fn create_oauth_mock(server: &mut ServerGuard) -> Mock { + create_oauth_mock_with_path(server, "/v1/oauth/tokens").await } -} -/// Requests and responses for rest api. -mod _serde { - use std::collections::HashMap; + async fn create_oauth_mock_with_path(server: &mut ServerGuard, path: &str) -> Mock { + server + .mock("POST", path) + .with_status(200) + .with_body( + r#"{ + "access_token": "ey000000000000", + "token_type": "Bearer", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "expires_in": 86400 + }"#, + ) + .expect(2) + .create_async() + .await + } - use serde_derive::{Deserialize, Serialize}; + #[tokio::test] + async fn test_oauth() { + let mut server = Server::new_async().await; + let oauth_mock = create_oauth_mock(&mut server).await; + let config_mock = create_config_mock(&mut server).await; - use iceberg::spec::{Schema, SortOrder, TableMetadata, UnboundPartitionSpec}; - use iceberg::{Error, ErrorKind, Namespace, TableIdent, TableRequirement, TableUpdate}; + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); - pub(super) const OK: u16 = 200u16; - pub(super) const NO_CONTENT: u16 = 204u16; + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ); - #[derive(Clone, Debug, Serialize, Deserialize)] - pub(super) struct CatalogConfig { - pub(super) overrides: HashMap, - pub(super) defaults: HashMap, + let token = catalog.context().await.unwrap().client.token().await; + oauth_mock.assert_async().await; + config_mock.assert_async().await; + assert_eq!(token, Some("ey000000000000".to_string())); } - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct ErrorResponse { - error: ErrorModel, - } + #[tokio::test] + async fn test_oauth_with_optional_param() { + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + props.insert("scope".to_string(), "custom_scope".to_string()); + props.insert("audience".to_string(), "custom_audience".to_string()); + props.insert("resource".to_string(), "custom_resource".to_string()); - impl From for Error { - fn from(resp: ErrorResponse) -> Error { - resp.error.into() - } - } + let mut server = Server::new_async().await; + let oauth_mock = server + .mock("POST", "/v1/oauth/tokens") + .match_body(mockito::Matcher::Regex("scope=custom_scope".to_string())) + .match_body(mockito::Matcher::Regex( + "audience=custom_audience".to_string(), + )) + .match_body(mockito::Matcher::Regex( + "resource=custom_resource".to_string(), + )) + .with_status(200) + .with_body( + r#"{ + "access_token": "ey000000000000", + "token_type": "Bearer", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "expires_in": 86400 + }"#, + ) + .expect(2) + .create_async() + .await; - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct ErrorModel { - pub(super) message: String, - pub(super) r#type: String, - pub(super) code: u16, - pub(super) stack: Option>, - } + let config_mock = create_config_mock(&mut server).await; - impl From for Error { - fn from(value: ErrorModel) -> Self { - let mut error = Error::new(ErrorKind::DataInvalid, value.message) - .with_context("type", value.r#type) - .with_context("code", format!("{}", value.code)); + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ); - if let Some(stack) = value.stack { - error = error.with_context("stack", stack.join("\n")); - } + let token = catalog.context().await.unwrap().client.token().await; - error - } + oauth_mock.assert_async().await; + config_mock.assert_async().await; + assert_eq!(token, Some("ey000000000000".to_string())); } - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct OAuthError { - pub(super) error: String, - pub(super) error_description: Option, - pub(super) error_uri: Option, - } + #[tokio::test] + async fn test_http_headers() { + let server = Server::new_async().await; + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + + let config = RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(); + let headers: HeaderMap = config.extra_headers().unwrap(); - impl From for Error { - fn from(value: OAuthError) -> Self { - let mut error = Error::new( - ErrorKind::DataInvalid, - format!("OAuthError: {}", value.error), - ); + let expected_headers = HeaderMap::from_iter([ + ( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ), + ( + HeaderName::from_static("x-client-version"), + HeaderValue::from_static(ICEBERG_REST_SPEC_VERSION), + ), + ( + header::USER_AGENT, + HeaderValue::from_str(&format!("iceberg-rs/{}", CARGO_PKG_VERSION)).unwrap(), + ), + ]); + assert_eq!(headers, expected_headers); + } - if let Some(desc) = value.error_description { - error = error.with_context("description", desc); - } + #[tokio::test] + async fn test_http_headers_with_custom_headers() { + let server = Server::new_async().await; + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + props.insert( + "header.content-type".to_string(), + "application/yaml".to_string(), + ); + props.insert( + "header.customized-header".to_string(), + "some/value".to_string(), + ); - if let Some(uri) = value.error_uri { - error = error.with_context("uri", uri); - } + let config = RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(); + let headers: HeaderMap = config.extra_headers().unwrap(); - error - } + let expected_headers = HeaderMap::from_iter([ + ( + header::CONTENT_TYPE, + HeaderValue::from_static("application/yaml"), + ), + ( + HeaderName::from_static("x-client-version"), + HeaderValue::from_static(ICEBERG_REST_SPEC_VERSION), + ), + ( + header::USER_AGENT, + HeaderValue::from_str(&format!("iceberg-rs/{}", CARGO_PKG_VERSION)).unwrap(), + ), + ( + HeaderName::from_static("customized-header"), + HeaderValue::from_static("some/value"), + ), + ]); + assert_eq!(headers, expected_headers); } - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct NamespaceSerde { - pub(super) namespace: Vec, - pub(super) properties: Option>, - } + #[tokio::test] + async fn test_oauth_with_deprecated_auth_url() { + let mut server = Server::new_async().await; + let config_mock = create_config_mock(&mut server).await; - impl TryFrom for super::Namespace { - type Error = Error; - fn try_from(value: NamespaceSerde) -> std::result::Result { - Ok(super::Namespace::with_properties( - super::NamespaceIdent::from_vec(value.namespace)?, - value.properties.unwrap_or_default(), - )) - } - } + let mut auth_server = Server::new_async().await; + let auth_server_path = "/some/path"; + let oauth_mock = create_oauth_mock_with_path(&mut auth_server, auth_server_path).await; - impl From<&Namespace> for NamespaceSerde { - fn from(value: &Namespace) -> Self { - Self { - namespace: value.name().as_ref().clone(), - properties: Some(value.properties().clone()), - } - } - } + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + props.insert( + "rest.authorization-url".to_string(), + format!("{}{}", auth_server.url(), auth_server_path).to_string(), + ); - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct ListNamespaceResponse { - pub(super) namespaces: Vec>, - } + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ); - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct UpdateNamespacePropsRequest { - removals: Option>, - updates: Option>, - } + let token = catalog.context().await.unwrap().client.token().await; - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct UpdateNamespacePropsResponse { - updated: Vec, - removed: Vec, - missing: Option>, + oauth_mock.assert_async().await; + config_mock.assert_async().await; + assert_eq!(token, Some("ey000000000000".to_string())); } - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct ListTableResponse { - pub(super) identifiers: Vec, - } + #[tokio::test] + async fn test_oauth_with_oauth2_server_uri() { + let mut server = Server::new_async().await; + let config_mock = create_config_mock(&mut server).await; - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct RenameTableRequest { - pub(super) source: TableIdent, - pub(super) destination: TableIdent, - } + let mut auth_server = Server::new_async().await; + let auth_server_path = "/some/path"; + let oauth_mock = create_oauth_mock_with_path(&mut auth_server, auth_server_path).await; - #[derive(Debug, Deserialize)] - #[serde(rename_all = "kebab-case")] - pub(super) struct LoadTableResponse { - pub(super) metadata_location: Option, - pub(super) metadata: TableMetadata, - pub(super) config: Option>, - } + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + props.insert( + "oauth2-server-uri".to_string(), + format!("{}{}", auth_server.url(), auth_server_path).to_string(), + ); - #[derive(Debug, Serialize, Deserialize)] - #[serde(rename_all = "kebab-case")] - pub(super) struct CreateTableRequest { - pub(super) name: String, - pub(super) location: Option, - pub(super) schema: Schema, - pub(super) partition_spec: Option, - pub(super) write_order: Option, - pub(super) stage_create: Option, - pub(super) properties: Option>, - } + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ); - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct CommitTableRequest { - pub(super) identifier: TableIdent, - pub(super) requirements: Vec, - pub(super) updates: Vec, - } + let token = catalog.context().await.unwrap().client.token().await; - #[derive(Debug, Serialize, Deserialize)] - #[serde(rename_all = "kebab-case")] - pub(super) struct CommitTableResponse { - pub(super) metadata_location: String, - pub(super) metadata: TableMetadata, + oauth_mock.assert_async().await; + config_mock.assert_async().await; + assert_eq!(token, Some("ey000000000000".to_string())); } -} - -#[cfg(test)] -mod tests { - use chrono::{TimeZone, Utc}; - use iceberg::spec::ManifestListLocation::ManifestListFile; - use iceberg::spec::{ - FormatVersion, NestedField, NullOrder, Operation, PrimitiveType, Schema, Snapshot, - SnapshotLog, SortDirection, SortField, SortOrder, Summary, Transform, Type, - UnboundPartitionField, UnboundPartitionSpec, - }; - use iceberg::transaction::Transaction; - use mockito::{Mock, Server, ServerGuard}; - use std::fs::File; - use std::io::BufReader; - use std::sync::Arc; - use uuid::uuid; - - use super::*; #[tokio::test] - async fn test_update_config() { + async fn test_config_override() { let mut server = Server::new_async().await; + let mut redirect_server = Server::new_async().await; + let new_uri = redirect_server.url(); let config_mock = server .mock("GET", "/v1/config") .with_status(200) .with_body( - r#"{ - "overrides": { - "warehouse": "s3://iceberg-catalog" - }, - "defaults": {} - }"#, + json!( + { + "overrides": { + "uri": new_uri, + "warehouse": "s3://iceberg-catalog", + "prefix": "ice/warehouses/my" + }, + "defaults": {}, + } + ) + .to_string(), ) .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); - - assert_eq!( - catalog.config.props.get("warehouse"), - Some(&"s3://iceberg-catalog".to_string()) - ); - - config_mock.assert_async().await; - } - - async fn create_config_mock(server: &mut ServerGuard) -> Mock { - server - .mock("GET", "/v1/config") - .with_status(200) + let list_ns_mock = redirect_server + .mock("GET", "/v1/ice/warehouses/my/namespaces") .with_body( r#"{ - "overrides": { - "warehouse": "s3://iceberg-catalog" - }, - "defaults": {} - }"#, + "namespaces": [] + }"#, ) .create_async() - .await + .await; + + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); + + let _namespaces = catalog.list_namespaces(None).await.unwrap(); + + config_mock.assert_async().await; + list_ns_mock.assert_async().await; } #[tokio::test] @@ -798,9 +1043,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let namespaces = catalog.list_namespaces(None).await.unwrap(); @@ -834,9 +1077,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let namespaces = catalog .create_namespace( @@ -876,9 +1117,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let namespaces = catalog .get_namespace(&NamespaceIdent::new("ns1".to_string())) @@ -908,9 +1147,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); assert!(catalog .namespace_exists(&NamespaceIdent::new("ns1".to_string())) @@ -933,9 +1170,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); catalog .drop_namespace(&NamespaceIdent::new("ns1".to_string())) @@ -972,9 +1207,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let tables = catalog .list_tables(&NamespaceIdent::new("ns1".to_string())) @@ -1004,9 +1237,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); catalog .drop_table(&TableIdent::new( @@ -1032,12 +1263,10 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); assert!(catalog - .stat_table(&TableIdent::new( + .table_exists(&TableIdent::new( NamespaceIdent::new("ns1".to_string()), "table1".to_string(), )) @@ -1060,9 +1289,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); catalog .rename_table( @@ -1093,9 +1320,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table = catalog .load_table(&TableIdent::new( @@ -1118,7 +1343,7 @@ mod tests { ); assert_eq!( Utc.timestamp_millis_opt(1646787054459).unwrap(), - table.metadata().last_updated_ms() + table.metadata().last_updated_timestamp().unwrap() ); assert_eq!( vec![&Arc::new( @@ -1146,7 +1371,7 @@ mod tests { assert_eq!(vec![&Arc::new(Snapshot::builder() .with_snapshot_id(3497810964824022504) .with_timestamp_ms(1646787054459) - .with_manifest_list(ManifestListFile("s3://warehouse/database/table/metadata/snap-3497810964824022504-1-c4f68204-666b-4e50-a9df-b10c34bf6b82.avro".to_string())) + .with_manifest_list("s3://warehouse/database/table/metadata/snap-3497810964824022504-1-c4f68204-666b-4e50-a9df-b10c34bf6b82.avro") .with_sequence_number(0) .with_schema_id(0) .with_summary(Summary { @@ -1206,9 +1431,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table = catalog .load_table(&TableIdent::new( @@ -1245,9 +1468,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table_creation = TableCreation::builder() .name("test1".to_string()) @@ -1268,13 +1489,13 @@ mod tests { .properties(HashMap::from([("owner".to_string(), "testx".to_string())])) .partition_spec( UnboundPartitionSpec::builder() - .with_fields(vec![UnboundPartitionField::builder() + .add_partition_fields(vec![UnboundPartitionField::builder() .source_id(1) .transform(Transform::Truncate(3)) .name("id".to_string()) .build()]) - .build() - .unwrap(), + .unwrap() + .build(), ) .sort_order( SortOrder::builder() @@ -1286,7 +1507,7 @@ mod tests { .null_order(NullOrder::First) .build(), ) - .build() + .build_unbound() .unwrap(), ) .build(); @@ -1312,7 +1533,11 @@ mod tests { ); assert_eq!( 1657810967051, - table.metadata().last_updated_ms().timestamp_millis() + table + .metadata() + .last_updated_timestamp() + .unwrap() + .timestamp_millis() ); assert_eq!( vec![&Arc::new( @@ -1387,9 +1612,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table_creation = TableCreation::builder() .name("test1".to_string()) @@ -1442,9 +1665,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table1 = { let file = File::open(format!( @@ -1462,6 +1683,7 @@ mod tests { .identifier(TableIdent::from_strs(["ns1", "test1"]).unwrap()) .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap()) .build() + .unwrap() }; let table = Transaction::new(&table1) @@ -1487,7 +1709,11 @@ mod tests { ); assert_eq!( 1657810967051, - table.metadata().last_updated_ms().timestamp_millis() + table + .metadata() + .last_updated_timestamp() + .unwrap() + .timestamp_millis() ); assert_eq!( vec![&Arc::new( @@ -1558,15 +1784,13 @@ mod tests { "type": "NoSuchTableException", "code": 404 } -} +} "#, ) .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table1 = { let file = File::open(format!( @@ -1584,6 +1808,7 @@ mod tests { .identifier(TableIdent::from_strs(["ns1", "test1"]).unwrap()) .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap()) .build() + .unwrap() }; let table_result = Transaction::new(&table1) diff --git a/crates/catalog/rest/src/client.rs b/crates/catalog/rest/src/client.rs new file mode 100644 index 000000000..53dcd4cee --- /dev/null +++ b/crates/catalog/rest/src/client.rs @@ -0,0 +1,277 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use std::sync::Mutex; + +use iceberg::{Error, ErrorKind, Result}; +use reqwest::header::HeaderMap; +use reqwest::{Client, IntoUrl, Method, Request, RequestBuilder, Response}; +use serde::de::DeserializeOwned; + +use crate::types::{ErrorResponse, TokenResponse, OK}; +use crate::RestCatalogConfig; + +pub(crate) struct HttpClient { + client: Client, + + /// The token to be used for authentication. + /// + /// It's possible to fetch the token from the server while needed. + token: Mutex>, + /// The token endpoint to be used for authentication. + token_endpoint: String, + /// The credential to be used for authentication. + credential: Option<(Option, String)>, + /// Extra headers to be added to each request. + extra_headers: HeaderMap, + /// Extra oauth parameters to be added to each authentication request. + extra_oauth_params: HashMap, +} + +impl Debug for HttpClient { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HttpClient") + .field("client", &self.client) + .field("extra_headers", &self.extra_headers) + .finish_non_exhaustive() + } +} + +impl HttpClient { + pub fn new(cfg: &RestCatalogConfig) -> Result { + Ok(HttpClient { + client: Client::new(), + + token: Mutex::new(cfg.token()), + token_endpoint: cfg.get_token_endpoint(), + credential: cfg.credential(), + extra_headers: cfg.extra_headers()?, + extra_oauth_params: cfg.extra_oauth_params(), + }) + } + + /// This API is testing only to assert the token. + #[cfg(test)] + pub(crate) async fn token(&self) -> Option { + let mut req = self + .request(Method::GET, &self.token_endpoint) + .build() + .unwrap(); + self.authenticate(&mut req).await.ok(); + self.token.lock().unwrap().clone() + } + + /// Authenticate the request by filling token. + /// + /// - If neither token nor credential is provided, this method will do nothing. + /// - If only credential is provided, this method will try to fetch token from the server. + /// - If token is provided, this method will use the token directly. + /// + /// # TODO + /// + /// Support refreshing token while needed. + async fn authenticate(&self, req: &mut Request) -> Result<()> { + // Clone the token from lock without holding the lock for entire function. + let token = { self.token.lock().expect("lock poison").clone() }; + + if self.credential.is_none() && token.is_none() { + return Ok(()); + } + + // Use token if provided. + if let Some(token) = &token { + req.headers_mut().insert( + http::header::AUTHORIZATION, + format!("Bearer {token}").parse().map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Invalid token received from catalog server!", + ) + .with_source(e) + })?, + ); + return Ok(()); + } + + // Credential must exist here. + let (client_id, client_secret) = self.credential.as_ref().ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Credential must be provided for authentication", + ) + })?; + + let mut params = HashMap::with_capacity(4); + params.insert("grant_type", "client_credentials"); + if let Some(client_id) = client_id { + params.insert("client_id", client_id); + } + params.insert("client_secret", client_secret); + params.extend( + self.extra_oauth_params + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())), + ); + + let auth_req = self + .client + .request(Method::POST, &self.token_endpoint) + .form(¶ms) + .build()?; + let auth_resp = self.client.execute(auth_req).await?; + + let auth_res: TokenResponse = if auth_resp.status().as_u16() == OK { + let text = auth_resp.bytes().await?; + Ok(serde_json::from_slice(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_source(e) + })?) + } else { + let code = auth_resp.status(); + let text = auth_resp.bytes().await?; + let e: ErrorResponse = serde_json::from_slice(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_context("code", code.to_string()) + .with_source(e) + })?; + Err(Error::from(e)) + }?; + let token = auth_res.access_token; + // Update token. + *self.token.lock().expect("lock poison") = Some(token.clone()); + // Insert token in request. + req.headers_mut().insert( + http::header::AUTHORIZATION, + format!("Bearer {token}").parse().map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Invalid token received from catalog server!", + ) + .with_source(e) + })?, + ); + + Ok(()) + } + + #[inline] + pub fn request(&self, method: Method, url: U) -> RequestBuilder { + self.client.request(method, url) + } + + pub async fn query< + R: DeserializeOwned, + E: DeserializeOwned + Into, + const SUCCESS_CODE: u16, + >( + &self, + mut request: Request, + ) -> Result { + self.authenticate(&mut request).await?; + + let resp = self.client.execute(request).await?; + + if resp.status().as_u16() == SUCCESS_CODE { + let text = resp.bytes().await?; + Ok(serde_json::from_slice::(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_source(e) + })?) + } else { + let code = resp.status(); + let text = resp.bytes().await?; + let e = serde_json::from_slice::(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_context("code", code.to_string()) + .with_source(e) + })?; + Err(e.into()) + } + } + + pub async fn execute, const SUCCESS_CODE: u16>( + &self, + mut request: Request, + ) -> Result<()> { + self.authenticate(&mut request).await?; + + let resp = self.client.execute(request).await?; + + if resp.status().as_u16() == SUCCESS_CODE { + Ok(()) + } else { + let code = resp.status(); + let text = resp.bytes().await?; + let e = serde_json::from_slice::(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_context("code", code.to_string()) + .with_source(e) + })?; + Err(e.into()) + } + } + + /// More generic logic handling for special cases like head. + pub async fn do_execute>( + &self, + mut request: Request, + handler: impl FnOnce(&Response) -> Option, + ) -> Result { + self.authenticate(&mut request).await?; + + let resp = self.client.execute(request).await?; + + if let Some(ret) = handler(&resp) { + Ok(ret) + } else { + let code = resp.status(); + let text = resp.bytes().await?; + let e = serde_json::from_slice::(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("code", code.to_string()) + .with_context("json", String::from_utf8_lossy(&text)) + .with_source(e) + })?; + Err(e.into()) + } + } +} diff --git a/crates/catalog/rest/src/lib.rs b/crates/catalog/rest/src/lib.rs index 023fe7ab2..f94ee8781 100644 --- a/crates/catalog/rest/src/lib.rs +++ b/crates/catalog/rest/src/lib.rs @@ -20,4 +20,7 @@ #![deny(missing_docs)] mod catalog; +mod client; +mod types; + pub use catalog::*; diff --git a/crates/catalog/rest/src/types.rs b/crates/catalog/rest/src/types.rs new file mode 100644 index 000000000..11833a562 --- /dev/null +++ b/crates/catalog/rest/src/types.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use iceberg::spec::{Schema, SortOrder, TableMetadata, UnboundPartitionSpec}; +use iceberg::{ + Error, ErrorKind, Namespace, NamespaceIdent, TableIdent, TableRequirement, TableUpdate, +}; +use serde_derive::{Deserialize, Serialize}; + +pub(super) const OK: u16 = 200u16; +pub(super) const NO_CONTENT: u16 = 204u16; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct CatalogConfig { + pub(super) overrides: HashMap, + pub(super) defaults: HashMap, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct ErrorResponse { + error: ErrorModel, +} + +impl From for Error { + fn from(resp: ErrorResponse) -> Error { + resp.error.into() + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct ErrorModel { + pub(super) message: String, + pub(super) r#type: String, + pub(super) code: u16, + pub(super) stack: Option>, +} + +impl From for Error { + fn from(value: ErrorModel) -> Self { + let mut error = Error::new(ErrorKind::DataInvalid, value.message) + .with_context("type", value.r#type) + .with_context("code", format!("{}", value.code)); + + if let Some(stack) = value.stack { + error = error.with_context("stack", stack.join("\n")); + } + + error + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct OAuthError { + pub(super) error: String, + pub(super) error_description: Option, + pub(super) error_uri: Option, +} + +impl From for Error { + fn from(value: OAuthError) -> Self { + let mut error = Error::new( + ErrorKind::DataInvalid, + format!("OAuthError: {}", value.error), + ); + + if let Some(desc) = value.error_description { + error = error.with_context("description", desc); + } + + if let Some(uri) = value.error_uri { + error = error.with_context("uri", uri); + } + + error + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct TokenResponse { + pub(super) access_token: String, + pub(super) token_type: String, + pub(super) expires_in: Option, + pub(super) issued_token_type: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct NamespaceSerde { + pub(super) namespace: Vec, + pub(super) properties: Option>, +} + +impl TryFrom for Namespace { + type Error = Error; + fn try_from(value: NamespaceSerde) -> std::result::Result { + Ok(Namespace::with_properties( + NamespaceIdent::from_vec(value.namespace)?, + value.properties.unwrap_or_default(), + )) + } +} + +impl From<&Namespace> for NamespaceSerde { + fn from(value: &Namespace) -> Self { + Self { + namespace: value.name().as_ref().clone(), + properties: Some(value.properties().clone()), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct ListNamespaceResponse { + pub(super) namespaces: Vec>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct UpdateNamespacePropsRequest { + removals: Option>, + updates: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct UpdateNamespacePropsResponse { + updated: Vec, + removed: Vec, + missing: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct ListTableResponse { + pub(super) identifiers: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct RenameTableRequest { + pub(super) source: TableIdent, + pub(super) destination: TableIdent, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(super) struct LoadTableResponse { + pub(super) metadata_location: Option, + pub(super) metadata: TableMetadata, + pub(super) config: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(super) struct CreateTableRequest { + pub(super) name: String, + pub(super) location: Option, + pub(super) schema: Schema, + pub(super) partition_spec: Option, + pub(super) write_order: Option, + pub(super) stage_create: Option, + pub(super) properties: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct CommitTableRequest { + pub(super) identifier: TableIdent, + pub(super) requirements: Vec, + pub(super) updates: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(super) struct CommitTableResponse { + pub(super) metadata_location: String, + pub(super) metadata: TableMetadata, +} diff --git a/crates/catalog/rest/testdata/rest_catalog/docker-compose.yaml b/crates/catalog/rest/testdata/rest_catalog/docker-compose.yaml index 5c101463f..34ba3c874 100644 --- a/crates/catalog/rest/testdata/rest_catalog/docker-compose.yaml +++ b/crates/catalog/rest/testdata/rest_catalog/docker-compose.yaml @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -version: '3.8' +networks: + rest_bridge: services: rest: @@ -31,35 +32,36 @@ services: - CATALOG_S3_ENDPOINT=http://minio:9000 depends_on: - minio - links: - - minio:icebergdata.minio + networks: + rest_bridge: + aliases: + - icebergdata.minio expose: - 8181 minio: - image: minio/minio + image: minio/minio:RELEASE.2024-03-07T00-43-48Z environment: - MINIO_ROOT_USER=admin - MINIO_ROOT_PASSWORD=password - MINIO_DOMAIN=minio + hostname: icebergdata.minio + networks: + rest_bridge: expose: - 9001 - 9000 - command: [ "server", "/data", "--console-address", ":9001" ] + command: ["server", "/data", "--console-address", ":9001"] mc: depends_on: - minio - image: minio/mc + image: minio/mc:RELEASE.2024-03-07T00-31-49Z environment: - AWS_ACCESS_KEY_ID=admin - AWS_SECRET_ACCESS_KEY=password - AWS_REGION=us-east-1 entrypoint: > - /bin/sh -c " - until (/usr/bin/mc config host add minio http://minio:9000 admin password) do echo '...waiting...' && sleep 1; done; - /usr/bin/mc rm -r --force minio/icebergdata; - /usr/bin/mc mb minio/icebergdata; - /usr/bin/mc policy set public minio/icebergdata; - tail -f /dev/null - " \ No newline at end of file + /bin/sh -c " until (/usr/bin/mc config host add minio http://minio:9000 admin password) do echo '...waiting...' && sleep 1; done; /usr/bin/mc rm -r --force minio/icebergdata; /usr/bin/mc mb minio/icebergdata; /usr/bin/mc policy set public minio/icebergdata; tail -f /dev/null " + networks: + rest_bridge: diff --git a/crates/catalog/rest/tests/rest_catalog_test.rs b/crates/catalog/rest/tests/rest_catalog_test.rs index a4d07955b..e98890a86 100644 --- a/crates/catalog/rest/tests/rest_catalog_test.rs +++ b/crates/catalog/rest/tests/rest_catalog_test.rs @@ -17,6 +17,11 @@ //! Integration tests for rest catalog. +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::RwLock; + +use ctor::{ctor, dtor}; use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type}; use iceberg::transaction::Transaction; use iceberg::{Catalog, Namespace, NamespaceIdent, TableCreation, TableIdent}; @@ -24,55 +29,55 @@ use iceberg_catalog_rest::{RestCatalog, RestCatalogConfig}; use iceberg_test_utils::docker::DockerCompose; use iceberg_test_utils::{normalize_test_name, set_up}; use port_scanner::scan_port_addr; -use std::collections::HashMap; use tokio::time::sleep; const REST_CATALOG_PORT: u16 = 8181; +static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); -struct TestFixture { - _docker_compose: DockerCompose, - rest_catalog: RestCatalog, -} - -async fn set_test_fixture(func: &str) -> TestFixture { - set_up(); +#[ctor] +fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); let docker_compose = DockerCompose::new( - normalize_test_name(format!("{}_{func}", module_path!())), + normalize_test_name(module_path!()), format!("{}/testdata/rest_catalog", env!("CARGO_MANIFEST_DIR")), ); - - // Start docker compose docker_compose.run(); + guard.replace(docker_compose); +} + +#[dtor] +fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); +} + +async fn get_catalog() -> RestCatalog { + set_up(); - let rest_catalog_ip = docker_compose.get_container_ip("rest"); + let rest_catalog_ip = { + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + docker_compose.get_container_ip("rest") + }; - let read_port = format!("{}:{}", rest_catalog_ip, REST_CATALOG_PORT); - loop { - if !scan_port_addr(&read_port) { - log::info!("Waiting for 1s rest catalog to ready..."); - sleep(std::time::Duration::from_millis(1000)).await; - } else { - break; - } + let rest_socket_addr = SocketAddr::new(rest_catalog_ip, REST_CATALOG_PORT); + while !scan_port_addr(rest_socket_addr) { + log::info!("Waiting for 1s rest catalog to ready..."); + sleep(std::time::Duration::from_millis(1000)).await; } let config = RestCatalogConfig::builder() - .uri(format!("http://{}:{}", rest_catalog_ip, REST_CATALOG_PORT)) + .uri(format!("http://{}", rest_socket_addr)) .build(); - let rest_catalog = RestCatalog::new(config).await.unwrap(); - - TestFixture { - _docker_compose: docker_compose, - rest_catalog, - } + RestCatalog::new(config) } + #[tokio::test] async fn test_get_non_exist_namespace() { - let fixture = set_test_fixture("test_get_non_exist_namespace").await; + let catalog = get_catalog().await; - let result = fixture - .rest_catalog - .get_namespace(&NamespaceIdent::from_strs(["demo"]).unwrap()) + let result = catalog + .get_namespace(&NamespaceIdent::from_strs(["test_get_non_exist_namespace"]).unwrap()) .await; assert!(result.is_err()); @@ -84,7 +89,7 @@ async fn test_get_non_exist_namespace() { #[tokio::test] async fn test_get_namespace() { - let fixture = set_test_fixture("test_get_namespace").await; + let catalog = get_catalog().await; let ns = Namespace::with_properties( NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), @@ -95,11 +100,10 @@ async fn test_get_namespace() { ); // Verify that namespace doesn't exist - assert!(fixture.rest_catalog.get_namespace(ns.name()).await.is_err()); + assert!(catalog.get_namespace(ns.name()).await.is_err()); // Create this namespace - let created_ns = fixture - .rest_catalog + let created_ns = catalog .create_namespace(ns.name(), ns.properties().clone()) .await .unwrap(); @@ -108,17 +112,17 @@ async fn test_get_namespace() { assert_map_contains(ns.properties(), created_ns.properties()); // Check that this namespace already exists - let get_ns = fixture.rest_catalog.get_namespace(ns.name()).await.unwrap(); + let get_ns = catalog.get_namespace(ns.name()).await.unwrap(); assert_eq!(ns.name(), get_ns.name()); assert_map_contains(ns.properties(), created_ns.properties()); } #[tokio::test] async fn test_list_namespace() { - let fixture = set_test_fixture("test_list_namespace").await; + let catalog = get_catalog().await; let ns1 = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), + NamespaceIdent::from_strs(["test_list_namespace", "ios"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -126,7 +130,7 @@ async fn test_list_namespace() { ); let ns2 = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "macos"]).unwrap(), + NamespaceIdent::from_strs(["test_list_namespace", "macos"]).unwrap(), HashMap::from([ ("owner".to_string(), "xuanwo".to_string()), ("community".to_string(), "apache".to_string()), @@ -134,42 +138,41 @@ async fn test_list_namespace() { ); // Currently this namespace doesn't exist, so it should return error. - assert!(fixture - .rest_catalog - .list_namespaces(Some(&NamespaceIdent::from_strs(["apple"]).unwrap())) + assert!(catalog + .list_namespaces(Some( + &NamespaceIdent::from_strs(["test_list_namespace"]).unwrap() + )) .await .is_err()); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns1.name(), ns1.properties().clone()) .await .unwrap(); - fixture - .rest_catalog + catalog .create_namespace(ns2.name(), ns1.properties().clone()) .await .unwrap(); // List namespace - let mut nss = fixture - .rest_catalog - .list_namespaces(Some(&NamespaceIdent::from_strs(["apple"]).unwrap())) + let nss = catalog + .list_namespaces(Some( + &NamespaceIdent::from_strs(["test_list_namespace"]).unwrap(), + )) .await .unwrap(); - nss.sort(); - assert_eq!(&nss[0], ns1.name()); - assert_eq!(&nss[1], ns2.name()); + assert!(nss.contains(ns1.name())); + assert!(nss.contains(ns2.name())); } #[tokio::test] async fn test_list_empty_namespace() { - let fixture = set_test_fixture("test_list_empty_namespace").await; + let catalog = get_catalog().await; let ns_apple = Namespace::with_properties( - NamespaceIdent::from_strs(["apple"]).unwrap(), + NamespaceIdent::from_strs(["test_list_empty_namespace", "apple"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -177,23 +180,20 @@ async fn test_list_empty_namespace() { ); // Currently this namespace doesn't exist, so it should return error. - assert!(fixture - .rest_catalog + assert!(catalog .list_namespaces(Some(ns_apple.name())) .await .is_err()); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns_apple.name(), ns_apple.properties().clone()) .await .unwrap(); // List namespace - let nss = fixture - .rest_catalog - .list_namespaces(Some(&NamespaceIdent::from_strs(["apple"]).unwrap())) + let nss = catalog + .list_namespaces(Some(ns_apple.name())) .await .unwrap(); assert!(nss.is_empty()); @@ -201,10 +201,10 @@ async fn test_list_empty_namespace() { #[tokio::test] async fn test_list_root_namespace() { - let fixture = set_test_fixture("test_list_root_namespace").await; + let catalog = get_catalog().await; let ns1 = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), + NamespaceIdent::from_strs(["test_list_root_namespace", "apple", "ios"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -212,7 +212,7 @@ async fn test_list_root_namespace() { ); let ns2 = Namespace::with_properties( - NamespaceIdent::from_strs(["google", "android"]).unwrap(), + NamespaceIdent::from_strs(["test_list_root_namespace", "google", "android"]).unwrap(), HashMap::from([ ("owner".to_string(), "xuanwo".to_string()), ("community".to_string(), "apache".to_string()), @@ -220,38 +220,34 @@ async fn test_list_root_namespace() { ); // Currently this namespace doesn't exist, so it should return error. - assert!(fixture - .rest_catalog - .list_namespaces(Some(&NamespaceIdent::from_strs(["apple"]).unwrap())) + assert!(catalog + .list_namespaces(Some( + &NamespaceIdent::from_strs(["test_list_root_namespace"]).unwrap() + )) .await .is_err()); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns1.name(), ns1.properties().clone()) .await .unwrap(); - fixture - .rest_catalog + catalog .create_namespace(ns2.name(), ns1.properties().clone()) .await .unwrap(); // List namespace - let mut nss = fixture.rest_catalog.list_namespaces(None).await.unwrap(); - nss.sort(); - - assert_eq!(&nss[0], &NamespaceIdent::from_strs(["apple"]).unwrap()); - assert_eq!(&nss[1], &NamespaceIdent::from_strs(["google"]).unwrap()); + let nss = catalog.list_namespaces(None).await.unwrap(); + assert!(nss.contains(&NamespaceIdent::from_strs(["test_list_root_namespace"]).unwrap())); } #[tokio::test] async fn test_create_table() { - let fixture = set_test_fixture("test_create_table").await; + let catalog = get_catalog().await; let ns = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), + NamespaceIdent::from_strs(["test_create_table", "apple", "ios"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -259,8 +255,7 @@ async fn test_create_table() { ); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns.name(), ns.properties().clone()) .await .unwrap(); @@ -281,8 +276,7 @@ async fn test_create_table() { .schema(schema.clone()) .build(); - let table = fixture - .rest_catalog + let table = catalog .create_table(ns.name(), table_creation) .await .unwrap(); @@ -309,10 +303,10 @@ async fn test_create_table() { #[tokio::test] async fn test_update_table() { - let fixture = set_test_fixture("test_update_table").await; + let catalog = get_catalog().await; let ns = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), + NamespaceIdent::from_strs(["test_update_table", "apple", "ios"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -320,8 +314,7 @@ async fn test_update_table() { ); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns.name(), ns.properties().clone()) .await .unwrap(); @@ -343,8 +336,7 @@ async fn test_update_table() { .schema(schema.clone()) .build(); - let table = fixture - .rest_catalog + let table = catalog .create_table(ns.name(), table_creation) .await .unwrap(); @@ -358,7 +350,7 @@ async fn test_update_table() { let table2 = Transaction::new(&table) .set_properties(HashMap::from([("prop1".to_string(), "v1".to_string())])) .unwrap() - .commit(&fixture.rest_catalog) + .commit(&catalog) .await .unwrap(); @@ -374,3 +366,39 @@ fn assert_map_contains(map1: &HashMap, map2: &HashMap, +} + +#[derive(Debug)] +/// Sql catalog implementation. +pub struct SqlCatalog { + name: String, + connection: AnyPool, + _warehouse_location: String, + _fileio: FileIO, + sql_bind_style: SqlBindStyle, +} + +#[derive(Debug, PartialEq)] +/// Set the SQL parameter bind style to either $1..$N (Postgres style) or ? (SQLite/MySQL/MariaDB) +pub enum SqlBindStyle { + /// DollarNumeric uses parameters of the form `$1..$N``, which is the Postgres style + DollarNumeric, + /// QMark uses parameters of the form `?` which is the style for other dialects (SQLite/MySQL/MariaDB) + QMark, +} + +impl SqlCatalog { + /// Create new sql catalog instance + pub async fn new(config: SqlCatalogConfig) -> Result { + install_default_drivers(); + let max_connections: u32 = config + .props + .get("pool.max-connections") + .map(|v| v.parse().unwrap()) + .unwrap_or(MAX_CONNECTIONS); + let idle_timeout: u64 = config + .props + .get("pool.idle-timeout") + .map(|v| v.parse().unwrap()) + .unwrap_or(IDLE_TIMEOUT); + let test_before_acquire: bool = config + .props + .get("pool.test-before-acquire") + .map(|v| v.parse().unwrap()) + .unwrap_or(TEST_BEFORE_ACQUIRE); + + let pool = AnyPoolOptions::new() + .max_connections(max_connections) + .idle_timeout(Duration::from_secs(idle_timeout)) + .test_before_acquire(test_before_acquire) + .connect(&config.uri) + .await + .map_err(from_sqlx_error)?; + + sqlx::query(&format!( + "CREATE TABLE IF NOT EXISTS {CATALOG_TABLE_NAME} ( + {CATALOG_FIELD_CATALOG_NAME} VARCHAR(255) NOT NULL, + {CATALOG_FIELD_TABLE_NAMESPACE} VARCHAR(255) NOT NULL, + {CATALOG_FIELD_TABLE_NAME} VARCHAR(255) NOT NULL, + {CATALOG_FIELD_METADATA_LOCATION_PROP} VARCHAR(1000), + {CATALOG_FIELD_PREVIOUS_METADATA_LOCATION_PROP} VARCHAR(1000), + {CATALOG_FIELD_RECORD_TYPE} VARCHAR(5), + PRIMARY KEY ({CATALOG_FIELD_CATALOG_NAME}, {CATALOG_FIELD_TABLE_NAMESPACE}, {CATALOG_FIELD_TABLE_NAME}))" + )) + .execute(&pool) + .await + .map_err(from_sqlx_error)?; + + sqlx::query(&format!( + "CREATE TABLE IF NOT EXISTS {NAMESPACE_TABLE_NAME} ( + {CATALOG_FIELD_CATALOG_NAME} VARCHAR(255) NOT NULL, + {NAMESPACE_FIELD_NAME} VARCHAR(255) NOT NULL, + {NAMESPACE_FIELD_PROPERTY_KEY} VARCHAR(255), + {NAMESPACE_FIELD_PROPERTY_VALUE} VARCHAR(1000), + PRIMARY KEY ({CATALOG_FIELD_CATALOG_NAME}, {NAMESPACE_FIELD_NAME}, {NAMESPACE_FIELD_PROPERTY_KEY}))" + )) + .execute(&pool) + .await + .map_err(from_sqlx_error)?; + + Ok(SqlCatalog { + name: config.name.to_owned(), + connection: pool, + _warehouse_location: config.warehouse_location, + _fileio: config.file_io, + sql_bind_style: config.sql_bind_style, + }) + } + + /// SQLX Any does not implement PostgresSQL bindings, so we have to do this. + fn replace_placeholders(&self, query: &str) -> String { + match self.sql_bind_style { + SqlBindStyle::DollarNumeric => { + let mut count = 1; + query + .chars() + .fold(String::with_capacity(query.len()), |mut acc, c| { + if c == '?' { + acc.push('$'); + acc.push_str(&count.to_string()); + count += 1; + } else { + acc.push(c); + } + acc + }) + } + _ => query.to_owned(), + } + } + + /// Fetch a vec of AnyRows from a given query + async fn fetch_rows(&self, query: &str, args: Vec>) -> Result> { + let query_with_placeholders = self.replace_placeholders(query); + + let mut sqlx_query = sqlx::query(&query_with_placeholders); + for arg in args { + sqlx_query = sqlx_query.bind(arg); + } + + sqlx_query + .fetch_all(&self.connection) + .await + .map_err(from_sqlx_error) + } + + /// Execute statements in a transaction, provided or not + async fn execute( + &self, + query: &str, + args: Vec>, + transaction: Option<&mut Transaction<'_, Any>>, + ) -> Result { + let query_with_placeholders = self.replace_placeholders(query); + + let mut sqlx_query = sqlx::query(&query_with_placeholders); + for arg in args { + sqlx_query = sqlx_query.bind(arg); + } + + match transaction { + Some(t) => sqlx_query.execute(&mut **t).await.map_err(from_sqlx_error), + None => { + let mut tx = self.connection.begin().await.map_err(from_sqlx_error)?; + let result = sqlx_query.execute(&mut *tx).await.map_err(from_sqlx_error); + let _ = tx.commit().await.map_err(from_sqlx_error); + result + } + } + } +} + +#[async_trait] +impl Catalog for SqlCatalog { + async fn list_namespaces( + &self, + parent: Option<&NamespaceIdent>, + ) -> Result> { + // UNION will remove duplicates. + let all_namespaces_stmt = format!( + "SELECT {CATALOG_FIELD_TABLE_NAMESPACE} + FROM {CATALOG_TABLE_NAME} + WHERE {CATALOG_FIELD_CATALOG_NAME} = ? + UNION + SELECT {NAMESPACE_FIELD_NAME} + FROM {NAMESPACE_TABLE_NAME} + WHERE {CATALOG_FIELD_CATALOG_NAME} = ?" + ); + + let namespace_rows = self + .fetch_rows(&all_namespaces_stmt, vec![ + Some(&self.name), + Some(&self.name), + ]) + .await?; + + let mut namespaces = HashSet::::with_capacity(namespace_rows.len()); + + if let Some(parent) = parent { + if self.namespace_exists(parent).await? { + let parent_str = parent.join("."); + + for row in namespace_rows.iter() { + let nsp = row.try_get::(0).map_err(from_sqlx_error)?; + // if parent = a, then we only want to see a.b, a.c returned. + if nsp != parent_str && nsp.starts_with(&parent_str) { + namespaces.insert(NamespaceIdent::from_strs(nsp.split("."))?); + } + } + + Ok(namespaces.into_iter().collect::>()) + } else { + no_such_namespace_err(parent) + } + } else { + for row in namespace_rows.iter() { + let nsp = row.try_get::(0).map_err(from_sqlx_error)?; + let mut levels = nsp.split(".").collect::>(); + if !levels.is_empty() { + let first_level = levels.drain(..1).collect::>(); + namespaces.insert(NamespaceIdent::from_strs(first_level)?); + } + } + + Ok(namespaces.into_iter().collect::>()) + } + } + + async fn create_namespace( + &self, + namespace: &NamespaceIdent, + properties: HashMap, + ) -> Result { + let exists = self.namespace_exists(namespace).await?; + + if exists { + return Err(Error::new( + iceberg::ErrorKind::Unexpected, + format!("Namespace {:?} already exists", namespace), + )); + } + + let namespace_str = namespace.join("."); + let insert = format!( + "INSERT INTO {NAMESPACE_TABLE_NAME} ({CATALOG_FIELD_CATALOG_NAME}, {NAMESPACE_FIELD_NAME}, {NAMESPACE_FIELD_PROPERTY_KEY}, {NAMESPACE_FIELD_PROPERTY_VALUE}) + VALUES (?, ?, ?, ?)"); + if !properties.is_empty() { + let mut insert_properties = properties.clone(); + insert_properties.insert("exists".to_string(), "true".to_string()); + + let mut query_args = Vec::with_capacity(insert_properties.len() * 4); + let mut insert_stmt = insert.clone(); + for (index, (key, value)) in insert_properties.iter().enumerate() { + query_args.extend_from_slice(&[ + Some(self.name.as_str()), + Some(namespace_str.as_str()), + Some(key.as_str()), + Some(value.as_str()), + ]); + if index > 0 { + insert_stmt.push_str(", (?, ?, ?, ?)"); + } + } + + self.execute(&insert_stmt, query_args, None).await?; + + Ok(Namespace::with_properties( + namespace.clone(), + insert_properties, + )) + } else { + // set a default property of exists = true + self.execute( + &insert, + vec![ + Some(&self.name), + Some(&namespace_str), + Some("exists"), + Some("true"), + ], + None, + ) + .await?; + Ok(Namespace::with_properties(namespace.clone(), properties)) + } + } + + async fn get_namespace(&self, namespace: &NamespaceIdent) -> Result { + let exists = self.namespace_exists(namespace).await?; + if exists { + let namespace_props = self + .fetch_rows( + &format!( + "SELECT + {NAMESPACE_FIELD_NAME}, + {NAMESPACE_FIELD_PROPERTY_KEY}, + {NAMESPACE_FIELD_PROPERTY_VALUE} + FROM {NAMESPACE_TABLE_NAME} + WHERE {CATALOG_FIELD_CATALOG_NAME} = ? + AND {NAMESPACE_FIELD_NAME} = ?" + ), + vec![Some(&self.name), Some(&namespace.join("."))], + ) + .await?; + + let mut properties = HashMap::with_capacity(namespace_props.len()); + + for row in namespace_props { + let key = row + .try_get::(NAMESPACE_FIELD_PROPERTY_KEY) + .map_err(from_sqlx_error)?; + let value = row + .try_get::(NAMESPACE_FIELD_PROPERTY_VALUE) + .map_err(from_sqlx_error)?; + + properties.insert(key, value); + } + + Ok(Namespace::with_properties(namespace.clone(), properties)) + } else { + no_such_namespace_err(namespace) + } + } + + async fn namespace_exists(&self, namespace: &NamespaceIdent) -> Result { + let namespace_str = namespace.join("."); + + let table_namespaces = self + .fetch_rows( + &format!( + "SELECT 1 FROM {CATALOG_TABLE_NAME} + WHERE {CATALOG_FIELD_CATALOG_NAME} = ? + AND {CATALOG_FIELD_TABLE_NAMESPACE} = ? + LIMIT 1" + ), + vec![Some(&self.name), Some(&namespace_str)], + ) + .await?; + + if !table_namespaces.is_empty() { + Ok(true) + } else { + let namespaces = self + .fetch_rows( + &format!( + "SELECT 1 FROM {NAMESPACE_TABLE_NAME} + WHERE {CATALOG_FIELD_CATALOG_NAME} = ? + AND {NAMESPACE_FIELD_NAME} = ? + LIMIT 1" + ), + vec![Some(&self.name), Some(&namespace_str)], + ) + .await?; + if !namespaces.is_empty() { + Ok(true) + } else { + Ok(false) + } + } + } + + async fn update_namespace( + &self, + namespace: &NamespaceIdent, + properties: HashMap, + ) -> Result<()> { + let exists = self.namespace_exists(namespace).await?; + if exists { + let existing_properties = self.get_namespace(namespace).await?.properties().clone(); + let namespace_str = namespace.join("."); + + let mut updates = vec![]; + let mut inserts = vec![]; + + for (key, value) in properties.iter() { + if existing_properties.contains_key(key) { + if existing_properties.get(key) != Some(value) { + updates.push((key, value)); + } + } else { + inserts.push((key, value)); + } + } + + let mut tx = self.connection.begin().await.map_err(from_sqlx_error)?; + let update_stmt = format!( + "UPDATE {NAMESPACE_TABLE_NAME} SET {NAMESPACE_FIELD_PROPERTY_VALUE} = ? + WHERE {CATALOG_FIELD_CATALOG_NAME} = ? + AND {NAMESPACE_FIELD_NAME} = ? + AND {NAMESPACE_FIELD_PROPERTY_KEY} = ?" + ); + + let insert_stmt = format!( + "INSERT INTO {NAMESPACE_TABLE_NAME} ({CATALOG_FIELD_CATALOG_NAME}, {NAMESPACE_FIELD_NAME}, {NAMESPACE_FIELD_PROPERTY_KEY}, {NAMESPACE_FIELD_PROPERTY_VALUE}) + VALUES (?, ?, ?, ?)" + ); + + for (key, value) in updates { + self.execute( + &update_stmt, + vec![ + Some(value), + Some(&self.name), + Some(&namespace_str), + Some(key), + ], + Some(&mut tx), + ) + .await?; + } + + for (key, value) in inserts { + self.execute( + &insert_stmt, + vec![ + Some(&self.name), + Some(&namespace_str), + Some(key), + Some(value), + ], + Some(&mut tx), + ) + .await?; + } + + let _ = tx.commit().await.map_err(from_sqlx_error)?; + + Ok(()) + } else { + no_such_namespace_err(namespace) + } + } + + async fn drop_namespace(&self, _namespace: &NamespaceIdent) -> Result<()> { + todo!() + } + + async fn list_tables(&self, _namespace: &NamespaceIdent) -> Result> { + todo!() + } + + async fn table_exists(&self, _identifier: &TableIdent) -> Result { + todo!() + } + + async fn drop_table(&self, _identifier: &TableIdent) -> Result<()> { + todo!() + } + + async fn load_table(&self, _identifier: &TableIdent) -> Result
{ + todo!() + } + + async fn create_table( + &self, + _namespace: &NamespaceIdent, + _creation: TableCreation, + ) -> Result
{ + todo!() + } + + async fn rename_table(&self, _src: &TableIdent, _dest: &TableIdent) -> Result<()> { + todo!() + } + + async fn update_table(&self, _commit: TableCommit) -> Result
{ + todo!() + } +} + +#[cfg(test)] +mod tests { + use std::collections::{HashMap, HashSet}; + use std::hash::Hash; + + use iceberg::io::FileIOBuilder; + use iceberg::{Catalog, Namespace, NamespaceIdent}; + use sqlx::migrate::MigrateDatabase; + use tempfile::TempDir; + + use crate::{SqlBindStyle, SqlCatalog, SqlCatalogConfig}; + + fn temp_path() -> String { + let temp_dir = TempDir::new().unwrap(); + temp_dir.path().to_str().unwrap().to_string() + } + + fn to_set(vec: Vec) -> HashSet { + HashSet::from_iter(vec) + } + + fn default_properties() -> HashMap { + HashMap::from([("exists".to_string(), "true".to_string())]) + } + + async fn new_sql_catalog(warehouse_location: String) -> impl Catalog { + let sql_lite_uri = format!("sqlite:{}", temp_path()); + sqlx::Sqlite::create_database(&sql_lite_uri).await.unwrap(); + + let config = SqlCatalogConfig::builder() + .uri(sql_lite_uri.to_string()) + .name("iceberg".to_string()) + .warehouse_location(warehouse_location) + .file_io(FileIOBuilder::new_fs_io().build().unwrap()) + .sql_bind_style(SqlBindStyle::QMark) + .build(); + + SqlCatalog::new(config).await.unwrap() + } + + async fn create_namespace(catalog: &C, namespace_ident: &NamespaceIdent) { + let _ = catalog + .create_namespace(namespace_ident, HashMap::new()) + .await + .unwrap(); + } + + async fn create_namespaces(catalog: &C, namespace_idents: &Vec<&NamespaceIdent>) { + for namespace_ident in namespace_idents { + let _ = create_namespace(catalog, namespace_ident).await; + } + } + + #[tokio::test] + async fn test_initialized() { + let warehouse_loc = temp_path(); + new_sql_catalog(warehouse_loc.clone()).await; + // catalog instantiation should not fail even if tables exist + new_sql_catalog(warehouse_loc.clone()).await; + new_sql_catalog(warehouse_loc.clone()).await; + } + + #[tokio::test] + async fn test_list_namespaces_returns_empty_vector() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + + assert_eq!(catalog.list_namespaces(None).await.unwrap(), vec![]); + } + + #[tokio::test] + async fn test_list_namespaces_returns_multiple_namespaces() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident_1 = NamespaceIdent::new("a".into()); + let namespace_ident_2 = NamespaceIdent::new("b".into()); + create_namespaces(&catalog, &vec![&namespace_ident_1, &namespace_ident_2]).await; + + assert_eq!( + to_set(catalog.list_namespaces(None).await.unwrap()), + to_set(vec![namespace_ident_1, namespace_ident_2]) + ); + } + + #[tokio::test] + async fn test_list_namespaces_returns_only_top_level_namespaces() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident_1 = NamespaceIdent::new("a".into()); + let namespace_ident_2 = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_3 = NamespaceIdent::new("b".into()); + create_namespaces(&catalog, &vec![ + &namespace_ident_1, + &namespace_ident_2, + &namespace_ident_3, + ]) + .await; + + assert_eq!( + to_set(catalog.list_namespaces(None).await.unwrap()), + to_set(vec![namespace_ident_1, namespace_ident_3]) + ); + } + + #[tokio::test] + async fn test_list_namespaces_returns_no_namespaces_under_parent() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident_1 = NamespaceIdent::new("a".into()); + let namespace_ident_2 = NamespaceIdent::new("b".into()); + create_namespaces(&catalog, &vec![&namespace_ident_1, &namespace_ident_2]).await; + + assert_eq!( + catalog + .list_namespaces(Some(&namespace_ident_1)) + .await + .unwrap(), + vec![] + ); + } + + #[tokio::test] + async fn test_list_namespaces_returns_namespace_under_parent() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident_1 = NamespaceIdent::new("a".into()); + let namespace_ident_2 = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_3 = NamespaceIdent::new("c".into()); + create_namespaces(&catalog, &vec![ + &namespace_ident_1, + &namespace_ident_2, + &namespace_ident_3, + ]) + .await; + + assert_eq!( + to_set(catalog.list_namespaces(None).await.unwrap()), + to_set(vec![namespace_ident_1.clone(), namespace_ident_3]) + ); + + assert_eq!( + catalog + .list_namespaces(Some(&namespace_ident_1)) + .await + .unwrap(), + vec![NamespaceIdent::from_strs(vec!["a", "b"]).unwrap()] + ); + } + + #[tokio::test] + async fn test_list_namespaces_returns_multiple_namespaces_under_parent() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident_1 = NamespaceIdent::new("a".to_string()); + let namespace_ident_2 = NamespaceIdent::from_strs(vec!["a", "a"]).unwrap(); + let namespace_ident_3 = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_4 = NamespaceIdent::from_strs(vec!["a", "c"]).unwrap(); + let namespace_ident_5 = NamespaceIdent::new("b".into()); + create_namespaces(&catalog, &vec![ + &namespace_ident_1, + &namespace_ident_2, + &namespace_ident_3, + &namespace_ident_4, + &namespace_ident_5, + ]) + .await; + + assert_eq!( + to_set( + catalog + .list_namespaces(Some(&namespace_ident_1)) + .await + .unwrap() + ), + to_set(vec![ + NamespaceIdent::from_strs(vec!["a", "a"]).unwrap(), + NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(), + NamespaceIdent::from_strs(vec!["a", "c"]).unwrap(), + ]) + ); + } + + #[tokio::test] + async fn test_namespace_exists_returns_false() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + + assert!(!catalog + .namespace_exists(&NamespaceIdent::new("b".into())) + .await + .unwrap()); + } + + #[tokio::test] + async fn test_namespace_exists_returns_true() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + + assert!(catalog.namespace_exists(&namespace_ident).await.unwrap()); + } + + #[tokio::test] + async fn test_create_namespace_with_properties() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident = NamespaceIdent::new("abc".into()); + + let mut properties = default_properties(); + properties.insert("k".into(), "v".into()); + + assert_eq!( + catalog + .create_namespace(&namespace_ident, properties.clone()) + .await + .unwrap(), + Namespace::with_properties(namespace_ident.clone(), properties.clone()) + ); + + assert_eq!( + catalog.get_namespace(&namespace_ident).await.unwrap(), + Namespace::with_properties(namespace_ident, properties) + ); + } + + #[tokio::test] + async fn test_create_namespace_throws_error_if_namespace_already_exists() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &namespace_ident).await; + + assert_eq!( + catalog + .create_namespace(&namespace_ident, HashMap::new()) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => Namespace {:?} already exists", + &namespace_ident + ) + ); + + assert_eq!( + catalog.get_namespace(&namespace_ident).await.unwrap(), + Namespace::with_properties(namespace_ident, default_properties()) + ); + } + + #[tokio::test] + async fn test_create_nested_namespace() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let parent_namespace_ident = NamespaceIdent::new("a".into()); + create_namespace(&catalog, &parent_namespace_ident).await; + + let child_namespace_ident = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + + assert_eq!( + catalog + .create_namespace(&child_namespace_ident, HashMap::new()) + .await + .unwrap(), + Namespace::new(child_namespace_ident.clone()) + ); + + assert_eq!( + catalog.get_namespace(&child_namespace_ident).await.unwrap(), + Namespace::with_properties(child_namespace_ident, default_properties()) + ); + } + + #[tokio::test] + async fn test_create_deeply_nested_namespace() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + let namespace_ident_a_b_c = NamespaceIdent::from_strs(vec!["a", "b", "c"]).unwrap(); + + assert_eq!( + catalog + .create_namespace(&namespace_ident_a_b_c, HashMap::new()) + .await + .unwrap(), + Namespace::new(namespace_ident_a_b_c.clone()) + ); + + assert_eq!( + catalog.get_namespace(&namespace_ident_a_b_c).await.unwrap(), + Namespace::with_properties(namespace_ident_a_b_c, default_properties()) + ); + } + + #[tokio::test] + #[ignore = "drop_namespace not implemented"] + async fn test_drop_namespace() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident = NamespaceIdent::new("abc".into()); + create_namespace(&catalog, &namespace_ident).await; + + catalog.drop_namespace(&namespace_ident).await.unwrap(); + + assert!(!catalog.namespace_exists(&namespace_ident).await.unwrap()) + } + + #[tokio::test] + #[ignore = "drop_namespace not implemented"] + async fn test_drop_nested_namespace() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + catalog.drop_namespace(&namespace_ident_a_b).await.unwrap(); + + assert!(!catalog + .namespace_exists(&namespace_ident_a_b) + .await + .unwrap()); + + assert!(catalog.namespace_exists(&namespace_ident_a).await.unwrap()); + } + + #[tokio::test] + #[ignore = "drop_namespace not implemented"] + async fn test_drop_deeply_nested_namespace() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + let namespace_ident_a_b_c = NamespaceIdent::from_strs(vec!["a", "b", "c"]).unwrap(); + create_namespaces(&catalog, &vec![ + &namespace_ident_a, + &namespace_ident_a_b, + &namespace_ident_a_b_c, + ]) + .await; + + catalog + .drop_namespace(&namespace_ident_a_b_c) + .await + .unwrap(); + + assert!(!catalog + .namespace_exists(&namespace_ident_a_b_c) + .await + .unwrap()); + + assert!(catalog + .namespace_exists(&namespace_ident_a_b) + .await + .unwrap()); + + assert!(catalog.namespace_exists(&namespace_ident_a).await.unwrap()); + } + + #[tokio::test] + #[ignore = "drop_namespace not implemented"] + async fn test_drop_namespace_throws_error_if_namespace_doesnt_exist() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + + let non_existent_namespace_ident = NamespaceIdent::new("abc".into()); + assert_eq!( + catalog + .drop_namespace(&non_existent_namespace_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_namespace_ident + ) + ) + } + + #[tokio::test] + #[ignore = "drop_namespace not implemented"] + async fn test_drop_namespace_throws_error_if_nested_namespace_doesnt_exist() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + create_namespace(&catalog, &NamespaceIdent::new("a".into())).await; + + let non_existent_namespace_ident = + NamespaceIdent::from_vec(vec!["a".into(), "b".into()]).unwrap(); + assert_eq!( + catalog + .drop_namespace(&non_existent_namespace_ident) + .await + .unwrap_err() + .to_string(), + format!( + "Unexpected => No such namespace: {:?}", + non_existent_namespace_ident + ) + ) + } + + #[tokio::test] + #[ignore = "drop_namespace not implemented"] + async fn test_dropping_a_namespace_does_not_drop_namespaces_nested_under_that_one() { + let warehouse_loc = temp_path(); + let catalog = new_sql_catalog(warehouse_loc).await; + let namespace_ident_a = NamespaceIdent::new("a".into()); + let namespace_ident_a_b = NamespaceIdent::from_strs(vec!["a", "b"]).unwrap(); + create_namespaces(&catalog, &vec![&namespace_ident_a, &namespace_ident_a_b]).await; + + catalog.drop_namespace(&namespace_ident_a).await.unwrap(); + + assert!(!catalog.namespace_exists(&namespace_ident_a).await.unwrap()); + + assert!(catalog + .namespace_exists(&namespace_ident_a_b) + .await + .unwrap()); + } +} diff --git a/crates/catalog/sql/src/error.rs b/crates/catalog/sql/src/error.rs new file mode 100644 index 000000000..cfefcc26a --- /dev/null +++ b/crates/catalog/sql/src/error.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use iceberg::{Error, ErrorKind, NamespaceIdent, Result}; + +/// Format an sqlx error into iceberg error. +pub fn from_sqlx_error(error: sqlx::Error) -> Error { + Error::new( + ErrorKind::Unexpected, + "operation failed for hitting sqlx error".to_string(), + ) + .with_source(error) +} + +pub fn no_such_namespace_err(namespace: &NamespaceIdent) -> Result { + Err(Error::new( + ErrorKind::Unexpected, + format!("No such namespace: {:?}", namespace), + )) +} diff --git a/crates/catalog/sql/src/lib.rs b/crates/catalog/sql/src/lib.rs new file mode 100644 index 000000000..6861dab3f --- /dev/null +++ b/crates/catalog/sql/src/lib.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Iceberg sql catalog implementation. + +#![deny(missing_docs)] + +mod catalog; +mod error; +pub use catalog::*; diff --git a/crates/examples/Cargo.toml b/crates/examples/Cargo.toml new file mode 100644 index 000000000..2fb3060c1 --- /dev/null +++ b/crates/examples/Cargo.toml @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "iceberg-examples" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +rust-version = { workspace = true } + +[dependencies] +iceberg = { workspace = true } +iceberg-catalog-rest = { workspace = true } +tokio = { version = "1", features = ["full"] } + +[[example]] +name = "rest-catalog-namespace" +path = "src/rest_catalog_namespace.rs" + +[[example]] +name = "rest-catalog-table" +path = "src/rest_catalog_table.rs" diff --git a/crates/examples/README.md b/crates/examples/README.md new file mode 100644 index 000000000..335d2ea28 --- /dev/null +++ b/crates/examples/README.md @@ -0,0 +1,21 @@ + + +Example usage codes for `iceberg-rust`. Currently, these examples can't run directly since it requires setting up of +environments for catalogs, for example, rest catalog server. \ No newline at end of file diff --git a/crates/examples/src/rest_catalog_namespace.rs b/crates/examples/src/rest_catalog_namespace.rs new file mode 100644 index 000000000..0a508a7d8 --- /dev/null +++ b/crates/examples/src/rest_catalog_namespace.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use iceberg::{Catalog, NamespaceIdent}; +use iceberg_catalog_rest::{RestCatalog, RestCatalogConfig}; + +#[tokio::main] +async fn main() { + // ANCHOR: create_catalog + // Create catalog + let config = RestCatalogConfig::builder() + .uri("http://localhost:8080".to_string()) + .build(); + + let catalog = RestCatalog::new(config); + // ANCHOR_END: create_catalog + + // ANCHOR: list_all_namespace + // List all namespaces + let all_namespaces = catalog.list_namespaces(None).await.unwrap(); + println!("Namespaces in current catalog: {:?}", all_namespaces); + // ANCHOR_END: list_all_namespace + + // ANCHOR: create_namespace + let namespace_id = + NamespaceIdent::from_vec(vec!["ns1".to_string(), "ns11".to_string()]).unwrap(); + // Create namespace + let ns = catalog + .create_namespace( + &namespace_id, + HashMap::from([("key1".to_string(), "value1".to_string())]), + ) + .await + .unwrap(); + + println!("Namespace created: {:?}", ns); + // ANCHOR_END: create_namespace +} diff --git a/crates/examples/src/rest_catalog_table.rs b/crates/examples/src/rest_catalog_table.rs new file mode 100644 index 000000000..a0a672f15 --- /dev/null +++ b/crates/examples/src/rest_catalog_table.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use iceberg::spec::{NestedField, PrimitiveType, Schema, Type}; +use iceberg::{Catalog, TableCreation, TableIdent}; +use iceberg_catalog_rest::{RestCatalog, RestCatalogConfig}; + +#[tokio::main] +async fn main() { + // Create catalog + let config = RestCatalogConfig::builder() + .uri("http://localhost:8080".to_string()) + .build(); + + let catalog = RestCatalog::new(config); + + // ANCHOR: create_table + let table_id = TableIdent::from_strs(["default", "t1"]).unwrap(); + + let table_schema = Schema::builder() + .with_fields(vec![ + NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(), + ]) + .with_schema_id(1) + .with_identifier_field_ids(vec![2]) + .build() + .unwrap(); + + // Create table + let table_creation = TableCreation::builder() + .name(table_id.name.clone()) + .schema(table_schema.clone()) + .properties(HashMap::from([("owner".to_string(), "testx".to_string())])) + .build(); + + let table = catalog + .create_table(&table_id.namespace, table_creation) + .await + .unwrap(); + + println!("Table created: {:?}", table.metadata()); + // ANCHOR_END: create_table + + // ANCHOR: load_table + let table2 = catalog + .load_table(&TableIdent::from_strs(["default", "t2"]).unwrap()) + .await + .unwrap(); + println!("{:?}", table2.metadata()); + // ANCHOR_END: load_table +} diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index b4867bbe4..6166d360d 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -17,35 +17,56 @@ [package] name = "iceberg" -version = "0.1.0" -edition = "2021" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +rust-version = { workspace = true } categories = ["database"] description = "Apache Iceberg Rust implementation" -repository = "https://github.com/apache/iceberg-rust" -license = "Apache-2.0" +repository = { workspace = true } +license = { workspace = true } keywords = ["iceberg"] +[features] +default = ["storage-memory", "storage-fs", "storage-s3", "tokio"] +storage-all = ["storage-memory", "storage-fs", "storage-s3", "storage-gcs"] + +storage-memory = ["opendal/services-memory"] +storage-fs = ["opendal/services-fs"] +storage-s3 = ["opendal/services-s3"] +storage-gcs = ["opendal/services-gcs"] + +async-std = ["dep:async-std"] +tokio = ["dep:tokio"] + [dependencies] anyhow = { workspace = true } apache-avro = { workspace = true } +array-init = { workspace = true } arrow-arith = { workspace = true } arrow-array = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } +arrow-select = { workspace = true } +arrow-string = { workspace = true } +async-std = { workspace = true, optional = true, features = ["attributes"] } async-trait = { workspace = true } bimap = { workspace = true } bitvec = { workspace = true } +bytes = { workspace = true } chrono = { workspace = true } derive_builder = { workspace = true } -either = { workspace = true } +fnv = { workspace = true } futures = { workspace = true } itertools = { workspace = true } -lazy_static = { workspace = true } -log = { workspace = true } +moka = { version = "0.12.8", features = ["future"] } murmur3 = { workspace = true } once_cell = { workspace = true } opendal = { workspace = true } ordered-float = { workspace = true } +parquet = { workspace = true, features = ["async"] } +paste = { workspace = true } reqwest = { workspace = true } rust_decimal = { workspace = true } serde = { workspace = true } @@ -54,12 +75,16 @@ serde_derive = { workspace = true } serde_json = { workspace = true } serde_repr = { workspace = true } serde_with = { workspace = true } +tokio = { workspace = true, optional = true } typed-builder = { workspace = true } url = { workspace = true } -urlencoding = { workspace = true } uuid = { workspace = true } [dev-dependencies] +ctor = { workspace = true } +iceberg-catalog-memory = { workspace = true } +iceberg_test_utils = { path = "../test_utils", features = ["tests"] } pretty_assertions = { workspace = true } +rand = { workspace = true } tempfile = { workspace = true } -tokio = { workspace = true } +tera = { workspace = true } diff --git a/crates/iceberg/DEPENDENCIES.rust.tsv b/crates/iceberg/DEPENDENCIES.rust.tsv new file mode 100644 index 000000000..b4617eedb --- /dev/null +++ b/crates/iceberg/DEPENDENCIES.rust.tsv @@ -0,0 +1,276 @@ +crate 0BSD Apache-2.0 Apache-2.0 WITH LLVM-exception BSD-2-Clause BSD-3-Clause BSL-1.0 CC0-1.0 ISC MIT MPL-2.0 OpenSSL Unicode-DFS-2016 Unlicense Zlib +addr2line@0.22.0 X X +adler@1.0.2 X X X +adler32@1.2.0 X +ahash@0.8.11 X X +aho-corasick@1.1.3 X X +alloc-no-stdlib@2.0.4 X +alloc-stdlib@0.2.2 X +allocator-api2@0.2.18 X X +android-tzdata@0.1.1 X X +android_system_properties@0.1.5 X X +anstream@0.6.15 X X +anstyle@1.0.8 X X +anstyle-parse@0.2.5 X X +anstyle-query@1.1.1 X X +anstyle-wincon@3.0.4 X X +anyhow@1.0.86 X X +apache-avro@0.17.0 X +array-init@2.1.0 X X +arrayvec@0.7.4 X X +arrow-arith@52.2.0 X +arrow-array@52.2.0 X +arrow-buffer@52.2.0 X +arrow-cast@52.2.0 X +arrow-data@52.2.0 X +arrow-ipc@52.2.0 X +arrow-ord@52.2.0 X +arrow-schema@52.2.0 X +arrow-select@52.2.0 X +arrow-string@52.2.0 X +async-trait@0.1.81 X X +atoi@2.0.0 X +autocfg@1.3.0 X X +backon@0.4.4 X +backtrace@0.3.73 X X +base64@0.22.1 X X +bigdecimal@0.4.5 X X +bimap@0.6.3 X X +bitflags@1.3.2 X X +bitvec@1.0.1 X +block-buffer@0.10.4 X X +brotli@6.0.0 X X +brotli-decompressor@4.0.1 X X +bumpalo@3.16.0 X X +byteorder@1.5.0 X X +bytes@1.7.1 X +cc@1.1.11 X X +cfg-if@1.0.0 X X +chrono@0.4.38 X X +colorchoice@1.0.2 X X +const-oid@0.9.6 X X +const-random@0.1.18 X X +const-random-macro@0.1.16 X X +core-foundation-sys@0.8.7 X X +core2@0.4.0 X X +cpufeatures@0.2.13 X X +crc32c@0.6.8 X X +crc32fast@1.4.2 X X +crunchy@0.2.2 X +crypto-common@0.1.6 X X +darling@0.20.10 X +darling_core@0.20.10 X +darling_macro@0.20.10 X +dary_heap@0.3.6 X X +derive_builder@0.20.0 X X +derive_builder_core@0.20.0 X X +derive_builder_macro@0.20.0 X X +digest@0.10.7 X X +either@1.13.0 X X +env_filter@0.1.2 X X +env_logger@0.11.5 X X +fastrand@2.1.0 X X +flagset@0.4.6 X +flatbuffers@24.3.25 X +flate2@1.0.31 X X +fnv@1.0.7 X X +form_urlencoded@1.2.1 X X +funty@2.0.0 X +futures@0.3.30 X X +futures-channel@0.3.30 X X +futures-core@0.3.30 X X +futures-executor@0.3.30 X X +futures-io@0.3.30 X X +futures-macro@0.3.30 X X +futures-sink@0.3.30 X X +futures-task@0.3.30 X X +futures-util@0.3.30 X X +generic-array@0.14.7 X +getrandom@0.2.15 X X +gimli@0.29.0 X X +half@2.4.1 X X +hashbrown@0.14.5 X X +heck@0.5.0 X X +hermit-abi@0.3.9 X X +hex@0.4.3 X X +hmac@0.12.1 X X +home@0.5.9 X X +http@1.1.0 X X +http-body@1.0.1 X +http-body-util@0.1.2 X +httparse@1.9.4 X X +humantime@2.1.0 X X +hyper@1.4.1 X +hyper-rustls@0.27.2 X X X +hyper-util@0.1.7 X +iana-time-zone@0.1.60 X X +iana-time-zone-haiku@0.1.2 X X +iceberg@0.3.0 X +iceberg-catalog-memory@0.3.0 X +iceberg_test_utils@0.3.0 X +ident_case@1.0.1 X X +idna@0.5.0 X X +integer-encoding@3.0.4 X +ipnet@2.9.0 X X +is_terminal_polyfill@1.70.1 X X +itertools@0.13.0 X X +itoa@1.0.11 X X +jobserver@0.1.32 X X +js-sys@0.3.70 X X +lexical-core@0.8.5 X X +lexical-parse-float@0.8.5 X X +lexical-parse-integer@0.8.6 X X +lexical-util@0.8.5 X X +lexical-write-float@0.8.5 X X +lexical-write-integer@0.8.5 X X +libc@0.2.155 X X +libflate@2.1.0 X +libflate_lz77@2.1.0 X +libm@0.2.8 X X +log@0.4.22 X X +lz4_flex@0.11.3 X +md-5@0.10.6 X X +memchr@2.7.4 X X +mime@0.3.17 X X +miniz_oxide@0.7.4 X X X +mio@1.0.2 X +murmur3@0.5.2 X X +num@0.4.3 X X +num-bigint@0.4.6 X X +num-complex@0.4.6 X X +num-integer@0.1.46 X X +num-iter@0.1.45 X X +num-rational@0.4.2 X X +num-traits@0.2.19 X X +object@0.36.3 X X +once_cell@1.19.0 X X +opendal@0.49.0 X +ordered-float@2.10.1 X +ordered-float@4.2.2 X +parquet@52.2.0 X +paste@1.0.15 X X +percent-encoding@2.3.1 X X +pin-project@1.1.5 X X +pin-project-internal@1.1.5 X X +pin-project-lite@0.2.14 X X +pin-utils@0.1.0 X X +pkg-config@0.3.30 X X +ppv-lite86@0.2.20 X X +proc-macro2@1.0.86 X X +quad-rand@0.2.1 X +quick-xml@0.36.1 X +quote@1.0.36 X X +radium@0.7.0 X +rand@0.8.5 X X +rand_chacha@0.3.1 X X +rand_core@0.6.4 X X +regex@1.10.6 X X +regex-automata@0.4.7 X X +regex-lite@0.1.6 X X +regex-syntax@0.8.4 X X +reqsign@0.16.0 X +reqwest@0.12.5 X X +ring@0.17.8 X +rle-decode-fast@1.0.3 X X +rust_decimal@1.35.0 X +rustc-demangle@0.1.24 X X +rustc_version@0.4.0 X X +rustls@0.23.12 X X X +rustls-pemfile@2.1.3 X X X +rustls-pki-types@1.8.0 X X +rustls-webpki@0.102.6 X +rustversion@1.0.17 X X +ryu@1.0.18 X X +semver@1.0.23 X X +seq-macro@0.3.5 X X +serde@1.0.207 X X +serde_bytes@0.11.15 X X +serde_derive@1.0.207 X X +serde_json@1.0.124 X X +serde_repr@0.1.19 X X +serde_urlencoded@0.7.1 X X +serde_with@3.9.0 X X +serde_with_macros@3.9.0 X X +sha1@0.10.6 X X +sha2@0.10.8 X X +shlex@1.3.0 X X +slab@0.4.9 X +smallvec@1.13.2 X X +snap@1.1.1 X +socket2@0.5.7 X X +spin@0.9.8 X +static_assertions@1.1.0 X X +strsim@0.11.1 X +strum@0.26.3 X +strum_macros@0.26.4 X +subtle@2.6.1 X +syn@2.0.74 X X +sync_wrapper@1.0.1 X +tap@1.0.1 X +thiserror@1.0.63 X X +thiserror-impl@1.0.63 X X +thrift@0.17.0 X +tiny-keccak@2.0.2 X +tinyvec@1.8.0 X X X +tinyvec_macros@0.1.1 X X X +tokio@1.39.2 X +tokio-macros@2.4.0 X +tokio-rustls@0.26.0 X X +tokio-util@0.7.11 X +tower@0.4.13 X +tower-layer@0.3.3 X +tower-service@0.3.3 X +tracing@0.1.40 X +tracing-core@0.1.32 X +try-lock@0.2.5 X +twox-hash@1.6.3 X +typed-builder@0.19.1 X X +typed-builder-macro@0.19.1 X X +typenum@1.17.0 X X +unicode-bidi@0.3.15 X X +unicode-ident@1.0.12 X X X +unicode-normalization@0.1.23 X X +untrusted@0.9.0 X +url@2.5.2 X X +utf8parse@0.2.2 X X +uuid@1.10.0 X X +version_check@0.9.5 X X +want@0.3.1 X +wasi@0.11.0+wasi-snapshot-preview1 X X X +wasm-bindgen@0.2.93 X X +wasm-bindgen-backend@0.2.93 X X +wasm-bindgen-futures@0.4.43 X X +wasm-bindgen-macro@0.2.93 X X +wasm-bindgen-macro-support@0.2.93 X X +wasm-bindgen-shared@0.2.93 X X +wasm-streams@0.4.0 X X +web-sys@0.3.70 X X +webpki-roots@0.26.3 X +windows-core@0.52.0 X X +windows-sys@0.48.0 X X +windows-sys@0.52.0 X X +windows-targets@0.48.5 X X +windows-targets@0.52.6 X X +windows_aarch64_gnullvm@0.48.5 X X +windows_aarch64_gnullvm@0.52.6 X X +windows_aarch64_msvc@0.48.5 X X +windows_aarch64_msvc@0.52.6 X X +windows_i686_gnu@0.48.5 X X +windows_i686_gnu@0.52.6 X X +windows_i686_gnullvm@0.52.6 X X +windows_i686_msvc@0.48.5 X X +windows_i686_msvc@0.52.6 X X +windows_x86_64_gnu@0.48.5 X X +windows_x86_64_gnu@0.52.6 X X +windows_x86_64_gnullvm@0.48.5 X X +windows_x86_64_gnullvm@0.52.6 X X +windows_x86_64_msvc@0.48.5 X X +windows_x86_64_msvc@0.52.6 X X +winreg@0.52.0 X +wyz@0.5.1 X +zerocopy@0.7.35 X X X +zerocopy-derive@0.7.35 X X X +zeroize@1.8.1 X X +zstd@0.13.2 X +zstd-safe@7.2.1 X X +zstd-sys@2.0.12+zstd.1.5.6 X X diff --git a/crates/iceberg/README.md b/crates/iceberg/README.md new file mode 100644 index 000000000..b292303d8 --- /dev/null +++ b/crates/iceberg/README.md @@ -0,0 +1,59 @@ + + +# Apache Iceberg Official Native Rust Implementation + +[![crates.io](https://img.shields.io/crates/v/iceberg.svg)](https://crates.io/crates/iceberg) +[![docs.rs](https://img.shields.io/docsrs/iceberg.svg)](https://docs.rs/iceberg/latest/iceberg/) + +This crate contains the official Native Rust implementation of [Apache Iceberg](https://rust.iceberg.apache.org/). + +See the [API documentation](https://docs.rs/iceberg/latest) for examples and the full API. + +## Usage + +```rust +use futures::TryStreamExt; +use iceberg::io::{FileIO, FileIOBuilder}; +use iceberg::{Catalog, Result, TableIdent}; +use iceberg_catalog_memory::MemoryCatalog; + +#[tokio::main] +async fn main() -> Result<()> { + // Build your file IO. + let file_io = FileIOBuilder::new("memory").build()?; + // Connect to a catalog. + let catalog = MemoryCatalog::new(file_io, None); + // Load table from catalog. + let table = catalog + .load_table(&TableIdent::from_strs(["hello", "world"])?) + .await?; + // Build table scan. + let stream = table + .scan() + .select(["name", "id"]) + .build()? + .to_arrow() + .await?; + + // Consume this stream like arrow record batch stream. + let _data: Vec<_> = stream.try_collect().await?; + Ok(()) +} +``` diff --git a/crates/iceberg/src/arrow/mod.rs b/crates/iceberg/src/arrow/mod.rs new file mode 100644 index 000000000..2076a958f --- /dev/null +++ b/crates/iceberg/src/arrow/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Conversion between Iceberg and Arrow schema + +mod schema; +pub use schema::*; +mod reader; +pub use reader::*; diff --git a/crates/iceberg/src/arrow/reader.rs b/crates/iceberg/src/arrow/reader.rs new file mode 100644 index 000000000..592945544 --- /dev/null +++ b/crates/iceberg/src/arrow/reader.rs @@ -0,0 +1,1083 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parquet file data reader + +use std::collections::{HashMap, HashSet}; +use std::ops::Range; +use std::str::FromStr; +use std::sync::Arc; + +use arrow_arith::boolean::{and, is_not_null, is_null, not, or}; +use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch}; +use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; +use arrow_schema::{ArrowError, DataType, SchemaRef as ArrowSchemaRef}; +use arrow_string::like::starts_with; +use bytes::Bytes; +use fnv::FnvHashSet; +use futures::channel::mpsc::{channel, Sender}; +use futures::future::BoxFuture; +use futures::{try_join, SinkExt, StreamExt, TryFutureExt, TryStreamExt}; +use parquet::arrow::arrow_reader::{ArrowPredicateFn, ArrowReaderOptions, RowFilter}; +use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader}; +use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask, PARQUET_FIELD_ID_META_KEY}; +use parquet::file::metadata::ParquetMetaData; +use parquet::schema::types::{SchemaDescriptor, Type as ParquetType}; + +use crate::arrow::{arrow_schema_to_schema, get_arrow_datum}; +use crate::error::Result; +use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; +use crate::expr::visitors::row_group_metrics_evaluator::RowGroupMetricsEvaluator; +use crate::expr::{BoundPredicate, BoundReference}; +use crate::io::{FileIO, FileMetadata, FileRead}; +use crate::runtime::spawn; +use crate::scan::{ArrowRecordBatchStream, FileScanTask, FileScanTaskStream}; +use crate::spec::{Datum, Schema}; +use crate::utils::available_parallelism; +use crate::{Error, ErrorKind}; + +/// Builder to create ArrowReader +pub struct ArrowReaderBuilder { + batch_size: Option, + file_io: FileIO, + concurrency_limit_data_files: usize, + row_group_filtering_enabled: bool, +} + +impl ArrowReaderBuilder { + /// Create a new ArrowReaderBuilder + pub(crate) fn new(file_io: FileIO) -> Self { + let num_cpus = available_parallelism().get(); + + ArrowReaderBuilder { + batch_size: None, + file_io, + concurrency_limit_data_files: num_cpus, + row_group_filtering_enabled: true, + } + } + + /// Sets the max number of in flight data files that are being fetched + pub fn with_data_file_concurrency_limit(mut self, val: usize) -> Self { + self.concurrency_limit_data_files = val; + self + } + + /// Sets the desired size of batches in the response + /// to something other than the default + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = Some(batch_size); + self + } + + /// Determines whether to enable row group filtering. + pub fn with_row_group_filtering_enabled(mut self, row_group_filtering_enabled: bool) -> Self { + self.row_group_filtering_enabled = row_group_filtering_enabled; + self + } + + /// Build the ArrowReader. + pub fn build(self) -> ArrowReader { + ArrowReader { + batch_size: self.batch_size, + file_io: self.file_io, + concurrency_limit_data_files: self.concurrency_limit_data_files, + row_group_filtering_enabled: self.row_group_filtering_enabled, + } + } +} + +/// Reads data from Parquet files +#[derive(Clone)] +pub struct ArrowReader { + batch_size: Option, + file_io: FileIO, + + /// the maximum number of data files that can be fetched at the same time + concurrency_limit_data_files: usize, + + row_group_filtering_enabled: bool, +} + +impl ArrowReader { + /// Take a stream of FileScanTasks and reads all the files. + /// Returns a stream of Arrow RecordBatches containing the data from the files + pub fn read(self, tasks: FileScanTaskStream) -> Result { + let file_io = self.file_io.clone(); + let batch_size = self.batch_size; + let concurrency_limit_data_files = self.concurrency_limit_data_files; + let row_group_filtering_enabled = self.row_group_filtering_enabled; + + let (tx, rx) = channel(concurrency_limit_data_files); + let mut channel_for_error = tx.clone(); + + spawn(async move { + let result = tasks + .map(|task| Ok((task, file_io.clone(), tx.clone()))) + .try_for_each_concurrent( + concurrency_limit_data_files, + |(file_scan_task, file_io, tx)| async move { + match file_scan_task { + Ok(task) => { + let file_path = task.data_file_path.to_string(); + + spawn(async move { + Self::process_file_scan_task( + task, + batch_size, + file_io, + tx, + row_group_filtering_enabled, + ) + .await + }) + .await + .map_err(|e| e.with_context("file_path", file_path)) + } + Err(err) => Err(err), + } + }, + ) + .await; + + if let Err(error) = result { + let _ = channel_for_error.send(Err(error)).await; + } + }); + + return Ok(rx.boxed()); + } + + async fn process_file_scan_task( + task: FileScanTask, + batch_size: Option, + file_io: FileIO, + mut tx: Sender>, + row_group_filtering_enabled: bool, + ) -> Result<()> { + // Get the metadata for the Parquet file we need to read and build + // a reader for the data within + let parquet_file = file_io.new_input(&task.data_file_path)?; + let (parquet_metadata, parquet_reader) = + try_join!(parquet_file.metadata(), parquet_file.reader())?; + let parquet_file_reader = ArrowFileReader::new(parquet_metadata, parquet_reader); + + // Start creating the record batch stream, which wraps the parquet file reader + let mut record_batch_stream_builder = ParquetRecordBatchStreamBuilder::new_with_options( + parquet_file_reader, + // Page index will be required in upcoming row selection PR + ArrowReaderOptions::new().with_page_index(false), + ) + .await?; + + // Create a projection mask for the batch stream to select which columns in the + // Parquet file that we want in the response + let projection_mask = Self::get_arrow_projection_mask( + &task.project_field_ids, + &task.schema, + record_batch_stream_builder.parquet_schema(), + record_batch_stream_builder.schema(), + )?; + record_batch_stream_builder = record_batch_stream_builder.with_projection(projection_mask); + + if let Some(batch_size) = batch_size { + record_batch_stream_builder = record_batch_stream_builder.with_batch_size(batch_size); + } + + if let Some(predicate) = &task.predicate { + let (iceberg_field_ids, field_id_map) = Self::build_field_id_set_and_map( + record_batch_stream_builder.parquet_schema(), + predicate, + )?; + + let row_filter = Self::get_row_filter( + predicate, + record_batch_stream_builder.parquet_schema(), + &iceberg_field_ids, + &field_id_map, + )?; + record_batch_stream_builder = record_batch_stream_builder.with_row_filter(row_filter); + + let mut selected_row_groups = None; + if row_group_filtering_enabled { + let result = Self::get_selected_row_group_indices( + predicate, + record_batch_stream_builder.metadata(), + &field_id_map, + &task.schema, + )?; + + selected_row_groups = Some(result); + } + + if let Some(selected_row_groups) = selected_row_groups { + record_batch_stream_builder = + record_batch_stream_builder.with_row_groups(selected_row_groups); + } + } + + // Build the batch stream and send all the RecordBatches that it generates + // to the requester. + let mut record_batch_stream = record_batch_stream_builder.build()?; + while let Some(batch) = record_batch_stream.try_next().await? { + tx.send(Ok(batch)).await? + } + + Ok(()) + } + + fn build_field_id_set_and_map( + parquet_schema: &SchemaDescriptor, + predicate: &BoundPredicate, + ) -> Result<(HashSet, HashMap)> { + // Collects all Iceberg field IDs referenced in the filter predicate + let mut collector = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + visit(&mut collector, predicate)?; + + let iceberg_field_ids = collector.field_ids(); + let field_id_map = build_field_id_map(parquet_schema)?; + + Ok((iceberg_field_ids, field_id_map)) + } + + fn get_arrow_projection_mask( + field_ids: &[i32], + iceberg_schema_of_task: &Schema, + parquet_schema: &SchemaDescriptor, + arrow_schema: &ArrowSchemaRef, + ) -> Result { + if field_ids.is_empty() { + Ok(ProjectionMask::all()) + } else { + // Build the map between field id and column index in Parquet schema. + let mut column_map = HashMap::new(); + + let fields = arrow_schema.fields(); + let iceberg_schema = arrow_schema_to_schema(arrow_schema)?; + fields.filter_leaves(|idx, field| { + let field_id = field.metadata().get(PARQUET_FIELD_ID_META_KEY); + if field_id.is_none() { + return false; + } + + let field_id = i32::from_str(field_id.unwrap()); + if field_id.is_err() { + return false; + } + let field_id = field_id.unwrap(); + + if !field_ids.contains(&field_id) { + return false; + } + + let iceberg_field = iceberg_schema_of_task.field_by_id(field_id); + let parquet_iceberg_field = iceberg_schema.field_by_id(field_id); + + if iceberg_field.is_none() || parquet_iceberg_field.is_none() { + return false; + } + + if iceberg_field.unwrap().field_type != parquet_iceberg_field.unwrap().field_type { + return false; + } + + column_map.insert(field_id, idx); + true + }); + + if column_map.len() != field_ids.len() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Parquet schema {} and Iceberg schema {} do not match.", + iceberg_schema, iceberg_schema_of_task + ), + )); + } + + let mut indices = vec![]; + for field_id in field_ids { + if let Some(col_idx) = column_map.get(field_id) { + indices.push(*col_idx); + } else { + return Err(Error::new( + ErrorKind::DataInvalid, + format!("Field {} is not found in Parquet schema.", field_id), + )); + } + } + Ok(ProjectionMask::leaves(parquet_schema, indices)) + } + } + + fn get_row_filter( + predicates: &BoundPredicate, + parquet_schema: &SchemaDescriptor, + iceberg_field_ids: &HashSet, + field_id_map: &HashMap, + ) -> Result { + // Collect Parquet column indices from field ids. + // If the field id is not found in Parquet schema, it will be ignored due to schema evolution. + let mut column_indices = iceberg_field_ids + .iter() + .filter_map(|field_id| field_id_map.get(field_id).cloned()) + .collect::>(); + column_indices.sort(); + + // The converter that converts `BoundPredicates` to `ArrowPredicates` + let mut converter = PredicateConverter { + parquet_schema, + column_map: field_id_map, + column_indices: &column_indices, + }; + + // After collecting required leaf column indices used in the predicate, + // creates the projection mask for the Arrow predicates. + let projection_mask = ProjectionMask::leaves(parquet_schema, column_indices.clone()); + let predicate_func = visit(&mut converter, predicates)?; + let arrow_predicate = ArrowPredicateFn::new(projection_mask, predicate_func); + Ok(RowFilter::new(vec![Box::new(arrow_predicate)])) + } + + fn get_selected_row_group_indices( + predicate: &BoundPredicate, + parquet_metadata: &Arc, + field_id_map: &HashMap, + snapshot_schema: &Schema, + ) -> Result> { + let row_groups_metadata = parquet_metadata.row_groups(); + let mut results = Vec::with_capacity(row_groups_metadata.len()); + + for (idx, row_group_metadata) in row_groups_metadata.iter().enumerate() { + if RowGroupMetricsEvaluator::eval( + predicate, + row_group_metadata, + field_id_map, + snapshot_schema, + )? { + results.push(idx); + } + } + + Ok(results) + } +} + +/// Build the map of parquet field id to Parquet column index in the schema. +fn build_field_id_map(parquet_schema: &SchemaDescriptor) -> Result> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column idx: {}, name: {}, type {:?} in schema doesn't have field id", + idx, + basic_info.name(), + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: HashSet, +} + +impl CollectFieldIdVisitor { + fn field_ids(self) -> HashSet { + self.field_ids + } +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + + fn always_true(&mut self) -> Result<()> { + Ok(()) + } + + fn always_false(&mut self) -> Result<()> { + Ok(()) + } + + fn and(&mut self, _lhs: (), _rhs: ()) -> Result<()> { + Ok(()) + } + + fn or(&mut self, _lhs: (), _rhs: ()) -> Result<()> { + Ok(()) + } + + fn not(&mut self, _inner: ()) -> Result<()> { + Ok(()) + } + + fn is_null(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_null(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn is_nan(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_nan(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn less_than( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn greater_than( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn starts_with( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_starts_with( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn r#in( + &mut self, + reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_in( + &mut self, + reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } +} + +/// A visitor to convert Iceberg bound predicates to Arrow predicates. +struct PredicateConverter<'a> { + /// The Parquet schema descriptor. + pub parquet_schema: &'a SchemaDescriptor, + /// The map between field id and leaf column index in Parquet schema. + pub column_map: &'a HashMap, + /// The required column indices in Parquet schema for the predicates. + pub column_indices: &'a Vec, +} + +impl PredicateConverter<'_> { + /// When visiting a bound reference, we return index of the leaf column in the + /// required column indices which is used to project the column in the record batch. + /// Return None if the field id is not found in the column map, which is possible + /// due to schema evolution. + fn bound_reference(&mut self, reference: &BoundReference) -> Result> { + // The leaf column's index in Parquet schema. + if let Some(column_idx) = self.column_map.get(&reference.field().id) { + if self.parquet_schema.get_column_root_idx(*column_idx) != *column_idx { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column `{}` in predicates isn't a root column in Parquet schema.", + reference.field().name + ), + )); + } + + // The leaf column's index in the required column indices. + let index = self + .column_indices + .iter() + .position(|&idx| idx == *column_idx).ok_or(Error::new(ErrorKind::DataInvalid, format!( + "Leave column `{}` in predicates cannot be found in the required column indices.", + reference.field().name + )))?; + + Ok(Some(index)) + } else { + Ok(None) + } + } + + /// Build an Arrow predicate that always returns true. + fn build_always_true(&self) -> Result> { + Ok(Box::new(|batch| { + Ok(BooleanArray::from(vec![true; batch.num_rows()])) + })) + } + + /// Build an Arrow predicate that always returns false. + fn build_always_false(&self) -> Result> { + Ok(Box::new(|batch| { + Ok(BooleanArray::from(vec![false; batch.num_rows()])) + })) + } +} + +/// Gets the leaf column from the record batch for the required column index. Only +/// supports top-level columns for now. +fn project_column( + batch: &RecordBatch, + column_idx: usize, +) -> std::result::Result { + let column = batch.column(column_idx); + + match column.data_type() { + DataType::Struct(_) => Err(ArrowError::SchemaError( + "Does not support struct column yet.".to_string(), + )), + _ => Ok(column.clone()), + } +} + +type PredicateResult = + dyn FnMut(RecordBatch) -> std::result::Result + Send + 'static; + +impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { + type T = Box; + + fn always_true(&mut self) -> Result> { + self.build_always_true() + } + + fn always_false(&mut self) -> Result> { + self.build_always_false() + } + + fn and( + &mut self, + mut lhs: Box, + mut rhs: Box, + ) -> Result> { + Ok(Box::new(move |batch| { + let left = lhs(batch.clone())?; + let right = rhs(batch)?; + and(&left, &right) + })) + } + + fn or( + &mut self, + mut lhs: Box, + mut rhs: Box, + ) -> Result> { + Ok(Box::new(move |batch| { + let left = lhs(batch.clone())?; + let right = rhs(batch)?; + or(&left, &right) + })) + } + + fn not(&mut self, mut inner: Box) -> Result> { + Ok(Box::new(move |batch| { + let pred_ret = inner(batch)?; + not(&pred_ret) + })) + } + + fn is_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + Ok(Box::new(move |batch| { + let column = project_column(&batch, idx)?; + is_null(&column) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } + + fn not_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + Ok(Box::new(move |batch| { + let column = project_column(&batch, idx)?; + is_not_null(&column) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn is_nan( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result> { + if self.bound_reference(reference)?.is_some() { + self.build_always_true() + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn not_nan( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result> { + if self.bound_reference(reference)?.is_some() { + self.build_always_false() + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } + + fn less_than( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + lt(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + lt_eq(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } + + fn greater_than( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + gt(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + gt_eq(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + eq(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn not_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + neq(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn starts_with( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + starts_with(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn not_starts_with( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + + // update here if arrow ever adds a native not_starts_with + not(&starts_with(&left, literal.as_ref())?) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } + + fn r#in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literals: Vec<_> = literals + .iter() + .map(|lit| get_arrow_datum(lit).unwrap()) + .collect(); + + Ok(Box::new(move |batch| { + // update this if arrow ever adds a native is_in kernel + let left = project_column(&batch, idx)?; + let mut acc = BooleanArray::from(vec![false; batch.num_rows()]); + for literal in &literals { + acc = or(&acc, &eq(&left, literal.as_ref())?)? + } + + Ok(acc) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn not_in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literals: Vec<_> = literals + .iter() + .map(|lit| get_arrow_datum(lit).unwrap()) + .collect(); + + Ok(Box::new(move |batch| { + // update this if arrow ever adds a native not_in kernel + let left = project_column(&batch, idx)?; + let mut acc = BooleanArray::from(vec![true; batch.num_rows()]); + for literal in &literals { + acc = and(&acc, &neq(&left, literal.as_ref())?)? + } + + Ok(acc) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } +} + +/// ArrowFileReader is a wrapper around a FileRead that impls parquets AsyncFileReader. +/// +/// # TODO +/// +/// [ParquetObjectReader](https://docs.rs/parquet/latest/src/parquet/arrow/async_reader/store.rs.html#64) +/// contains the following hints to speed up metadata loading, we can consider adding them to this struct: +/// +/// - `metadata_size_hint`: Provide a hint as to the size of the parquet file's footer. +/// - `preload_column_index`: Load the Column Index as part of [`Self::get_metadata`]. +/// - `preload_offset_index`: Load the Offset Index as part of [`Self::get_metadata`]. +struct ArrowFileReader { + meta: FileMetadata, + r: R, +} + +impl ArrowFileReader { + /// Create a new ArrowFileReader + fn new(meta: FileMetadata, r: R) -> Self { + Self { meta, r } + } +} + +impl AsyncFileReader for ArrowFileReader { + fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, parquet::errors::Result> { + Box::pin( + self.r + .read(range.start as _..range.end as _) + .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))), + ) + } + + fn get_metadata(&mut self) -> BoxFuture<'_, parquet::errors::Result>> { + Box::pin(async move { + let file_size = self.meta.size; + let mut loader = MetadataLoader::load(self, file_size as usize, None).await?; + loader.load_page_index(false, false).await?; + Ok(Arc::new(loader.finish())) + }) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + use std::sync::Arc; + + use crate::arrow::reader::CollectFieldIdVisitor; + use crate::expr::visitors::bound_predicate_visitor::visit; + use crate::expr::{Bind, Reference}; + use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; + + fn table_schema_simple() -> SchemaRef { + Arc::new( + Schema::builder() + .with_schema_id(1) + .with_identifier_field_ids(vec![2]) + .with_fields(vec![ + NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(), + NestedField::optional(4, "qux", Type::Primitive(PrimitiveType::Float)).into(), + ]) + .build() + .unwrap(), + ) + } + + #[test] + fn test_collect_field_id() { + let schema = table_schema_simple(); + let expr = Reference::new("qux").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + + let mut visitor = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + visit(&mut visitor, &bound_expr).unwrap(); + + let mut expected = HashSet::default(); + expected.insert(4_i32); + + assert_eq!(visitor.field_ids, expected); + } + + #[test] + fn test_collect_field_id_with_and() { + let schema = table_schema_simple(); + let expr = Reference::new("qux") + .is_null() + .and(Reference::new("baz").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + + let mut visitor = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + visit(&mut visitor, &bound_expr).unwrap(); + + let mut expected = HashSet::default(); + expected.insert(4_i32); + expected.insert(3); + + assert_eq!(visitor.field_ids, expected); + } + + #[test] + fn test_collect_field_id_with_or() { + let schema = table_schema_simple(); + let expr = Reference::new("qux") + .is_null() + .or(Reference::new("baz").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + + let mut visitor = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + visit(&mut visitor, &bound_expr).unwrap(); + + let mut expected = HashSet::default(); + expected.insert(4_i32); + expected.insert(3); + + assert_eq!(visitor.field_ids, expected); + } +} diff --git a/crates/iceberg/src/arrow/schema.rs b/crates/iceberg/src/arrow/schema.rs new file mode 100644 index 000000000..2ff43e0f0 --- /dev/null +++ b/crates/iceberg/src/arrow/schema.rs @@ -0,0 +1,1368 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Conversion between Arrow schema and Iceberg schema. + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow_array::types::{ + validate_decimal_precision_and_scale, Decimal128Type, TimestampMicrosecondType, +}; +use arrow_array::{ + BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array, + PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray, +}; +use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; +use bitvec::macros::internal::funty::Fundamental; +use parquet::arrow::PARQUET_FIELD_ID_META_KEY; +use parquet::file::statistics::Statistics; +use rust_decimal::prelude::ToPrimitive; +use uuid::Uuid; + +use crate::error::Result; +use crate::spec::{ + Datum, ListType, MapType, NestedField, NestedFieldRef, PrimitiveLiteral, PrimitiveType, Schema, + SchemaVisitor, StructType, Type, +}; +use crate::{Error, ErrorKind}; + +/// When iceberg map type convert to Arrow map type, the default map field name is "key_value". +pub(crate) const DEFAULT_MAP_FIELD_NAME: &str = "key_value"; + +/// A post order arrow schema visitor. +/// +/// For order of methods called, please refer to [`visit_schema`]. +pub trait ArrowSchemaVisitor { + /// Return type of this visitor on arrow field. + type T; + + /// Return type of this visitor on arrow schema. + type U; + + /// Called before struct/list/map field. + fn before_field(&mut self, _field: &Field) -> Result<()> { + Ok(()) + } + + /// Called after struct/list/map field. + fn after_field(&mut self, _field: &Field) -> Result<()> { + Ok(()) + } + + /// Called before list element. + fn before_list_element(&mut self, _field: &Field) -> Result<()> { + Ok(()) + } + + /// Called after list element. + fn after_list_element(&mut self, _field: &Field) -> Result<()> { + Ok(()) + } + + /// Called before map key. + fn before_map_key(&mut self, _field: &Field) -> Result<()> { + Ok(()) + } + + /// Called after map key. + fn after_map_key(&mut self, _field: &Field) -> Result<()> { + Ok(()) + } + + /// Called before map value. + fn before_map_value(&mut self, _field: &Field) -> Result<()> { + Ok(()) + } + + /// Called after map value. + fn after_map_value(&mut self, _field: &Field) -> Result<()> { + Ok(()) + } + + /// Called after schema's type visited. + fn schema(&mut self, schema: &ArrowSchema, values: Vec) -> Result; + + /// Called after struct's fields visited. + fn r#struct(&mut self, fields: &Fields, results: Vec) -> Result; + + /// Called after list fields visited. + fn list(&mut self, list: &DataType, value: Self::T) -> Result; + + /// Called after map's key and value fields visited. + fn map(&mut self, map: &DataType, key_value: Self::T, value: Self::T) -> Result; + + /// Called when see a primitive type. + fn primitive(&mut self, p: &DataType) -> Result; +} + +/// Visiting a type in post order. +fn visit_type(r#type: &DataType, visitor: &mut V) -> Result { + match r#type { + p if p.is_primitive() + || matches!( + p, + DataType::Boolean + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::FixedSizeBinary(_) + ) => + { + visitor.primitive(p) + } + DataType::List(element_field) => visit_list(r#type, element_field, visitor), + DataType::LargeList(element_field) => visit_list(r#type, element_field, visitor), + DataType::FixedSizeList(element_field, _) => visit_list(r#type, element_field, visitor), + DataType::Map(field, _) => match field.data_type() { + DataType::Struct(fields) => { + if fields.len() != 2 { + return Err(Error::new( + ErrorKind::DataInvalid, + "Map field must have exactly 2 fields", + )); + } + + let key_field = &fields[0]; + let value_field = &fields[1]; + + let key_result = { + visitor.before_map_key(key_field)?; + let ret = visit_type(key_field.data_type(), visitor)?; + visitor.after_map_key(key_field)?; + ret + }; + + let value_result = { + visitor.before_map_value(value_field)?; + let ret = visit_type(value_field.data_type(), visitor)?; + visitor.after_map_value(value_field)?; + ret + }; + + visitor.map(r#type, key_result, value_result) + } + _ => Err(Error::new( + ErrorKind::DataInvalid, + "Map field must have struct type", + )), + }, + DataType::Struct(fields) => visit_struct(fields, visitor), + other => Err(Error::new( + ErrorKind::DataInvalid, + format!("Cannot visit Arrow data type: {other}"), + )), + } +} + +/// Visit list types in post order. +#[allow(dead_code)] +fn visit_list( + data_type: &DataType, + element_field: &Field, + visitor: &mut V, +) -> Result { + visitor.before_list_element(element_field)?; + let value = visit_type(element_field.data_type(), visitor)?; + visitor.after_list_element(element_field)?; + visitor.list(data_type, value) +} + +/// Visit struct type in post order. +#[allow(dead_code)] +fn visit_struct(fields: &Fields, visitor: &mut V) -> Result { + let mut results = Vec::with_capacity(fields.len()); + for field in fields { + visitor.before_field(field)?; + let result = visit_type(field.data_type(), visitor)?; + visitor.after_field(field)?; + results.push(result); + } + + visitor.r#struct(fields, results) +} + +/// Visit schema in post order. +#[allow(dead_code)] +fn visit_schema(schema: &ArrowSchema, visitor: &mut V) -> Result { + let mut results = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + visitor.before_field(field)?; + let result = visit_type(field.data_type(), visitor)?; + visitor.after_field(field)?; + results.push(result); + } + visitor.schema(schema, results) +} + +/// Convert Arrow schema to ceberg schema. +#[allow(dead_code)] +pub fn arrow_schema_to_schema(schema: &ArrowSchema) -> Result { + let mut visitor = ArrowSchemaConverter::new(); + visit_schema(schema, &mut visitor) +} + +const ARROW_FIELD_DOC_KEY: &str = "doc"; + +fn get_field_id(field: &Field) -> Result { + if let Some(value) = field.metadata().get(PARQUET_FIELD_ID_META_KEY) { + return value.parse::().map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Failed to parse field id".to_string(), + ) + .with_context("value", value) + .with_source(e) + }); + } + Err(Error::new( + ErrorKind::DataInvalid, + "Field id not found in metadata", + )) +} + +fn get_field_doc(field: &Field) -> Option { + if let Some(value) = field.metadata().get(ARROW_FIELD_DOC_KEY) { + return Some(value.clone()); + } + None +} + +struct ArrowSchemaConverter; + +impl ArrowSchemaConverter { + #[allow(dead_code)] + fn new() -> Self { + Self {} + } + + fn convert_fields(fields: &Fields, field_results: &[Type]) -> Result> { + let mut results = Vec::with_capacity(fields.len()); + for i in 0..fields.len() { + let field = &fields[i]; + let field_type = &field_results[i]; + let id = get_field_id(field)?; + let doc = get_field_doc(field); + let nested_field = NestedField { + id, + doc, + name: field.name().clone(), + required: !field.is_nullable(), + field_type: Box::new(field_type.clone()), + initial_default: None, + write_default: None, + }; + results.push(Arc::new(nested_field)); + } + Ok(results) + } +} + +impl ArrowSchemaVisitor for ArrowSchemaConverter { + type T = Type; + type U = Schema; + + fn schema(&mut self, schema: &ArrowSchema, values: Vec) -> Result { + let fields = Self::convert_fields(schema.fields(), &values)?; + let builder = Schema::builder().with_fields(fields); + builder.build() + } + + fn r#struct(&mut self, fields: &Fields, results: Vec) -> Result { + let fields = Self::convert_fields(fields, &results)?; + Ok(Type::Struct(StructType::new(fields))) + } + + fn list(&mut self, list: &DataType, value: Self::T) -> Result { + let element_field = match list { + DataType::List(element_field) => element_field, + DataType::LargeList(element_field) => element_field, + DataType::FixedSizeList(element_field, _) => element_field, + _ => { + return Err(Error::new( + ErrorKind::DataInvalid, + "List type must have list data type", + )) + } + }; + + let id = get_field_id(element_field)?; + let doc = get_field_doc(element_field); + let mut element_field = + NestedField::list_element(id, value.clone(), !element_field.is_nullable()); + if let Some(doc) = doc { + element_field = element_field.with_doc(doc); + } + let element_field = Arc::new(element_field); + Ok(Type::List(ListType { element_field })) + } + + fn map(&mut self, map: &DataType, key_value: Self::T, value: Self::T) -> Result { + match map { + DataType::Map(field, _) => match field.data_type() { + DataType::Struct(fields) => { + if fields.len() != 2 { + return Err(Error::new( + ErrorKind::DataInvalid, + "Map field must have exactly 2 fields", + )); + } + + let key_field = &fields[0]; + let value_field = &fields[1]; + + let key_id = get_field_id(key_field)?; + let key_doc = get_field_doc(key_field); + let mut key_field = NestedField::map_key_element(key_id, key_value.clone()); + if let Some(doc) = key_doc { + key_field = key_field.with_doc(doc); + } + let key_field = Arc::new(key_field); + + let value_id = get_field_id(value_field)?; + let value_doc = get_field_doc(value_field); + let mut value_field = NestedField::map_value_element( + value_id, + value.clone(), + !value_field.is_nullable(), + ); + if let Some(doc) = value_doc { + value_field = value_field.with_doc(doc); + } + let value_field = Arc::new(value_field); + + Ok(Type::Map(MapType { + key_field, + value_field, + })) + } + _ => Err(Error::new( + ErrorKind::DataInvalid, + "Map field must have struct type", + )), + }, + _ => Err(Error::new( + ErrorKind::DataInvalid, + "Map type must have map data type", + )), + } + } + + fn primitive(&mut self, p: &DataType) -> Result { + match p { + DataType::Boolean => Ok(Type::Primitive(PrimitiveType::Boolean)), + DataType::Int32 => Ok(Type::Primitive(PrimitiveType::Int)), + DataType::Int64 => Ok(Type::Primitive(PrimitiveType::Long)), + DataType::Float32 => Ok(Type::Primitive(PrimitiveType::Float)), + DataType::Float64 => Ok(Type::Primitive(PrimitiveType::Double)), + DataType::Decimal128(p, s) => Type::decimal(*p as u32, *s as u32).map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Failed to create decimal type".to_string(), + ) + .with_source(e) + }), + DataType::Date32 => Ok(Type::Primitive(PrimitiveType::Date)), + DataType::Time64(unit) if unit == &TimeUnit::Microsecond => { + Ok(Type::Primitive(PrimitiveType::Time)) + } + DataType::Timestamp(unit, None) if unit == &TimeUnit::Microsecond => { + Ok(Type::Primitive(PrimitiveType::Timestamp)) + } + DataType::Timestamp(unit, Some(zone)) + if unit == &TimeUnit::Microsecond + && (zone.as_ref() == "UTC" || zone.as_ref() == "+00:00") => + { + Ok(Type::Primitive(PrimitiveType::Timestamptz)) + } + DataType::Binary | DataType::LargeBinary => Ok(Type::Primitive(PrimitiveType::Binary)), + DataType::FixedSizeBinary(width) => { + Ok(Type::Primitive(PrimitiveType::Fixed(*width as u64))) + } + DataType::Utf8 | DataType::LargeUtf8 => Ok(Type::Primitive(PrimitiveType::String)), + _ => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported Arrow data type: {p}"), + )), + } + } +} + +struct ToArrowSchemaConverter; + +enum ArrowSchemaOrFieldOrType { + Schema(ArrowSchema), + Field(Field), + Type(DataType), +} + +impl SchemaVisitor for ToArrowSchemaConverter { + type T = ArrowSchemaOrFieldOrType; + + fn schema( + &mut self, + _schema: &crate::spec::Schema, + value: ArrowSchemaOrFieldOrType, + ) -> crate::Result { + let struct_type = match value { + ArrowSchemaOrFieldOrType::Type(DataType::Struct(fields)) => fields, + _ => unreachable!(), + }; + Ok(ArrowSchemaOrFieldOrType::Schema(ArrowSchema::new( + struct_type, + ))) + } + + fn field( + &mut self, + field: &crate::spec::NestedFieldRef, + value: ArrowSchemaOrFieldOrType, + ) -> crate::Result { + let ty = match value { + ArrowSchemaOrFieldOrType::Type(ty) => ty, + _ => unreachable!(), + }; + let metadata = if let Some(doc) = &field.doc { + HashMap::from([ + (PARQUET_FIELD_ID_META_KEY.to_string(), field.id.to_string()), + (ARROW_FIELD_DOC_KEY.to_string(), doc.clone()), + ]) + } else { + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), field.id.to_string())]) + }; + Ok(ArrowSchemaOrFieldOrType::Field( + Field::new(field.name.clone(), ty, !field.required).with_metadata(metadata), + )) + } + + fn r#struct( + &mut self, + _: &crate::spec::StructType, + results: Vec, + ) -> crate::Result { + let fields = results + .into_iter() + .map(|result| match result { + ArrowSchemaOrFieldOrType::Field(field) => field, + _ => unreachable!(), + }) + .collect(); + Ok(ArrowSchemaOrFieldOrType::Type(DataType::Struct(fields))) + } + + fn list( + &mut self, + list: &crate::spec::ListType, + value: ArrowSchemaOrFieldOrType, + ) -> crate::Result { + let field = match self.field(&list.element_field, value)? { + ArrowSchemaOrFieldOrType::Field(field) => field, + _ => unreachable!(), + }; + let meta = if let Some(doc) = &list.element_field.doc { + HashMap::from([ + ( + PARQUET_FIELD_ID_META_KEY.to_string(), + list.element_field.id.to_string(), + ), + (ARROW_FIELD_DOC_KEY.to_string(), doc.clone()), + ]) + } else { + HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + list.element_field.id.to_string(), + )]) + }; + let field = field.with_metadata(meta); + Ok(ArrowSchemaOrFieldOrType::Type(DataType::List(Arc::new( + field, + )))) + } + + fn map( + &mut self, + map: &crate::spec::MapType, + key_value: ArrowSchemaOrFieldOrType, + value: ArrowSchemaOrFieldOrType, + ) -> crate::Result { + let key_field = match self.field(&map.key_field, key_value)? { + ArrowSchemaOrFieldOrType::Field(field) => field, + _ => unreachable!(), + }; + let value_field = match self.field(&map.value_field, value)? { + ArrowSchemaOrFieldOrType::Field(field) => field, + _ => unreachable!(), + }; + let field = Field::new( + DEFAULT_MAP_FIELD_NAME, + DataType::Struct(vec![key_field, value_field].into()), + // Map field is always not nullable + false, + ); + + Ok(ArrowSchemaOrFieldOrType::Type(DataType::Map( + field.into(), + false, + ))) + } + + fn primitive( + &mut self, + p: &crate::spec::PrimitiveType, + ) -> crate::Result { + match p { + crate::spec::PrimitiveType::Boolean => { + Ok(ArrowSchemaOrFieldOrType::Type(DataType::Boolean)) + } + crate::spec::PrimitiveType::Int => Ok(ArrowSchemaOrFieldOrType::Type(DataType::Int32)), + crate::spec::PrimitiveType::Long => Ok(ArrowSchemaOrFieldOrType::Type(DataType::Int64)), + crate::spec::PrimitiveType::Float => { + Ok(ArrowSchemaOrFieldOrType::Type(DataType::Float32)) + } + crate::spec::PrimitiveType::Double => { + Ok(ArrowSchemaOrFieldOrType::Type(DataType::Float64)) + } + crate::spec::PrimitiveType::Decimal { precision, scale } => { + let (precision, scale) = { + let precision: u8 = precision.to_owned().try_into().map_err(|err| { + Error::new( + crate::ErrorKind::DataInvalid, + "incompatible precision for decimal type convert", + ) + .with_source(err) + })?; + let scale = scale.to_owned().try_into().map_err(|err| { + Error::new( + crate::ErrorKind::DataInvalid, + "incompatible scale for decimal type convert", + ) + .with_source(err) + })?; + (precision, scale) + }; + validate_decimal_precision_and_scale::(precision, scale).map_err( + |err| { + Error::new( + crate::ErrorKind::DataInvalid, + "incompatible precision and scale for decimal type convert", + ) + .with_source(err) + }, + )?; + Ok(ArrowSchemaOrFieldOrType::Type(DataType::Decimal128( + precision, scale, + ))) + } + crate::spec::PrimitiveType::Date => { + Ok(ArrowSchemaOrFieldOrType::Type(DataType::Date32)) + } + crate::spec::PrimitiveType::Time => Ok(ArrowSchemaOrFieldOrType::Type( + DataType::Time64(TimeUnit::Microsecond), + )), + crate::spec::PrimitiveType::Timestamp => Ok(ArrowSchemaOrFieldOrType::Type( + DataType::Timestamp(TimeUnit::Microsecond, None), + )), + crate::spec::PrimitiveType::Timestamptz => Ok(ArrowSchemaOrFieldOrType::Type( + // Timestampz always stored as UTC + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + )), + crate::spec::PrimitiveType::TimestampNs => Ok(ArrowSchemaOrFieldOrType::Type( + DataType::Timestamp(TimeUnit::Nanosecond, None), + )), + crate::spec::PrimitiveType::TimestamptzNs => Ok(ArrowSchemaOrFieldOrType::Type( + // Store timestamptz_ns as UTC + DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + )), + crate::spec::PrimitiveType::String => { + Ok(ArrowSchemaOrFieldOrType::Type(DataType::Utf8)) + } + crate::spec::PrimitiveType::Uuid => Ok(ArrowSchemaOrFieldOrType::Type( + DataType::FixedSizeBinary(16), + )), + crate::spec::PrimitiveType::Fixed(len) => Ok(ArrowSchemaOrFieldOrType::Type( + len.to_i32() + .map(DataType::FixedSizeBinary) + .unwrap_or(DataType::LargeBinary), + )), + crate::spec::PrimitiveType::Binary => { + Ok(ArrowSchemaOrFieldOrType::Type(DataType::LargeBinary)) + } + } + } +} + +/// Convert iceberg schema to an arrow schema. +pub fn schema_to_arrow_schema(schema: &crate::spec::Schema) -> crate::Result { + let mut converter = ToArrowSchemaConverter; + match crate::spec::visit_schema(schema, &mut converter)? { + ArrowSchemaOrFieldOrType::Schema(schema) => Ok(schema), + _ => unreachable!(), + } +} + +/// Convert Iceberg Datum to Arrow Datum. +pub(crate) fn get_arrow_datum(datum: &Datum) -> Result> { + match (datum.data_type(), datum.literal()) { + (PrimitiveType::Boolean, PrimitiveLiteral::Boolean(value)) => { + Ok(Box::new(BooleanArray::new_scalar(*value))) + } + (PrimitiveType::Int, PrimitiveLiteral::Int(value)) => { + Ok(Box::new(Int32Array::new_scalar(*value))) + } + (PrimitiveType::Long, PrimitiveLiteral::Long(value)) => { + Ok(Box::new(Int64Array::new_scalar(*value))) + } + (PrimitiveType::Float, PrimitiveLiteral::Float(value)) => { + Ok(Box::new(Float32Array::new_scalar(value.as_f32()))) + } + (PrimitiveType::Double, PrimitiveLiteral::Double(value)) => { + Ok(Box::new(Float64Array::new_scalar(value.as_f64()))) + } + (PrimitiveType::String, PrimitiveLiteral::String(value)) => { + Ok(Box::new(StringArray::new_scalar(value.as_str()))) + } + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => { + Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value))) + } + (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Box::new(Scalar::new( + PrimitiveArray::::new(vec![*value; 1].into(), None) + .with_timezone("UTC"), + ))), + + (typ, _) => Err(Error::new( + ErrorKind::FeatureUnsupported, + format!( + "Converting datum from type {:?} to arrow not supported yet.", + typ + ), + )), + } +} + +macro_rules! get_parquet_stat_as_datum { + ($limit_type:ident) => { + paste::paste! { + /// Gets the $limit_type value from a parquet Statistics struct, as a Datum + pub(crate) fn []( + primitive_type: &PrimitiveType, stats: &Statistics + ) -> Result> { + Ok(Some(match (primitive_type, stats) { + (PrimitiveType::Boolean, Statistics::Boolean(stats)) => Datum::bool(*stats.$limit_type()), + (PrimitiveType::Int, Statistics::Int32(stats)) => Datum::int(*stats.$limit_type()), + (PrimitiveType::Date, Statistics::Int32(stats)) => Datum::date(*stats.$limit_type()), + (PrimitiveType::Long, Statistics::Int64(stats)) => Datum::long(*stats.$limit_type()), + (PrimitiveType::Time, Statistics::Int64(stats)) => Datum::time_micros(*stats.$limit_type())?, + (PrimitiveType::Timestamp, Statistics::Int64(stats)) => { + Datum::timestamp_micros(*stats.$limit_type()) + } + (PrimitiveType::Timestamptz, Statistics::Int64(stats)) => { + Datum::timestamptz_micros(*stats.$limit_type()) + } + (PrimitiveType::TimestampNs, Statistics::Int64(stats)) => { + Datum::timestamp_nanos(*stats.$limit_type()) + } + (PrimitiveType::TimestamptzNs, Statistics::Int64(stats)) => { + Datum::timestamptz_nanos(*stats.$limit_type()) + } + (PrimitiveType::Float, Statistics::Float(stats)) => Datum::float(*stats.$limit_type()), + (PrimitiveType::Double, Statistics::Double(stats)) => Datum::double(*stats.$limit_type()), + (PrimitiveType::String, Statistics::ByteArray(stats)) => { + Datum::string(stats.$limit_type().as_utf8()?) + } + (PrimitiveType::Decimal { + precision: _, + scale: _, + }, Statistics::ByteArray(stats)) => { + Datum::new( + primitive_type.clone(), + PrimitiveLiteral::Int128(i128::from_le_bytes(stats.[<$limit_type _bytes>]().try_into()?)), + ) + } + ( + PrimitiveType::Decimal { + precision: _, + scale: _, + }, + Statistics::Int32(stats)) => { + Datum::new( + primitive_type.clone(), + PrimitiveLiteral::Int128(i128::from(*stats.$limit_type())), + ) + } + + ( + PrimitiveType::Decimal { + precision: _, + scale: _, + }, + Statistics::Int64(stats), + ) => { + Datum::new( + primitive_type.clone(), + PrimitiveLiteral::Int128(i128::from(*stats.$limit_type())), + ) + } + (PrimitiveType::Uuid, Statistics::FixedLenByteArray(stats)) => { + let raw = stats.[<$limit_type _bytes>](); + if raw.len() != 16 { + return Err(Error::new( + ErrorKind::Unexpected, + "Invalid length of uuid bytes.", + )); + } + Datum::uuid(Uuid::from_bytes( + raw[..16].try_into().unwrap(), + )) + } + (PrimitiveType::Fixed(len), Statistics::FixedLenByteArray(stat)) => { + let raw = stat.[<$limit_type _bytes>](); + if raw.len() != *len as usize { + return Err(Error::new( + ErrorKind::Unexpected, + "Invalid length of fixed bytes.", + )); + } + Datum::fixed(raw.to_vec()) + } + (PrimitiveType::Binary, Statistics::ByteArray(stat)) => { + Datum::binary(stat.[<$limit_type _bytes>]().to_vec()) + } + _ => { + return Ok(None); + } + })) + } + } + } +} + +get_parquet_stat_as_datum!(min); + +get_parquet_stat_as_datum!(max); + +impl TryFrom<&ArrowSchema> for crate::spec::Schema { + type Error = Error; + + fn try_from(schema: &ArrowSchema) -> crate::Result { + arrow_schema_to_schema(schema) + } +} + +impl TryFrom<&crate::spec::Schema> for ArrowSchema { + type Error = Error; + + fn try_from(schema: &crate::spec::Schema) -> crate::Result { + schema_to_arrow_schema(schema) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit}; + + use super::*; + use crate::spec::Schema; + + /// Create a simple field with metadata. + fn simple_field(name: &str, ty: DataType, nullable: bool, value: &str) -> Field { + Field::new(name, ty, nullable).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + value.to_string(), + )])) + } + + fn arrow_schema_for_arrow_schema_to_schema_test() -> ArrowSchema { + let fields = Fields::from(vec![ + simple_field("key", DataType::Int32, false, "17"), + simple_field("value", DataType::Utf8, true, "18"), + ]); + + let r#struct = DataType::Struct(fields); + let map = DataType::Map( + Arc::new(simple_field(DEFAULT_MAP_FIELD_NAME, r#struct, false, "17")), + false, + ); + + let fields = Fields::from(vec![ + simple_field("aa", DataType::Int32, false, "18"), + simple_field("bb", DataType::Utf8, true, "19"), + simple_field( + "cc", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + "20", + ), + ]); + + let r#struct = DataType::Struct(fields); + + ArrowSchema::new(vec![ + simple_field("a", DataType::Int32, false, "2"), + simple_field("b", DataType::Int64, false, "1"), + simple_field("c", DataType::Utf8, false, "3"), + simple_field("n", DataType::Utf8, false, "21"), + simple_field( + "d", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + "4", + ), + simple_field("e", DataType::Boolean, true, "6"), + simple_field("f", DataType::Float32, false, "5"), + simple_field("g", DataType::Float64, false, "7"), + simple_field("p", DataType::Decimal128(10, 2), false, "27"), + simple_field("h", DataType::Date32, false, "8"), + simple_field("i", DataType::Time64(TimeUnit::Microsecond), false, "9"), + simple_field( + "j", + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + false, + "10", + ), + simple_field( + "k", + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + false, + "12", + ), + simple_field("l", DataType::Binary, false, "13"), + simple_field("o", DataType::LargeBinary, false, "22"), + simple_field("m", DataType::FixedSizeBinary(10), false, "11"), + simple_field( + "list", + DataType::List(Arc::new(simple_field( + "element", + DataType::Int32, + false, + "15", + ))), + true, + "14", + ), + simple_field( + "large_list", + DataType::LargeList(Arc::new(simple_field( + "element", + DataType::Utf8, + false, + "23", + ))), + true, + "24", + ), + simple_field( + "fixed_list", + DataType::FixedSizeList( + Arc::new(simple_field("element", DataType::Binary, false, "26")), + 10, + ), + true, + "25", + ), + simple_field("map", map, false, "16"), + simple_field("struct", r#struct, false, "17"), + ]) + } + + fn iceberg_schema_for_arrow_schema_to_schema_test() -> Schema { + let schema_json = r#"{ + "type":"struct", + "schema-id":0, + "fields":[ + { + "id":2, + "name":"a", + "required":true, + "type":"int" + }, + { + "id":1, + "name":"b", + "required":true, + "type":"long" + }, + { + "id":3, + "name":"c", + "required":true, + "type":"string" + }, + { + "id":21, + "name":"n", + "required":true, + "type":"string" + }, + { + "id":4, + "name":"d", + "required":false, + "type":"timestamp" + }, + { + "id":6, + "name":"e", + "required":false, + "type":"boolean" + }, + { + "id":5, + "name":"f", + "required":true, + "type":"float" + }, + { + "id":7, + "name":"g", + "required":true, + "type":"double" + }, + { + "id":27, + "name":"p", + "required":true, + "type":"decimal(10,2)" + }, + { + "id":8, + "name":"h", + "required":true, + "type":"date" + }, + { + "id":9, + "name":"i", + "required":true, + "type":"time" + }, + { + "id":10, + "name":"j", + "required":true, + "type":"timestamptz" + }, + { + "id":12, + "name":"k", + "required":true, + "type":"timestamptz" + }, + { + "id":13, + "name":"l", + "required":true, + "type":"binary" + }, + { + "id":22, + "name":"o", + "required":true, + "type":"binary" + }, + { + "id":11, + "name":"m", + "required":true, + "type":"fixed[10]" + }, + { + "id":14, + "name":"list", + "required": false, + "type": { + "type": "list", + "element-id": 15, + "element-required": true, + "element": "int" + } + }, + { + "id":24, + "name":"large_list", + "required": false, + "type": { + "type": "list", + "element-id": 23, + "element-required": true, + "element": "string" + } + }, + { + "id":25, + "name":"fixed_list", + "required": false, + "type": { + "type": "list", + "element-id": 26, + "element-required": true, + "element": "binary" + } + }, + { + "id":16, + "name":"map", + "required": true, + "type": { + "type": "map", + "key-id": 17, + "key": "int", + "value-id": 18, + "value-required": false, + "value": "string" + } + }, + { + "id":17, + "name":"struct", + "required": true, + "type": { + "type": "struct", + "fields": [ + { + "id":18, + "name":"aa", + "required":true, + "type":"int" + }, + { + "id":19, + "name":"bb", + "required":false, + "type":"string" + }, + { + "id":20, + "name":"cc", + "required":true, + "type":"timestamp" + } + ] + } + } + ], + "identifier-field-ids":[] + }"#; + + let schema: Schema = serde_json::from_str(schema_json).unwrap(); + schema + } + + #[test] + fn test_arrow_schema_to_schema() { + let arrow_schema = arrow_schema_for_arrow_schema_to_schema_test(); + let schema = iceberg_schema_for_arrow_schema_to_schema_test(); + let converted_schema = arrow_schema_to_schema(&arrow_schema).unwrap(); + assert_eq!(converted_schema, schema); + } + + fn arrow_schema_for_schema_to_arrow_schema_test() -> ArrowSchema { + let fields = Fields::from(vec![ + simple_field("key", DataType::Int32, false, "17"), + simple_field("value", DataType::Utf8, true, "18"), + ]); + + let r#struct = DataType::Struct(fields); + let map = DataType::Map( + Arc::new(Field::new(DEFAULT_MAP_FIELD_NAME, r#struct, false)), + false, + ); + + let fields = Fields::from(vec![ + simple_field("aa", DataType::Int32, false, "18"), + simple_field("bb", DataType::Utf8, true, "19"), + simple_field( + "cc", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + "20", + ), + ]); + + let r#struct = DataType::Struct(fields); + + ArrowSchema::new(vec![ + simple_field("a", DataType::Int32, false, "2"), + simple_field("b", DataType::Int64, false, "1"), + simple_field("c", DataType::Utf8, false, "3"), + simple_field("n", DataType::Utf8, false, "21"), + simple_field( + "d", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + "4", + ), + simple_field("e", DataType::Boolean, true, "6"), + simple_field("f", DataType::Float32, false, "5"), + simple_field("g", DataType::Float64, false, "7"), + simple_field("p", DataType::Decimal128(10, 2), false, "27"), + simple_field("h", DataType::Date32, false, "8"), + simple_field("i", DataType::Time64(TimeUnit::Microsecond), false, "9"), + simple_field( + "j", + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + false, + "10", + ), + simple_field( + "k", + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + false, + "12", + ), + simple_field("l", DataType::LargeBinary, false, "13"), + simple_field("o", DataType::LargeBinary, false, "22"), + simple_field("m", DataType::FixedSizeBinary(10), false, "11"), + simple_field( + "list", + DataType::List(Arc::new(simple_field( + "element", + DataType::Int32, + false, + "15", + ))), + true, + "14", + ), + simple_field( + "large_list", + DataType::List(Arc::new(simple_field( + "element", + DataType::Utf8, + false, + "23", + ))), + true, + "24", + ), + simple_field( + "fixed_list", + DataType::List(Arc::new(simple_field( + "element", + DataType::LargeBinary, + false, + "26", + ))), + true, + "25", + ), + simple_field("map", map, false, "16"), + simple_field("struct", r#struct, false, "17"), + simple_field("uuid", DataType::FixedSizeBinary(16), false, "26"), + ]) + } + + fn iceberg_schema_for_schema_to_arrow_schema() -> Schema { + let schema_json = r#"{ + "type":"struct", + "schema-id":0, + "fields":[ + { + "id":2, + "name":"a", + "required":true, + "type":"int" + }, + { + "id":1, + "name":"b", + "required":true, + "type":"long" + }, + { + "id":3, + "name":"c", + "required":true, + "type":"string" + }, + { + "id":21, + "name":"n", + "required":true, + "type":"string" + }, + { + "id":4, + "name":"d", + "required":false, + "type":"timestamp" + }, + { + "id":6, + "name":"e", + "required":false, + "type":"boolean" + }, + { + "id":5, + "name":"f", + "required":true, + "type":"float" + }, + { + "id":7, + "name":"g", + "required":true, + "type":"double" + }, + { + "id":27, + "name":"p", + "required":true, + "type":"decimal(10,2)" + }, + { + "id":8, + "name":"h", + "required":true, + "type":"date" + }, + { + "id":9, + "name":"i", + "required":true, + "type":"time" + }, + { + "id":10, + "name":"j", + "required":true, + "type":"timestamptz" + }, + { + "id":12, + "name":"k", + "required":true, + "type":"timestamptz" + }, + { + "id":13, + "name":"l", + "required":true, + "type":"binary" + }, + { + "id":22, + "name":"o", + "required":true, + "type":"binary" + }, + { + "id":11, + "name":"m", + "required":true, + "type":"fixed[10]" + }, + { + "id":14, + "name":"list", + "required": false, + "type": { + "type": "list", + "element-id": 15, + "element-required": true, + "element": "int" + } + }, + { + "id":24, + "name":"large_list", + "required": false, + "type": { + "type": "list", + "element-id": 23, + "element-required": true, + "element": "string" + } + }, + { + "id":25, + "name":"fixed_list", + "required": false, + "type": { + "type": "list", + "element-id": 26, + "element-required": true, + "element": "binary" + } + }, + { + "id":16, + "name":"map", + "required": true, + "type": { + "type": "map", + "key-id": 17, + "key": "int", + "value-id": 18, + "value-required": false, + "value": "string" + } + }, + { + "id":17, + "name":"struct", + "required": true, + "type": { + "type": "struct", + "fields": [ + { + "id":18, + "name":"aa", + "required":true, + "type":"int" + }, + { + "id":19, + "name":"bb", + "required":false, + "type":"string" + }, + { + "id":20, + "name":"cc", + "required":true, + "type":"timestamp" + } + ] + } + }, + { + "id":26, + "name":"uuid", + "required":true, + "type":"uuid" + } + ], + "identifier-field-ids":[] + }"#; + + let schema: Schema = serde_json::from_str(schema_json).unwrap(); + schema + } + + #[test] + fn test_schema_to_arrow_schema() { + let arrow_schema = arrow_schema_for_schema_to_arrow_schema_test(); + let schema = iceberg_schema_for_schema_to_arrow_schema(); + let converted_arrow_schema = schema_to_arrow_schema(&schema).unwrap(); + assert_eq!(converted_arrow_schema, arrow_schema); + } +} diff --git a/crates/iceberg/src/avro/mod.rs b/crates/iceberg/src/avro/mod.rs index bdccb2ff4..f2a9310e7 100644 --- a/crates/iceberg/src/avro/mod.rs +++ b/crates/iceberg/src/avro/mod.rs @@ -16,6 +16,5 @@ // under the License. //! Avro related codes. -#[allow(dead_code)] mod schema; pub(crate) use schema::*; diff --git a/crates/iceberg/src/avro/schema.rs b/crates/iceberg/src/avro/schema.rs index f8420d460..cfcf38dea 100644 --- a/crates/iceberg/src/avro/schema.rs +++ b/crates/iceberg/src/avro/schema.rs @@ -18,22 +18,27 @@ //! Conversion between iceberg and avro schema. use std::collections::BTreeMap; -use crate::spec::{ - visit_schema, ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, - SchemaVisitor, StructType, Type, -}; -use crate::{ensure_data_valid, Error, ErrorKind, Result}; use apache_avro::schema::{ - DecimalSchema, FixedSchema, Name, RecordField as AvroRecordField, RecordFieldOrder, - RecordSchema, UnionSchema, + ArraySchema, DecimalSchema, FixedSchema, MapSchema, Name, RecordField as AvroRecordField, + RecordFieldOrder, RecordSchema, UnionSchema, }; use apache_avro::Schema as AvroSchema; use itertools::{Either, Itertools}; use serde_json::{Number, Value}; +use crate::spec::{ + visit_schema, ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, + SchemaVisitor, StructType, Type, +}; +use crate::{ensure_data_valid, Error, ErrorKind, Result}; + +const ELEMENT_ID: &str = "element-id"; const FILED_ID_PROP: &str = "field-id"; +const KEY_ID: &str = "key-id"; +const VALUE_ID: &str = "value-id"; const UUID_BYTES: usize = 16; const UUID_LOGICAL_TYPE: &str = "uuid"; +const MAP_LOGICAL_TYPE: &str = "map"; // # TODO: https://github.com/apache/iceberg-rust/issues/86 // This const may better to maintain in avro-rs. const LOGICAL_TYPE: &str = "logicalType"; @@ -123,8 +128,13 @@ impl SchemaVisitor for SchemaToAvroSchema { field_schema = avro_optional(field_schema)?; } - // TODO: We need to add element id prop here, but rust's avro schema doesn't support property except record schema. - Ok(Either::Left(AvroSchema::Array(Box::new(field_schema)))) + Ok(Either::Left(AvroSchema::Array(ArraySchema { + items: Box::new(field_schema), + attributes: BTreeMap::from([( + ELEMENT_ID.to_string(), + Value::Number(Number::from(list.element_field.id)), + )]), + }))) } fn map( @@ -140,7 +150,19 @@ impl SchemaVisitor for SchemaToAvroSchema { } if matches!(key_field_schema, AvroSchema::String) { - Ok(Either::Left(AvroSchema::Map(Box::new(value_field_schema)))) + Ok(Either::Left(AvroSchema::Map(MapSchema { + types: Box::new(value_field_schema), + attributes: BTreeMap::from([ + ( + KEY_ID.to_string(), + Value::Number(Number::from(map.key_field.id)), + ), + ( + VALUE_ID.to_string(), + Value::Number(Number::from(map.value_field.id)), + ), + ]), + }))) } else { // Avro map requires that key must be string type. Here we convert it to array if key is // not string type. @@ -186,7 +208,13 @@ impl SchemaVisitor for SchemaToAvroSchema { fields, )?; - Ok(Either::Left(AvroSchema::Array(item_avro_schema.into()))) + Ok(Either::Left(AvroSchema::Array(ArraySchema { + items: Box::new(item_avro_schema), + attributes: BTreeMap::from([( + LOGICAL_TYPE.to_string(), + Value::String(MAP_LOGICAL_TYPE.to_string()), + )]), + }))) } } @@ -201,6 +229,8 @@ impl SchemaVisitor for SchemaToAvroSchema { PrimitiveType::Time => AvroSchema::TimeMicros, PrimitiveType::Timestamp => AvroSchema::TimestampMicros, PrimitiveType::Timestamptz => AvroSchema::TimestampMicros, + PrimitiveType::TimestampNs => AvroSchema::TimestampNanos, + PrimitiveType::TimestamptzNs => AvroSchema::TimestampNanos, PrimitiveType::String => AvroSchema::String, PrimitiveType::Uuid => avro_fixed_schema(UUID_BYTES, Some(UUID_LOGICAL_TYPE))?, PrimitiveType::Fixed(len) => avro_fixed_schema((*len) as usize, None)?, @@ -254,14 +284,28 @@ pub(crate) fn avro_fixed_schema(len: usize, logical_type: Option<&str>) -> Resul doc: None, size: len, attributes, + default: None, })) } pub(crate) fn avro_decimal_schema(precision: usize, scale: usize) -> Result { + // Avro decimal logical type annotates Avro bytes _or_ fixed types. + // https://avro.apache.org/docs/1.11.1/specification/_print/#decimal + // Iceberg spec: Stored as _fixed_ using the minimum number of bytes for the given precision. + // https://iceberg.apache.org/spec/#avro Ok(AvroSchema::Decimal(DecimalSchema { precision, scale, - inner: Box::new(AvroSchema::Bytes), + inner: Box::new(AvroSchema::Fixed(FixedSchema { + // Name is not restricted by the spec. + // Refer to iceberg-python https://github.com/apache/iceberg-python/blob/d8bc1ca9af7957ce4d4db99a52c701ac75db7688/pyiceberg/utils/schema_conversion.py#L574-L582 + name: Name::new(&format!("decimal_{precision}_{scale}")).unwrap(), + aliases: None, + doc: None, + size: crate::spec::Type::decimal_required_bytes(precision as u32)? as usize, + attributes: Default::default(), + default: None, + })), })) } @@ -287,8 +331,11 @@ pub(crate) trait AvroSchemaVisitor { fn union(&mut self, union: &UnionSchema, options: Vec) -> Result; - fn array(&mut self, array: &AvroSchema, item: Self::T) -> Result; - fn map(&mut self, map: &AvroSchema, value: Self::T) -> Result; + fn array(&mut self, array: &ArraySchema, item: Self::T) -> Result; + fn map(&mut self, map: &MapSchema, value: Self::T) -> Result; + // There are two representation for iceberg map in avro: array of key-value records, or map when keys are strings (optional), + // ref: https://iceberg.apache.org/spec/#avro + fn map_array(&mut self, array: &RecordSchema, key: Self::T, value: Self::T) -> Result; fn primitive(&mut self, schema: &AvroSchema) -> Result; } @@ -315,25 +362,73 @@ pub(crate) fn visit(schema: &AvroSchema, visitor: &mut V) visitor.union(union, option_results) } AvroSchema::Array(item) => { - let item_result = visit(item, visitor)?; - visitor.array(schema, item_result) + if let Some(logical_type) = item + .attributes + .get(LOGICAL_TYPE) + .and_then(|v| Value::as_str(v)) + { + if logical_type == MAP_LOGICAL_TYPE { + if let AvroSchema::Record(record_schema) = &*item.items { + let key = visit(&record_schema.fields[0].schema, visitor)?; + let value = visit(&record_schema.fields[1].schema, visitor)?; + return visitor.map_array(record_schema, key, value); + } else { + return Err(Error::new( + ErrorKind::DataInvalid, + "Can't convert avro map schema, item is not a record.", + )); + } + } else { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + format!( + "Logical type {logical_type} is not support in iceberg array type.", + ), + )); + } + } + let item_result = visit(&item.items, visitor)?; + visitor.array(item, item_result) } AvroSchema::Map(inner) => { - let item_result = visit(inner, visitor)?; - visitor.map(schema, item_result) + let item_result = visit(&inner.types, visitor)?; + visitor.map(inner, item_result) } schema => visitor.primitive(schema), } } -struct AvroSchemaToSchema { - next_id: i32, -} +struct AvroSchemaToSchema; impl AvroSchemaToSchema { - fn next_field_id(&mut self) -> i32 { - self.next_id += 1; - self.next_id + /// A convenient way to get element id(i32) from attributes. + #[inline] + fn get_element_id_from_attributes( + attributes: &BTreeMap, + name: &str, + ) -> Result { + attributes + .get(name) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Can't convert avro array schema, missing element id.", + ) + })? + .as_i64() + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Can't convert avro array schema, element id is not a valid i64 number.", + ) + })? + .try_into() + .map_err(|_| { + Error::new( + ErrorKind::DataInvalid, + "Can't convert avro array schema, element id is not a valid i32.", + ) + }) } } @@ -348,24 +443,12 @@ impl AvroSchemaVisitor for AvroSchemaToSchema { ) -> Result> { let mut fields = Vec::with_capacity(field_types.len()); for (avro_field, typ) in record.fields.iter().zip_eq(field_types) { - let field_id = avro_field - .custom_attributes - .get(FILED_ID_PROP) - .and_then(Value::as_i64) - .ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - format!("Can't convert field, missing field id: {avro_field:?}"), - ) - })?; + let field_id = + Self::get_element_id_from_attributes(&avro_field.custom_attributes, FILED_ID_PROP)?; let optional = is_avro_optional(&avro_field.schema); - let mut field = if optional { - NestedField::optional(field_id as i32, &avro_field.name, typ.unwrap()) - } else { - NestedField::required(field_id as i32, &avro_field.name, typ.unwrap()) - }; + let mut field = NestedField::new(field_id, &avro_field.name, typ.unwrap(), !optional); if let Some(doc) = &avro_field.doc { field = field.with_doc(doc); @@ -403,46 +486,31 @@ impl AvroSchemaVisitor for AvroSchemaToSchema { } } - fn array(&mut self, array: &AvroSchema, item: Option) -> Result { - if let AvroSchema::Array(item_schema) = array { - let element_field = NestedField::list_element( - self.next_field_id(), - item.unwrap(), - !is_avro_optional(item_schema), - ) - .into(); - Ok(Some(Type::List(ListType { element_field }))) - } else { - Err(Error::new( - ErrorKind::Unexpected, - "Expected avro array schema, but {array}", - )) - } + fn array(&mut self, array: &ArraySchema, item: Option) -> Result { + let element_field_id = Self::get_element_id_from_attributes(&array.attributes, ELEMENT_ID)?; + let element_field = NestedField::list_element( + element_field_id, + item.unwrap(), + !is_avro_optional(&array.items), + ) + .into(); + Ok(Some(Type::List(ListType { element_field }))) } - fn map(&mut self, map: &AvroSchema, value: Option) -> Result> { - if let AvroSchema::Map(value_schema) = map { - // Due to avro rust implementation's limitation, we can't store attributes in map schema, - // we will fix it later when it has been resolved. - let key_field = NestedField::map_key_element( - self.next_field_id(), - Type::Primitive(PrimitiveType::String), - ); - let value_field = NestedField::map_value_element( - self.next_field_id(), - value.unwrap(), - !is_avro_optional(value_schema), - ); - Ok(Some(Type::Map(MapType { - key_field: key_field.into(), - value_field: value_field.into(), - }))) - } else { - Err(Error::new( - ErrorKind::Unexpected, - "Expected avro map schema, but {map}", - )) - } + fn map(&mut self, map: &MapSchema, value: Option) -> Result> { + let key_field_id = Self::get_element_id_from_attributes(&map.attributes, KEY_ID)?; + let key_field = + NestedField::map_key_element(key_field_id, Type::Primitive(PrimitiveType::String)); + let value_field_id = Self::get_element_id_from_attributes(&map.attributes, VALUE_ID)?; + let value_field = NestedField::map_value_element( + value_field_id, + value.unwrap(), + !is_avro_optional(&map.types), + ); + Ok(Some(Type::Map(MapType { + key_field: key_field.into(), + value_field: value_field.into(), + }))) } fn primitive(&mut self, schema: &AvroSchema) -> Result> { @@ -453,6 +521,7 @@ impl AvroSchemaVisitor for AvroSchemaToSchema { AvroSchema::Date => Type::Primitive(PrimitiveType::Date), AvroSchema::TimeMicros => Type::Primitive(PrimitiveType::Time), AvroSchema::TimestampMicros => Type::Primitive(PrimitiveType::Timestamp), + AvroSchema::TimestampNanos => Type::Primitive(PrimitiveType::TimestampNs), AvroSchema::Boolean => Type::Primitive(PrimitiveType::Boolean), AvroSchema::Int => Type::Primitive(PrimitiveType::Int), AvroSchema::Long => Type::Primitive(PrimitiveType::Long), @@ -494,12 +563,53 @@ impl AvroSchemaVisitor for AvroSchemaToSchema { Ok(Some(typ)) } + + fn map_array( + &mut self, + array: &RecordSchema, + key: Option, + value: Option, + ) -> Result { + let key = key.ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Can't convert avro map schema, missing key schema.", + ) + })?; + let value = value.ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Can't convert avro map schema, missing value schema.", + ) + })?; + let key_id = Self::get_element_id_from_attributes( + &array.fields[0].custom_attributes, + FILED_ID_PROP, + )?; + let value_id = Self::get_element_id_from_attributes( + &array.fields[1].custom_attributes, + FILED_ID_PROP, + )?; + let key_field = NestedField::map_key_element(key_id, key); + let value_field = NestedField::map_value_element( + value_id, + value, + !is_avro_optional(&array.fields[1].schema), + ); + Ok(Some(Type::Map(MapType { + key_field: key_field.into(), + value_field: value_field.into(), + }))) + } } +// # TODO +// Fix this when we have used `avro_schema_to_schema` inner. +#[allow(unused)] /// Converts avro schema to iceberg schema. pub(crate) fn avro_schema_to_schema(avro_schema: &AvroSchema) -> Result { if let AvroSchema::Record(_) = avro_schema { - let mut converter = AvroSchemaToSchema { next_id: 0 }; + let mut converter = AvroSchemaToSchema; let typ = visit(avro_schema, &mut converter)?.expect("Iceberg schema should not be none."); if let Type::Struct(s) = typ { Schema::builder() @@ -521,12 +631,15 @@ pub(crate) fn avro_schema_to_schema(avro_schema: &AvroSchema) -> Result #[cfg(test)] mod tests { + use std::fs::read_to_string; + use std::sync::Arc; + + use apache_avro::schema::{Namespace, UnionSchema}; + use apache_avro::Schema as AvroSchema; + use super::*; use crate::avro::schema::AvroSchemaToSchema; use crate::spec::{ListType, MapType, NestedField, PrimitiveType, Schema, StructType, Type}; - use apache_avro::schema::{Namespace, UnionSchema}; - use apache_avro::Schema as AvroSchema; - use std::fs::read_to_string; fn read_test_data_file_to_avro_schema(filename: &str) -> AvroSchema { let input = read_to_string(format!( @@ -539,22 +652,27 @@ mod tests { AvroSchema::parse_str(input.as_str()).unwrap() } - fn check_schema_conversion( - avro_schema: AvroSchema, - expected_iceberg_schema: Schema, - check_avro_to_iceberg: bool, - ) { - if check_avro_to_iceberg { - let converted_iceberg_schema = avro_schema_to_schema(&avro_schema).unwrap(); - assert_eq!(expected_iceberg_schema, converted_iceberg_schema); - } + /// Help function to check schema conversion between avro and iceberg: + /// 1. avro to iceberg + /// 2. iceberg to avro + /// 3. iceberg to avro to iceberg back + fn check_schema_conversion(avro_schema: AvroSchema, iceberg_schema: Schema) { + // 1. avro to iceberg + let converted_iceberg_schema = avro_schema_to_schema(&avro_schema).unwrap(); + assert_eq!(iceberg_schema, converted_iceberg_schema); + // 2. iceberg to avro let converted_avro_schema = schema_to_avro_schema( avro_schema.name().unwrap().fullname(Namespace::None), - &expected_iceberg_schema, + &iceberg_schema, ) .unwrap(); assert_eq!(avro_schema, converted_avro_schema); + + // 3.iceberg to avro to iceberg back + let converted_avro_converted_iceberg_schema = + avro_schema_to_schema(&converted_avro_schema).unwrap(); + assert_eq!(iceberg_schema, converted_avro_converted_iceberg_schema); } #[test] @@ -633,7 +751,6 @@ mod tests { check_schema_conversion( read_test_data_file_to_avro_schema("avro_schema_manifest_file_v1.json"), iceberg_schema, - false, ); } @@ -682,7 +799,7 @@ mod tests { .unwrap() }; - check_schema_conversion(avro_schema, iceberg_schema, false); + check_schema_conversion(avro_schema, iceberg_schema); } #[test] @@ -705,7 +822,7 @@ mod tests { "field-id": 100 } ] -} +} "#, ) .unwrap() @@ -731,7 +848,7 @@ mod tests { .unwrap() }; - check_schema_conversion(avro_schema, iceberg_schema, false); + check_schema_conversion(avro_schema, iceberg_schema); } #[test] @@ -768,7 +885,7 @@ mod tests { "field-id": 100 } ] -} +} "#, ) .unwrap() @@ -808,7 +925,144 @@ mod tests { .unwrap() }; - check_schema_conversion(avro_schema, iceberg_schema, false); + check_schema_conversion(avro_schema, iceberg_schema); + } + + #[test] + fn test_schema_with_array_map() { + let avro_schema = { + AvroSchema::parse_str( + r#" +{ + "type": "record", + "name": "avro_schema", + "fields": [ + { + "name": "optional", + "type": { + "type": "array", + "items": { + "type": "record", + "name": "k102_v103", + "fields": [ + { + "name": "key", + "type": "boolean", + "field-id": 102 + }, + { + "name": "value", + "type": ["null", "boolean"], + "field-id": 103 + } + ] + }, + "default": [], + "element-id": 101, + "logicalType": "map" + }, + "field-id": 100 + },{ + "name": "required", + "type": { + "type": "array", + "items": { + "type": "record", + "name": "k105_v106", + "fields": [ + { + "name": "key", + "type": "boolean", + "field-id": 105 + }, + { + "name": "value", + "type": "boolean", + "field-id": 106 + } + ] + }, + "default": [], + "logicalType": "map" + }, + "field-id": 104 + }, { + "name": "string_map", + "type": { + "type": "map", + "values": ["null", "long"], + "key-id": 108, + "value-id": 109 + }, + "field-id": 107 + } + ] +} +"#, + ) + .unwrap() + }; + + let iceberg_schema = { + Schema::builder() + .with_fields(vec![ + Arc::new(NestedField::required( + 100, + "optional", + Type::Map(MapType { + key_field: NestedField::map_key_element( + 102, + PrimitiveType::Boolean.into(), + ) + .into(), + value_field: NestedField::map_value_element( + 103, + PrimitiveType::Boolean.into(), + false, + ) + .into(), + }), + )), + Arc::new(NestedField::required( + 104, + "required", + Type::Map(MapType { + key_field: NestedField::map_key_element( + 105, + PrimitiveType::Boolean.into(), + ) + .into(), + value_field: NestedField::map_value_element( + 106, + PrimitiveType::Boolean.into(), + true, + ) + .into(), + }), + )), + Arc::new(NestedField::required( + 107, + "string_map", + Type::Map(MapType { + key_field: NestedField::map_key_element( + 108, + PrimitiveType::String.into(), + ) + .into(), + value_field: NestedField::map_value_element( + 109, + PrimitiveType::Long.into(), + false, + ) + .into(), + }), + )), + ]) + .build() + .unwrap() + }; + + check_schema_conversion(avro_schema, iceberg_schema); } #[test] @@ -820,7 +1074,7 @@ mod tests { ]) .unwrap(); - let mut converter = AvroSchemaToSchema { next_id: 0 }; + let mut converter = AvroSchemaToSchema; let options = avro_schema .variants() @@ -832,7 +1086,7 @@ mod tests { #[test] fn test_string_type() { - let mut converter = AvroSchemaToSchema { next_id: 0 }; + let mut converter = AvroSchemaToSchema; let avro_schema = AvroSchema::String; assert_eq!( @@ -857,10 +1111,14 @@ mod tests { .unwrap() }; - let mut converter = AvroSchemaToSchema { next_id: 0 }; + let AvroSchema::Map(avro_schema) = avro_schema else { + unreachable!() + }; + + let mut converter = AvroSchemaToSchema; let iceberg_type = Type::Map(MapType { - key_field: NestedField::map_key_element(1, PrimitiveType::String.into()).into(), - value_field: NestedField::map_value_element(2, PrimitiveType::Long.into(), false) + key_field: NestedField::map_key_element(101, PrimitiveType::String.into()).into(), + value_field: NestedField::map_value_element(102, PrimitiveType::Long.into(), false) .into(), }); @@ -884,7 +1142,7 @@ mod tests { .unwrap() }; - let mut converter = AvroSchemaToSchema { next_id: 0 }; + let mut converter = AvroSchemaToSchema; let iceberg_type = Type::from(PrimitiveType::Fixed(22)); @@ -896,7 +1154,7 @@ mod tests { #[test] fn test_unknown_primitive() { - let mut converter = AvroSchemaToSchema { next_id: 0 }; + let mut converter = AvroSchemaToSchema; assert!(converter.primitive(&AvroSchema::Duration).is_err()); } @@ -915,7 +1173,7 @@ mod tests { "type": "string" } ] -} +} "#, ) .unwrap() @@ -935,7 +1193,7 @@ mod tests { .unwrap() }; - let mut converter = AvroSchemaToSchema { next_id: 0 }; + let mut converter = AvroSchemaToSchema; assert_eq!( Type::decimal(25, 19).unwrap(), @@ -945,7 +1203,7 @@ mod tests { #[test] fn test_date_type() { - let mut converter = AvroSchemaToSchema { next_id: 0 }; + let mut converter = AvroSchemaToSchema; assert_eq!( Type::from(PrimitiveType::Date), diff --git a/crates/iceberg/src/catalog/mod.rs b/crates/iceberg/src/catalog/mod.rs index b68837593..aa2311b4a 100644 --- a/crates/iceberg/src/catalog/mod.rs +++ b/crates/iceberg/src/catalog/mod.rs @@ -17,25 +17,27 @@ //! Catalog API for Apache Iceberg +use std::collections::HashMap; +use std::fmt::Debug; +use std::mem::take; +use std::ops::Deref; + +use async_trait::async_trait; use serde_derive::{Deserialize, Serialize}; -use urlencoding::encode; +use typed_builder::TypedBuilder; +use uuid::Uuid; use crate::spec::{ - FormatVersion, Schema, Snapshot, SnapshotReference, SortOrder, UnboundPartitionSpec, + FormatVersion, Schema, Snapshot, SnapshotReference, SortOrder, TableMetadataBuilder, + UnboundPartitionSpec, ViewRepresentations, }; use crate::table::Table; use crate::{Error, ErrorKind, Result}; -use async_trait::async_trait; -use std::collections::HashMap; -use std::mem::take; -use std::ops::Deref; -use typed_builder::TypedBuilder; -use uuid::Uuid; /// The catalog API for Iceberg Rust. #[async_trait] -pub trait Catalog: std::fmt::Debug { - /// List namespaces from table. +pub trait Catalog: Debug + Sync + Send { + /// List namespaces inside the catalog. async fn list_namespaces(&self, parent: Option<&NamespaceIdent>) -> Result>; @@ -50,7 +52,7 @@ pub trait Catalog: std::fmt::Debug { async fn get_namespace(&self, namespace: &NamespaceIdent) -> Result; /// Check if namespace exists in catalog. - async fn namespace_exists(&self, namesace: &NamespaceIdent) -> Result; + async fn namespace_exists(&self, namespace: &NamespaceIdent) -> Result; /// Update a namespace inside the catalog. /// @@ -83,7 +85,7 @@ pub trait Catalog: std::fmt::Debug { async fn drop_table(&self, table: &TableIdent) -> Result<()>; /// Check if a table exists in the catalog. - async fn stat_table(&self, table: &TableIdent) -> Result; + async fn table_exists(&self, table: &TableIdent) -> Result; /// Rename a table in the catalog. async fn rename_table(&self, src: &TableIdent, dest: &TableIdent) -> Result<()>; @@ -122,9 +124,9 @@ impl NamespaceIdent { Self::from_vec(iter.into_iter().map(|s| s.to_string()).collect()) } - /// Returns url encoded format. - pub fn encode_in_url(&self) -> String { - encode(&self.as_ref().join("\u{1F}")).to_string() + /// Returns a string for used in url. + pub fn to_url_string(&self) -> String { + self.as_ref().join("\u{001f}") } /// Returns inner strings. @@ -227,7 +229,7 @@ pub struct TableCreation { /// The schema of the table. pub schema: Schema, /// The partition spec of the table, could be None. - #[builder(default, setter(strip_option))] + #[builder(default, setter(strip_option, into))] pub partition_spec: Option, /// The sort order of the table. #[builder(default, setter(strip_option))] @@ -331,7 +333,7 @@ pub enum TableRequirement { } /// TableUpdate represents an update to a table in the catalog. -#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] #[serde(tag = "action", rename_all = "kebab-case")] pub enum TableUpdate { /// Upgrade table's format version @@ -381,7 +383,7 @@ pub enum TableUpdate { #[serde(rename_all = "kebab-case")] SetDefaultSortOrder { /// Sort order ID to set as the default, or -1 to set last added sort order - sort_order_id: i32, + sort_order_id: i64, }, /// Add snapshot to table. #[serde(rename_all = "kebab-case")] @@ -427,20 +429,56 @@ pub enum TableUpdate { }, } +impl TableUpdate { + /// Applies the update to the table metadata builder. + pub fn apply(self, builder: TableMetadataBuilder) -> Result { + match self { + TableUpdate::AssignUuid { uuid } => builder.assign_uuid(uuid), + _ => unimplemented!(), + } + } +} + +/// ViewCreation represents the creation of a view in the catalog. +#[derive(Debug, TypedBuilder)] +pub struct ViewCreation { + /// The name of the view. + pub name: String, + /// The view's base location; used to create metadata file locations + pub location: String, + /// Representations for the view. + pub representations: ViewRepresentations, + /// The schema of the view. + pub schema: Schema, + /// The properties of the view. + #[builder(default)] + pub properties: HashMap, + /// The default namespace to use when a reference in the SELECT is a single identifier + pub default_namespace: NamespaceIdent, + /// Default catalog to use when a reference in the SELECT does not contain a catalog + #[builder(default)] + pub default_catalog: Option, + /// A string to string map of summary metadata about the version + /// Typical keys are "engine-name" and "engine-version" + #[builder(default)] + pub summary: HashMap, +} + #[cfg(test)] mod tests { - use crate::spec::ManifestListLocation::ManifestListFile; + use std::collections::HashMap; + use std::fmt::Debug; + + use serde::de::DeserializeOwned; + use serde::Serialize; + use uuid::uuid; + use crate::spec::{ FormatVersion, NestedField, NullOrder, Operation, PrimitiveType, Schema, Snapshot, SnapshotReference, SnapshotRetention, SortDirection, SortField, SortOrder, Summary, - Transform, Type, UnboundPartitionField, UnboundPartitionSpec, + TableMetadataBuilder, Transform, Type, UnboundPartitionSpec, }; - use crate::{NamespaceIdent, TableIdent, TableRequirement, TableUpdate}; - use serde::de::DeserializeOwned; - use serde::Serialize; - use std::collections::HashMap; - use std::fmt::Debug; - use uuid::uuid; + use crate::{NamespaceIdent, TableCreation, TableIdent, TableRequirement, TableUpdate}; #[test] fn test_create_table_id() { @@ -782,29 +820,13 @@ mod tests { "#, TableUpdate::AddSpec { spec: UnboundPartitionSpec::builder() - .with_unbound_partition_field( - UnboundPartitionField::builder() - .source_id(4) - .name("ts_day".to_string()) - .transform(Transform::Day) - .build(), - ) - .with_unbound_partition_field( - UnboundPartitionField::builder() - .source_id(1) - .name("id_bucket".to_string()) - .transform(Transform::Bucket(16)) - .build(), - ) - .with_unbound_partition_field( - UnboundPartitionField::builder() - .source_id(2) - .name("id_truncate".to_string()) - .transform(Transform::Truncate(4)) - .build(), - ) - .build() - .unwrap(), + .add_partition_field(4, "ts_day".to_string(), Transform::Day) + .unwrap() + .add_partition_field(1, "id_bucket".to_string(), Transform::Bucket(16)) + .unwrap() + .add_partition_field(2, "id_truncate".to_string(), Transform::Truncate(4)) + .unwrap() + .build(), }, ); } @@ -866,7 +888,7 @@ mod tests { .transform(Transform::Bucket(4)) .build(), ) - .build() + .build_unbound() .unwrap(), }; @@ -911,7 +933,7 @@ mod tests { .with_parent_snapshot_id(Some(3051729675574597000)) .with_timestamp_ms(1555100955770) .with_sequence_number(1) - .with_manifest_list(ManifestListFile("s3://a/b/2.avro".to_string())) + .with_manifest_list("s3://a/b/2.avro") .with_schema_id(1) .with_summary(Summary { operation: Operation::Append, @@ -972,7 +994,9 @@ mod tests { ref_name: "hank".to_string(), reference: SnapshotReference { snapshot_id: 1, - retention: SnapshotRetention::Tag { max_ref_age_ms: 1 }, + retention: SnapshotRetention::Tag { + max_ref_age_ms: Some(1), + }, }, }; @@ -1066,4 +1090,28 @@ mod tests { test_serde_json(json, update); } + + #[test] + fn test_table_update_apply() { + let table_creation = TableCreation::builder() + .location("s3://db/table".to_string()) + .name("table".to_string()) + .properties(HashMap::new()) + .schema(Schema::builder().build().unwrap()) + .build(); + let table_metadata = TableMetadataBuilder::from_table_creation(table_creation) + .unwrap() + .build() + .unwrap(); + let table_metadata_builder = TableMetadataBuilder::new(table_metadata); + + let uuid = uuid::Uuid::new_v4(); + let update = TableUpdate::AssignUuid { uuid }; + let updated_metadata = update + .apply(table_metadata_builder) + .unwrap() + .build() + .unwrap(); + assert_eq!(updated_metadata.uuid(), uuid); + } } diff --git a/crates/iceberg/src/error.rs b/crates/iceberg/src/error.rs index 55c010b09..2b69b4706 100644 --- a/crates/iceberg/src/error.rs +++ b/crates/iceberg/src/error.rs @@ -17,9 +17,9 @@ use std::backtrace::{Backtrace, BacktraceStatus}; use std::fmt; -use std::fmt::Debug; -use std::fmt::Display; -use std::fmt::Formatter; +use std::fmt::{Debug, Display, Formatter}; + +use chrono::{DateTime, TimeZone as _, Utc}; /// Result that is a wrapper of `Result` pub type Result = std::result::Result; @@ -325,6 +325,42 @@ define_from_err!( "Failed to convert decimal literal to rust decimal" ); +define_from_err!( + parquet::errors::ParquetError, + ErrorKind::Unexpected, + "Failed to read a Parquet file" +); + +define_from_err!( + futures::channel::mpsc::SendError, + ErrorKind::Unexpected, + "Failed to send a message to a channel" +); + +define_from_err!(std::io::Error, ErrorKind::Unexpected, "IO Operation failed"); + +/// Converts a timestamp in milliseconds to `DateTime`, handling errors. +/// +/// # Arguments +/// +/// * `timestamp_ms` - The timestamp in milliseconds to convert. +/// +/// # Returns +/// +/// This function returns a `Result, Error>` which is `Ok` with the `DateTime` if the conversion is successful, +/// or an `Err` with an appropriate error if the timestamp is ambiguous or invalid. +pub(crate) fn timestamp_ms_to_utc(timestamp_ms: i64) -> Result> { + match Utc.timestamp_millis_opt(timestamp_ms) { + chrono::LocalResult::Single(t) => Ok(t), + chrono::LocalResult::Ambiguous(_, _) => Err(Error::new( + ErrorKind::Unexpected, + "Ambiguous timestamp with two possible results", + )), + chrono::LocalResult::None => Err(Error::new(ErrorKind::DataInvalid, "Invalid timestamp")), + } + .map_err(|e| e.with_context("timestamp value", timestamp_ms.to_string())) +} + /// Helper macro to check arguments. /// /// diff --git a/crates/iceberg/src/expr/accessor.rs b/crates/iceberg/src/expr/accessor.rs new file mode 100644 index 000000000..51bfa7d39 --- /dev/null +++ b/crates/iceberg/src/expr/accessor.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use serde_derive::{Deserialize, Serialize}; + +use crate::spec::{Datum, Literal, PrimitiveType, Struct}; +use crate::{Error, ErrorKind, Result}; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +pub struct StructAccessor { + position: usize, + r#type: PrimitiveType, + inner: Option>, +} + +pub(crate) type StructAccessorRef = Arc; + +impl StructAccessor { + pub(crate) fn new(position: usize, r#type: PrimitiveType) -> Self { + StructAccessor { + position, + r#type, + inner: None, + } + } + + pub(crate) fn wrap(position: usize, inner: Box) -> Self { + StructAccessor { + position, + r#type: inner.r#type().clone(), + inner: Some(inner), + } + } + + pub(crate) fn position(&self) -> usize { + self.position + } + + pub(crate) fn r#type(&self) -> &PrimitiveType { + &self.r#type + } + + pub(crate) fn get<'a>(&'a self, container: &'a Struct) -> Result> { + match &self.inner { + None => { + if container.is_null_at_index(self.position) { + Ok(None) + } else if let Literal::Primitive(literal) = &container[self.position] { + Ok(Some(Datum::new(self.r#type().clone(), literal.clone()))) + } else { + Err(Error::new( + ErrorKind::Unexpected, + "Expected Literal to be Primitive", + )) + } + } + Some(inner) => { + if let Literal::Struct(wrapped) = &container[self.position] { + inner.get(wrapped) + } else { + Err(Error::new( + ErrorKind::Unexpected, + "Nested accessor should only be wrapping a Struct", + )) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::expr::accessor::StructAccessor; + use crate::spec::{Datum, Literal, PrimitiveType, Struct}; + + #[test] + fn test_single_level_accessor() { + let accessor = StructAccessor::new(1, PrimitiveType::Boolean); + + assert_eq!(accessor.r#type(), &PrimitiveType::Boolean); + assert_eq!(accessor.position(), 1); + + let test_struct = + Struct::from_iter(vec![Some(Literal::bool(false)), Some(Literal::bool(true))]); + + assert_eq!(accessor.get(&test_struct).unwrap(), Some(Datum::bool(true))); + } + + #[test] + fn test_single_level_accessor_null() { + let accessor = StructAccessor::new(1, PrimitiveType::Boolean); + + assert_eq!(accessor.r#type(), &PrimitiveType::Boolean); + assert_eq!(accessor.position(), 1); + + let test_struct = Struct::from_iter(vec![Some(Literal::bool(false)), None]); + + assert_eq!(accessor.get(&test_struct).unwrap(), None); + } + + #[test] + fn test_nested_accessor() { + let nested_accessor = StructAccessor::new(1, PrimitiveType::Boolean); + let accessor = StructAccessor::wrap(2, Box::new(nested_accessor)); + + assert_eq!(accessor.r#type(), &PrimitiveType::Boolean); + //assert_eq!(accessor.position(), 1); + + let nested_test_struct = + Struct::from_iter(vec![Some(Literal::bool(false)), Some(Literal::bool(true))]); + + let test_struct = Struct::from_iter(vec![ + Some(Literal::bool(false)), + Some(Literal::bool(false)), + Some(Literal::Struct(nested_test_struct)), + ]); + + assert_eq!(accessor.get(&test_struct).unwrap(), Some(Datum::bool(true))); + } + + #[test] + fn test_nested_accessor_null() { + let nested_accessor = StructAccessor::new(0, PrimitiveType::Boolean); + let accessor = StructAccessor::wrap(2, Box::new(nested_accessor)); + + assert_eq!(accessor.r#type(), &PrimitiveType::Boolean); + //assert_eq!(accessor.position(), 1); + + let nested_test_struct = Struct::from_iter(vec![None, Some(Literal::bool(true))]); + + let test_struct = Struct::from_iter(vec![ + Some(Literal::bool(false)), + Some(Literal::bool(false)), + Some(Literal::Struct(nested_test_struct)), + ]); + + assert_eq!(accessor.get(&test_struct).unwrap(), None); + } +} diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs new file mode 100644 index 000000000..5771aac5e --- /dev/null +++ b/crates/iceberg/src/expr/mod.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains expressions. + +mod term; +use serde::{Deserialize, Serialize}; +pub use term::*; +pub(crate) mod accessor; +mod predicate; +pub(crate) mod visitors; +use std::fmt::{Display, Formatter}; + +pub use predicate::*; + +use crate::spec::SchemaRef; + +/// Predicate operators used in expressions. +/// +/// The discriminant of this enum is used for determining the type of the operator, see +/// [`PredicateOperator::is_unary`], [`PredicateOperator::is_binary`], [`PredicateOperator::is_set`] +#[allow(missing_docs)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +#[non_exhaustive] +#[repr(u16)] +pub enum PredicateOperator { + // Unary operators + IsNull = 101, + NotNull = 102, + IsNan = 103, + NotNan = 104, + + // Binary operators + LessThan = 201, + LessThanOrEq = 202, + GreaterThan = 203, + GreaterThanOrEq = 204, + Eq = 205, + NotEq = 206, + StartsWith = 207, + NotStartsWith = 208, + + // Set operators + In = 301, + NotIn = 302, +} + +impl Display for PredicateOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + PredicateOperator::IsNull => write!(f, "IS NULL"), + PredicateOperator::NotNull => write!(f, "IS NOT NULL"), + PredicateOperator::IsNan => write!(f, "IS NAN"), + PredicateOperator::NotNan => write!(f, "IS NOT NAN"), + PredicateOperator::LessThan => write!(f, "<"), + PredicateOperator::LessThanOrEq => write!(f, "<="), + PredicateOperator::GreaterThan => write!(f, ">"), + PredicateOperator::GreaterThanOrEq => write!(f, ">="), + PredicateOperator::Eq => write!(f, "="), + PredicateOperator::NotEq => write!(f, "!="), + PredicateOperator::In => write!(f, "IN"), + PredicateOperator::NotIn => write!(f, "NOT IN"), + PredicateOperator::StartsWith => write!(f, "STARTS WITH"), + PredicateOperator::NotStartsWith => write!(f, "NOT STARTS WITH"), + } + } +} + +impl PredicateOperator { + /// Check if this operator is unary operator. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::PredicateOperator; + /// assert!(PredicateOperator::IsNull.is_unary()); + /// ``` + pub fn is_unary(self) -> bool { + (self as u16) < (PredicateOperator::LessThan as u16) + } + + /// Check if this operator is binary operator. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::PredicateOperator; + /// assert!(PredicateOperator::LessThan.is_binary()); + /// ``` + pub fn is_binary(self) -> bool { + ((self as u16) > (PredicateOperator::NotNan as u16)) + && ((self as u16) < (PredicateOperator::In as u16)) + } + + /// Check if this operator is set operator. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::PredicateOperator; + /// assert!(PredicateOperator::In.is_set()); + /// ``` + pub fn is_set(self) -> bool { + (self as u16) > (PredicateOperator::NotStartsWith as u16) + } + + /// Returns the predicate that is the inverse of self + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::PredicateOperator; + /// assert!(PredicateOperator::IsNull.negate() == PredicateOperator::NotNull); + /// assert!(PredicateOperator::IsNan.negate() == PredicateOperator::NotNan); + /// assert!(PredicateOperator::LessThan.negate() == PredicateOperator::GreaterThanOrEq); + /// assert!(PredicateOperator::GreaterThan.negate() == PredicateOperator::LessThanOrEq); + /// assert!(PredicateOperator::Eq.negate() == PredicateOperator::NotEq); + /// assert!(PredicateOperator::In.negate() == PredicateOperator::NotIn); + /// assert!(PredicateOperator::StartsWith.negate() == PredicateOperator::NotStartsWith); + /// ``` + pub fn negate(self) -> PredicateOperator { + match self { + PredicateOperator::IsNull => PredicateOperator::NotNull, + PredicateOperator::NotNull => PredicateOperator::IsNull, + PredicateOperator::IsNan => PredicateOperator::NotNan, + PredicateOperator::NotNan => PredicateOperator::IsNan, + PredicateOperator::LessThan => PredicateOperator::GreaterThanOrEq, + PredicateOperator::LessThanOrEq => PredicateOperator::GreaterThan, + PredicateOperator::GreaterThan => PredicateOperator::LessThanOrEq, + PredicateOperator::GreaterThanOrEq => PredicateOperator::LessThan, + PredicateOperator::Eq => PredicateOperator::NotEq, + PredicateOperator::NotEq => PredicateOperator::Eq, + PredicateOperator::In => PredicateOperator::NotIn, + PredicateOperator::NotIn => PredicateOperator::In, + PredicateOperator::StartsWith => PredicateOperator::NotStartsWith, + PredicateOperator::NotStartsWith => PredicateOperator::StartsWith, + } + } +} + +/// Bind expression to a schema. +pub trait Bind { + /// The type of the bound result. + type Bound; + /// Bind an expression to a schema. + fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> crate::Result; +} + +#[cfg(test)] +mod tests { + use crate::expr::PredicateOperator; + + #[test] + fn test_unary() { + assert!(PredicateOperator::IsNull.is_unary()); + assert!(PredicateOperator::NotNull.is_unary()); + assert!(PredicateOperator::IsNan.is_unary()); + assert!(PredicateOperator::NotNan.is_unary()); + } + + #[test] + fn test_binary() { + assert!(PredicateOperator::LessThan.is_binary()); + assert!(PredicateOperator::LessThanOrEq.is_binary()); + assert!(PredicateOperator::GreaterThan.is_binary()); + assert!(PredicateOperator::GreaterThanOrEq.is_binary()); + assert!(PredicateOperator::Eq.is_binary()); + assert!(PredicateOperator::NotEq.is_binary()); + assert!(PredicateOperator::StartsWith.is_binary()); + assert!(PredicateOperator::NotStartsWith.is_binary()); + } + + #[test] + fn test_set() { + assert!(PredicateOperator::In.is_set()); + assert!(PredicateOperator::NotIn.is_set()); + } +} diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs new file mode 100644 index 000000000..acf21a5b1 --- /dev/null +++ b/crates/iceberg/src/expr/predicate.rs @@ -0,0 +1,1287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains predicate expressions. +//! Predicate expressions are used to filter data, and evaluates to a boolean value. For example, +//! `a > 10` is a predicate expression, and it evaluates to `true` if `a` is greater than `10`, + +use std::fmt::{Debug, Display, Formatter}; +use std::ops::Not; + +use array_init::array_init; +use fnv::FnvHashSet; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::error::Result; +use crate::expr::{Bind, BoundReference, PredicateOperator, Reference}; +use crate::spec::{Datum, SchemaRef}; +use crate::{Error, ErrorKind}; + +/// Logical expression, such as `AND`, `OR`, `NOT`. +#[derive(PartialEq, Clone)] +pub struct LogicalExpression { + inputs: [Box; N], +} + +impl Serialize for LogicalExpression { + fn serialize(&self, serializer: S) -> std::result::Result + where S: serde::Serializer { + self.inputs.serialize(serializer) + } +} + +impl<'de, T: Deserialize<'de>, const N: usize> Deserialize<'de> for LogicalExpression { + fn deserialize(deserializer: D) -> std::result::Result + where D: serde::Deserializer<'de> { + let inputs = Vec::>::deserialize(deserializer)?; + Ok(LogicalExpression::new( + array_init::from_iter(inputs.into_iter()).ok_or_else(|| { + serde::de::Error::custom(format!("Failed to deserialize LogicalExpression: the len of inputs is not match with the len of LogicalExpression {}",N)) + })?, + )) + } +} + +impl Debug for LogicalExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LogicalExpression") + .field("inputs", &self.inputs) + .finish() + } +} + +impl LogicalExpression { + fn new(inputs: [Box; N]) -> Self { + Self { inputs } + } + + /// Return inputs of this logical expression. + pub fn inputs(&self) -> [&T; N] { + let mut ret: [&T; N] = [self.inputs[0].as_ref(); N]; + for (i, item) in ret.iter_mut().enumerate() { + *item = &self.inputs[i]; + } + ret + } +} + +impl Bind for LogicalExpression +where T::Bound: Sized +{ + type Bound = LogicalExpression; + + fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result { + let mut outputs: [Option>; N] = array_init(|_| None); + for (i, input) in self.inputs.iter().enumerate() { + outputs[i] = Some(Box::new(input.bind(schema.clone(), case_sensitive)?)); + } + + // It's safe to use `unwrap` here since they are all `Some`. + let bound_inputs = array_init::from_iter(outputs.into_iter().map(Option::unwrap)).unwrap(); + Ok(LogicalExpression::new(bound_inputs)) + } +} + +/// Unary predicate, for example, `a IS NULL`. +#[derive(PartialEq, Clone, Serialize, Deserialize)] +pub struct UnaryExpression { + /// Operator of this predicate, must be single operand operator. + op: PredicateOperator, + /// Term of this predicate, for example, `a` in `a IS NULL`. + #[serde(bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>"))] + term: T, +} + +impl Debug for UnaryExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("UnaryExpression") + .field("op", &self.op) + .field("term", &self.term) + .finish() + } +} + +impl Display for UnaryExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{} {}", self.term, self.op) + } +} + +impl Bind for UnaryExpression { + type Bound = UnaryExpression; + + fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result { + let bound_term = self.term.bind(schema, case_sensitive)?; + Ok(UnaryExpression::new(self.op, bound_term)) + } +} + +impl UnaryExpression { + pub(crate) fn new(op: PredicateOperator, term: T) -> Self { + debug_assert!(op.is_unary()); + Self { op, term } + } + + /// Return the operator of this predicate. + pub(crate) fn op(&self) -> PredicateOperator { + self.op + } + + /// Return the term of this predicate. + pub(crate) fn term(&self) -> &T { + &self.term + } +} + +/// Binary predicate, for example, `a > 10`. +#[derive(PartialEq, Clone, Serialize, Deserialize)] +pub struct BinaryExpression { + /// Operator of this predicate, must be binary operator, such as `=`, `>`, `<`, etc. + op: PredicateOperator, + /// Term of this predicate, for example, `a` in `a > 10`. + #[serde(bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>"))] + term: T, + /// Literal of this predicate, for example, `10` in `a > 10`. + literal: Datum, +} + +impl Debug for BinaryExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BinaryExpression") + .field("op", &self.op) + .field("term", &self.term) + .field("literal", &self.literal) + .finish() + } +} + +impl BinaryExpression { + pub(crate) fn new(op: PredicateOperator, term: T, literal: Datum) -> Self { + debug_assert!(op.is_binary()); + Self { op, term, literal } + } + + pub(crate) fn op(&self) -> PredicateOperator { + self.op + } + + /// Return the literal of this predicate. + pub(crate) fn literal(&self) -> &Datum { + &self.literal + } + + /// Return the term of this predicate. + pub(crate) fn term(&self) -> &T { + &self.term + } +} + +impl Display for BinaryExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{} {} {}", self.term, self.op, self.literal) + } +} + +impl Bind for BinaryExpression { + type Bound = BinaryExpression; + + fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result { + let bound_term = self.term.bind(schema.clone(), case_sensitive)?; + Ok(BinaryExpression::new( + self.op, + bound_term, + self.literal.clone(), + )) + } +} + +/// Set predicates, for example, `a in (1, 2, 3)`. +#[derive(PartialEq, Clone, Serialize, Deserialize)] +pub struct SetExpression { + /// Operator of this predicate, must be set operator, such as `IN`, `NOT IN`, etc. + op: PredicateOperator, + /// Term of this predicate, for example, `a` in `a in (1, 2, 3)`. + term: T, + /// Literals of this predicate, for example, `(1, 2, 3)` in `a in (1, 2, 3)`. + literals: FnvHashSet, +} + +impl Debug for SetExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SetExpression") + .field("op", &self.op) + .field("term", &self.term) + .field("literal", &self.literals) + .finish() + } +} + +impl SetExpression { + pub(crate) fn new(op: PredicateOperator, term: T, literals: FnvHashSet) -> Self { + debug_assert!(op.is_set()); + Self { op, term, literals } + } + + /// Return the operator of this predicate. + pub(crate) fn op(&self) -> PredicateOperator { + self.op + } + + pub(crate) fn literals(&self) -> &FnvHashSet { + &self.literals + } + + /// Return the term of this predicate. + pub(crate) fn term(&self) -> &T { + &self.term + } +} + +impl Bind for SetExpression { + type Bound = SetExpression; + + fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result { + let bound_term = self.term.bind(schema.clone(), case_sensitive)?; + Ok(SetExpression::new( + self.op, + bound_term, + self.literals.clone(), + )) + } +} + +impl Display for SetExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut literal_strs = self.literals.iter().map(|l| format!("{}", l)); + + write!(f, "{} {} ({})", self.term, self.op, literal_strs.join(", ")) + } +} + +/// Unbound predicate expression before binding to a schema. +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub enum Predicate { + /// AlwaysTrue predicate, for example, `TRUE`. + AlwaysTrue, + /// AlwaysFalse predicate, for example, `FALSE`. + AlwaysFalse, + /// And predicate, for example, `a > 10 AND b < 20`. + And(LogicalExpression), + /// Or predicate, for example, `a > 10 OR b < 20`. + Or(LogicalExpression), + /// Not predicate, for example, `NOT (a > 10)`. + Not(LogicalExpression), + /// Unary expression, for example, `a IS NULL`. + Unary(UnaryExpression), + /// Binary expression, for example, `a > 10`. + Binary(BinaryExpression), + /// Set predicates, for example, `a in (1, 2, 3)`. + Set(SetExpression), +} + +impl Bind for Predicate { + type Bound = BoundPredicate; + + fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result { + match self { + Predicate::And(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + + let [left, right] = bound_expr.inputs; + Ok(match (left, right) { + (_, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => { + BoundPredicate::AlwaysFalse + } + (l, _) if matches!(&*l, &BoundPredicate::AlwaysFalse) => { + BoundPredicate::AlwaysFalse + } + (left, r) if matches!(&*r, &BoundPredicate::AlwaysTrue) => *left, + (l, right) if matches!(&*l, &BoundPredicate::AlwaysTrue) => *right, + (left, right) => BoundPredicate::And(LogicalExpression::new([left, right])), + }) + } + Predicate::Not(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let [inner] = bound_expr.inputs; + Ok(match inner { + e if matches!(&*e, &BoundPredicate::AlwaysTrue) => BoundPredicate::AlwaysFalse, + e if matches!(&*e, &BoundPredicate::AlwaysFalse) => BoundPredicate::AlwaysTrue, + e => BoundPredicate::Not(LogicalExpression::new([e])), + }) + } + Predicate::Or(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let [left, right] = bound_expr.inputs; + Ok(match (left, right) { + (l, r) + if matches!(&*r, &BoundPredicate::AlwaysTrue) + || matches!(&*l, &BoundPredicate::AlwaysTrue) => + { + BoundPredicate::AlwaysTrue + } + (left, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => *left, + (l, right) if matches!(&*l, &BoundPredicate::AlwaysFalse) => *right, + (left, right) => BoundPredicate::Or(LogicalExpression::new([left, right])), + }) + } + Predicate::Unary(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + + match &bound_expr.op { + &PredicateOperator::IsNull => { + if bound_expr.term.field().required { + return Ok(BoundPredicate::AlwaysFalse); + } + } + &PredicateOperator::NotNull => { + if bound_expr.term.field().required { + return Ok(BoundPredicate::AlwaysTrue); + } + } + &PredicateOperator::IsNan | &PredicateOperator::NotNan => { + if !bound_expr.term.field().field_type.is_floating_type() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Expecting floating point type, but found {}", + bound_expr.term.field().field_type + ), + )); + } + } + op => { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Expecting unary operator, but found {op}"), + )) + } + } + + Ok(BoundPredicate::Unary(bound_expr)) + } + Predicate::Binary(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let bound_literal = bound_expr.literal.to(&bound_expr.term.field().field_type)?; + Ok(BoundPredicate::Binary(BinaryExpression::new( + bound_expr.op, + bound_expr.term, + bound_literal, + ))) + } + Predicate::Set(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let bound_literals = bound_expr + .literals + .into_iter() + .map(|l| l.to(&bound_expr.term.field().field_type)) + .collect::>>()?; + + match &bound_expr.op { + &PredicateOperator::In => { + if bound_literals.is_empty() { + return Ok(BoundPredicate::AlwaysFalse); + } + if bound_literals.len() == 1 { + return Ok(BoundPredicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + bound_expr.term, + bound_literals.into_iter().next().unwrap(), + ))); + } + } + &PredicateOperator::NotIn => { + if bound_literals.is_empty() { + return Ok(BoundPredicate::AlwaysTrue); + } + if bound_literals.len() == 1 { + return Ok(BoundPredicate::Binary(BinaryExpression::new( + PredicateOperator::NotEq, + bound_expr.term, + bound_literals.into_iter().next().unwrap(), + ))); + } + } + op => { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Expecting unary operator,but found {op}"), + )) + } + } + + Ok(BoundPredicate::Set(SetExpression::new( + bound_expr.op, + bound_expr.term, + bound_literals, + ))) + } + Predicate::AlwaysTrue => Ok(BoundPredicate::AlwaysTrue), + Predicate::AlwaysFalse => Ok(BoundPredicate::AlwaysFalse), + } + } +} + +impl Display for Predicate { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Predicate::AlwaysTrue => { + write!(f, "TRUE") + } + Predicate::AlwaysFalse => { + write!(f, "FALSE") + } + Predicate::And(expr) => { + write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1]) + } + Predicate::Or(expr) => { + write!(f, "({}) OR ({})", expr.inputs()[0], expr.inputs()[1]) + } + Predicate::Not(expr) => { + write!(f, "NOT ({})", expr.inputs()[0]) + } + Predicate::Unary(expr) => { + write!(f, "{}", expr) + } + Predicate::Binary(expr) => { + write!(f, "{}", expr) + } + Predicate::Set(expr) => { + write!(f, "{}", expr) + } + } + } +} + +impl Predicate { + /// Combines two predicates with `AND`. + /// + /// # Example + /// + /// ```rust + /// use std::ops::Bound::Unbounded; + /// + /// use iceberg::expr::BoundPredicate::Unary; + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr1 = Reference::new("a").less_than(Datum::long(10)); + /// + /// let expr2 = Reference::new("b").less_than(Datum::long(20)); + /// + /// let expr = expr1.and(expr2); + /// + /// assert_eq!(&format!("{expr}"), "(a < 10) AND (b < 20)"); + /// ``` + pub fn and(self, other: Predicate) -> Predicate { + match (self, other) { + (Predicate::AlwaysFalse, _) => Predicate::AlwaysFalse, + (_, Predicate::AlwaysFalse) => Predicate::AlwaysFalse, + (Predicate::AlwaysTrue, rhs) => rhs, + (lhs, Predicate::AlwaysTrue) => lhs, + (lhs, rhs) => Predicate::And(LogicalExpression::new([Box::new(lhs), Box::new(rhs)])), + } + } + + /// Combines two predicates with `OR`. + /// + /// # Example + /// + /// ```rust + /// use std::ops::Bound::Unbounded; + /// + /// use iceberg::expr::BoundPredicate::Unary; + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr1 = Reference::new("a").less_than(Datum::long(10)); + /// + /// let expr2 = Reference::new("b").less_than(Datum::long(20)); + /// + /// let expr = expr1.or(expr2); + /// + /// assert_eq!(&format!("{expr}"), "(a < 10) OR (b < 20)"); + /// ``` + pub fn or(self, other: Predicate) -> Predicate { + match (self, other) { + (Predicate::AlwaysTrue, _) => Predicate::AlwaysTrue, + (_, Predicate::AlwaysTrue) => Predicate::AlwaysTrue, + (Predicate::AlwaysFalse, rhs) => rhs, + (lhs, Predicate::AlwaysFalse) => lhs, + (lhs, rhs) => Predicate::Or(LogicalExpression::new([Box::new(lhs), Box::new(rhs)])), + } + } + + /// Returns a predicate representing the negation ('NOT') of this one, + /// by using inverse predicates rather than wrapping in a `NOT`. + /// Used for `NOT` elimination. + /// + /// # Example + /// + /// ```rust + /// use std::ops::Bound::Unbounded; + /// + /// use iceberg::expr::BoundPredicate::Unary; + /// use iceberg::expr::{LogicalExpression, Predicate, Reference}; + /// use iceberg::spec::Datum; + /// let expr1 = Reference::new("a").less_than(Datum::long(10)); + /// let expr2 = Reference::new("b") + /// .less_than(Datum::long(5)) + /// .and(Reference::new("c").less_than(Datum::long(10))); + /// + /// let result = expr1.negate(); + /// assert_eq!(&format!("{result}"), "a >= 10"); + /// + /// let result = expr2.negate(); + /// assert_eq!(&format!("{result}"), "(b >= 5) OR (c >= 10)"); + /// ``` + pub fn negate(self) -> Predicate { + match self { + Predicate::AlwaysTrue => Predicate::AlwaysFalse, + Predicate::AlwaysFalse => Predicate::AlwaysTrue, + Predicate::And(expr) => Predicate::Or(LogicalExpression::new( + expr.inputs.map(|expr| Box::new(expr.negate())), + )), + Predicate::Or(expr) => Predicate::And(LogicalExpression::new( + expr.inputs.map(|expr| Box::new(expr.negate())), + )), + Predicate::Not(expr) => { + let LogicalExpression { inputs: [input_0] } = expr; + *input_0 + } + Predicate::Unary(expr) => { + Predicate::Unary(UnaryExpression::new(expr.op.negate(), expr.term)) + } + Predicate::Binary(expr) => Predicate::Binary(BinaryExpression::new( + expr.op.negate(), + expr.term, + expr.literal, + )), + Predicate::Set(expr) => Predicate::Set(SetExpression::new( + expr.op.negate(), + expr.term, + expr.literals, + )), + } + } + /// Simplifies the expression by removing `NOT` predicates, + /// directly negating the inner expressions instead. The transformation + /// applies logical laws (such as De Morgan's laws) to + /// recursively negate and simplify inner expressions within `NOT` + /// predicates. + /// + /// # Example + /// + /// ```rust + /// use std::ops::Not; + /// + /// use iceberg::expr::{LogicalExpression, Predicate, Reference}; + /// use iceberg::spec::Datum; + /// + /// let expression = Reference::new("a").less_than(Datum::long(5)).not(); + /// let result = expression.rewrite_not(); + /// + /// assert_eq!(&format!("{result}"), "a >= 5"); + /// ``` + pub fn rewrite_not(self) -> Predicate { + match self { + Predicate::And(expr) => { + let [left, right] = expr.inputs; + let new_left = Box::new(left.rewrite_not()); + let new_right = Box::new(right.rewrite_not()); + Predicate::And(LogicalExpression::new([new_left, new_right])) + } + Predicate::Or(expr) => { + let [left, right] = expr.inputs; + let new_left = Box::new(left.rewrite_not()); + let new_right = Box::new(right.rewrite_not()); + Predicate::Or(LogicalExpression::new([new_left, new_right])) + } + Predicate::Not(expr) => { + let [inner] = expr.inputs; + inner.negate() + } + Predicate::Unary(expr) => Predicate::Unary(expr), + Predicate::Binary(expr) => Predicate::Binary(expr), + Predicate::Set(expr) => Predicate::Set(expr), + Predicate::AlwaysTrue => Predicate::AlwaysTrue, + Predicate::AlwaysFalse => Predicate::AlwaysFalse, + } + } +} + +impl Not for Predicate { + type Output = Predicate; + + /// Create a predicate which is the reverse of this predicate. For example: `NOT (a > 10)`. + /// + /// This is different from [`Predicate::negate()`] since it doesn't rewrite expression, but + /// just adds a `NOT` operator. + /// + /// # Example + /// + ///```rust + /// use std::ops::Bound::Unbounded; + /// + /// use iceberg::expr::BoundPredicate::Unary; + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr1 = Reference::new("a").less_than(Datum::long(10)); + /// + /// let expr = !expr1; + /// + /// assert_eq!(&format!("{expr}"), "NOT (a < 10)"); + /// ``` + fn not(self) -> Self::Output { + Predicate::Not(LogicalExpression::new([Box::new(self)])) + } +} + +/// Bound predicate expression after binding to a schema. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum BoundPredicate { + /// An expression always evaluates to true. + AlwaysTrue, + /// An expression always evaluates to false. + AlwaysFalse, + /// An expression combined by `AND`, for example, `a > 10 AND b < 20`. + And(LogicalExpression), + /// An expression combined by `OR`, for example, `a > 10 OR b < 20`. + Or(LogicalExpression), + /// An expression combined by `NOT`, for example, `NOT (a > 10)`. + Not(LogicalExpression), + /// Unary expression, for example, `a IS NULL`. + Unary(UnaryExpression), + /// Binary expression, for example, `a > 10`. + Binary(BinaryExpression), + /// Set predicates, for example, `a IN (1, 2, 3)`. + Set(SetExpression), +} + +impl Display for BoundPredicate { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + BoundPredicate::AlwaysTrue => { + write!(f, "True") + } + BoundPredicate::AlwaysFalse => { + write!(f, "False") + } + BoundPredicate::And(expr) => { + write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1]) + } + BoundPredicate::Or(expr) => { + write!(f, "({}) OR ({})", expr.inputs()[0], expr.inputs()[1]) + } + BoundPredicate::Not(expr) => { + write!(f, "NOT ({})", expr.inputs()[0]) + } + BoundPredicate::Unary(expr) => { + write!(f, "{}", expr) + } + BoundPredicate::Binary(expr) => { + write!(f, "{}", expr) + } + BoundPredicate::Set(expr) => { + write!(f, "{}", expr) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::ops::Not; + use std::sync::Arc; + + use crate::expr::Predicate::{AlwaysFalse, AlwaysTrue}; + use crate::expr::{Bind, BoundPredicate, Reference}; + use crate::spec::{Datum, NestedField, PrimitiveType, Schema, SchemaRef, Type}; + + #[test] + fn test_logical_or_rewrite_not() { + let expression = Reference::new("b") + .less_than(Datum::long(5)) + .or(Reference::new("c").less_than(Datum::long(10))) + .not(); + + let expected = Reference::new("b") + .greater_than_or_equal_to(Datum::long(5)) + .and(Reference::new("c").greater_than_or_equal_to(Datum::long(10))); + + let result = expression.rewrite_not(); + + assert_eq!(result, expected); + } + + #[test] + fn test_logical_and_rewrite_not() { + let expression = Reference::new("b") + .less_than(Datum::long(5)) + .and(Reference::new("c").less_than(Datum::long(10))) + .not(); + + let expected = Reference::new("b") + .greater_than_or_equal_to(Datum::long(5)) + .or(Reference::new("c").greater_than_or_equal_to(Datum::long(10))); + + let result = expression.rewrite_not(); + + assert_eq!(result, expected); + } + + #[test] + fn test_set_rewrite_not() { + let expression = Reference::new("a") + .is_in([Datum::int(5), Datum::int(6)]) + .not(); + + let expected = Reference::new("a").is_not_in([Datum::int(5), Datum::int(6)]); + + let result = expression.rewrite_not(); + + assert_eq!(result, expected); + } + + #[test] + fn test_binary_rewrite_not() { + let expression = Reference::new("a").less_than(Datum::long(5)).not(); + + let expected = Reference::new("a").greater_than_or_equal_to(Datum::long(5)); + + let result = expression.rewrite_not(); + + assert_eq!(result, expected); + } + + #[test] + fn test_unary_rewrite_not() { + let expression = Reference::new("a").is_null().not(); + + let expected = Reference::new("a").is_not_null(); + + let result = expression.rewrite_not(); + + assert_eq!(result, expected); + } + + #[test] + fn test_predicate_and_reduce_always_true_false() { + let true_or_expr = AlwaysTrue.and(Reference::new("b").less_than(Datum::long(5))); + assert_eq!(&format!("{true_or_expr}"), "b < 5"); + + let expr_or_true = Reference::new("b") + .less_than(Datum::long(5)) + .and(AlwaysTrue); + assert_eq!(&format!("{expr_or_true}"), "b < 5"); + + let false_or_expr = AlwaysFalse.and(Reference::new("b").less_than(Datum::long(5))); + assert_eq!(&format!("{false_or_expr}"), "FALSE"); + + let expr_or_false = Reference::new("b") + .less_than(Datum::long(5)) + .and(AlwaysFalse); + assert_eq!(&format!("{expr_or_false}"), "FALSE"); + } + + #[test] + fn test_predicate_or_reduce_always_true_false() { + let true_or_expr = AlwaysTrue.or(Reference::new("b").less_than(Datum::long(5))); + assert_eq!(&format!("{true_or_expr}"), "TRUE"); + + let expr_or_true = Reference::new("b").less_than(Datum::long(5)).or(AlwaysTrue); + assert_eq!(&format!("{expr_or_true}"), "TRUE"); + + let false_or_expr = AlwaysFalse.or(Reference::new("b").less_than(Datum::long(5))); + assert_eq!(&format!("{false_or_expr}"), "b < 5"); + + let expr_or_false = Reference::new("b") + .less_than(Datum::long(5)) + .or(AlwaysFalse); + assert_eq!(&format!("{expr_or_false}"), "b < 5"); + } + + #[test] + fn test_predicate_negate_and() { + let expression = Reference::new("b") + .less_than(Datum::long(5)) + .and(Reference::new("c").less_than(Datum::long(10))); + + let expected = Reference::new("b") + .greater_than_or_equal_to(Datum::long(5)) + .or(Reference::new("c").greater_than_or_equal_to(Datum::long(10))); + + let result = expression.negate(); + + assert_eq!(result, expected); + } + + #[test] + fn test_predicate_negate_or() { + let expression = Reference::new("b") + .greater_than_or_equal_to(Datum::long(5)) + .or(Reference::new("c").greater_than_or_equal_to(Datum::long(10))); + + let expected = Reference::new("b") + .less_than(Datum::long(5)) + .and(Reference::new("c").less_than(Datum::long(10))); + + let result = expression.negate(); + + assert_eq!(result, expected); + } + + #[test] + fn test_predicate_negate_not() { + let expression = Reference::new("b") + .greater_than_or_equal_to(Datum::long(5)) + .not(); + + let expected = Reference::new("b").greater_than_or_equal_to(Datum::long(5)); + + let result = expression.negate(); + + assert_eq!(result, expected); + } + + #[test] + fn test_predicate_negate_unary() { + let expression = Reference::new("b").is_not_null(); + + let expected = Reference::new("b").is_null(); + + let result = expression.negate(); + + assert_eq!(result, expected); + } + + #[test] + fn test_predicate_negate_binary() { + let expression = Reference::new("a").less_than(Datum::long(5)); + + let expected = Reference::new("a").greater_than_or_equal_to(Datum::long(5)); + + let result = expression.negate(); + + assert_eq!(result, expected); + } + + #[test] + fn test_predicate_negate_set() { + let expression = Reference::new("a").is_in([Datum::long(5), Datum::long(6)]); + + let expected = Reference::new("a").is_not_in([Datum::long(5), Datum::long(6)]); + + let result = expression.negate(); + + assert_eq!(result, expected); + } + + fn table_schema_simple() -> SchemaRef { + Arc::new( + Schema::builder() + .with_schema_id(1) + .with_identifier_field_ids(vec![2]) + .with_fields(vec![ + NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(), + NestedField::optional(4, "qux", Type::Primitive(PrimitiveType::Float)).into(), + ]) + .build() + .unwrap(), + ) + } + + fn test_bound_predicate_serialize_diserialize(bound_predicate: BoundPredicate) { + let serialized = serde_json::to_string(&bound_predicate).unwrap(); + let deserialized: BoundPredicate = serde_json::from_str(&serialized).unwrap(); + assert_eq!(bound_predicate, deserialized); + } + + #[test] + fn test_bind_is_null() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "foo IS NULL"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_is_null_required() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_is_not_null() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").is_not_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "foo IS NOT NULL"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_is_not_null_required() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "True"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_is_nan() { + let schema = table_schema_simple(); + let expr = Reference::new("qux").is_nan(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "qux IS NAN"); + + let schema_string = table_schema_simple(); + let expr_string = Reference::new("foo").is_nan(); + let bound_expr_string = expr_string.bind(schema_string, true); + assert!(bound_expr_string.is_err()); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_is_nan_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").is_nan(); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_is_not_nan() { + let schema = table_schema_simple(); + let expr = Reference::new("qux").is_not_nan(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "qux IS NOT NAN"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_is_not_nan_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").is_not_nan(); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_less_than() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").less_than(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar < 10"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_less_than_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").less_than(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_less_than_or_eq() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").less_than_or_equal_to(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar <= 10"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_less_than_or_eq_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").less_than_or_equal_to(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_greater_than() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").greater_than(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar > 10"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_greater_than_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").greater_than(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_greater_than_or_eq() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").greater_than_or_equal_to(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar >= 10"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_greater_than_or_eq_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").greater_than_or_equal_to(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_equal_to() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").equal_to(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar = 10"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_equal_to_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").equal_to(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_not_equal_to() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").not_equal_to(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar != 10"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_not_equal_to_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").not_equal_to(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_starts_with() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").starts_with(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"foo STARTS WITH "abcd""#); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_starts_with_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").starts_with(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_not_starts_with() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").not_starts_with(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"foo NOT STARTS WITH "abcd""#); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_not_starts_with_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").not_starts_with(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_in() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in([Datum::int(10), Datum::int(20)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar IN (20, 10)"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_in_empty() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in(vec![]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_in_one_literal() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in(vec![Datum::int(10)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar = 10"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_in_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in(vec![Datum::int(10), Datum::string("abcd")]); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_not_in() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in([Datum::int(10), Datum::int(20)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar NOT IN (20, 10)"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_not_in_empty() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in(vec![]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "True"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_not_in_one_literal() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in(vec![Datum::int(10)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar != 10"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_not_in_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in([Datum::int(10), Datum::string("abcd")]); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_and() { + let schema = table_schema_simple(); + let expr = Reference::new("bar") + .less_than(Datum::int(10)) + .and(Reference::new("foo").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "(bar < 10) AND (foo IS NULL)"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_and_always_false() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .and(Reference::new("bar").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_and_always_true() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .and(Reference::new("bar").is_not_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_or() { + let schema = table_schema_simple(); + let expr = Reference::new("bar") + .less_than(Datum::int(10)) + .or(Reference::new("foo").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "(bar < 10) OR (foo IS NULL)"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_or_always_true() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .or(Reference::new("bar").is_not_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "True"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_or_always_false() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .or(Reference::new("bar").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_not() { + let schema = table_schema_simple(); + let expr = !Reference::new("bar").less_than(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "NOT (bar < 10)"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_not_always_true() { + let schema = table_schema_simple(); + let expr = !Reference::new("bar").is_not_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + test_bound_predicate_serialize_diserialize(bound_expr); + } + + #[test] + fn test_bind_not_always_false() { + let schema = table_schema_simple(); + let expr = !Reference::new("bar").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"True"#); + test_bound_predicate_serialize_diserialize(bound_expr); + } +} diff --git a/crates/iceberg/src/expr/term.rs b/crates/iceberg/src/expr/term.rs new file mode 100644 index 000000000..f83cebd99 --- /dev/null +++ b/crates/iceberg/src/expr/term.rs @@ -0,0 +1,452 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Term definition. + +use std::fmt::{Display, Formatter}; + +use fnv::FnvHashSet; +use serde::{Deserialize, Serialize}; + +use crate::expr::accessor::{StructAccessor, StructAccessorRef}; +use crate::expr::{ + BinaryExpression, Bind, Predicate, PredicateOperator, SetExpression, UnaryExpression, +}; +use crate::spec::{Datum, NestedField, NestedFieldRef, SchemaRef}; +use crate::{Error, ErrorKind}; + +/// Unbound term before binding to a schema. +pub type Term = Reference; + +/// A named reference in an unbound expression. +/// For example, `a` in `a > 10`. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Reference { + name: String, +} + +impl Reference { + /// Create a new unbound reference. + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + /// Return the name of this reference. + pub fn name(&self) -> &str { + &self.name + } +} + +impl Reference { + /// Creates an less than expression. For example, `a < 10`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").less_than(Datum::long(10)); + /// + /// assert_eq!(&format!("{expr}"), "a < 10"); + /// ``` + pub fn less_than(self, datum: Datum) -> Predicate { + Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + self, + datum, + )) + } + + /// Creates an less than or equal to expression. For example, `a <= 10`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").less_than_or_equal_to(Datum::long(10)); + /// + /// assert_eq!(&format!("{expr}"), "a <= 10"); + /// ``` + pub fn less_than_or_equal_to(self, datum: Datum) -> Predicate { + Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThanOrEq, + self, + datum, + )) + } + + /// Creates an greater than expression. For example, `a > 10`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").greater_than(Datum::long(10)); + /// + /// assert_eq!(&format!("{expr}"), "a > 10"); + /// ``` + pub fn greater_than(self, datum: Datum) -> Predicate { + Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThan, + self, + datum, + )) + } + + /// Creates a greater-than-or-equal-to than expression. For example, `a >= 10`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").greater_than_or_equal_to(Datum::long(10)); + /// + /// assert_eq!(&format!("{expr}"), "a >= 10"); + /// ``` + pub fn greater_than_or_equal_to(self, datum: Datum) -> Predicate { + Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + self, + datum, + )) + } + + /// Creates an equal-to expression. For example, `a = 10`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").equal_to(Datum::long(10)); + /// + /// assert_eq!(&format!("{expr}"), "a = 10"); + /// ``` + pub fn equal_to(self, datum: Datum) -> Predicate { + Predicate::Binary(BinaryExpression::new(PredicateOperator::Eq, self, datum)) + } + + /// Creates a not equal-to expression. For example, `a!= 10`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").not_equal_to(Datum::long(10)); + /// + /// assert_eq!(&format!("{expr}"), "a != 10"); + /// ``` + pub fn not_equal_to(self, datum: Datum) -> Predicate { + Predicate::Binary(BinaryExpression::new(PredicateOperator::NotEq, self, datum)) + } + + /// Creates a start-with expression. For example, `a STARTS WITH "foo"`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").starts_with(Datum::string("foo")); + /// + /// assert_eq!(&format!("{expr}"), r#"a STARTS WITH "foo""#); + /// ``` + pub fn starts_with(self, datum: Datum) -> Predicate { + Predicate::Binary(BinaryExpression::new( + PredicateOperator::StartsWith, + self, + datum, + )) + } + + /// Creates a not start-with expression. For example, `a NOT STARTS WITH 'foo'`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// + /// let expr = Reference::new("a").not_starts_with(Datum::string("foo")); + /// + /// assert_eq!(&format!("{expr}"), r#"a NOT STARTS WITH "foo""#); + /// ``` + pub fn not_starts_with(self, datum: Datum) -> Predicate { + Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotStartsWith, + self, + datum, + )) + } + + /// Creates an is-nan expression. For example, `a IS NAN`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").is_nan(); + /// + /// assert_eq!(&format!("{expr}"), "a IS NAN"); + /// ``` + pub fn is_nan(self) -> Predicate { + Predicate::Unary(UnaryExpression::new(PredicateOperator::IsNan, self)) + } + + /// Creates an is-not-nan expression. For example, `a IS NOT NAN`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").is_not_nan(); + /// + /// assert_eq!(&format!("{expr}"), "a IS NOT NAN"); + /// ``` + pub fn is_not_nan(self) -> Predicate { + Predicate::Unary(UnaryExpression::new(PredicateOperator::NotNan, self)) + } + + /// Creates an is-null expression. For example, `a IS NULL`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").is_null(); + /// + /// assert_eq!(&format!("{expr}"), "a IS NULL"); + /// ``` + pub fn is_null(self) -> Predicate { + Predicate::Unary(UnaryExpression::new(PredicateOperator::IsNull, self)) + } + + /// Creates an is-not-null expression. For example, `a IS NOT NULL`. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").is_not_null(); + /// + /// assert_eq!(&format!("{expr}"), "a IS NOT NULL"); + /// ``` + pub fn is_not_null(self) -> Predicate { + Predicate::Unary(UnaryExpression::new(PredicateOperator::NotNull, self)) + } + + /// Creates an is-in expression. For example, `a IS IN (5, 6)`. + /// + /// # Example + /// + /// ```rust + /// use fnv::FnvHashSet; + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").is_in([Datum::long(5), Datum::long(6)]); + /// + /// let as_string = format!("{expr}"); + /// assert!(&as_string == "a IN (5, 6)" || &as_string == "a IN (6, 5)"); + /// ``` + pub fn is_in(self, literals: impl IntoIterator) -> Predicate { + Predicate::Set(SetExpression::new( + PredicateOperator::In, + self, + FnvHashSet::from_iter(literals), + )) + } + + /// Creates an is-not-in expression. For example, `a IS NOT IN (5, 6)`. + /// + /// # Example + /// + /// ```rust + /// use fnv::FnvHashSet; + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").is_not_in([Datum::long(5), Datum::long(6)]); + /// + /// let as_string = format!("{expr}"); + /// assert!(&as_string == "a NOT IN (5, 6)" || &as_string == "a NOT IN (6, 5)"); + /// ``` + pub fn is_not_in(self, literals: impl IntoIterator) -> Predicate { + Predicate::Set(SetExpression::new( + PredicateOperator::NotIn, + self, + FnvHashSet::from_iter(literals), + )) + } +} + +impl Display for Reference { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl Bind for Reference { + type Bound = BoundReference; + + fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> crate::Result { + let field = if case_sensitive { + schema.field_by_name(&self.name) + } else { + schema.field_by_name_case_insensitive(&self.name) + }; + + let field = field.ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Field {} not found in schema", self.name), + ) + })?; + + let accessor = schema.accessor_by_field_id(field.id).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Accessor for Field {} not found", self.name), + ) + })?; + + Ok(BoundReference::new( + self.name.clone(), + field.clone(), + accessor.clone(), + )) + } +} + +/// A named reference in a bound expression after binding to a schema. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct BoundReference { + // This maybe different from [`name`] filed in [`NestedField`] since this contains full path. + // For example, if the field is `a.b.c`, then `field.name` is `c`, but `original_name` is `a.b.c`. + column_name: String, + field: NestedFieldRef, + accessor: StructAccessorRef, +} + +impl BoundReference { + /// Creates a new bound reference. + pub fn new( + name: impl Into, + field: NestedFieldRef, + accessor: StructAccessorRef, + ) -> Self { + Self { + column_name: name.into(), + field, + accessor, + } + } + + /// Return the field of this reference. + pub fn field(&self) -> &NestedField { + &self.field + } + + /// Get this BoundReference's Accessor + pub fn accessor(&self) -> &StructAccessor { + &self.accessor + } +} + +impl Display for BoundReference { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.column_name) + } +} + +/// Bound term after binding to a schema. +pub type BoundTerm = BoundReference; + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::expr::accessor::StructAccessor; + use crate::expr::{Bind, BoundReference, Reference}; + use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; + + fn table_schema_simple() -> SchemaRef { + Arc::new( + Schema::builder() + .with_schema_id(1) + .with_identifier_field_ids(vec![2]) + .with_fields(vec![ + NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(), + ]) + .build() + .unwrap(), + ) + } + + #[test] + fn test_bind_reference() { + let schema = table_schema_simple(); + let reference = Reference::new("bar").bind(schema, true).unwrap(); + + let accessor_ref = Arc::new(StructAccessor::new(1, PrimitiveType::Int)); + let expected_ref = BoundReference::new( + "bar", + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + accessor_ref.clone(), + ); + + assert_eq!(expected_ref, reference); + } + + #[test] + fn test_bind_reference_case_insensitive() { + let schema = table_schema_simple(); + let reference = Reference::new("BAR").bind(schema, false).unwrap(); + + let accessor_ref = Arc::new(StructAccessor::new(1, PrimitiveType::Int)); + let expected_ref = BoundReference::new( + "BAR", + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + accessor_ref.clone(), + ); + + assert_eq!(expected_ref, reference); + } + + #[test] + fn test_bind_reference_failure() { + let schema = table_schema_simple(); + let result = Reference::new("bar_not_eix").bind(schema, true); + + assert!(result.is_err()); + } + + #[test] + fn test_bind_reference_case_insensitive_failure() { + let schema = table_schema_simple(); + let result = Reference::new("bar_non_exist").bind(schema, false); + assert!(result.is_err()); + } +} diff --git a/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs b/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs new file mode 100644 index 000000000..0858d1dcf --- /dev/null +++ b/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs @@ -0,0 +1,741 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use fnv::FnvHashSet; + +use crate::expr::{BoundPredicate, BoundReference, PredicateOperator}; +use crate::spec::Datum; +use crate::Result; + +/// A visitor for [`BoundPredicate`]s. Visits in post-order. +pub trait BoundPredicateVisitor { + /// The return type of this visitor + type T; + + /// Called after an `AlwaysTrue` predicate is visited + fn always_true(&mut self) -> Result; + + /// Called after an `AlwaysFalse` predicate is visited + fn always_false(&mut self) -> Result; + + /// Called after an `And` predicate is visited + fn and(&mut self, lhs: Self::T, rhs: Self::T) -> Result; + + /// Called after an `Or` predicate is visited + fn or(&mut self, lhs: Self::T, rhs: Self::T) -> Result; + + /// Called after a `Not` predicate is visited + fn not(&mut self, inner: Self::T) -> Result; + + /// Called after a predicate with an `IsNull` operator is visited + fn is_null( + &mut self, + reference: &BoundReference, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with a `NotNull` operator is visited + fn not_null( + &mut self, + reference: &BoundReference, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with an `IsNan` operator is visited + fn is_nan(&mut self, reference: &BoundReference, predicate: &BoundPredicate) + -> Result; + + /// Called after a predicate with a `NotNan` operator is visited + fn not_nan( + &mut self, + reference: &BoundReference, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with a `LessThan` operator is visited + fn less_than( + &mut self, + reference: &BoundReference, + literal: &Datum, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with a `LessThanOrEq` operator is visited + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with a `GreaterThan` operator is visited + fn greater_than( + &mut self, + reference: &BoundReference, + literal: &Datum, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with a `GreaterThanOrEq` operator is visited + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with an `Eq` operator is visited + fn eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with a `NotEq` operator is visited + fn not_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with a `StartsWith` operator is visited + fn starts_with( + &mut self, + reference: &BoundReference, + literal: &Datum, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with a `NotStartsWith` operator is visited + fn not_starts_with( + &mut self, + reference: &BoundReference, + literal: &Datum, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with an `In` operator is visited + fn r#in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + predicate: &BoundPredicate, + ) -> Result; + + /// Called after a predicate with a `NotIn` operator is visited + fn not_in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + predicate: &BoundPredicate, + ) -> Result; +} + +/// Visits a [`BoundPredicate`] with the provided visitor, +/// in post-order +pub(crate) fn visit( + visitor: &mut V, + predicate: &BoundPredicate, +) -> Result { + match predicate { + BoundPredicate::AlwaysTrue => visitor.always_true(), + BoundPredicate::AlwaysFalse => visitor.always_false(), + BoundPredicate::And(expr) => { + let [left_pred, right_pred] = expr.inputs(); + + let left_result = visit(visitor, left_pred)?; + let right_result = visit(visitor, right_pred)?; + + visitor.and(left_result, right_result) + } + BoundPredicate::Or(expr) => { + let [left_pred, right_pred] = expr.inputs(); + + let left_result = visit(visitor, left_pred)?; + let right_result = visit(visitor, right_pred)?; + + visitor.or(left_result, right_result) + } + BoundPredicate::Not(expr) => { + let [inner_pred] = expr.inputs(); + + let inner_result = visit(visitor, inner_pred)?; + + visitor.not(inner_result) + } + BoundPredicate::Unary(expr) => match expr.op() { + PredicateOperator::IsNull => visitor.is_null(expr.term(), predicate), + PredicateOperator::NotNull => visitor.not_null(expr.term(), predicate), + PredicateOperator::IsNan => visitor.is_nan(expr.term(), predicate), + PredicateOperator::NotNan => visitor.not_nan(expr.term(), predicate), + op => { + panic!("Unexpected op for unary predicate: {}", &op) + } + }, + BoundPredicate::Binary(expr) => { + let reference = expr.term(); + let literal = expr.literal(); + match expr.op() { + PredicateOperator::LessThan => visitor.less_than(reference, literal, predicate), + PredicateOperator::LessThanOrEq => { + visitor.less_than_or_eq(reference, literal, predicate) + } + PredicateOperator::GreaterThan => { + visitor.greater_than(reference, literal, predicate) + } + PredicateOperator::GreaterThanOrEq => { + visitor.greater_than_or_eq(reference, literal, predicate) + } + PredicateOperator::Eq => visitor.eq(reference, literal, predicate), + PredicateOperator::NotEq => visitor.not_eq(reference, literal, predicate), + PredicateOperator::StartsWith => visitor.starts_with(reference, literal, predicate), + PredicateOperator::NotStartsWith => { + visitor.not_starts_with(reference, literal, predicate) + } + op => { + panic!("Unexpected op for binary predicate: {}", &op) + } + } + } + BoundPredicate::Set(expr) => { + let reference = expr.term(); + let literals = expr.literals(); + match expr.op() { + PredicateOperator::In => visitor.r#in(reference, literals, predicate), + PredicateOperator::NotIn => visitor.not_in(reference, literals, predicate), + op => { + panic!("Unexpected op for set predicate: {}", &op) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::ops::Not; + use std::sync::Arc; + + use fnv::FnvHashSet; + + use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; + use crate::expr::{ + BinaryExpression, Bind, BoundPredicate, BoundReference, Predicate, PredicateOperator, + Reference, SetExpression, UnaryExpression, + }; + use crate::spec::{Datum, NestedField, PrimitiveType, Schema, SchemaRef, Type}; + + struct TestEvaluator {} + impl BoundPredicateVisitor for TestEvaluator { + type T = bool; + + fn always_true(&mut self) -> crate::Result { + Ok(true) + } + + fn always_false(&mut self) -> crate::Result { + Ok(false) + } + + fn and(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs && rhs) + } + + fn or(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs || rhs) + } + + fn not(&mut self, inner: Self::T) -> crate::Result { + Ok(!inner) + } + + fn is_null( + &mut self, + _reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(true) + } + + fn not_null( + &mut self, + _reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(false) + } + + fn is_nan( + &mut self, + _reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(true) + } + + fn not_nan( + &mut self, + _reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(false) + } + + fn less_than( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(true) + } + + fn less_than_or_eq( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(false) + } + + fn greater_than( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(true) + } + + fn greater_than_or_eq( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(false) + } + + fn eq( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(true) + } + + fn not_eq( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(false) + } + + fn starts_with( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(true) + } + + fn not_starts_with( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(false) + } + + fn r#in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(true) + } + + fn not_in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(false) + } + } + + fn create_test_schema() -> SchemaRef { + let schema = Schema::builder() + .with_fields(vec![ + Arc::new(NestedField::required( + 1, + "a", + Type::Primitive(PrimitiveType::Int), + )), + Arc::new(NestedField::required( + 2, + "b", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 3, + "c", + Type::Primitive(PrimitiveType::Float), + )), + ]) + .build() + .unwrap(); + + let schema_arc = Arc::new(schema); + schema_arc.clone() + } + + #[test] + fn test_always_true() { + let predicate = Predicate::AlwaysTrue; + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_always_false() { + let predicate = Predicate::AlwaysFalse; + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_logical_and() { + let predicate = Predicate::AlwaysTrue.and(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + + let predicate = Predicate::AlwaysFalse.and(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + + let predicate = Predicate::AlwaysTrue.and(Predicate::AlwaysTrue); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_logical_or() { + let predicate = Predicate::AlwaysTrue.or(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + + let predicate = Predicate::AlwaysFalse.or(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + + let predicate = Predicate::AlwaysTrue.or(Predicate::AlwaysTrue); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not() { + let predicate = Predicate::AlwaysFalse.not(); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + + let predicate = Predicate::AlwaysTrue.not(); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_is_null() { + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("c"), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_null() { + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("a"), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_is_nan() { + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("b"), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_nan() { + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("b"), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_less_than() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_less_than_or_eq() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThanOrEq, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_greater_than() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThan, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_greater_than_or_eq() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_eq() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_eq() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotEq, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_starts_with() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::StartsWith, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_starts_with() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotStartsWith, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_in() { + let predicate = Predicate::Set(SetExpression::new( + PredicateOperator::In, + Reference::new("a"), + FnvHashSet::from_iter(vec![Datum::int(1)]), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_in() { + let predicate = Predicate::Set(SetExpression::new( + PredicateOperator::NotIn, + Reference::new("a"), + FnvHashSet::from_iter(vec![Datum::int(1)]), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } +} diff --git a/crates/iceberg/src/expr/visitors/expression_evaluator.rs b/crates/iceberg/src/expr/visitors/expression_evaluator.rs new file mode 100644 index 000000000..8f3c2971c --- /dev/null +++ b/crates/iceberg/src/expr/visitors/expression_evaluator.rs @@ -0,0 +1,796 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use fnv::FnvHashSet; + +use super::bound_predicate_visitor::{visit, BoundPredicateVisitor}; +use crate::expr::{BoundPredicate, BoundReference}; +use crate::spec::{DataFile, Datum, PrimitiveLiteral, Struct}; +use crate::{Error, ErrorKind, Result}; + +/// Evaluates a [`DataFile`]'s partition [`Struct`] to check +/// if the partition tuples match the given [`BoundPredicate`]. +/// +/// Use within [`TableScan`] to prune the list of [`DataFile`]s +/// that could potentially match the TableScan's filter. +#[derive(Debug)] +pub(crate) struct ExpressionEvaluator { + /// The provided partition filter. + partition_filter: BoundPredicate, +} + +impl ExpressionEvaluator { + /// Creates a new [`ExpressionEvaluator`]. + pub(crate) fn new(partition_filter: BoundPredicate) -> Self { + Self { partition_filter } + } + + /// Evaluate this [`ExpressionEvaluator`]'s partition filter against + /// the provided [`DataFile`]'s partition [`Struct`]. Used by [`TableScan`] + /// to see if this [`DataFile`] could possibly contain data that matches + /// the scan's filter. + pub(crate) fn eval(&self, data_file: &DataFile) -> Result { + let mut visitor = ExpressionEvaluatorVisitor::new(data_file.partition()); + + visit(&mut visitor, &self.partition_filter) + } +} + +/// Acts as a visitor for [`ExpressionEvaluator`] to apply +/// evaluation logic to different parts of a data structure, +/// specifically for data file partitions. +#[derive(Debug)] +struct ExpressionEvaluatorVisitor<'a> { + /// Reference to a [`DataFile`]'s partition [`Struct`]. + partition: &'a Struct, +} + +impl<'a> ExpressionEvaluatorVisitor<'a> { + /// Creates a new [`ExpressionEvaluatorVisitor`]. + fn new(partition: &'a Struct) -> Self { + Self { partition } + } +} + +impl BoundPredicateVisitor for ExpressionEvaluatorVisitor<'_> { + type T = bool; + + fn always_true(&mut self) -> Result { + Ok(true) + } + + fn always_false(&mut self) -> Result { + Ok(false) + } + + fn and(&mut self, lhs: bool, rhs: bool) -> Result { + Ok(lhs && rhs) + } + + fn or(&mut self, lhs: bool, rhs: bool) -> Result { + Ok(lhs || rhs) + } + + fn not(&mut self, _inner: bool) -> Result { + Err(Error::new(ErrorKind::Unexpected, "The evaluation of expressions should not be performed against Predicates that contain a Not operator. Ensure that \"Rewrite Not\" gets applied to the originating Predicate before binding it.")) + } + + fn is_null(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result { + match reference.accessor().get(self.partition)? { + Some(_) => Ok(false), + None => Ok(true), + } + } + + fn not_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result { + match reference.accessor().get(self.partition)? { + Some(_) => Ok(true), + None => Ok(false), + } + } + + fn is_nan(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(datum.is_nan()), + None => Ok(false), + } + } + + fn not_nan(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(!datum.is_nan()), + None => Ok(true), + } + } + + fn less_than( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(&datum < literal), + None => Ok(false), + } + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(&datum <= literal), + None => Ok(false), + } + } + + fn greater_than( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(&datum > literal), + None => Ok(false), + } + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(&datum >= literal), + None => Ok(false), + } + } + + fn eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(&datum == literal), + None => Ok(false), + } + } + + fn not_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(&datum != literal), + None => Ok(true), + } + } + + fn starts_with( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + let Some(datum) = reference.accessor().get(self.partition)? else { + return Ok(false); + }; + + match (datum.literal(), literal.literal()) { + (PrimitiveLiteral::String(d), PrimitiveLiteral::String(l)) => Ok(d.starts_with(l)), + _ => Ok(false), + } + } + + fn not_starts_with( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + Ok(!self.starts_with(reference, literal, _predicate)?) + } + + fn r#in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(literals.contains(&datum)), + None => Ok(false), + } + } + + fn not_in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result { + match reference.accessor().get(self.partition)? { + Some(datum) => Ok(!literals.contains(&datum)), + None => Ok(true), + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use fnv::FnvHashSet; + use predicate::SetExpression; + + use super::ExpressionEvaluator; + use crate::expr::visitors::inclusive_projection::InclusiveProjection; + use crate::expr::{ + predicate, BinaryExpression, Bind, BoundPredicate, Predicate, PredicateOperator, Reference, + UnaryExpression, + }; + use crate::spec::{ + DataContentType, DataFile, DataFileFormat, Datum, Literal, NestedField, PartitionSpec, + PartitionSpecRef, PrimitiveType, Schema, SchemaRef, Struct, Transform, Type, + UnboundPartitionField, + }; + use crate::Result; + + fn create_schema_and_partition_spec( + r#type: PrimitiveType, + ) -> Result<(SchemaRef, PartitionSpecRef)> { + let schema = Schema::builder() + .with_fields(vec![Arc::new(NestedField::optional( + 1, + "a", + Type::Primitive(r#type), + ))]) + .build()?; + + let spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_fields(vec![UnboundPartitionField::builder() + .source_id(1) + .name("a".to_string()) + .field_id(1) + .transform(Transform::Identity) + .build()]) + .unwrap() + .build() + .unwrap(); + + Ok((Arc::new(schema), Arc::new(spec))) + } + + fn create_partition_filter( + schema: &Schema, + partition_spec: PartitionSpecRef, + predicate: &BoundPredicate, + case_sensitive: bool, + ) -> Result { + let partition_type = partition_spec.partition_type(schema)?; + let partition_fields = partition_type.fields().to_owned(); + + let partition_schema = Schema::builder() + .with_schema_id(partition_spec.spec_id()) + .with_fields(partition_fields) + .build()?; + + let mut inclusive_projection = InclusiveProjection::new(partition_spec); + + let partition_filter = inclusive_projection + .project(predicate)? + .rewrite_not() + .bind(Arc::new(partition_schema), case_sensitive)?; + + Ok(partition_filter) + } + + fn create_expression_evaluator( + schema: &Schema, + partition_spec: PartitionSpecRef, + predicate: &BoundPredicate, + case_sensitive: bool, + ) -> Result { + let partition_filter = + create_partition_filter(schema, partition_spec, predicate, case_sensitive)?; + + Ok(ExpressionEvaluator::new(partition_filter)) + } + + fn create_data_file_float() -> DataFile { + let partition = Struct::from_iter([Some(Literal::float(1.0))]); + + DataFile { + content: DataContentType::Data, + file_path: "/test/path".to_string(), + file_format: DataFileFormat::Parquet, + partition, + record_count: 1, + file_size_in_bytes: 1, + column_sizes: HashMap::new(), + value_counts: HashMap::new(), + null_value_counts: HashMap::new(), + nan_value_counts: HashMap::new(), + lower_bounds: HashMap::new(), + upper_bounds: HashMap::new(), + key_metadata: vec![], + split_offsets: vec![], + equality_ids: vec![], + sort_order_id: None, + } + } + + fn create_data_file_string() -> DataFile { + let partition = Struct::from_iter([Some(Literal::string("test str"))]); + + DataFile { + content: DataContentType::Data, + file_path: "/test/path".to_string(), + file_format: DataFileFormat::Parquet, + partition, + record_count: 1, + file_size_in_bytes: 1, + column_sizes: HashMap::new(), + value_counts: HashMap::new(), + null_value_counts: HashMap::new(), + nan_value_counts: HashMap::new(), + lower_bounds: HashMap::new(), + upper_bounds: HashMap::new(), + key_metadata: vec![], + split_offsets: vec![], + equality_ids: vec![], + sort_order_id: None, + } + } + + #[test] + fn test_expr_or() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("a"), + Datum::float(1.0), + )) + .or(Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("a"), + Datum::float(0.4), + ))) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_and() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("a"), + Datum::float(1.1), + )) + .and(Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("a"), + Datum::float(0.4), + ))) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_not_in() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Set(SetExpression::new( + PredicateOperator::NotIn, + Reference::new("a"), + FnvHashSet::from_iter([Datum::float(0.9), Datum::float(1.2), Datum::float(2.4)]), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_in() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Set(SetExpression::new( + PredicateOperator::In, + Reference::new("a"), + FnvHashSet::from_iter([Datum::float(1.0), Datum::float(1.2), Datum::float(2.4)]), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_not_starts_with() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::String)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotStartsWith, + Reference::new("a"), + Datum::string("not"), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_string(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_starts_with() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::String)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::StartsWith, + Reference::new("a"), + Datum::string("test"), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_string(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_not_eq() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotEq, + Reference::new("a"), + Datum::float(0.9), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_eq() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + Reference::new("a"), + Datum::float(1.0), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_greater_than_or_eq() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("a"), + Datum::float(1.0), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_greater_than() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThan, + Reference::new("a"), + Datum::float(0.9), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_less_than_or_eq() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThanOrEq, + Reference::new("a"), + Datum::float(1.0), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_less_than() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("a"), + Datum::float(1.1), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_is_not_nan() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("a"), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_is_nan() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("a"), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(!result); + + Ok(()) + } + + #[test] + fn test_expr_is_not_null() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("a"), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } + + #[test] + fn test_expr_is_null() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("a"), + )) + .bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(!result); + + Ok(()) + } + + #[test] + fn test_expr_always_false() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + let predicate = Predicate::AlwaysFalse.bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(!result); + + Ok(()) + } + + #[test] + fn test_expr_always_true() -> Result<()> { + let case_sensitive = true; + let (schema, partition_spec) = create_schema_and_partition_spec(PrimitiveType::Float)?; + let predicate = Predicate::AlwaysTrue.bind(schema.clone(), case_sensitive)?; + + let expression_evaluator = + create_expression_evaluator(&schema, partition_spec, &predicate, case_sensitive)?; + + let data_file = create_data_file_float(); + + let result = expression_evaluator.eval(&data_file)?; + + assert!(result); + + Ok(()) + } +} diff --git a/crates/iceberg/src/expr/visitors/inclusive_metrics_evaluator.rs b/crates/iceberg/src/expr/visitors/inclusive_metrics_evaluator.rs new file mode 100644 index 000000000..a2ee4722f --- /dev/null +++ b/crates/iceberg/src/expr/visitors/inclusive_metrics_evaluator.rs @@ -0,0 +1,2159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use fnv::FnvHashSet; + +use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; +use crate::expr::{BoundPredicate, BoundReference}; +use crate::spec::{DataFile, Datum, PrimitiveLiteral}; +use crate::{Error, ErrorKind}; + +const IN_PREDICATE_LIMIT: usize = 200; +const ROWS_MIGHT_MATCH: crate::Result = Ok(true); +const ROWS_CANNOT_MATCH: crate::Result = Ok(false); + +pub(crate) struct InclusiveMetricsEvaluator<'a> { + data_file: &'a DataFile, +} + +impl<'a> InclusiveMetricsEvaluator<'a> { + fn new(data_file: &'a DataFile) -> Self { + InclusiveMetricsEvaluator { data_file } + } + + /// Evaluate this `InclusiveMetricsEvaluator`'s filter predicate against the + /// provided [`DataFile`]'s metrics. Used by [`TableScan`] to + /// see if this `DataFile` contains data that could match + /// the scan's filter. + pub(crate) fn eval( + filter: &'a BoundPredicate, + data_file: &'a DataFile, + include_empty_files: bool, + ) -> crate::Result { + if !include_empty_files && data_file.record_count == 0 { + return ROWS_CANNOT_MATCH; + } + + let mut evaluator = Self::new(data_file); + visit(&mut evaluator, filter) + } + + fn nan_count(&self, field_id: i32) -> Option<&u64> { + self.data_file.nan_value_counts.get(&field_id) + } + + fn null_count(&self, field_id: i32) -> Option<&u64> { + self.data_file.null_value_counts.get(&field_id) + } + + fn value_count(&self, field_id: i32) -> Option<&u64> { + self.data_file.value_counts.get(&field_id) + } + + fn lower_bound(&self, field_id: i32) -> Option<&Datum> { + self.data_file.lower_bounds.get(&field_id) + } + + fn upper_bound(&self, field_id: i32) -> Option<&Datum> { + self.data_file.upper_bounds.get(&field_id) + } + + fn contains_nans_only(&self, field_id: i32) -> bool { + let nan_count = self.nan_count(field_id); + let value_count = self.value_count(field_id); + + nan_count.is_some() && nan_count == value_count + } + + fn contains_nulls_only(&self, field_id: i32) -> bool { + let null_count = self.null_count(field_id); + let value_count = self.value_count(field_id); + + null_count.is_some() && null_count == value_count + } + + fn may_contain_null(&self, field_id: i32) -> bool { + if let Some(&null_count) = self.null_count(field_id) { + null_count > 0 + } else { + true + } + } + + fn visit_inequality( + &mut self, + reference: &BoundReference, + datum: &Datum, + cmp_fn: fn(&Datum, &Datum) -> bool, + use_lower_bound: bool, + ) -> crate::Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) || self.contains_nans_only(field_id) { + return ROWS_CANNOT_MATCH; + } + + if datum.is_nan() { + // NaN indicates unreliable bounds. + // See the InclusiveMetricsEvaluator docs for more. + return ROWS_MIGHT_MATCH; + } + + let bound = if use_lower_bound { + self.lower_bound(field_id) + } else { + self.upper_bound(field_id) + }; + + if let Some(bound) = bound { + if cmp_fn(bound, datum) { + return ROWS_MIGHT_MATCH; + } + + return ROWS_CANNOT_MATCH; + } + + ROWS_MIGHT_MATCH + } +} + +impl BoundPredicateVisitor for InclusiveMetricsEvaluator<'_> { + type T = bool; + + fn always_true(&mut self) -> crate::Result { + ROWS_MIGHT_MATCH + } + + fn always_false(&mut self) -> crate::Result { + ROWS_CANNOT_MATCH + } + + fn and(&mut self, lhs: bool, rhs: bool) -> crate::Result { + Ok(lhs && rhs) + } + + fn or(&mut self, lhs: bool, rhs: bool) -> crate::Result { + Ok(lhs || rhs) + } + + fn not(&mut self, inner: bool) -> crate::Result { + Ok(!inner) + } + + fn is_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field_id = reference.field().id; + + match self.null_count(field_id) { + Some(&0) => ROWS_CANNOT_MATCH, + Some(_) => ROWS_MIGHT_MATCH, + None => ROWS_MIGHT_MATCH, + } + } + + fn not_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROWS_CANNOT_MATCH; + } + + ROWS_MIGHT_MATCH + } + + fn is_nan( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field_id = reference.field().id; + + match self.nan_count(field_id) { + Some(&0) => ROWS_CANNOT_MATCH, + _ if self.contains_nulls_only(field_id) => ROWS_CANNOT_MATCH, + _ => ROWS_MIGHT_MATCH, + } + } + + fn not_nan( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field_id = reference.field().id; + + if self.contains_nans_only(field_id) { + return ROWS_CANNOT_MATCH; + } + + ROWS_MIGHT_MATCH + } + + fn less_than( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + self.visit_inequality(reference, datum, PartialOrd::lt, true) + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + self.visit_inequality(reference, datum, PartialOrd::le, true) + } + + fn greater_than( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + self.visit_inequality(reference, datum, PartialOrd::gt, false) + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + self.visit_inequality(reference, datum, PartialOrd::ge, false) + } + + fn eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) || self.contains_nans_only(field_id) { + return ROWS_CANNOT_MATCH; + } + + if let Some(lower_bound) = self.lower_bound(field_id) { + if lower_bound.is_nan() { + // NaN indicates unreliable bounds. + // See the InclusiveMetricsEvaluator docs for more. + return ROWS_MIGHT_MATCH; + } else if lower_bound.gt(datum) { + return ROWS_CANNOT_MATCH; + } + } + + if let Some(upper_bound) = self.upper_bound(field_id) { + if upper_bound.is_nan() { + // NaN indicates unreliable bounds. + // See the InclusiveMetricsEvaluator docs for more. + return ROWS_MIGHT_MATCH; + } else if upper_bound.lt(datum) { + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH + } + + fn not_eq( + &mut self, + _reference: &BoundReference, + _datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + // Because the bounds are not necessarily a min or max value, + // this cannot be answered using them. notEq(col, X) with (X, Y) + // doesn't guarantee that X is a value in col. + ROWS_MIGHT_MATCH + } + + fn starts_with( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROWS_CANNOT_MATCH; + } + + let PrimitiveLiteral::String(datum) = datum.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string values", + )); + }; + + if let Some(lower_bound) = self.lower_bound(field_id) { + let PrimitiveLiteral::String(lower_bound) = lower_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string lower_bound value", + )); + }; + + let prefix_length = lower_bound.chars().count().min(datum.chars().count()); + + // truncate lower bound so that its length + // is not greater than the length of prefix + let truncated_lower_bound = lower_bound.chars().take(prefix_length).collect::(); + if datum < &truncated_lower_bound { + return ROWS_CANNOT_MATCH; + } + } + + if let Some(upper_bound) = self.upper_bound(field_id) { + let PrimitiveLiteral::String(upper_bound) = upper_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string upper_bound value", + )); + }; + + let prefix_length = upper_bound.chars().count().min(datum.chars().count()); + + // truncate upper bound so that its length + // is not greater than the length of prefix + let truncated_upper_bound = upper_bound.chars().take(prefix_length).collect::(); + if datum > &truncated_upper_bound { + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH + } + + fn not_starts_with( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field_id = reference.field().id; + + if self.may_contain_null(field_id) { + return ROWS_MIGHT_MATCH; + } + + // notStartsWith will match unless all values must start with the prefix. + // This happens when the lower and upper bounds both start with the prefix. + + let PrimitiveLiteral::String(prefix) = datum.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string values", + )); + }; + + let Some(lower_bound) = self.lower_bound(field_id) else { + return ROWS_MIGHT_MATCH; + }; + + let PrimitiveLiteral::String(lower_bound_str) = lower_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use NotStartsWith operator on non-string lower_bound value", + )); + }; + + if lower_bound_str < prefix { + // if lower is shorter than the prefix then lower doesn't start with the prefix + return ROWS_MIGHT_MATCH; + } + + let prefix_len = prefix.chars().count(); + + if lower_bound_str.chars().take(prefix_len).collect::() == *prefix { + // lower bound matches the prefix + + let Some(upper_bound) = self.upper_bound(field_id) else { + return ROWS_MIGHT_MATCH; + }; + + let PrimitiveLiteral::String(upper_bound) = upper_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use NotStartsWith operator on non-string upper_bound value", + )); + }; + + // if upper is shorter than the prefix then upper can't start with the prefix + if upper_bound.chars().count() < prefix_len { + return ROWS_MIGHT_MATCH; + } + + if upper_bound.chars().take(prefix_len).collect::() == *prefix { + // both bounds match the prefix, so all rows must match the + // prefix and therefore do not satisfy the predicate + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH + } + + fn r#in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) || self.contains_nans_only(field_id) { + return ROWS_CANNOT_MATCH; + } + + if literals.len() > IN_PREDICATE_LIMIT { + // skip evaluating the predicate if the number of values is too big + return ROWS_MIGHT_MATCH; + } + + if let Some(lower_bound) = self.lower_bound(field_id) { + if lower_bound.is_nan() { + // NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. + return ROWS_MIGHT_MATCH; + } + + if !literals.iter().any(|datum| datum.ge(lower_bound)) { + // if all values are less than lower bound, rows cannot match. + return ROWS_CANNOT_MATCH; + } + } + + if let Some(upper_bound) = self.upper_bound(field_id) { + if upper_bound.is_nan() { + // NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. + return ROWS_MIGHT_MATCH; + } + + if !literals.iter().any(|datum| datum.le(upper_bound)) { + // if all values are greater than upper bound, rows cannot match. + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH + } + + fn not_in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> crate::Result { + // Because the bounds are not necessarily a min or max value, + // this cannot be answered using them. notIn(col, {X, ...}) + // with (X, Y) doesn't guarantee that X is a value in col. + ROWS_MIGHT_MATCH + } +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + use std::ops::Not; + use std::sync::Arc; + + use fnv::FnvHashSet; + + use crate::expr::visitors::inclusive_metrics_evaluator::InclusiveMetricsEvaluator; + use crate::expr::PredicateOperator::{ + Eq, GreaterThan, GreaterThanOrEq, In, IsNan, IsNull, LessThan, LessThanOrEq, NotEq, NotIn, + NotNan, NotNull, NotStartsWith, StartsWith, + }; + use crate::expr::{ + BinaryExpression, Bind, BoundPredicate, Predicate, Reference, SetExpression, + UnaryExpression, + }; + use crate::spec::{ + DataContentType, DataFile, DataFileFormat, Datum, NestedField, PartitionSpec, + PrimitiveType, Schema, Struct, Transform, Type, UnboundPartitionField, + }; + + const INT_MIN_VALUE: i32 = 30; + const INT_MAX_VALUE: i32 = 79; + + #[test] + fn test_data_file_no_partitions() { + let (table_schema_ref, _partition_spec_ref) = create_test_schema_and_partition_spec(); + + let partition_filter = Predicate::AlwaysTrue + .bind(table_schema_ref.clone(), false) + .unwrap(); + + let case_sensitive = false; + + let data_file = create_test_data_file(); + + let result = + InclusiveMetricsEvaluator::eval(&partition_filter, &data_file, case_sensitive).unwrap(); + + assert!(result); + } + + #[test] + fn test_all_nulls() { + let result = + InclusiveMetricsEvaluator::eval(¬_null("all_nulls"), &get_test_file_1(), true) + .unwrap(); + assert!(!result, "Should skip: no non-null value in all null column"); + + let result = + InclusiveMetricsEvaluator::eval(&less_than("all_nulls", "a"), &get_test_file_1(), true) + .unwrap(); + assert!(!result, "Should skip: LessThan on an all null column"); + + let result = InclusiveMetricsEvaluator::eval( + &less_than_or_equal("all_nulls", "a"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + !result, + "Should skip: LessThanOrEqual on an all null column" + ); + + let result = InclusiveMetricsEvaluator::eval( + &greater_than("all_nulls", "a"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: GreaterThan on an all null column"); + + let result = InclusiveMetricsEvaluator::eval( + &greater_than_or_equal("all_nulls", "a"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + !result, + "Should skip: GreaterThanOrEqual on an all null column" + ); + + let result = + InclusiveMetricsEvaluator::eval(&equal("all_nulls", "a"), &get_test_file_1(), true) + .unwrap(); + assert!(!result, "Should skip: Equal on an all null column"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("all_nulls", "a"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: StartsWith on an all null column"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("all_nulls", "a"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: NotStartsWith on an all null column"); + + let result = + InclusiveMetricsEvaluator::eval(¬_null("some_nulls"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: col with some nulls could contain a non-null value" + ); + + let result = + InclusiveMetricsEvaluator::eval(¬_null("no_nulls"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: col with all nulls contains a non-null value" + ); + } + + #[test] + fn test_no_nulls() { + let result = + InclusiveMetricsEvaluator::eval(&is_null("all_nulls"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: col with all nulls contains a non-null value" + ); + + let result = + InclusiveMetricsEvaluator::eval(&is_null("some_nulls"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: col with some nulls could contain a non-null value" + ); + + let result = + InclusiveMetricsEvaluator::eval(&is_null("no_nulls"), &get_test_file_1(), true) + .unwrap(); + assert!( + !result, + "Should skip: col with no nulls can't contains a non-null value" + ); + } + + #[test] + fn test_is_nan() { + let result = + InclusiveMetricsEvaluator::eval(&is_nan("all_nans"), &get_test_file_1(), true).unwrap(); + assert!( + result, + "Should read: col with all nans must contains a nan value" + ); + + let result = + InclusiveMetricsEvaluator::eval(&is_nan("some_nans"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: col with some nans could contains a nan value" + ); + + let result = + InclusiveMetricsEvaluator::eval(&is_nan("no_nans"), &get_test_file_1(), true).unwrap(); + assert!( + !result, + "Should skip: col with no nans can't contains a nan value" + ); + + let result = + InclusiveMetricsEvaluator::eval(&is_nan("all_nulls_double"), &get_test_file_1(), true) + .unwrap(); + assert!( + !result, + "Should skip: col with no nans can't contains a nan value" + ); + + let result = + InclusiveMetricsEvaluator::eval(&is_nan("no_nan_stats"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: no guarantee col is nan-free without nan stats" + ); + + let result = + InclusiveMetricsEvaluator::eval(&is_nan("all_nans_v1_stats"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: col with all nans must contains a nan value" + ); + + let result = + InclusiveMetricsEvaluator::eval(&is_nan("nan_and_null_only"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: col with nans and nulls must contain a nan value" + ); + } + + #[test] + fn test_not_nan() { + let result = + InclusiveMetricsEvaluator::eval(¬_nan("all_nans"), &get_test_file_1(), true) + .unwrap(); + assert!( + !result, + "Should read: col with all nans must contains a nan value" + ); + + let result = + InclusiveMetricsEvaluator::eval(¬_nan("some_nans"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: col with some nans could contains a nan value" + ); + + let result = + InclusiveMetricsEvaluator::eval(¬_nan("no_nans"), &get_test_file_1(), true).unwrap(); + assert!( + result, + "Should read: col with no nans might contains a non-nan value" + ); + + let result = + InclusiveMetricsEvaluator::eval(¬_nan("all_nulls_double"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: col with no nans can't contains a nan value" + ); + + let result = + InclusiveMetricsEvaluator::eval(¬_nan("no_nan_stats"), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: no guarantee col is nan-free without nan stats" + ); + + let result = InclusiveMetricsEvaluator::eval( + ¬_nan("all_nans_v1_stats"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + result, + "Should read: col with all nans must contains a nan value" + ); + + let result = InclusiveMetricsEvaluator::eval( + ¬_nan("nan_and_null_only"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + result, + "Should read: col with nans and nulls may contain a non-nan value" + ); + } + + #[test] + fn test_required_column() { + let result = + InclusiveMetricsEvaluator::eval(¬_null("required"), &get_test_file_1(), true) + .unwrap(); + assert!(result, "Should read: required columns are always non-null"); + + let result = + InclusiveMetricsEvaluator::eval(&is_null("required"), &get_test_file_1(), true) + .unwrap(); + assert!(!result, "Should skip: required columns are always non-null"); + } + + #[test] + #[should_panic] + fn test_missing_column() { + let _result = + InclusiveMetricsEvaluator::eval(&less_than("missing", "a"), &get_test_file_1(), true); + } + + #[test] + fn test_missing_stats() { + let missing_stats_datafile = create_test_data_file(); + + let expressions = [ + less_than_int("no_stats", 5), + less_than_or_equal_int("no_stats", 30), + equal_int("no_stats", 70), + greater_than_int("no_stats", 78), + greater_than_or_equal_int("no_stats", 90), + not_equal_int("no_stats", 101), + is_null("no_stats"), + not_null("no_stats"), + // is_nan("no_stats"), + // not_nan("no_stats"), + ]; + + for expression in expressions { + let result = + InclusiveMetricsEvaluator::eval(&expression, &missing_stats_datafile, true) + .unwrap(); + + assert!( + result, + "Should read if stats are missing for {:?}", + &expression + ); + } + } + + #[test] + fn test_zero_record_file() { + let zero_records_datafile = create_zero_records_data_file(); + + let expressions = [ + less_than_int("no_stats", 5), + less_than_or_equal_int("no_stats", 30), + equal_int("no_stats", 70), + greater_than_int("no_stats", 78), + greater_than_or_equal_int("no_stats", 90), + not_equal_int("no_stats", 101), + is_null("no_stats"), + not_null("no_stats"), + // is_nan("no_stats"), + // not_nan("no_stats"), + ]; + + for expression in expressions { + let result = + InclusiveMetricsEvaluator::eval(&expression, &zero_records_datafile, true).unwrap(); + + assert!( + result, + "Should skip if data file has zero records (expression: {:?})", + &expression + ); + } + } + + #[test] + fn test_not() { + // Not sure if we need a test for this, as we'd expect, + // as a precondition, that rewrite-not has already been applied. + + let result = InclusiveMetricsEvaluator::eval( + ¬_less_than_int("id", INT_MIN_VALUE - 25), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: not(false)"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_greater_than_int("id", INT_MIN_VALUE - 25), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: not(true)"); + } + + #[test] + fn test_and() { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .and(Predicate::Binary(BinaryExpression::new( + GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 30), + ))); + + let bound_pred = filter.bind(schema.clone(), true).unwrap(); + + let result = + InclusiveMetricsEvaluator::eval(&bound_pred, &get_test_file_1(), true).unwrap(); + assert!(!result, "Should skip: and(false, true)"); + + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .and(Predicate::Binary(BinaryExpression::new( + GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MAX_VALUE + 1), + ))); + + let bound_pred = filter.bind(schema.clone(), true).unwrap(); + + let result = + InclusiveMetricsEvaluator::eval(&bound_pred, &get_test_file_1(), true).unwrap(); + assert!(!result, "Should skip: and(false, false)"); + + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + GreaterThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .and(Predicate::Binary(BinaryExpression::new( + LessThanOrEq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE), + ))); + + let bound_pred = filter.bind(schema.clone(), true).unwrap(); + + let result = + InclusiveMetricsEvaluator::eval(&bound_pred, &get_test_file_1(), true).unwrap(); + assert!(result, "Should read: and(true, true)"); + } + + #[test] + fn test_or() { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .or(Predicate::Binary(BinaryExpression::new( + GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 30), + ))); + + let bound_pred = filter.bind(schema.clone(), true).unwrap(); + + let result = + InclusiveMetricsEvaluator::eval(&bound_pred, &get_test_file_1(), true).unwrap(); + assert!(result, "Should read: or(false, true)"); + + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .or(Predicate::Binary(BinaryExpression::new( + GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MAX_VALUE + 1), + ))); + + let bound_pred = filter.bind(schema.clone(), true).unwrap(); + + let result = + InclusiveMetricsEvaluator::eval(&bound_pred, &get_test_file_1(), true).unwrap(); + assert!(!result, "Should skip: or(false, false)"); + } + + #[test] + fn test_integer_lt() { + let result = InclusiveMetricsEvaluator::eval( + &less_than_int("id", INT_MIN_VALUE - 25), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id range below lower bound (5 < 30)"); + + let result = InclusiveMetricsEvaluator::eval( + &less_than_int("id", INT_MIN_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + !result, + "Should skip: id range below lower bound (30 is not < 30)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &less_than_int("id", INT_MIN_VALUE + 1), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: one possible id"); + + let result = InclusiveMetricsEvaluator::eval( + &less_than_int("id", INT_MAX_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: many possible ids"); + } + + #[test] + fn test_integer_lt_eq() { + let result = InclusiveMetricsEvaluator::eval( + &less_than_or_equal_int("id", INT_MIN_VALUE - 25), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id range below lower bound (5 < 30)"); + + let result = InclusiveMetricsEvaluator::eval( + &less_than_or_equal_int("id", INT_MIN_VALUE - 1), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id range below lower bound (29 < 30)"); + + let result = InclusiveMetricsEvaluator::eval( + &less_than_or_equal_int("id", INT_MIN_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: one possible id"); + + let result = InclusiveMetricsEvaluator::eval( + &less_than_or_equal_int("id", INT_MAX_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: many possible ids"); + } + + #[test] + fn test_integer_gt() { + let result = InclusiveMetricsEvaluator::eval( + &greater_than_int("id", INT_MAX_VALUE + 6), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id range above upper bound (85 > 79)"); + + let result = InclusiveMetricsEvaluator::eval( + &greater_than_int("id", INT_MAX_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + !result, + "Should skip: id range above upper bound (79 is not > 79)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &greater_than_int("id", INT_MAX_VALUE - 1), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: one possible id"); + + let result = InclusiveMetricsEvaluator::eval( + &greater_than_int("id", INT_MAX_VALUE - 4), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: many possible ids"); + } + + #[test] + fn test_integer_gt_eq() { + let result = InclusiveMetricsEvaluator::eval( + &greater_than_or_equal_int("id", INT_MAX_VALUE + 6), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id range above upper bound (85 < 79)"); + + let result = InclusiveMetricsEvaluator::eval( + &greater_than_or_equal_int("id", INT_MAX_VALUE + 1), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id range above upper bound (80 > 79)"); + + let result = InclusiveMetricsEvaluator::eval( + &greater_than_or_equal_int("id", INT_MAX_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: one possible id"); + + let result = InclusiveMetricsEvaluator::eval( + &greater_than_or_equal_int("id", INT_MAX_VALUE - 4), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: many possible ids"); + } + + #[test] + fn test_integer_eq() { + let result = InclusiveMetricsEvaluator::eval( + &equal_int("id", INT_MIN_VALUE - 25), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id below lower bound"); + + let result = InclusiveMetricsEvaluator::eval( + &equal_int("id", INT_MIN_VALUE - 1), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id below lower bound"); + + let result = InclusiveMetricsEvaluator::eval( + &equal_int("id", INT_MIN_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id equal to lower bound"); + + let result = InclusiveMetricsEvaluator::eval( + &equal_int("id", INT_MAX_VALUE - 4), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id between lower and upper bounds"); + + let result = InclusiveMetricsEvaluator::eval( + &equal_int("id", INT_MAX_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id equal to upper bound"); + + let result = InclusiveMetricsEvaluator::eval( + &equal_int("id", INT_MAX_VALUE + 1), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id above upper bound"); + + let result = InclusiveMetricsEvaluator::eval( + &equal_int("id", INT_MAX_VALUE + 6), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: id above upper bound"); + } + + #[test] + fn test_integer_not_eq() { + let result = InclusiveMetricsEvaluator::eval( + ¬_equal_int("id", INT_MIN_VALUE - 25), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id below lower bound"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_equal_int("id", INT_MIN_VALUE - 1), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id below lower bound"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_equal_int("id", INT_MIN_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id equal to lower bound"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_equal_int("id", INT_MAX_VALUE - 4), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id between lower and upper bounds"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_equal_int("id", INT_MAX_VALUE), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id equal to upper bound"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_equal_int("id", INT_MAX_VALUE + 1), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id above upper bound"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_equal_int("id", INT_MAX_VALUE + 6), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id above upper bound"); + } + + #[test] + #[should_panic] + fn test_case_sensitive_integer_not_eq_rewritten() { + let _result = + InclusiveMetricsEvaluator::eval(&equal_int_not("ID", 5), &get_test_file_1(), true) + .unwrap(); + } + + #[test] + fn test_string_starts_with() { + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "a"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: no stats"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "a"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "aa"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "aaa"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "1s"), + &get_test_file_3(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "1str1x"), + &get_test_file_3(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "ff"), + &get_test_file_4(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "aB"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: range does not match"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "dWX"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: range does not match"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "5"), + &get_test_file_3(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: range does not match"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", "3str3x"), + &get_test_file_3(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: range does not match"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("some_empty", "房东整租霍"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: range does matches"); + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("all_nulls", ""), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: range does not match"); + + // Note: This string has been created manually by taking + // the string "イロハニホヘト", which is an upper bound in + // the datafile returned by get_test_file_4(), truncating it + // to four character, and then appending the "ボ" character, + // which occupies the next code point after the 5th + // character in the string above, "ホ". + // In the Java implementation of Iceberg, this is done by + // the `truncateStringMax` function, but we don't yet have + // this implemented in iceberg-rust. + let above_max = "イロハニボ"; + + let result = InclusiveMetricsEvaluator::eval( + &starts_with("required", above_max), + &get_test_file_4(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: range does not match"); + } + + #[test] + fn test_string_not_starts_with() { + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "a"), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: no stats"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "a"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "aa"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "aaa"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "1s"), + &get_test_file_3(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "1str1x"), + &get_test_file_3(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "ff"), + &get_test_file_4(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "aB"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "dWX"), + &get_test_file_2(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "5"), + &get_test_file_3(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", "3str3x"), + &get_test_file_3(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + + let above_max = "イロハニホヘト"; + let result = InclusiveMetricsEvaluator::eval( + ¬_starts_with("required", above_max), + &get_test_file_4(), + true, + ) + .unwrap(); + assert!(result, "Should read: range matches"); + } + + #[test] + fn test_integer_in() { + let result = InclusiveMetricsEvaluator::eval( + &r#in_int("id", &[INT_MIN_VALUE - 25, INT_MIN_VALUE - 24]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + !result, + "Should skip: id below lower bound (5 < 30, 6 < 30)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &r#in_int("id", &[INT_MIN_VALUE - 2, INT_MIN_VALUE - 1]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + !result, + "Should skip: id below lower bound (28 < 30, 29 < 30)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &r#in_int("id", &[INT_MIN_VALUE - 1, INT_MIN_VALUE]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id equal to lower bound (30 == 30)"); + + let result = InclusiveMetricsEvaluator::eval( + &r#in_int("id", &[INT_MAX_VALUE - 4, INT_MAX_VALUE - 3]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + result, + "Should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &r#in_int("id", &[INT_MAX_VALUE, INT_MAX_VALUE + 1]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id equal to upper bound (79 == 79)"); + + let result = InclusiveMetricsEvaluator::eval( + &r#in_int("id", &[INT_MAX_VALUE + 1, INT_MAX_VALUE + 2]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + !result, + "Should skip: id above upper bound (80 > 79, 81 > 79)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &r#in_int("id", &[INT_MAX_VALUE + 6, INT_MAX_VALUE + 7]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + !result, + "Should skip: id above upper bound (85 > 79, 86 > 79)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &r#in_str("all_nulls", &["abc", "def"]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(!result, "Should skip: in on all nulls column"); + + let result = InclusiveMetricsEvaluator::eval( + &r#in_str("some_nulls", &["abc", "def"]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: in on some nulls column"); + + let result = InclusiveMetricsEvaluator::eval( + &r#in_str("no_nulls", &["abc", "def"]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: in on no nulls column"); + + let ids = (-400..=0).collect::>(); + let result = + InclusiveMetricsEvaluator::eval(&r#in_int("id", &ids), &get_test_file_1(), true) + .unwrap(); + assert!( + result, + "Should read: number of items in In expression greater than threshold" + ); + } + + #[test] + fn test_integer_not_in() { + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_int("id", &[INT_MIN_VALUE - 25, INT_MIN_VALUE - 24]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id below lower bound (5 < 30, 6 < 30)"); + + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_int("id", &[INT_MIN_VALUE - 2, INT_MIN_VALUE - 1]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + result, + "Should read: id below lower bound (28 < 30, 29 < 30)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_int("id", &[INT_MIN_VALUE - 1, INT_MIN_VALUE]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id equal to lower bound (30 == 30)"); + + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_int("id", &[INT_MAX_VALUE - 4, INT_MAX_VALUE - 3]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + result, + "Should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_int("id", &[INT_MAX_VALUE, INT_MAX_VALUE + 1]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: id equal to upper bound (79 == 79)"); + + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_int("id", &[INT_MAX_VALUE + 1, INT_MAX_VALUE + 2]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + result, + "Should read: id above upper bound (80 > 79, 81 > 79)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_int("id", &[INT_MAX_VALUE + 6, INT_MAX_VALUE + 7]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!( + result, + "Should read: id above upper bound (85 > 79, 86 > 79)" + ); + + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_str("all_nulls", &["abc", "def"]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: NotIn on all nulls column"); + + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_str("some_nulls", &["abc", "def"]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: NotIn on some nulls column"); + + let result = InclusiveMetricsEvaluator::eval( + &r#not_in_str("no_nulls", &["abc", "def"]), + &get_test_file_1(), + true, + ) + .unwrap(); + assert!(result, "Should read: NotIn on no nulls column"); + } + + fn create_test_schema_and_partition_spec() -> (Arc, Arc) { + let table_schema = Schema::builder() + .with_fields(vec![Arc::new(NestedField::optional( + 1, + "a", + Type::Primitive(PrimitiveType::Float), + ))]) + .build() + .unwrap(); + let table_schema_ref = Arc::new(table_schema); + + let partition_spec = PartitionSpec::builder(&table_schema_ref) + .with_spec_id(1) + .add_unbound_fields(vec![UnboundPartitionField::builder() + .source_id(1) + .name("a".to_string()) + .field_id(1) + .transform(Transform::Identity) + .build()]) + .unwrap() + .build() + .unwrap(); + let partition_spec_ref = Arc::new(partition_spec); + (table_schema_ref, partition_spec_ref) + } + + fn not_null(reference: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Unary(UnaryExpression::new(NotNull, Reference::new(reference))); + filter.bind(schema.clone(), true).unwrap() + } + + fn is_null(reference: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Unary(UnaryExpression::new(IsNull, Reference::new(reference))); + filter.bind(schema.clone(), true).unwrap() + } + + fn not_nan(reference: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Unary(UnaryExpression::new(NotNan, Reference::new(reference))); + filter.bind(schema.clone(), true).unwrap() + } + + fn is_nan(reference: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Unary(UnaryExpression::new(IsNan, Reference::new(reference))); + filter.bind(schema.clone(), true).unwrap() + } + + fn less_than(reference: &str, str_literal: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + LessThan, + Reference::new(reference), + Datum::string(str_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn less_than_or_equal(reference: &str, str_literal: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + LessThanOrEq, + Reference::new(reference), + Datum::string(str_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn greater_than(reference: &str, str_literal: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + GreaterThan, + Reference::new(reference), + Datum::string(str_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn greater_than_or_equal(reference: &str, str_literal: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + GreaterThanOrEq, + Reference::new(reference), + Datum::string(str_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn equal(reference: &str, str_literal: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + Eq, + Reference::new(reference), + Datum::string(str_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn less_than_int(reference: &str, int_literal: i32) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + LessThan, + Reference::new(reference), + Datum::int(int_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn not_less_than_int(reference: &str, int_literal: i32) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + LessThan, + Reference::new(reference), + Datum::int(int_literal), + )) + .not(); + filter.bind(schema.clone(), true).unwrap() + } + + fn less_than_or_equal_int(reference: &str, int_literal: i32) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + LessThanOrEq, + Reference::new(reference), + Datum::int(int_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn greater_than_int(reference: &str, int_literal: i32) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + GreaterThan, + Reference::new(reference), + Datum::int(int_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn not_greater_than_int(reference: &str, int_literal: i32) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + GreaterThan, + Reference::new(reference), + Datum::int(int_literal), + )) + .not(); + filter.bind(schema.clone(), true).unwrap() + } + + fn greater_than_or_equal_int(reference: &str, int_literal: i32) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + GreaterThanOrEq, + Reference::new(reference), + Datum::int(int_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn equal_int(reference: &str, int_literal: i32) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + Eq, + Reference::new(reference), + Datum::int(int_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn equal_int_not(reference: &str, int_literal: i32) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + Eq, + Reference::new(reference), + Datum::int(int_literal), + )) + .not(); + filter.bind(schema.clone(), true).unwrap() + } + + fn not_equal_int(reference: &str, int_literal: i32) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + NotEq, + Reference::new(reference), + Datum::int(int_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn starts_with(reference: &str, str_literal: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + StartsWith, + Reference::new(reference), + Datum::string(str_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn not_starts_with(reference: &str, str_literal: &str) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Binary(BinaryExpression::new( + NotStartsWith, + Reference::new(reference), + Datum::string(str_literal), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn in_int(reference: &str, int_literals: &[i32]) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Set(SetExpression::new( + In, + Reference::new(reference), + FnvHashSet::from_iter(int_literals.iter().map(|&lit| Datum::int(lit))), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn in_str(reference: &str, str_literals: &[&str]) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Set(SetExpression::new( + In, + Reference::new(reference), + FnvHashSet::from_iter(str_literals.iter().map(Datum::string)), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn not_in_int(reference: &str, int_literals: &[i32]) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Set(SetExpression::new( + NotIn, + Reference::new(reference), + FnvHashSet::from_iter(int_literals.iter().map(|&lit| Datum::int(lit))), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn not_in_str(reference: &str, str_literals: &[&str]) -> BoundPredicate { + let schema = create_test_schema(); + let filter = Predicate::Set(SetExpression::new( + NotIn, + Reference::new(reference), + FnvHashSet::from_iter(str_literals.iter().map(Datum::string)), + )); + filter.bind(schema.clone(), true).unwrap() + } + + fn create_test_schema() -> Arc { + let table_schema = Schema::builder() + .with_fields(vec![ + Arc::new(NestedField::required( + 1, + "id", + Type::Primitive(PrimitiveType::Int), + )), + Arc::new(NestedField::optional( + 2, + "no_stats", + Type::Primitive(PrimitiveType::Int), + )), + Arc::new(NestedField::required( + 3, + "required", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 4, + "all_nulls", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 5, + "some_nulls", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 6, + "no_nulls", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 7, + "all_nans", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 8, + "some_nans", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 9, + "no_nans", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 10, + "all_nulls_double", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 11, + "all_nans_v1_stats", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 12, + "nan_and_null_only", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 13, + "no_nan_stats", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 14, + "some_empty", + Type::Primitive(PrimitiveType::String), + )), + ]) + .build() + .unwrap(); + + Arc::new(table_schema) + } + + fn create_test_data_file() -> DataFile { + DataFile { + content: DataContentType::Data, + file_path: "/test/path".to_string(), + file_format: DataFileFormat::Parquet, + partition: Struct::empty(), + record_count: 10, + file_size_in_bytes: 10, + column_sizes: Default::default(), + value_counts: Default::default(), + null_value_counts: Default::default(), + nan_value_counts: Default::default(), + lower_bounds: Default::default(), + upper_bounds: Default::default(), + key_metadata: vec![], + split_offsets: vec![], + equality_ids: vec![], + sort_order_id: None, + } + } + + fn create_zero_records_data_file() -> DataFile { + DataFile { + content: DataContentType::Data, + file_path: "/test/path".to_string(), + file_format: DataFileFormat::Parquet, + partition: Struct::empty(), + record_count: 0, + file_size_in_bytes: 10, + column_sizes: Default::default(), + value_counts: Default::default(), + null_value_counts: Default::default(), + nan_value_counts: Default::default(), + lower_bounds: Default::default(), + upper_bounds: Default::default(), + key_metadata: vec![], + split_offsets: vec![], + equality_ids: vec![], + sort_order_id: None, + } + } + + fn get_test_file_1() -> DataFile { + DataFile { + content: DataContentType::Data, + file_path: "/test/path".to_string(), + file_format: DataFileFormat::Parquet, + partition: Struct::empty(), + record_count: 50, + file_size_in_bytes: 10, + + value_counts: HashMap::from([ + (4, 50), + (5, 50), + (6, 50), + (7, 50), + (8, 50), + (9, 50), + (10, 50), + (11, 50), + (12, 50), + (13, 50), + (14, 50), + ]), + + null_value_counts: HashMap::from([ + (4, 50), + (5, 10), + (6, 0), + (10, 50), + (11, 0), + (12, 1), + (14, 0), + ]), + + nan_value_counts: HashMap::from([(7, 50), (8, 10), (9, 0)]), + + lower_bounds: HashMap::from([ + (1, Datum::int(INT_MIN_VALUE)), + (11, Datum::float(f32::NAN)), + (12, Datum::double(f64::NAN)), + (14, Datum::string("")), + ]), + + upper_bounds: HashMap::from([ + (1, Datum::int(INT_MAX_VALUE)), + (11, Datum::float(f32::NAN)), + (12, Datum::double(f64::NAN)), + (14, Datum::string("房东整租霍营小区二层两居室")), + ]), + + column_sizes: Default::default(), + key_metadata: vec![], + split_offsets: vec![], + equality_ids: vec![], + sort_order_id: None, + } + } + fn get_test_file_2() -> DataFile { + DataFile { + content: DataContentType::Data, + file_path: "file_2.avro".to_string(), + file_format: DataFileFormat::Parquet, + partition: Struct::empty(), + record_count: 50, + file_size_in_bytes: 10, + + value_counts: HashMap::from([(3, 20)]), + + null_value_counts: HashMap::from([(3, 2)]), + + nan_value_counts: HashMap::default(), + + lower_bounds: HashMap::from([(3, Datum::string("aa"))]), + + upper_bounds: HashMap::from([(3, Datum::string("dC"))]), + + column_sizes: Default::default(), + key_metadata: vec![], + split_offsets: vec![], + equality_ids: vec![], + sort_order_id: None, + } + } + + fn get_test_file_3() -> DataFile { + DataFile { + content: DataContentType::Data, + file_path: "file_3.avro".to_string(), + file_format: DataFileFormat::Parquet, + partition: Struct::empty(), + record_count: 50, + file_size_in_bytes: 10, + + value_counts: HashMap::from([(3, 20)]), + + null_value_counts: HashMap::from([(3, 2)]), + + nan_value_counts: HashMap::default(), + + lower_bounds: HashMap::from([(3, Datum::string("1str1"))]), + + upper_bounds: HashMap::from([(3, Datum::string("3str3"))]), + + column_sizes: Default::default(), + key_metadata: vec![], + split_offsets: vec![], + equality_ids: vec![], + sort_order_id: None, + } + } + + fn get_test_file_4() -> DataFile { + DataFile { + content: DataContentType::Data, + file_path: "file_4.avro".to_string(), + file_format: DataFileFormat::Parquet, + partition: Struct::empty(), + record_count: 50, + file_size_in_bytes: 10, + + value_counts: HashMap::from([(3, 20)]), + + null_value_counts: HashMap::from([(3, 2)]), + + nan_value_counts: HashMap::default(), + + lower_bounds: HashMap::from([(3, Datum::string("abc"))]), + + upper_bounds: HashMap::from([(3, Datum::string("イロハニホヘト"))]), + + column_sizes: Default::default(), + key_metadata: vec![], + split_offsets: vec![], + equality_ids: vec![], + sort_order_id: None, + } + } +} diff --git a/crates/iceberg/src/expr/visitors/inclusive_projection.rs b/crates/iceberg/src/expr/visitors/inclusive_projection.rs new file mode 100644 index 000000000..2087207ea --- /dev/null +++ b/crates/iceberg/src/expr/visitors/inclusive_projection.rs @@ -0,0 +1,455 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use fnv::FnvHashSet; + +use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; +use crate::expr::{BoundPredicate, BoundReference, Predicate}; +use crate::spec::{Datum, PartitionField, PartitionSpecRef}; +use crate::Error; + +pub(crate) struct InclusiveProjection { + partition_spec: PartitionSpecRef, + cached_parts: HashMap>, +} + +impl InclusiveProjection { + pub(crate) fn new(partition_spec: PartitionSpecRef) -> Self { + Self { + partition_spec, + cached_parts: HashMap::new(), + } + } + + fn get_parts_for_field_id(&mut self, field_id: i32) -> &Vec { + if let std::collections::hash_map::Entry::Vacant(e) = self.cached_parts.entry(field_id) { + let mut parts: Vec = vec![]; + for partition_spec_field in self.partition_spec.fields() { + if partition_spec_field.source_id == field_id { + parts.push(partition_spec_field.clone()) + } + } + + e.insert(parts); + } + + &self.cached_parts[&field_id] + } + + pub(crate) fn project(&mut self, predicate: &BoundPredicate) -> crate::Result { + visit(self, predicate) + } + + fn get_parts( + &mut self, + reference: &BoundReference, + predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + // This could be made a bit neater if `try_reduce` ever becomes stable + self.get_parts_for_field_id(field_id) + .iter() + .try_fold(Predicate::AlwaysTrue, |res, part| { + Ok( + if let Some(pred_for_part) = part.transform.project(&part.name, predicate)? { + if res == Predicate::AlwaysTrue { + pred_for_part + } else { + res.and(pred_for_part) + } + } else { + res + }, + ) + }) + } +} + +impl BoundPredicateVisitor for InclusiveProjection { + type T = Predicate; + + fn always_true(&mut self) -> crate::Result { + Ok(Predicate::AlwaysTrue) + } + + fn always_false(&mut self) -> crate::Result { + Ok(Predicate::AlwaysFalse) + } + + fn and(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs.and(rhs)) + } + + fn or(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs.or(rhs)) + } + + fn not(&mut self, _inner: Self::T) -> crate::Result { + panic!("InclusiveProjection should not be performed against Predicates that contain a Not operator. Ensure that \"Rewrite Not\" gets applied to the originating Predicate before binding it.") + } + + fn is_null( + &mut self, + reference: &BoundReference, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn not_null( + &mut self, + reference: &BoundReference, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn is_nan( + &mut self, + reference: &BoundReference, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn not_nan( + &mut self, + reference: &BoundReference, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn less_than( + &mut self, + reference: &BoundReference, + _literal: &Datum, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn greater_than( + &mut self, + reference: &BoundReference, + _literal: &Datum, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn not_eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn starts_with( + &mut self, + reference: &BoundReference, + _literal: &Datum, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn not_starts_with( + &mut self, + reference: &BoundReference, + _literal: &Datum, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn r#in( + &mut self, + reference: &BoundReference, + _literals: &FnvHashSet, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } + + fn not_in( + &mut self, + reference: &BoundReference, + _literals: &FnvHashSet, + predicate: &BoundPredicate, + ) -> crate::Result { + self.get_parts(reference, predicate) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::expr::visitors::inclusive_projection::InclusiveProjection; + use crate::expr::{Bind, Predicate, Reference}; + use crate::spec::{ + Datum, NestedField, PartitionField, PartitionSpec, PrimitiveType, Schema, Transform, Type, + UnboundPartitionField, + }; + + fn build_test_schema() -> Schema { + Schema::builder() + .with_fields(vec![ + Arc::new(NestedField::required( + 1, + "a", + Type::Primitive(PrimitiveType::Int), + )), + Arc::new(NestedField::required( + 2, + "date", + Type::Primitive(PrimitiveType::Date), + )), + Arc::new(NestedField::required( + 3, + "name", + Type::Primitive(PrimitiveType::String), + )), + ]) + .build() + .unwrap() + } + + #[test] + fn test_inclusive_projection_logic_ops() { + let schema = build_test_schema(); + + let partition_spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .build() + .unwrap(); + + let arc_schema = Arc::new(schema); + let arc_partition_spec = Arc::new(partition_spec); + + // this predicate contains only logic operators, + // AlwaysTrue, and AlwaysFalse. + let unbound_predicate = Predicate::AlwaysTrue + .and(Predicate::AlwaysFalse) + .or(Predicate::AlwaysTrue); + + let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap(); + + // applying InclusiveProjection to bound_predicate + // should result in the same Predicate as the original + // `unbound_predicate`, since `InclusiveProjection` + // simply unbinds logic ops, AlwaysTrue, and AlwaysFalse. + let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec.clone()); + let result = inclusive_projection.project(&bound_predicate).unwrap(); + + assert_eq!(result.to_string(), "TRUE".to_string()) + } + + #[test] + fn test_inclusive_projection_identity_transform() { + let schema = build_test_schema(); + + let partition_spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field( + UnboundPartitionField::builder() + .source_id(1) + .name("a".to_string()) + .field_id(1) + .transform(Transform::Identity) + .build(), + ) + .unwrap() + .build() + .unwrap(); + + let arc_schema = Arc::new(schema); + let arc_partition_spec = Arc::new(partition_spec); + + let unbound_predicate = Reference::new("a").less_than(Datum::int(10)); + + let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap(); + + // applying InclusiveProjection to bound_predicate + // should result in the same Predicate as the original + // `unbound_predicate`, since we have just a single partition field, + // and it has an Identity transform + let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec.clone()); + let result = inclusive_projection.project(&bound_predicate).unwrap(); + + let expected = "a < 10".to_string(); + + assert_eq!(result.to_string(), expected) + } + + #[test] + fn test_inclusive_projection_date_transforms() { + let schema = build_test_schema(); + + let partition_spec = PartitionSpec { + spec_id: 1, + fields: vec![ + PartitionField { + source_id: 2, + name: "year".to_string(), + field_id: 1000, + transform: Transform::Year, + }, + PartitionField { + source_id: 2, + name: "month".to_string(), + field_id: 1001, + transform: Transform::Month, + }, + PartitionField { + source_id: 2, + name: "day".to_string(), + field_id: 1002, + transform: Transform::Day, + }, + ], + }; + + let arc_schema = Arc::new(schema); + let arc_partition_spec = Arc::new(partition_spec); + + let unbound_predicate = + Reference::new("date").less_than(Datum::date_from_str("2024-01-01").unwrap()); + + let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap(); + + // applying InclusiveProjection to bound_predicate + // should result in a predicate that correctly handles + // year, month and date + let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec.clone()); + let result = inclusive_projection.project(&bound_predicate).unwrap(); + + let expected = "((year <= 53) AND (month <= 647)) AND (day <= 19722)".to_string(); + + assert_eq!(result.to_string(), expected); + } + + #[test] + fn test_inclusive_projection_truncate_transform() { + let schema = build_test_schema(); + + let partition_spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field( + UnboundPartitionField::builder() + .source_id(3) + .name("name_truncate".to_string()) + .field_id(3) + .transform(Transform::Truncate(4)) + .build(), + ) + .unwrap() + .build() + .unwrap(); + + let arc_schema = Arc::new(schema); + let arc_partition_spec = Arc::new(partition_spec); + + let unbound_predicate = Reference::new("name").starts_with(Datum::string("Testy McTest")); + + let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap(); + + // applying InclusiveProjection to bound_predicate + // should result in the 'name STARTS WITH "Testy McTest"' + // predicate being transformed to 'name_truncate STARTS WITH "Test"', + // since a `Truncate(4)` partition will map values of + // name that start with "Testy McTest" into a partition + // for values of name that start with the first four letters + // of that, ie "Test". + let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec.clone()); + let result = inclusive_projection.project(&bound_predicate).unwrap(); + + let expected = "name_truncate STARTS WITH \"Test\"".to_string(); + + assert_eq!(result.to_string(), expected) + } + + #[test] + fn test_inclusive_projection_bucket_transform() { + let schema = build_test_schema(); + + let partition_spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field( + UnboundPartitionField::builder() + .source_id(1) + .name("a_bucket[7]".to_string()) + .field_id(1) + .transform(Transform::Bucket(7)) + .build(), + ) + .unwrap() + .build() + .unwrap(); + + let arc_schema = Arc::new(schema); + let arc_partition_spec = Arc::new(partition_spec); + + let unbound_predicate = Reference::new("a").equal_to(Datum::int(10)); + + let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap(); + + // applying InclusiveProjection to bound_predicate + // should result in the "a = 10" predicate being + // transformed into "a = 2", since 10 gets bucketed + // to 2 with a Bucket(7) partition + let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec.clone()); + let result = inclusive_projection.project(&bound_predicate).unwrap(); + + let expected = "a_bucket[7] = 2".to_string(); + + assert_eq!(result.to_string(), expected) + } +} diff --git a/crates/iceberg/src/expr/visitors/manifest_evaluator.rs b/crates/iceberg/src/expr/visitors/manifest_evaluator.rs new file mode 100644 index 000000000..3554d57a0 --- /dev/null +++ b/crates/iceberg/src/expr/visitors/manifest_evaluator.rs @@ -0,0 +1,1293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use fnv::FnvHashSet; + +use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; +use crate::expr::{BoundPredicate, BoundReference}; +use crate::spec::{Datum, FieldSummary, ManifestFile, PrimitiveLiteral, Type}; +use crate::{Error, ErrorKind, Result}; + +/// Evaluates a [`ManifestFile`] to see if the partition summaries +/// match a provided [`BoundPredicate`]. +/// +/// Used by [`TableScan`] to prune the list of [`ManifestFile`]s +/// in which data might be found that matches the TableScan's filter. +#[derive(Debug)] +pub(crate) struct ManifestEvaluator { + partition_filter: BoundPredicate, +} + +impl ManifestEvaluator { + pub(crate) fn new(partition_filter: BoundPredicate) -> Self { + Self { partition_filter } + } + + /// Evaluate this `ManifestEvaluator`'s filter predicate against the + /// provided [`ManifestFile`]'s partitions. Used by [`TableScan`] to + /// see if this `ManifestFile` could possibly contain data that matches + /// the scan's filter. + pub(crate) fn eval(&self, manifest_file: &ManifestFile) -> Result { + if manifest_file.partitions.is_empty() { + return Ok(true); + } + + let mut evaluator = ManifestFilterVisitor::new(&manifest_file.partitions); + + visit(&mut evaluator, &self.partition_filter) + } +} + +struct ManifestFilterVisitor<'a> { + partitions: &'a Vec, +} + +impl<'a> ManifestFilterVisitor<'a> { + fn new(partitions: &'a Vec) -> Self { + ManifestFilterVisitor { partitions } + } +} + +const ROWS_MIGHT_MATCH: Result = Ok(true); +const ROWS_CANNOT_MATCH: Result = Ok(false); +const IN_PREDICATE_LIMIT: usize = 200; + +impl BoundPredicateVisitor for ManifestFilterVisitor<'_> { + type T = bool; + + fn always_true(&mut self) -> crate::Result { + ROWS_MIGHT_MATCH + } + + fn always_false(&mut self) -> crate::Result { + ROWS_CANNOT_MATCH + } + + fn and(&mut self, lhs: bool, rhs: bool) -> crate::Result { + Ok(lhs && rhs) + } + + fn or(&mut self, lhs: bool, rhs: bool) -> crate::Result { + Ok(lhs || rhs) + } + + fn not(&mut self, inner: bool) -> crate::Result { + Ok(!inner) + } + + fn is_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(self.field_summary_for_reference(reference).contains_null) + } + + fn not_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + + // contains_null encodes whether at least one partition value is null, + // lowerBound is null if all partition values are null + if ManifestFilterVisitor::are_all_null(field, &reference.field().field_type) { + ROWS_CANNOT_MATCH + } else { + ROWS_MIGHT_MATCH + } + } + + fn is_nan( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + if let Some(contains_nan) = field.contains_nan { + if !contains_nan { + return ROWS_CANNOT_MATCH; + } + } + + if ManifestFilterVisitor::are_all_null(field, &reference.field().field_type) { + return ROWS_CANNOT_MATCH; + } + + ROWS_MIGHT_MATCH + } + + fn not_nan( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + if let Some(contains_nan) = field.contains_nan { + // check if all values are nan + if contains_nan && !field.contains_null && field.lower_bound.is_none() { + return ROWS_CANNOT_MATCH; + } + } + ROWS_MIGHT_MATCH + } + + fn less_than( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + match &field.lower_bound { + Some(bound) if datum <= bound => ROWS_CANNOT_MATCH, + Some(_) => ROWS_MIGHT_MATCH, + None => ROWS_CANNOT_MATCH, + } + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + match &field.lower_bound { + Some(bound) if datum < bound => ROWS_CANNOT_MATCH, + Some(_) => ROWS_MIGHT_MATCH, + None => ROWS_CANNOT_MATCH, + } + } + + fn greater_than( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + match &field.upper_bound { + Some(bound) if datum >= bound => ROWS_CANNOT_MATCH, + Some(_) => ROWS_MIGHT_MATCH, + None => ROWS_CANNOT_MATCH, + } + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + match &field.upper_bound { + Some(bound) if datum > bound => ROWS_CANNOT_MATCH, + Some(_) => ROWS_MIGHT_MATCH, + None => ROWS_CANNOT_MATCH, + } + } + + fn eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + + if field.lower_bound.is_none() || field.upper_bound.is_none() { + return ROWS_CANNOT_MATCH; + } + + if let Some(lower_bound) = &field.lower_bound { + if lower_bound > datum { + return ROWS_CANNOT_MATCH; + } + } + + if let Some(upper_bound) = &field.upper_bound { + if upper_bound < datum { + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH + } + + fn not_eq( + &mut self, + _reference: &BoundReference, + _datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + // because the bounds are not necessarily a min or max value, this cannot be answered using + // them. notEq(col, X) with (X, Y) doesn't guarantee that X is a value in col. + ROWS_MIGHT_MATCH + } + + fn starts_with( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + + if field.lower_bound.is_none() || field.upper_bound.is_none() { + return ROWS_CANNOT_MATCH; + } + + let prefix = ManifestFilterVisitor::datum_as_str( + datum, + "Cannot perform starts_with on non-string value", + )?; + let prefix_len = prefix.len(); + + if let Some(lower_bound) = &field.lower_bound { + let lower_bound_str = ManifestFilterVisitor::datum_as_str( + lower_bound, + "Cannot perform starts_with on non-string lower bound", + )?; + let min_len = lower_bound_str.len().min(prefix_len); + if prefix.as_bytes().lt(&lower_bound_str.as_bytes()[..min_len]) { + return ROWS_CANNOT_MATCH; + } + } + + if let Some(upper_bound) = &field.upper_bound { + let upper_bound_str = ManifestFilterVisitor::datum_as_str( + upper_bound, + "Cannot perform starts_with on non-string upper bound", + )?; + let min_len = upper_bound_str.len().min(prefix_len); + if prefix.as_bytes().gt(&upper_bound_str.as_bytes()[..min_len]) { + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH + } + + fn not_starts_with( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + + if field.contains_null || field.lower_bound.is_none() || field.upper_bound.is_none() { + return ROWS_MIGHT_MATCH; + } + + let prefix = ManifestFilterVisitor::datum_as_str( + datum, + "Cannot perform not_starts_with on non-string value", + )?; + let prefix_len = prefix.len(); + + // not_starts_with will match unless all values must start with the prefix. This happens when + // the lower and upper bounds both start with the prefix. + if let Some(lower_bound) = &field.lower_bound { + let lower_bound_str = ManifestFilterVisitor::datum_as_str( + lower_bound, + "Cannot perform not_starts_with on non-string lower bound", + )?; + + // if lower is shorter than the prefix then lower doesn't start with the prefix + if prefix_len > lower_bound_str.len() { + return ROWS_MIGHT_MATCH; + } + + if prefix + .as_bytes() + .eq(&lower_bound_str.as_bytes()[..prefix_len]) + { + if let Some(upper_bound) = &field.upper_bound { + let upper_bound_str = ManifestFilterVisitor::datum_as_str( + upper_bound, + "Cannot perform not_starts_with on non-string upper bound", + )?; + + // if upper is shorter than the prefix then upper can't start with the prefix + if prefix_len > upper_bound_str.len() { + return ROWS_MIGHT_MATCH; + } + + if prefix + .as_bytes() + .eq(&upper_bound_str.as_bytes()[..prefix_len]) + { + return ROWS_CANNOT_MATCH; + } + } + } + } + + ROWS_MIGHT_MATCH + } + + fn r#in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> crate::Result { + let field = self.field_summary_for_reference(reference); + if field.lower_bound.is_none() { + return ROWS_CANNOT_MATCH; + } + + if literals.len() > IN_PREDICATE_LIMIT { + return ROWS_MIGHT_MATCH; + } + + if let Some(lower_bound) = &field.lower_bound { + if literals.iter().all(|datum| lower_bound > datum) { + return ROWS_CANNOT_MATCH; + } + } + + if let Some(upper_bound) = &field.upper_bound { + if literals.iter().all(|datum| upper_bound < datum) { + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH + } + + fn not_in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> crate::Result { + // because the bounds are not necessarily a min or max value, this cannot be answered using + // them. notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col. + ROWS_MIGHT_MATCH + } +} + +impl ManifestFilterVisitor<'_> { + fn field_summary_for_reference(&self, reference: &BoundReference) -> &FieldSummary { + let pos = reference.accessor().position(); + &self.partitions[pos] + } + + fn are_all_null(field: &FieldSummary, r#type: &Type) -> bool { + // contains_null encodes whether at least one partition value is null, + // lowerBound is null if all partition values are null + let mut all_null: bool = field.contains_null && field.lower_bound.is_none(); + + if all_null && r#type.is_floating_type() { + // floating point types may include NaN values, which we check separately. + // In case bounds don't include NaN value, contains_nan needs to be checked against. + all_null = match field.contains_nan { + Some(val) => !val, + None => false, + } + } + + all_null + } + + fn datum_as_str<'a>(bound: &'a Datum, err_msg: &str) -> crate::Result<&'a String> { + let PrimitiveLiteral::String(bound) = bound.literal() else { + return Err(Error::new(ErrorKind::Unexpected, err_msg)); + }; + Ok(bound) + } +} + +#[cfg(test)] +mod test { + use std::ops::Not; + use std::sync::Arc; + + use fnv::FnvHashSet; + + use crate::expr::visitors::manifest_evaluator::ManifestEvaluator; + use crate::expr::{ + BinaryExpression, Bind, Predicate, PredicateOperator, Reference, SetExpression, + UnaryExpression, + }; + use crate::spec::{ + Datum, FieldSummary, ManifestContentType, ManifestFile, NestedField, PrimitiveType, Schema, + SchemaRef, Type, + }; + use crate::Result; + + const INT_MIN_VALUE: i32 = 30; + const INT_MAX_VALUE: i32 = 79; + + const STRING_MIN_VALUE: &str = "a"; + const STRING_MAX_VALUE: &str = "z"; + + fn create_schema() -> Result { + let schema = Schema::builder() + .with_fields(vec![ + Arc::new(NestedField::required( + 1, + "id", + Type::Primitive(PrimitiveType::Int), + )), + Arc::new(NestedField::optional( + 2, + "all_nulls_missing_nan", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 3, + "some_nulls", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 4, + "no_nulls", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 5, + "float", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 6, + "all_nulls_double", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 7, + "all_nulls_no_nans", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 8, + "all_nans", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 9, + "both_nan_and_null", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 10, + "no_nan_or_null", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 11, + "all_nulls_missing_nan_float", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 12, + "all_same_value_or_null", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 13, + "no_nulls_same_value_a", + Type::Primitive(PrimitiveType::String), + )), + ]) + .build()?; + + Ok(Arc::new(schema)) + } + + fn create_partitions() -> Vec { + vec![ + // id + FieldSummary { + contains_null: false, + contains_nan: None, + lower_bound: Some(Datum::int(INT_MIN_VALUE)), + upper_bound: Some(Datum::int(INT_MAX_VALUE)), + }, + // all_nulls_missing_nan + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: None, + upper_bound: None, + }, + // some_nulls + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: Some(Datum::string(STRING_MIN_VALUE)), + upper_bound: Some(Datum::string(STRING_MAX_VALUE)), + }, + // no_nulls + FieldSummary { + contains_null: false, + contains_nan: None, + lower_bound: Some(Datum::string(STRING_MIN_VALUE)), + upper_bound: Some(Datum::string(STRING_MAX_VALUE)), + }, + // float + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: Some(Datum::float(0.0)), + upper_bound: Some(Datum::float(20.0)), + }, + // all_nulls_double + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: None, + upper_bound: None, + }, + // all_nulls_no_nans + FieldSummary { + contains_null: true, + contains_nan: Some(false), + lower_bound: None, + upper_bound: None, + }, + // all_nans + FieldSummary { + contains_null: false, + contains_nan: Some(true), + lower_bound: None, + upper_bound: None, + }, + // both_nan_and_null + FieldSummary { + contains_null: true, + contains_nan: Some(true), + lower_bound: None, + upper_bound: None, + }, + // no_nan_or_null + FieldSummary { + contains_null: false, + contains_nan: Some(false), + lower_bound: Some(Datum::float(0.0)), + upper_bound: Some(Datum::float(20.0)), + }, + // all_nulls_missing_nan_float + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: None, + upper_bound: None, + }, + // all_same_value_or_null + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: Some(Datum::string(STRING_MIN_VALUE)), + upper_bound: Some(Datum::string(STRING_MIN_VALUE)), + }, + // no_nulls_same_value_a + FieldSummary { + contains_null: false, + contains_nan: None, + lower_bound: Some(Datum::string(STRING_MIN_VALUE)), + upper_bound: Some(Datum::string(STRING_MIN_VALUE)), + }, + ] + } + + fn create_manifest_file(partitions: Vec) -> ManifestFile { + ManifestFile { + manifest_path: "/test/path".to_string(), + manifest_length: 0, + partition_spec_id: 1, + content: ManifestContentType::Data, + sequence_number: 0, + min_sequence_number: 0, + added_snapshot_id: 0, + added_files_count: None, + existing_files_count: None, + deleted_files_count: None, + added_rows_count: None, + existing_rows_count: None, + deleted_rows_count: None, + partitions, + key_metadata: vec![], + } + } + + #[test] + fn test_always_true() -> Result<()> { + let case_sensitive = false; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::AlwaysTrue.bind(schema.clone(), case_sensitive)?; + + assert!(ManifestEvaluator::new(filter).eval(&manifest_file)?); + + Ok(()) + } + + #[test] + fn test_always_false() -> Result<()> { + let case_sensitive = false; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::AlwaysFalse.bind(schema.clone(), case_sensitive)?; + + assert!(!ManifestEvaluator::new(filter).eval(&manifest_file)?); + + Ok(()) + } + + #[test] + fn test_all_nulls() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + // all_nulls_missing_nan + let all_nulls_missing_nan_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("all_nulls_missing_nan"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(all_nulls_missing_nan_filter).eval(&manifest_file)?, + "Should skip: all nulls column with non-floating type contains all null" + ); + + // all_nulls_missing_nan_float + let all_nulls_missing_nan_float_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("all_nulls_missing_nan_float"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_missing_nan_float_filter).eval(&manifest_file)?, + "Should read: no NaN information may indicate presence of NaN value" + ); + + // some_nulls + let some_nulls_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("some_nulls"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(some_nulls_filter).eval(&manifest_file)?, + "Should read: column with some nulls contains a non-null value" + ); + + // no_nulls + let no_nulls_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("no_nulls"), + )) + .bind(schema.clone(), case_sensitive)?; + + assert!( + ManifestEvaluator::new(no_nulls_filter).eval(&manifest_file)?, + "Should read: non-null column contains a non-null value" + ); + + Ok(()) + } + + #[test] + fn test_no_nulls() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + // all_nulls_missing_nan + let all_nulls_missing_nan_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("all_nulls_missing_nan"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_missing_nan_filter).eval(&manifest_file)?, + "Should read: at least one null value in all null column" + ); + + // some_nulls + let some_nulls_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("some_nulls"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(some_nulls_filter).eval(&manifest_file)?, + "Should read: column with some nulls contains a null value" + ); + + // no_nulls + let no_nulls_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("no_nulls"), + )) + .bind(schema.clone(), case_sensitive)?; + + assert!( + !ManifestEvaluator::new(no_nulls_filter).eval(&manifest_file)?, + "Should skip: non-null column contains no null values" + ); + + // both_nan_and_null + let both_nan_and_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("both_nan_and_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(both_nan_and_null_filter).eval(&manifest_file)?, + "Should read: both_nan_and_null column contains no null values" + ); + + Ok(()) + } + + #[test] + fn test_is_nan() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + // float + let float_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("float"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(float_filter).eval(&manifest_file)?, + "Should read: no information on if there are nan value in float column" + ); + + // all_nulls_double + let all_nulls_double_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("all_nulls_double"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_double_filter).eval(&manifest_file)?, + "Should read: no NaN information may indicate presence of NaN value" + ); + + // all_nulls_missing_nan_float + let all_nulls_missing_nan_float_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("all_nulls_missing_nan_float"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_missing_nan_float_filter).eval(&manifest_file)?, + "Should read: no NaN information may indicate presence of NaN value" + ); + + // all_nulls_no_nans + let all_nulls_no_nans_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("all_nulls_no_nans"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(all_nulls_no_nans_filter).eval(&manifest_file)?, + "Should skip: no nan column doesn't contain nan value" + ); + + // all_nans + let all_nans_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("all_nans"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nans_filter).eval(&manifest_file)?, + "Should read: all_nans column contains nan value" + ); + + // both_nan_and_null + let both_nan_and_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("both_nan_and_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(both_nan_and_null_filter).eval(&manifest_file)?, + "Should read: both_nan_and_null column contains nan value" + ); + + // no_nan_or_null + let no_nan_or_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("no_nan_or_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(no_nan_or_null_filter).eval(&manifest_file)?, + "Should skip: no_nan_or_null column doesn't contain nan value" + ); + + Ok(()) + } + + #[test] + fn test_not_nan() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + // float + let float_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("float"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(float_filter).eval(&manifest_file)?, + "Should read: no information on if there are nan value in float column" + ); + + // all_nulls_double + let all_nulls_double_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("all_nulls_double"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_double_filter).eval(&manifest_file)?, + "Should read: all null column contains non nan value" + ); + + // all_nulls_no_nans + let all_nulls_no_nans_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("all_nulls_no_nans"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_no_nans_filter).eval(&manifest_file)?, + "Should read: no_nans column contains non nan value" + ); + + // all_nans + let all_nans_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("all_nans"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(all_nans_filter).eval(&manifest_file)?, + "Should skip: all nans column doesn't contain non nan value" + ); + + // both_nan_and_null + let both_nan_and_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("both_nan_and_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(both_nan_and_null_filter).eval(&manifest_file)?, + "Should read: both_nan_and_null nans column contains non nan value" + ); + + // no_nan_or_null + let no_nan_or_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("no_nan_or_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(no_nan_or_null_filter).eval(&manifest_file)?, + "Should read: no_nan_or_null column contains non nan value" + ); + + Ok(()) + } + + #[test] + fn test_and() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .and(Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 30), + ))) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: no information on if there are nan value in float column" + ); + + Ok(()) + } + + #[test] + fn test_or() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .or(Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MAX_VALUE + 1), + ))) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should skip: or(false, false)" + ); + + Ok(()) + } + + #[test] + fn test_not() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .not() + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: not(false)" + ); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .not() + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should skip: not(true)" + ); + + Ok(()) + } + + #[test] + fn test_less_than() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id range below lower bound (5 < 30)" + ); + + Ok(()) + } + + #[test] + fn test_less_than_or_eq() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThanOrEq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id range below lower bound (5 < 30)" + ); + + Ok(()) + } + + #[test] + fn test_greater_than() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThan, + Reference::new("id"), + Datum::int(INT_MAX_VALUE + 6), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id range above upper bound (85 < 79)" + ); + + Ok(()) + } + + #[test] + fn test_greater_than_or_eq() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MAX_VALUE + 6), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id range above upper bound (85 < 79)" + ); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MAX_VALUE), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: one possible id" + ); + + Ok(()) + } + + #[test] + fn test_eq() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id below lower bound" + ); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: id equal to lower bound" + ); + + Ok(()) + } + + #[test] + fn test_not_eq() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotEq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: id below lower bound" + ); + + Ok(()) + } + + #[test] + fn test_in() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Set(SetExpression::new( + PredicateOperator::In, + Reference::new("id"), + FnvHashSet::from_iter(vec![ + Datum::int(INT_MIN_VALUE - 25), + Datum::int(INT_MIN_VALUE - 24), + ]), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id below lower bound (5 < 30, 6 < 30)" + ); + + let filter = Predicate::Set(SetExpression::new( + PredicateOperator::In, + Reference::new("id"), + FnvHashSet::from_iter(vec![ + Datum::int(INT_MIN_VALUE - 1), + Datum::int(INT_MIN_VALUE), + ]), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: id equal to lower bound (30 == 30)" + ); + + Ok(()) + } + + #[test] + fn test_not_in() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Set(SetExpression::new( + PredicateOperator::NotIn, + Reference::new("id"), + FnvHashSet::from_iter(vec![ + Datum::int(INT_MIN_VALUE - 25), + Datum::int(INT_MIN_VALUE - 24), + ]), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: id below lower bound (5 < 30, 6 < 30)" + ); + + Ok(()) + } + + #[test] + fn test_starts_with() -> Result<()> { + let case_sensitive = false; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::StartsWith, + Reference::new("some_nulls"), + Datum::string("a"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: range matches" + ); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::StartsWith, + Reference::new("some_nulls"), + Datum::string("zzzz"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should skip: range doesn't match" + ); + + Ok(()) + } + + #[test] + fn test_not_starts_with() -> Result<()> { + let case_sensitive = false; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotStartsWith, + Reference::new("some_nulls"), + Datum::string("a"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: range matches" + ); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotStartsWith, + Reference::new("no_nulls_same_value_a"), + Datum::string("a"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: all values start with the prefix" + ); + + Ok(()) + } +} diff --git a/crates/iceberg/src/expr/visitors/mod.rs b/crates/iceberg/src/expr/visitors/mod.rs new file mode 100644 index 000000000..06bfd8cda --- /dev/null +++ b/crates/iceberg/src/expr/visitors/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod bound_predicate_visitor; +pub(crate) mod expression_evaluator; +pub(crate) mod inclusive_metrics_evaluator; +pub(crate) mod inclusive_projection; +pub(crate) mod manifest_evaluator; +pub(crate) mod row_group_metrics_evaluator; diff --git a/crates/iceberg/src/expr/visitors/row_group_metrics_evaluator.rs b/crates/iceberg/src/expr/visitors/row_group_metrics_evaluator.rs new file mode 100644 index 000000000..4bf53d6ee --- /dev/null +++ b/crates/iceberg/src/expr/visitors/row_group_metrics_evaluator.rs @@ -0,0 +1,1872 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Evaluates Parquet Row Group metrics + +use std::collections::HashMap; + +use fnv::FnvHashSet; +use parquet::file::metadata::RowGroupMetaData; +use parquet::file::statistics::Statistics; + +use crate::arrow::{get_parquet_stat_max_as_datum, get_parquet_stat_min_as_datum}; +use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; +use crate::expr::{BoundPredicate, BoundReference}; +use crate::spec::{Datum, PrimitiveLiteral, PrimitiveType, Schema}; +use crate::{Error, ErrorKind, Result}; + +pub(crate) struct RowGroupMetricsEvaluator<'a> { + row_group_metadata: &'a RowGroupMetaData, + iceberg_field_id_to_parquet_column_index: &'a HashMap, + snapshot_schema: &'a Schema, +} + +const IN_PREDICATE_LIMIT: usize = 200; +const ROW_GROUP_MIGHT_MATCH: Result = Ok(true); +const ROW_GROUP_CANT_MATCH: Result = Ok(false); + +impl<'a> RowGroupMetricsEvaluator<'a> { + fn new( + row_group_metadata: &'a RowGroupMetaData, + field_id_map: &'a HashMap, + snapshot_schema: &'a Schema, + ) -> Self { + Self { + row_group_metadata, + iceberg_field_id_to_parquet_column_index: field_id_map, + snapshot_schema, + } + } + + /// Evaluate this `RowGroupMetricsEvaluator`'s filter predicate against the + /// provided [`RowGroupMetaData`]'. Used by [`ArrowReader`] to + /// see if a Parquet file RowGroup could possibly contain data that matches + /// the scan's filter. + pub(crate) fn eval( + filter: &'a BoundPredicate, + row_group_metadata: &'a RowGroupMetaData, + field_id_map: &'a HashMap, + snapshot_schema: &'a Schema, + ) -> Result { + if row_group_metadata.num_rows() == 0 { + return ROW_GROUP_CANT_MATCH; + } + + let mut evaluator = Self::new(row_group_metadata, field_id_map, snapshot_schema); + + visit(&mut evaluator, filter) + } + + fn stats_for_field_id(&self, field_id: i32) -> Option<&Statistics> { + let parquet_column_index = *self + .iceberg_field_id_to_parquet_column_index + .get(&field_id)?; + self.row_group_metadata + .column(parquet_column_index) + .statistics() + } + + fn null_count(&self, field_id: i32) -> Option { + self.stats_for_field_id(field_id) + .map(|stats| stats.null_count()) + } + + fn value_count(&self) -> u64 { + self.row_group_metadata.num_rows() as u64 + } + + fn contains_nulls_only(&self, field_id: i32) -> bool { + let null_count = self.null_count(field_id); + let value_count = self.value_count(); + + null_count == Some(value_count) + } + + fn may_contain_null(&self, field_id: i32) -> bool { + if let Some(null_count) = self.null_count(field_id) { + null_count > 0 + } else { + true + } + } + + fn stats_and_type_for_field_id( + &self, + field_id: i32, + ) -> Result> { + let Some(stats) = self.stats_for_field_id(field_id) else { + // No statistics for column + return Ok(None); + }; + + let Some(field) = self.snapshot_schema.field_by_id(field_id) else { + return Err(Error::new( + ErrorKind::Unexpected, + format!( + "Could not find a field with id '{}' in the snapshot schema", + &field_id + ), + )); + }; + + let Some(primitive_type) = field.field_type.as_primitive_type() else { + return Err(Error::new( + ErrorKind::Unexpected, + format!( + "Could not determine the PrimitiveType for field id '{}'", + &field_id + ), + )); + }; + + Ok(Some((stats, primitive_type.clone()))) + } + + fn min_value(&self, field_id: i32) -> Result> { + let Some((stats, primitive_type)) = self.stats_and_type_for_field_id(field_id)? else { + return Ok(None); + }; + + if !stats.has_min_max_set() { + return Ok(None); + } + + get_parquet_stat_min_as_datum(&primitive_type, stats) + } + + fn max_value(&self, field_id: i32) -> Result> { + let Some((stats, primitive_type)) = self.stats_and_type_for_field_id(field_id)? else { + return Ok(None); + }; + + if !stats.has_min_max_set() { + return Ok(None); + } + + get_parquet_stat_max_as_datum(&primitive_type, stats) + } + + fn visit_inequality( + &mut self, + reference: &BoundReference, + datum: &Datum, + cmp_fn: fn(&Datum, &Datum) -> bool, + use_lower_bound: bool, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + if datum.is_nan() { + // NaN indicates unreliable bounds. + // See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } + + let bound = if use_lower_bound { + self.min_value(field_id) + } else { + self.max_value(field_id) + }?; + + if let Some(bound) = bound { + if cmp_fn(&bound, datum) { + return ROW_GROUP_MIGHT_MATCH; + } + + return ROW_GROUP_CANT_MATCH; + } + + ROW_GROUP_MIGHT_MATCH + } +} + +impl BoundPredicateVisitor for RowGroupMetricsEvaluator<'_> { + type T = bool; + + fn always_true(&mut self) -> Result { + ROW_GROUP_MIGHT_MATCH + } + + fn always_false(&mut self) -> Result { + ROW_GROUP_CANT_MATCH + } + + fn and(&mut self, lhs: bool, rhs: bool) -> Result { + Ok(lhs && rhs) + } + + fn or(&mut self, lhs: bool, rhs: bool) -> Result { + Ok(lhs || rhs) + } + + fn not(&mut self, inner: bool) -> Result { + Ok(!inner) + } + + fn is_null(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result { + let field_id = reference.field().id; + + match self.null_count(field_id) { + Some(0) => ROW_GROUP_CANT_MATCH, + Some(_) => ROW_GROUP_MIGHT_MATCH, + None => ROW_GROUP_MIGHT_MATCH, + } + } + + fn not_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + ROW_GROUP_MIGHT_MATCH + } + + fn is_nan(&mut self, _reference: &BoundReference, _predicate: &BoundPredicate) -> Result { + // NaN counts not in ColumnChunkMetadata Statistics + ROW_GROUP_MIGHT_MATCH + } + + fn not_nan( + &mut self, + _reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result { + // NaN counts not in ColumnChunkMetadata Statistics + ROW_GROUP_MIGHT_MATCH + } + + fn less_than( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + self.visit_inequality(reference, datum, PartialOrd::lt, true) + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + self.visit_inequality(reference, datum, PartialOrd::le, true) + } + + fn greater_than( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + self.visit_inequality(reference, datum, PartialOrd::gt, false) + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + self.visit_inequality(reference, datum, PartialOrd::ge, false) + } + + fn eq( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + if let Some(lower_bound) = self.min_value(field_id)? { + if lower_bound.is_nan() { + // NaN indicates unreliable bounds. + // See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } else if lower_bound.gt(datum) { + return ROW_GROUP_CANT_MATCH; + } + } + + if let Some(upper_bound) = self.max_value(field_id)? { + if upper_bound.is_nan() { + // NaN indicates unreliable bounds. + // See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } else if upper_bound.lt(datum) { + return ROW_GROUP_CANT_MATCH; + } + } + + ROW_GROUP_MIGHT_MATCH + } + + fn not_eq( + &mut self, + _reference: &BoundReference, + _datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + // Because the bounds are not necessarily a min or max value, + // this cannot be answered using them. notEq(col, X) with (X, Y) + // doesn't guarantee that X is a value in col. + ROW_GROUP_MIGHT_MATCH + } + + fn starts_with( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + let PrimitiveLiteral::String(datum) = datum.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string values", + )); + }; + + if let Some(lower_bound) = self.min_value(field_id)? { + let PrimitiveLiteral::String(lower_bound) = lower_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string lower_bound value", + )); + }; + + let prefix_length = lower_bound.chars().count().min(datum.chars().count()); + + // truncate lower bound so that its length + // is not greater than the length of prefix + let truncated_lower_bound = lower_bound.chars().take(prefix_length).collect::(); + if datum < &truncated_lower_bound { + return ROW_GROUP_CANT_MATCH; + } + } + + if let Some(upper_bound) = self.max_value(field_id)? { + let PrimitiveLiteral::String(upper_bound) = upper_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string upper_bound value", + )); + }; + + let prefix_length = upper_bound.chars().count().min(datum.chars().count()); + + // truncate upper bound so that its length + // is not greater than the length of prefix + let truncated_upper_bound = upper_bound.chars().take(prefix_length).collect::(); + if datum > &truncated_upper_bound { + return ROW_GROUP_CANT_MATCH; + } + } + + ROW_GROUP_MIGHT_MATCH + } + + fn not_starts_with( + &mut self, + reference: &BoundReference, + datum: &Datum, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.may_contain_null(field_id) { + return ROW_GROUP_MIGHT_MATCH; + } + + // notStartsWith will match unless all values must start with the prefix. + // This happens when the lower and upper bounds both start with the prefix. + + let PrimitiveLiteral::String(prefix) = datum.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use StartsWith operator on non-string values", + )); + }; + + let Some(lower_bound) = self.min_value(field_id)? else { + return ROW_GROUP_MIGHT_MATCH; + }; + + let PrimitiveLiteral::String(lower_bound_str) = lower_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use NotStartsWith operator on non-string lower_bound value", + )); + }; + + if lower_bound_str < prefix { + // if lower is shorter than the prefix then lower doesn't start with the prefix + return ROW_GROUP_MIGHT_MATCH; + } + + let prefix_len = prefix.chars().count(); + + if lower_bound_str.chars().take(prefix_len).collect::() == *prefix { + // lower bound matches the prefix + + let Some(upper_bound) = self.max_value(field_id)? else { + return ROW_GROUP_MIGHT_MATCH; + }; + + let PrimitiveLiteral::String(upper_bound) = upper_bound.literal() else { + return Err(Error::new( + ErrorKind::Unexpected, + "Cannot use NotStartsWith operator on non-string upper_bound value", + )); + }; + + // if upper is shorter than the prefix then upper can't start with the prefix + if upper_bound.chars().count() < prefix_len { + return ROW_GROUP_MIGHT_MATCH; + } + + if upper_bound.chars().take(prefix_len).collect::() == *prefix { + // both bounds match the prefix, so all rows must match the + // prefix and therefore do not satisfy the predicate + return ROW_GROUP_CANT_MATCH; + } + } + + ROW_GROUP_MIGHT_MATCH + } + + fn r#in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result { + let field_id = reference.field().id; + + if self.contains_nulls_only(field_id) { + return ROW_GROUP_CANT_MATCH; + } + + if literals.len() > IN_PREDICATE_LIMIT { + // skip evaluating the predicate if the number of values is too big + return ROW_GROUP_MIGHT_MATCH; + } + + if let Some(lower_bound) = self.min_value(field_id)? { + if lower_bound.is_nan() { + // NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } + + if !literals.iter().any(|datum| datum.ge(&lower_bound)) { + // if all values are less than lower bound, rows cannot match. + return ROW_GROUP_CANT_MATCH; + } + } + + if let Some(upper_bound) = self.max_value(field_id)? { + if upper_bound.is_nan() { + // NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. + return ROW_GROUP_MIGHT_MATCH; + } + + if !literals.iter().any(|datum| datum.le(&upper_bound)) { + // if all values are greater than upper bound, rows cannot match. + return ROW_GROUP_CANT_MATCH; + } + } + + ROW_GROUP_MIGHT_MATCH + } + + fn not_in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result { + // Because the bounds are not necessarily a min or max value, + // this cannot be answered using them. notIn(col, {X, ...}) + // with (X, Y) doesn't guarantee that X is a value in col. + ROW_GROUP_MIGHT_MATCH + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use parquet::basic::{LogicalType as ParquetLogicalType, Type as ParquetPhysicalType}; + use parquet::data_type::ByteArray; + use parquet::file::metadata::{ColumnChunkMetaData, RowGroupMetaData}; + use parquet::file::statistics::Statistics; + use parquet::schema::types::{ + ColumnDescriptor, ColumnPath, SchemaDescriptor, Type as parquetSchemaType, + }; + use rand::{thread_rng, Rng}; + + use super::RowGroupMetricsEvaluator; + use crate::expr::{Bind, Reference}; + use crate::spec::{Datum, NestedField, PrimitiveType, Schema, Type}; + use crate::Result; + + #[test] + fn eval_matches_no_rows_for_empty_row_group() -> Result<()> { + let row_group_metadata = create_row_group_metadata(0, 0, None, 0, None)?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + + Ok(()) + } + + #[test] + fn eval_true_for_row_group_no_bounds_present() -> Result<()> { + let row_group_metadata = create_row_group_metadata(1, 1, None, 1, None)?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_null_filter_not_null() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 1, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_not_null() + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_all_null_filter_is_null() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 1, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_null() + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_none_null_filter_not_null() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_not_null() + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_none_null_filter_is_null() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_null() + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_nulls_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 1, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_datum_nan_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(f32::NAN)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_missing_bound_valid_other_bound_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_failing_bound_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(0.9), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_passing_bound_filter_inequality() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .greater_than(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_nulls_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(None, None, None, 1, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_lower_nan_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(f32::NAN), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_lower_gt_than_datum_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(1.5), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_upper_nan_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(f32::NAN), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_upper_lt_than_datum_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(0.5), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_good_bounds_than_datum_filter_eq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(2.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_bounds_eq_datum_filter_neq() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(1.0), Some(1.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .not_equal_to(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_nulls_filter_starts_with() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array(None, None, None, 1, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_error_for_starts_with_non_string_filter_datum() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array(None, None, None, 0, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .starts_with(Datum::float(1.0)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_error_for_starts_with_non_utf8_lower_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // min val of 0xff is not valid utf-8 string. Max val of 0x20 is valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from(vec![255u8])), + Some(ByteArray::from(vec![32u8])), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_error_for_starts_with_non_utf8_upper_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("ice".as_bytes())), + Some(ByteArray::from(vec![255u8])), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_false_for_starts_with_meta_all_nulls() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array(None, None, None, 1, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_false_for_starts_with_datum_below_min_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("id".as_bytes())), + Some(ByteArray::from("ie".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_false_for_starts_with_datum_above_max_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("h".as_bytes())), + Some(ByteArray::from("ib".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_starts_with_datum_between_bounds() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("h".as_bytes())), + Some(ByteArray::from("j".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_meta_all_nulls_filter_not_starts_with() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array(None, None, None, 1, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_error_for_not_starts_with_non_utf8_lower_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // min val of 0xff is not valid utf-8 string. Max val of 0x20 is valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from(vec![255u8])), + Some(ByteArray::from(vec![32u8])), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_error_for_not_starts_with_non_utf8_upper_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from(vec![255u8])), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + ); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn eval_true_for_not_starts_with_no_min_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + None, + Some(ByteArray::from("iceberg".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_not_starts_with_datum_longer_min_max_bound() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("ice".as_bytes())), + Some(ByteArray::from("iceberg".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_not_starts_with_datum_matches_lower_no_upper() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + None, + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_not_starts_with_datum_matches_lower_upper_shorter() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from("icy".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_not_starts_with_datum_matches_lower_and_upper() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from("iceberg".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .not_starts_with(Datum::string("iceberg")) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_false_for_meta_all_nulls_filter_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from("iceberg".as_bytes())), + None, + 1, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .is_in([Datum::string("ice"), Datum::string("berg")]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_too_many_literals_filter_is_in() -> Result<()> { + let mut rng = thread_rng(); + + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(11.0), Some(12.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in(std::iter::repeat_with(|| Datum::float(rng.gen_range(0.0..10.0))).take(1000)) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_missing_bounds_filter_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + Some(Statistics::byte_array(None, None, None, 0, false)), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .is_in([Datum::string("ice")]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_lower_bound_is_nan_filter_is_in() -> Result<()> { + // TODO: should this be false, since the max stat + // is lower than the min val in the set? + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(f32::NAN), Some(1.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in([Datum::float(2.0), Datum::float(3.0)]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_lower_bound_greater_than_all_vals_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(4.0), None, None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in([Datum::float(2.0), Datum::float(3.0)]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_true_for_nan_upper_bound_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(f32::NAN), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in([Datum::float(2.0), Datum::float(3.0)]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + #[test] + fn eval_false_for_upper_bound_below_all_vals_is_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + Some(Statistics::float(Some(0.0), Some(1.0), None, 0, false)), + 1, + None, + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_float") + .is_in([Datum::float(2.0), Datum::float(3.0)]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(!result); + Ok(()) + } + + #[test] + fn eval_true_for_not_in() -> Result<()> { + let row_group_metadata = create_row_group_metadata( + 1, + 1, + None, + 1, + // Max val of 0xFF is not valid utf8 + Some(Statistics::byte_array( + Some(ByteArray::from("iceberg".as_bytes())), + Some(ByteArray::from("iceberg".as_bytes())), + None, + 0, + false, + )), + )?; + + let (iceberg_schema_ref, field_id_map) = build_iceberg_schema_and_field_map()?; + + let filter = Reference::new("col_string") + .is_not_in([Datum::string("iceberg")]) + .bind(iceberg_schema_ref.clone(), false)?; + + let result = RowGroupMetricsEvaluator::eval( + &filter, + &row_group_metadata, + &field_id_map, + iceberg_schema_ref.as_ref(), + )?; + + assert!(result); + Ok(()) + } + + fn build_iceberg_schema_and_field_map() -> Result<(Arc, HashMap)> { + let iceberg_schema = Schema::builder() + .with_fields([ + Arc::new(NestedField::new( + 1, + "col_float", + Type::Primitive(PrimitiveType::Float), + false, + )), + Arc::new(NestedField::new( + 2, + "col_string", + Type::Primitive(PrimitiveType::String), + false, + )), + ]) + .build()?; + let iceberg_schema_ref = Arc::new(iceberg_schema); + + let field_id_map = HashMap::from_iter([(1, 0), (2, 1)]); + + Ok((iceberg_schema_ref, field_id_map)) + } + + fn build_parquet_schema_descriptor() -> Result> { + let field_1 = Arc::new( + parquetSchemaType::primitive_type_builder("col_float", ParquetPhysicalType::FLOAT) + .with_id(Some(1)) + .build()?, + ); + + let field_2 = Arc::new( + parquetSchemaType::primitive_type_builder( + "col_string", + ParquetPhysicalType::BYTE_ARRAY, + ) + .with_id(Some(2)) + .with_logical_type(Some(ParquetLogicalType::String)) + .build()?, + ); + + let group_type = Arc::new( + parquetSchemaType::group_type_builder("all") + .with_id(Some(1000)) + .with_fields(vec![field_1, field_2]) + .build()?, + ); + + let schema_descriptor = SchemaDescriptor::new(group_type); + let schema_descriptor_arc = Arc::new(schema_descriptor); + Ok(schema_descriptor_arc) + } + + fn create_row_group_metadata( + num_rows: i64, + col_1_num_vals: i64, + col_1_stats: Option, + col_2_num_vals: i64, + col_2_stats: Option, + ) -> Result { + let schema_descriptor_arc = build_parquet_schema_descriptor()?; + + let column_1_desc_ptr = Arc::new(ColumnDescriptor::new( + schema_descriptor_arc.column(0).self_type_ptr(), + 1, + 1, + ColumnPath::new(vec!["col_float".to_string()]), + )); + + let column_2_desc_ptr = Arc::new(ColumnDescriptor::new( + schema_descriptor_arc.column(1).self_type_ptr(), + 1, + 1, + ColumnPath::new(vec!["col_string".to_string()]), + )); + + let mut col_1_meta = + ColumnChunkMetaData::builder(column_1_desc_ptr).set_num_values(col_1_num_vals); + if let Some(stats1) = col_1_stats { + col_1_meta = col_1_meta.set_statistics(stats1) + } + + let mut col_2_meta = + ColumnChunkMetaData::builder(column_2_desc_ptr).set_num_values(col_2_num_vals); + if let Some(stats2) = col_2_stats { + col_2_meta = col_2_meta.set_statistics(stats2) + } + + let row_group_metadata = RowGroupMetaData::builder(schema_descriptor_arc) + .set_num_rows(num_rows) + .set_column_metadata(vec![ + col_1_meta.build()?, + // .set_statistics(Statistics::float(None, None, None, 1, false)) + col_2_meta.build()?, + ]) + .build(); + + Ok(row_group_metadata?) + } +} diff --git a/crates/iceberg/src/io.rs b/crates/iceberg/src/io/file_io.rs similarity index 60% rename from crates/iceberg/src/io.rs rename to crates/iceberg/src/io/file_io.rs index 3a7c85f42..9af398270 100644 --- a/crates/iceberg/src/io.rs +++ b/crates/iceberg/src/io/file_io.rs @@ -15,69 +15,16 @@ // specific language governing permissions and limitations // under the License. -//! File io implementation. -//! -//! # How to build `FileIO` -//! -//! We provided a `FileIOBuilder` to build `FileIO` from scratch. For example: -//! ```rust -//! use iceberg::io::{FileIOBuilder, S3_REGION}; -//! -//! let file_io = FileIOBuilder::new("s3") -//! .with_prop(S3_REGION, "us-east-1") -//! .build() -//! .unwrap(); -//! ``` -//! -//! Or you can pass a path to ask `FileIO` to infer schema for you: -//! ```rust -//! use iceberg::io::{FileIO, S3_REGION}; -//! let file_io = FileIO::from_path("s3://bucket/a") -//! .unwrap() -//! .with_prop(S3_REGION, "us-east-1") -//! .build() -//! .unwrap(); -//! ``` -//! -//! # How to use `FileIO` -//! -//! Currently `FileIO` provides simple methods for file operations: -//! -//! - `delete`: Delete file. -//! - `is_exist`: Check if file exists. -//! - `new_input`: Create input file for reading. -//! - `new_output`: Create output file for writing. - -use std::{collections::HashMap, sync::Arc}; - -use crate::{error::Result, Error, ErrorKind}; -use futures::{AsyncRead, AsyncSeek, AsyncWrite}; -use once_cell::sync::Lazy; -use opendal::{Operator, Scheme}; +use std::collections::HashMap; +use std::ops::Range; +use std::sync::Arc; + +use bytes::Bytes; +use opendal::Operator; use url::Url; -/// Following are arguments for [s3 file io](https://py.iceberg.apache.org/configuration/#s3). -/// S3 endopint. -pub const S3_ENDPOINT: &str = "s3.endpoint"; -/// S3 access key id. -pub const S3_ACCESS_KEY_ID: &str = "s3.access-key-id"; -/// S3 secret access key. -pub const S3_SECRET_ACCESS_KEY: &str = "s3.secret-access-key"; -/// S3 region. -pub const S3_REGION: &str = "s3.region"; - -/// A mapping from iceberg s3 configuration key to [`opendal::Operator`] configuration key. -static S3_CONFIG_MAPPING: Lazy> = Lazy::new(|| { - let mut m = HashMap::with_capacity(4); - m.insert(S3_ENDPOINT, "endpoint"); - m.insert(S3_ACCESS_KEY_ID, "access_key_id"); - m.insert(S3_SECRET_ACCESS_KEY, "secret_access_key"); - m.insert(S3_REGION, "region"); - - m -}); - -const DEFAULT_ROOT_PATH: &str = "/"; +use super::storage::Storage; +use crate::{Error, ErrorKind, Result}; /// FileIO implementation, used to manipulate files in underlying storage. /// @@ -90,59 +37,6 @@ pub struct FileIO { inner: Arc, } -/// Builder for [`FileIO`]. -#[derive(Debug)] -pub struct FileIOBuilder { - /// This is used to infer scheme of operator. - /// - /// If this is `None`, then [`FileIOBuilder::build`](FileIOBuilder::build) will build a local file io. - scheme_str: Option, - /// Arguments for operator. - props: HashMap, -} - -impl FileIOBuilder { - /// Creates a new builder with scheme. - pub fn new(scheme_str: impl ToString) -> Self { - Self { - scheme_str: Some(scheme_str.to_string()), - props: HashMap::default(), - } - } - - /// Creates a new builder for local file io. - pub fn new_fs_io() -> Self { - Self { - scheme_str: None, - props: HashMap::default(), - } - } - - /// Add argument for operator. - pub fn with_prop(mut self, key: impl ToString, value: impl ToString) -> Self { - self.props.insert(key.to_string(), value.to_string()); - self - } - - /// Add argument for operator. - pub fn with_props( - mut self, - args: impl IntoIterator, - ) -> Self { - self.props - .extend(args.into_iter().map(|e| (e.0.to_string(), e.1.to_string()))); - self - } - - /// Builds [`FileIO`]. - pub fn build(self) -> Result { - let storage = Storage::build(self)?; - Ok(FileIO { - inner: Arc::new(storage), - }) - } -} - impl FileIO { /// Try to infer file io scheme from path. /// @@ -150,7 +44,7 @@ impl FileIO { /// If it's not a valid url, will try to detect if it's a file path. /// /// Otherwise will return parsing error. - pub fn from_path(path: impl AsRef) -> Result { + pub fn from_path(path: impl AsRef) -> crate::Result { let url = Url::parse(path.as_ref()) .map_err(Error::from) .or_else(|e| { @@ -204,6 +98,95 @@ impl FileIO { } } +/// Builder for [`FileIO`]. +#[derive(Debug)] +pub struct FileIOBuilder { + /// This is used to infer scheme of operator. + /// + /// If this is `None`, then [`FileIOBuilder::build`](FileIOBuilder::build) will build a local file io. + scheme_str: Option, + /// Arguments for operator. + props: HashMap, +} + +impl FileIOBuilder { + /// Creates a new builder with scheme. + pub fn new(scheme_str: impl ToString) -> Self { + Self { + scheme_str: Some(scheme_str.to_string()), + props: HashMap::default(), + } + } + + /// Creates a new builder for local file io. + pub fn new_fs_io() -> Self { + Self { + scheme_str: None, + props: HashMap::default(), + } + } + + /// Fetch the scheme string. + /// + /// The scheme_str will be empty if it's None. + pub(crate) fn into_parts(self) -> (String, HashMap) { + (self.scheme_str.unwrap_or_default(), self.props) + } + + /// Add argument for operator. + pub fn with_prop(mut self, key: impl ToString, value: impl ToString) -> Self { + self.props.insert(key.to_string(), value.to_string()); + self + } + + /// Add argument for operator. + pub fn with_props( + mut self, + args: impl IntoIterator, + ) -> Self { + self.props + .extend(args.into_iter().map(|e| (e.0.to_string(), e.1.to_string()))); + self + } + + /// Builds [`FileIO`]. + pub fn build(self) -> crate::Result { + let storage = Storage::build(self)?; + Ok(FileIO { + inner: Arc::new(storage), + }) + } +} + +/// The struct the represents the metadata of a file. +/// +/// TODO: we can add last modified time, content type, etc. in the future. +pub struct FileMetadata { + /// The size of the file. + pub size: u64, +} + +/// Trait for reading file. +/// +/// # TODO +/// +/// It's possible for us to remove the async_trait, but we need to figure +/// out how to handle the object safety. +#[async_trait::async_trait] +pub trait FileRead: Send + Unpin + 'static { + /// Read file content with given range. + /// + /// TODO: we can support reading non-contiguous bytes in the future. + async fn read(&self, range: Range) -> crate::Result; +} + +#[async_trait::async_trait] +impl FileRead for opendal::Reader { + async fn read(&self, range: Range) -> crate::Result { + Ok(opendal::Reader::read(self, range).await?.to_bytes()) + } +} + /// Input file is used for reading from files. #[derive(Debug)] pub struct InputFile { @@ -214,11 +197,6 @@ pub struct InputFile { relative_path_pos: usize, } -/// Trait for reading file. -pub trait FileRead: AsyncRead + AsyncSeek {} - -impl FileRead for T where T: AsyncRead + AsyncSeek {} - impl InputFile { /// Absolute path to root uri. pub fn location(&self) -> &str { @@ -226,23 +204,70 @@ impl InputFile { } /// Check if file exists. - pub async fn exists(&self) -> Result { + pub async fn exists(&self) -> crate::Result { Ok(self .op .is_exist(&self.path[self.relative_path_pos..]) .await?) } - /// Creates [`InputStream`] for reading. - pub async fn reader(&self) -> Result { + /// Fetch and returns metadata of file. + pub async fn metadata(&self) -> crate::Result { + let meta = self.op.stat(&self.path[self.relative_path_pos..]).await?; + + Ok(FileMetadata { + size: meta.content_length(), + }) + } + + /// Read and returns whole content of file. + /// + /// For continues reading, use [`Self::reader`] instead. + pub async fn read(&self) -> crate::Result { + Ok(self + .op + .read(&self.path[self.relative_path_pos..]) + .await? + .to_bytes()) + } + + /// Creates [`FileRead`] for continues reading. + /// + /// For one-time reading, use [`Self::read`] instead. + pub async fn reader(&self) -> crate::Result { Ok(self.op.reader(&self.path[self.relative_path_pos..]).await?) } } /// Trait for writing file. -pub trait FileWrite: AsyncWrite {} +/// +/// # TODO +/// +/// It's possible for us to remove the async_trait, but we need to figure +/// out how to handle the object safety. +#[async_trait::async_trait] +pub trait FileWrite: Send + Unpin + 'static { + /// Write bytes to file. + /// + /// TODO: we can support writing non-contiguous bytes in the future. + async fn write(&mut self, bs: Bytes) -> crate::Result<()>; -impl FileWrite for T where T: AsyncWrite {} + /// Close file. + /// + /// Calling close on closed file will generate an error. + async fn close(&mut self) -> crate::Result<()>; +} + +#[async_trait::async_trait] +impl FileWrite for opendal::Writer { + async fn write(&mut self, bs: Bytes) -> crate::Result<()> { + Ok(opendal::Writer::write(self, bs).await?) + } + + async fn close(&mut self) -> crate::Result<()> { + Ok(opendal::Writer::close(self).await?) + } +} /// Output file is used for writing to files.. #[derive(Debug)] @@ -261,7 +286,7 @@ impl OutputFile { } /// Checks if file exists. - pub async fn exists(&self) -> Result { + pub async fn exists(&self) -> crate::Result { Ok(self .op .is_exist(&self.path[self.relative_path_pos..]) @@ -277,122 +302,38 @@ impl OutputFile { } } - /// Creates output file for writing. - pub async fn writer(&self) -> Result { - Ok(self.op.writer(&self.path[self.relative_path_pos..]).await?) - } -} - -// We introduce this because I don't want to handle unsupported `Scheme` in every method. -#[derive(Debug)] -enum Storage { - LocalFs { - op: Operator, - }, - S3 { - scheme_str: String, - props: HashMap, - }, -} - -impl Storage { - /// Creates operator from path. - /// - /// # Arguments - /// - /// * path: It should be *absolute* path starting with scheme string used to construct [`FileIO`]. + /// Create a new output file with given bytes. /// - /// # Returns + /// # Notes /// - /// The return value consists of two parts: - /// - /// * An [`opendal::Operator`] instance used to operate on file. - /// * Relative path to the root uri of [`opendal::Operator`]. - /// - fn create_operator<'a>(&self, path: &'a impl AsRef) -> Result<(Operator, &'a str)> { - let path = path.as_ref(); - match self { - Storage::LocalFs { op } => { - if let Some(stripped) = path.strip_prefix("file:/") { - Ok((op.clone(), stripped)) - } else { - Ok((op.clone(), &path[1..])) - } - } - Storage::S3 { scheme_str, props } => { - let mut props = props.clone(); - let url = Url::parse(path)?; - let bucket = url.host_str().ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - format!("Invalid s3 url: {}, missing bucket", path), - ) - })?; - - props.insert("bucket".to_string(), bucket.to_string()); - - let prefix = format!("{}://{}/", scheme_str, bucket); - if path.starts_with(&prefix) { - Ok((Operator::via_map(Scheme::S3, props)?, &path[prefix.len()..])) - } else { - Err(Error::new( - ErrorKind::DataInvalid, - format!("Invalid s3 url: {}, should start with {}", path, prefix), - )) - } - } - } + /// Calling `write` will overwrite the file if it exists. + /// For continues writing, use [`Self::writer`]. + pub async fn write(&self, bs: Bytes) -> crate::Result<()> { + let mut writer = self.writer().await?; + writer.write(bs).await?; + writer.close().await } - /// Parse scheme. - fn parse_scheme(scheme: &str) -> Result { - match scheme { - "file" | "" => Ok(Scheme::Fs), - "s3" | "s3a" => Ok(Scheme::S3), - s => Ok(s.parse::()?), - } - } - - /// Convert iceberg config to opendal config. - fn build(file_io_builder: FileIOBuilder) -> Result { - let scheme_str = file_io_builder.scheme_str.unwrap_or("".to_string()); - let scheme = Self::parse_scheme(&scheme_str)?; - let mut new_props = HashMap::default(); - new_props.insert("root".to_string(), DEFAULT_ROOT_PATH.to_string()); - - match scheme { - Scheme::Fs => Ok(Self::LocalFs { - op: Operator::via_map(Scheme::Fs, new_props)?, - }), - Scheme::S3 => { - for prop in file_io_builder.props { - if let Some(op_key) = S3_CONFIG_MAPPING.get(prop.0.as_str()) { - new_props.insert(op_key.to_string(), prop.1); - } - } - - Ok(Self::S3 { - scheme_str, - props: new_props, - }) - } - _ => Err(Error::new( - ErrorKind::FeatureUnsupported, - format!("Constructing file io from scheme: {scheme} not supported now",), - )), - } + /// Creates output file for continues writing. + /// + /// # Notes + /// + /// For one-time writing, use [`Self::write`] instead. + pub async fn writer(&self) -> crate::Result> { + Ok(Box::new( + self.op.writer(&self.path[self.relative_path_pos..]).await?, + )) } } #[cfg(test)] mod tests { + use std::fs::File; use std::io::Write; - - use std::{fs::File, path::Path}; + use std::path::Path; use futures::io::AllowStdIo; - use futures::{AsyncReadExt, AsyncWriteExt}; - + use futures::AsyncReadExt; use tempfile::TempDir; use super::{FileIO, FileIOBuilder}; @@ -476,9 +417,7 @@ mod tests { assert!(!output_file.exists().await.unwrap()); { - let mut writer = output_file.writer().await.unwrap(); - writer.write_all(content.as_bytes()).await.unwrap(); - writer.close().await.unwrap(); + output_file.write(content.into()).await.unwrap(); } assert_eq!(&full_path, output_file.location()); diff --git a/crates/iceberg/src/io/mod.rs b/crates/iceberg/src/io/mod.rs new file mode 100644 index 000000000..52a1da23a --- /dev/null +++ b/crates/iceberg/src/io/mod.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! File io implementation. +//! +//! # How to build `FileIO` +//! +//! We provided a `FileIOBuilder` to build `FileIO` from scratch. For example: +//! +//! ```rust +//! use iceberg::io::{FileIOBuilder, S3_REGION}; +//! use iceberg::Result; +//! +//! # fn test() -> Result<()> { +//! // Build a memory file io. +//! let file_io = FileIOBuilder::new("memory").build()?; +//! // Build an fs file io. +//! let file_io = FileIOBuilder::new("fs").build()?; +//! // Build an s3 file io. +//! let file_io = FileIOBuilder::new("s3") +//! .with_prop(S3_REGION, "us-east-1") +//! .build()?; +//! # Ok(()) +//! # } +//! ``` +//! +//! Or you can pass a path to ask `FileIO` to infer schema for you: +//! +//! ```rust +//! use iceberg::io::{FileIO, S3_REGION}; +//! use iceberg::Result; +//! +//! # fn test() -> Result<()> { +//! // Build a memory file io. +//! let file_io = FileIO::from_path("memory:///")?.build()?; +//! // Build an fs file io. +//! let file_io = FileIO::from_path("fs:///tmp")?.build()?; +//! // Build an s3 file io. +//! let file_io = FileIO::from_path("s3://bucket/a")? +//! .with_prop(S3_REGION, "us-east-1") +//! .build()?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # How to use `FileIO` +//! +//! Currently `FileIO` provides simple methods for file operations: +//! +//! - `delete`: Delete file. +//! - `is_exist`: Check if file exists. +//! - `new_input`: Create input file for reading. +//! - `new_output`: Create output file for writing. + +mod file_io; +pub use file_io::*; + +mod storage; +#[cfg(feature = "storage-memory")] +mod storage_memory; +#[cfg(feature = "storage-memory")] +use storage_memory::*; +#[cfg(feature = "storage-s3")] +mod storage_s3; +#[cfg(feature = "storage-s3")] +pub use storage_s3::*; +pub(crate) mod object_cache; +#[cfg(feature = "storage-fs")] +mod storage_fs; + +#[cfg(feature = "storage-fs")] +use storage_fs::*; +#[cfg(feature = "storage-gcs")] +mod storage_gcs; +#[cfg(feature = "storage-gcs")] +pub use storage_gcs::*; diff --git a/crates/iceberg/src/io/object_cache.rs b/crates/iceberg/src/io/object_cache.rs new file mode 100644 index 000000000..3b89a4e6a --- /dev/null +++ b/crates/iceberg/src/io/object_cache.rs @@ -0,0 +1,402 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::io::FileIO; +use crate::spec::{ + FormatVersion, Manifest, ManifestFile, ManifestList, SchemaId, SnapshotRef, TableMetadataRef, +}; +use crate::{Error, ErrorKind, Result}; + +const DEFAULT_CACHE_SIZE_BYTES: u64 = 32 * 1024 * 1024; // 32MB + +#[derive(Clone, Debug)] +pub(crate) enum CachedItem { + ManifestList(Arc), + Manifest(Arc), +} + +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +pub(crate) enum CachedObjectKey { + ManifestList((String, FormatVersion, SchemaId)), + Manifest(String), +} + +/// Caches metadata objects deserialized from immutable files +#[derive(Clone, Debug)] +pub struct ObjectCache { + cache: moka::future::Cache, + file_io: FileIO, + cache_disabled: bool, +} + +impl ObjectCache { + /// Creates a new [`ObjectCache`] + /// with the default cache size + pub(crate) fn new(file_io: FileIO) -> Self { + Self::new_with_capacity(file_io, DEFAULT_CACHE_SIZE_BYTES) + } + + /// Creates a new [`ObjectCache`] + /// with a specific cache size + pub(crate) fn new_with_capacity(file_io: FileIO, cache_size_bytes: u64) -> Self { + if cache_size_bytes == 0 { + Self::with_disabled_cache(file_io) + } else { + Self { + cache: moka::future::Cache::builder() + .weigher(|_, val: &CachedItem| match val { + CachedItem::ManifestList(item) => size_of_val(item.as_ref()), + CachedItem::Manifest(item) => size_of_val(item.as_ref()), + } as u32) + .max_capacity(cache_size_bytes) + .build(), + file_io, + cache_disabled: false, + } + } + } + + /// Creates a new [`ObjectCache`] + /// with caching disabled + pub(crate) fn with_disabled_cache(file_io: FileIO) -> Self { + Self { + cache: moka::future::Cache::new(0), + file_io, + cache_disabled: true, + } + } + + /// Retrieves an Arc [`Manifest`] from the cache + /// or retrieves one from FileIO and parses it if not present + pub(crate) async fn get_manifest(&self, manifest_file: &ManifestFile) -> Result> { + if self.cache_disabled { + return manifest_file + .load_manifest(&self.file_io) + .await + .map(Arc::new); + } + + let key = CachedObjectKey::Manifest(manifest_file.manifest_path.clone()); + + let cache_entry = self + .cache + .entry_by_ref(&key) + .or_try_insert_with(self.fetch_and_parse_manifest(manifest_file)) + .await + .map_err(|err| { + Error::new( + ErrorKind::Unexpected, + format!("Failed to load manifest {}", manifest_file.manifest_path), + ) + .with_source(err) + })? + .into_value(); + + match cache_entry { + CachedItem::Manifest(arc_manifest) => Ok(arc_manifest), + _ => Err(Error::new( + ErrorKind::Unexpected, + format!("cached object for key '{:?}' is not a Manifest", key), + )), + } + } + + /// Retrieves an Arc [`ManifestList`] from the cache + /// or retrieves one from FileIO and parses it if not present + pub(crate) async fn get_manifest_list( + &self, + snapshot: &SnapshotRef, + table_metadata: &TableMetadataRef, + ) -> Result> { + if self.cache_disabled { + return snapshot + .load_manifest_list(&self.file_io, table_metadata) + .await + .map(Arc::new); + } + + let key = CachedObjectKey::ManifestList(( + snapshot.manifest_list().to_string(), + table_metadata.format_version, + snapshot.schema_id().unwrap(), + )); + let cache_entry = self + .cache + .entry_by_ref(&key) + .or_try_insert_with(self.fetch_and_parse_manifest_list(snapshot, table_metadata)) + .await + .map_err(|err| Error::new(ErrorKind::Unexpected, err.as_ref().message()))? + .into_value(); + + match cache_entry { + CachedItem::ManifestList(arc_manifest_list) => Ok(arc_manifest_list), + _ => Err(Error::new( + ErrorKind::Unexpected, + format!("cached object for path '{:?}' is not a manifest list", key), + )), + } + } + + async fn fetch_and_parse_manifest(&self, manifest_file: &ManifestFile) -> Result { + let manifest = manifest_file.load_manifest(&self.file_io).await?; + + Ok(CachedItem::Manifest(Arc::new(manifest))) + } + + async fn fetch_and_parse_manifest_list( + &self, + snapshot: &SnapshotRef, + table_metadata: &TableMetadataRef, + ) -> Result { + let manifest_list = snapshot + .load_manifest_list(&self.file_io, table_metadata) + .await?; + + Ok(CachedItem::ManifestList(Arc::new(manifest_list))) + } +} + +#[cfg(test)] +mod tests { + use std::fs; + + use tempfile::TempDir; + use tera::{Context, Tera}; + use uuid::Uuid; + + use super::*; + use crate::io::{FileIO, OutputFile}; + use crate::spec::{ + DataContentType, DataFileBuilder, DataFileFormat, FormatVersion, Literal, Manifest, + ManifestContentType, ManifestEntry, ManifestListWriter, ManifestMetadata, ManifestStatus, + ManifestWriter, Struct, TableMetadata, EMPTY_SNAPSHOT_ID, + }; + use crate::table::Table; + use crate::TableIdent; + + struct TableTestFixture { + table_location: String, + table: Table, + } + + impl TableTestFixture { + fn new() -> Self { + let tmp_dir = TempDir::new().unwrap(); + let table_location = tmp_dir.path().join("table1"); + let manifest_list1_location = table_location.join("metadata/manifests_list_1.avro"); + let manifest_list2_location = table_location.join("metadata/manifests_list_2.avro"); + let table_metadata1_location = table_location.join("metadata/v1.json"); + + let file_io = FileIO::from_path(table_location.as_os_str().to_str().unwrap()) + .unwrap() + .build() + .unwrap(); + + let table_metadata = { + let template_json_str = fs::read_to_string(format!( + "{}/testdata/example_table_metadata_v2.json", + env!("CARGO_MANIFEST_DIR") + )) + .unwrap(); + let mut context = Context::new(); + context.insert("table_location", &table_location); + context.insert("manifest_list_1_location", &manifest_list1_location); + context.insert("manifest_list_2_location", &manifest_list2_location); + context.insert("table_metadata_1_location", &table_metadata1_location); + + let metadata_json = Tera::one_off(&template_json_str, &context, false).unwrap(); + serde_json::from_str::(&metadata_json).unwrap() + }; + + let table = Table::builder() + .metadata(table_metadata) + .identifier(TableIdent::from_strs(["db", "table1"]).unwrap()) + .file_io(file_io.clone()) + .metadata_location(table_metadata1_location.as_os_str().to_str().unwrap()) + .build() + .unwrap(); + + Self { + table_location: table_location.to_str().unwrap().to_string(), + table, + } + } + + fn next_manifest_file(&self) -> OutputFile { + self.table + .file_io() + .new_output(format!( + "{}/metadata/manifest_{}.avro", + self.table_location, + Uuid::new_v4() + )) + .unwrap() + } + + async fn setup_manifest_files(&mut self) { + let current_snapshot = self.table.metadata().current_snapshot().unwrap(); + let current_schema = current_snapshot.schema(self.table.metadata()).unwrap(); + let current_partition_spec = self.table.metadata().default_partition_spec().unwrap(); + + // Write data files + let data_file_manifest = ManifestWriter::new( + self.next_manifest_file(), + current_snapshot.snapshot_id(), + vec![], + ) + .write(Manifest::new( + ManifestMetadata::builder() + .schema((*current_schema).clone()) + .content(ManifestContentType::Data) + .format_version(FormatVersion::V2) + .partition_spec((**current_partition_spec).clone()) + .schema_id(current_schema.schema_id()) + .build(), + vec![ManifestEntry::builder() + .status(ManifestStatus::Added) + .data_file( + DataFileBuilder::default() + .content(DataContentType::Data) + .file_path(format!("{}/1.parquet", &self.table_location)) + .file_format(DataFileFormat::Parquet) + .file_size_in_bytes(100) + .record_count(1) + .partition(Struct::from_iter([Some(Literal::long(100))])) + .build() + .unwrap(), + ) + .build()], + )) + .await + .unwrap(); + + // Write to manifest list + let mut manifest_list_write = ManifestListWriter::v2( + self.table + .file_io() + .new_output(current_snapshot.manifest_list()) + .unwrap(), + current_snapshot.snapshot_id(), + current_snapshot + .parent_snapshot_id() + .unwrap_or(EMPTY_SNAPSHOT_ID), + current_snapshot.sequence_number(), + ); + manifest_list_write + .add_manifests(vec![data_file_manifest].into_iter()) + .unwrap(); + manifest_list_write.close().await.unwrap(); + } + } + + #[tokio::test] + async fn test_get_manifest_list_and_manifest_from_disabled_cache() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + let object_cache = ObjectCache::with_disabled_cache(fixture.table.file_io().clone()); + + let result_manifest_list = object_cache + .get_manifest_list( + fixture.table.metadata().current_snapshot().unwrap(), + &fixture.table.metadata_ref(), + ) + .await + .unwrap(); + + assert_eq!(result_manifest_list.entries().len(), 1); + + let manifest_file = result_manifest_list.entries().first().unwrap(); + let result_manifest = object_cache.get_manifest(manifest_file).await.unwrap(); + + assert_eq!( + result_manifest + .entries() + .first() + .unwrap() + .file_path() + .split("/") + .last() + .unwrap(), + "1.parquet" + ); + } + + #[tokio::test] + async fn test_get_manifest_list_and_manifest_from_default_cache() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + let object_cache = ObjectCache::new(fixture.table.file_io().clone()); + + // not in cache + let result_manifest_list = object_cache + .get_manifest_list( + fixture.table.metadata().current_snapshot().unwrap(), + &fixture.table.metadata_ref(), + ) + .await + .unwrap(); + + assert_eq!(result_manifest_list.entries().len(), 1); + + // retrieve cached version + let result_manifest_list = object_cache + .get_manifest_list( + fixture.table.metadata().current_snapshot().unwrap(), + &fixture.table.metadata_ref(), + ) + .await + .unwrap(); + + assert_eq!(result_manifest_list.entries().len(), 1); + + let manifest_file = result_manifest_list.entries().first().unwrap(); + + // not in cache + let result_manifest = object_cache.get_manifest(manifest_file).await.unwrap(); + + assert_eq!( + result_manifest + .entries() + .first() + .unwrap() + .file_path() + .split("/") + .last() + .unwrap(), + "1.parquet" + ); + + // retrieve cached version + let result_manifest = object_cache.get_manifest(manifest_file).await.unwrap(); + + assert_eq!( + result_manifest + .entries() + .first() + .unwrap() + .file_path() + .split("/") + .last() + .unwrap(), + "1.parquet" + ); + } +} diff --git a/crates/iceberg/src/io/storage.rs b/crates/iceberg/src/io/storage.rs new file mode 100644 index 000000000..682b1d33e --- /dev/null +++ b/crates/iceberg/src/io/storage.rs @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +#[cfg(feature = "storage-gcs")] +use opendal::services::GcsConfig; +#[cfg(feature = "storage-s3")] +use opendal::services::S3Config; +use opendal::{Operator, Scheme}; + +use super::FileIOBuilder; +use crate::{Error, ErrorKind}; + +/// The storage carries all supported storage services in iceberg +#[derive(Debug)] +pub(crate) enum Storage { + #[cfg(feature = "storage-memory")] + Memory, + #[cfg(feature = "storage-fs")] + LocalFs, + #[cfg(feature = "storage-s3")] + S3 { + /// s3 storage could have `s3://` and `s3a://`. + /// Storing the scheme string here to return the correct path. + scheme_str: String, + /// uses the same client for one FileIO Storage. + /// + /// TODO: allow users to configure this client. + client: reqwest::Client, + config: Arc, + }, + #[cfg(feature = "storage-gcs")] + Gcs { config: Arc }, +} + +impl Storage { + /// Convert iceberg config to opendal config. + pub(crate) fn build(file_io_builder: FileIOBuilder) -> crate::Result { + let (scheme_str, props) = file_io_builder.into_parts(); + let scheme = Self::parse_scheme(&scheme_str)?; + + match scheme { + #[cfg(feature = "storage-memory")] + Scheme::Memory => Ok(Self::Memory), + #[cfg(feature = "storage-fs")] + Scheme::Fs => Ok(Self::LocalFs), + #[cfg(feature = "storage-s3")] + Scheme::S3 => Ok(Self::S3 { + scheme_str, + client: reqwest::Client::new(), + config: super::s3_config_parse(props)?.into(), + }), + #[cfg(feature = "storage-gcs")] + Scheme::Gcs => Ok(Self::Gcs { + config: super::gcs_config_parse(props)?.into(), + }), + _ => Err(Error::new( + ErrorKind::FeatureUnsupported, + format!("Constructing file io from scheme: {scheme} not supported now",), + )), + } + } + + /// Creates operator from path. + /// + /// # Arguments + /// + /// * path: It should be *absolute* path starting with scheme string used to construct [`FileIO`]. + /// + /// # Returns + /// + /// The return value consists of two parts: + /// + /// * An [`opendal::Operator`] instance used to operate on file. + /// * Relative path to the root uri of [`opendal::Operator`]. + pub(crate) fn create_operator<'a>( + &self, + path: &'a impl AsRef, + ) -> crate::Result<(Operator, &'a str)> { + let path = path.as_ref(); + match self { + #[cfg(feature = "storage-memory")] + Storage::Memory => { + let op = super::memory_config_build()?; + + if let Some(stripped) = path.strip_prefix("memory:/") { + Ok((op, stripped)) + } else { + Ok((op, &path[1..])) + } + } + #[cfg(feature = "storage-fs")] + Storage::LocalFs => { + let op = super::fs_config_build()?; + + if let Some(stripped) = path.strip_prefix("file:/") { + Ok((op, stripped)) + } else { + Ok((op, &path[1..])) + } + } + #[cfg(feature = "storage-s3")] + Storage::S3 { + scheme_str, + client, + config, + } => { + let op = super::s3_config_build(client, config, path)?; + let op_info = op.info(); + + // Check prefix of s3 path. + let prefix = format!("{}://{}/", scheme_str, op_info.name()); + if path.starts_with(&prefix) { + Ok((op, &path[prefix.len()..])) + } else { + Err(Error::new( + ErrorKind::DataInvalid, + format!("Invalid s3 url: {}, should start with {}", path, prefix), + )) + } + } + #[cfg(feature = "storage-gcs")] + Storage::Gcs { config } => { + let operator = super::gcs_config_build(config, path)?; + let prefix = format!("gs://{}/", operator.info().name()); + if path.starts_with(&prefix) { + Ok((operator, &path[prefix.len()..])) + } else { + Err(Error::new( + ErrorKind::DataInvalid, + format!("Invalid gcs url: {}, should start with {}", path, prefix), + )) + } + } + #[cfg(all( + not(feature = "storage-s3"), + not(feature = "storage-fs"), + not(feature = "storage-gcs") + ))] + _ => Err(Error::new( + ErrorKind::FeatureUnsupported, + "No storage service has been enabled", + )), + } + } + + /// Parse scheme. + fn parse_scheme(scheme: &str) -> crate::Result { + match scheme { + "memory" => Ok(Scheme::Memory), + "file" | "" => Ok(Scheme::Fs), + "s3" | "s3a" => Ok(Scheme::S3), + "gs" => Ok(Scheme::Gcs), + s => Ok(s.parse::()?), + } + } +} diff --git a/crates/iceberg/src/io/storage_fs.rs b/crates/iceberg/src/io/storage_fs.rs new file mode 100644 index 000000000..ff38d7613 --- /dev/null +++ b/crates/iceberg/src/io/storage_fs.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use opendal::services::FsConfig; +use opendal::Operator; + +use crate::Result; + +/// Build new opendal operator from give path. +pub(crate) fn fs_config_build() -> Result { + let mut cfg = FsConfig::default(); + cfg.root = Some("/".to_string()); + + Ok(Operator::from_config(cfg)?.finish()) +} diff --git a/crates/iceberg/src/io/storage_gcs.rs b/crates/iceberg/src/io/storage_gcs.rs new file mode 100644 index 000000000..0a2410799 --- /dev/null +++ b/crates/iceberg/src/io/storage_gcs.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +//! Google Cloud Storage properties + +use std::collections::HashMap; + +use opendal::services::GcsConfig; +use opendal::Operator; +use url::Url; + +use crate::{Error, ErrorKind, Result}; + +// Reference: https://github.com/apache/iceberg/blob/main/gcp/src/main/java/org/apache/iceberg/gcp/GCPProperties.java + +/// Google Cloud Project ID +pub const GCS_PROJECT_ID: &str = "gcs.project-id"; +/// Google Cloud Storage endpoint +pub const GCS_SERVICE_PATH: &str = "gcs.service.path"; +/// Google Cloud user project +pub const GCS_USER_PROJECT: &str = "gcs.user-project"; +/// Allow unauthenticated requests +pub const GCS_NO_AUTH: &str = "gcs.no-auth"; + +/// Parse iceberg properties to [`GcsConfig`]. +pub(crate) fn gcs_config_parse(mut m: HashMap) -> Result { + let mut cfg = GcsConfig::default(); + + if let Some(endpoint) = m.remove(GCS_SERVICE_PATH) { + cfg.endpoint = Some(endpoint); + } + + if m.remove(GCS_NO_AUTH).is_some() { + cfg.allow_anonymous = true; + cfg.disable_vm_metadata = true; + cfg.disable_config_load = true; + } + + Ok(cfg) +} + +/// Build a new OpenDAL [`Operator`] based on a provided [`GcsConfig`]. +pub(crate) fn gcs_config_build(cfg: &GcsConfig, path: &str) -> Result { + let url = Url::parse(path)?; + let bucket = url.host_str().ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Invalid gcs url: {}, bucket is required", path), + ) + })?; + + let mut cfg = cfg.clone(); + cfg.bucket = bucket.to_string(); + Ok(Operator::from_config(cfg)?.finish()) +} diff --git a/crates/iceberg/src/io/storage_memory.rs b/crates/iceberg/src/io/storage_memory.rs new file mode 100644 index 000000000..ffc082d83 --- /dev/null +++ b/crates/iceberg/src/io/storage_memory.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use opendal::services::MemoryConfig; +use opendal::Operator; + +use crate::Result; + +pub(crate) fn memory_config_build() -> Result { + Ok(Operator::from_config(MemoryConfig::default())?.finish()) +} diff --git a/crates/iceberg/src/io/storage_s3.rs b/crates/iceberg/src/io/storage_s3.rs new file mode 100644 index 000000000..60e97ab45 --- /dev/null +++ b/crates/iceberg/src/io/storage_s3.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use opendal::raw::HttpClient; +use opendal::services::S3Config; +use opendal::{Configurator, Operator}; +use url::Url; + +use crate::{Error, ErrorKind, Result}; + +/// Following are arguments for [s3 file io](https://py.iceberg.apache.org/configuration/#s3). +/// S3 endpoint. +pub const S3_ENDPOINT: &str = "s3.endpoint"; +/// S3 access key id. +pub const S3_ACCESS_KEY_ID: &str = "s3.access-key-id"; +/// S3 secret access key. +pub const S3_SECRET_ACCESS_KEY: &str = "s3.secret-access-key"; +/// S3 session token. +/// This is required when using temporary credentials. +pub const S3_SESSION_TOKEN: &str = "s3.session-token"; +/// S3 region. +pub const S3_REGION: &str = "s3.region"; +/// Region to use for the S3 client. +/// +/// This takes precedence over [`S3_REGION`]. +pub const CLIENT_REGION: &str = "client.region"; +/// S3 Path Style Access. +pub const S3_PATH_STYLE_ACCESS: &str = "s3.path-style-access"; +/// S3 Server Side Encryption Type. +pub const S3_SSE_TYPE: &str = "s3.sse.type"; +/// S3 Server Side Encryption Key. +/// If S3 encryption type is kms, input is a KMS Key ID. +/// In case this property is not set, default key "aws/s3" is used. +/// If encryption type is custom, input is a custom base-64 AES256 symmetric key. +pub const S3_SSE_KEY: &str = "s3.sse.key"; +/// S3 Server Side Encryption MD5. +pub const S3_SSE_MD5: &str = "s3.sse.md5"; +/// If set, all AWS clients will assume a role of the given ARN, instead of using the default +/// credential chain. +pub const S3_ASSUME_ROLE_ARN: &str = "client.assume-role.arn"; +/// Optional external ID used to assume an IAM role. +pub const S3_ASSUME_ROLE_EXTERNAL_ID: &str = "client.assume-role.external-id"; +/// Optional session name used to assume an IAM role. +pub const S3_ASSUME_ROLE_SESSION_NAME: &str = "client.assume-role.session-name"; + +/// Parse iceberg props to s3 config. +pub(crate) fn s3_config_parse(mut m: HashMap) -> Result { + let mut cfg = S3Config::default(); + if let Some(endpoint) = m.remove(S3_ENDPOINT) { + cfg.endpoint = Some(endpoint); + }; + if let Some(access_key_id) = m.remove(S3_ACCESS_KEY_ID) { + cfg.access_key_id = Some(access_key_id); + }; + if let Some(secret_access_key) = m.remove(S3_SECRET_ACCESS_KEY) { + cfg.secret_access_key = Some(secret_access_key); + }; + if let Some(session_token) = m.remove(S3_SESSION_TOKEN) { + cfg.session_token = Some(session_token); + }; + if let Some(region) = m.remove(S3_REGION) { + cfg.region = Some(region); + }; + if let Some(region) = m.remove(CLIENT_REGION) { + cfg.region = Some(region); + }; + if let Some(path_style_access) = m.remove(S3_PATH_STYLE_ACCESS) { + if ["true", "True", "1"].contains(&path_style_access.as_str()) { + cfg.enable_virtual_host_style = true; + } + }; + if let Some(arn) = m.remove(S3_ASSUME_ROLE_ARN) { + cfg.role_arn = Some(arn); + } + if let Some(external_id) = m.remove(S3_ASSUME_ROLE_EXTERNAL_ID) { + cfg.external_id = Some(external_id); + }; + if let Some(session_name) = m.remove(S3_ASSUME_ROLE_SESSION_NAME) { + cfg.role_session_name = Some(session_name); + }; + let s3_sse_key = m.remove(S3_SSE_KEY); + if let Some(sse_type) = m.remove(S3_SSE_TYPE) { + match sse_type.to_lowercase().as_str() { + // No Server Side Encryption + "none" => {} + // S3 SSE-S3 encryption (S3 managed keys). https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingServerSideEncryption.html + "s3" => { + cfg.server_side_encryption = Some("AES256".to_string()); + } + // S3 SSE KMS, either using default or custom KMS key. https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingKMSEncryption.html + "kms" => { + cfg.server_side_encryption = Some("aws:kms".to_string()); + cfg.server_side_encryption_aws_kms_key_id = s3_sse_key; + } + // S3 SSE-C, using customer managed keys. https://docs.aws.amazon.com/AmazonS3/latest/dev/ServerSideEncryptionCustomerKeys.html + "custom" => { + cfg.server_side_encryption_customer_algorithm = Some("AES256".to_string()); + cfg.server_side_encryption_customer_key = s3_sse_key; + cfg.server_side_encryption_customer_key_md5 = m.remove(S3_SSE_MD5); + } + _ => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Invalid {}: {}. Expected one of (custom, kms, s3, none)", + S3_SSE_TYPE, sse_type + ), + )); + } + } + }; + + Ok(cfg) +} + +/// Build new opendal operator from give path. +pub(crate) fn s3_config_build( + client: &reqwest::Client, + cfg: &S3Config, + path: &str, +) -> Result { + let url = Url::parse(path)?; + let bucket = url.host_str().ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Invalid s3 url: {}, missing bucket", path), + ) + })?; + + let builder = cfg + .clone() + .into_builder() + // Set bucket name. + .bucket(bucket) + // Set http client we want to use. + .http_client(HttpClient::with(client.clone())); + + Ok(Operator::new(builder)?.finish()) +} diff --git a/crates/iceberg/src/lib.rs b/crates/iceberg/src/lib.rs index 378164838..d6c5010d3 100644 --- a/crates/iceberg/src/lib.rs +++ b/crates/iceberg/src/lib.rs @@ -15,7 +15,41 @@ // specific language governing permissions and limitations // under the License. -//! Native Rust implementation of Apache Iceberg +//! Apache Iceberg Official Native Rust Implementation +//! +//! # Examples +//! +//! ## Scan A Table +//! +//! ```rust, no_run +//! use futures::TryStreamExt; +//! use iceberg::io::{FileIO, FileIOBuilder}; +//! use iceberg::{Catalog, Result, TableIdent}; +//! use iceberg_catalog_memory::MemoryCatalog; +//! +//! #[tokio::main] +//! async fn main() -> Result<()> { +//! // Build your file IO. +//! let file_io = FileIOBuilder::new("memory").build()?; +//! // Connect to a catalog. +//! let catalog = MemoryCatalog::new(file_io, None); +//! // Load table from catalog. +//! let table = catalog +//! .load_table(&TableIdent::from_strs(["hello", "world"])?) +//! .await?; +//! // Build table scan. +//! let stream = table +//! .scan() +//! .select(["name", "id"]) +//! .build()? +//! .to_arrow() +//! .await?; +//! +//! // Consume this stream like arrow record batch stream. +//! let _data: Vec<_> = stream.try_collect().await?; +//! Ok(()) +//! } +//! ``` #![deny(missing_docs)] @@ -23,26 +57,29 @@ extern crate derive_builder; mod error; -pub use error::Error; -pub use error::ErrorKind; -pub use error::Result; +pub use error::{Error, ErrorKind, Result}; mod catalog; -pub use catalog::Catalog; -pub use catalog::Namespace; -pub use catalog::NamespaceIdent; -pub use catalog::TableCommit; -pub use catalog::TableCreation; -pub use catalog::TableIdent; -pub use catalog::TableRequirement; -pub use catalog::TableUpdate; - -#[allow(dead_code)] + +pub use catalog::{ + Catalog, Namespace, NamespaceIdent, TableCommit, TableCreation, TableIdent, TableRequirement, + TableUpdate, ViewCreation, +}; + pub mod table; mod avro; pub mod io; pub mod spec; +pub mod scan; + +pub mod expr; pub mod transaction; pub mod transform; + +mod runtime; + +pub mod arrow; +mod utils; +pub mod writer; diff --git a/crates/iceberg/src/runtime/mod.rs b/crates/iceberg/src/runtime/mod.rs new file mode 100644 index 000000000..65c30e82c --- /dev/null +++ b/crates/iceberg/src/runtime/mod.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This module contains the async runtime abstraction for iceberg. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub enum JoinHandle { + #[cfg(feature = "tokio")] + Tokio(tokio::task::JoinHandle), + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + AsyncStd(async_std::task::JoinHandle), + #[cfg(all(not(feature = "async-std"), not(feature = "tokio")))] + Unimplemented(Box), +} + +impl Future for JoinHandle { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut() { + #[cfg(feature = "tokio")] + JoinHandle::Tokio(handle) => Pin::new(handle) + .poll(cx) + .map(|h| h.expect("tokio spawned task failed")), + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + JoinHandle::AsyncStd(handle) => Pin::new(handle).poll(cx), + #[cfg(all(not(feature = "async-std"), not(feature = "tokio")))] + JoinHandle::Unimplemented(_) => unimplemented!("no runtime has been enabled"), + } + } +} + +#[allow(dead_code)] +pub fn spawn(f: F) -> JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + #[cfg(feature = "tokio")] + return JoinHandle::Tokio(tokio::task::spawn(f)); + + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + return JoinHandle::AsyncStd(async_std::task::spawn(f)); + + #[cfg(all(not(feature = "async-std"), not(feature = "tokio")))] + unimplemented!("no runtime has been enabled") +} + +#[allow(dead_code)] +pub fn spawn_blocking(f: F) -> JoinHandle +where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, +{ + #[cfg(feature = "tokio")] + return JoinHandle::Tokio(tokio::task::spawn_blocking(f)); + + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + return JoinHandle::AsyncStd(async_std::task::spawn_blocking(f)); + + #[cfg(all(not(feature = "async-std"), not(feature = "tokio")))] + unimplemented!("no runtime has been enabled") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "tokio")] + #[tokio::test] + async fn test_tokio_spawn() { + let handle = spawn(async { 1 + 1 }); + assert_eq!(handle.await, 2); + } + + #[cfg(feature = "tokio")] + #[tokio::test] + async fn test_tokio_spawn_blocking() { + let handle = spawn_blocking(|| 1 + 1); + assert_eq!(handle.await, 2); + } + + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + #[async_std::test] + async fn test_async_std_spawn() { + let handle = spawn(async { 1 + 1 }); + assert_eq!(handle.await, 2); + } + + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + #[async_std::test] + async fn test_async_std_spawn_blocking() { + let handle = spawn_blocking(|| 1 + 1); + assert_eq!(handle.await, 2); + } +} diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs new file mode 100644 index 000000000..f1cb86ab3 --- /dev/null +++ b/crates/iceberg/src/scan.rs @@ -0,0 +1,1610 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Table scan api. + +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +use arrow_array::RecordBatch; +use futures::channel::mpsc::{channel, Sender}; +use futures::stream::BoxStream; +use futures::{SinkExt, StreamExt, TryFutureExt, TryStreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::arrow::ArrowReaderBuilder; +use crate::expr::visitors::expression_evaluator::ExpressionEvaluator; +use crate::expr::visitors::inclusive_metrics_evaluator::InclusiveMetricsEvaluator; +use crate::expr::visitors::inclusive_projection::InclusiveProjection; +use crate::expr::visitors::manifest_evaluator::ManifestEvaluator; +use crate::expr::{Bind, BoundPredicate, Predicate}; +use crate::io::object_cache::ObjectCache; +use crate::io::FileIO; +use crate::runtime::spawn; +use crate::spec::{ + DataContentType, DataFileFormat, ManifestContentType, ManifestEntryRef, ManifestFile, + ManifestList, Schema, SchemaRef, SnapshotRef, TableMetadataRef, +}; +use crate::table::Table; +use crate::utils::available_parallelism; +use crate::{Error, ErrorKind, Result}; + +/// A stream of [`FileScanTask`]. +pub type FileScanTaskStream = BoxStream<'static, Result>; +/// A stream of arrow [`RecordBatch`]es. +pub type ArrowRecordBatchStream = BoxStream<'static, Result>; + +/// Builder to create table scan. +pub struct TableScanBuilder<'a> { + table: &'a Table, + // Empty column names means to select all columns + column_names: Vec, + snapshot_id: Option, + batch_size: Option, + case_sensitive: bool, + filter: Option, + concurrency_limit_data_files: usize, + concurrency_limit_manifest_entries: usize, + concurrency_limit_manifest_files: usize, + row_group_filtering_enabled: bool, +} + +impl<'a> TableScanBuilder<'a> { + pub(crate) fn new(table: &'a Table) -> Self { + let num_cpus = available_parallelism().get(); + + Self { + table, + column_names: vec![], + snapshot_id: None, + batch_size: None, + case_sensitive: true, + filter: None, + concurrency_limit_data_files: num_cpus, + concurrency_limit_manifest_entries: num_cpus, + concurrency_limit_manifest_files: num_cpus, + row_group_filtering_enabled: true, + } + } + + /// Sets the desired size of batches in the response + /// to something other than the default + pub fn with_batch_size(mut self, batch_size: Option) -> Self { + self.batch_size = batch_size; + self + } + + /// Sets the scan's case sensitivity + pub fn with_case_sensitive(mut self, case_sensitive: bool) -> Self { + self.case_sensitive = case_sensitive; + self + } + + /// Specifies a predicate to use as a filter + pub fn with_filter(mut self, predicate: Predicate) -> Self { + // calls rewrite_not to remove Not nodes, which must be absent + // when applying the manifest evaluator + self.filter = Some(predicate.rewrite_not()); + self + } + + /// Select all columns. + pub fn select_all(mut self) -> Self { + self.column_names.clear(); + self + } + + /// Select some columns of the table. + pub fn select(mut self, column_names: impl IntoIterator) -> Self { + self.column_names = column_names + .into_iter() + .map(|item| item.to_string()) + .collect(); + self + } + + /// Set the snapshot to scan. When not set, it uses current snapshot. + pub fn snapshot_id(mut self, snapshot_id: i64) -> Self { + self.snapshot_id = Some(snapshot_id); + self + } + + /// Sets the concurrency limit for both manifest files and manifest + /// entries for this scan + pub fn with_concurrency_limit(mut self, limit: usize) -> Self { + self.concurrency_limit_manifest_files = limit; + self.concurrency_limit_manifest_entries = limit; + self.concurrency_limit_data_files = limit; + self + } + + /// Sets the data file concurrency limit for this scan + pub fn with_data_file_concurrency_limit(mut self, limit: usize) -> Self { + self.concurrency_limit_data_files = limit; + self + } + + /// Sets the manifest entry concurrency limit for this scan + pub fn with_manifest_entry_concurrency_limit(mut self, limit: usize) -> Self { + self.concurrency_limit_manifest_entries = limit; + self + } + + /// Determines whether to enable row group filtering. + /// When enabled, if a read is performed with a filter predicate, + /// then the metadata for each row group in the parquet file is + /// evaluated against the filter predicate and row groups + /// that cant contain matching rows will be skipped entirely. + /// + /// Defaults to enabled, as it generally improves performance or + /// keeps it the same, with performance degradation unlikely. + pub fn with_row_group_filtering_enabled(mut self, row_group_filtering_enabled: bool) -> Self { + self.row_group_filtering_enabled = row_group_filtering_enabled; + self + } + + /// Build the table scan. + pub fn build(self) -> Result { + let snapshot = match self.snapshot_id { + Some(snapshot_id) => self + .table + .metadata() + .snapshot_by_id(snapshot_id) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Snapshot with id {} not found", snapshot_id), + ) + })? + .clone(), + None => self + .table + .metadata() + .current_snapshot() + .ok_or_else(|| { + Error::new( + ErrorKind::FeatureUnsupported, + "Can't scan table without snapshots", + ) + })? + .clone(), + }; + + let schema = snapshot.schema(self.table.metadata())?; + + // Check that all column names exist in the schema. + if !self.column_names.is_empty() { + for column_name in &self.column_names { + if schema.field_by_name(column_name).is_none() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Column {} not found in table. Schema: {}", + column_name, schema + ), + )); + } + } + } + + let mut field_ids = vec![]; + for column_name in &self.column_names { + let field_id = schema.field_id_by_name(column_name).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!( + "Column {} not found in table. Schema: {}", + column_name, schema + ), + ) + })?; + + let field = schema + .as_struct() + .field_by_id(field_id) + .ok_or_else(|| { + Error::new( + ErrorKind::FeatureUnsupported, + format!( + "Column {} is not a direct child of schema but a nested field, which is not supported now. Schema: {}", + column_name, schema + ), + ) + })?; + + if !field.field_type.is_primitive() { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + format!( + "Column {} is not a primitive type. Schema: {}", + column_name, schema + ), + )); + } + + field_ids.push(field_id); + } + + let snapshot_bound_predicate = if let Some(ref predicates) = self.filter { + Some(predicates.bind(schema.clone(), true)?) + } else { + None + }; + + let plan_context = PlanContext { + snapshot, + table_metadata: self.table.metadata_ref(), + snapshot_schema: schema, + case_sensitive: self.case_sensitive, + predicate: self.filter.map(Arc::new), + snapshot_bound_predicate: snapshot_bound_predicate.map(Arc::new), + object_cache: self.table.object_cache(), + field_ids: Arc::new(field_ids), + partition_filter_cache: Arc::new(PartitionFilterCache::new()), + manifest_evaluator_cache: Arc::new(ManifestEvaluatorCache::new()), + expression_evaluator_cache: Arc::new(ExpressionEvaluatorCache::new()), + }; + + Ok(TableScan { + batch_size: self.batch_size, + column_names: self.column_names, + file_io: self.table.file_io().clone(), + plan_context, + concurrency_limit_data_files: self.concurrency_limit_data_files, + concurrency_limit_manifest_entries: self.concurrency_limit_manifest_entries, + concurrency_limit_manifest_files: self.concurrency_limit_manifest_files, + row_group_filtering_enabled: self.row_group_filtering_enabled, + }) + } +} + +/// Table scan. +#[derive(Debug)] +pub struct TableScan { + plan_context: PlanContext, + batch_size: Option, + file_io: FileIO, + column_names: Vec, + /// The maximum number of manifest files that will be + /// retrieved from [`FileIO`] concurrently + concurrency_limit_manifest_files: usize, + + /// The maximum number of [`ManifestEntry`]s that will + /// be processed in parallel + concurrency_limit_manifest_entries: usize, + + /// The maximum number of [`ManifestEntry`]s that will + /// be processed in parallel + concurrency_limit_data_files: usize, + + row_group_filtering_enabled: bool, +} + +/// PlanContext wraps a [`SnapshotRef`] alongside all the other +/// objects that are required to perform a scan file plan. +#[derive(Debug)] +struct PlanContext { + snapshot: SnapshotRef, + + table_metadata: TableMetadataRef, + snapshot_schema: SchemaRef, + case_sensitive: bool, + predicate: Option>, + snapshot_bound_predicate: Option>, + object_cache: Arc, + field_ids: Arc>, + + partition_filter_cache: Arc, + manifest_evaluator_cache: Arc, + expression_evaluator_cache: Arc, +} + +impl TableScan { + /// Returns a stream of [`FileScanTask`]s. + pub async fn plan_files(&self) -> Result { + let concurrency_limit_manifest_files = self.concurrency_limit_manifest_files; + let concurrency_limit_manifest_entries = self.concurrency_limit_manifest_entries; + + // used to stream ManifestEntryContexts between stages of the file plan operation + let (manifest_entry_ctx_tx, manifest_entry_ctx_rx) = + channel(concurrency_limit_manifest_files); + // used to stream the results back to the caller + let (file_scan_task_tx, file_scan_task_rx) = channel(concurrency_limit_manifest_entries); + + let manifest_list = self.plan_context.get_manifest_list().await?; + + // get the [`ManifestFile`]s from the [`ManifestList`], filtering out any + // whose content type is not Data or whose partitions cannot match this + // scan's filter + let manifest_file_contexts = self + .plan_context + .build_manifest_file_contexts(manifest_list, manifest_entry_ctx_tx)?; + + let mut channel_for_manifest_error = file_scan_task_tx.clone(); + + // Concurrently load all [`Manifest`]s and stream their [`ManifestEntry`]s + spawn(async move { + let result = futures::stream::iter(manifest_file_contexts) + .try_for_each_concurrent(concurrency_limit_manifest_files, |ctx| async move { + ctx.fetch_manifest_and_stream_manifest_entries().await + }) + .await; + + if let Err(error) = result { + let _ = channel_for_manifest_error.send(Err(error)).await; + } + }); + + let mut channel_for_manifest_entry_error = file_scan_task_tx.clone(); + + // Process the [`ManifestEntry`] stream in parallel + spawn(async move { + let result = manifest_entry_ctx_rx + .map(|me_ctx| Ok((me_ctx, file_scan_task_tx.clone()))) + .try_for_each_concurrent( + concurrency_limit_manifest_entries, + |(manifest_entry_context, tx)| async move { + spawn(async move { + Self::process_manifest_entry(manifest_entry_context, tx).await + }) + .await + }, + ) + .await; + + if let Err(error) = result { + let _ = channel_for_manifest_entry_error.send(Err(error)).await; + } + }); + + return Ok(file_scan_task_rx.boxed()); + } + + /// Returns an [`ArrowRecordBatchStream`]. + pub async fn to_arrow(&self) -> Result { + let mut arrow_reader_builder = ArrowReaderBuilder::new(self.file_io.clone()) + .with_data_file_concurrency_limit(self.concurrency_limit_data_files) + .with_row_group_filtering_enabled(self.row_group_filtering_enabled); + + if let Some(batch_size) = self.batch_size { + arrow_reader_builder = arrow_reader_builder.with_batch_size(batch_size); + } + + arrow_reader_builder.build().read(self.plan_files().await?) + } + + /// Returns a reference to the column names of the table scan. + pub fn column_names(&self) -> &[String] { + &self.column_names + } + /// Returns a reference to the snapshot of the table scan. + pub fn snapshot(&self) -> &SnapshotRef { + &self.plan_context.snapshot + } + + async fn process_manifest_entry( + manifest_entry_context: ManifestEntryContext, + mut file_scan_task_tx: Sender>, + ) -> Result<()> { + // skip processing this manifest entry if it has been marked as deleted + if !manifest_entry_context.manifest_entry.is_alive() { + return Ok(()); + } + + // abort the plan if we encounter a manifest entry whose data file's + // content type is currently unsupported + if manifest_entry_context.manifest_entry.content_type() != DataContentType::Data { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + "Only Data files currently supported", + )); + } + + if let Some(ref bound_predicates) = manifest_entry_context.bound_predicates { + let BoundPredicates { + ref snapshot_bound_predicate, + ref partition_bound_predicate, + } = bound_predicates.as_ref(); + + let expression_evaluator_cache = + manifest_entry_context.expression_evaluator_cache.as_ref(); + + let expression_evaluator = expression_evaluator_cache.get( + manifest_entry_context.partition_spec_id, + partition_bound_predicate, + )?; + + // skip any data file whose partition data indicates that it can't contain + // any data that matches this scan's filter + if !expression_evaluator.eval(manifest_entry_context.manifest_entry.data_file())? { + return Ok(()); + } + + // skip any data file whose metrics don't match this scan's filter + if !InclusiveMetricsEvaluator::eval( + snapshot_bound_predicate, + manifest_entry_context.manifest_entry.data_file(), + false, + )? { + return Ok(()); + } + } + + // congratulations! the manifest entry has made its way through the + // entire plan without getting filtered out. Create a corresponding + // FileScanTask and push it to the result stream + file_scan_task_tx + .send(Ok(manifest_entry_context.into_file_scan_task())) + .await?; + + Ok(()) + } +} + +struct BoundPredicates { + partition_bound_predicate: BoundPredicate, + snapshot_bound_predicate: BoundPredicate, +} + +/// Wraps a [`ManifestFile`] alongside the objects that are needed +/// to process it in a thread-safe manner +struct ManifestFileContext { + manifest_file: ManifestFile, + + sender: Sender, + + field_ids: Arc>, + bound_predicates: Option>, + object_cache: Arc, + snapshot_schema: SchemaRef, + expression_evaluator_cache: Arc, +} + +/// Wraps a [`ManifestEntryRef`] alongside the objects that are needed +/// to process it in a thread-safe manner +struct ManifestEntryContext { + manifest_entry: ManifestEntryRef, + + expression_evaluator_cache: Arc, + field_ids: Arc>, + bound_predicates: Option>, + partition_spec_id: i32, + snapshot_schema: SchemaRef, +} + +impl ManifestFileContext { + /// Consumes this [`ManifestFileContext`], fetching its Manifest from FileIO and then + /// streaming its constituent [`ManifestEntries`] to the channel provided in the context + async fn fetch_manifest_and_stream_manifest_entries(self) -> Result<()> { + let ManifestFileContext { + object_cache, + manifest_file, + bound_predicates, + snapshot_schema, + field_ids, + mut sender, + expression_evaluator_cache, + .. + } = self; + + let manifest = object_cache.get_manifest(&manifest_file).await?; + + for manifest_entry in manifest.entries() { + let manifest_entry_context = ManifestEntryContext { + // TODO: refactor to avoid clone + manifest_entry: manifest_entry.clone(), + expression_evaluator_cache: expression_evaluator_cache.clone(), + field_ids: field_ids.clone(), + partition_spec_id: manifest_file.partition_spec_id, + bound_predicates: bound_predicates.clone(), + snapshot_schema: snapshot_schema.clone(), + }; + + sender + .send(manifest_entry_context) + .map_err(|_| Error::new(ErrorKind::Unexpected, "mpsc channel SendError")) + .await?; + } + + Ok(()) + } +} + +impl ManifestEntryContext { + /// consume this `ManifestEntryContext`, returning a `FileScanTask` + /// created from it + fn into_file_scan_task(self) -> FileScanTask { + FileScanTask { + start: 0, + length: self.manifest_entry.file_size_in_bytes(), + record_count: Some(self.manifest_entry.record_count()), + + data_file_path: self.manifest_entry.file_path().to_string(), + data_file_content: self.manifest_entry.content_type(), + data_file_format: self.manifest_entry.file_format(), + + schema: self.snapshot_schema, + project_field_ids: self.field_ids.to_vec(), + predicate: self + .bound_predicates + .map(|x| x.as_ref().snapshot_bound_predicate.clone()), + } + } +} + +impl PlanContext { + async fn get_manifest_list(&self) -> Result> { + self.object_cache + .as_ref() + .get_manifest_list(&self.snapshot, &self.table_metadata) + .await + } + + fn get_partition_filter(&self, manifest_file: &ManifestFile) -> Result> { + let partition_spec_id = manifest_file.partition_spec_id; + + let partition_filter = self.partition_filter_cache.get( + partition_spec_id, + &self.table_metadata, + &self.snapshot_schema, + self.case_sensitive, + self.predicate + .as_ref() + .ok_or(Error::new( + ErrorKind::Unexpected, + "Expected a predicate but none present", + ))? + .as_ref() + .bind(self.snapshot_schema.clone(), self.case_sensitive)?, + )?; + + Ok(partition_filter) + } + + fn build_manifest_file_contexts( + &self, + manifest_list: Arc, + sender: Sender, + ) -> Result>>> { + let filtered_entries = manifest_list + .entries() + .iter() + .filter(|manifest_file| manifest_file.content == ManifestContentType::Data); + + // TODO: Ideally we could ditch this intermediate Vec as we return an iterator. + let mut filtered_mfcs = vec![]; + if self.predicate.is_some() { + for manifest_file in filtered_entries { + let partition_bound_predicate = self.get_partition_filter(manifest_file)?; + + // evaluate the ManifestFile against the partition filter. Skip + // if it cannot contain any matching rows + if self + .manifest_evaluator_cache + .get( + manifest_file.partition_spec_id, + partition_bound_predicate.clone(), + ) + .eval(manifest_file)? + { + let mfc = self.create_manifest_file_context( + manifest_file, + Some(partition_bound_predicate), + sender.clone(), + ); + filtered_mfcs.push(Ok(mfc)); + } + } + } else { + for manifest_file in filtered_entries { + let mfc = self.create_manifest_file_context(manifest_file, None, sender.clone()); + filtered_mfcs.push(Ok(mfc)); + } + } + + Ok(Box::new(filtered_mfcs.into_iter())) + } + + fn create_manifest_file_context( + &self, + manifest_file: &ManifestFile, + partition_filter: Option>, + sender: Sender, + ) -> ManifestFileContext { + let bound_predicates = + if let (Some(ref partition_bound_predicate), Some(snapshot_bound_predicate)) = + (partition_filter, &self.snapshot_bound_predicate) + { + Some(Arc::new(BoundPredicates { + partition_bound_predicate: partition_bound_predicate.as_ref().clone(), + snapshot_bound_predicate: snapshot_bound_predicate.as_ref().clone(), + })) + } else { + None + }; + + ManifestFileContext { + manifest_file: manifest_file.clone(), + bound_predicates, + sender, + object_cache: self.object_cache.clone(), + snapshot_schema: self.snapshot_schema.clone(), + field_ids: self.field_ids.clone(), + expression_evaluator_cache: self.expression_evaluator_cache.clone(), + } + } +} + +/// Manages the caching of [`BoundPredicate`] objects +/// for [`PartitionSpec`]s based on partition spec id. +#[derive(Debug)] +struct PartitionFilterCache(RwLock>>); + +impl PartitionFilterCache { + /// Creates a new [`PartitionFilterCache`] + /// with an empty internal HashMap. + fn new() -> Self { + Self(RwLock::new(HashMap::new())) + } + + /// Retrieves a [`BoundPredicate`] from the cache + /// or computes it if not present. + fn get( + &self, + spec_id: i32, + table_metadata: &TableMetadataRef, + schema: &SchemaRef, + case_sensitive: bool, + filter: BoundPredicate, + ) -> Result> { + // we need a block here to ensure that the `read()` gets dropped before we hit the `write()` + // below, otherwise we hit deadlock + { + let read = self.0.read().map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "PartitionFilterCache RwLock was poisoned", + ) + })?; + + if read.contains_key(&spec_id) { + return Ok(read.get(&spec_id).unwrap().clone()); + } + } + + let partition_spec = table_metadata + .partition_spec_by_id(spec_id) + .ok_or(Error::new( + ErrorKind::Unexpected, + format!("Could not find partition spec for id {}", spec_id), + ))?; + + let partition_type = partition_spec.partition_type(schema.as_ref())?; + let partition_fields = partition_type.fields().to_owned(); + let partition_schema = Arc::new( + Schema::builder() + .with_schema_id(partition_spec.spec_id) + .with_fields(partition_fields) + .build()?, + ); + + let mut inclusive_projection = InclusiveProjection::new(partition_spec.clone()); + + let partition_filter = inclusive_projection + .project(&filter)? + .rewrite_not() + .bind(partition_schema.clone(), case_sensitive)?; + + self.0 + .write() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "PartitionFilterCache RwLock was poisoned", + ) + })? + .insert(spec_id, Arc::new(partition_filter)); + + let read = self.0.read().map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "PartitionFilterCache RwLock was poisoned", + ) + })?; + + Ok(read.get(&spec_id).unwrap().clone()) + } +} + +/// Manages the caching of [`ManifestEvaluator`] objects +/// for [`PartitionSpec`]s based on partition spec id. +#[derive(Debug)] +struct ManifestEvaluatorCache(RwLock>>); + +impl ManifestEvaluatorCache { + /// Creates a new [`ManifestEvaluatorCache`] + /// with an empty internal HashMap. + fn new() -> Self { + Self(RwLock::new(HashMap::new())) + } + + /// Retrieves a [`ManifestEvaluator`] from the cache + /// or computes it if not present. + fn get(&self, spec_id: i32, partition_filter: Arc) -> Arc { + // we need a block here to ensure that the `read()` gets dropped before we hit the `write()` + // below, otherwise we hit deadlock + { + let read = self + .0 + .read() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap(); + + if read.contains_key(&spec_id) { + return read.get(&spec_id).unwrap().clone(); + } + } + + self.0 + .write() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap() + .insert( + spec_id, + Arc::new(ManifestEvaluator::new(partition_filter.as_ref().clone())), + ); + + let read = self + .0 + .read() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap(); + + read.get(&spec_id).unwrap().clone() + } +} + +/// Manages the caching of [`ExpressionEvaluator`] objects +/// for [`PartitionSpec`]s based on partition spec id. +#[derive(Debug)] +struct ExpressionEvaluatorCache(RwLock>>); + +impl ExpressionEvaluatorCache { + /// Creates a new [`ExpressionEvaluatorCache`] + /// with an empty internal HashMap. + fn new() -> Self { + Self(RwLock::new(HashMap::new())) + } + + /// Retrieves a [`ExpressionEvaluator`] from the cache + /// or computes it if not present. + fn get( + &self, + spec_id: i32, + partition_filter: &BoundPredicate, + ) -> Result> { + // we need a block here to ensure that the `read()` gets dropped before we hit the `write()` + // below, otherwise we hit deadlock + { + let read = self.0.read().map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "PartitionFilterCache RwLock was poisoned", + ) + })?; + + if read.contains_key(&spec_id) { + return Ok(read.get(&spec_id).unwrap().clone()); + } + } + + self.0 + .write() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap() + .insert( + spec_id, + Arc::new(ExpressionEvaluator::new(partition_filter.clone())), + ); + + let read = self + .0 + .read() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap(); + + Ok(read.get(&spec_id).unwrap().clone()) + } +} + +/// A task to scan part of file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FileScanTask { + /// The start offset of the file to scan. + pub start: u64, + /// The length of the file to scan. + pub length: u64, + /// The number of records in the file to scan. + /// + /// This is an optional field, and only available if we are + /// reading the entire data file. + pub record_count: Option, + + /// The data file path corresponding to the task. + pub data_file_path: String, + /// The content type of the file to scan. + pub data_file_content: DataContentType, + /// The format of the file to scan. + pub data_file_format: DataFileFormat, + + /// The schema of the file to scan. + pub schema: SchemaRef, + /// The field ids to project. + pub project_field_ids: Vec, + /// The predicate to filter. + #[serde(skip_serializing_if = "Option::is_none")] + pub predicate: Option, +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::fs; + use std::fs::File; + use std::sync::Arc; + + use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; + use futures::{stream, TryStreamExt}; + use parquet::arrow::{ArrowWriter, PARQUET_FIELD_ID_META_KEY}; + use parquet::basic::Compression; + use parquet::file::properties::WriterProperties; + use tempfile::TempDir; + use tera::{Context, Tera}; + use uuid::Uuid; + + use crate::arrow::ArrowReaderBuilder; + use crate::expr::{BoundPredicate, Reference}; + use crate::io::{FileIO, OutputFile}; + use crate::scan::FileScanTask; + use crate::spec::{ + DataContentType, DataFileBuilder, DataFileFormat, Datum, FormatVersion, Literal, Manifest, + ManifestContentType, ManifestEntry, ManifestListWriter, ManifestMetadata, ManifestStatus, + ManifestWriter, NestedField, PrimitiveType, Schema, Struct, TableMetadata, Type, + EMPTY_SNAPSHOT_ID, + }; + use crate::table::Table; + use crate::TableIdent; + + struct TableTestFixture { + table_location: String, + table: Table, + } + + impl TableTestFixture { + fn new() -> Self { + let tmp_dir = TempDir::new().unwrap(); + let table_location = tmp_dir.path().join("table1"); + let manifest_list1_location = table_location.join("metadata/manifests_list_1.avro"); + let manifest_list2_location = table_location.join("metadata/manifests_list_2.avro"); + let table_metadata1_location = table_location.join("metadata/v1.json"); + + let file_io = FileIO::from_path(table_location.as_os_str().to_str().unwrap()) + .unwrap() + .build() + .unwrap(); + + let table_metadata = { + let template_json_str = fs::read_to_string(format!( + "{}/testdata/example_table_metadata_v2.json", + env!("CARGO_MANIFEST_DIR") + )) + .unwrap(); + let mut context = Context::new(); + context.insert("table_location", &table_location); + context.insert("manifest_list_1_location", &manifest_list1_location); + context.insert("manifest_list_2_location", &manifest_list2_location); + context.insert("table_metadata_1_location", &table_metadata1_location); + + let metadata_json = Tera::one_off(&template_json_str, &context, false).unwrap(); + serde_json::from_str::(&metadata_json).unwrap() + }; + + let table = Table::builder() + .metadata(table_metadata) + .identifier(TableIdent::from_strs(["db", "table1"]).unwrap()) + .file_io(file_io.clone()) + .metadata_location(table_metadata1_location.as_os_str().to_str().unwrap()) + .build() + .unwrap(); + + Self { + table_location: table_location.to_str().unwrap().to_string(), + table, + } + } + + fn next_manifest_file(&self) -> OutputFile { + self.table + .file_io() + .new_output(format!( + "{}/metadata/manifest_{}.avro", + self.table_location, + Uuid::new_v4() + )) + .unwrap() + } + + async fn setup_manifest_files(&mut self) { + let current_snapshot = self.table.metadata().current_snapshot().unwrap(); + let parent_snapshot = current_snapshot + .parent_snapshot(self.table.metadata()) + .unwrap(); + let current_schema = current_snapshot.schema(self.table.metadata()).unwrap(); + let current_partition_spec = self.table.metadata().default_partition_spec().unwrap(); + + // Write data files + let data_file_manifest = ManifestWriter::new( + self.next_manifest_file(), + current_snapshot.snapshot_id(), + vec![], + ) + .write(Manifest::new( + ManifestMetadata::builder() + .schema((*current_schema).clone()) + .content(ManifestContentType::Data) + .format_version(FormatVersion::V2) + .partition_spec((**current_partition_spec).clone()) + .schema_id(current_schema.schema_id()) + .build(), + vec![ + ManifestEntry::builder() + .status(ManifestStatus::Added) + .data_file( + DataFileBuilder::default() + .content(DataContentType::Data) + .file_path(format!("{}/1.parquet", &self.table_location)) + .file_format(DataFileFormat::Parquet) + .file_size_in_bytes(100) + .record_count(1) + .partition(Struct::from_iter([Some(Literal::long(100))])) + .build() + .unwrap(), + ) + .build(), + ManifestEntry::builder() + .status(ManifestStatus::Deleted) + .snapshot_id(parent_snapshot.snapshot_id()) + .sequence_number(parent_snapshot.sequence_number()) + .file_sequence_number(parent_snapshot.sequence_number()) + .data_file( + DataFileBuilder::default() + .content(DataContentType::Data) + .file_path(format!("{}/2.parquet", &self.table_location)) + .file_format(DataFileFormat::Parquet) + .file_size_in_bytes(100) + .record_count(1) + .partition(Struct::from_iter([Some(Literal::long(200))])) + .build() + .unwrap(), + ) + .build(), + ManifestEntry::builder() + .status(ManifestStatus::Existing) + .snapshot_id(parent_snapshot.snapshot_id()) + .sequence_number(parent_snapshot.sequence_number()) + .file_sequence_number(parent_snapshot.sequence_number()) + .data_file( + DataFileBuilder::default() + .content(DataContentType::Data) + .file_path(format!("{}/3.parquet", &self.table_location)) + .file_format(DataFileFormat::Parquet) + .file_size_in_bytes(100) + .record_count(1) + .partition(Struct::from_iter([Some(Literal::long(300))])) + .build() + .unwrap(), + ) + .build(), + ], + )) + .await + .unwrap(); + + // Write to manifest list + let mut manifest_list_write = ManifestListWriter::v2( + self.table + .file_io() + .new_output(current_snapshot.manifest_list()) + .unwrap(), + current_snapshot.snapshot_id(), + current_snapshot + .parent_snapshot_id() + .unwrap_or(EMPTY_SNAPSHOT_ID), + current_snapshot.sequence_number(), + ); + manifest_list_write + .add_manifests(vec![data_file_manifest].into_iter()) + .unwrap(); + manifest_list_write.close().await.unwrap(); + + // prepare data + let schema = { + let fields = vec![ + arrow_schema::Field::new("x", arrow_schema::DataType::Int64, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "1".to_string(), + )])), + arrow_schema::Field::new("y", arrow_schema::DataType::Int64, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "2".to_string(), + )])), + arrow_schema::Field::new("z", arrow_schema::DataType::Int64, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "3".to_string(), + )])), + arrow_schema::Field::new("a", arrow_schema::DataType::Utf8, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "4".to_string(), + )])), + ]; + Arc::new(arrow_schema::Schema::new(fields)) + }; + // 4 columns: + // x: [1, 1, 1, 1, ...] + let col1 = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; + + let mut values = vec![2; 512]; + values.append(vec![3; 200].as_mut()); + values.append(vec![4; 300].as_mut()); + values.append(vec![5; 12].as_mut()); + + // y: [2, 2, 2, 2, ..., 3, 3, 3, 3, ..., 4, 4, 4, 4, ..., 5, 5, 5, 5] + let col2 = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + + let mut values = vec![3; 512]; + values.append(vec![4; 512].as_mut()); + + // z: [3, 3, 3, 3, ..., 4, 4, 4, 4] + let col3 = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + + // a: ["Apache", "Apache", "Apache", ..., "Iceberg", "Iceberg", "Iceberg"] + let mut values = vec!["Apache"; 512]; + values.append(vec!["Iceberg"; 512].as_mut()); + let col4 = Arc::new(StringArray::from_iter_values(values)) as ArrayRef; + + let to_write = + RecordBatch::try_new(schema.clone(), vec![col1, col2, col3, col4]).unwrap(); + + // Write the Parquet files + let props = WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(); + + for n in 1..=3 { + let file = File::create(format!("{}/{}.parquet", &self.table_location, n)).unwrap(); + let mut writer = + ArrowWriter::try_new(file, to_write.schema(), Some(props.clone())).unwrap(); + + writer.write(&to_write).expect("Writing batch"); + + // writer must be closed to write footer + writer.close().unwrap(); + } + } + } + + #[test] + fn test_table_scan_columns() { + let table = TableTestFixture::new().table; + + let table_scan = table.scan().select(["x", "y"]).build().unwrap(); + assert_eq!(vec!["x", "y"], table_scan.column_names); + + let table_scan = table + .scan() + .select(["x", "y"]) + .select(["z"]) + .build() + .unwrap(); + assert_eq!(vec!["z"], table_scan.column_names); + } + + #[test] + fn test_select_all() { + let table = TableTestFixture::new().table; + + let table_scan = table.scan().select_all().build().unwrap(); + assert!(table_scan.column_names.is_empty()); + } + + #[test] + fn test_select_no_exist_column() { + let table = TableTestFixture::new().table; + + let table_scan = table.scan().select(["x", "y", "z", "a", "b"]).build(); + assert!(table_scan.is_err()); + } + + #[test] + fn test_table_scan_default_snapshot_id() { + let table = TableTestFixture::new().table; + + let table_scan = table.scan().build().unwrap(); + assert_eq!( + table.metadata().current_snapshot().unwrap().snapshot_id(), + table_scan.snapshot().snapshot_id() + ); + } + + #[test] + fn test_table_scan_non_exist_snapshot_id() { + let table = TableTestFixture::new().table; + + let table_scan = table.scan().snapshot_id(1024).build(); + assert!(table_scan.is_err()); + } + + #[test] + fn test_table_scan_with_snapshot_id() { + let table = TableTestFixture::new().table; + + let table_scan = table + .scan() + .snapshot_id(3051729675574597004) + .build() + .unwrap(); + assert_eq!(table_scan.snapshot().snapshot_id(), 3051729675574597004); + } + + #[tokio::test] + async fn test_plan_files_no_deletions() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Create table scan for current snapshot and plan files + let table_scan = fixture.table.scan().build().unwrap(); + let mut tasks = table_scan + .plan_files() + .await + .unwrap() + .try_fold(vec![], |mut acc, task| async move { + acc.push(task); + Ok(acc) + }) + .await + .unwrap(); + + assert_eq!(tasks.len(), 2); + + tasks.sort_by_key(|t| t.data_file_path.to_string()); + + // Check first task is added data file + assert_eq!( + tasks[0].data_file_path, + format!("{}/1.parquet", &fixture.table_location) + ); + + // Check second task is existing data file + assert_eq!( + tasks[1].data_file_path, + format!("{}/3.parquet", &fixture.table_location) + ); + } + + #[tokio::test] + async fn test_open_parquet_no_deletions() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Create table scan for current snapshot and plan files + let table_scan = fixture.table.scan().build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + let col = batches[0].column_by_name("x").unwrap(); + + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + } + + #[tokio::test] + async fn test_open_parquet_no_deletions_by_separate_reader() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Create table scan for current snapshot and plan files + let table_scan = fixture.table.scan().build().unwrap(); + + let mut plan_task: Vec<_> = table_scan + .plan_files() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + assert_eq!(plan_task.len(), 2); + + let reader = ArrowReaderBuilder::new(fixture.table.file_io().clone()).build(); + let batch_stream = reader + .clone() + .read(Box::pin(stream::iter(vec![Ok(plan_task.remove(0))]))) + .unwrap(); + let batche1: Vec<_> = batch_stream.try_collect().await.unwrap(); + + let reader = ArrowReaderBuilder::new(fixture.table.file_io().clone()).build(); + let batch_stream = reader + .read(Box::pin(stream::iter(vec![Ok(plan_task.remove(0))]))) + .unwrap(); + let batche2: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batche1, batche2); + } + + #[tokio::test] + async fn test_open_parquet_with_projection() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Create table scan for current snapshot and plan files + let table_scan = fixture.table.scan().select(["x", "z"]).build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_columns(), 2); + + let col1 = batches[0].column_by_name("x").unwrap(); + let int64_arr = col1.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col2 = batches[0].column_by_name("z").unwrap(); + let int64_arr = col2.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 3); + } + + #[tokio::test] + async fn test_filter_on_arrow_lt() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y < 3 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").less_than(Datum::long(3)); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 2); + } + + #[tokio::test] + async fn test_filter_on_arrow_gt_eq() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y >= 5 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").greater_than_or_equal_to(Datum::long(5)); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 12); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 5); + } + + #[tokio::test] + async fn test_filter_on_arrow_is_null() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y is null + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").is_null(); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 0); + } + + #[tokio::test] + async fn test_filter_on_arrow_is_not_null() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y is not null + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").is_not_null(); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches[0].num_rows(), 1024); + } + + #[tokio::test] + async fn test_filter_on_arrow_lt_and_gt() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y < 5 AND z >= 4 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y") + .less_than(Datum::long(5)) + .and(Reference::new("z").greater_than_or_equal_to(Datum::long(4))); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches[0].num_rows(), 500); + + let col = batches[0].column_by_name("x").unwrap(); + let expected_x = Arc::new(Int64Array::from_iter_values(vec![1; 500])) as ArrayRef; + assert_eq!(col, &expected_x); + + let col = batches[0].column_by_name("y").unwrap(); + let mut values = vec![]; + values.append(vec![3; 200].as_mut()); + values.append(vec![4; 300].as_mut()); + let expected_y = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + assert_eq!(col, &expected_y); + + let col = batches[0].column_by_name("z").unwrap(); + let expected_z = Arc::new(Int64Array::from_iter_values(vec![4; 500])) as ArrayRef; + assert_eq!(col, &expected_z); + } + + #[tokio::test] + async fn test_filter_on_arrow_lt_or_gt() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y < 5 AND z >= 4 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y") + .less_than(Datum::long(5)) + .or(Reference::new("z").greater_than_or_equal_to(Datum::long(4))); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches[0].num_rows(), 1024); + + let col = batches[0].column_by_name("x").unwrap(); + let expected_x = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; + assert_eq!(col, &expected_x); + + let col = batches[0].column_by_name("y").unwrap(); + let mut values = vec![2; 512]; + values.append(vec![3; 200].as_mut()); + values.append(vec![4; 300].as_mut()); + values.append(vec![5; 12].as_mut()); + let expected_y = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + assert_eq!(col, &expected_y); + + let col = batches[0].column_by_name("z").unwrap(); + let mut values = vec![3; 512]; + values.append(vec![4; 512].as_mut()); + let expected_z = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + assert_eq!(col, &expected_z); + } + + #[tokio::test] + async fn test_filter_on_arrow_startswith() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: a STARTSWITH "Ice" + let mut builder = fixture.table.scan(); + let predicate = Reference::new("a").starts_with(Datum::string("Ice")); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("a").unwrap(); + let string_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.value(0), "Iceberg"); + } + + #[tokio::test] + async fn test_filter_on_arrow_not_startswith() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: a NOT STARTSWITH "Ice" + let mut builder = fixture.table.scan(); + let predicate = Reference::new("a").not_starts_with(Datum::string("Ice")); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("a").unwrap(); + let string_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.value(0), "Apache"); + } + + #[tokio::test] + async fn test_filter_on_arrow_in() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: a IN ("Sioux", "Iceberg") + let mut builder = fixture.table.scan(); + let predicate = + Reference::new("a").is_in([Datum::string("Sioux"), Datum::string("Iceberg")]); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("a").unwrap(); + let string_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.value(0), "Iceberg"); + } + + #[tokio::test] + async fn test_filter_on_arrow_not_in() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: a NOT IN ("Sioux", "Iceberg") + let mut builder = fixture.table.scan(); + let predicate = + Reference::new("a").is_not_in([Datum::string("Sioux"), Datum::string("Iceberg")]); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("a").unwrap(); + let string_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.value(0), "Apache"); + } + + #[test] + fn test_file_scan_task_serialize_deserialize() { + let test_fn = |task: FileScanTask| { + let serialized = serde_json::to_string(&task).unwrap(); + let deserialized: FileScanTask = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(task.data_file_path, deserialized.data_file_path); + assert_eq!(task.start, deserialized.start); + assert_eq!(task.length, deserialized.length); + assert_eq!(task.project_field_ids, deserialized.project_field_ids); + assert_eq!(task.predicate, deserialized.predicate); + assert_eq!(task.schema, deserialized.schema); + }; + + // without predicate + let schema = Arc::new( + Schema::builder() + .with_fields(vec![Arc::new(NestedField::required( + 1, + "x", + Type::Primitive(PrimitiveType::Binary), + ))]) + .build() + .unwrap(), + ); + let task = FileScanTask { + data_file_path: "data_file_path".to_string(), + data_file_content: DataContentType::Data, + start: 0, + length: 100, + project_field_ids: vec![1, 2, 3], + predicate: None, + schema: schema.clone(), + record_count: Some(100), + data_file_format: DataFileFormat::Parquet, + }; + test_fn(task); + + // with predicate + let task = FileScanTask { + data_file_path: "data_file_path".to_string(), + data_file_content: DataContentType::Data, + start: 0, + length: 100, + project_field_ids: vec![1, 2, 3], + predicate: Some(BoundPredicate::AlwaysTrue), + schema, + record_count: None, + data_file_format: DataFileFormat::Avro, + }; + test_fn(task); + } +} diff --git a/crates/iceberg/src/spec/datatypes.rs b/crates/iceberg/src/spec/datatypes.rs index 172cb6417..d38245960 100644 --- a/crates/iceberg/src/spec/datatypes.rs +++ b/crates/iceberg/src/spec/datatypes.rs @@ -17,20 +17,23 @@ /*! * Data Types -*/ -use crate::ensure_data_valid; -use crate::error::Result; -use crate::spec::datatypes::_decimal::{MAX_PRECISION, REQUIRED_LENGTH}; + */ +use std::collections::HashMap; +use std::convert::identity; +use std::fmt; +use std::ops::Index; +use std::sync::{Arc, OnceLock}; + use ::serde::de::{MapAccess, Visitor}; use serde::de::{Error, IntoDeserializer}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serde_json::Value as JsonValue; -use std::convert::identity; -use std::sync::Arc; -use std::sync::OnceLock; -use std::{collections::HashMap, fmt, ops::Index}; use super::values::Literal; +use crate::ensure_data_valid; +use crate::error::Result; +use crate::spec::datatypes::_decimal::{MAX_PRECISION, REQUIRED_LENGTH}; +use crate::spec::PrimitiveLiteral; /// Field name for list type. pub(crate) const LIST_FILED_NAME: &str = "element"; @@ -41,38 +44,37 @@ pub(crate) const MAX_DECIMAL_BYTES: u32 = 24; pub(crate) const MAX_DECIMAL_PRECISION: u32 = 38; mod _decimal { - use lazy_static::lazy_static; + use once_cell::sync::Lazy; use crate::spec::{MAX_DECIMAL_BYTES, MAX_DECIMAL_PRECISION}; - lazy_static! { - // Max precision of bytes, starts from 1 - pub(super) static ref MAX_PRECISION: [u32; MAX_DECIMAL_BYTES as usize] = { - let mut ret: [u32; 24] = [0; 24]; - for (i, prec) in ret.iter_mut().enumerate() { - *prec = 2f64.powi((8 * (i + 1) - 1) as i32).log10().floor() as u32; - } + // Max precision of bytes, starts from 1 + pub(super) static MAX_PRECISION: Lazy<[u32; MAX_DECIMAL_BYTES as usize]> = Lazy::new(|| { + let mut ret: [u32; 24] = [0; 24]; + for (i, prec) in ret.iter_mut().enumerate() { + *prec = 2f64.powi((8 * (i + 1) - 1) as i32).log10().floor() as u32; + } - ret - }; + ret + }); - // Required bytes of precision, starts from 1 - pub(super) static ref REQUIRED_LENGTH: [u32; MAX_DECIMAL_PRECISION as usize] = { - let mut ret: [u32; MAX_DECIMAL_PRECISION as usize] = [0; MAX_DECIMAL_PRECISION as usize]; + // Required bytes of precision, starts from 1 + pub(super) static REQUIRED_LENGTH: Lazy<[u32; MAX_DECIMAL_PRECISION as usize]> = + Lazy::new(|| { + let mut ret: [u32; MAX_DECIMAL_PRECISION as usize] = + [0; MAX_DECIMAL_PRECISION as usize]; for (i, required_len) in ret.iter_mut().enumerate() { for j in 0..MAX_PRECISION.len() { - if MAX_PRECISION[j] >= ((i+1) as u32) { - *required_len = (j+1) as u32; + if MAX_PRECISION[j] >= ((i + 1) as u32) { + *required_len = (j + 1) as u32; break; } } } ret - }; - - } + }); } #[derive(Debug, PartialEq, Eq, Clone)] @@ -112,6 +114,30 @@ impl Type { matches!(self, Type::Struct(_)) } + /// Whether the type is nested type. + #[inline(always)] + pub fn is_nested(&self) -> bool { + matches!(self, Type::Struct(_) | Type::List(_) | Type::Map(_)) + } + + /// Convert Type to reference of PrimitiveType + pub fn as_primitive_type(&self) -> Option<&PrimitiveType> { + if let Type::Primitive(primitive_type) = self { + Some(primitive_type) + } else { + None + } + } + + /// Convert Type to StructType + pub fn to_struct_type(self) -> Option { + if let Type::Struct(struct_type) = self { + Some(struct_type) + } else { + None + } + } + /// Return max precision for decimal given [`num_bytes`] bytes. #[inline(always)] pub fn decimal_max_precision(num_bytes: u32) -> Result { @@ -135,6 +161,15 @@ impl Type { ensure_data_valid!(precision > 0 && precision <= MAX_DECIMAL_PRECISION, "Decimals with precision larger than {MAX_DECIMAL_PRECISION} are not supported: {precision}",); Ok(Type::Primitive(PrimitiveType::Decimal { precision, scale })) } + + /// Check if it's float or double type. + #[inline(always)] + pub fn is_floating_type(&self) -> bool { + matches!( + self, + Type::Primitive(PrimitiveType::Float) | Type::Primitive(PrimitiveType::Double) + ) + } } impl From for Type { @@ -162,7 +197,7 @@ impl From for Type { } /// Primitive data types -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)] #[serde(rename_all = "lowercase", remote = "Self")] pub enum PrimitiveType { /// True or False @@ -171,28 +206,32 @@ pub enum PrimitiveType { Int, /// 64-bit signed integer Long, - /// 32-bit IEEE 754 floating bit. + /// 32-bit IEEE 754 floating point. Float, - /// 64-bit IEEE 754 floating bit. + /// 64-bit IEEE 754 floating point. Double, /// Fixed point decimal Decimal { - /// Precision + /// Precision, must be 38 or less precision: u32, /// Scale scale: u32, }, /// Calendar date without timezone or time. Date, - /// Time of day without date or timezone. + /// Time of day in microsecond precision, without date or timezone. Time, - /// Timestamp without timezone + /// Timestamp in microsecond precision, without timezone Timestamp, - /// Timestamp with timezone + /// Timestamp in microsecond precision, with timezone Timestamptz, + /// Timestamp in nanosecond precision, without timezone + TimestampNs, + /// Timestamp in nanosecond precision with timezone + TimestamptzNs, /// Arbitrary-length character sequences encoded in utf-8 String, - /// Universally Unique Identifiers + /// Universally Unique Identifiers, should use 16-byte fixed Uuid, /// Fixed length byte array Fixed(u64), @@ -200,11 +239,34 @@ pub enum PrimitiveType { Binary, } +impl PrimitiveType { + /// Check whether literal is compatible with the type. + pub fn compatible(&self, literal: &PrimitiveLiteral) -> bool { + matches!( + (self, literal), + (PrimitiveType::Boolean, PrimitiveLiteral::Boolean(_)) + | (PrimitiveType::Int, PrimitiveLiteral::Int(_)) + | (PrimitiveType::Long, PrimitiveLiteral::Long(_)) + | (PrimitiveType::Float, PrimitiveLiteral::Float(_)) + | (PrimitiveType::Double, PrimitiveLiteral::Double(_)) + | (PrimitiveType::Decimal { .. }, PrimitiveLiteral::Int128(_)) + | (PrimitiveType::Date, PrimitiveLiteral::Int(_)) + | (PrimitiveType::Time, PrimitiveLiteral::Long(_)) + | (PrimitiveType::Timestamp, PrimitiveLiteral::Long(_)) + | (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(_)) + | (PrimitiveType::TimestampNs, PrimitiveLiteral::Long(_)) + | (PrimitiveType::TimestamptzNs, PrimitiveLiteral::Long(_)) + | (PrimitiveType::String, PrimitiveLiteral::String(_)) + | (PrimitiveType::Uuid, PrimitiveLiteral::UInt128(_)) + | (PrimitiveType::Fixed(_), PrimitiveLiteral::Binary(_)) + | (PrimitiveType::Binary, PrimitiveLiteral::Binary(_)) + ) + } +} + impl Serialize for Type { fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { + where S: Serializer { let type_serde = _serde::SerdeType::from(self); type_serde.serialize(serializer) } @@ -212,9 +274,7 @@ impl Serialize for Type { impl<'de> Deserialize<'de> for Type { fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { + where D: Deserializer<'de> { let type_serde = _serde::SerdeType::deserialize(deserializer)?; Ok(Type::from(type_serde)) } @@ -222,9 +282,7 @@ impl<'de> Deserialize<'de> for Type { impl<'de> Deserialize<'de> for PrimitiveType { fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { + where D: Deserializer<'de> { let s = String::deserialize(deserializer)?; if s.starts_with("decimal") { deserialize_decimal(s.into_deserializer()) @@ -238,9 +296,7 @@ impl<'de> Deserialize<'de> for PrimitiveType { impl Serialize for PrimitiveType { fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { + where S: Serializer { match self { PrimitiveType::Decimal { precision, scale } => { serialize_decimal(precision, scale, serializer) @@ -252,9 +308,7 @@ impl Serialize for PrimitiveType { } fn deserialize_decimal<'de, D>(deserializer: D) -> std::result::Result -where - D: Deserializer<'de>, -{ +where D: Deserializer<'de> { let s = String::deserialize(deserializer)?; let (precision, scale) = s .trim_start_matches(r"decimal(") @@ -280,9 +334,7 @@ where } fn deserialize_fixed<'de, D>(deserializer: D) -> std::result::Result -where - D: Deserializer<'de>, -{ +where D: Deserializer<'de> { let fixed = String::deserialize(deserializer)? .trim_start_matches(r"fixed[") .trim_end_matches(']') @@ -295,9 +347,7 @@ where } fn serialize_fixed(value: &u64, serializer: S) -> std::result::Result -where - S: Serializer, -{ +where S: Serializer { serializer.serialize_str(&format!("fixed[{value}]")) } @@ -316,6 +366,8 @@ impl fmt::Display for PrimitiveType { PrimitiveType::Time => write!(f, "time"), PrimitiveType::Timestamp => write!(f, "timestamp"), PrimitiveType::Timestamptz => write!(f, "timestamptz"), + PrimitiveType::TimestampNs => write!(f, "timestamp_ns"), + PrimitiveType::TimestamptzNs => write!(f, "timestamptz_ns"), PrimitiveType::String => write!(f, "string"), PrimitiveType::Uuid => write!(f, "uuid"), PrimitiveType::Fixed(size) => write!(f, "fixed({})", size), @@ -325,7 +377,7 @@ impl fmt::Display for PrimitiveType { } /// DataType for a specific struct -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize, Clone, Default)] #[serde(rename = "struct", tag = "type")] pub struct StructType { /// Struct fields @@ -339,9 +391,7 @@ pub struct StructType { impl<'de> Deserialize<'de> for StructType { fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { + where D: Deserializer<'de> { #[derive(Deserialize)] #[serde(field_identifier, rename_all = "lowercase")] enum Field { @@ -359,9 +409,7 @@ impl<'de> Deserialize<'de> for StructType { } fn visit_map(self, mut map: V) -> std::result::Result - where - V: MapAccess<'de>, - { + where V: MapAccess<'de> { let mut fields = None; while let Some(key) = map.next_key()? { match key { @@ -541,6 +589,19 @@ impl From for SerdeNestedField { pub type NestedFieldRef = Arc; impl NestedField { + /// Construct a new field. + pub fn new(id: i32, name: impl ToString, field_type: Type, required: bool) -> Self { + Self { + id, + name: name.to_string(), + required, + field_type: Box::new(field_type), + doc: None, + initial_default: None, + write_default: None, + } + } + /// Construct a required field. pub fn required(id: i32, name: impl ToString, field_type: Type) -> Self { Self { @@ -634,14 +695,23 @@ pub struct ListType { pub element_field: NestedFieldRef, } +impl ListType { + /// Construct a list type with the given element field. + pub fn new(element_field: NestedFieldRef) -> Self { + Self { element_field } + } +} + /// Module for type serialization/deserialization. pub(super) mod _serde { + use std::borrow::Cow; + + use serde_derive::{Deserialize, Serialize}; + use crate::spec::datatypes::Type::Map; use crate::spec::datatypes::{ ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, StructType, Type, }; - use serde_derive::{Deserialize, Serialize}; - use std::borrow::Cow; /// List type for serialization and deserialization #[derive(Serialize, Deserialize)] @@ -749,14 +819,23 @@ pub struct MapType { pub value_field: NestedFieldRef, } +impl MapType { + /// Construct a map type with the given key and value fields. + pub fn new(key_field: NestedFieldRef, value_field: NestedFieldRef) -> Self { + Self { + key_field, + value_field, + } + } +} + #[cfg(test)] mod tests { use pretty_assertions::assert_eq; use uuid::Uuid; - use crate::spec::values::PrimitiveLiteral; - use super::*; + use crate::spec::values::PrimitiveLiteral; fn check_type_serde(json: &str, expected_type: Type) { let desered_type: Type = serde_json::from_str(json).unwrap(); @@ -862,11 +941,15 @@ mod tests { Type::Struct(StructType { fields: vec![ NestedField::required(1, "id", Type::Primitive(PrimitiveType::Uuid)) - .with_initial_default(Literal::Primitive(PrimitiveLiteral::UUID( - Uuid::parse_str("0db3e2a8-9d1d-42b9-aa7b-74ebe558dceb").unwrap(), + .with_initial_default(Literal::Primitive(PrimitiveLiteral::UInt128( + Uuid::parse_str("0db3e2a8-9d1d-42b9-aa7b-74ebe558dceb") + .unwrap() + .as_u128(), ))) - .with_write_default(Literal::Primitive(PrimitiveLiteral::UUID( - Uuid::parse_str("ec5911be-b0a7-458c-8438-c9a3e53cffae").unwrap(), + .with_write_default(Literal::Primitive(PrimitiveLiteral::UInt128( + Uuid::parse_str("ec5911be-b0a7-458c-8438-c9a3e53cffae") + .unwrap() + .as_u128(), ))) .into(), NestedField::optional(2, "data", Type::Primitive(PrimitiveType::Int)).into(), @@ -931,11 +1014,15 @@ mod tests { let struct_type = Type::Struct(StructType::new(vec![ NestedField::required(1, "id", Type::Primitive(PrimitiveType::Uuid)) - .with_initial_default(Literal::Primitive(PrimitiveLiteral::UUID( - Uuid::parse_str("0db3e2a8-9d1d-42b9-aa7b-74ebe558dceb").unwrap(), + .with_initial_default(Literal::Primitive(PrimitiveLiteral::UInt128( + Uuid::parse_str("0db3e2a8-9d1d-42b9-aa7b-74ebe558dceb") + .unwrap() + .as_u128(), ))) - .with_write_default(Literal::Primitive(PrimitiveLiteral::UUID( - Uuid::parse_str("ec5911be-b0a7-458c-8438-c9a3e53cffae").unwrap(), + .with_write_default(Literal::Primitive(PrimitiveLiteral::UInt128( + Uuid::parse_str("ec5911be-b0a7-458c-8438-c9a3e53cffae") + .unwrap() + .as_u128(), ))) .into(), NestedField::optional(2, "data", Type::Primitive(PrimitiveType::Int)).into(), @@ -1053,4 +1140,37 @@ mod tests { assert_eq!(5, Type::decimal_required_bytes(10).unwrap()); assert_eq!(16, Type::decimal_required_bytes(38).unwrap()); } + + #[test] + fn test_primitive_type_compatitable() { + let pairs = vec![ + (PrimitiveType::Boolean, PrimitiveLiteral::Boolean(true)), + (PrimitiveType::Int, PrimitiveLiteral::Int(1)), + (PrimitiveType::Long, PrimitiveLiteral::Long(1)), + (PrimitiveType::Float, PrimitiveLiteral::Float(1.0.into())), + (PrimitiveType::Double, PrimitiveLiteral::Double(1.0.into())), + ( + PrimitiveType::Decimal { + precision: 9, + scale: 2, + }, + PrimitiveLiteral::Int128(1), + ), + (PrimitiveType::Date, PrimitiveLiteral::Int(1)), + (PrimitiveType::Time, PrimitiveLiteral::Long(1)), + (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(1)), + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(1)), + (PrimitiveType::TimestamptzNs, PrimitiveLiteral::Long(1)), + (PrimitiveType::TimestampNs, PrimitiveLiteral::Long(1)), + ( + PrimitiveType::Uuid, + PrimitiveLiteral::UInt128(Uuid::new_v4().as_u128()), + ), + (PrimitiveType::Fixed(8), PrimitiveLiteral::Binary(vec![1])), + (PrimitiveType::Binary, PrimitiveLiteral::Binary(vec![1])), + ]; + for (ty, literal) in pairs { + assert!(ty.compatible(&literal)); + } + } } diff --git a/crates/iceberg/src/spec/manifest.rs b/crates/iceberg/src/spec/manifest.rs index 563837aea..f0dfdf47c 100644 --- a/crates/iceberg/src/spec/manifest.rs +++ b/crates/iceberg/src/spec/manifest.rs @@ -16,33 +16,38 @@ // under the License. //! Manifest for Iceberg. -use self::_const_schema::{manifest_schema_v1, manifest_schema_v2}; +use std::cmp::min; +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; +use apache_avro::{from_value, to_value, Reader as AvroReader, Writer as AvroWriter}; +use bytes::Bytes; +use serde_derive::{Deserialize, Serialize}; +use serde_json::to_vec; +use serde_with::{DeserializeFromStr, SerializeDisplay}; +use typed_builder::TypedBuilder; + +use self::_const_schema::{manifest_schema_v1, manifest_schema_v2}; use super::{ - FieldSummary, FormatVersion, ManifestContentType, ManifestListEntry, PartitionSpec, Schema, - Struct, + Datum, FieldSummary, FormatVersion, ManifestContentType, ManifestFile, PartitionSpec, Schema, + SchemaId, Struct, INITIAL_SEQUENCE_NUMBER, UNASSIGNED_SEQUENCE_NUMBER, }; -use super::{Literal, UNASSIGNED_SEQUENCE_NUMBER}; +use crate::error::Result; use crate::io::OutputFile; use crate::spec::PartitionField; use crate::{Error, ErrorKind}; -use apache_avro::{from_value, to_value, Reader as AvroReader, Writer as AvroWriter}; -use futures::AsyncWriteExt; -use serde_json::to_vec; -use std::cmp::min; -use std::collections::HashMap; -use std::str::FromStr; /// A manifest contains metadata and a list of entries. #[derive(Debug, PartialEq, Eq, Clone)] pub struct Manifest { metadata: ManifestMetadata, - entries: Vec, + entries: Vec, } impl Manifest { - /// Parse manifest from bytes of avro file. - pub fn parse_avro(bs: &[u8]) -> Result { + /// Parse manifest metadata and entries from bytes of avro file. + pub(crate) fn try_from_avro_bytes(bs: &[u8]) -> Result<(ManifestMetadata, Vec)> { let reader = AvroReader::new(bs)?; // Parse manifest metadata @@ -62,7 +67,7 @@ impl Manifest { from_value::<_serde::ManifestEntryV1>(&value?)? .try_into(&partition_type, &metadata.schema) }) - .collect::, Error>>()? + .collect::>>()? } FormatVersion::V2 => { let schema = manifest_schema_v2(partition_type.clone())?; @@ -73,11 +78,36 @@ impl Manifest { from_value::<_serde::ManifestEntryV2>(&value?)? .try_into(&partition_type, &metadata.schema) }) - .collect::, Error>>()? + .collect::>>()? } }; - Ok(Manifest { metadata, entries }) + Ok((metadata, entries)) + } + + /// Parse manifest from bytes of avro file. + pub fn parse_avro(bs: &[u8]) -> Result { + let (metadata, entries) = Self::try_from_avro_bytes(bs)?; + Ok(Self::new(metadata, entries)) + } + + /// Entries slice. + pub fn entries(&self) -> &[ManifestEntryRef] { + &self.entries + } + + /// Consume this Manifest, returning its constituent parts + pub fn into_parts(self) -> (Vec, ManifestMetadata) { + let Self { entries, metadata } = self; + (entries, metadata) + } + + /// Constructor from [`ManifestMetadata`] and [`ManifestEntry`]s. + pub fn new(metadata: ManifestMetadata, entries: Vec) -> Self { + Self { + metadata, + entries: entries.into_iter().map(Arc::new).collect(), + } } } @@ -167,14 +197,14 @@ impl ManifestWriter { let entry = self .field_summary .remove(&field.source_id) - .unwrap_or(FieldSummary::default()); + .unwrap_or_default(); partition_summary.push(entry); } partition_summary } - /// Write a manifest entry. - pub async fn write(mut self, manifest: Manifest) -> Result { + /// Write a manifest. + pub async fn write(mut self, manifest: Manifest) -> Result { // Create the avro writer let partition_type = manifest .metadata @@ -199,14 +229,14 @@ impl ManifestWriter { )?; avro_writer.add_user_metadata( "partition-spec".to_string(), - to_vec(&manifest.metadata.partition_spec.fields).map_err(|err| { + to_vec(&manifest.metadata.partition_spec.fields()).map_err(|err| { Error::new(ErrorKind::DataInvalid, "Fail to serialize partition spec") .with_source(err) })?, )?; avro_writer.add_user_metadata( "partition-spec-id".to_string(), - manifest.metadata.partition_spec.spec_id.to_string(), + manifest.metadata.partition_spec.spec_id().to_string(), )?; avro_writer.add_user_metadata( "format-version".to_string(), @@ -252,45 +282,41 @@ impl ManifestWriter { self.update_field_summary(&entry); let value = match manifest.metadata.format_version { - FormatVersion::V1 => { - to_value(_serde::ManifestEntryV1::try_from(entry, &partition_type)?)? - .resolve(&avro_schema)? - } - FormatVersion::V2 => { - to_value(_serde::ManifestEntryV2::try_from(entry, &partition_type)?)? - .resolve(&avro_schema)? - } + FormatVersion::V1 => to_value(_serde::ManifestEntryV1::try_from( + (*entry).clone(), + &partition_type, + )?)? + .resolve(&avro_schema)?, + FormatVersion::V2 => to_value(_serde::ManifestEntryV2::try_from( + (*entry).clone(), + &partition_type, + )?)? + .resolve(&avro_schema)?, }; avro_writer.append(value)?; } - let length = avro_writer.flush()?; let content = avro_writer.into_inner()?; - let mut writer = self.output.writer().await?; - writer.write_all(&content).await.map_err(|err| { - Error::new(ErrorKind::Unexpected, "Fail to write Manifest Entry").with_source(err) - })?; - writer.close().await.map_err(|err| { - Error::new(ErrorKind::Unexpected, "Fail to write Manifest Entry").with_source(err) - })?; + let length = content.len(); + self.output.write(Bytes::from(content)).await?; let partition_summary = - self.get_field_summary_vec(&manifest.metadata.partition_spec.fields); + self.get_field_summary_vec(manifest.metadata.partition_spec.fields()); - Ok(ManifestListEntry { + Ok(ManifestFile { manifest_path: self.output.location().to_string(), manifest_length: length as i64, - partition_spec_id: manifest.metadata.partition_spec.spec_id, + partition_spec_id: manifest.metadata.partition_spec.spec_id(), content: manifest.metadata.content, // sequence_number and min_sequence_number with UNASSIGNED_SEQUENCE_NUMBER will be replace with // real sequence number in `ManifestListWriter`. sequence_number: UNASSIGNED_SEQUENCE_NUMBER, min_sequence_number: self.min_seq_num.unwrap_or(UNASSIGNED_SEQUENCE_NUMBER), added_snapshot_id: self.snapshot_id, - added_data_files_count: Some(self.added_files), - existing_data_files_count: Some(self.existing_files), - deleted_data_files_count: Some(self.deleted_files), + added_files_count: Some(self.added_files), + existing_files_count: Some(self.existing_files), + deleted_files_count: Some(self.deleted_files), added_rows_count: Some(self.added_rows), existing_rows_count: Some(self.existing_rows), deleted_rows_count: Some(self.deleted_rows), @@ -307,13 +333,11 @@ mod _const_schema { use apache_avro::Schema as AvroSchema; use once_cell::sync::Lazy; - use crate::{ - avro::schema_to_avro_schema, - spec::{ - ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, StructType, Type, - }, - Error, + use crate::avro::schema_to_avro_schema; + use crate::spec::{ + ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, StructType, Type, }; + use crate::Error; static STATUS: Lazy = { Lazy::new(|| { @@ -637,8 +661,8 @@ mod _const_schema { ])), )), ]; - let schema = Schema::builder().with_fields(fields).build().unwrap(); - schema_to_avro_schema("manifest", &schema) + let schema = Schema::builder().with_fields(fields).build()?; + schema_to_avro_schema("manifest_entry", &schema) } pub(super) fn manifest_schema_v1(partition_type: StructType) -> Result { @@ -671,19 +695,19 @@ mod _const_schema { ])), )), ]; - let schema = Schema::builder().with_fields(fields).build().unwrap(); - schema_to_avro_schema("manifest", &schema) + let schema = Schema::builder().with_fields(fields).build()?; + schema_to_avro_schema("manifest_entry", &schema) } } /// Meta data of a manifest that is stored in the key-value metadata of the Avro file -#[derive(Debug, PartialEq, Clone, Eq)] +#[derive(Debug, PartialEq, Clone, Eq, TypedBuilder)] pub struct ManifestMetadata { /// The table schema at the time the manifest /// was written schema: Schema, /// ID of the schema used to write the manifest as a string - schema_id: i32, + schema_id: SchemaId, /// The partition spec used to write the manifest partition_spec: PartitionSpec, /// Table format version number of the manifest as a string @@ -694,7 +718,7 @@ pub struct ManifestMetadata { impl ManifestMetadata { /// Parse from metadata in avro file. - pub fn parse(meta: &HashMap>) -> Result { + pub fn parse(meta: &HashMap>) -> Result { let schema = { let bs = meta.get("schema").ok_or_else(|| { Error::new( @@ -781,10 +805,13 @@ impl ManifestMetadata { } } +/// Reference to [`ManifestEntry`]. +pub type ManifestEntryRef = Arc; + /// A manifest is an immutable Avro file that lists data files or delete /// files, along with each file’s partition data tuple, metrics, and tracking /// information. -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone, TypedBuilder)] pub struct ManifestEntry { /// field: 0 /// @@ -794,16 +821,19 @@ pub struct ManifestEntry { /// /// Snapshot id where the file was added, or deleted if status is 2. /// Inherited when null. + #[builder(default, setter(strip_option))] snapshot_id: Option, /// field id: 3 /// /// Data sequence number of the file. /// Inherited when null and status is 1 (added). + #[builder(default, setter(strip_option))] sequence_number: Option, /// field id: 4 /// /// File sequence number indicating when the file was added. /// Inherited when null and status is 1 (added). + #[builder(default, setter(strip_option))] file_sequence_number: Option, /// field id: 2 /// @@ -819,6 +849,69 @@ impl ManifestEntry { ManifestStatus::Added | ManifestStatus::Existing ) } + + /// Content type of this manifest entry. + #[inline] + pub fn content_type(&self) -> DataContentType { + self.data_file.content + } + + /// File format of this manifest entry. + #[inline] + pub fn file_format(&self) -> DataFileFormat { + self.data_file.file_format + } + + /// Data file path of this manifest entry. + #[inline] + pub fn file_path(&self) -> &str { + &self.data_file.file_path + } + + /// Data file record count of the manifest entry. + #[inline] + pub fn record_count(&self) -> u64 { + self.data_file.record_count + } + + /// Inherit data from manifest list, such as snapshot id, sequence number. + pub(crate) fn inherit_data(&mut self, snapshot_entry: &ManifestFile) { + if self.snapshot_id.is_none() { + self.snapshot_id = Some(snapshot_entry.added_snapshot_id); + } + + if self.sequence_number.is_none() + && (self.status == ManifestStatus::Added + || snapshot_entry.sequence_number == INITIAL_SEQUENCE_NUMBER) + { + self.sequence_number = Some(snapshot_entry.sequence_number); + } + + if self.file_sequence_number.is_none() + && (self.status == ManifestStatus::Added + || snapshot_entry.sequence_number == INITIAL_SEQUENCE_NUMBER) + { + self.file_sequence_number = Some(snapshot_entry.sequence_number); + } + } + + /// Data sequence number. + #[inline] + pub fn sequence_number(&self) -> Option { + self.sequence_number + } + + /// File size in bytes. + #[inline] + pub fn file_size_in_bytes(&self) -> u64 { + self.data_file.file_size_in_bytes + } + + /// get a reference to the actual data file + #[inline] + pub fn data_file(&self) -> &DataFile { + &self.data_file + } } /// Used to track additions and deletions in ManifestEntry. @@ -837,7 +930,7 @@ pub enum ManifestStatus { impl TryFrom for ManifestStatus { type Error = Error; - fn try_from(v: i32) -> Result { + fn try_from(v: i32) -> Result { match v { 0 => Ok(ManifestStatus::Existing), 1 => Ok(ManifestStatus::Added), @@ -851,34 +944,34 @@ impl TryFrom for ManifestStatus { } /// Data file carries data file path, partition tuple, metrics, … -#[derive(Debug, PartialEq, Clone, Eq)] +#[derive(Debug, PartialEq, Clone, Eq, Builder)] pub struct DataFile { /// field id: 134 /// /// Type of content stored by the data file: data, equality deletes, /// or position deletes (all v1 files are data files) - content: DataContentType, + pub(crate) content: DataContentType, /// field id: 100 /// /// Full URI for the file with FS scheme - file_path: String, + pub(crate) file_path: String, /// field id: 101 /// /// String file format name, avro, orc or parquet - file_format: DataFileFormat, + pub(crate) file_format: DataFileFormat, /// field id: 102 /// /// Partition data tuple, schema based on the partition spec output using /// partition field ids for the struct field ids - partition: Struct, + pub(crate) partition: Struct, /// field id: 103 /// /// Number of records in this file - record_count: u64, + pub(crate) record_count: u64, /// field id: 104 /// /// Total file size in bytes - file_size_in_bytes: u64, + pub(crate) file_size_in_bytes: u64, /// field id: 108 /// key field id: 117 /// value field id: 118 @@ -886,26 +979,30 @@ pub struct DataFile { /// Map from column id to the total size on disk of all regions that /// store the column. Does not include bytes necessary to read other /// columns, like footers. Leave null for row-oriented formats (Avro) - column_sizes: HashMap, + #[builder(default)] + pub(crate) column_sizes: HashMap, /// field id: 109 /// key field id: 119 /// value field id: 120 /// /// Map from column id to number of values in the column (including null /// and NaN values) - value_counts: HashMap, + #[builder(default)] + pub(crate) value_counts: HashMap, /// field id: 110 /// key field id: 121 /// value field id: 122 /// /// Map from column id to number of null values in the column - null_value_counts: HashMap, + #[builder(default)] + pub(crate) null_value_counts: HashMap, /// field id: 137 /// key field id: 138 /// value field id: 139 /// /// Map from column id to number of NaN values in the column - nan_value_counts: HashMap, + #[builder(default)] + pub(crate) nan_value_counts: HashMap, /// field id: 125 /// key field id: 126 /// value field id: 127 @@ -917,7 +1014,8 @@ pub struct DataFile { /// Reference: /// /// - [Binary single-value serialization](https://iceberg.apache.org/spec/#binary-single-value-serialization) - lower_bounds: HashMap, + #[builder(default)] + pub(crate) lower_bounds: HashMap, /// field id: 128 /// key field id: 129 /// value field id: 130 @@ -929,17 +1027,20 @@ pub struct DataFile { /// Reference: /// /// - [Binary single-value serialization](https://iceberg.apache.org/spec/#binary-single-value-serialization) - upper_bounds: HashMap, + #[builder(default)] + pub(crate) upper_bounds: HashMap, /// field id: 131 /// /// Implementation-specific key metadata for encryption - key_metadata: Vec, + #[builder(default)] + pub(crate) key_metadata: Vec, /// field id: 132 /// element field id: 133 /// /// Split offsets for the data file. For example, all row group offsets /// in a Parquet file. Must be sorted ascending - split_offsets: Vec, + #[builder(default)] + pub(crate) split_offsets: Vec, /// field id: 135 /// element field id: 136 /// @@ -947,7 +1048,8 @@ pub struct DataFile { /// Required when content is EqualityDeletes and should be null /// otherwise. Fields with ids listed in this column must be present /// in the delete file - equality_ids: Vec, + #[builder(default)] + pub(crate) equality_ids: Vec, /// field id: 140 /// /// ID representing sort order for this file. @@ -958,12 +1060,96 @@ pub struct DataFile { /// sorted by file and position, not a table order, and should set sort /// order id to null. Readers must ignore sort order id for position /// delete files. - sort_order_id: Option, + #[builder(default, setter(strip_option))] + pub(crate) sort_order_id: Option, } +impl DataFile { + /// Get the content type of the data file (data, equality deletes, or position deletes) + pub fn content_type(&self) -> DataContentType { + self.content + } + /// Get the file path as full URI with FS scheme + pub fn file_path(&self) -> &str { + &self.file_path + } + /// Get the file format of the file (avro, orc or parquet). + pub fn file_format(&self) -> DataFileFormat { + self.file_format + } + /// Get the partition values of the file. + pub fn partition(&self) -> &Struct { + &self.partition + } + /// Get the record count in the data file. + pub fn record_count(&self) -> u64 { + self.record_count + } + /// Get the file size in bytes. + pub fn file_size_in_bytes(&self) -> u64 { + self.file_size_in_bytes + } + /// Get the column sizes. + /// Map from column id to the total size on disk of all regions that + /// store the column. Does not include bytes necessary to read other + /// columns, like footers. Null for row-oriented formats (Avro) + pub fn column_sizes(&self) -> &HashMap { + &self.column_sizes + } + /// Get the columns value counts for the data file. + /// Map from column id to number of values in the column (including null + /// and NaN values) + pub fn value_counts(&self) -> &HashMap { + &self.value_counts + } + /// Get the null value counts of the data file. + /// Map from column id to number of null values in the column + pub fn null_value_counts(&self) -> &HashMap { + &self.null_value_counts + } + /// Get the nan value counts of the data file. + /// Map from column id to number of NaN values in the column + pub fn nan_value_counts(&self) -> &HashMap { + &self.nan_value_counts + } + /// Get the lower bounds of the data file values per column. + /// Map from column id to lower bound in the column serialized as binary. + pub fn lower_bounds(&self) -> &HashMap { + &self.lower_bounds + } + /// Get the upper bounds of the data file values per column. + /// Map from column id to upper bound in the column serialized as binary. + pub fn upper_bounds(&self) -> &HashMap { + &self.upper_bounds + } + /// Get the Implementation-specific key metadata for the data file. + pub fn key_metadata(&self) -> &[u8] { + &self.key_metadata + } + /// Get the split offsets of the data file. + /// For example, all row group offsets in a Parquet file. + pub fn split_offsets(&self) -> &[i64] { + &self.split_offsets + } + /// Get the equality ids of the data file. + /// Field ids used to determine row equality in equality delete files. + /// null when content is not EqualityDeletes. + pub fn equality_ids(&self) -> &[i32] { + &self.equality_ids + } + /// Get the sort order id of the data file. + /// Only data files and equality delete files should be + /// written with a non-null order id. Position deletes are required to be + /// sorted by file and position, not a table order, and should set sort + /// order id to null. Readers must ignore sort order id for position + /// delete files. + pub fn sort_order_id(&self) -> Option { + self.sort_order_id + } +} /// Type of content stored by the data file: data, equality deletes, or /// position deletes (all v1 files are data files) -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub enum DataContentType { /// value: 0 Data = 0, @@ -976,7 +1162,7 @@ pub enum DataContentType { impl TryFrom for DataContentType { type Error = Error; - fn try_from(v: i32) -> Result { + fn try_from(v: i32) -> Result { match v { 0 => Ok(DataContentType::Data), 1 => Ok(DataContentType::PositionDeletes), @@ -990,7 +1176,7 @@ impl TryFrom for DataContentType { } /// Format of this data. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, SerializeDisplay, DeserializeFromStr)] pub enum DataFileFormat { /// Avro file format: Avro, @@ -1003,7 +1189,7 @@ pub enum DataFileFormat { impl FromStr for DataFileFormat { type Err = Error; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { "avro" => Ok(Self::Avro), "orc" => Ok(Self::Orc), @@ -1016,34 +1202,25 @@ impl FromStr for DataFileFormat { } } -impl ToString for DataFileFormat { - fn to_string(&self) -> String { +impl std::fmt::Display for DataFileFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - DataFileFormat::Avro => "avro", - DataFileFormat::Orc => "orc", - DataFileFormat::Parquet => "parquet", + DataFileFormat::Avro => write!(f, "avro"), + DataFileFormat::Orc => write!(f, "orc"), + DataFileFormat::Parquet => write!(f, "parquet"), } - .to_string() } } mod _serde { use std::collections::HashMap; - use serde_bytes::ByteBuf; use serde_derive::{Deserialize, Serialize}; use serde_with::serde_as; - use crate::spec::Literal; - use crate::spec::RawLiteral; - use crate::spec::Schema; - use crate::spec::Struct; - use crate::spec::StructType; - use crate::spec::Type; - use crate::Error; - use crate::ErrorKind; - use super::ManifestEntry; + use crate::spec::{Datum, Literal, RawLiteral, Schema, Struct, StructType, Type}; + use crate::{Error, ErrorKind}; #[derive(Serialize, Deserialize)] pub(super) struct ManifestEntryV2 { @@ -1242,28 +1419,32 @@ mod _serde { fn parse_bytes_entry( v: Vec, schema: &Schema, - ) -> Result, Error> { + ) -> Result, Error> { let mut m = HashMap::with_capacity(v.len()); for entry in v { - let data_type = &schema - .field_by_id(entry.key) - .ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - format!("Can't find field id {} for upper/lower_bounds", entry.key), - ) - })? - .field_type; - m.insert(entry.key, Literal::try_from_bytes(&entry.value, data_type)?); + // We ignore the entry if the field is not found in the schema, due to schema evolution. + if let Some(field) = schema.field_by_id(entry.key) { + let data_type = field + .field_type + .as_primitive_type() + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("field {} is not a primitive type", field.name), + ) + })? + .clone(); + m.insert(entry.key, Datum::try_from_bytes(&entry.value, data_type)?); + } } Ok(m) } - fn to_bytes_entry(v: HashMap) -> Vec { + fn to_bytes_entry(v: impl IntoIterator) -> Vec { v.into_iter() .map(|e| BytesEntry { key: e.0, - value: Into::::into(e.1), + value: e.1.to_bytes(), }) .collect() } @@ -1278,7 +1459,11 @@ mod _serde { fn parse_i64_entry(v: Vec) -> Result, Error> { let mut m = HashMap::with_capacity(v.len()); for entry in v { - m.insert(entry.key, entry.value.try_into()?); + // We ignore the entry if it's value is negative since these entries are supposed to be used for + // counting, which should never be negative. + if let Ok(v) = entry.value.try_into() { + m.insert(entry.key, v); + } } Ok(m) } @@ -1294,22 +1479,38 @@ mod _serde { }) .collect() } + + #[cfg(test)] + mod tests { + use std::collections::HashMap; + + use crate::spec::manifest::_serde::{parse_i64_entry, I64Entry}; + + #[test] + fn test_parse_negative_manifest_entry() { + let entries = vec![I64Entry { key: 1, value: -1 }, I64Entry { + key: 2, + value: 3, + }]; + + let ret = parse_i64_entry(entries).unwrap(); + + let expected_ret = HashMap::from([(2, 3)]); + assert_eq!(ret, expected_ret, "Negative i64 entry should be ignored!"); + } + } } #[cfg(test)] mod tests { use std::fs; + use std::sync::Arc; use tempfile::TempDir; use super::*; use crate::io::FileIOBuilder; - use crate::spec::NestedField; - use crate::spec::PrimitiveType; - use crate::spec::Struct; - use crate::spec::Transform; - use crate::spec::Type; - use std::sync::Arc; + use crate::spec::{Literal, NestedField, PrimitiveType, Struct, Transform, Type}; #[tokio::test] async fn test_parse_manifest_v2_unpartition() { @@ -1377,6 +1578,11 @@ mod tests { "v_ts_ntz", Type::Primitive(PrimitiveType::Timestamp), )), + Arc::new(NestedField::optional( + 12, + "v_ts_ns_ntz", + Type::Primitive(PrimitiveType::TimestampNs + ))), ]) .build() .unwrap(), @@ -1388,7 +1594,7 @@ mod tests { format_version: FormatVersion::V2, }, entries: vec![ - ManifestEntry { + Arc::new(ManifestEntry { status: ManifestStatus::Added, snapshot_id: None, sequence_number: None, @@ -1411,7 +1617,7 @@ mod tests { equality_ids: Vec::new(), sort_order_id: None, } - } + }) ] }; @@ -1485,6 +1691,11 @@ mod tests { "v_ts_ntz", Type::Primitive(PrimitiveType::Timestamp), )), + Arc::new(NestedField::optional( + 12, + "v_ts_ns_ntz", + Type::Primitive(PrimitiveType::TimestampNs + ))) ]) .build() .unwrap(), @@ -1508,7 +1719,7 @@ mod tests { content: ManifestContentType::Data, format_version: FormatVersion::V2, }, - entries: vec![ManifestEntry { + entries: vec![Arc::new(ManifestEntry { status: ManifestStatus::Added, snapshot_id: None, sequence_number: None, @@ -1519,8 +1730,8 @@ mod tests { file_path: "s3a://icebergdata/demo/s1/t1/data/00000-0-378b56f5-5c52-4102-a2c2-f05f8a7cbe4a-00000.parquet".to_string(), partition: Struct::from_iter( vec![ - (1000, Some(Literal::int(1)), "v_int".to_string()), - (1001, Some(Literal::long(1000)), "v_long".to_string()) + Some(Literal::int(1)), + Some(Literal::long(1000)), ] .into_iter() ), @@ -1573,7 +1784,7 @@ mod tests { equality_ids: vec![], sort_order_id: None, }, - }], + })], }; let writer = |output_file: OutputFile| ManifestWriter::new(output_file, 1, vec![]); @@ -1617,7 +1828,7 @@ mod tests { content: ManifestContentType::Data, format_version: FormatVersion::V1, }, - entries: vec![ManifestEntry { + entries: vec![Arc::new(ManifestEntry { status: ManifestStatus::Added, snapshot_id: Some(0), sequence_number: Some(0), @@ -1633,14 +1844,14 @@ mod tests { value_counts: HashMap::from([(1,1),(2,1),(3,1)]), null_value_counts: HashMap::from([(1,0),(2,0),(3,0)]), nan_value_counts: HashMap::new(), - lower_bounds: HashMap::from([(1,Literal::int(1)),(2,Literal::string("a")),(3,Literal::string("AC/DC"))]), - upper_bounds: HashMap::from([(1,Literal::int(1)),(2,Literal::string("a")),(3,Literal::string("AC/DC"))]), + lower_bounds: HashMap::from([(1,Datum::int(1)),(2,Datum::string("a")),(3,Datum::string("AC/DC"))]), + upper_bounds: HashMap::from([(1,Datum::int(1)),(2,Datum::string("a")),(3,Datum::string("AC/DC"))]), key_metadata: vec![], split_offsets: vec![4], equality_ids: vec![], sort_order_id: Some(0), } - }], + })], }; let writer = @@ -1687,7 +1898,7 @@ mod tests { format_version: FormatVersion::V1, }, entries: vec![ - ManifestEntry { + Arc::new(ManifestEntry { status: ManifestStatus::Added, snapshot_id: Some(0), sequence_number: Some(0), @@ -1697,14 +1908,11 @@ mod tests { file_path: "s3://testbucket/prod/db/sample/data/category=x/00010-1-d5c93668-1e52-41ac-92a6-bba590cbf249-00001.parquet".to_string(), file_format: DataFileFormat::Parquet, partition: Struct::from_iter( - vec![( - 1000, + vec![ Some( - Literal::try_from_bytes(&[120], &Type::Primitive(PrimitiveType::String)) - .unwrap() + Literal::string("x"), ), - "category".to_string() - )] + ] .into_iter() ), record_count: 1, @@ -1714,21 +1922,21 @@ mod tests { null_value_counts: HashMap::from([(1, 0), (2, 0), (3, 0)]), nan_value_counts: HashMap::new(), lower_bounds: HashMap::from([ - (1, Literal::long(1)), - (2, Literal::string("a")), - (3, Literal::string("x")) + (1, Datum::long(1)), + (2, Datum::string("a")), + (3, Datum::string("x")) ]), upper_bounds: HashMap::from([ - (1, Literal::long(1)), - (2, Literal::string("a")), - (3, Literal::string("x")) + (1, Datum::long(1)), + (2, Datum::string("a")), + (3, Datum::string("x")) ]), key_metadata: vec![], split_offsets: vec![4], equality_ids: vec![], sort_order_id: Some(0), }, - } + }) ] }; @@ -1737,14 +1945,164 @@ mod tests { let entry = test_manifest_read_write(manifest, writer).await; assert_eq!(entry.partitions.len(), 1); - assert_eq!(entry.partitions[0].lower_bound, Some(Literal::string("x"))); - assert_eq!(entry.partitions[0].upper_bound, Some(Literal::string("x"))); + assert_eq!(entry.partitions[0].lower_bound, Some(Datum::string("x"))); + assert_eq!(entry.partitions[0].upper_bound, Some(Datum::string("x"))); + } + + #[tokio::test] + async fn test_parse_manifest_with_schema_evolution() { + let manifest = Manifest { + metadata: ManifestMetadata { + schema_id: 0, + schema: Schema::builder() + .with_fields(vec![ + Arc::new(NestedField::optional( + 1, + "id", + Type::Primitive(PrimitiveType::Long), + )), + Arc::new(NestedField::optional( + 2, + "v_int", + Type::Primitive(PrimitiveType::Int), + )), + ]) + .build() + .unwrap(), + partition_spec: PartitionSpec { + spec_id: 0, + fields: vec![], + }, + content: ManifestContentType::Data, + format_version: FormatVersion::V2, + }, + entries: vec![Arc::new(ManifestEntry { + status: ManifestStatus::Added, + snapshot_id: None, + sequence_number: None, + file_sequence_number: None, + data_file: DataFile { + content: DataContentType::Data, + file_format: DataFileFormat::Parquet, + file_path: "s3a://icebergdata/demo/s1/t1/data/00000-0-378b56f5-5c52-4102-a2c2-f05f8a7cbe4a-00000.parquet".to_string(), + partition: Struct::empty(), + record_count: 1, + file_size_in_bytes: 5442, + column_sizes: HashMap::from([ + (1, 61), + (2, 73), + (3, 61), + ]), + value_counts: HashMap::default(), + null_value_counts: HashMap::default(), + nan_value_counts: HashMap::new(), + lower_bounds: HashMap::from([ + (1, Datum::long(1)), + (2, Datum::int(2)), + (3, Datum::string("x")) + ]), + upper_bounds: HashMap::from([ + (1, Datum::long(1)), + (2, Datum::int(2)), + (3, Datum::string("x")) + ]), + key_metadata: vec![], + split_offsets: vec![4], + equality_ids: vec![], + sort_order_id: None, + }, + })], + }; + + let writer = |output_file: OutputFile| ManifestWriter::new(output_file, 1, vec![]); + + let (avro_bytes, _) = write_manifest(&manifest, writer).await; + + // The parse should succeed. + let actual_manifest = Manifest::parse_avro(avro_bytes.as_slice()).unwrap(); + + // Compared with original manifest, the lower_bounds and upper_bounds no longer has data for field 3, and + // other parts should be same. + let expected_manifest = Manifest { + metadata: ManifestMetadata { + schema_id: 0, + schema: Schema::builder() + .with_fields(vec![ + Arc::new(NestedField::optional( + 1, + "id", + Type::Primitive(PrimitiveType::Long), + )), + Arc::new(NestedField::optional( + 2, + "v_int", + Type::Primitive(PrimitiveType::Int), + )), + ]) + .build() + .unwrap(), + partition_spec: PartitionSpec { + spec_id: 0, + fields: vec![], + }, + content: ManifestContentType::Data, + format_version: FormatVersion::V2, + }, + entries: vec![Arc::new(ManifestEntry { + status: ManifestStatus::Added, + snapshot_id: None, + sequence_number: None, + file_sequence_number: None, + data_file: DataFile { + content: DataContentType::Data, + file_format: DataFileFormat::Parquet, + file_path: "s3a://icebergdata/demo/s1/t1/data/00000-0-378b56f5-5c52-4102-a2c2-f05f8a7cbe4a-00000.parquet".to_string(), + partition: Struct::empty(), + record_count: 1, + file_size_in_bytes: 5442, + column_sizes: HashMap::from([ + (1, 61), + (2, 73), + (3, 61), + ]), + value_counts: HashMap::default(), + null_value_counts: HashMap::default(), + nan_value_counts: HashMap::new(), + lower_bounds: HashMap::from([ + (1, Datum::long(1)), + (2, Datum::int(2)), + ]), + upper_bounds: HashMap::from([ + (1, Datum::long(1)), + (2, Datum::int(2)), + ]), + key_metadata: vec![], + split_offsets: vec![4], + equality_ids: vec![], + sort_order_id: None, + }, + })], + }; + + assert_eq!(actual_manifest, expected_manifest); } async fn test_manifest_read_write( manifest: Manifest, writer_builder: impl FnOnce(OutputFile) -> ManifestWriter, - ) -> ManifestListEntry { + ) -> ManifestFile { + let (bs, res) = write_manifest(&manifest, writer_builder).await; + let actual_manifest = Manifest::parse_avro(bs.as_slice()).unwrap(); + + assert_eq!(actual_manifest, manifest); + res + } + + /// Utility method which writes out a manifest and returns the bytes. + async fn write_manifest( + manifest: &Manifest, + writer_builder: impl FnOnce(OutputFile) -> ManifestWriter, + ) -> (Vec, ManifestFile) { let temp_dir = TempDir::new().unwrap(); let path = temp_dir.path().join("test_manifest.avro"); let io = FileIOBuilder::new_fs_io().build().unwrap(); @@ -1753,10 +2111,6 @@ mod tests { let res = writer.write(manifest.clone()).await.unwrap(); // Verify manifest - let bs = fs::read(path).expect("read_file must succeed"); - let actual_manifest = Manifest::parse_avro(bs.as_slice()).unwrap(); - - assert_eq!(actual_manifest, manifest); - res + (fs::read(path).expect("read_file must succeed"), res) } } diff --git a/crates/iceberg/src/spec/manifest_list.rs b/crates/iceberg/src/spec/manifest_list.rs index 76b8b53dd..3aaecf12d 100644 --- a/crates/iceberg/src/spec/manifest_list.rs +++ b/crates/iceberg/src/spec/manifest_list.rs @@ -17,18 +17,19 @@ //! ManifestList for Iceberg. -use std::{collections::HashMap, str::FromStr}; +use std::collections::HashMap; +use std::str::FromStr; -use crate::{io::OutputFile, spec::Literal, Error, ErrorKind}; -use apache_avro::{from_value, types::Value, Reader, Writer}; -use futures::AsyncWriteExt; +use apache_avro::types::Value; +use apache_avro::{from_value, Reader, Writer}; +use bytes::Bytes; -use self::{ - _const_schema::{MANIFEST_LIST_AVRO_SCHEMA_V1, MANIFEST_LIST_AVRO_SCHEMA_V2}, - _serde::{ManifestListEntryV1, ManifestListEntryV2}, -}; - -use super::{FormatVersion, StructType}; +use self::_const_schema::{MANIFEST_LIST_AVRO_SCHEMA_V1, MANIFEST_LIST_AVRO_SCHEMA_V2}; +use self::_serde::{ManifestFileV1, ManifestFileV2}; +use super::{Datum, FormatVersion, Manifest, StructType}; +use crate::error::Result; +use crate::io::{FileIO, OutputFile}; +use crate::{Error, ErrorKind}; /// Placeholder for sequence number. The field with this value must be replaced with the actual sequence number before it write. pub const UNASSIGNED_SEQUENCE_NUMBER: i64 = -1; @@ -49,7 +50,7 @@ pub const UNASSIGNED_SEQUENCE_NUMBER: i64 = -1; #[derive(Debug, Clone, PartialEq)] pub struct ManifestList { /// Entries in a manifest list. - entries: Vec, + entries: Vec, } impl ManifestList { @@ -57,26 +58,31 @@ impl ManifestList { pub fn parse_with_version( bs: &[u8], version: FormatVersion, - partition_types: &HashMap, - ) -> Result { + partition_type_provider: impl Fn(i32) -> Result>, + ) -> Result { match version { FormatVersion::V1 => { let reader = Reader::with_schema(&MANIFEST_LIST_AVRO_SCHEMA_V1, bs)?; - let values = Value::Array(reader.collect::, _>>()?); - from_value::<_serde::ManifestListV1>(&values)?.try_into(partition_types) + let values = Value::Array(reader.collect::, _>>()?); + from_value::<_serde::ManifestListV1>(&values)?.try_into(partition_type_provider) } FormatVersion::V2 => { - let reader = Reader::with_schema(&MANIFEST_LIST_AVRO_SCHEMA_V2, bs)?; - let values = Value::Array(reader.collect::, _>>()?); - from_value::<_serde::ManifestListV2>(&values)?.try_into(partition_types) + let reader = Reader::new(bs)?; + let values = Value::Array(reader.collect::, _>>()?); + from_value::<_serde::ManifestListV2>(&values)?.try_into(partition_type_provider) } } } /// Get the entries in the manifest list. - pub fn entries(&self) -> &[ManifestListEntry] { + pub fn entries(&self) -> &[ManifestFile] { &self.entries } + + /// Take ownership of the entries in the manifest list, consuming it + pub fn consume_entries(self) -> impl IntoIterator { + Box::new(self.entries.into_iter()) + } } /// A manifest list writer. @@ -163,45 +169,42 @@ impl ManifestListWriter { } } - /// Append manifest entries to be written. - pub fn add_manifest_entries( - &mut self, - manifest_entries: impl Iterator, - ) -> Result<(), Error> { + /// Append manifests to be written. + pub fn add_manifests(&mut self, manifests: impl Iterator) -> Result<()> { match self.format_version { FormatVersion::V1 => { - for manifest_entry in manifest_entries { - let manifest_entry: ManifestListEntryV1 = manifest_entry.try_into()?; - self.avro_writer.append_ser(manifest_entry)?; + for manifest in manifests { + let manifes: ManifestFileV1 = manifest.try_into()?; + self.avro_writer.append_ser(manifes)?; } } FormatVersion::V2 => { - for mut manifest_entry in manifest_entries { - if manifest_entry.sequence_number == UNASSIGNED_SEQUENCE_NUMBER { - if manifest_entry.added_snapshot_id != self.snapshot_id { + for mut manifest in manifests { + if manifest.sequence_number == UNASSIGNED_SEQUENCE_NUMBER { + if manifest.added_snapshot_id != self.snapshot_id { return Err(Error::new( ErrorKind::DataInvalid, format!( "Found unassigned sequence number for a manifest from snapshot {}.", - manifest_entry.added_snapshot_id + manifest.added_snapshot_id ), )); } - manifest_entry.sequence_number = self.sequence_number; + manifest.sequence_number = self.sequence_number; } - if manifest_entry.min_sequence_number == UNASSIGNED_SEQUENCE_NUMBER { - if manifest_entry.added_snapshot_id != self.snapshot_id { + if manifest.min_sequence_number == UNASSIGNED_SEQUENCE_NUMBER { + if manifest.added_snapshot_id != self.snapshot_id { return Err(Error::new( ErrorKind::DataInvalid, format!( "Found unassigned sequence number for a manifest from snapshot {}.", - manifest_entry.added_snapshot_id + manifest.added_snapshot_id ), )); } - manifest_entry.min_sequence_number = self.sequence_number; + manifest.min_sequence_number = self.sequence_number; } - let manifest_entry: ManifestListEntryV2 = manifest_entry.try_into()?; + let manifest_entry: ManifestFileV2 = manifest.try_into()?; self.avro_writer.append_ser(manifest_entry)?; } } @@ -210,11 +213,11 @@ impl ManifestListWriter { } /// Write the manifest list to the output file. - pub async fn close(self) -> Result<(), Error> { + pub async fn close(self) -> Result<()> { let data = self.avro_writer.into_inner()?; let mut writer = self.output_file.writer().await?; - writer.write_all(&data).await.unwrap(); - writer.close().await.unwrap(); + writer.write(Bytes::from(data)).await?; + writer.close().await?; Ok(()) } } @@ -226,9 +229,9 @@ mod _const_schema { use apache_avro::Schema as AvroSchema; use once_cell::sync::Lazy; - use crate::{ - avro::schema_to_avro_schema, - spec::{ListType, NestedField, NestedFieldRef, PrimitiveType, Schema, StructType, Type}, + use crate::avro::schema_to_avro_schema; + use crate::spec::{ + ListType, NestedField, NestedFieldRef, PrimitiveType, Schema, StructType, Type, }; static MANIFEST_PATH: Lazy = { @@ -298,7 +301,7 @@ mod _const_schema { Lazy::new(|| { Arc::new(NestedField::required( 504, - "added_data_files_count", + "added_files_count", Type::Primitive(PrimitiveType::Int), )) }) @@ -316,7 +319,7 @@ mod _const_schema { Lazy::new(|| { Arc::new(NestedField::required( 505, - "existing_data_files_count", + "existing_files_count", Type::Primitive(PrimitiveType::Int), )) }) @@ -334,7 +337,7 @@ mod _const_schema { Lazy::new(|| { Arc::new(NestedField::required( 506, - "deleted_data_files_count", + "deleted_files_count", Type::Primitive(PrimitiveType::Int), )) }) @@ -493,15 +496,15 @@ mod _const_schema { }; pub(super) static MANIFEST_LIST_AVRO_SCHEMA_V1: Lazy = - Lazy::new(|| schema_to_avro_schema("manifest_list", &V1_SCHEMA).unwrap()); + Lazy::new(|| schema_to_avro_schema("manifest_file", &V1_SCHEMA).unwrap()); pub(super) static MANIFEST_LIST_AVRO_SCHEMA_V2: Lazy = - Lazy::new(|| schema_to_avro_schema("manifest_list", &V2_SCHEMA).unwrap()); + Lazy::new(|| schema_to_avro_schema("manifest_file", &V2_SCHEMA).unwrap()); } /// Entry in a manifest list. #[derive(Debug, PartialEq, Clone)] -pub struct ManifestListEntry { +pub struct ManifestFile { /// field: 500 /// /// Location of the manifest file @@ -538,17 +541,17 @@ pub struct ManifestListEntry { /// /// Number of entries in the manifest that have status ADDED, when null /// this is assumed to be non-zero - pub added_data_files_count: Option, + pub added_files_count: Option, /// field: 505 /// /// Number of entries in the manifest that have status EXISTING (0), /// when null this is assumed to be non-zero - pub existing_data_files_count: Option, + pub existing_files_count: Option, /// field: 506 /// /// Number of entries in the manifest that have status DELETED (2), /// when null this is assumed to be non-zero - pub deleted_data_files_count: Option, + pub deleted_files_count: Option, /// field: 512 /// /// Number of rows in all of files in the manifest that have status @@ -589,7 +592,7 @@ pub enum ManifestContentType { impl FromStr for ManifestContentType { type Err = Error; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> Result { match s { "data" => Ok(ManifestContentType::Data), "deletes" => Ok(ManifestContentType::Deletes), @@ -601,11 +604,11 @@ impl FromStr for ManifestContentType { } } -impl ToString for ManifestContentType { - fn to_string(&self) -> String { +impl std::fmt::Display for ManifestContentType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ManifestContentType::Data => "data".to_string(), - ManifestContentType::Deletes => "deletes".to_string(), + ManifestContentType::Data => write!(f, "data"), + ManifestContentType::Deletes => write!(f, "deletes"), } } } @@ -613,7 +616,7 @@ impl ToString for ManifestContentType { impl TryFrom for ManifestContentType { type Error = Error; - fn try_from(value: i32) -> Result { + fn try_from(value: i32) -> std::result::Result { match value { 0 => Ok(ManifestContentType::Data), 1 => Ok(ManifestContentType::Deletes), @@ -628,6 +631,24 @@ impl TryFrom for ManifestContentType { } } +impl ManifestFile { + /// Load [`Manifest`]. + /// + /// This method will also initialize inherited values of [`ManifestEntry`], such as `sequence_number`. + pub async fn load_manifest(&self, file_io: &FileIO) -> Result { + let avro = file_io.new_input(&self.manifest_path)?.read().await?; + + let (metadata, mut entries) = Manifest::try_from_avro_bytes(&avro)?; + + // Let entries inherit values from the manifest list entry. + for entry in &mut entries { + entry.inherit_data(self); + } + + Ok(Manifest::new(metadata, entries)) + } +} + /// Field summary for partition field in the spec. /// /// Each field in the list corresponds to a field in the manifest file’s partition spec. @@ -645,40 +666,36 @@ pub struct FieldSummary { /// field: 510 /// The minimum value for the field in the manifests /// partitions. - pub lower_bound: Option, + pub lower_bound: Option, /// field: 511 /// The maximum value for the field in the manifests /// partitions. - pub upper_bound: Option, + pub upper_bound: Option, } /// This is a helper module that defines types to help with serialization/deserialization. -/// For deserialization the input first gets read into either the [ManifestListEntryV1] or [ManifestListEntryV2] struct -/// and then converted into the [ManifestListEntry] struct. Serialization works the other way around. -/// [ManifestListEntryV1] and [ManifestListEntryV2] are internal struct that are only used for serialization and deserialization. +/// For deserialization the input first gets read into either the [ManifestFileV1] or [ManifestFileV2] struct +/// and then converted into the [ManifestFile] struct. Serialization works the other way around. +/// [ManifestFileV1] and [ManifestFileV2] are internal struct that are only used for serialization and deserialization. pub(super) mod _serde { - use std::collections::HashMap; - pub use serde_bytes::ByteBuf; use serde_derive::{Deserialize, Serialize}; - use crate::{ - spec::{Literal, StructType, Type}, - Error, - }; - - use super::ManifestListEntry; + use super::ManifestFile; + use crate::error::Result; + use crate::spec::{Datum, PrimitiveType, StructType}; + use crate::Error; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub(crate) struct ManifestListV2 { - entries: Vec, + entries: Vec, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(transparent)] pub(crate) struct ManifestListV1 { - entries: Vec, + entries: Vec, } impl ManifestListV2 { @@ -686,8 +703,8 @@ pub(super) mod _serde { /// The convert of [entries] need the partition_type info so use this function instead of [std::TryFrom] trait. pub fn try_into( self, - partition_types: &HashMap, - ) -> Result { + partition_type_provider: impl Fn(i32) -> Result>, + ) -> Result { Ok(super::ManifestList { entries: self .entries @@ -695,7 +712,7 @@ pub(super) mod _serde { .map(|v| { let partition_spec_id = v.partition_spec_id; let manifest_path = v.manifest_path.clone(); - v.try_into(partition_types.get(&partition_spec_id)) + v.try_into(partition_type_provider(partition_spec_id)?.as_ref()) .map_err(|err| { err.with_context("manifest file path", manifest_path) .with_context( @@ -704,7 +721,7 @@ pub(super) mod _serde { ) }) }) - .collect::, _>>()?, + .collect::>>()?, }) } } @@ -712,13 +729,13 @@ pub(super) mod _serde { impl TryFrom for ManifestListV2 { type Error = Error; - fn try_from(value: super::ManifestList) -> Result { + fn try_from(value: super::ManifestList) -> std::result::Result { Ok(Self { entries: value .entries .into_iter() .map(TryInto::try_into) - .collect::, _>>()?, + .collect::, _>>()?, }) } } @@ -728,8 +745,8 @@ pub(super) mod _serde { /// The convert of [entries] need the partition_type info so use this function instead of [std::TryFrom] trait. pub fn try_into( self, - partition_types: &HashMap, - ) -> Result { + partition_type_provider: impl Fn(i32) -> Result>, + ) -> Result { Ok(super::ManifestList { entries: self .entries @@ -737,7 +754,7 @@ pub(super) mod _serde { .map(|v| { let partition_spec_id = v.partition_spec_id; let manifest_path = v.manifest_path.clone(); - v.try_into(partition_types.get(&partition_spec_id)) + v.try_into(partition_type_provider(partition_spec_id)?.as_ref()) .map_err(|err| { err.with_context("manifest file path", manifest_path) .with_context( @@ -746,7 +763,7 @@ pub(super) mod _serde { ) }) }) - .collect::, _>>()?, + .collect::>>()?, }) } } @@ -754,19 +771,19 @@ pub(super) mod _serde { impl TryFrom for ManifestListV1 { type Error = Error; - fn try_from(value: super::ManifestList) -> Result { + fn try_from(value: super::ManifestList) -> std::result::Result { Ok(Self { entries: value .entries .into_iter() .map(TryInto::try_into) - .collect::, _>>()?, + .collect::, _>>()?, }) } } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] - pub(super) struct ManifestListEntryV1 { + pub(super) struct ManifestFileV1 { pub manifest_path: String, pub manifest_length: i64, pub partition_spec_id: i32, @@ -781,8 +798,11 @@ pub(super) mod _serde { pub key_metadata: Option, } + // Aliases were added to fields that were renamed in Iceberg 1.5.0 (https://github.com/apache/iceberg/pull/5338), in order to support both conventions/versions. + // In the current implementation deserialization is done using field names, and therefore these fields may appear as either. + // see issue that raised this here: https://github.com/apache/iceberg-rust/issues/338 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] - pub(super) struct ManifestListEntryV2 { + pub(super) struct ManifestFileV2 { pub manifest_path: String, pub manifest_length: i64, pub partition_spec_id: i32, @@ -790,9 +810,12 @@ pub(super) mod _serde { pub sequence_number: i64, pub min_sequence_number: i64, pub added_snapshot_id: i64, - pub added_data_files_count: i32, - pub existing_data_files_count: i32, - pub deleted_data_files_count: i32, + #[serde(alias = "added_data_files_count", alias = "added_files_count")] + pub added_files_count: i32, + #[serde(alias = "existing_data_files_count", alias = "existing_files_count")] + pub existing_files_count: i32, + #[serde(alias = "deleted_data_files_count", alias = "deleted_files_count")] + pub deleted_files_count: i32, pub added_rows_count: i64, pub existing_rows_count: i64, pub deleted_rows_count: i64, @@ -812,17 +835,17 @@ pub(super) mod _serde { /// Converts the [FieldSummary] into a [super::FieldSummary]. /// [lower_bound] and [upper_bound] are converted into [Literal]s need the type info so use /// this function instead of [std::TryFrom] trait. - pub(crate) fn try_into(self, r#type: &Type) -> Result { + pub(crate) fn try_into(self, r#type: &PrimitiveType) -> Result { Ok(super::FieldSummary { contains_null: self.contains_null, contains_nan: self.contains_nan, lower_bound: self .lower_bound - .map(|v| Literal::try_from_bytes(&v, r#type)) + .map(|v| Datum::try_from_bytes(&v, r#type.clone())) .transpose()?, upper_bound: self .upper_bound - .map(|v| Literal::try_from_bytes(&v, r#type)) + .map(|v| Datum::try_from_bytes(&v, r#type.clone())) .transpose()?, }) } @@ -831,7 +854,7 @@ pub(super) mod _serde { fn try_convert_to_field_summary( partitions: Option>, partition_type: Option<&StructType>, - ) -> Result, Error> { + ) -> Result> { if let Some(partitions) = partitions { if let Some(partition_type) = partition_type { let partition_types = partition_type.fields(); @@ -848,8 +871,15 @@ pub(super) mod _serde { partitions .into_iter() .zip(partition_types) - .map(|(v, field)| v.try_into(&field.field_type)) - .collect::, _>>() + .map(|(v, field)| { + v.try_into(field.field_type.as_primitive_type().ok_or_else(|| { + Error::new( + crate::ErrorKind::DataInvalid, + "Invalid partition spec. Field type is not primitive", + ) + })?) + }) + .collect::>>() } else { Err(Error::new( crate::ErrorKind::DataInvalid, @@ -861,15 +891,12 @@ pub(super) mod _serde { } } - impl ManifestListEntryV2 { - /// Converts the [ManifestListEntryV2] into a [ManifestListEntry]. + impl ManifestFileV2 { + /// Converts the [ManifestFileV2] into a [ManifestFile]. /// The convert of [partitions] need the partition_type info so use this function instead of [std::TryFrom] trait. - pub fn try_into( - self, - partition_type: Option<&StructType>, - ) -> Result { + pub fn try_into(self, partition_type: Option<&StructType>) -> Result { let partitions = try_convert_to_field_summary(self.partitions, partition_type)?; - Ok(ManifestListEntry { + Ok(ManifestFile { manifest_path: self.manifest_path, manifest_length: self.manifest_length, partition_spec_id: self.partition_spec_id, @@ -877,9 +904,9 @@ pub(super) mod _serde { sequence_number: self.sequence_number, min_sequence_number: self.min_sequence_number, added_snapshot_id: self.added_snapshot_id, - added_data_files_count: Some(self.added_data_files_count.try_into()?), - existing_data_files_count: Some(self.existing_data_files_count.try_into()?), - deleted_data_files_count: Some(self.deleted_data_files_count.try_into()?), + added_files_count: Some(self.added_files_count.try_into()?), + existing_files_count: Some(self.existing_files_count.try_into()?), + deleted_files_count: Some(self.deleted_files_count.try_into()?), added_rows_count: Some(self.added_rows_count.try_into()?), existing_rows_count: Some(self.existing_rows_count.try_into()?), deleted_rows_count: Some(self.deleted_rows_count.try_into()?), @@ -889,28 +916,25 @@ pub(super) mod _serde { } } - impl ManifestListEntryV1 { - /// Converts the [ManifestListEntryV1] into a [ManifestListEntry]. + impl ManifestFileV1 { + /// Converts the [MManifestFileV1] into a [ManifestFile]. /// The convert of [partitions] need the partition_type info so use this function instead of [std::TryFrom] trait. - pub fn try_into( - self, - partition_type: Option<&StructType>, - ) -> Result { + pub fn try_into(self, partition_type: Option<&StructType>) -> Result { let partitions = try_convert_to_field_summary(self.partitions, partition_type)?; - Ok(ManifestListEntry { + Ok(ManifestFile { manifest_path: self.manifest_path, manifest_length: self.manifest_length, partition_spec_id: self.partition_spec_id, added_snapshot_id: self.added_snapshot_id, - added_data_files_count: self + added_files_count: self .added_data_files_count .map(TryInto::try_into) .transpose()?, - existing_data_files_count: self + existing_files_count: self .existing_data_files_count .map(TryInto::try_into) .transpose()?, - deleted_data_files_count: self + deleted_files_count: self .deleted_data_files_count .map(TryInto::try_into) .transpose()?, @@ -943,8 +967,8 @@ pub(super) mod _serde { .map(|v| FieldSummary { contains_null: v.contains_null, contains_nan: v.contains_nan, - lower_bound: v.lower_bound.map(|v| v.into()), - upper_bound: v.upper_bound.map(|v| v.into()), + lower_bound: v.lower_bound.map(|v| v.to_bytes()), + upper_bound: v.upper_bound.map(|v| v.to_bytes()), }) .collect(), ) @@ -959,10 +983,10 @@ pub(super) mod _serde { } } - impl TryFrom for ManifestListEntryV2 { + impl TryFrom for ManifestFileV2 { type Error = Error; - fn try_from(value: ManifestListEntry) -> Result { + fn try_from(value: ManifestFile) -> std::result::Result { let partitions = convert_to_serde_field_summary(value.partitions); let key_metadata = convert_to_serde_key_metadata(value.key_metadata); Ok(Self { @@ -973,30 +997,30 @@ pub(super) mod _serde { sequence_number: value.sequence_number, min_sequence_number: value.min_sequence_number, added_snapshot_id: value.added_snapshot_id, - added_data_files_count: value - .added_data_files_count + added_files_count: value + .added_files_count .ok_or_else(|| { Error::new( crate::ErrorKind::DataInvalid, - "added_data_files_count in ManifestListEntryV2 should be require", + "added_data_files_count in ManifestFileV2 should be require", ) })? .try_into()?, - existing_data_files_count: value - .existing_data_files_count + existing_files_count: value + .existing_files_count .ok_or_else(|| { Error::new( crate::ErrorKind::DataInvalid, - "existing_data_files_count in ManifestListEntryV2 should be require", + "existing_data_files_count in ManifestFileV2 should be require", ) })? .try_into()?, - deleted_data_files_count: value - .deleted_data_files_count + deleted_files_count: value + .deleted_files_count .ok_or_else(|| { Error::new( crate::ErrorKind::DataInvalid, - "deleted_data_files_count in ManifestListEntryV2 should be require", + "deleted_data_files_count in ManifestFileV2 should be require", ) })? .try_into()?, @@ -1005,7 +1029,7 @@ pub(super) mod _serde { .ok_or_else(|| { Error::new( crate::ErrorKind::DataInvalid, - "added_rows_count in ManifestListEntryV2 should be require", + "added_rows_count in ManifestFileV2 should be require", ) })? .try_into()?, @@ -1014,7 +1038,7 @@ pub(super) mod _serde { .ok_or_else(|| { Error::new( crate::ErrorKind::DataInvalid, - "existing_rows_count in ManifestListEntryV2 should be require", + "existing_rows_count in ManifestFileV2 should be require", ) })? .try_into()?, @@ -1023,7 +1047,7 @@ pub(super) mod _serde { .ok_or_else(|| { Error::new( crate::ErrorKind::DataInvalid, - "deleted_rows_count in ManifestListEntryV2 should be require", + "deleted_rows_count in ManifestFileV2 should be require", ) })? .try_into()?, @@ -1033,10 +1057,10 @@ pub(super) mod _serde { } } - impl TryFrom for ManifestListEntryV1 { + impl TryFrom for ManifestFileV1 { type Error = Error; - fn try_from(value: ManifestListEntry) -> Result { + fn try_from(value: ManifestFile) -> std::result::Result { let partitions = convert_to_serde_field_summary(value.partitions); let key_metadata = convert_to_serde_key_metadata(value.key_metadata); Ok(Self { @@ -1045,15 +1069,15 @@ pub(super) mod _serde { partition_spec_id: value.partition_spec_id, added_snapshot_id: value.added_snapshot_id, added_data_files_count: value - .added_data_files_count + .added_files_count .map(TryInto::try_into) .transpose()?, existing_data_files_count: value - .existing_data_files_count + .existing_files_count .map(TryInto::try_into) .transpose()?, deleted_data_files_count: value - .deleted_data_files_count + .deleted_files_count .map(TryInto::try_into) .transpose()?, added_rows_count: value.added_rows_count.map(TryInto::try_into).transpose()?, @@ -1074,26 +1098,26 @@ pub(super) mod _serde { #[cfg(test)] mod test { - use std::{collections::HashMap, fs, sync::Arc}; + use std::collections::HashMap; + use std::fs; + use std::sync::Arc; + use apache_avro::{Reader, Schema}; use tempfile::TempDir; - use crate::{ - io::FileIOBuilder, - spec::{ - manifest_list::{_serde::ManifestListV1, UNASSIGNED_SEQUENCE_NUMBER}, - FieldSummary, Literal, ManifestContentType, ManifestList, ManifestListEntry, - ManifestListWriter, NestedField, PrimitiveType, StructType, Type, - }, - }; - use super::_serde::ManifestListV2; + use crate::io::FileIOBuilder; + use crate::spec::manifest_list::_serde::ManifestListV1; + use crate::spec::{ + Datum, FieldSummary, ManifestContentType, ManifestFile, ManifestList, ManifestListWriter, + NestedField, PrimitiveType, StructType, Type, UNASSIGNED_SEQUENCE_NUMBER, + }; #[tokio::test] async fn test_parse_manifest_list_v1() { let manifest_list = ManifestList { entries: vec![ - ManifestListEntry { + ManifestFile { manifest_path: "/opt/bitnami/spark/warehouse/db/table/metadata/10d28031-9739-484c-92db-cdf2975cead4-m0.avro".to_string(), manifest_length: 5806, partition_spec_id: 0, @@ -1101,9 +1125,9 @@ mod test { sequence_number: 0, min_sequence_number: 0, added_snapshot_id: 1646658105718557341, - added_data_files_count: Some(3), - existing_data_files_count: Some(0), - deleted_data_files_count: Some(0), + added_files_count: Some(3), + existing_files_count: Some(0), + deleted_files_count: Some(0), added_rows_count: Some(3), existing_rows_count: Some(0), deleted_rows_count: Some(0), @@ -1126,14 +1150,14 @@ mod test { ); writer - .add_manifest_entries(manifest_list.entries.clone().into_iter()) + .add_manifests(manifest_list.entries.clone().into_iter()) .unwrap(); writer.close().await.unwrap(); let bs = fs::read(full_path).expect("read_file must succeed"); let parsed_manifest_list = - ManifestList::parse_with_version(&bs, crate::spec::FormatVersion::V1, &HashMap::new()) + ManifestList::parse_with_version(&bs, crate::spec::FormatVersion::V1, |_id| Ok(None)) .unwrap(); assert_eq!(manifest_list, parsed_manifest_list); @@ -1143,7 +1167,7 @@ mod test { async fn test_parse_manifest_list_v2() { let manifest_list = ManifestList { entries: vec![ - ManifestListEntry { + ManifestFile { manifest_path: "s3a://icebergdata/demo/s1/t1/metadata/05ffe08b-810f-49b3-a8f4-e88fc99b254a-m0.avro".to_string(), manifest_length: 6926, partition_spec_id: 1, @@ -1151,16 +1175,16 @@ mod test { sequence_number: 1, min_sequence_number: 1, added_snapshot_id: 377075049360453639, - added_data_files_count: Some(1), - existing_data_files_count: Some(0), - deleted_data_files_count: Some(0), + added_files_count: Some(1), + existing_files_count: Some(0), + deleted_files_count: Some(0), added_rows_count: Some(3), existing_rows_count: Some(0), deleted_rows_count: Some(0), - partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Literal::long(1)), upper_bound: Some(Literal::long(1))}], + partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Datum::long(1)), upper_bound: Some(Datum::long(1))}], key_metadata: vec![], }, - ManifestListEntry { + ManifestFile { manifest_path: "s3a://icebergdata/demo/s1/t1/metadata/05ffe08b-810f-49b3-a8f4-e88fc99b254a-m1.avro".to_string(), manifest_length: 6926, partition_spec_id: 2, @@ -1168,13 +1192,13 @@ mod test { sequence_number: 1, min_sequence_number: 1, added_snapshot_id: 377075049360453639, - added_data_files_count: Some(1), - existing_data_files_count: Some(0), - deleted_data_files_count: Some(0), + added_files_count: Some(1), + existing_files_count: Some(0), + deleted_files_count: Some(0), added_rows_count: Some(3), existing_rows_count: Some(0), deleted_rows_count: Some(0), - partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Literal::float(1.1)), upper_bound: Some(Literal::float(2.1))}], + partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Datum::float(1.1)), upper_bound: Some(Datum::float(2.1))}], key_metadata: vec![], } ] @@ -1194,35 +1218,36 @@ mod test { ); writer - .add_manifest_entries(manifest_list.entries.clone().into_iter()) + .add_manifests(manifest_list.entries.clone().into_iter()) .unwrap(); writer.close().await.unwrap(); let bs = fs::read(full_path).expect("read_file must succeed"); - let parsed_manifest_list = ManifestList::parse_with_version( - &bs, - crate::spec::FormatVersion::V2, - &HashMap::from([ - ( - 1, - StructType::new(vec![Arc::new(NestedField::required( - 1, - "test", - Type::Primitive(PrimitiveType::Long), - ))]), - ), - ( - 2, - StructType::new(vec![Arc::new(NestedField::required( + let parsed_manifest_list = + ManifestList::parse_with_version(&bs, crate::spec::FormatVersion::V2, |id| { + Ok(HashMap::from([ + ( 1, - "test", - Type::Primitive(PrimitiveType::Float), - ))]), - ), - ]), - ) - .unwrap(); + StructType::new(vec![Arc::new(NestedField::required( + 1, + "test", + Type::Primitive(PrimitiveType::Long), + ))]), + ), + ( + 2, + StructType::new(vec![Arc::new(NestedField::required( + 1, + "test", + Type::Primitive(PrimitiveType::Float), + ))]), + ), + ]) + .get(&id) + .cloned()) + }) + .unwrap(); assert_eq!(manifest_list, parsed_manifest_list); } @@ -1230,7 +1255,7 @@ mod test { #[test] fn test_serialize_manifest_list_v1() { let manifest_list:ManifestListV1 = ManifestList { - entries: vec![ManifestListEntry { + entries: vec![ManifestFile { manifest_path: "/opt/bitnami/spark/warehouse/db/table/metadata/10d28031-9739-484c-92db-cdf2975cead4-m0.avro".to_string(), manifest_length: 5806, partition_spec_id: 0, @@ -1238,9 +1263,9 @@ mod test { sequence_number: 0, min_sequence_number: 0, added_snapshot_id: 1646658105718557341, - added_data_files_count: Some(3), - existing_data_files_count: Some(0), - deleted_data_files_count: Some(0), + added_files_count: Some(3), + existing_files_count: Some(0), + deleted_files_count: Some(0), added_rows_count: Some(3), existing_rows_count: Some(0), deleted_rows_count: Some(0), @@ -1258,7 +1283,7 @@ mod test { #[test] fn test_serialize_manifest_list_v2() { let manifest_list:ManifestListV2 = ManifestList { - entries: vec![ManifestListEntry { + entries: vec![ManifestFile { manifest_path: "s3a://icebergdata/demo/s1/t1/metadata/05ffe08b-810f-49b3-a8f4-e88fc99b254a-m0.avro".to_string(), manifest_length: 6926, partition_spec_id: 1, @@ -1266,27 +1291,27 @@ mod test { sequence_number: 1, min_sequence_number: 1, added_snapshot_id: 377075049360453639, - added_data_files_count: Some(1), - existing_data_files_count: Some(0), - deleted_data_files_count: Some(0), + added_files_count: Some(1), + existing_files_count: Some(0), + deleted_files_count: Some(0), added_rows_count: Some(3), existing_rows_count: Some(0), deleted_rows_count: Some(0), - partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Literal::long(1)), upper_bound: Some(Literal::long(1))}], + partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Datum::long(1)), upper_bound: Some(Datum::long(1))}], key_metadata: vec![], }] }.try_into().unwrap(); let result = serde_json::to_string(&manifest_list).unwrap(); assert_eq!( result, - r#"[{"manifest_path":"s3a://icebergdata/demo/s1/t1/metadata/05ffe08b-810f-49b3-a8f4-e88fc99b254a-m0.avro","manifest_length":6926,"partition_spec_id":1,"content":0,"sequence_number":1,"min_sequence_number":1,"added_snapshot_id":377075049360453639,"added_data_files_count":1,"existing_data_files_count":0,"deleted_data_files_count":0,"added_rows_count":3,"existing_rows_count":0,"deleted_rows_count":0,"partitions":[{"contains_null":false,"contains_nan":false,"lower_bound":[1,0,0,0,0,0,0,0],"upper_bound":[1,0,0,0,0,0,0,0]}],"key_metadata":null}]"# + r#"[{"manifest_path":"s3a://icebergdata/demo/s1/t1/metadata/05ffe08b-810f-49b3-a8f4-e88fc99b254a-m0.avro","manifest_length":6926,"partition_spec_id":1,"content":0,"sequence_number":1,"min_sequence_number":1,"added_snapshot_id":377075049360453639,"added_files_count":1,"existing_files_count":0,"deleted_files_count":0,"added_rows_count":3,"existing_rows_count":0,"deleted_rows_count":0,"partitions":[{"contains_null":false,"contains_nan":false,"lower_bound":[1,0,0,0,0,0,0,0],"upper_bound":[1,0,0,0,0,0,0,0]}],"key_metadata":null}]"# ); } #[tokio::test] async fn test_manifest_list_writer_v1() { let expected_manifest_list = ManifestList { - entries: vec![ManifestListEntry { + entries: vec![ManifestFile { manifest_path: "/opt/bitnami/spark/warehouse/db/table/metadata/10d28031-9739-484c-92db-cdf2975cead4-m0.avro".to_string(), manifest_length: 5806, partition_spec_id: 1, @@ -1294,13 +1319,13 @@ mod test { sequence_number: 0, min_sequence_number: 0, added_snapshot_id: 1646658105718557341, - added_data_files_count: Some(3), - existing_data_files_count: Some(0), - deleted_data_files_count: Some(0), + added_files_count: Some(3), + existing_files_count: Some(0), + deleted_files_count: Some(0), added_rows_count: Some(3), existing_rows_count: Some(0), deleted_rows_count: Some(0), - partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Literal::long(1)), upper_bound: Some(Literal::long(1))}], + partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Datum::long(1)), upper_bound: Some(Datum::long(1))}], key_metadata: vec![], }] }; @@ -1312,24 +1337,26 @@ mod test { let mut writer = ManifestListWriter::v1(output_file, 1646658105718557341, 0); writer - .add_manifest_entries(expected_manifest_list.entries.clone().into_iter()) + .add_manifests(expected_manifest_list.entries.clone().into_iter()) .unwrap(); writer.close().await.unwrap(); let bs = fs::read(path).unwrap(); - let manifest_list = ManifestList::parse_with_version( - &bs, - crate::spec::FormatVersion::V1, - &HashMap::from([( + + let partition_types = HashMap::from([( + 1, + StructType::new(vec![Arc::new(NestedField::required( 1, - StructType::new(vec![Arc::new(NestedField::required( - 1, - "test", - Type::Primitive(PrimitiveType::Long), - ))]), - )]), - ) - .unwrap(); + "test", + Type::Primitive(PrimitiveType::Long), + ))]), + )]); + + let manifest_list = + ManifestList::parse_with_version(&bs, crate::spec::FormatVersion::V1, move |id| { + Ok(partition_types.get(&id).cloned()) + }) + .unwrap(); assert_eq!(manifest_list, expected_manifest_list); temp_dir.close().unwrap(); @@ -1340,7 +1367,7 @@ mod test { let snapshot_id = 377075049360453639; let seq_num = 1; let mut expected_manifest_list = ManifestList { - entries: vec![ManifestListEntry { + entries: vec![ManifestFile { manifest_path: "s3a://icebergdata/demo/s1/t1/metadata/05ffe08b-810f-49b3-a8f4-e88fc99b254a-m0.avro".to_string(), manifest_length: 6926, partition_spec_id: 1, @@ -1348,13 +1375,13 @@ mod test { sequence_number: UNASSIGNED_SEQUENCE_NUMBER, min_sequence_number: UNASSIGNED_SEQUENCE_NUMBER, added_snapshot_id: snapshot_id, - added_data_files_count: Some(1), - existing_data_files_count: Some(0), - deleted_data_files_count: Some(0), + added_files_count: Some(1), + existing_files_count: Some(0), + deleted_files_count: Some(0), added_rows_count: Some(3), existing_rows_count: Some(0), deleted_rows_count: Some(0), - partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Literal::long(1)), upper_bound: Some(Literal::long(1))}], + partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Datum::long(1)), upper_bound: Some(Datum::long(1))}], key_metadata: vec![], }] }; @@ -1366,24 +1393,24 @@ mod test { let mut writer = ManifestListWriter::v2(output_file, snapshot_id, 0, seq_num); writer - .add_manifest_entries(expected_manifest_list.entries.clone().into_iter()) + .add_manifests(expected_manifest_list.entries.clone().into_iter()) .unwrap(); writer.close().await.unwrap(); let bs = fs::read(path).unwrap(); - let manifest_list = ManifestList::parse_with_version( - &bs, - crate::spec::FormatVersion::V2, - &HashMap::from([( + let partition_types = HashMap::from([( + 1, + StructType::new(vec![Arc::new(NestedField::required( 1, - StructType::new(vec![Arc::new(NestedField::required( - 1, - "test", - Type::Primitive(PrimitiveType::Long), - ))]), - )]), - ) - .unwrap(); + "test", + Type::Primitive(PrimitiveType::Long), + ))]), + )]); + let manifest_list = + ManifestList::parse_with_version(&bs, crate::spec::FormatVersion::V2, move |id| { + Ok(partition_types.get(&id).cloned()) + }) + .unwrap(); expected_manifest_list.entries[0].sequence_number = seq_num; expected_manifest_list.entries[0].min_sequence_number = seq_num; assert_eq!(manifest_list, expected_manifest_list); @@ -1394,7 +1421,7 @@ mod test { #[tokio::test] async fn test_manifest_list_writer_v1_as_v2() { let expected_manifest_list = ManifestList { - entries: vec![ManifestListEntry { + entries: vec![ManifestFile { manifest_path: "/opt/bitnami/spark/warehouse/db/table/metadata/10d28031-9739-484c-92db-cdf2975cead4-m0.avro".to_string(), manifest_length: 5806, partition_spec_id: 1, @@ -1402,13 +1429,13 @@ mod test { sequence_number: 0, min_sequence_number: 0, added_snapshot_id: 1646658105718557341, - added_data_files_count: Some(3), - existing_data_files_count: Some(0), - deleted_data_files_count: Some(0), + added_files_count: Some(3), + existing_files_count: Some(0), + deleted_files_count: Some(0), added_rows_count: Some(3), existing_rows_count: Some(0), deleted_rows_count: Some(0), - partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Literal::long(1)), upper_bound: Some(Literal::long(1))}], + partitions: vec![FieldSummary { contains_null: false, contains_nan: Some(false), lower_bound: Some(Datum::long(1)), upper_bound: Some(Datum::long(1))}], key_metadata: vec![], }] }; @@ -1420,26 +1447,74 @@ mod test { let mut writer = ManifestListWriter::v2(output_file, 1646658105718557341, 0, 1); writer - .add_manifest_entries(expected_manifest_list.entries.clone().into_iter()) + .add_manifests(expected_manifest_list.entries.clone().into_iter()) .unwrap(); writer.close().await.unwrap(); let bs = fs::read(path).unwrap(); - let manifest_list = ManifestList::parse_with_version( - &bs, - crate::spec::FormatVersion::V2, - &HashMap::from([( + + let partition_types = HashMap::from([( + 1, + StructType::new(vec![Arc::new(NestedField::required( 1, - StructType::new(vec![Arc::new(NestedField::required( - 1, - "test", - Type::Primitive(PrimitiveType::Long), - ))]), - )]), - ) - .unwrap(); + "test", + Type::Primitive(PrimitiveType::Long), + ))]), + )]); + + let manifest_list = + ManifestList::parse_with_version(&bs, crate::spec::FormatVersion::V2, move |id| { + Ok(partition_types.get(&id).cloned()) + }) + .unwrap(); assert_eq!(manifest_list, expected_manifest_list); temp_dir.close().unwrap(); } + + #[tokio::test] + async fn test_manifest_list_v2_deserializer_aliases() { + // reading avro manifest file generated by iceberg 1.4.0 + let avro_1_path = "testdata/manifests_lists/manifest-list-v2-1.avro"; + let bs_1 = fs::read(avro_1_path).unwrap(); + let avro_1_fields = read_avro_schema_fields_as_str(bs_1.clone()).await; + assert_eq!( + avro_1_fields, + "manifest_path, manifest_length, partition_spec_id, content, sequence_number, min_sequence_number, added_snapshot_id, added_data_files_count, existing_data_files_count, deleted_data_files_count, added_rows_count, existing_rows_count, deleted_rows_count, partitions" + ); + // reading avro manifest file generated by iceberg 1.5.0 + let avro_2_path = "testdata/manifests_lists/manifest-list-v2-2.avro"; + let bs_2 = fs::read(avro_2_path).unwrap(); + let avro_2_fields = read_avro_schema_fields_as_str(bs_2.clone()).await; + assert_eq!( + avro_2_fields, + "manifest_path, manifest_length, partition_spec_id, content, sequence_number, min_sequence_number, added_snapshot_id, added_files_count, existing_files_count, deleted_files_count, added_rows_count, existing_rows_count, deleted_rows_count, partitions" + ); + // deserializing both files to ManifestList struct + let _manifest_list_1 = + ManifestList::parse_with_version(&bs_1, crate::spec::FormatVersion::V2, move |_id| { + Ok(Some(StructType::new(vec![]))) + }) + .unwrap(); + let _manifest_list_2 = + ManifestList::parse_with_version(&bs_2, crate::spec::FormatVersion::V2, move |_id| { + Ok(Some(StructType::new(vec![]))) + }) + .unwrap(); + } + + async fn read_avro_schema_fields_as_str(bs: Vec) -> String { + let reader = Reader::new(&bs[..]).unwrap(); + let schema = reader.writer_schema(); + let fields: String = match schema { + Schema::Record(record) => record + .fields + .iter() + .map(|field| field.name.clone()) + .collect::>() + .join(", "), + _ => "".to_string(), + }; + fields + } } diff --git a/crates/iceberg/src/spec/mod.rs b/crates/iceberg/src/spec/mod.rs index 199fc4a16..793f00d34 100644 --- a/crates/iceberg/src/spec/mod.rs +++ b/crates/iceberg/src/spec/mod.rs @@ -27,6 +27,8 @@ mod sort; mod table_metadata; mod transform; mod values; +mod view_metadata; +mod view_version; pub use datatypes::*; pub use manifest::*; @@ -38,3 +40,5 @@ pub use sort::*; pub use table_metadata::*; pub use transform::*; pub use values::*; +pub use view_metadata::*; +pub use view_version::*; diff --git a/crates/iceberg/src/spec/partition.rs b/crates/iceberg/src/spec/partition.rs index 9388820a2..36763df7e 100644 --- a/crates/iceberg/src/spec/partition.rs +++ b/crates/iceberg/src/spec/partition.rs @@ -17,14 +17,18 @@ /*! * Partitioning -*/ -use serde::{Deserialize, Serialize}; + */ use std::sync::Arc; + +use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; -use crate::{Error, ErrorKind}; +use super::transform::Transform; +use super::{NestedField, Schema, StructType}; +use crate::{Error, ErrorKind, Result}; -use super::{transform::Transform, NestedField, Schema, StructType}; +pub(crate) const UNPARTITIONED_LAST_ASSIGNED_ID: i32 = 999; +pub(crate) const DEFAULT_PARTITION_SPEC_ID: i32 = 0; /// Reference to [`PartitionSpec`]. pub type PartitionSpecRef = Arc; @@ -43,22 +47,37 @@ pub struct PartitionField { pub transform: Transform, } +impl PartitionField { + /// To unbound partition field + pub fn into_unbound(self) -> UnboundPartitionField { + self.into() + } +} + /// Partition spec that defines how to produce a tuple of partition values from a record. -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default, Builder)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)] #[serde(rename_all = "kebab-case")] -#[builder(setter(prefix = "with"))] pub struct PartitionSpec { /// Identifier for PartitionSpec - pub spec_id: i32, + pub(crate) spec_id: i32, /// Details of the partition spec - #[builder(setter(each(name = "with_partition_field")))] - pub fields: Vec, + pub(crate) fields: Vec, } impl PartitionSpec { - /// Create partition spec builer - pub fn builder() -> PartitionSpecBuilder { - PartitionSpecBuilder::default() + /// Create partition spec builder + pub fn builder(schema: &Schema) -> PartitionSpecBuilder { + PartitionSpecBuilder::new(schema) + } + + /// Spec id of the partition spec + pub fn spec_id(&self) -> i32 { + self.spec_id + } + + /// Fields of the partition spec + pub fn fields(&self) -> &[PartitionField] { + &self.fields } /// Returns if the partition spec is unpartitioned. @@ -73,7 +92,7 @@ impl PartitionSpec { } /// Returns the partition type of this partition spec. - pub fn partition_type(&self, schema: &Schema) -> Result { + pub fn partition_type(&self, schema: &Schema) -> Result { let mut fields = Vec::with_capacity(self.fields.len()); for partition_field in &self.fields { let field = schema @@ -95,6 +114,67 @@ impl PartitionSpec { } Ok(StructType::new(fields)) } + + /// Turn this partition spec into an unbound partition spec. + /// + /// The `field_id` is retained as `partition_id` in the unbound partition spec. + pub fn into_unbound(self) -> UnboundPartitionSpec { + self.into() + } + + /// Check if this partition spec is compatible with another partition spec. + /// + /// Returns true if the partition spec is equal to the other spec with partition field ids ignored and + /// spec_id ignored. The following must be identical: + /// * The number of fields + /// * Field order + /// * Field names + /// * Source column ids + /// * Transforms + pub fn is_compatible_with(&self, other: &UnboundPartitionSpec) -> bool { + if self.fields.len() != other.fields.len() { + return false; + } + + for (this_field, other_field) in self.fields.iter().zip(&other.fields) { + if this_field.source_id != other_field.source_id + || this_field.transform != other_field.transform + || this_field.name != other_field.name + { + return false; + } + } + + true + } + + /// Check if this partition spec has sequential partition ids. + /// Sequential ids start from 1000 and increment by 1 for each field. + /// This is required for spec version 1 + pub fn has_sequential_ids(&self) -> bool { + for (index, field) in self.fields.iter().enumerate() { + let expected_id = (UNPARTITIONED_LAST_ASSIGNED_ID as i64) + .checked_add(1) + .and_then(|id| id.checked_add(index as i64)) + .unwrap_or(i64::MAX); + + if field.field_id as i64 != expected_id { + return false; + } + } + + true + } + + /// Get the highest field id in the partition spec. + /// If the partition spec is unpartitioned, it returns the last unpartitioned last assigned id (999). + pub fn highest_field_id(&self) -> i32 { + self.fields + .iter() + .map(|f| f.field_id) + .max() + .unwrap_or(UNPARTITIONED_LAST_ASSIGNED_ID) + } } /// Reference to [`UnboundPartitionSpec`]. @@ -108,7 +188,7 @@ pub struct UnboundPartitionField { /// A partition field id that is used to identify a partition field and is unique within a partition spec. /// In v2 table metadata, it is unique across all partition specs. #[builder(default, setter(strip_option))] - pub partition_id: Option, + pub field_id: Option, /// A partition name. pub name: String, /// A transform that is applied to the source column to produce a partition value. @@ -116,30 +196,448 @@ pub struct UnboundPartitionField { } /// Unbound partition spec can be built without a schema and later bound to a schema. -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default, Builder)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)] #[serde(rename_all = "kebab-case")] -#[builder(setter(prefix = "with"))] pub struct UnboundPartitionSpec { /// Identifier for PartitionSpec - #[builder(default, setter(strip_option))] - pub spec_id: Option, + pub(crate) spec_id: Option, /// Details of the partition spec - #[builder(setter(each(name = "with_unbound_partition_field")))] - pub fields: Vec, + pub(crate) fields: Vec, } impl UnboundPartitionSpec { - /// Create unbound partition spec builer + /// Create unbound partition spec builder pub fn builder() -> UnboundPartitionSpecBuilder { UnboundPartitionSpecBuilder::default() } + + /// Bind this unbound partition spec to a schema. + pub fn bind(self, schema: &Schema) -> Result { + PartitionSpecBuilder::new_from_unbound(self, schema)?.build() + } + + /// Spec id of the partition spec + pub fn spec_id(&self) -> Option { + self.spec_id + } + + /// Fields of the partition spec + pub fn fields(&self) -> &[UnboundPartitionField] { + &self.fields + } + + /// Change the spec id of the partition spec + pub fn with_spec_id(self, spec_id: i32) -> Self { + Self { + spec_id: Some(spec_id), + ..self + } + } +} + +impl From for UnboundPartitionField { + fn from(field: PartitionField) -> Self { + UnboundPartitionField { + source_id: field.source_id, + field_id: Some(field.field_id), + name: field.name, + transform: field.transform, + } + } +} + +impl From for UnboundPartitionSpec { + fn from(spec: PartitionSpec) -> Self { + UnboundPartitionSpec { + spec_id: Some(spec.spec_id), + fields: spec.fields.into_iter().map(Into::into).collect(), + } + } +} + +/// Create a new UnboundPartitionSpec +#[derive(Debug, Default)] +pub struct UnboundPartitionSpecBuilder { + spec_id: Option, + fields: Vec, +} + +impl UnboundPartitionSpecBuilder { + /// Create a new partition spec builder with the given schema. + pub fn new() -> Self { + Self { + spec_id: None, + fields: vec![], + } + } + + /// Set the spec id for the partition spec. + pub fn with_spec_id(mut self, spec_id: i32) -> Self { + self.spec_id = Some(spec_id); + self + } + + /// Add a new partition field to the partition spec from an unbound partition field. + pub fn add_partition_field( + self, + source_id: i32, + target_name: impl ToString, + transformation: Transform, + ) -> Result { + let field = UnboundPartitionField { + source_id, + field_id: None, + name: target_name.to_string(), + transform: transformation, + }; + self.add_partition_field_internal(field) + } + + /// Add multiple partition fields to the partition spec. + pub fn add_partition_fields( + self, + fields: impl IntoIterator, + ) -> Result { + let mut builder = self; + for field in fields { + builder = builder.add_partition_field_internal(field)?; + } + Ok(builder) + } + + fn add_partition_field_internal(mut self, field: UnboundPartitionField) -> Result { + self.check_name_set_and_unique(&field.name)?; + self.check_for_redundant_partitions(field.source_id, &field.transform)?; + if let Some(partition_field_id) = field.field_id { + self.check_partition_id_unique(partition_field_id)?; + } + self.fields.push(field); + Ok(self) + } + + /// Build the unbound partition spec. + pub fn build(self) -> UnboundPartitionSpec { + UnboundPartitionSpec { + spec_id: self.spec_id, + fields: self.fields, + } + } +} + +/// Create valid partition specs for a given schema. +#[derive(Debug)] +pub struct PartitionSpecBuilder<'a> { + spec_id: Option, + last_assigned_field_id: i32, + fields: Vec, + schema: &'a Schema, +} + +impl<'a> PartitionSpecBuilder<'a> { + /// Create a new partition spec builder with the given schema. + pub fn new(schema: &'a Schema) -> Self { + Self { + spec_id: None, + fields: vec![], + last_assigned_field_id: UNPARTITIONED_LAST_ASSIGNED_ID, + schema, + } + } + + /// Create a new partition spec builder from an existing unbound partition spec. + pub fn new_from_unbound(unbound: UnboundPartitionSpec, schema: &'a Schema) -> Result { + let mut builder = + Self::new(schema).with_spec_id(unbound.spec_id.unwrap_or(DEFAULT_PARTITION_SPEC_ID)); + + for field in unbound.fields { + builder = builder.add_unbound_field(field)?; + } + Ok(builder) + } + + /// Set the last assigned field id for the partition spec. + /// + /// Set this field when a new partition spec is created for an existing TableMetaData. + /// As `field_id` must be unique in V2 metadata, this should be set to + /// the highest field id used previously. + pub fn with_last_assigned_field_id(mut self, last_assigned_field_id: i32) -> Self { + self.last_assigned_field_id = last_assigned_field_id; + self + } + + /// Set the spec id for the partition spec. + pub fn with_spec_id(mut self, spec_id: i32) -> Self { + self.spec_id = Some(spec_id); + self + } + + /// Add a new partition field to the partition spec. + pub fn add_partition_field( + self, + source_name: impl AsRef, + target_name: impl Into, + transform: Transform, + ) -> Result { + let source_id = self + .schema + .field_by_name(source_name.as_ref()) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!( + "Cannot find source column with name: {} in schema", + source_name.as_ref() + ), + ) + })? + .id; + let field = UnboundPartitionField { + source_id, + field_id: None, + name: target_name.into(), + transform, + }; + + self.add_unbound_field(field) + } + + /// Add a new partition field to the partition spec. + /// + /// If partition field id is set, it is used as the field id. + /// Otherwise, a new `field_id` is assigned. + pub fn add_unbound_field(mut self, field: UnboundPartitionField) -> Result { + self.check_name_set_and_unique(&field.name)?; + self.check_for_redundant_partitions(field.source_id, &field.transform)?; + Self::check_name_does_not_collide_with_schema(&field, self.schema)?; + Self::check_transform_compatibility(&field, self.schema)?; + if let Some(partition_field_id) = field.field_id { + self.check_partition_id_unique(partition_field_id)?; + } + + // Non-fallible from here + self.fields.push(field); + Ok(self) + } + + /// Wrapper around `with_unbound_fields` to add multiple partition fields. + pub fn add_unbound_fields( + self, + fields: impl IntoIterator, + ) -> Result { + let mut builder = self; + for field in fields { + builder = builder.add_unbound_field(field)?; + } + Ok(builder) + } + + /// Build a bound partition spec with the given schema. + pub fn build(self) -> Result { + let fields = Self::set_field_ids(self.fields, self.last_assigned_field_id)?; + Ok(PartitionSpec { + spec_id: self.spec_id.unwrap_or(DEFAULT_PARTITION_SPEC_ID), + fields, + }) + } + + fn set_field_ids( + fields: Vec, + last_assigned_field_id: i32, + ) -> Result> { + let mut last_assigned_field_id = last_assigned_field_id; + // Already assigned partition ids. If we see one of these during iteration, + // we skip it. + let assigned_ids = fields + .iter() + .filter_map(|f| f.field_id) + .collect::>(); + + fn _check_add_1(prev: i32) -> Result { + prev.checked_add(1).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Cannot assign more partition ids. Overflow.", + ) + }) + } + + let mut bound_fields = Vec::with_capacity(fields.len()); + for field in fields.into_iter() { + let partition_field_id = if let Some(partition_field_id) = field.field_id { + last_assigned_field_id = std::cmp::max(last_assigned_field_id, partition_field_id); + partition_field_id + } else { + last_assigned_field_id = _check_add_1(last_assigned_field_id)?; + while assigned_ids.contains(&last_assigned_field_id) { + last_assigned_field_id = _check_add_1(last_assigned_field_id)?; + } + last_assigned_field_id + }; + + bound_fields.push(PartitionField { + source_id: field.source_id, + field_id: partition_field_id, + name: field.name, + transform: field.transform, + }) + } + + Ok(bound_fields) + } + + /// Ensure that the partition name is unique among columns in the schema. + /// Duplicate names are allowed if: + /// 1. The column is sourced from the column with the same name. + /// 2. AND the transformation is identity + fn check_name_does_not_collide_with_schema( + field: &UnboundPartitionField, + schema: &Schema, + ) -> Result<()> { + match schema.field_by_name(field.name.as_str()) { + Some(schema_collision) => { + if field.transform == Transform::Identity { + if schema_collision.id == field.source_id { + Ok(()) + } else { + Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Cannot create identity partition sourced from different field in schema. Field name '{}' has id `{}` in schema but partition source id is `{}`", + field.name, schema_collision.id, field.source_id + ), + )) + } + } else { + Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Cannot create partition with name: '{}' that conflicts with schema field and is not an identity transform.", + field.name + ), + )) + } + } + None => Ok(()), + } + } + + /// Ensure that the transformation of the field is compatible with type of the field + /// in the schema. Implicitly also checks if the source field exists in the schema. + fn check_transform_compatibility(field: &UnboundPartitionField, schema: &Schema) -> Result<()> { + let schema_field = schema.field_by_id(field.source_id).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!( + "Cannot find partition source field with id `{}` in schema", + field.source_id + ), + ) + })?; + + if field.transform != Transform::Void { + if !schema_field.field_type.is_primitive() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Cannot partition by non-primitive source field: '{}'.", + schema_field.field_type + ), + )); + } + + if field + .transform + .result_type(&schema_field.field_type) + .is_err() + { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Invalid source type: '{}' for transform: '{}'.", + schema_field.field_type, + field.transform.dedup_name() + ), + )); + } + } + + Ok(()) + } +} + +/// Contains checks that are common to both PartitionSpecBuilder and UnboundPartitionSpecBuilder +trait CorePartitionSpecValidator { + /// Ensure that the partition name is unique among the partition fields and is not empty. + fn check_name_set_and_unique(&self, name: &str) -> Result<()> { + if name.is_empty() { + return Err(Error::new( + ErrorKind::DataInvalid, + "Cannot use empty partition name", + )); + } + + if self.fields().iter().any(|f| f.name == name) { + return Err(Error::new( + ErrorKind::DataInvalid, + format!("Cannot use partition name more than once: {}", name), + )); + } + Ok(()) + } + + /// For a single source-column transformations must be unique. + fn check_for_redundant_partitions(&self, source_id: i32, transform: &Transform) -> Result<()> { + let collision = self.fields().iter().find(|f| { + f.source_id == source_id && f.transform.dedup_name() == transform.dedup_name() + }); + + if let Some(collision) = collision { + Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Cannot add redundant partition with source id `{}` and transform `{}`. A partition with the same source id and transform already exists with name `{}`", + source_id, transform.dedup_name(), collision.name + ), + )) + } else { + Ok(()) + } + } + + /// Check field / partition_id unique within the partition spec if set + fn check_partition_id_unique(&self, field_id: i32) -> Result<()> { + if self.fields().iter().any(|f| f.field_id == Some(field_id)) { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Cannot use field id more than once in one PartitionSpec: {}", + field_id + ), + )); + } + + Ok(()) + } + + fn fields(&self) -> &Vec; +} + +impl CorePartitionSpecValidator for PartitionSpecBuilder<'_> { + fn fields(&self) -> &Vec { + &self.fields + } +} + +impl CorePartitionSpecValidator for UnboundPartitionSpecBuilder { + fn fields(&self) -> &Vec { + &self.fields + } } #[cfg(test)] mod tests { - use crate::spec::Type; - use super::*; + use crate::spec::Type; #[test] fn test_partition_spec() { @@ -184,9 +682,21 @@ mod tests { #[test] fn test_is_unpartitioned() { - let partition_spec = PartitionSpec::builder() + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + let partition_spec = PartitionSpec::builder(&schema) .with_spec_id(1) - .with_fields(vec![]) .build() .unwrap(); assert!( @@ -194,23 +704,20 @@ mod tests { "Empty partition spec should be unpartitioned" ); - let partition_spec = PartitionSpec::builder() - .with_partition_field( - PartitionField::builder() + let partition_spec = PartitionSpec::builder(&schema) + .add_unbound_fields(vec![ + UnboundPartitionField::builder() .source_id(1) - .field_id(1) .name("id".to_string()) .transform(Transform::Identity) .build(), - ) - .with_partition_field( - PartitionField::builder() + UnboundPartitionField::builder() .source_id(2) - .field_id(2) - .name("name".to_string()) + .name("name_string".to_string()) .transform(Transform::Void) .build(), - ) + ]) + .unwrap() .with_spec_id(1) .build() .unwrap(); @@ -219,24 +726,21 @@ mod tests { "Partition spec with one non void transform should not be unpartitioned" ); - let partition_spec = PartitionSpec::builder() + let partition_spec = PartitionSpec::builder(&schema) .with_spec_id(1) - .with_partition_field( - PartitionField::builder() + .add_unbound_fields(vec![ + UnboundPartitionField::builder() .source_id(1) - .field_id(1) - .name("id".to_string()) + .name("id_void".to_string()) .transform(Transform::Void) .build(), - ) - .with_partition_field( - PartitionField::builder() + UnboundPartitionField::builder() .source_id(2) - .field_id(2) - .name("name".to_string()) + .name("name_void".to_string()) .transform(Transform::Void) .build(), - ) + ]) + .unwrap() .build() .unwrap(); assert!( @@ -252,17 +756,17 @@ mod tests { "spec-id": 1, "fields": [ { "source-id": 4, - "partition-id": 1000, + "field-id": 1000, "name": "ts_day", "transform": "day" }, { "source-id": 1, - "partition-id": 1001, + "field-id": 1001, "name": "id_bucket", "transform": "bucket[16]" }, { "source-id": 2, - "partition-id": 1002, + "field-id": 1002, "name": "id_truncate", "transform": "truncate[4]" } ] @@ -273,17 +777,17 @@ mod tests { assert_eq!(Some(1), partition_spec.spec_id); assert_eq!(4, partition_spec.fields[0].source_id); - assert_eq!(Some(1000), partition_spec.fields[0].partition_id); + assert_eq!(Some(1000), partition_spec.fields[0].field_id); assert_eq!("ts_day", partition_spec.fields[0].name); assert_eq!(Transform::Day, partition_spec.fields[0].transform); assert_eq!(1, partition_spec.fields[1].source_id); - assert_eq!(Some(1001), partition_spec.fields[1].partition_id); + assert_eq!(Some(1001), partition_spec.fields[1].field_id); assert_eq!("id_bucket", partition_spec.fields[1].name); assert_eq!(Transform::Bucket(16), partition_spec.fields[1].transform); assert_eq!(2, partition_spec.fields[2].source_id); - assert_eq!(Some(1002), partition_spec.fields[2].partition_id); + assert_eq!(Some(1002), partition_spec.fields[2].field_id); assert_eq!("id_truncate", partition_spec.fields[2].name); assert_eq!(Transform::Truncate(4), partition_spec.fields[2].transform); @@ -300,7 +804,7 @@ mod tests { assert_eq!(None, partition_spec.spec_id); assert_eq!(4, partition_spec.fields[0].source_id); - assert_eq!(None, partition_spec.fields[0].partition_id); + assert_eq!(None, partition_spec.fields[0].field_id); assert_eq!("ts_day", partition_spec.fields[0].name); assert_eq!(Transform::Day, partition_spec.fields[0].transform); } @@ -375,7 +879,7 @@ mod tests { NestedField::optional( partition_spec.fields[0].field_id, &partition_spec.fields[0].name, - Type::Primitive(crate::spec::PrimitiveType::Int) + Type::Primitive(crate::spec::PrimitiveType::Date) ) ); assert_eq!( @@ -489,4 +993,686 @@ mod tests { assert!(partition_spec.partition_type(&schema).is_err()); } + + #[test] + fn test_builder_disallow_duplicate_names() { + UnboundPartitionSpec::builder() + .add_partition_field(1, "ts_day".to_string(), Transform::Day) + .unwrap() + .add_partition_field(2, "ts_day".to_string(), Transform::Day) + .unwrap_err(); + } + + #[test] + fn test_builder_disallow_duplicate_field_ids() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + PartitionSpec::builder(&schema) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: Some(1000), + name: "id".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: Some(1000), + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }) + .unwrap_err(); + } + + #[test] + fn test_builder_auto_assign_field_ids() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + NestedField::required( + 3, + "ts", + Type::Primitive(crate::spec::PrimitiveType::Timestamp), + ) + .into(), + ]) + .build() + .unwrap(); + let spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + name: "id".to_string(), + transform: Transform::Identity, + field_id: Some(1012), + }) + .unwrap() + .add_unbound_field(UnboundPartitionField { + source_id: 2, + name: "name_void".to_string(), + transform: Transform::Void, + field_id: None, + }) + .unwrap() + // Should keep its ID even if its lower + .add_unbound_field(UnboundPartitionField { + source_id: 3, + name: "year".to_string(), + transform: Transform::Year, + field_id: Some(1), + }) + .unwrap() + .build() + .unwrap(); + + assert_eq!(1012, spec.fields[0].field_id); + assert_eq!(1013, spec.fields[1].field_id); + assert_eq!(1, spec.fields[2].field_id); + } + + #[test] + fn test_builder_valid_schema() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + + PartitionSpec::builder(&schema) + .with_spec_id(1) + .build() + .unwrap(); + + let spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_partition_field("id", "id_bucket[16]", Transform::Bucket(16)) + .unwrap() + .build() + .unwrap(); + + assert_eq!(spec, PartitionSpec { + spec_id: 1, + fields: vec![PartitionField { + source_id: 1, + field_id: 1000, + name: "id_bucket[16]".to_string(), + transform: Transform::Bucket(16), + }] + }); + } + + #[test] + fn test_collision_with_schema_name() { + let schema = Schema::builder() + .with_fields(vec![NestedField::required( + 1, + "id", + Type::Primitive(crate::spec::PrimitiveType::Int), + ) + .into()]) + .build() + .unwrap(); + + PartitionSpec::builder(&schema) + .with_spec_id(1) + .build() + .unwrap(); + + let err = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id".to_string(), + transform: Transform::Bucket(16), + }) + .unwrap_err(); + assert!(err.message().contains("conflicts with schema")) + } + + #[test] + fn test_builder_collision_is_ok_for_identity_transforms() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "number", + Type::Primitive(crate::spec::PrimitiveType::Int), + ) + .into(), + ]) + .build() + .unwrap(); + + PartitionSpec::builder(&schema) + .with_spec_id(1) + .build() + .unwrap(); + + PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .build() + .unwrap(); + + // Not OK for different source id + PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: None, + name: "id".to_string(), + transform: Transform::Identity, + }) + .unwrap_err(); + } + + #[test] + fn test_builder_all_source_ids_must_exist() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + NestedField::required( + 3, + "ts", + Type::Primitive(crate::spec::PrimitiveType::Timestamp), + ) + .into(), + ]) + .build() + .unwrap(); + + // Valid + PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_fields(vec![ + UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }, + UnboundPartitionField { + source_id: 2, + field_id: None, + name: "name".to_string(), + transform: Transform::Identity, + }, + ]) + .unwrap() + .build() + .unwrap(); + + // Invalid + PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_fields(vec![ + UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }, + UnboundPartitionField { + source_id: 4, + field_id: None, + name: "name".to_string(), + transform: Transform::Identity, + }, + ]) + .unwrap_err(); + } + + #[test] + fn test_builder_disallows_redundant() { + let err = UnboundPartitionSpec::builder() + .with_spec_id(1) + .add_partition_field(1, "id_bucket[16]".to_string(), Transform::Bucket(16)) + .unwrap() + .add_partition_field( + 1, + "id_bucket_with_other_name".to_string(), + Transform::Bucket(16), + ) + .unwrap_err(); + assert!(err.message().contains("redundant partition")); + } + + #[test] + fn test_builder_incompatible_transforms_disallowed() { + let schema = Schema::builder() + .with_fields(vec![NestedField::required( + 1, + "id", + Type::Primitive(crate::spec::PrimitiveType::Int), + ) + .into()]) + .build() + .unwrap(); + + PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_year".to_string(), + transform: Transform::Year, + }) + .unwrap_err(); + } + + #[test] + fn test_build_unbound_specs_without_partition_id() { + let spec = UnboundPartitionSpec::builder() + .with_spec_id(1) + .add_partition_fields(vec![UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket[16]".to_string(), + transform: Transform::Bucket(16), + }]) + .unwrap() + .build(); + + assert_eq!(spec, UnboundPartitionSpec { + spec_id: Some(1), + fields: vec![UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket[16]".to_string(), + transform: Transform::Bucket(16), + }] + }); + } + + #[test] + fn test_is_compatible_with() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + + let partition_spec_1 = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }) + .unwrap() + .build() + .unwrap(); + + let partition_spec_2 = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }) + .unwrap() + .build() + .unwrap(); + + assert!(partition_spec_1.is_compatible_with(&partition_spec_2.into_unbound())); + } + + #[test] + fn test_not_compatible_with_transform_different() { + let schema = Schema::builder() + .with_fields(vec![NestedField::required( + 1, + "id", + Type::Primitive(crate::spec::PrimitiveType::Int), + ) + .into()]) + .build() + .unwrap(); + + let partition_spec_1 = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }) + .unwrap() + .build() + .unwrap(); + + let partition_spec_2 = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(32), + }) + .unwrap() + .build() + .unwrap(); + + assert!(!partition_spec_1.is_compatible_with(&partition_spec_2.into_unbound())); + } + + #[test] + fn test_not_compatible_with_source_id_different() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + + let partition_spec_1 = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }) + .unwrap() + .build() + .unwrap(); + + let partition_spec_2 = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }) + .unwrap() + .build() + .unwrap(); + + assert!(!partition_spec_1.is_compatible_with(&partition_spec_2.into_unbound())); + } + + #[test] + fn test_not_compatible_with_order_different() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + + let partition_spec_1 = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }) + .unwrap() + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: None, + name: "name".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .build() + .unwrap(); + + let partition_spec_2 = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: None, + name: "name".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Bucket(16), + }) + .unwrap() + .build() + .unwrap(); + + assert!(!partition_spec_1.is_compatible_with(&partition_spec_2.into_unbound())); + } + + #[test] + fn test_highest_field_id_unpartitioned() { + let spec = PartitionSpec::builder(&Schema::builder().with_fields(vec![]).build().unwrap()) + .with_spec_id(1) + .build() + .unwrap(); + + assert_eq!(UNPARTITIONED_LAST_ASSIGNED_ID, spec.highest_field_id()); + } + + #[test] + fn test_highest_field_id() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + + let spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: Some(1001), + name: "id".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: Some(1000), + name: "name".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .build() + .unwrap(); + + assert_eq!(1001, spec.highest_field_id()); + } + + #[test] + fn test_has_sequential_ids() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + + let spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: Some(1000), + name: "id".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: Some(1001), + name: "name".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .build() + .unwrap(); + + assert_eq!(1000, spec.fields[0].field_id); + assert_eq!(1001, spec.fields[1].field_id); + assert!(spec.has_sequential_ids()); + } + + #[test] + fn test_sequential_ids_must_start_at_1000() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + + let spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: Some(999), + name: "id".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: Some(1000), + name: "name".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .build() + .unwrap(); + + assert_eq!(999, spec.fields[0].field_id); + assert_eq!(1000, spec.fields[1].field_id); + assert!(!spec.has_sequential_ids()); + } + + #[test] + fn test_sequential_ids_must_have_no_gaps() { + let schema = Schema::builder() + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(crate::spec::PrimitiveType::Int)) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(); + + let spec = PartitionSpec::builder(&schema) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: Some(1000), + name: "id".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: Some(1002), + name: "name".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .build() + .unwrap(); + + assert_eq!(1000, spec.fields[0].field_id); + assert_eq!(1002, spec.fields[1].field_id); + assert!(!spec.has_sequential_ids()); + } } diff --git a/crates/iceberg/src/spec/schema.rs b/crates/iceberg/src/spec/schema.rs index 724498b45..63a9e3cb4 100644 --- a/crates/iceberg/src/spec/schema.rs +++ b/crates/iceberg/src/spec/schema.rs @@ -17,31 +17,36 @@ //! This module defines schema in iceberg. +use std::collections::{HashMap, HashSet}; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use _serde::SchemaEnum; +use bimap::BiHashMap; +use itertools::{zip_eq, Itertools}; +use serde::{Deserialize, Serialize}; + +use super::NestedField; use crate::error::Result; +use crate::expr::accessor::StructAccessor; use crate::spec::datatypes::{ ListType, MapType, NestedFieldRef, PrimitiveType, StructType, Type, LIST_FILED_NAME, MAP_KEY_FIELD_NAME, MAP_VALUE_FIELD_NAME, }; use crate::{ensure_data_valid, Error, ErrorKind}; -use bimap::BiHashMap; -use itertools::Itertools; -use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; -use std::fmt::{Display, Formatter}; -use std::sync::Arc; - -use _serde::SchemaEnum; +/// Type alias for schema id. +pub type SchemaId = i32; /// Reference to [`Schema`]. pub type SchemaRef = Arc; -const DEFAULT_SCHEMA_ID: i32 = 0; +const DEFAULT_SCHEMA_ID: SchemaId = 0; /// Defines schema in iceberg. #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(try_from = "SchemaEnum", into = "SchemaEnum")] pub struct Schema { r#struct: StructType, - schema_id: i32, + schema_id: SchemaId, highest_field_id: i32, identifier_field_ids: HashSet, @@ -49,7 +54,10 @@ pub struct Schema { id_to_field: HashMap, name_to_id: HashMap, + lowercase_name_to_id: HashMap, id_to_name: HashMap, + + field_id_to_accessor: HashMap>, } impl PartialEq for Schema { @@ -98,7 +106,7 @@ impl SchemaBuilder { /// Builds the schema. pub fn build(self) -> Result { - let highest_field_id = self.fields.iter().map(|f| f.id).max().unwrap_or(0); + let field_id_to_accessor = self.build_accessors(); let r#struct = StructType::new(self.fields); let id_to_field = index_by_id(&r#struct)?; @@ -115,26 +123,90 @@ impl SchemaBuilder { index.indexes() }; + let lowercase_name_to_id = name_to_id + .iter() + .map(|(k, v)| (k.to_lowercase(), *v)) + .collect(); + + let highest_field_id = id_to_field.keys().max().cloned().unwrap_or(0); + Ok(Schema { r#struct, schema_id: self.schema_id, highest_field_id, identifier_field_ids: self.identifier_field_ids, - alias_to_id: self.alias_to_id, id_to_field, name_to_id, + lowercase_name_to_id, id_to_name, + + field_id_to_accessor, }) } + fn build_accessors(&self) -> HashMap> { + let mut map = HashMap::new(); + + for (pos, field) in self.fields.iter().enumerate() { + match field.field_type.as_ref() { + Type::Primitive(prim_type) => { + // add an accessor for this field + let accessor = Arc::new(StructAccessor::new(pos, prim_type.clone())); + map.insert(field.id, accessor.clone()); + } + + Type::Struct(nested) => { + // add accessors for nested fields + for (field_id, accessor) in Self::build_accessors_nested(nested.fields()) { + let new_accessor = Arc::new(StructAccessor::wrap(pos, accessor)); + map.insert(field_id, new_accessor.clone()); + } + } + _ => { + // Accessors don't get built for Map or List types + } + } + } + + map + } + + fn build_accessors_nested(fields: &[NestedFieldRef]) -> Vec<(i32, Box)> { + let mut results = vec![]; + for (pos, field) in fields.iter().enumerate() { + match field.field_type.as_ref() { + Type::Primitive(prim_type) => { + let accessor = Box::new(StructAccessor::new(pos, prim_type.clone())); + results.push((field.id, accessor)); + } + Type::Struct(nested) => { + let nested_accessors = Self::build_accessors_nested(nested.fields()); + + let wrapped_nested_accessors = + nested_accessors.into_iter().map(|(id, accessor)| { + let new_accessor = Box::new(StructAccessor::wrap(pos, accessor)); + (id, new_accessor.clone()) + }); + + results.extend(wrapped_nested_accessors); + } + _ => { + // Accessors don't get built for Map or List types + } + } + } + + results + } + fn validate_identifier_ids( r#struct: &StructType, id_to_field: &HashMap, identifier_field_ids: impl Iterator, ) -> Result<()> { - let id_to_parent = index_parents(r#struct); + let id_to_parent = index_parents(r#struct)?; for identifier_field_id in identifier_field_ids { let field = id_to_field.get(&identifier_field_id).ok_or_else(|| { Error::new( @@ -196,6 +268,16 @@ impl Schema { } } + /// Create a new schema builder from a schema. + pub fn into_builder(self) -> SchemaBuilder { + SchemaBuilder { + schema_id: self.schema_id, + fields: self.r#struct.fields().to_vec(), + alias_to_id: self.alias_to_id, + identifier_field_ids: self.identifier_field_ids, + } + } + /// Get field by field id. pub fn field_by_id(&self, field_id: i32) -> Option<&NestedFieldRef> { self.id_to_field.get(&field_id) @@ -210,6 +292,15 @@ impl Schema { .and_then(|id| self.field_by_id(*id)) } + /// Get field by field name, but in case-insensitive way. + /// + /// Both full name and short name could work here. + pub fn field_by_name_case_insensitive(&self, field_name: &str) -> Option<&NestedFieldRef> { + self.lowercase_name_to_id + .get(&field_name.to_lowercase()) + .and_then(|id| self.field_by_id(*id)) + } + /// Get field by alias. pub fn field_by_alias(&self, alias: &str) -> Option<&NestedFieldRef> { self.alias_to_id @@ -225,7 +316,7 @@ impl Schema { /// Returns [`schema_id`]. #[inline] - pub fn schema_id(&self) -> i32 { + pub fn schema_id(&self) -> SchemaId { self.schema_id } @@ -235,6 +326,12 @@ impl Schema { &self.r#struct } + /// Returns [`identifier_field_ids`]. + #[inline] + pub fn identifier_field_ids(&self) -> impl Iterator + '_ { + self.identifier_field_ids.iter().copied() + } + /// Get field id by full name. pub fn field_id_by_name(&self, name: &str) -> Option { self.name_to_id.get(name).copied() @@ -244,6 +341,11 @@ impl Schema { pub fn name_by_field_id(&self, field_id: i32) -> Option<&str> { self.id_to_name.get(&field_id).map(String::as_str) } + + /// Get an accessor for retrieving data in a struct + pub fn accessor_by_field_id(&self, field_id: i32) -> Option> { + self.field_id_to_accessor.get(&field_id).cloned() + } } impl Display for Schema { @@ -361,7 +463,7 @@ pub fn visit_schema(schema: &Schema, visitor: &mut V) -> Resul visitor.schema(schema, result) } -/// Creates an field id to field map. +/// Creates a field id to field map. pub fn index_by_id(r#struct: &StructType) -> Result> { struct IndexById(HashMap); @@ -404,7 +506,7 @@ pub fn index_by_id(r#struct: &StructType) -> Result } /// Creates a field id to parent field id map. -pub fn index_parents(r#struct: &StructType) -> HashMap { +pub fn index_parents(r#struct: &StructType) -> Result> { struct IndexByParent { parents: Vec, result: HashMap, @@ -485,8 +587,8 @@ pub fn index_parents(r#struct: &StructType) -> HashMap { parents: vec![], result: HashMap::new(), }; - visit_struct(r#struct, &mut index).unwrap(); - index.result + visit_struct(r#struct, &mut index)?; + Ok(index.result) } #[derive(Default)] @@ -624,16 +726,238 @@ impl SchemaVisitor for IndexByName { } } +struct PruneColumn { + selected: HashSet, + select_full_types: bool, +} + +/// Visit a schema and returns only the fields selected by id set +pub fn prune_columns( + schema: &Schema, + selected: impl IntoIterator, + select_full_types: bool, +) -> Result { + let mut visitor = PruneColumn::new(HashSet::from_iter(selected), select_full_types); + let result = visit_schema(schema, &mut visitor); + + match result { + Ok(s) => { + if let Some(struct_type) = s { + Ok(struct_type) + } else { + Ok(Type::Struct(StructType::default())) + } + } + Err(e) => Err(e), + } +} + +impl PruneColumn { + fn new(selected: HashSet, select_full_types: bool) -> Self { + Self { + selected, + select_full_types, + } + } + + fn project_selected_struct(projected_field: Option) -> Result { + match projected_field { + // If the field is a StructType, return it as such + Some(Type::Struct(s)) => Ok(s), + Some(_) => Err(Error::new( + ErrorKind::Unexpected, + "Projected field with struct type must be struct".to_string(), + )), + // If projected_field is None or not a StructType, return an empty StructType + None => Ok(StructType::default()), + } + } + fn project_list(list: &ListType, element_result: Type) -> Result { + if *list.element_field.field_type == element_result { + return Ok(list.clone()); + } + Ok(ListType { + element_field: Arc::new(NestedField { + id: list.element_field.id, + name: list.element_field.name.clone(), + required: list.element_field.required, + field_type: Box::new(element_result), + doc: list.element_field.doc.clone(), + initial_default: list.element_field.initial_default.clone(), + write_default: list.element_field.write_default.clone(), + }), + }) + } + fn project_map(map: &MapType, value_result: Type) -> Result { + if *map.value_field.field_type == value_result { + return Ok(map.clone()); + } + Ok(MapType { + key_field: map.key_field.clone(), + value_field: Arc::new(NestedField { + id: map.value_field.id, + name: map.value_field.name.clone(), + required: map.value_field.required, + field_type: Box::new(value_result), + doc: map.value_field.doc.clone(), + initial_default: map.value_field.initial_default.clone(), + write_default: map.value_field.write_default.clone(), + }), + }) + } +} + +impl SchemaVisitor for PruneColumn { + type T = Option; + + fn schema(&mut self, _schema: &Schema, value: Option) -> Result> { + Ok(Some(value.unwrap())) + } + + fn field(&mut self, field: &NestedFieldRef, value: Option) -> Result> { + if self.selected.contains(&field.id) { + if self.select_full_types { + Ok(Some(*field.field_type.clone())) + } else if field.field_type.is_struct() { + return Ok(Some(Type::Struct(PruneColumn::project_selected_struct( + value, + )?))); + } else if !field.field_type.is_nested() { + return Ok(Some(*field.field_type.clone())); + } else { + return Err(Error::new( + ErrorKind::DataInvalid, + "Can't project list or map field directly when not selecting full type." + .to_string(), + ) + .with_context("field_id", field.id.to_string()) + .with_context("field_type", field.field_type.to_string())); + } + } else { + Ok(value) + } + } + + fn r#struct( + &mut self, + r#struct: &StructType, + results: Vec>, + ) -> Result> { + let fields = r#struct.fields(); + let mut selected_field = Vec::with_capacity(fields.len()); + let mut same_type = true; + + for (field, projected_type) in zip_eq(fields.iter(), results.iter()) { + if let Some(projected_type) = projected_type { + if *field.field_type == *projected_type { + selected_field.push(field.clone()); + } else { + same_type = false; + let new_field = NestedField { + id: field.id, + name: field.name.clone(), + required: field.required, + field_type: Box::new(projected_type.clone()), + doc: field.doc.clone(), + initial_default: field.initial_default.clone(), + write_default: field.write_default.clone(), + }; + selected_field.push(Arc::new(new_field)); + } + } + } + + if !selected_field.is_empty() { + if selected_field.len() == fields.len() && same_type { + return Ok(Some(Type::Struct(r#struct.clone()))); + } else { + return Ok(Some(Type::Struct(StructType::new(selected_field)))); + } + } + Ok(None) + } + + fn list(&mut self, list: &ListType, value: Option) -> Result> { + if self.selected.contains(&list.element_field.id) { + if self.select_full_types { + Ok(Some(Type::List(list.clone()))) + } else if list.element_field.field_type.is_struct() { + let projected_struct = PruneColumn::project_selected_struct(value).unwrap(); + return Ok(Some(Type::List(PruneColumn::project_list( + list, + Type::Struct(projected_struct), + )?))); + } else if list.element_field.field_type.is_primitive() { + return Ok(Some(Type::List(list.clone()))); + } else { + return Err(Error::new( + ErrorKind::DataInvalid, + format!("Cannot explicitly project List or Map types, List element {} of type {} was selected", list.element_field.id, list.element_field.field_type), + )); + } + } else if let Some(result) = value { + Ok(Some(Type::List(PruneColumn::project_list(list, result)?))) + } else { + Ok(None) + } + } + + fn map( + &mut self, + map: &MapType, + _key_value: Option, + value: Option, + ) -> Result> { + if self.selected.contains(&map.value_field.id) { + if self.select_full_types { + Ok(Some(Type::Map(map.clone()))) + } else if map.value_field.field_type.is_struct() { + let projected_struct = + PruneColumn::project_selected_struct(Some(value.unwrap())).unwrap(); + return Ok(Some(Type::Map(PruneColumn::project_map( + map, + Type::Struct(projected_struct), + )?))); + } else if map.value_field.field_type.is_primitive() { + return Ok(Some(Type::Map(map.clone()))); + } else { + return Err(Error::new( + ErrorKind::DataInvalid, + format!("Cannot explicitly project List or Map types, Map value {} of type {} was selected", map.value_field.id, map.value_field.field_type), + )); + } + } else if let Some(value_result) = value { + return Ok(Some(Type::Map(PruneColumn::project_map( + map, + value_result, + )?))); + } else if self.selected.contains(&map.key_field.id) { + Ok(Some(Type::Map(map.clone()))) + } else { + Ok(None) + } + } + + fn primitive(&mut self, _p: &PrimitiveType) -> Result> { + Ok(None) + } +} + pub(super) mod _serde { /// This is a helper module that defines types to help with serialization/deserialization. /// For deserialization the input first gets read into either the [SchemaV1] or [SchemaV2] struct /// and then converted into the [Schema] struct. Serialization works the other way around. /// [SchemaV1] and [SchemaV2] are internal struct that are only used for serialization and deserialization. - use serde::{Deserialize, Serialize}; - - use crate::{spec::StructType, Error, Result}; + use serde::Deserialize; + /// This is a helper module that defines types to help with serialization/deserialization. + /// For deserialization the input first gets read into either the [SchemaV1] or [SchemaV2] struct + /// and then converted into the [Schema] struct. Serialization works the other way around. + /// [SchemaV1] and [SchemaV2] are internal struct that are only used for serialization and deserialization. + use serde::Serialize; use super::{Schema, DEFAULT_SCHEMA_ID}; + use crate::spec::StructType; + use crate::{Error, Result}; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] @@ -736,15 +1060,17 @@ pub(super) mod _serde { #[cfg(test)] mod tests { + use std::collections::{HashMap, HashSet}; + + use super::DEFAULT_SCHEMA_ID; use crate::spec::datatypes::Type::{List, Map, Primitive, Struct}; use crate::spec::datatypes::{ ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, StructType, Type, }; use crate::spec::schema::Schema; use crate::spec::schema::_serde::{SchemaEnum, SchemaV1, SchemaV2}; - use std::collections::HashMap; - - use super::DEFAULT_SCHEMA_ID; + use crate::spec::values::Map as MapValue; + use crate::spec::{prune_columns, Datum, Literal}; fn check_schema_serde(json: &str, expected_type: Schema, _expected_enum: SchemaEnum) { let desered_type: Schema = serde_json::from_str(json).unwrap(); @@ -874,7 +1200,7 @@ mod tests { (schema, record) } - fn table_schema_nested() -> Schema { + pub fn table_schema_nested() -> Schema { Schema::builder() .with_schema_id(1) .with_identifier_field_ids(vec![2]) @@ -969,13 +1295,13 @@ mod tests { #[test] fn test_schema_display() { - let expected_str = r#" + let expected_str = " table { - 1: foo: optional string - 2: bar: required int - 3: baz: optional boolean + 1: foo: optional string\x20 + 2: bar: required int\x20 + 3: baz: optional boolean\x20 } -"#; +"; assert_eq!(expected_str, format!("\n{}", table_schema_simple().0)); } @@ -999,6 +1325,15 @@ table { .contains("Invalid schema: multiple fields for name baz")); } + #[test] + fn test_schema_into_builder() { + let original_schema = table_schema_nested(); + let builder = original_schema.clone().into_builder(); + let schema = builder.build().unwrap(); + + assert_eq!(original_schema, schema); + } + #[test] fn test_schema_index_by_name() { let expected_name_to_id = HashMap::from( @@ -1030,6 +1365,42 @@ table { assert_eq!(&expected_name_to_id, &schema.name_to_id); } + #[test] + fn test_schema_index_by_name_case_insensitive() { + let expected_name_to_id = HashMap::from( + [ + ("fOo", 1), + ("Bar", 2), + ("BAz", 3), + ("quX", 4), + ("quX.ELEment", 5), + ("qUUx", 6), + ("QUUX.KEY", 7), + ("QUUX.Value", 8), + ("qUUX.VALUE.Key", 9), + ("qUux.VaLue.Value", 10), + ("lOCAtION", 11), + ("LOCAtioN.ELeMENt", 12), + ("LoCATion.element.LATitude", 13), + ("locatION.ElemeNT.LONgitude", 14), + ("LOCAtiON.LATITUDE", 13), + ("LOCATION.LONGITUDE", 14), + ("PERSon", 15), + ("PERSON.Name", 16), + ("peRSON.AGe", 17), + ] + .map(|e| (e.0.to_string(), e.1)), + ); + + let schema = table_schema_nested(); + for (name, id) in expected_name_to_id { + assert_eq!( + Some(id), + schema.field_by_name_case_insensitive(&name).map(|f| f.id) + ); + } + } + #[test] fn test_schema_find_column_name() { let expected_column_name = HashMap::from([ @@ -1284,4 +1655,586 @@ table { ); } } + + #[test] + fn test_build_accessors() { + let schema = table_schema_nested(); + + let test_struct = crate::spec::Struct::from_iter(vec![ + Some(Literal::string("foo value")), + Some(Literal::int(1002)), + Some(Literal::bool(true)), + Some(Literal::List(vec![ + Some(Literal::string("qux item 1")), + Some(Literal::string("qux item 2")), + ])), + Some(Literal::Map(MapValue::from([( + Literal::string("quux key 1"), + Some(Literal::Map(MapValue::from([( + Literal::string("quux nested key 1"), + Some(Literal::int(1000)), + )]))), + )]))), + Some(Literal::List(vec![Some(Literal::Struct( + crate::spec::Struct::from_iter(vec![ + Some(Literal::float(52.509_09)), + Some(Literal::float(-1.885_249)), + ]), + ))])), + Some(Literal::Struct(crate::spec::Struct::from_iter(vec![ + Some(Literal::string("Testy McTest")), + Some(Literal::int(33)), + ]))), + ]); + + assert_eq!( + schema + .accessor_by_field_id(1) + .unwrap() + .get(&test_struct) + .unwrap(), + Some(Datum::string("foo value")) + ); + assert_eq!( + schema + .accessor_by_field_id(2) + .unwrap() + .get(&test_struct) + .unwrap(), + Some(Datum::int(1002)) + ); + assert_eq!( + schema + .accessor_by_field_id(3) + .unwrap() + .get(&test_struct) + .unwrap(), + Some(Datum::bool(true)) + ); + assert_eq!( + schema + .accessor_by_field_id(16) + .unwrap() + .get(&test_struct) + .unwrap(), + Some(Datum::string("Testy McTest")) + ); + assert_eq!( + schema + .accessor_by_field_id(17) + .unwrap() + .get(&test_struct) + .unwrap(), + Some(Datum::int(33)) + ); + } + + #[test] + fn test_schema_prune_columns_string() { + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::optional( + 1, + "foo", + Type::Primitive(PrimitiveType::String), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([1]); + let result = prune_columns(&schema, selected, false); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_schema_prune_columns_string_full() { + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::optional( + 1, + "foo", + Type::Primitive(PrimitiveType::String), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([1]); + let result = prune_columns(&schema, selected, true); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_schema_prune_columns_list() { + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::required( + 4, + "qux", + Type::List(ListType { + element_field: NestedField::list_element( + 5, + Type::Primitive(PrimitiveType::String), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([5]); + let result = prune_columns(&schema, selected, false); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_list_itself() { + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([4]); + let result = prune_columns(&schema, selected, false); + assert!(result.is_err()); + } + + #[test] + fn test_schema_prune_columns_list_full() { + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::required( + 4, + "qux", + Type::List(ListType { + element_field: NestedField::list_element( + 5, + Type::Primitive(PrimitiveType::String), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([5]); + let result = prune_columns(&schema, selected, true); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_map() { + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::required( + 6, + "quux", + Type::Map(MapType { + key_field: NestedField::map_key_element( + 7, + Type::Primitive(PrimitiveType::String), + ) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Map(MapType { + key_field: NestedField::map_key_element( + 9, + Type::Primitive(PrimitiveType::String), + ) + .into(), + value_field: NestedField::map_value_element( + 10, + Type::Primitive(PrimitiveType::Int), + true, + ) + .into(), + }), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([9]); + let result = prune_columns(&schema, selected, false); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_map_itself() { + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([6]); + let result = prune_columns(&schema, selected, false); + assert!(result.is_err()); + } + + #[test] + fn test_prune_columns_map_full() { + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::required( + 6, + "quux", + Type::Map(MapType { + key_field: NestedField::map_key_element( + 7, + Type::Primitive(PrimitiveType::String), + ) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Map(MapType { + key_field: NestedField::map_key_element( + 9, + Type::Primitive(PrimitiveType::String), + ) + .into(), + value_field: NestedField::map_value_element( + 10, + Type::Primitive(PrimitiveType::Int), + true, + ) + .into(), + }), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([9]); + let result = prune_columns(&schema, selected, true); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_map_key() { + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::required( + 6, + "quux", + Type::Map(MapType { + key_field: NestedField::map_key_element( + 7, + Type::Primitive(PrimitiveType::String), + ) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Map(MapType { + key_field: NestedField::map_key_element( + 9, + Type::Primitive(PrimitiveType::String), + ) + .into(), + value_field: NestedField::map_value_element( + 10, + Type::Primitive(PrimitiveType::Int), + true, + ) + .into(), + }), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([10]); + let result = prune_columns(&schema, selected, false); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_struct() { + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![NestedField::optional( + 16, + "name", + Type::Primitive(PrimitiveType::String), + ) + .into()])), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([16]); + let result = prune_columns(&schema, selected, false); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_struct_full() { + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![NestedField::optional( + 16, + "name", + Type::Primitive(PrimitiveType::String), + ) + .into()])), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([16]); + let result = prune_columns(&schema, selected, true); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_empty_struct() { + let schema_with_empty_struct_field = Schema::builder() + .with_fields(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![])), + ) + .into()]) + .build() + .unwrap(); + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![])), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let selected: HashSet = HashSet::from([15]); + let result = prune_columns(&schema_with_empty_struct_field, selected, false); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_empty_struct_full() { + let schema_with_empty_struct_field = Schema::builder() + .with_fields(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![])), + ) + .into()]) + .build() + .unwrap(); + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![])), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let selected: HashSet = HashSet::from([15]); + let result = prune_columns(&schema_with_empty_struct_field, selected, true); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_struct_in_map() { + let schema_with_struct_in_map_field = Schema::builder() + .with_schema_id(1) + .with_fields(vec![NestedField::required( + 6, + "id_to_person", + Type::Map(MapType { + key_field: NestedField::map_key_element(7, Type::Primitive(PrimitiveType::Int)) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Struct(StructType::new(vec![ + NestedField::optional(10, "name", Primitive(PrimitiveType::String)) + .into(), + NestedField::required(11, "age", Primitive(PrimitiveType::Int)).into(), + ])), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap(); + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::required( + 6, + "id_to_person", + Type::Map(MapType { + key_field: NestedField::map_key_element( + 7, + Type::Primitive(PrimitiveType::Int), + ) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Struct(StructType::new(vec![NestedField::required( + 11, + "age", + Primitive(PrimitiveType::Int), + ) + .into()])), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let selected: HashSet = HashSet::from([11]); + let result = prune_columns(&schema_with_struct_in_map_field, selected, false); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + #[test] + fn test_prune_columns_struct_in_map_full() { + let schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![NestedField::required( + 6, + "id_to_person", + Type::Map(MapType { + key_field: NestedField::map_key_element(7, Type::Primitive(PrimitiveType::Int)) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Struct(StructType::new(vec![ + NestedField::optional(10, "name", Primitive(PrimitiveType::String)) + .into(), + NestedField::required(11, "age", Primitive(PrimitiveType::Int)).into(), + ])), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap(); + let expected_type = Type::from( + Schema::builder() + .with_fields(vec![NestedField::required( + 6, + "id_to_person", + Type::Map(MapType { + key_field: NestedField::map_key_element( + 7, + Type::Primitive(PrimitiveType::Int), + ) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Struct(StructType::new(vec![NestedField::required( + 11, + "age", + Primitive(PrimitiveType::Int), + ) + .into()])), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap() + .as_struct() + .clone(), + ); + let selected: HashSet = HashSet::from([11]); + let result = prune_columns(&schema, selected, true); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_type); + } + + #[test] + fn test_prune_columns_select_original_schema() { + let schema = table_schema_nested(); + let selected: HashSet = (0..schema.highest_field_id() + 1).collect(); + let result = prune_columns(&schema, selected, true); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Type::Struct(schema.as_struct().clone())); + } + + #[test] + fn test_highest_field_id() { + let schema = table_schema_nested(); + assert_eq!(17, schema.highest_field_id()); + + let schema = table_schema_simple().0; + assert_eq!(3, schema.highest_field_id()); + } } diff --git a/crates/iceberg/src/spec/snapshot.rs b/crates/iceberg/src/spec/snapshot.rs index c10e892bf..f42e736ea 100644 --- a/crates/iceberg/src/spec/snapshot.rs +++ b/crates/iceberg/src/spec/snapshot.rs @@ -17,15 +17,23 @@ /*! * Snapshots -*/ -use chrono::{DateTime, TimeZone, Utc}; -use serde::{Deserialize, Serialize}; + */ use std::collections::HashMap; use std::sync::Arc; + +use _serde::SnapshotV2; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; use super::table_metadata::SnapshotLog; -use _serde::SnapshotV2; +use crate::error::{timestamp_ms_to_utc, Result}; +use crate::io::FileIO; +use crate::spec::{ManifestList, SchemaId, SchemaRef, StructType, TableMetadata}; +use crate::{Error, ErrorKind}; + +/// The ref name of the main branch of the table. +pub const MAIN_BRANCH: &str = "main"; /// Reference to [`Snapshot`]. pub type SnapshotRef = Arc; @@ -79,22 +87,14 @@ pub struct Snapshot { timestamp_ms: i64, /// The location of a manifest list for this snapshot that /// tracks manifest files with additional metadata. - manifest_list: ManifestListLocation, + /// Currently we only support manifest list file, and manifest files are not supported. + #[builder(setter(into))] + manifest_list: String, /// A string map that summarizes the snapshot changes, including operation. summary: Summary, /// ID of the table’s current schema when the snapshot was created. #[builder(setter(strip_option), default = None)] - schema_id: Option, -} - -/// Type to distinguish between a path to a manifestlist file or a vector of manifestfile locations -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -#[serde(untagged)] -pub enum ManifestListLocation { - /// Location of manifestlist file - ManifestListFile(String), - /// Manifestfile locations - ManifestFiles(Vec), + schema_id: Option, } impl Snapshot { @@ -103,6 +103,13 @@ impl Snapshot { pub fn snapshot_id(&self) -> i64 { self.snapshot_id } + + /// Get parent snapshot id. + #[inline] + pub fn parent_snapshot_id(&self) -> Option { + self.parent_snapshot_id + } + /// Get sequence_number of the snapshot. Is 0 for Iceberg V1 tables. #[inline] pub fn sequence_number(&self) -> i64 { @@ -110,9 +117,10 @@ impl Snapshot { } /// Get location of manifest_list file #[inline] - pub fn manifest_list(&self) -> &ManifestListLocation { + pub fn manifest_list(&self) -> &str { &self.manifest_list } + /// Get summary of the snapshot #[inline] pub fn summary(&self) -> &Summary { @@ -120,8 +128,69 @@ impl Snapshot { } /// Get the timestamp of when the snapshot was created #[inline] - pub fn timestamp(&self) -> DateTime { - Utc.timestamp_millis_opt(self.timestamp_ms).unwrap() + pub fn timestamp(&self) -> Result> { + timestamp_ms_to_utc(self.timestamp_ms) + } + + /// Get the timestamp of when the snapshot was created in milliseconds + #[inline] + pub fn timestamp_ms(&self) -> i64 { + self.timestamp_ms + } + + /// Get the schema id of this snapshot. + #[inline] + pub fn schema_id(&self) -> Option { + self.schema_id + } + + /// Get the schema of this snapshot. + pub fn schema(&self, table_metadata: &TableMetadata) -> Result { + Ok(match self.schema_id() { + Some(schema_id) => table_metadata + .schema_by_id(schema_id) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Schema with id {} not found", schema_id), + ) + })? + .clone(), + None => table_metadata.current_schema().clone(), + }) + } + + /// Get parent snapshot. + #[cfg(test)] + pub(crate) fn parent_snapshot(&self, table_metadata: &TableMetadata) -> Option { + match self.parent_snapshot_id { + Some(id) => table_metadata.snapshot_by_id(id).cloned(), + None => None, + } + } + + /// Load manifest list. + pub async fn load_manifest_list( + &self, + file_io: &FileIO, + table_metadata: &TableMetadata, + ) -> Result { + let manifest_list_content = file_io.new_input(&self.manifest_list)?.read().await?; + + let schema = self.schema(table_metadata)?; + + let partition_type_provider = |partition_spec_id: i32| -> Result> { + table_metadata + .partition_spec_by_id(partition_spec_id) + .map(|partition_spec| partition_spec.partition_type(&schema)) + .transpose() + }; + + ManifestList::parse_with_version( + &manifest_list_content, + table_metadata.format_version(), + partition_type_provider, + ) } pub(crate) fn log(&self) -> SnapshotLog { @@ -141,9 +210,9 @@ pub(super) mod _serde { use serde::{Deserialize, Serialize}; - use crate::{Error, ErrorKind}; - - use super::{ManifestListLocation, Operation, Snapshot, Summary}; + use super::{Operation, Snapshot, Summary}; + use crate::spec::SchemaId; + use crate::Error; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] @@ -157,7 +226,7 @@ pub(super) mod _serde { pub manifest_list: String, pub summary: Summary, #[serde(skip_serializing_if = "Option::is_none")] - pub schema_id: Option, + pub schema_id: Option, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] @@ -175,7 +244,7 @@ pub(super) mod _serde { #[serde(skip_serializing_if = "Option::is_none")] pub summary: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub schema_id: Option, + pub schema_id: Option, } impl From for Snapshot { @@ -185,7 +254,7 @@ pub(super) mod _serde { parent_snapshot_id: v2.parent_snapshot_id, sequence_number: v2.sequence_number, timestamp_ms: v2.timestamp_ms, - manifest_list: ManifestListLocation::ManifestListFile(v2.manifest_list), + manifest_list: v2.manifest_list, summary: v2.summary, schema_id: v2.schema_id, } @@ -195,17 +264,14 @@ pub(super) mod _serde { impl From for SnapshotV2 { fn from(v2: Snapshot) -> Self { SnapshotV2 { - snapshot_id: v2.snapshot_id, - parent_snapshot_id: v2.parent_snapshot_id, - sequence_number: v2.sequence_number, - timestamp_ms: v2.timestamp_ms, - manifest_list: match v2.manifest_list { - ManifestListLocation::ManifestListFile(file) => file, - ManifestListLocation::ManifestFiles(_) => panic!("Wrong table format version. Can't convert a list of manifest files into a location of a manifest file.") - }, - summary: v2.summary, - schema_id: v2.schema_id, - } + snapshot_id: v2.snapshot_id, + parent_snapshot_id: v2.parent_snapshot_id, + sequence_number: v2.sequence_number, + timestamp_ms: v2.timestamp_ms, + manifest_list: v2.manifest_list, + summary: v2.summary, + schema_id: v2.schema_id, + } } } @@ -219,15 +285,10 @@ pub(super) mod _serde { sequence_number: 0, timestamp_ms: v1.timestamp_ms, manifest_list: match (v1.manifest_list, v1.manifests) { - (Some(file), _) => ManifestListLocation::ManifestListFile(file), - (None, Some(files)) => ManifestListLocation::ManifestFiles(files), - (None, None) => { - return Err(Error::new( - ErrorKind::DataInvalid, - "Neither manifestlist file or manifest files are provided.", - )) - } - }, + (Some(file), None) => file, + (Some(_), Some(_)) => "Invalid v1 snapshot, when manifest list provided, manifest files should be omitted".to_string(), + (None, _) => "Unsupported v1 snapshot, only manifest list is supported".to_string() + }, summary: v1.summary.unwrap_or(Summary { operation: Operation::default(), other: HashMap::new(), @@ -239,18 +300,14 @@ pub(super) mod _serde { impl From for SnapshotV1 { fn from(v2: Snapshot) -> Self { - let (manifest_list, manifests) = match v2.manifest_list { - ManifestListLocation::ManifestListFile(file) => (Some(file), None), - ManifestListLocation::ManifestFiles(files) => (None, Some(files)), - }; SnapshotV1 { snapshot_id: v2.snapshot_id, parent_snapshot_id: v2.parent_snapshot_id, timestamp_ms: v2.timestamp_ms, - manifest_list, - manifests, + manifest_list: Some(v2.manifest_list), summary: Some(v2.summary), schema_id: v2.schema_id, + manifests: None, } } } @@ -303,18 +360,19 @@ pub enum SnapshotRetention { Tag { /// For snapshot references except the main branch, a positive number for the max age of the snapshot reference to keep while expiring snapshots. /// Defaults to table property history.expire.max-ref-age-ms. The main branch never expires. - max_ref_age_ms: i64, + #[serde(skip_serializing_if = "Option::is_none")] + max_ref_age_ms: Option, }, } #[cfg(test)] mod tests { - use chrono::{TimeZone, Utc}; use std::collections::HashMap; - use crate::spec::snapshot::{ - ManifestListLocation, Operation, Snapshot, Summary, _serde::SnapshotV1, - }; + use chrono::{TimeZone, Utc}; + + use crate::spec::snapshot::_serde::SnapshotV1; + use crate::spec::snapshot::{Operation, Snapshot, Summary}; #[test] fn schema() { @@ -337,8 +395,9 @@ mod tests { assert_eq!(3051729675574597004, result.snapshot_id()); assert_eq!( Utc.timestamp_millis_opt(1515100955770).unwrap(), - result.timestamp() + result.timestamp().unwrap() ); + assert_eq!(1515100955770, result.timestamp_ms()); assert_eq!( Summary { operation: Operation::Append, @@ -346,9 +405,6 @@ mod tests { }, *result.summary() ); - assert_eq!( - ManifestListLocation::ManifestListFile("s3://b/wh/.../s1.avro".to_string()), - *result.manifest_list() - ); + assert_eq!("s3://b/wh/.../s1.avro".to_string(), *result.manifest_list()); } } diff --git a/crates/iceberg/src/spec/sort.rs b/crates/iceberg/src/spec/sort.rs index 01a1eddea..5e50a175c 100644 --- a/crates/iceberg/src/spec/sort.rs +++ b/crates/iceberg/src/spec/sort.rs @@ -17,16 +17,22 @@ /*! * Sorting -*/ -use serde::{Deserialize, Serialize}; + */ +use core::fmt; +use std::fmt::Formatter; use std::sync::Arc; + +use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; use super::transform::Transform; +use crate::error::Result; +use crate::spec::Schema; +use crate::{Error, ErrorKind}; /// Reference to [`SortOrder`]. pub type SortOrderRef = Arc; -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Copy, Clone)] /// Sort direction in a partition, either ascending or descending pub enum SortDirection { /// Ascending @@ -37,7 +43,16 @@ pub enum SortDirection { Descending, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +impl fmt::Display for SortDirection { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match *self { + SortDirection::Ascending => write!(f, "ascending"), + SortDirection::Descending => write!(f, "descending"), + } + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Copy, Clone)] /// Describes the order of null values when sorted. pub enum NullOrder { #[serde(rename = "nulls-first")] @@ -48,6 +63,15 @@ pub enum NullOrder { Last, } +impl fmt::Display for NullOrder { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match *self { + NullOrder::First => write!(f, "first"), + NullOrder::Last => write!(f, "last"), + } + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, TypedBuilder)] #[serde(rename_all = "kebab-case")] /// Entry for every column that is to be sorted @@ -62,9 +86,20 @@ pub struct SortField { pub null_order: NullOrder, } +impl fmt::Display for SortField { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "SortField {{ source_id: {}, transform: {}, direction: {}, null_order: {} }}", + self.source_id, self.transform, self.direction, self.null_order + ) + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Builder, Default)] #[serde(rename_all = "kebab-case")] #[builder(setter(prefix = "with"))] +#[builder(build_fn(skip))] /// A sort order is defined by a sort order id and a list of sort fields. /// The order of the sort fields within the list defines the order in which the sort is applied to the data. pub struct SortOrder { @@ -77,26 +112,113 @@ pub struct SortOrder { } impl SortOrder { + const UNSORTED_ORDER_ID: i64 = 0; + /// Create sort order builder pub fn builder() -> SortOrderBuilder { SortOrderBuilder::default() } + /// Create an unbound unsorted order + pub fn unsorted_order() -> SortOrder { + SortOrder { + order_id: SortOrder::UNSORTED_ORDER_ID, + fields: Vec::new(), + } + } + /// Returns true if the sort order is unsorted. /// /// A [`SortOrder`] is unsorted if it has no sort fields. pub fn is_unsorted(&self) -> bool { self.fields.is_empty() } + + /// Set the order id for the sort order + pub fn with_order_id(self, order_id: i64) -> SortOrder { + SortOrder { + order_id, + fields: self.fields, + } + } +} + +impl SortOrderBuilder { + /// Creates a new unbound sort order. + pub fn build_unbound(&self) -> Result { + let fields = self.fields.clone().unwrap_or_default(); + return match (self.order_id, fields.as_slice()) { + (Some(SortOrder::UNSORTED_ORDER_ID) | None, []) => Ok(SortOrder::unsorted_order()), + (_, []) => Err(Error::new( + ErrorKind::Unexpected, + format!("Unsorted order ID must be {}", SortOrder::UNSORTED_ORDER_ID), + )), + (Some(SortOrder::UNSORTED_ORDER_ID), [..]) => Err(Error::new( + ErrorKind::Unexpected, + format!( + "Sort order ID {} is reserved for unsorted order", + SortOrder::UNSORTED_ORDER_ID + ), + )), + (maybe_order_id, [..]) => Ok(SortOrder { + order_id: maybe_order_id.unwrap_or(1), + fields: fields.to_vec(), + }), + }; + } + + /// Creates a new bound sort order. + pub fn build(&self, schema: &Schema) -> Result { + let unbound_sort_order = self.build_unbound()?; + SortOrderBuilder::check_compatibility(unbound_sort_order, schema) + } + + /// Returns the given sort order if it is compatible with the given schema + fn check_compatibility(sort_order: SortOrder, schema: &Schema) -> Result { + let sort_fields = &sort_order.fields; + for sort_field in sort_fields { + match schema.field_by_id(sort_field.source_id) { + None => { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Cannot find source column for sort field: {sort_field}"), + )) + } + Some(source_field) => { + let source_type = source_field.field_type.as_ref(); + + if !source_type.is_primitive() { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Cannot sort by non-primitive source field: {source_type}"), + )); + } + + let field_transform = sort_field.transform; + if field_transform.result_type(source_type).is_err() { + return Err(Error::new( + ErrorKind::Unexpected, + format!( + "Invalid source type {source_type} for transform {field_transform}" + ), + )); + } + } + } + } + + Ok(sort_order) + } } #[cfg(test)] mod tests { use super::*; + use crate::spec::{ListType, NestedField, PrimitiveType, Type}; #[test] - fn sort_field() { - let sort_field = r#" + fn test_sort_field() { + let spec = r#" { "transform": "bucket[4]", "source-id": 3, @@ -105,7 +227,7 @@ mod tests { } "#; - let field: SortField = serde_json::from_str(sort_field).unwrap(); + let field: SortField = serde_json::from_str(spec).unwrap(); assert_eq!(Transform::Bucket(4), field.transform); assert_eq!(3, field.source_id); assert_eq!(SortDirection::Descending, field.direction); @@ -113,8 +235,8 @@ mod tests { } #[test] - fn sort_order() { - let sort_order = r#" + fn test_sort_order() { + let spec = r#" { "order-id": 1, "fields": [ { @@ -131,7 +253,7 @@ mod tests { } "#; - let order: SortOrder = serde_json::from_str(sort_order).unwrap(); + let order: SortOrder = serde_json::from_str(spec).unwrap(); assert_eq!(Transform::Identity, order.fields[0].transform); assert_eq!(2, order.fields[0].source_id); assert_eq!(SortDirection::Ascending, order.fields[0].direction); @@ -142,4 +264,255 @@ mod tests { assert_eq!(SortDirection::Descending, order.fields[1].direction); assert_eq!(NullOrder::Last, order.fields[1].null_order); } + + #[test] + fn test_build_unbound_should_return_err_if_unsorted_order_does_not_have_an_order_id_of_zero() { + assert_eq!( + SortOrder::builder() + .with_order_id(1) + .build_unbound() + .expect_err("Expected an Err value") + .message(), + "Unsorted order ID must be 0" + ) + } + + #[test] + fn test_build_unbound_should_return_err_if_order_id_equals_zero_is_used_for_anything_other_than_unsorted_order( + ) { + assert_eq!( + SortOrder::builder() + .with_order_id(SortOrder::UNSORTED_ORDER_ID) + .with_sort_field( + SortField::builder() + .source_id(2) + .direction(SortDirection::Ascending) + .null_order(NullOrder::First) + .transform(Transform::Identity) + .build() + ) + .build_unbound() + .expect_err("Expected an Err value") + .message(), + "Sort order ID 0 is reserved for unsorted order" + ) + } + + #[test] + fn test_build_unbound_returns_correct_default_order_id_for_no_fields() { + assert_eq!( + SortOrder::builder() + .build_unbound() + .expect("Expected an Ok value") + .order_id, + SortOrder::UNSORTED_ORDER_ID + ) + } + + #[test] + fn test_build_unbound_returns_correct_default_order_id_for_fields() { + let sort_field = SortField::builder() + .source_id(2) + .direction(SortDirection::Ascending) + .null_order(NullOrder::First) + .transform(Transform::Identity) + .build(); + assert_ne!( + SortOrder::builder() + .with_sort_field(sort_field.clone()) + .build_unbound() + .expect("Expected an Ok value") + .order_id, + SortOrder::UNSORTED_ORDER_ID + ) + } + + #[test] + fn test_build_unbound_should_return_unsorted_sort_order() { + assert_eq!( + SortOrder::builder() + .with_order_id(SortOrder::UNSORTED_ORDER_ID) + .build_unbound() + .expect("Expected an Ok value"), + SortOrder::unsorted_order() + ) + } + + #[test] + fn test_build_unbound_should_return_sort_order_with_given_order_id_and_sort_fields() { + let sort_field = SortField::builder() + .source_id(2) + .direction(SortDirection::Ascending) + .null_order(NullOrder::First) + .transform(Transform::Identity) + .build(); + + assert_eq!( + SortOrder::builder() + .with_order_id(2) + .with_sort_field(sort_field.clone()) + .build_unbound() + .expect("Expected an Ok value"), + SortOrder { + order_id: 2, + fields: vec![sort_field] + } + ) + } + + #[test] + fn test_build_unbound_should_return_sort_order_with_given_sort_fields_and_defaults_to_1_if_missing_an_order_id( + ) { + let sort_field = SortField::builder() + .source_id(2) + .direction(SortDirection::Ascending) + .null_order(NullOrder::First) + .transform(Transform::Identity) + .build(); + + assert_eq!( + SortOrder::builder() + .with_sort_field(sort_field.clone()) + .build_unbound() + .expect("Expected an Ok value"), + SortOrder { + order_id: 1, + fields: vec![sort_field] + } + ) + } + + #[test] + fn test_build_should_return_err_if_sort_order_field_is_not_present_in_schema() { + let schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![NestedField::required( + 1, + "foo", + Type::Primitive(PrimitiveType::Int), + ) + .into()]) + .build() + .unwrap(); + + let sort_order_builder_result = SortOrder::builder() + .with_sort_field( + SortField::builder() + .source_id(2) + .direction(SortDirection::Ascending) + .null_order(NullOrder::First) + .transform(Transform::Identity) + .build(), + ) + .build(&schema); + + assert_eq!( + sort_order_builder_result + .expect_err("Expected an Err value") + .message(), + "Cannot find source column for sort field: SortField { source_id: 2, transform: identity, direction: ascending, null_order: first }" + ) + } + + #[test] + fn test_build_should_return_err_if_source_field_is_not_a_primitive_type() { + let schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![NestedField::required( + 1, + "foo", + Type::List(ListType { + element_field: NestedField::list_element( + 2, + Type::Primitive(PrimitiveType::String), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap(); + + let sort_order_builder_result = SortOrder::builder() + .with_sort_field( + SortField::builder() + .source_id(1) + .direction(SortDirection::Ascending) + .null_order(NullOrder::First) + .transform(Transform::Identity) + .build(), + ) + .build(&schema); + + assert_eq!( + sort_order_builder_result + .expect_err("Expected an Err value") + .message(), + "Cannot sort by non-primitive source field: list" + ) + } + + #[test] + fn test_build_should_return_err_if_source_field_type_is_not_supported_by_transform() { + let schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![NestedField::required( + 1, + "foo", + Type::Primitive(PrimitiveType::Int), + ) + .into()]) + .build() + .unwrap(); + + let sort_order_builder_result = SortOrder::builder() + .with_sort_field( + SortField::builder() + .source_id(1) + .direction(SortDirection::Ascending) + .null_order(NullOrder::First) + .transform(Transform::Year) + .build(), + ) + .build(&schema); + + assert_eq!( + sort_order_builder_result + .expect_err("Expected an Err value") + .message(), + "Invalid source type int for transform year" + ) + } + + #[test] + fn test_build_should_return_valid_sort_order() { + let schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::required(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + ]) + .build() + .unwrap(); + + let sort_field = SortField::builder() + .source_id(2) + .direction(SortDirection::Ascending) + .null_order(NullOrder::First) + .transform(Transform::Identity) + .build(); + + let sort_order_builder_result = SortOrder::builder() + .with_sort_field(sort_field.clone()) + .build(&schema); + + assert_eq!( + sort_order_builder_result.expect("Expected an Ok value"), + SortOrder { + order_id: 1, + fields: vec![sort_field], + } + ) + } } diff --git a/crates/iceberg/src/spec/table_metadata.rs b/crates/iceberg/src/spec/table_metadata.rs index 905c82307..16deaac22 100644 --- a/crates/iceberg/src/spec/table_metadata.rs +++ b/crates/iceberg/src/spec/table_metadata.rs @@ -18,74 +18,82 @@ //! Defines the [table metadata](https://iceberg.apache.org/spec/#table-metadata). //! The main struct here is [TableMetadataV2] which defines the data for a table. -use serde::{Deserialize, Serialize}; -use serde_repr::{Deserialize_repr, Serialize_repr}; use std::cmp::Ordering; +use std::collections::HashMap; use std::fmt::{Display, Formatter}; -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; + +use _serde::TableMetadataEnum; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; use uuid::Uuid; +use super::snapshot::{Snapshot, SnapshotReference, SnapshotRetention}; use super::{ - snapshot::{Snapshot, SnapshotReference, SnapshotRetention}, - PartitionSpecRef, SchemaRef, SnapshotRef, SortOrderRef, + PartitionSpec, PartitionSpecRef, SchemaId, SchemaRef, SnapshotRef, SortOrder, SortOrderRef, + DEFAULT_PARTITION_SPEC_ID, }; - -use _serde::TableMetadataEnum; - -use chrono::{DateTime, TimeZone, Utc}; +use crate::error::{timestamp_ms_to_utc, Result}; +use crate::{Error, ErrorKind, TableCreation}; static MAIN_BRANCH: &str = "main"; -static DEFAULT_SPEC_ID: i32 = 0; static DEFAULT_SORT_ORDER_ID: i64 = 0; -#[derive(Debug, PartialEq, Serialize, Deserialize, Eq, Clone)] -#[serde(try_from = "TableMetadataEnum", into = "TableMetadataEnum")] +pub(crate) static EMPTY_SNAPSHOT_ID: i64 = -1; +pub(crate) static INITIAL_SEQUENCE_NUMBER: i64 = 0; + +/// Reference to [`TableMetadata`]. +pub type TableMetadataRef = Arc; + +#[derive(Debug, PartialEq, Deserialize, Eq, Clone)] +#[serde(try_from = "TableMetadataEnum")] /// Fields for the version 2 of the table metadata. /// /// We assume that this data structure is always valid, so we will panic when invalid error happens. /// We check the validity of this data structure when constructing. pub struct TableMetadata { /// Integer Version for the format. - format_version: FormatVersion, + pub(crate) format_version: FormatVersion, /// A UUID that identifies the table - table_uuid: Uuid, + pub(crate) table_uuid: Uuid, /// Location tables base location - location: String, + pub(crate) location: String, /// The tables highest sequence number - last_sequence_number: i64, + pub(crate) last_sequence_number: i64, /// Timestamp in milliseconds from the unix epoch when the table was last updated. - last_updated_ms: i64, + pub(crate) last_updated_ms: i64, /// An integer; the highest assigned column ID for the table. - last_column_id: i32, + pub(crate) last_column_id: i32, /// A list of schemas, stored as objects with schema-id. - schemas: HashMap, + pub(crate) schemas: HashMap, /// ID of the table’s current schema. - current_schema_id: i32, + pub(crate) current_schema_id: i32, /// A list of partition specs, stored as full partition spec objects. - partition_specs: HashMap, + pub(crate) partition_specs: HashMap, /// ID of the “current” spec that writers should use by default. - default_spec_id: i32, + pub(crate) default_spec_id: i32, /// An integer; the highest assigned partition field ID across all partition specs for the table. - last_partition_id: i32, + pub(crate) last_partition_id: i32, ///A string to string map of table properties. This is used to control settings that /// affect reading and writing and is not intended to be used for arbitrary metadata. /// For example, commit.retry.num-retries is used to control the number of commit retries. - properties: HashMap, + pub(crate) properties: HashMap, /// long ID of the current table snapshot; must be the same as the current /// ID of the main branch in refs. - current_snapshot_id: Option, + pub(crate) current_snapshot_id: Option, ///A list of valid snapshots. Valid snapshots are snapshots for which all /// data files exist in the file system. A data file must not be deleted /// from the file system until the last snapshot in which it was listed is /// garbage collected. - snapshots: HashMap, + pub(crate) snapshots: HashMap, /// A list (optional) of timestamp and snapshot ID pairs that encodes changes /// to the current snapshot for the table. Each time the current-snapshot-id /// is changed, a new entry should be added with the last-updated-ms /// and the new current-snapshot-id. When snapshots are expired from /// the list of valid snapshots, all entries before a snapshot that has /// expired should be removed. - snapshot_log: Vec, + pub(crate) snapshot_log: Vec, /// A list (optional) of timestamp and metadata file location pairs /// that encodes changes to the previous metadata files for the table. @@ -93,19 +101,19 @@ pub struct TableMetadata { /// previous metadata file location should be added to the list. /// Tables can be configured to remove oldest metadata log entries and /// keep a fixed-size log of the most recent entries after a commit. - metadata_log: Vec, + pub(crate) metadata_log: Vec, /// A list of sort orders, stored as full sort order objects. - sort_orders: HashMap, + pub(crate) sort_orders: HashMap, /// Default sort order id of the table. Note that this could be used by /// writers, but is not used when reading because reads use the specs /// stored in manifest files. - default_sort_order_id: i64, + pub(crate) default_sort_order_id: i64, ///A map of snapshot references. The map keys are the unique snapshot reference /// names in the table, and the map values are snapshot reference objects. /// There is always a main branch reference pointing to the current-snapshot-id /// even if the refs map is null. - refs: HashMap, + pub(crate) refs: HashMap, } impl TableMetadata { @@ -135,8 +143,14 @@ impl TableMetadata { /// Returns last updated time. #[inline] - pub fn last_updated_ms(&self) -> DateTime { - Utc.timestamp_millis_opt(self.last_updated_ms).unwrap() + pub fn last_updated_timestamp(&self) -> Result> { + timestamp_ms_to_utc(self.last_updated_ms) + } + + /// Returns last updated time in milliseconds. + #[inline] + pub fn last_updated_ms(&self) -> i64 { + self.last_updated_ms } /// Returns schemas @@ -147,7 +161,7 @@ impl TableMetadata { /// Lookup schema by id. #[inline] - pub fn schema_by_id(&self, schema_id: i32) -> Option<&SchemaRef> { + pub fn schema_by_id(&self, schema_id: SchemaId) -> Option<&SchemaRef> { self.schemas.get(&schema_id) } @@ -173,11 +187,11 @@ impl TableMetadata { /// Get default partition spec #[inline] pub fn default_partition_spec(&self) -> Option<&PartitionSpecRef> { - if self.default_spec_id == DEFAULT_SPEC_ID { - self.partition_spec_by_id(DEFAULT_SPEC_ID) + if self.default_spec_id == DEFAULT_PARTITION_SPEC_ID { + self.partition_spec_by_id(DEFAULT_PARTITION_SPEC_ID) } else { Some( - self.partition_spec_by_id(DEFAULT_SPEC_ID) + self.partition_spec_by_id(self.default_spec_id) .expect("Default partition spec id set, but not found in table metadata"), ) } @@ -244,7 +258,7 @@ impl TableMetadata { /// Append snapshot to table pub fn append_snapshot(&mut self, snapshot: Snapshot) { - self.last_updated_ms = snapshot.timestamp().timestamp_millis(); + self.last_updated_ms = snapshot.timestamp_ms(); self.last_sequence_number = snapshot.sequence_number(); self.refs @@ -253,14 +267,11 @@ impl TableMetadata { s.snapshot_id = snapshot.snapshot_id(); }) .or_insert_with(|| { - SnapshotReference::new( - snapshot.snapshot_id(), - SnapshotRetention::Branch { - min_snapshots_to_keep: None, - max_snapshot_age_ms: None, - max_ref_age_ms: None, - }, - ) + SnapshotReference::new(snapshot.snapshot_id(), SnapshotRetention::Branch { + min_snapshots_to_keep: None, + max_snapshot_age_ms: None, + max_ref_age_ms: None, + }) }); self.snapshot_log.push(snapshot.log()); @@ -269,31 +280,127 @@ impl TableMetadata { } } +/// Manipulating table metadata. +pub struct TableMetadataBuilder(TableMetadata); + +impl TableMetadataBuilder { + /// Creates a new table metadata builder from the given table metadata. + pub fn new(origin: TableMetadata) -> Self { + Self(origin) + } + + /// Creates a new table metadata builder from the given table creation. + pub fn from_table_creation(table_creation: TableCreation) -> Result { + let TableCreation { + name: _, + location, + schema, + partition_spec, + sort_order, + properties, + } = table_creation; + + let partition_specs = match partition_spec { + Some(_) => { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + "Can't create table with partition spec now", + )) + } + None => HashMap::from([( + DEFAULT_PARTITION_SPEC_ID, + Arc::new(PartitionSpec { + spec_id: DEFAULT_PARTITION_SPEC_ID, + fields: vec![], + }), + )]), + }; + + let sort_orders = match sort_order { + Some(_) => { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + "Can't create table with sort order now", + )) + } + None => HashMap::from([( + DEFAULT_SORT_ORDER_ID, + Arc::new(SortOrder { + order_id: DEFAULT_SORT_ORDER_ID, + fields: vec![], + }), + )]), + }; + + let table_metadata = TableMetadata { + format_version: FormatVersion::V2, + table_uuid: Uuid::now_v7(), + location: location.ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Can't create table without location", + ) + })?, + last_sequence_number: 0, + last_updated_ms: Utc::now().timestamp_millis(), + last_column_id: schema.highest_field_id(), + current_schema_id: schema.schema_id(), + schemas: HashMap::from([(schema.schema_id(), Arc::new(schema))]), + partition_specs, + default_spec_id: DEFAULT_PARTITION_SPEC_ID, + last_partition_id: 0, + properties, + current_snapshot_id: None, + snapshots: Default::default(), + snapshot_log: vec![], + sort_orders, + metadata_log: vec![], + default_sort_order_id: DEFAULT_SORT_ORDER_ID, + refs: Default::default(), + }; + + Ok(Self(table_metadata)) + } + + /// Changes uuid of table metadata. + pub fn assign_uuid(mut self, uuid: Uuid) -> Result { + self.0.table_uuid = uuid; + Ok(self) + } + + /// Returns the new table metadata after changes. + pub fn build(self) -> Result { + Ok(self.0) + } +} + pub(super) mod _serde { /// This is a helper module that defines types to help with serialization/deserialization. /// For deserialization the input first gets read into either the [TableMetadataV1] or [TableMetadataV2] struct /// and then converted into the [TableMetadata] struct. Serialization works the other way around. /// [TableMetadataV1] and [TableMetadataV2] are internal struct that are only used for serialization and deserialization. - use std::{collections::HashMap, sync::Arc}; + use std::collections::HashMap; + /// This is a helper module that defines types to help with serialization/deserialization. + /// For deserialization the input first gets read into either the [TableMetadataV1] or [TableMetadataV2] struct + /// and then converted into the [TableMetadata] struct. Serialization works the other way around. + /// [TableMetadataV1] and [TableMetadataV2] are internal struct that are only used for serialization and deserialization. + use std::sync::Arc; use itertools::Itertools; use serde::{Deserialize, Serialize}; use uuid::Uuid; - use crate::spec::Snapshot; - use crate::{ - spec::{ - schema::_serde::{SchemaV1, SchemaV2}, - snapshot::_serde::{SnapshotV1, SnapshotV2}, - PartitionField, PartitionSpec, Schema, SnapshotReference, SnapshotRetention, SortOrder, - }, - Error, ErrorKind, - }; - use super::{ - FormatVersion, MetadataLog, SnapshotLog, TableMetadata, DEFAULT_SORT_ORDER_ID, - DEFAULT_SPEC_ID, MAIN_BRANCH, + FormatVersion, MetadataLog, SnapshotLog, TableMetadata, DEFAULT_PARTITION_SPEC_ID, + DEFAULT_SORT_ORDER_ID, MAIN_BRANCH, }; + use crate::spec::schema::_serde::{SchemaV1, SchemaV2}; + use crate::spec::snapshot::_serde::{SnapshotV1, SnapshotV2}; + use crate::spec::{ + PartitionField, PartitionSpec, Schema, Snapshot, SnapshotReference, SnapshotRetention, + SortOrder, EMPTY_SNAPSHOT_ID, + }; + use crate::{Error, ErrorKind}; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] @@ -371,22 +478,29 @@ pub(super) mod _serde { /// Helper to serialize and deserialize the format version. #[derive(Debug, PartialEq, Eq)] - pub(super) struct VersionNumber; + pub(crate) struct VersionNumber; + + impl Serialize for TableMetadata { + fn serialize(&self, serializer: S) -> Result + where S: serde::Serializer { + // we must do a clone here + let table_metadata_enum: TableMetadataEnum = + self.clone().try_into().map_err(serde::ser::Error::custom)?; + + table_metadata_enum.serialize(serializer) + } + } impl Serialize for VersionNumber { fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { + where S: serde::Serializer { serializer.serialize_u8(V) } } impl<'de, const V: u8> Deserialize<'de> for VersionNumber { fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { + where D: serde::Deserializer<'de> { let value = u8::deserialize(deserializer)?; if value == V { Ok(VersionNumber::) @@ -406,12 +520,13 @@ pub(super) mod _serde { } } - impl From for TableMetadataEnum { - fn from(value: TableMetadata) -> Self { - match value.format_version { + impl TryFrom for TableMetadataEnum { + type Error = Error; + fn try_from(value: TableMetadata) -> Result { + Ok(match value.format_version { FormatVersion::V2 => TableMetadataEnum::V2(value.into()), - FormatVersion::V1 => TableMetadataEnum::V1(value.into()), - } + FormatVersion::V1 => TableMetadataEnum::V1(value.try_into()?), + }) } } @@ -453,7 +568,7 @@ pub(super) mod _serde { value .partition_specs .into_iter() - .map(|x| (x.spec_id, Arc::new(x))), + .map(|x| (x.spec_id(), Arc::new(x))), ), default_spec_id: value.default_spec_id, last_partition_id: value.last_partition_id, @@ -480,17 +595,14 @@ pub(super) mod _serde { default_sort_order_id: value.default_sort_order_id, refs: value.refs.unwrap_or_else(|| { if let Some(snapshot_id) = current_snapshot_id { - HashMap::from_iter(vec![( - MAIN_BRANCH.to_string(), - SnapshotReference { - snapshot_id, - retention: SnapshotRetention::Branch { - min_snapshots_to_keep: None, - max_snapshot_age_ms: None, - max_ref_age_ms: None, - }, + HashMap::from_iter(vec![(MAIN_BRANCH.to_string(), SnapshotReference { + snapshot_id, + retention: SnapshotRetention::Branch { + min_snapshots_to_keep: None, + max_snapshot_age_ms: None, + max_ref_age_ms: None, }, - )]) + })]) } else { HashMap::new() } @@ -531,12 +643,12 @@ pub(super) mod _serde { .partition_specs .unwrap_or_else(|| { vec![PartitionSpec { - spec_id: DEFAULT_SPEC_ID, + spec_id: DEFAULT_PARTITION_SPEC_ID, fields: value.partition_spec, }] }) .into_iter() - .map(|x| (x.spec_id, Arc::new(x))), + .map(|x| (x.spec_id(), Arc::new(x))), ); Ok(TableMetadata { format_version: FormatVersion::V1, @@ -558,8 +670,12 @@ pub(super) mod _serde { schemas, properties: value.properties.unwrap_or_default(), - current_snapshot_id: if let &Some(-1) = &value.current_snapshot_id { - None + current_snapshot_id: if let &Some(id) = &value.current_snapshot_id { + if id == EMPTY_SNAPSHOT_ID { + None + } else { + Some(id) + } } else { value.current_snapshot_id }, @@ -584,17 +700,14 @@ pub(super) mod _serde { None => HashMap::new(), }, default_sort_order_id: value.default_sort_order_id.unwrap_or(DEFAULT_SORT_ORDER_ID), - refs: HashMap::from_iter(vec![( - MAIN_BRANCH.to_string(), - SnapshotReference { - snapshot_id: value.current_snapshot_id.unwrap_or_default(), - retention: SnapshotRetention::Branch { - min_snapshots_to_keep: None, - max_snapshot_age_ms: None, - max_ref_age_ms: None, - }, + refs: HashMap::from_iter(vec![(MAIN_BRANCH.to_string(), SnapshotReference { + snapshot_id: value.current_snapshot_id.unwrap_or_default(), + retention: SnapshotRetention::Branch { + min_snapshots_to_keep: None, + max_snapshot_age_ms: None, + max_ref_age_ms: None, }, - )]), + })]), }) } } @@ -625,14 +738,10 @@ pub(super) mod _serde { .collect(), default_spec_id: v.default_spec_id, last_partition_id: v.last_partition_id, - properties: if v.properties.is_empty() { - None - } else { - Some(v.properties) - }, + properties: Some(v.properties), current_snapshot_id: v.current_snapshot_id.or(Some(-1)), snapshots: if v.snapshots.is_empty() { - None + Some(vec![]) } else { Some( v.snapshots @@ -666,9 +775,10 @@ pub(super) mod _serde { } } - impl From for TableMetadataV1 { - fn from(v: TableMetadata) -> Self { - TableMetadataV1 { + impl TryFrom for TableMetadataV1 { + type Error = Error; + fn try_from(v: TableMetadata) -> Result { + Ok(TableMetadataV1 { format_version: VersionNumber::<1>, table_uuid: Some(v.table_uuid), location: v.location, @@ -677,7 +787,10 @@ pub(super) mod _serde { schema: v .schemas .get(&v.current_schema_id) - .unwrap() + .ok_or(Error::new( + ErrorKind::Unexpected, + "current_schema_id not found in schemas", + ))? .as_ref() .clone() .into(), @@ -695,7 +808,7 @@ pub(super) mod _serde { partition_spec: v .partition_specs .get(&v.default_spec_id) - .map(|x| x.fields.clone()) + .map(|x| x.fields().to_vec()) .unwrap_or_default(), partition_specs: Some( v.partition_specs @@ -738,12 +851,12 @@ pub(super) mod _serde { .collect(), ), default_sort_order_id: Some(v.default_sort_order_id), - } + }) } } } -#[derive(Debug, Serialize_repr, Deserialize_repr, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Serialize_repr, Deserialize_repr, PartialEq, Eq, Clone, Copy, Hash)] #[repr(u8)] /// Iceberg format version pub enum FormatVersion { @@ -796,27 +909,35 @@ pub struct SnapshotLog { impl SnapshotLog { /// Returns the last updated timestamp as a DateTime with millisecond precision - pub fn timestamp(self) -> DateTime { - Utc.timestamp_millis_opt(self.timestamp_ms).unwrap() + pub fn timestamp(self) -> Result> { + timestamp_ms_to_utc(self.timestamp_ms) + } + + /// Returns the timestamp in milliseconds + #[inline] + pub fn timestamp_ms(&self) -> i64 { + self.timestamp_ms } } #[cfg(test)] mod tests { - use std::{collections::HashMap, fs, sync::Arc}; + use std::collections::HashMap; + use std::fs; + use std::sync::Arc; use anyhow::Result; - use uuid::Uuid; - use pretty_assertions::assert_eq; + use uuid::Uuid; + use super::{FormatVersion, MetadataLog, SnapshotLog, TableMetadataBuilder}; + use crate::spec::table_metadata::TableMetadata; use crate::spec::{ - table_metadata::TableMetadata, ManifestListLocation, NestedField, NullOrder, Operation, - PartitionField, PartitionSpec, PrimitiveType, Schema, Snapshot, SnapshotReference, - SnapshotRetention, SortDirection, SortField, SortOrder, Summary, Transform, Type, + NestedField, NullOrder, Operation, PartitionField, PartitionSpec, PrimitiveType, Schema, + Snapshot, SnapshotReference, SnapshotRetention, SortDirection, SortField, SortOrder, + Summary, Transform, Type, UnboundPartitionField, }; - - use super::{FormatVersion, MetadataLog, SnapshotLog}; + use crate::TableCreation; fn check_table_metadata_serde(json: &str, expected_type: TableMetadata) { let desered_type: TableMetadata = serde_json::from_str(json).unwrap(); @@ -828,6 +949,13 @@ mod tests { assert_eq!(parsed_json_value, desered_type); } + fn get_test_table_metadata(file_name: &str) -> TableMetadata { + let path = format!("testdata/table_metadata/{}", file_name); + let metadata: String = fs::read_to_string(path).unwrap(); + + serde_json::from_str(&metadata).unwrap() + } + #[test] fn test_table_data_v2() { let data = r#" @@ -857,12 +985,12 @@ mod tests { { "spec-id": 1, "fields": [ - { - "source-id": 4, - "field-id": 1000, - "name": "ts_day", + { + "source-id": 4, + "field-id": 1000, + "name": "ts_day", "transform": "day" - } + } ] } ], @@ -872,8 +1000,8 @@ mod tests { "commit.retry.num-retries": "1" }, "metadata-log": [ - { - "metadata-file": "s3://bucket/.../v1.json", + { + "metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100 } ], @@ -892,16 +1020,15 @@ mod tests { .build() .unwrap(); - let partition_spec = PartitionSpec::builder() - .with_spec_id(1) - .with_partition_field(PartitionField { + let partition_spec = PartitionSpec { + spec_id: 1, + fields: vec![PartitionField { name: "ts_day".to_string(), transform: Transform::Day, source_id: 4, field_id: 1000, - }) - .build() - .unwrap(); + }], + }; let expected = TableMetadata { format_version: FormatVersion::V2, @@ -1051,25 +1178,24 @@ mod tests { .build() .unwrap(); - let partition_spec = PartitionSpec::builder() + let partition_spec = PartitionSpec::builder(&schema) .with_spec_id(0) - .with_partition_field(PartitionField { - name: "vendor_id".to_string(), - transform: Transform::Identity, - source_id: 1, - field_id: 1000, - }) + .add_partition_field("vendor_id", "vendor_id", Transform::Identity) + .unwrap() .build() .unwrap(); - let sort_order = SortOrder::builder().with_order_id(0).build().unwrap(); + let sort_order = SortOrder::builder() + .with_order_id(0) + .build_unbound() + .unwrap(); let snapshot = Snapshot::builder() .with_snapshot_id(638933773299822130) .with_timestamp_ms(1662532818843) .with_sequence_number(0) .with_schema_id(0) - .with_manifest_list(ManifestListLocation::ManifestListFile("/home/iceberg/warehouse/nyc/taxis/metadata/snap-638933773299822130-1-7e6760f0-4f6c-4b23-b907-0a5a174e3863.avro".to_string())) + .with_manifest_list("/home/iceberg/warehouse/nyc/taxis/metadata/snap-638933773299822130-1-7e6760f0-4f6c-4b23-b907-0a5a174e3863.avro") .with_summary(Summary { operation: Operation::Append, other: HashMap::from_iter(vec![("spark.app.id".to_string(), "local-1662532784305".to_string()), ("added-data-files".to_string(), "4".to_string()), ("added-records".to_string(), "4".to_string()), ("added-files-size".to_string(), "6001".to_string())]) }) .build(); @@ -1161,14 +1287,15 @@ mod tests { .build() .unwrap(); - let partition_spec = PartitionSpec::builder() + let partition_spec = PartitionSpec::builder(&schema1) .with_spec_id(0) - .with_partition_field(PartitionField { + .add_unbound_field(UnboundPartitionField { name: "x".to_string(), transform: Transform::Identity, source_id: 1, - field_id: 1000, + field_id: Some(1000), }) + .unwrap() .build() .unwrap(); @@ -1186,16 +1313,14 @@ mod tests { direction: SortDirection::Descending, null_order: NullOrder::Last, }) - .build() + .build_unbound() .unwrap(); let snapshot1 = Snapshot::builder() .with_snapshot_id(3051729675574597004) .with_timestamp_ms(1515100955770) .with_sequence_number(0) - .with_manifest_list(ManifestListLocation::ManifestListFile( - "s3://a/b/1.avro".to_string(), - )) + .with_manifest_list("s3://a/b/1.avro") .with_summary(Summary { operation: Operation::Append, other: HashMap::new(), @@ -1208,9 +1333,7 @@ mod tests { .with_timestamp_ms(1555100955770) .with_sequence_number(1) .with_schema_id(1) - .with_manifest_list(ManifestListLocation::ManifestListFile( - "s3://a/b/2.avro".to_string(), - )) + .with_manifest_list("s3://a/b/2.avro") .with_summary(Summary { operation: Operation::Append, other: HashMap::new(), @@ -1248,17 +1371,14 @@ mod tests { }, ], metadata_log: Vec::new(), - refs: HashMap::from_iter(vec![( - "main".to_string(), - SnapshotReference { - snapshot_id: 3055729675574597004, - retention: SnapshotRetention::Branch { - min_snapshots_to_keep: None, - max_snapshot_age_ms: None, - max_ref_age_ms: None, - }, + refs: HashMap::from_iter(vec![("main".to_string(), SnapshotReference { + snapshot_id: 3055729675574597004, + retention: SnapshotRetention::Branch { + min_snapshots_to_keep: None, + max_snapshot_age_ms: None, + max_ref_age_ms: None, }, - )]), + })]), }; check_table_metadata_serde(&metadata, expected); @@ -1290,14 +1410,15 @@ mod tests { .build() .unwrap(); - let partition_spec = PartitionSpec::builder() + let partition_spec = PartitionSpec::builder(&schema) .with_spec_id(0) - .with_partition_field(PartitionField { + .add_unbound_field(UnboundPartitionField { name: "x".to_string(), transform: Transform::Identity, source_id: 1, - field_id: 1000, + field_id: Some(1000), }) + .unwrap() .build() .unwrap(); @@ -1315,7 +1436,7 @@ mod tests { direction: SortDirection::Descending, null_order: NullOrder::Last, }) - .build() + .build_unbound() .unwrap(); let expected = TableMetadata { @@ -1369,14 +1490,15 @@ mod tests { .build() .unwrap(); - let partition_spec = PartitionSpec::builder() + let partition_spec = PartitionSpec::builder(&schema) .with_spec_id(0) - .with_partition_field(PartitionField { + .add_unbound_field(UnboundPartitionField { name: "x".to_string(), transform: Transform::Identity, source_id: 1, - field_id: 1000, + field_id: Some(1000), }) + .unwrap() .build() .unwrap(); @@ -1399,17 +1521,14 @@ mod tests { properties: HashMap::new(), snapshot_log: vec![], metadata_log: Vec::new(), - refs: HashMap::from_iter(vec![( - "main".to_string(), - SnapshotReference { - snapshot_id: -1, - retention: SnapshotRetention::Branch { - min_snapshots_to_keep: None, - max_snapshot_age_ms: None, - max_ref_age_ms: None, - }, + refs: HashMap::from_iter(vec![("main".to_string(), SnapshotReference { + snapshot_id: -1, + retention: SnapshotRetention::Branch { + min_snapshots_to_keep: None, + max_snapshot_age_ms: None, + max_ref_age_ms: None, }, - )]), + })]), }; check_table_metadata_serde(&metadata, expected); @@ -1501,9 +1620,100 @@ mod tests { } #[test] - fn order_of_format_version() { + fn test_order_of_format_version() { assert!(FormatVersion::V1 < FormatVersion::V2); assert_eq!(FormatVersion::V1, FormatVersion::V1); assert_eq!(FormatVersion::V2, FormatVersion::V2); } + + #[test] + fn test_default_partition_spec() { + let default_spec_id = 1234; + let mut table_meta_data = get_test_table_metadata("TableMetadataV2Valid.json"); + table_meta_data.default_spec_id = default_spec_id; + table_meta_data + .partition_specs + .insert(default_spec_id, Arc::new(PartitionSpec::default())); + + assert_eq!( + table_meta_data.default_partition_spec(), + table_meta_data.partition_spec_by_id(default_spec_id) + ); + } + #[test] + fn test_default_sort_order() { + let default_sort_order_id = 1234; + let mut table_meta_data = get_test_table_metadata("TableMetadataV2Valid.json"); + table_meta_data.default_sort_order_id = default_sort_order_id; + table_meta_data + .sort_orders + .insert(default_sort_order_id, Arc::new(SortOrder::default())); + + assert_eq!( + table_meta_data.default_sort_order(), + table_meta_data.sort_orders.get(&default_sort_order_id) + ) + } + + #[test] + fn test_table_metadata_builder_from_table_creation() { + let table_creation = TableCreation::builder() + .location("s3://db/table".to_string()) + .name("table".to_string()) + .properties(HashMap::new()) + .schema(Schema::builder().build().unwrap()) + .build(); + let table_metadata = TableMetadataBuilder::from_table_creation(table_creation) + .unwrap() + .build() + .unwrap(); + assert_eq!(table_metadata.location, "s3://db/table"); + assert_eq!(table_metadata.schemas.len(), 1); + assert_eq!( + table_metadata + .schemas + .get(&0) + .unwrap() + .as_struct() + .fields() + .len(), + 0 + ); + assert_eq!(table_metadata.properties.len(), 0); + assert_eq!( + table_metadata.partition_specs, + HashMap::from([( + 0, + Arc::new( + PartitionSpec::builder(table_metadata.schemas.get(&0).unwrap()) + .with_spec_id(0) + .build() + .unwrap() + ) + )]) + ); + assert_eq!( + table_metadata.sort_orders, + HashMap::from([( + 0, + Arc::new(SortOrder { + order_id: 0, + fields: vec![] + }) + )]) + ); + } + + #[test] + fn test_table_builder_from_table_metadata() { + let table_metadata = get_test_table_metadata("TableMetadataV2Valid.json"); + let table_metadata_builder = TableMetadataBuilder::new(table_metadata); + let uuid = Uuid::new_v4(); + let table_metadata = table_metadata_builder + .assign_uuid(uuid) + .unwrap() + .build() + .unwrap(); + assert_eq!(table_metadata.uuid(), uuid); + } } diff --git a/crates/iceberg/src/spec/transform.rs b/crates/iceberg/src/spec/transform.rs index 839d582dc..6b7d03f11 100644 --- a/crates/iceberg/src/spec/transform.rs +++ b/crates/iceberg/src/spec/transform.rs @@ -17,12 +17,21 @@ //! Transforms in iceberg. +use std::fmt::{Display, Formatter}; +use std::str::FromStr; + +use fnv::FnvHashSet; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use super::{Datum, PrimitiveLiteral}; use crate::error::{Error, Result}; +use crate::expr::{ + BinaryExpression, BoundPredicate, BoundReference, Predicate, PredicateOperator, Reference, + SetExpression, UnaryExpression, +}; use crate::spec::datatypes::{PrimitiveType, Type}; +use crate::transform::{create_transform_function, BoxedTransformFunction}; use crate::ErrorKind; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use std::fmt::{Display, Formatter}; -use std::str::FromStr; /// Transform is used to transform predicates to partition predicates, /// in addition to transforming data values. @@ -150,6 +159,8 @@ impl Transform { | PrimitiveType::Time | PrimitiveType::Timestamp | PrimitiveType::Timestamptz + | PrimitiveType::TimestampNs + | PrimitiveType::TimestamptzNs | PrimitiveType::String | PrimitiveType::Uuid | PrimitiveType::Fixed(_) @@ -191,7 +202,9 @@ impl Transform { match p { PrimitiveType::Timestamp | PrimitiveType::Timestamptz - | PrimitiveType::Date => Ok(Type::Primitive(PrimitiveType::Int)), + | PrimitiveType::TimestampNs + | PrimitiveType::TimestamptzNs + | PrimitiveType::Date => Ok(Type::Primitive(PrimitiveType::Date)), _ => Err(Error::new( ErrorKind::DataInvalid, format!("{input_type} is not a valid input type of {self} transform",), @@ -207,9 +220,10 @@ impl Transform { Transform::Hour => { if let Type::Primitive(p) = input_type { match p { - PrimitiveType::Timestamp | PrimitiveType::Timestamptz => { - Ok(Type::Primitive(PrimitiveType::Int)) - } + PrimitiveType::Timestamp + | PrimitiveType::Timestamptz + | PrimitiveType::TimestampNs + | PrimitiveType::TimestamptzNs => Ok(Type::Primitive(PrimitiveType::Int)), _ => Err(Error::new( ErrorKind::DataInvalid, format!("{input_type} is not a valid input type of {self} transform",), @@ -248,7 +262,7 @@ impl Transform { /// result. /// /// For example, sorting by day(ts) will produce an ordering that is also by month(ts) or - // year(ts). However, sorting by day(ts) will not satisfy the order of hour(ts) or identity(ts). + /// year(ts). However, sorting by day(ts) will not satisfy the order of hour(ts) or identity(ts). pub fn satisfies_order_of(&self, other: &Self) -> bool { match self { Transform::Identity => other.preserves_order(), @@ -261,6 +275,329 @@ impl Transform { _ => self == other, } } + + /// Projects a given predicate according to the transformation + /// specified by the `Transform` instance. + /// + /// This allows predicates to be effectively applied to data + /// that has undergone transformation, enabling efficient querying + /// and filtering based on the original, untransformed data. + /// + /// # Example + /// Suppose, we have row filter `a = 10`, and a partition spec + /// `bucket(a, 37) as bs`, if one row matches `a = 10`, then its partition + /// value should match `bucket(10, 37) as bs`, and we project `a = 10` to + /// `bs = bucket(10, 37)` + pub fn project(&self, name: &str, predicate: &BoundPredicate) -> Result> { + let func = create_transform_function(self)?; + + match self { + Transform::Identity => match predicate { + BoundPredicate::Unary(expr) => Self::project_unary(expr.op(), name), + BoundPredicate::Binary(expr) => Ok(Some(Predicate::Binary(BinaryExpression::new( + expr.op(), + Reference::new(name), + expr.literal().to_owned(), + )))), + BoundPredicate::Set(expr) => Ok(Some(Predicate::Set(SetExpression::new( + expr.op(), + Reference::new(name), + expr.literals().to_owned(), + )))), + _ => Ok(None), + }, + Transform::Bucket(_) => match predicate { + BoundPredicate::Unary(expr) => Self::project_unary(expr.op(), name), + BoundPredicate::Binary(expr) => self.project_eq_operator(name, expr, &func), + BoundPredicate::Set(expr) => self.project_in_operator(expr, name, &func), + _ => Ok(None), + }, + Transform::Truncate(width) => match predicate { + BoundPredicate::Unary(expr) => Self::project_unary(expr.op(), name), + BoundPredicate::Binary(expr) => { + self.project_binary_with_adjusted_boundary(name, expr, &func, Some(*width)) + } + BoundPredicate::Set(expr) => self.project_in_operator(expr, name, &func), + _ => Ok(None), + }, + Transform::Year | Transform::Month | Transform::Day | Transform::Hour => { + match predicate { + BoundPredicate::Unary(expr) => Self::project_unary(expr.op(), name), + BoundPredicate::Binary(expr) => { + self.project_binary_with_adjusted_boundary(name, expr, &func, None) + } + BoundPredicate::Set(expr) => self.project_in_operator(expr, name, &func), + _ => Ok(None), + } + } + _ => Ok(None), + } + } + + /// Check if `Transform` is applicable on datum's `PrimitiveType` + fn can_transform(&self, datum: &Datum) -> bool { + let input_type = datum.data_type().clone(); + self.result_type(&Type::Primitive(input_type)).is_ok() + } + + /// Creates a unary predicate from a given operator and a reference name. + fn project_unary(op: PredicateOperator, name: &str) -> Result> { + Ok(Some(Predicate::Unary(UnaryExpression::new( + op, + Reference::new(name), + )))) + } + + /// Attempts to create a binary predicate based on a binary expression, + /// if applicable. + /// + /// This method evaluates a given binary expression and, if the operation + /// is equality (`Eq`) and the literal can be transformed, constructs a + /// `Predicate::Binary`variant representing the binary operation. + fn project_eq_operator( + &self, + name: &str, + expr: &BinaryExpression, + func: &BoxedTransformFunction, + ) -> Result> { + if expr.op() != PredicateOperator::Eq || !self.can_transform(expr.literal()) { + return Ok(None); + } + + Ok(Some(Predicate::Binary(BinaryExpression::new( + expr.op(), + Reference::new(name), + func.transform_literal_result(expr.literal())?, + )))) + } + + /// Projects a binary expression to a predicate with an adjusted boundary. + /// + /// Checks if the literal within the given binary expression is + /// transformable. If transformable, it proceeds to potentially adjust + /// the boundary of the expression based on the comparison operator (`op`). + /// The potential adjustments involve incrementing or decrementing the + /// literal value and changing the `PredicateOperator` itself to its + /// inclusive variant. + fn project_binary_with_adjusted_boundary( + &self, + name: &str, + expr: &BinaryExpression, + func: &BoxedTransformFunction, + width: Option, + ) -> Result> { + if !self.can_transform(expr.literal()) { + return Ok(None); + } + + let op = &expr.op(); + let datum = &expr.literal(); + + if let Some(boundary) = Self::adjust_boundary(op, datum)? { + let transformed_projection = func.transform_literal_result(&boundary)?; + + let adjusted_projection = + self.adjust_time_projection(op, datum, &transformed_projection); + + let adjusted_operator = Self::adjust_operator(op, datum, width); + + if let Some(op) = adjusted_operator { + let predicate = match adjusted_projection { + None => Predicate::Binary(BinaryExpression::new( + op, + Reference::new(name), + transformed_projection, + )), + Some(AdjustedProjection::Single(d)) => { + Predicate::Binary(BinaryExpression::new(op, Reference::new(name), d)) + } + Some(AdjustedProjection::Set(d)) => Predicate::Set(SetExpression::new( + PredicateOperator::In, + Reference::new(name), + d, + )), + }; + return Ok(Some(predicate)); + } + }; + + Ok(None) + } + + /// Projects a set expression to a predicate, + /// applying a transformation to each literal in the set. + fn project_in_operator( + &self, + expr: &SetExpression, + name: &str, + func: &BoxedTransformFunction, + ) -> Result> { + if expr.op() != PredicateOperator::In + || expr.literals().iter().any(|d| !self.can_transform(d)) + { + return Ok(None); + } + + let mut new_set = FnvHashSet::default(); + + for lit in expr.literals() { + let datum = func.transform_literal_result(lit)?; + + if let Some(AdjustedProjection::Single(d)) = + self.adjust_time_projection(&PredicateOperator::In, lit, &datum) + { + new_set.insert(d); + }; + + new_set.insert(datum); + } + + Ok(Some(Predicate::Set(SetExpression::new( + expr.op(), + Reference::new(name), + new_set, + )))) + } + + /// Adjusts the boundary value for comparison operations + /// based on the specified `PredicateOperator` and `Datum`. + /// + /// This function modifies the boundary value for certain comparison + /// operators (`LessThan`, `GreaterThan`) by incrementing or decrementing + /// the literal value within the given `Datum`. For operators that do not + /// imply a boundary shift (`Eq`, `LessThanOrEq`, `GreaterThanOrEq`, + /// `StartsWith`, `NotStartsWith`), the original datum is returned + /// unmodified. + fn adjust_boundary(op: &PredicateOperator, datum: &Datum) -> Result> { + let adjusted_boundary = match op { + PredicateOperator::LessThan => match (datum.data_type(), datum.literal()) { + (PrimitiveType::Int, PrimitiveLiteral::Int(v)) => Some(Datum::int(v - 1)), + (PrimitiveType::Long, PrimitiveLiteral::Long(v)) => Some(Datum::long(v - 1)), + (PrimitiveType::Decimal { .. }, PrimitiveLiteral::Int128(v)) => { + Some(Datum::decimal(v - 1)?) + } + (PrimitiveType::Date, PrimitiveLiteral::Int(v)) => Some(Datum::date(v - 1)), + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(v)) => { + Some(Datum::timestamp_micros(v - 1)) + } + _ => Some(datum.to_owned()), + }, + PredicateOperator::GreaterThan => match (datum.data_type(), datum.literal()) { + (PrimitiveType::Int, PrimitiveLiteral::Int(v)) => Some(Datum::int(v + 1)), + (PrimitiveType::Long, PrimitiveLiteral::Long(v)) => Some(Datum::long(v + 1)), + (PrimitiveType::Decimal { .. }, PrimitiveLiteral::Int128(v)) => { + Some(Datum::decimal(v + 1)?) + } + (PrimitiveType::Date, PrimitiveLiteral::Int(v)) => Some(Datum::date(v + 1)), + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(v)) => { + Some(Datum::timestamp_micros(v + 1)) + } + _ => Some(datum.to_owned()), + }, + PredicateOperator::Eq + | PredicateOperator::LessThanOrEq + | PredicateOperator::GreaterThanOrEq + | PredicateOperator::StartsWith + | PredicateOperator::NotStartsWith => Some(datum.to_owned()), + _ => None, + }; + + Ok(adjusted_boundary) + } + + /// Adjusts the comparison operator based on the specified datum and an + /// optional width constraint. + /// + /// This function modifies the comparison operator for `LessThan` and + /// `GreaterThan` cases to their inclusive counterparts (`LessThanOrEq`, + /// `GreaterThanOrEq`) unconditionally. For `StartsWith` and + /// `NotStartsWith` operators acting on string literals, the operator may + /// be adjusted to `Eq` or `NotEq` if the string length matches the + /// specified width, indicating a precise match rather than a prefix + /// condition. + fn adjust_operator( + op: &PredicateOperator, + datum: &Datum, + width: Option, + ) -> Option { + match op { + PredicateOperator::LessThan => Some(PredicateOperator::LessThanOrEq), + PredicateOperator::GreaterThan => Some(PredicateOperator::GreaterThanOrEq), + PredicateOperator::StartsWith => match datum.literal() { + PrimitiveLiteral::String(s) => { + if let Some(w) = width { + if s.len() == w as usize { + return Some(PredicateOperator::Eq); + }; + }; + Some(*op) + } + _ => Some(*op), + }, + PredicateOperator::NotStartsWith => match datum.literal() { + PrimitiveLiteral::String(s) => { + if let Some(w) = width { + let w = w as usize; + + if s.len() == w { + return Some(PredicateOperator::NotEq); + } + + if s.len() < w { + return Some(*op); + } + + return None; + }; + Some(*op) + } + _ => Some(*op), + }, + _ => Some(*op), + } + } + + /// Adjust projection for temporal transforms, align with Java + /// implementation: https://github.com/apache/iceberg/blob/main/api/src/main/java/org/apache/iceberg/transforms/ProjectionUtil.java#L275 + fn adjust_time_projection( + &self, + op: &PredicateOperator, + original: &Datum, + transformed: &Datum, + ) -> Option { + let should_adjust = match self { + Transform::Day => matches!(original.data_type(), PrimitiveType::Timestamp), + Transform::Year | Transform::Month => true, + _ => false, + }; + + if should_adjust { + if let &PrimitiveLiteral::Int(v) = transformed.literal() { + match op { + PredicateOperator::LessThan + | PredicateOperator::LessThanOrEq + | PredicateOperator::In => { + if v < 0 { + return Some(AdjustedProjection::Single(Datum::int(v + 1))); + }; + } + PredicateOperator::Eq => { + if v < 0 { + let new_set = FnvHashSet::from_iter(vec![ + transformed.to_owned(), + Datum::int(v + 1), + ]); + return Some(AdjustedProjection::Set(new_set)); + } + } + _ => { + return None; + } + } + }; + } + None + } } impl Display for Transform { @@ -339,523 +676,23 @@ impl FromStr for Transform { impl Serialize for Transform { fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { + where S: Serializer { serializer.serialize_str(format!("{self}").as_str()) } } impl<'de> Deserialize<'de> for Transform { fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { + where D: Deserializer<'de> { let s = String::deserialize(deserializer)?; s.parse().map_err(::custom) } } -#[cfg(test)] -mod tests { - use crate::spec::datatypes::PrimitiveType::{ - Binary, Date, Decimal, Fixed, Int, Long, String as StringType, Time, Timestamp, - Timestamptz, Uuid, - }; - use crate::spec::datatypes::Type::{Primitive, Struct}; - use crate::spec::datatypes::{NestedField, StructType, Type}; - use crate::spec::transform::Transform; - - struct TestParameter { - display: String, - json: String, - dedup_name: String, - preserves_order: bool, - satisfies_order_of: Vec<(Transform, bool)>, - trans_types: Vec<(Type, Option)>, - } - - fn check_transform(trans: Transform, param: TestParameter) { - assert_eq!(param.display, format!("{trans}")); - assert_eq!(param.json, serde_json::to_string(&trans).unwrap()); - assert_eq!(trans, serde_json::from_str(param.json.as_str()).unwrap()); - assert_eq!(param.dedup_name, trans.dedup_name()); - assert_eq!(param.preserves_order, trans.preserves_order()); - - for (other_trans, satisfies_order_of) in param.satisfies_order_of { - assert_eq!( - satisfies_order_of, - trans.satisfies_order_of(&other_trans), - "Failed to check satisfies order {}, {}, {}", - trans, - other_trans, - satisfies_order_of - ); - } - - for (input_type, result_type) in param.trans_types { - assert_eq!(result_type, trans.result_type(&input_type).ok()); - } - } - - #[test] - fn test_bucket_transform() { - let trans = Transform::Bucket(8); - - let test_param = TestParameter { - display: "bucket[8]".to_string(), - json: r#""bucket[8]""#.to_string(), - dedup_name: "bucket[8]".to_string(), - preserves_order: false, - satisfies_order_of: vec![ - (Transform::Bucket(8), true), - (Transform::Bucket(4), false), - (Transform::Void, false), - (Transform::Day, false), - ], - trans_types: vec![ - (Primitive(Binary), Some(Primitive(Int))), - (Primitive(Date), Some(Primitive(Int))), - ( - Primitive(Decimal { - precision: 8, - scale: 5, - }), - Some(Primitive(Int)), - ), - (Primitive(Fixed(8)), Some(Primitive(Int))), - (Primitive(Int), Some(Primitive(Int))), - (Primitive(Long), Some(Primitive(Int))), - (Primitive(StringType), Some(Primitive(Int))), - (Primitive(Uuid), Some(Primitive(Int))), - (Primitive(Time), Some(Primitive(Int))), - (Primitive(Timestamp), Some(Primitive(Int))), - (Primitive(Timestamptz), Some(Primitive(Int))), - ( - Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()])), - None, - ), - ], - }; - - check_transform(trans, test_param); - } - - #[test] - fn test_truncate_transform() { - let trans = Transform::Truncate(4); - - let test_param = TestParameter { - display: "truncate[4]".to_string(), - json: r#""truncate[4]""#.to_string(), - dedup_name: "truncate[4]".to_string(), - preserves_order: true, - satisfies_order_of: vec![ - (Transform::Truncate(4), true), - (Transform::Truncate(2), false), - (Transform::Bucket(4), false), - (Transform::Void, false), - (Transform::Day, false), - ], - trans_types: vec![ - (Primitive(Binary), Some(Primitive(Binary))), - (Primitive(Date), None), - ( - Primitive(Decimal { - precision: 8, - scale: 5, - }), - Some(Primitive(Decimal { - precision: 8, - scale: 5, - })), - ), - (Primitive(Fixed(8)), None), - (Primitive(Int), Some(Primitive(Int))), - (Primitive(Long), Some(Primitive(Long))), - (Primitive(StringType), Some(Primitive(StringType))), - (Primitive(Uuid), None), - (Primitive(Time), None), - (Primitive(Timestamp), None), - (Primitive(Timestamptz), None), - ( - Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()])), - None, - ), - ], - }; - - check_transform(trans, test_param); - } - - #[test] - fn test_identity_transform() { - let trans = Transform::Identity; - - let test_param = TestParameter { - display: "identity".to_string(), - json: r#""identity""#.to_string(), - dedup_name: "identity".to_string(), - preserves_order: true, - satisfies_order_of: vec![ - (Transform::Truncate(4), true), - (Transform::Truncate(2), true), - (Transform::Bucket(4), false), - (Transform::Void, false), - (Transform::Day, true), - ], - trans_types: vec![ - (Primitive(Binary), Some(Primitive(Binary))), - (Primitive(Date), Some(Primitive(Date))), - ( - Primitive(Decimal { - precision: 8, - scale: 5, - }), - Some(Primitive(Decimal { - precision: 8, - scale: 5, - })), - ), - (Primitive(Fixed(8)), Some(Primitive(Fixed(8)))), - (Primitive(Int), Some(Primitive(Int))), - (Primitive(Long), Some(Primitive(Long))), - (Primitive(StringType), Some(Primitive(StringType))), - (Primitive(Uuid), Some(Primitive(Uuid))), - (Primitive(Time), Some(Primitive(Time))), - (Primitive(Timestamp), Some(Primitive(Timestamp))), - (Primitive(Timestamptz), Some(Primitive(Timestamptz))), - ( - Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()])), - None, - ), - ], - }; - - check_transform(trans, test_param); - } - - #[test] - fn test_year_transform() { - let trans = Transform::Year; - - let test_param = TestParameter { - display: "year".to_string(), - json: r#""year""#.to_string(), - dedup_name: "time".to_string(), - preserves_order: true, - satisfies_order_of: vec![ - (Transform::Year, true), - (Transform::Month, false), - (Transform::Day, false), - (Transform::Hour, false), - (Transform::Void, false), - (Transform::Identity, false), - ], - trans_types: vec![ - (Primitive(Binary), None), - (Primitive(Date), Some(Primitive(Int))), - ( - Primitive(Decimal { - precision: 8, - scale: 5, - }), - None, - ), - (Primitive(Fixed(8)), None), - (Primitive(Int), None), - (Primitive(Long), None), - (Primitive(StringType), None), - (Primitive(Uuid), None), - (Primitive(Time), None), - (Primitive(Timestamp), Some(Primitive(Int))), - (Primitive(Timestamptz), Some(Primitive(Int))), - ( - Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()])), - None, - ), - ], - }; - - check_transform(trans, test_param); - } - - #[test] - fn test_month_transform() { - let trans = Transform::Month; - - let test_param = TestParameter { - display: "month".to_string(), - json: r#""month""#.to_string(), - dedup_name: "time".to_string(), - preserves_order: true, - satisfies_order_of: vec![ - (Transform::Year, true), - (Transform::Month, true), - (Transform::Day, false), - (Transform::Hour, false), - (Transform::Void, false), - (Transform::Identity, false), - ], - trans_types: vec![ - (Primitive(Binary), None), - (Primitive(Date), Some(Primitive(Int))), - ( - Primitive(Decimal { - precision: 8, - scale: 5, - }), - None, - ), - (Primitive(Fixed(8)), None), - (Primitive(Int), None), - (Primitive(Long), None), - (Primitive(StringType), None), - (Primitive(Uuid), None), - (Primitive(Time), None), - (Primitive(Timestamp), Some(Primitive(Int))), - (Primitive(Timestamptz), Some(Primitive(Int))), - ( - Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()])), - None, - ), - ], - }; - - check_transform(trans, test_param); - } - - #[test] - fn test_day_transform() { - let trans = Transform::Day; - - let test_param = TestParameter { - display: "day".to_string(), - json: r#""day""#.to_string(), - dedup_name: "time".to_string(), - preserves_order: true, - satisfies_order_of: vec![ - (Transform::Year, true), - (Transform::Month, true), - (Transform::Day, true), - (Transform::Hour, false), - (Transform::Void, false), - (Transform::Identity, false), - ], - trans_types: vec![ - (Primitive(Binary), None), - (Primitive(Date), Some(Primitive(Int))), - ( - Primitive(Decimal { - precision: 8, - scale: 5, - }), - None, - ), - (Primitive(Fixed(8)), None), - (Primitive(Int), None), - (Primitive(Long), None), - (Primitive(StringType), None), - (Primitive(Uuid), None), - (Primitive(Time), None), - (Primitive(Timestamp), Some(Primitive(Int))), - (Primitive(Timestamptz), Some(Primitive(Int))), - ( - Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()])), - None, - ), - ], - }; - - check_transform(trans, test_param); - } - - #[test] - fn test_hour_transform() { - let trans = Transform::Hour; - - let test_param = TestParameter { - display: "hour".to_string(), - json: r#""hour""#.to_string(), - dedup_name: "time".to_string(), - preserves_order: true, - satisfies_order_of: vec![ - (Transform::Year, true), - (Transform::Month, true), - (Transform::Day, true), - (Transform::Hour, true), - (Transform::Void, false), - (Transform::Identity, false), - ], - trans_types: vec![ - (Primitive(Binary), None), - (Primitive(Date), None), - ( - Primitive(Decimal { - precision: 8, - scale: 5, - }), - None, - ), - (Primitive(Fixed(8)), None), - (Primitive(Int), None), - (Primitive(Long), None), - (Primitive(StringType), None), - (Primitive(Uuid), None), - (Primitive(Time), None), - (Primitive(Timestamp), Some(Primitive(Int))), - (Primitive(Timestamptz), Some(Primitive(Int))), - ( - Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()])), - None, - ), - ], - }; - - check_transform(trans, test_param); - } - - #[test] - fn test_void_transform() { - let trans = Transform::Void; - - let test_param = TestParameter { - display: "void".to_string(), - json: r#""void""#.to_string(), - dedup_name: "void".to_string(), - preserves_order: false, - satisfies_order_of: vec![ - (Transform::Year, false), - (Transform::Month, false), - (Transform::Day, false), - (Transform::Hour, false), - (Transform::Void, true), - (Transform::Identity, false), - ], - trans_types: vec![ - (Primitive(Binary), Some(Primitive(Binary))), - (Primitive(Date), Some(Primitive(Date))), - ( - Primitive(Decimal { - precision: 8, - scale: 5, - }), - Some(Primitive(Decimal { - precision: 8, - scale: 5, - })), - ), - (Primitive(Fixed(8)), Some(Primitive(Fixed(8)))), - (Primitive(Int), Some(Primitive(Int))), - (Primitive(Long), Some(Primitive(Long))), - (Primitive(StringType), Some(Primitive(StringType))), - (Primitive(Uuid), Some(Primitive(Uuid))), - (Primitive(Time), Some(Primitive(Time))), - (Primitive(Timestamp), Some(Primitive(Timestamp))), - (Primitive(Timestamptz), Some(Primitive(Timestamptz))), - ( - Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()])), - Some(Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()]))), - ), - ], - }; - - check_transform(trans, test_param); - } - - #[test] - fn test_known_transform() { - let trans = Transform::Unknown; - - let test_param = TestParameter { - display: "unknown".to_string(), - json: r#""unknown""#.to_string(), - dedup_name: "unknown".to_string(), - preserves_order: false, - satisfies_order_of: vec![ - (Transform::Year, false), - (Transform::Month, false), - (Transform::Day, false), - (Transform::Hour, false), - (Transform::Void, false), - (Transform::Identity, false), - (Transform::Unknown, true), - ], - trans_types: vec![ - (Primitive(Binary), Some(Primitive(StringType))), - (Primitive(Date), Some(Primitive(StringType))), - ( - Primitive(Decimal { - precision: 8, - scale: 5, - }), - Some(Primitive(StringType)), - ), - (Primitive(Fixed(8)), Some(Primitive(StringType))), - (Primitive(Int), Some(Primitive(StringType))), - (Primitive(Long), Some(Primitive(StringType))), - (Primitive(StringType), Some(Primitive(StringType))), - (Primitive(Uuid), Some(Primitive(StringType))), - (Primitive(Time), Some(Primitive(StringType))), - (Primitive(Timestamp), Some(Primitive(StringType))), - (Primitive(Timestamptz), Some(Primitive(StringType))), - ( - Struct(StructType::new(vec![NestedField::optional( - 1, - "a", - Primitive(Timestamp), - ) - .into()])), - Some(Primitive(StringType)), - ), - ], - }; - - check_transform(trans, test_param); - } +/// An enum representing the result of the adjusted projection. +/// Either being a single adjusted datum or a set. +#[derive(Debug)] +enum AdjustedProjection { + Single(Datum), + Set(FnvHashSet), } diff --git a/crates/iceberg/src/spec/values.rs b/crates/iceberg/src/spec/values.rs index 39f870602..3568d3dcd 100644 --- a/crates/iceberg/src/spec/values.rs +++ b/crates/iceberg/src/spec/values.rs @@ -19,59 +19,1156 @@ * Value in iceberg */ +use std::any::Any; +use std::collections::HashMap; +use std::fmt::{Display, Formatter}; +use std::hash::Hash; +use std::ops::Index; use std::str::FromStr; -use std::{any::Any, collections::BTreeMap}; -use crate::error::Result; +pub use _serde::RawLiteral; use bitvec::vec::BitVec; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; use ordered_float::OrderedFloat; use rust_decimal::Decimal; +use serde::de::{ + MapAccess, {self}, +}; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use serde_json::{Map as JsonMap, Number, Value as JsonValue}; +use timestamp::nanoseconds_to_datetime; use uuid::Uuid; -use crate::{Error, ErrorKind}; +use super::datatypes::{PrimitiveType, Type}; +use crate::error::Result; +use crate::spec::values::date::{date_from_naive_date, days_to_date, unix_epoch}; +use crate::spec::values::time::microseconds_to_time; +use crate::spec::values::timestamp::microseconds_to_datetime; +use crate::spec::values::timestamptz::{microseconds_to_datetimetz, nanoseconds_to_datetimetz}; +use crate::spec::MAX_DECIMAL_PRECISION; +use crate::{ensure_data_valid, Error, ErrorKind}; + +/// Maximum value for [`PrimitiveType::Time`] type in microseconds, e.g. 23 hours 59 minutes 59 seconds 999999 microseconds. +const MAX_TIME_VALUE: i64 = 24 * 60 * 60 * 1_000_000i64 - 1; + +/// Values present in iceberg type +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub enum PrimitiveLiteral { + /// 0x00 for false, non-zero byte for true + Boolean(bool), + /// Stored as 4-byte little-endian + Int(i32), + /// Stored as 8-byte little-endian + Long(i64), + /// Stored as 4-byte little-endian + Float(OrderedFloat), + /// Stored as 8-byte little-endian + Double(OrderedFloat), + /// UTF-8 bytes (without length) + String(String), + /// Binary value (without length) + Binary(Vec), + /// Stored as 16-byte little-endian + Int128(i128), + /// Stored as 16-byte little-endian + UInt128(u128), +} + +impl PrimitiveLiteral { + /// Returns true if the Literal represents a primitive type + /// that can be a NaN, and that it's value is NaN + pub fn is_nan(&self) -> bool { + match self { + PrimitiveLiteral::Double(val) => val.is_nan(), + PrimitiveLiteral::Float(val) => val.is_nan(), + _ => false, + } + } +} + +/// Literal associated with its type. The value and type pair is checked when construction, so the type and value is +/// guaranteed to be correct when used. +/// +/// By default, we decouple the type and value of a literal, so we can use avoid the cost of storing extra type info +/// for each literal. But associate type with literal can be useful in some cases, for example, in unbound expression. +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub struct Datum { + r#type: PrimitiveType, + literal: PrimitiveLiteral, +} + +impl Serialize for Datum { + fn serialize( + &self, + serializer: S, + ) -> std::result::Result { + let mut struct_ser = serializer + .serialize_struct("Datum", 2) + .map_err(serde::ser::Error::custom)?; + struct_ser + .serialize_field("type", &self.r#type) + .map_err(serde::ser::Error::custom)?; + struct_ser + .serialize_field( + "literal", + &RawLiteral::try_from( + Literal::Primitive(self.literal.clone()), + &Type::Primitive(self.r#type.clone()), + ) + .map_err(serde::ser::Error::custom)?, + ) + .map_err(serde::ser::Error::custom)?; + struct_ser.end() + } +} + +impl<'de> Deserialize<'de> for Datum { + fn deserialize>( + deserializer: D, + ) -> std::result::Result { + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "lowercase")] + enum Field { + Type, + Literal, + } + + struct DatumVisitor; + + impl<'de> serde::de::Visitor<'de> for DatumVisitor { + type Value = Datum; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct Datum") + } + + fn visit_seq(self, mut seq: A) -> std::result::Result + where A: serde::de::SeqAccess<'de> { + let r#type = seq + .next_element::()? + .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; + let value = seq + .next_element::()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; + let Literal::Primitive(primitive) = value + .try_into(&Type::Primitive(r#type.clone())) + .map_err(serde::de::Error::custom)? + .ok_or_else(|| serde::de::Error::custom("None value"))? + else { + return Err(serde::de::Error::custom("Invalid value")); + }; + + Ok(Datum::new(r#type, primitive)) + } + + fn visit_map(self, mut map: V) -> std::result::Result + where V: MapAccess<'de> { + let mut raw_primitive: Option = None; + let mut r#type: Option = None; + while let Some(key) = map.next_key()? { + match key { + Field::Type => { + if r#type.is_some() { + return Err(de::Error::duplicate_field("type")); + } + r#type = Some(map.next_value()?); + } + Field::Literal => { + if raw_primitive.is_some() { + return Err(de::Error::duplicate_field("literal")); + } + raw_primitive = Some(map.next_value()?); + } + } + } + let Some(r#type) = r#type else { + return Err(serde::de::Error::missing_field("type")); + }; + let Some(raw_primitive) = raw_primitive else { + return Err(serde::de::Error::missing_field("literal")); + }; + let Literal::Primitive(primitive) = raw_primitive + .try_into(&Type::Primitive(r#type.clone())) + .map_err(serde::de::Error::custom)? + .ok_or_else(|| serde::de::Error::custom("None value"))? + else { + return Err(serde::de::Error::custom("Invalid value")); + }; + Ok(Datum::new(r#type, primitive)) + } + } + const FIELDS: &[&str] = &["type", "literal"]; + deserializer.deserialize_struct("Datum", FIELDS, DatumVisitor) + } +} + +impl PartialOrd for Datum { + fn partial_cmp(&self, other: &Self) -> Option { + match (&self.literal, &other.literal, &self.r#type, &other.r#type) { + // generate the arm with same type and same literal + ( + PrimitiveLiteral::Boolean(val), + PrimitiveLiteral::Boolean(other_val), + PrimitiveType::Boolean, + PrimitiveType::Boolean, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Int(val), + PrimitiveLiteral::Int(other_val), + PrimitiveType::Int, + PrimitiveType::Int, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Long(val), + PrimitiveLiteral::Long(other_val), + PrimitiveType::Long, + PrimitiveType::Long, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Float(val), + PrimitiveLiteral::Float(other_val), + PrimitiveType::Float, + PrimitiveType::Float, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Double(val), + PrimitiveLiteral::Double(other_val), + PrimitiveType::Double, + PrimitiveType::Double, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Int(val), + PrimitiveLiteral::Int(other_val), + PrimitiveType::Date, + PrimitiveType::Date, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Long(val), + PrimitiveLiteral::Long(other_val), + PrimitiveType::Time, + PrimitiveType::Time, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Long(val), + PrimitiveLiteral::Long(other_val), + PrimitiveType::Timestamp, + PrimitiveType::Timestamp, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Long(val), + PrimitiveLiteral::Long(other_val), + PrimitiveType::Timestamptz, + PrimitiveType::Timestamptz, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::String(val), + PrimitiveLiteral::String(other_val), + PrimitiveType::String, + PrimitiveType::String, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::UInt128(val), + PrimitiveLiteral::UInt128(other_val), + PrimitiveType::Uuid, + PrimitiveType::Uuid, + ) => Uuid::from_u128(*val).partial_cmp(&Uuid::from_u128(*other_val)), + ( + PrimitiveLiteral::Binary(val), + PrimitiveLiteral::Binary(other_val), + PrimitiveType::Fixed(_), + PrimitiveType::Fixed(_), + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Binary(val), + PrimitiveLiteral::Binary(other_val), + PrimitiveType::Binary, + PrimitiveType::Binary, + ) => val.partial_cmp(other_val), + ( + PrimitiveLiteral::Int128(val), + PrimitiveLiteral::Int128(other_val), + PrimitiveType::Decimal { + precision: _, + scale, + }, + PrimitiveType::Decimal { + precision: _, + scale: other_scale, + }, + ) => { + let val = Decimal::from_i128_with_scale(*val, *scale); + let other_val = Decimal::from_i128_with_scale(*other_val, *other_scale); + val.partial_cmp(&other_val) + } + _ => None, + } + } +} + +impl Display for Datum { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match (&self.r#type, &self.literal) { + (_, PrimitiveLiteral::Boolean(val)) => write!(f, "{}", val), + (PrimitiveType::Int, PrimitiveLiteral::Int(val)) => write!(f, "{}", val), + (PrimitiveType::Long, PrimitiveLiteral::Long(val)) => write!(f, "{}", val), + (_, PrimitiveLiteral::Float(val)) => write!(f, "{}", val), + (_, PrimitiveLiteral::Double(val)) => write!(f, "{}", val), + (PrimitiveType::Date, PrimitiveLiteral::Int(val)) => { + write!(f, "{}", days_to_date(*val)) + } + (PrimitiveType::Time, PrimitiveLiteral::Long(val)) => { + write!(f, "{}", microseconds_to_time(*val)) + } + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(val)) => { + write!(f, "{}", microseconds_to_datetime(*val)) + } + (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(val)) => { + write!(f, "{}", microseconds_to_datetimetz(*val)) + } + (PrimitiveType::TimestampNs, PrimitiveLiteral::Long(val)) => { + write!(f, "{}", nanoseconds_to_datetime(*val)) + } + (PrimitiveType::TimestamptzNs, PrimitiveLiteral::Long(val)) => { + write!(f, "{}", nanoseconds_to_datetimetz(*val)) + } + (_, PrimitiveLiteral::String(val)) => write!(f, r#""{}""#, val), + (PrimitiveType::Uuid, PrimitiveLiteral::UInt128(val)) => { + write!(f, "{}", Uuid::from_u128(*val)) + } + (_, PrimitiveLiteral::Binary(val)) => display_bytes(val, f), + ( + PrimitiveType::Decimal { + precision: _, + scale, + }, + PrimitiveLiteral::Int128(val), + ) => { + write!(f, "{}", Decimal::from_i128_with_scale(*val, *scale)) + } + (_, _) => { + unreachable!() + } + } + } +} + +fn display_bytes(bytes: &[u8], f: &mut Formatter<'_>) -> std::fmt::Result { + let mut s = String::with_capacity(bytes.len() * 2); + for b in bytes { + s.push_str(&format!("{:02X}", b)); + } + f.write_str(&s) +} + +impl From for Literal { + fn from(value: Datum) -> Self { + Literal::Primitive(value.literal) + } +} + +impl From for PrimitiveLiteral { + fn from(value: Datum) -> Self { + value.literal + } +} + +impl Datum { + /// Creates a `Datum` from a `PrimitiveType` and a `PrimitiveLiteral` + pub(crate) fn new(r#type: PrimitiveType, literal: PrimitiveLiteral) -> Self { + Datum { r#type, literal } + } + + /// Create iceberg value from bytes. + /// + /// See [this spec](https://iceberg.apache.org/spec/#binary-single-value-serialization) for reference. + pub fn try_from_bytes(bytes: &[u8], data_type: PrimitiveType) -> Result { + let literal = match data_type { + PrimitiveType::Boolean => { + if bytes.len() == 1 && bytes[0] == 0u8 { + PrimitiveLiteral::Boolean(false) + } else { + PrimitiveLiteral::Boolean(true) + } + } + PrimitiveType::Int => PrimitiveLiteral::Int(i32::from_le_bytes(bytes.try_into()?)), + PrimitiveType::Long => PrimitiveLiteral::Long(i64::from_le_bytes(bytes.try_into()?)), + PrimitiveType::Float => { + PrimitiveLiteral::Float(OrderedFloat(f32::from_le_bytes(bytes.try_into()?))) + } + PrimitiveType::Double => { + PrimitiveLiteral::Double(OrderedFloat(f64::from_le_bytes(bytes.try_into()?))) + } + PrimitiveType::Date => PrimitiveLiteral::Int(i32::from_le_bytes(bytes.try_into()?)), + PrimitiveType::Time => PrimitiveLiteral::Long(i64::from_le_bytes(bytes.try_into()?)), + PrimitiveType::Timestamp => { + PrimitiveLiteral::Long(i64::from_le_bytes(bytes.try_into()?)) + } + PrimitiveType::Timestamptz => { + PrimitiveLiteral::Long(i64::from_le_bytes(bytes.try_into()?)) + } + PrimitiveType::TimestampNs => { + PrimitiveLiteral::Long(i64::from_le_bytes(bytes.try_into()?)) + } + PrimitiveType::TimestamptzNs => { + PrimitiveLiteral::Long(i64::from_le_bytes(bytes.try_into()?)) + } + PrimitiveType::String => { + PrimitiveLiteral::String(std::str::from_utf8(bytes)?.to_string()) + } + PrimitiveType::Uuid => { + PrimitiveLiteral::UInt128(u128::from_be_bytes(bytes.try_into()?)) + } + PrimitiveType::Fixed(_) => PrimitiveLiteral::Binary(Vec::from(bytes)), + PrimitiveType::Binary => PrimitiveLiteral::Binary(Vec::from(bytes)), + PrimitiveType::Decimal { + precision: _, + scale: _, + } => todo!(), + }; + Ok(Datum::new(data_type, literal)) + } + + /// Convert the value to bytes + /// + /// See [this spec](https://iceberg.apache.org/spec/#binary-single-value-serialization) for reference. + pub fn to_bytes(&self) -> ByteBuf { + match &self.literal { + PrimitiveLiteral::Boolean(val) => { + if *val { + ByteBuf::from([1u8]) + } else { + ByteBuf::from([0u8]) + } + } + PrimitiveLiteral::Int(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Long(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Float(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Double(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::String(val) => ByteBuf::from(val.as_bytes()), + PrimitiveLiteral::UInt128(val) => ByteBuf::from(val.to_be_bytes()), + PrimitiveLiteral::Binary(val) => ByteBuf::from(val.as_slice()), + PrimitiveLiteral::Int128(_) => todo!(), + } + } + + /// Creates a boolean value. + /// + /// Example: + /// ```rust + /// use iceberg::spec::{Datum, Literal, PrimitiveLiteral}; + /// let t = Datum::bool(true); + /// + /// assert_eq!(format!("{}", t), "true".to_string()); + /// assert_eq!( + /// Literal::from(t), + /// Literal::Primitive(PrimitiveLiteral::Boolean(true)) + /// ); + /// ``` + pub fn bool>(t: T) -> Self { + Self { + r#type: PrimitiveType::Boolean, + literal: PrimitiveLiteral::Boolean(t.into()), + } + } + + /// Creates a boolean value from string. + /// See [Parse bool from str](https://doc.rust-lang.org/stable/std/primitive.bool.html#impl-FromStr-for-bool) for reference. + /// + /// Example: + /// ```rust + /// use iceberg::spec::{Datum, Literal, PrimitiveLiteral}; + /// let t = Datum::bool_from_str("false").unwrap(); + /// + /// assert_eq!(&format!("{}", t), "false"); + /// assert_eq!( + /// Literal::Primitive(PrimitiveLiteral::Boolean(false)), + /// t.into() + /// ); + /// ``` + pub fn bool_from_str>(s: S) -> Result { + let v = s.as_ref().parse::().map_err(|e| { + Error::new(ErrorKind::DataInvalid, "Can't parse string to bool.").with_source(e) + })?; + Ok(Self::bool(v)) + } + + /// Creates an 32bit integer. + /// + /// Example: + /// ```rust + /// use iceberg::spec::{Datum, Literal, PrimitiveLiteral}; + /// let t = Datum::int(23i8); + /// + /// assert_eq!(&format!("{}", t), "23"); + /// assert_eq!(Literal::Primitive(PrimitiveLiteral::Int(23)), t.into()); + /// ``` + pub fn int>(t: T) -> Self { + Self { + r#type: PrimitiveType::Int, + literal: PrimitiveLiteral::Int(t.into()), + } + } + + /// Creates an 64bit integer. + /// + /// Example: + /// ```rust + /// use iceberg::spec::{Datum, Literal, PrimitiveLiteral}; + /// let t = Datum::long(24i8); + /// + /// assert_eq!(&format!("{t}"), "24"); + /// assert_eq!(Literal::Primitive(PrimitiveLiteral::Long(24)), t.into()); + /// ``` + pub fn long>(t: T) -> Self { + Self { + r#type: PrimitiveType::Long, + literal: PrimitiveLiteral::Long(t.into()), + } + } + + /// Creates an 32bit floating point number. + /// + /// Example: + /// ```rust + /// use iceberg::spec::{Datum, Literal, PrimitiveLiteral}; + /// use ordered_float::OrderedFloat; + /// let t = Datum::float(32.1f32); + /// + /// assert_eq!(&format!("{t}"), "32.1"); + /// assert_eq!( + /// Literal::Primitive(PrimitiveLiteral::Float(OrderedFloat(32.1))), + /// t.into() + /// ); + /// ``` + pub fn float>(t: T) -> Self { + Self { + r#type: PrimitiveType::Float, + literal: PrimitiveLiteral::Float(OrderedFloat(t.into())), + } + } + + /// Creates an 64bit floating point number. + /// + /// Example: + /// ```rust + /// use iceberg::spec::{Datum, Literal, PrimitiveLiteral}; + /// use ordered_float::OrderedFloat; + /// let t = Datum::double(32.1f64); + /// + /// assert_eq!(&format!("{t}"), "32.1"); + /// assert_eq!( + /// Literal::Primitive(PrimitiveLiteral::Double(OrderedFloat(32.1))), + /// t.into() + /// ); + /// ``` + pub fn double>(t: T) -> Self { + Self { + r#type: PrimitiveType::Double, + literal: PrimitiveLiteral::Double(OrderedFloat(t.into())), + } + } + + /// Creates date literal from number of days from unix epoch directly. + /// + /// Example: + /// ```rust + /// use iceberg::spec::{Datum, Literal, PrimitiveLiteral}; + /// // 2 days after 1970-01-01 + /// let t = Datum::date(2); + /// + /// assert_eq!(&format!("{t}"), "1970-01-03"); + /// assert_eq!(Literal::Primitive(PrimitiveLiteral::Int(2)), t.into()); + /// ``` + pub fn date(days: i32) -> Self { + Self { + r#type: PrimitiveType::Date, + literal: PrimitiveLiteral::Int(days), + } + } + + /// Creates date literal in `%Y-%m-%d` format, assume in utc timezone. + /// + /// See [`NaiveDate::from_str`]. + /// + /// Example + /// ```rust + /// use iceberg::spec::{Datum, Literal}; + /// let t = Datum::date_from_str("1970-01-05").unwrap(); + /// + /// assert_eq!(&format!("{t}"), "1970-01-05"); + /// assert_eq!(Literal::date(4), t.into()); + /// ``` + pub fn date_from_str>(s: S) -> Result { + let t = s.as_ref().parse::().map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Can't parse date from string: {}", s.as_ref()), + ) + .with_source(e) + })?; + + Ok(Self::date(date_from_naive_date(t))) + } + + /// Create date literal from calendar date (year, month and day). + /// + /// See [`NaiveDate::from_ymd_opt`]. + /// + /// Example: + /// + ///```rust + /// use iceberg::spec::{Datum, Literal}; + /// let t = Datum::date_from_ymd(1970, 1, 5).unwrap(); + /// + /// assert_eq!(&format!("{t}"), "1970-01-05"); + /// assert_eq!(Literal::date(4), t.into()); + /// ``` + pub fn date_from_ymd(year: i32, month: u32, day: u32) -> Result { + let t = NaiveDate::from_ymd_opt(year, month, day).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Can't create date from year: {year}, month: {month}, day: {day}"), + ) + })?; + + Ok(Self::date(date_from_naive_date(t))) + } + + /// Creates time literal in microseconds directly. + /// + /// It will return error when it's negative or too large to fit in 24 hours. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::{Datum, Literal}; + /// let micro_secs = { + /// 1 * 3600 * 1_000_000 + // 1 hour + /// 2 * 60 * 1_000_000 + // 2 minutes + /// 1 * 1_000_000 + // 1 second + /// 888999 // microseconds + /// }; + /// + /// let t = Datum::time_micros(micro_secs).unwrap(); + /// + /// assert_eq!(&format!("{t}"), "01:02:01.888999"); + /// assert_eq!(Literal::time(micro_secs), t.into()); + /// + /// let negative_value = -100; + /// assert!(Datum::time_micros(negative_value).is_err()); + /// + /// let too_large_value = 36 * 60 * 60 * 1_000_000; // Too large to fit in 24 hours. + /// assert!(Datum::time_micros(too_large_value).is_err()); + /// ``` + pub fn time_micros(value: i64) -> Result { + ensure_data_valid!( + (0..=MAX_TIME_VALUE).contains(&value), + "Invalid value for Time type: {}", + value + ); + + Ok(Self { + r#type: PrimitiveType::Time, + literal: PrimitiveLiteral::Long(value), + }) + } + + /// Creates time literal from [`chrono::NaiveTime`]. + fn time_from_naive_time(t: NaiveTime) -> Self { + let duration = t - unix_epoch().time(); + // It's safe to unwrap here since less than 24 hours will never overflow. + let micro_secs = duration.num_microseconds().unwrap(); + + Self { + r#type: PrimitiveType::Time, + literal: PrimitiveLiteral::Long(micro_secs), + } + } + + /// Creates time literal in microseconds in `%H:%M:%S:.f` format. + /// + /// See [`NaiveTime::from_str`] for details. + /// + /// Example: + /// ```rust + /// use iceberg::spec::{Datum, Literal}; + /// let t = Datum::time_from_str("01:02:01.888999777").unwrap(); + /// + /// assert_eq!(&format!("{t}"), "01:02:01.888999"); + /// ``` + pub fn time_from_str>(s: S) -> Result { + let t = s.as_ref().parse::().map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Can't parse time from string: {}", s.as_ref()), + ) + .with_source(e) + })?; + + Ok(Self::time_from_naive_time(t)) + } + + /// Creates time literal from hour, minute, second, and microseconds. + /// + /// See [`NaiveTime::from_hms_micro_opt`]. + /// + /// Example: + /// ```rust + /// use iceberg::spec::{Datum, Literal}; + /// let t = Datum::time_from_hms_micro(22, 15, 33, 111).unwrap(); + /// + /// assert_eq!(&format!("{t}"), "22:15:33.000111"); + /// ``` + pub fn time_from_hms_micro(hour: u32, min: u32, sec: u32, micro: u32) -> Result { + let t = NaiveTime::from_hms_micro_opt(hour, min, sec, micro) + .ok_or_else(|| Error::new( + ErrorKind::DataInvalid, + format!("Can't create time from hour: {hour}, min: {min}, second: {sec}, microsecond: {micro}"), + ))?; + Ok(Self::time_from_naive_time(t)) + } + + /// Creates a timestamp from unix epoch in microseconds. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// let t = Datum::timestamp_micros(1000); + /// + /// assert_eq!(&format!("{t}"), "1970-01-01 00:00:00.001"); + /// ``` + pub fn timestamp_micros(value: i64) -> Self { + Self { + r#type: PrimitiveType::Timestamp, + literal: PrimitiveLiteral::Long(value), + } + } + + /// Creates a timestamp from unix epoch in nanoseconds. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// let t = Datum::timestamp_nanos(1000); + /// + /// assert_eq!(&format!("{t}"), "1970-01-01 00:00:00.000001"); + /// ``` + pub fn timestamp_nanos(value: i64) -> Self { + Self { + r#type: PrimitiveType::TimestampNs, + literal: PrimitiveLiteral::Long(value), + } + } -use super::datatypes::{PrimitiveType, Type}; + /// Creates a timestamp from [`DateTime`]. + /// + /// Example: + /// + /// ```rust + /// use chrono::{NaiveDate, NaiveDateTime, TimeZone, Utc}; + /// use iceberg::spec::Datum; + /// let t = Datum::timestamp_from_datetime( + /// NaiveDate::from_ymd_opt(1992, 3, 1) + /// .unwrap() + /// .and_hms_micro_opt(1, 2, 3, 88) + /// .unwrap(), + /// ); + /// + /// assert_eq!(&format!("{t}"), "1992-03-01 01:02:03.000088"); + /// ``` + pub fn timestamp_from_datetime(dt: NaiveDateTime) -> Self { + Self::timestamp_micros(dt.and_utc().timestamp_micros()) + } -pub use _serde::RawLiteral; + /// Parse a timestamp in [`%Y-%m-%dT%H:%M:%S%.f`] format. + /// + /// See [`NaiveDateTime::from_str`]. + /// + /// Example: + /// + /// ```rust + /// use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; + /// use iceberg::spec::{Datum, Literal}; + /// let t = Datum::timestamp_from_str("1992-03-01T01:02:03.000088").unwrap(); + /// + /// assert_eq!(&format!("{t}"), "1992-03-01 01:02:03.000088"); + /// ``` + pub fn timestamp_from_str>(s: S) -> Result { + let dt = s.as_ref().parse::().map_err(|e| { + Error::new(ErrorKind::DataInvalid, "Can't parse timestamp.").with_source(e) + })?; -/// Values present in iceberg type -#[derive(Clone, Debug, PartialEq, Hash, Eq, PartialOrd, Ord)] -pub enum PrimitiveLiteral { - /// 0x00 for false, non-zero byte for true - Boolean(bool), - /// Stored as 4-byte little-endian - Int(i32), - /// Stored as 8-byte little-endian - Long(i64), - /// Stored as 4-byte little-endian - Float(OrderedFloat), - /// Stored as 8-byte little-endian - Double(OrderedFloat), - /// Stores days from the 1970-01-01 in an 4-byte little-endian int - Date(i32), - /// Stores microseconds from midnight in an 8-byte little-endian long - Time(i64), - /// Timestamp without timezone - Timestamp(i64), - /// Timestamp with timezone - TimestampTZ(i64), - /// UTF-8 bytes (without length) - String(String), - /// 16-byte big-endian value - UUID(Uuid), - /// Binary value - Fixed(Vec), - /// Binary value (without length) - Binary(Vec), - /// Stores unscaled value as big int. According to iceberg spec, the precision must less than 38(`MAX_DECIMAL_PRECISION`) , so i128 is suit here. - Decimal(i128), + Ok(Self::timestamp_from_datetime(dt)) + } + + /// Creates a timestamp with timezone from unix epoch in microseconds. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// let t = Datum::timestamptz_micros(1000); + /// + /// assert_eq!(&format!("{t}"), "1970-01-01 00:00:00.001 UTC"); + /// ``` + pub fn timestamptz_micros(value: i64) -> Self { + Self { + r#type: PrimitiveType::Timestamptz, + literal: PrimitiveLiteral::Long(value), + } + } + + /// Creates a timestamp with timezone from unix epoch in nanoseconds. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// let t = Datum::timestamptz_nanos(1000); + /// + /// assert_eq!(&format!("{t}"), "1970-01-01 00:00:00.000001 UTC"); + /// ``` + pub fn timestamptz_nanos(value: i64) -> Self { + Self { + r#type: PrimitiveType::TimestamptzNs, + literal: PrimitiveLiteral::Long(value), + } + } + + /// Creates a timestamp with timezone from [`DateTime`]. + /// Example: + /// + /// ```rust + /// use chrono::{TimeZone, Utc}; + /// use iceberg::spec::Datum; + /// let t = Datum::timestamptz_from_datetime(Utc.timestamp_opt(1000, 0).unwrap()); + /// + /// assert_eq!(&format!("{t}"), "1970-01-01 00:16:40 UTC"); + /// ``` + pub fn timestamptz_from_datetime(dt: DateTime) -> Self { + Self::timestamptz_micros(dt.with_timezone(&Utc).timestamp_micros()) + } + + /// Parse timestamp with timezone in RFC3339 format. + /// + /// See [`DateTime::from_str`]. + /// + /// Example: + /// + /// ```rust + /// use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; + /// use iceberg::spec::{Datum, Literal}; + /// let t = Datum::timestamptz_from_str("1992-03-01T01:02:03.000088+08:00").unwrap(); + /// + /// assert_eq!(&format!("{t}"), "1992-02-29 17:02:03.000088 UTC"); + /// ``` + pub fn timestamptz_from_str>(s: S) -> Result { + let dt = DateTime::::from_str(s.as_ref()).map_err(|e| { + Error::new(ErrorKind::DataInvalid, "Can't parse datetime.").with_source(e) + })?; + + Ok(Self::timestamptz_from_datetime(dt)) + } + + /// Creates a string literal. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// let t = Datum::string("ss"); + /// + /// assert_eq!(&format!("{t}"), r#""ss""#); + /// ``` + pub fn string(s: S) -> Self { + Self { + r#type: PrimitiveType::String, + literal: PrimitiveLiteral::String(s.to_string()), + } + } + + /// Creates uuid literal. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// use uuid::uuid; + /// let t = Datum::uuid(uuid!("a1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8")); + /// + /// assert_eq!(&format!("{t}"), "a1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8"); + /// ``` + pub fn uuid(uuid: Uuid) -> Self { + Self { + r#type: PrimitiveType::Uuid, + literal: PrimitiveLiteral::UInt128(uuid.as_u128()), + } + } + + /// Creates uuid from str. See [`Uuid::parse_str`]. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// let t = Datum::uuid_from_str("a1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8").unwrap(); + /// + /// assert_eq!(&format!("{t}"), "a1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8"); + /// ``` + pub fn uuid_from_str>(s: S) -> Result { + let uuid = Uuid::parse_str(s.as_ref()).map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Can't parse uuid from string: {}", s.as_ref()), + ) + .with_source(e) + })?; + Ok(Self::uuid(uuid)) + } + + /// Creates a fixed literal from bytes. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::{Datum, Literal, PrimitiveLiteral}; + /// let t = Datum::fixed(vec![1u8, 2u8]); + /// + /// assert_eq!(&format!("{t}"), "0102"); + /// ``` + pub fn fixed>(input: I) -> Self { + let value: Vec = input.into_iter().collect(); + Self { + r#type: PrimitiveType::Fixed(value.len() as u64), + literal: PrimitiveLiteral::Binary(value), + } + } + + /// Creates a binary literal from bytes. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// let t = Datum::binary(vec![1u8, 100u8]); + /// + /// assert_eq!(&format!("{t}"), "0164"); + /// ``` + pub fn binary>(input: I) -> Self { + Self { + r#type: PrimitiveType::Binary, + literal: PrimitiveLiteral::Binary(input.into_iter().collect()), + } + } + + /// Creates decimal literal from string. See [`Decimal::from_str_exact`]. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// use itertools::assert_equal; + /// use rust_decimal::Decimal; + /// let t = Datum::decimal_from_str("123.45").unwrap(); + /// + /// assert_eq!(&format!("{t}"), "123.45"); + /// ``` + pub fn decimal_from_str>(s: S) -> Result { + let decimal = Decimal::from_str_exact(s.as_ref()).map_err(|e| { + Error::new(ErrorKind::DataInvalid, "Can't parse decimal.").with_source(e) + })?; + + Self::decimal(decimal) + } + + /// Try to create a decimal literal from [`Decimal`]. + /// + /// Example: + /// + /// ```rust + /// use iceberg::spec::Datum; + /// use rust_decimal::Decimal; + /// + /// let t = Datum::decimal(Decimal::new(123, 2)).unwrap(); + /// + /// assert_eq!(&format!("{t}"), "1.23"); + /// ``` + pub fn decimal(value: impl Into) -> Result { + let decimal = value.into(); + let scale = decimal.scale(); + + let r#type = Type::decimal(MAX_DECIMAL_PRECISION, scale)?; + if let Type::Primitive(p) = r#type { + Ok(Self { + r#type: p, + literal: PrimitiveLiteral::Int128(decimal.mantissa()), + }) + } else { + unreachable!("Decimal type must be primitive.") + } + } + + /// Convert the datum to `target_type`. + pub fn to(self, target_type: &Type) -> Result { + match target_type { + Type::Primitive(target_primitive_type) => { + match (&self.literal, &self.r#type, target_primitive_type) { + (PrimitiveLiteral::Int(val), _, PrimitiveType::Int) => Ok(Datum::int(*val)), + (PrimitiveLiteral::Int(val), _, PrimitiveType::Date) => Ok(Datum::date(*val)), + // TODO: implement more type conversions + (_, self_type, target_type) if self_type == target_type => Ok(self), + _ => Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Can't convert datum from {} type to {} type.", + self.r#type, target_primitive_type + ), + )), + } + } + _ => Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Can't convert datum from {} type to {} type.", + self.r#type, target_type + ), + )), + } + } + + /// Get the primitive literal from datum. + pub fn literal(&self) -> &PrimitiveLiteral { + &self.literal + } + + /// Get the primitive type from datum. + pub fn data_type(&self) -> &PrimitiveType { + &self.r#type + } + + /// Returns true if the Literal represents a primitive type + /// that can be a NaN, and that it's value is NaN + pub fn is_nan(&self) -> bool { + match self.literal { + PrimitiveLiteral::Double(val) => val.is_nan(), + PrimitiveLiteral::Float(val) => val.is_nan(), + _ => false, + } + } +} + +/// Map is a collection of key-value pairs with a key type and a value type. +/// It used in Literal::Map, to make it hashable, the order of key-value pairs is stored in a separate vector +/// so that we can hash the map in a deterministic way. But it also means that the order of key-value pairs is matter +/// for the hash value. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Map { + index: HashMap, + pair: Vec<(Literal, Option)>, +} + +impl Map { + /// Creates a new empty map. + pub fn new() -> Self { + Self { + index: HashMap::new(), + pair: Vec::new(), + } + } + + /// Return the number of key-value pairs in the map. + pub fn len(&self) -> usize { + self.pair.len() + } + + /// Returns true if the map contains no elements. + pub fn is_empty(&self) -> bool { + self.pair.is_empty() + } + + /// Inserts a key-value pair into the map. + /// If the map did not have this key present, None is returned. + /// If the map did have this key present, the value is updated, and the old value is returned. + pub fn insert(&mut self, key: Literal, value: Option) -> Option> { + if let Some(index) = self.index.get(&key) { + let old_value = std::mem::replace(&mut self.pair[*index].1, value); + Some(old_value) + } else { + self.pair.push((key.clone(), value)); + self.index.insert(key, self.pair.len() - 1); + None + } + } + + /// Returns a reference to the value corresponding to the key. + /// If the key is not present in the map, None is returned. + pub fn get(&self, key: &Literal) -> Option<&Option> { + self.index.get(key).map(|index| &self.pair[*index].1) + } + + /// The order of map is matter, so this method used to compare two maps has same key-value pairs without considering the order. + pub fn has_same_content(&self, other: &Map) -> bool { + if self.len() != other.len() { + return false; + } + + for (key, value) in &self.pair { + match other.get(key) { + Some(other_value) if value == other_value => (), + _ => return false, + } + } + + true + } +} + +impl Default for Map { + fn default() -> Self { + Self::new() + } +} + +impl Hash for Map { + fn hash(&self, state: &mut H) { + for (key, value) in &self.pair { + key.hash(state); + value.hash(state); + } + } +} + +impl FromIterator<(Literal, Option)> for Map { + fn from_iter)>>(iter: T) -> Self { + let mut map = Map::new(); + for (key, value) in iter { + map.insert(key, value); + } + map + } +} + +impl IntoIterator for Map { + type Item = (Literal, Option); + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.pair.into_iter() + } +} + +impl From<[(Literal, Option); N]> for Map { + fn from(value: [(Literal, Option); N]) -> Self { + value.iter().cloned().collect() + } } /// Values present in iceberg type -#[derive(Clone, Debug, PartialEq, Hash, Eq, PartialOrd, Ord)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum Literal { /// A primitive value Primitive(PrimitiveLiteral), @@ -86,7 +1183,7 @@ pub enum Literal { /// A map is a collection of key-value pairs with a key type and a value type. /// Both the key field and value field each have an integer id that is unique in the table schema. /// Map keys are required and map values can be either optional or required. Both map keys and map values may be any type, including nested types. - Map(BTreeMap>), + Map(Map), } impl Literal { @@ -150,11 +1247,14 @@ impl Literal { /// /// Example: /// ```rust - /// use ordered_float::OrderedFloat; /// use iceberg::spec::{Literal, PrimitiveLiteral}; - /// let t = Literal::float( 32.1f32 ); + /// use ordered_float::OrderedFloat; + /// let t = Literal::float(32.1f32); /// - /// assert_eq!(Literal::Primitive(PrimitiveLiteral::Float(OrderedFloat(32.1))), t); + /// assert_eq!( + /// Literal::Primitive(PrimitiveLiteral::Float(OrderedFloat(32.1))), + /// t + /// ); /// ``` pub fn float>(t: T) -> Self { Self::Primitive(PrimitiveLiteral::Float(OrderedFloat(t.into()))) @@ -164,30 +1264,22 @@ impl Literal { /// /// Example: /// ```rust - /// use ordered_float::OrderedFloat; /// use iceberg::spec::{Literal, PrimitiveLiteral}; - /// let t = Literal::double( 32.1f64 ); + /// use ordered_float::OrderedFloat; + /// let t = Literal::double(32.1f64); /// - /// assert_eq!(Literal::Primitive(PrimitiveLiteral::Double(OrderedFloat(32.1))), t); + /// assert_eq!( + /// Literal::Primitive(PrimitiveLiteral::Double(OrderedFloat(32.1))), + /// t + /// ); /// ``` pub fn double>(t: T) -> Self { Self::Primitive(PrimitiveLiteral::Double(OrderedFloat(t.into()))) } - /// Returns unix epoch. - pub fn unix_epoch() -> DateTime { - Utc.timestamp_nanos(0) - } - /// Creates date literal from number of days from unix epoch directly. pub fn date(days: i32) -> Self { - Self::Primitive(PrimitiveLiteral::Date(days)) - } - - /// Creates date literal from `NaiveDate`, assuming it's utc timezone. - fn date_from_naive_date(date: NaiveDate) -> Self { - let days = (date - Self::unix_epoch().date_naive()).num_days(); - Self::date(days as i32) + Self::Primitive(PrimitiveLiteral::Int(days)) } /// Creates a date in `%Y-%m-%d` format, assume in utc timezone. @@ -210,7 +1302,7 @@ impl Literal { .with_source(e) })?; - Ok(Self::date_from_naive_date(t)) + Ok(Self::date(date_from_naive_date(t))) } /// Create a date from calendar date (year, month and day). @@ -233,17 +1325,17 @@ impl Literal { ) })?; - Ok(Self::date_from_naive_date(t)) + Ok(Self::date(date_from_naive_date(t))) } /// Creates time in microseconds directly pub fn time(value: i64) -> Self { - Self::Primitive(PrimitiveLiteral::Time(value)) + Self::Primitive(PrimitiveLiteral::Long(value)) } /// Creates time literal from [`chrono::NaiveTime`]. fn time_from_naive_time(t: NaiveTime) -> Self { - let duration = t - Self::unix_epoch().time(); + let duration = t - unix_epoch().time(); // It's safe to unwrap here since less than 24 hours will never overflow. let micro_secs = duration.num_microseconds().unwrap(); @@ -263,7 +1355,7 @@ impl Literal { /// 1 * 3600 * 1_000_000 + // 1 hour /// 2 * 60 * 1_000_000 + // 2 minutes /// 1 * 1_000_000 + // 1 second - /// 888999 // microseconds + /// 888999 // microseconds /// }; /// assert_eq!(Literal::time(micro_secs), t); /// ``` @@ -285,7 +1377,6 @@ impl Literal { /// /// Example: /// ```rust - /// /// use iceberg::spec::Literal; /// let t = Literal::time_from_hms_micro(22, 15, 33, 111).unwrap(); /// @@ -302,12 +1393,12 @@ impl Literal { /// Creates a timestamp from unix epoch in microseconds. pub fn timestamp(value: i64) -> Self { - Self::Primitive(PrimitiveLiteral::Timestamp(value)) + Self::Primitive(PrimitiveLiteral::Long(value)) } /// Creates a timestamp with timezone from unix epoch in microseconds. pub fn timestamptz(value: i64) -> Self { - Self::Primitive(PrimitiveLiteral::TimestampTZ(value)) + Self::Primitive(PrimitiveLiteral::Long(value)) } /// Creates a timestamp from [`DateTime`]. @@ -332,10 +1423,13 @@ impl Literal { /// let t = Literal::timestamp_from_str("2012-12-12 12:12:12.8899-04:00").unwrap(); /// /// let t2 = { - /// let date = NaiveDate::from_ymd_opt(2012, 12, 12).unwrap(); - /// let time = NaiveTime::from_hms_micro_opt(12, 12, 12, 889900).unwrap(); - /// let dt = NaiveDateTime::new(date, time); - /// Literal::timestamp_from_datetime(DateTime::::from_local(dt, FixedOffset::west_opt(4 * 3600).unwrap())) + /// let date = NaiveDate::from_ymd_opt(2012, 12, 12).unwrap(); + /// let time = NaiveTime::from_hms_micro_opt(12, 12, 12, 889900).unwrap(); + /// let dt = NaiveDateTime::new(date, time); + /// Literal::timestamp_from_datetime(DateTime::::from_local( + /// dt, + /// FixedOffset::west_opt(4 * 3600).unwrap(), + /// )) /// }; /// /// assert_eq!(t, t2); @@ -364,7 +1458,7 @@ impl Literal { /// Creates uuid literal. pub fn uuid(uuid: Uuid) -> Self { - Self::Primitive(PrimitiveLiteral::UUID(uuid)) + Self::Primitive(PrimitiveLiteral::UInt128(uuid.as_u128())) } /// Creates uuid from str. See [`Uuid::parse_str`]. @@ -372,8 +1466,8 @@ impl Literal { /// Example: /// /// ```rust - /// use uuid::Uuid; /// use iceberg::spec::Literal; + /// use uuid::Uuid; /// let t1 = Literal::uuid_from_str("a1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8").unwrap(); /// let t2 = Literal::uuid(Uuid::from_u128_le(0xd8d7d6d5d4d3d2d1c2c1b2b1a4a3a2a1)); /// @@ -397,12 +1491,12 @@ impl Literal { /// ```rust /// use iceberg::spec::{Literal, PrimitiveLiteral}; /// let t1 = Literal::fixed(vec![1u8, 2u8]); - /// let t2 = Literal::Primitive(PrimitiveLiteral::Fixed(vec![1u8, 2u8])); + /// let t2 = Literal::Primitive(PrimitiveLiteral::Binary(vec![1u8, 2u8])); /// /// assert_eq!(t1, t2); /// ``` pub fn fixed>(input: I) -> Self { - Literal::Primitive(PrimitiveLiteral::Fixed(input.into_iter().collect())) + Literal::Primitive(PrimitiveLiteral::Binary(input.into_iter().collect())) } /// Creates a binary literal from bytes. @@ -422,7 +1516,7 @@ impl Literal { /// Creates a decimal literal. pub fn decimal(decimal: i128) -> Self { - Self::Primitive(PrimitiveLiteral::Decimal(decimal)) + Self::Primitive(PrimitiveLiteral::Int128(decimal)) } /// Creates decimal literal from string. See [`Decimal::from_str_exact`]. @@ -430,8 +1524,8 @@ impl Literal { /// Example: /// /// ```rust - /// use rust_decimal::Decimal; /// use iceberg::spec::Literal; + /// use rust_decimal::Decimal; /// let t1 = Literal::decimal(12345); /// let t2 = Literal::decimal_from_str("123.45").unwrap(); /// @@ -445,70 +1539,10 @@ impl Literal { } } -impl From for ByteBuf { - fn from(value: Literal) -> Self { - match value { - Literal::Primitive(prim) => match prim { - PrimitiveLiteral::Boolean(val) => { - if val { - ByteBuf::from([1u8]) - } else { - ByteBuf::from([0u8]) - } - } - PrimitiveLiteral::Int(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Long(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Float(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Double(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Date(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Time(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Timestamp(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::TimestampTZ(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::String(val) => ByteBuf::from(val.as_bytes()), - PrimitiveLiteral::UUID(val) => ByteBuf::from(val.as_u128().to_be_bytes()), - PrimitiveLiteral::Fixed(val) => ByteBuf::from(val), - PrimitiveLiteral::Binary(val) => ByteBuf::from(val), - PrimitiveLiteral::Decimal(_) => todo!(), - }, - _ => unimplemented!(), - } - } -} - -impl From for Vec { - fn from(value: Literal) -> Self { - match value { - Literal::Primitive(prim) => match prim { - PrimitiveLiteral::Boolean(val) => { - if val { - Vec::from([1u8]) - } else { - Vec::from([0u8]) - } - } - PrimitiveLiteral::Int(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Long(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Float(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Double(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Date(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Time(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Timestamp(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::TimestampTZ(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::String(val) => Vec::from(val.as_bytes()), - PrimitiveLiteral::UUID(val) => Vec::from(val.as_u128().to_be_bytes()), - PrimitiveLiteral::Fixed(val) => val, - PrimitiveLiteral::Binary(val) => val, - PrimitiveLiteral::Decimal(_) => todo!(), - }, - _ => unimplemented!(), - } - } -} - /// The partition struct stores the tuple of partition values for each file. /// Its type is derived from the partition fields of the partition spec used to write the manifest file. /// In v2, the partition struct’s field ids must match the ids from the partition spec. -#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Struct { /// Vector to store the field values fields: Vec, @@ -537,6 +1571,19 @@ impl Struct { }, ) } + + /// returns true if the field at position `index` is null + pub fn is_null_at_index(&self, index: usize) -> bool { + self.null_bitmap[index] + } +} + +impl Index for Struct { + type Output = Literal; + + fn index(&self, idx: usize) -> &Self::Output { + &self.fields[idx] + } } /// An iterator that moves out of a struct. @@ -569,16 +1616,12 @@ impl IntoIterator for Struct { } } -impl FromIterator<(i32, Option, String)> for Struct { - fn from_iter, String)>>(iter: I) -> Self { +impl FromIterator> for Struct { + fn from_iter>>(iter: I) -> Self { let mut fields = Vec::new(); - let mut field_ids = Vec::new(); - let mut field_names = Vec::new(); let mut null_bitmap = BitVec::new(); - for (id, value, name) in iter.into_iter() { - field_ids.push(id); - field_names.push(name); + for value in iter.into_iter() { match value { Some(value) => { fields.push(value); @@ -598,66 +1641,9 @@ impl FromIterator<(i32, Option, String)> for Struct { } impl Literal { - /// Create iceberg value from bytes - pub fn try_from_bytes(bytes: &[u8], data_type: &Type) -> Result { - match data_type { - Type::Primitive(primitive) => match primitive { - PrimitiveType::Boolean => { - if bytes.len() == 1 && bytes[0] == 0u8 { - Ok(Literal::Primitive(PrimitiveLiteral::Boolean(false))) - } else { - Ok(Literal::Primitive(PrimitiveLiteral::Boolean(true))) - } - } - PrimitiveType::Int => Ok(Literal::Primitive(PrimitiveLiteral::Int( - i32::from_le_bytes(bytes.try_into()?), - ))), - PrimitiveType::Long => Ok(Literal::Primitive(PrimitiveLiteral::Long( - i64::from_le_bytes(bytes.try_into()?), - ))), - PrimitiveType::Float => Ok(Literal::Primitive(PrimitiveLiteral::Float( - OrderedFloat(f32::from_le_bytes(bytes.try_into()?)), - ))), - PrimitiveType::Double => Ok(Literal::Primitive(PrimitiveLiteral::Double( - OrderedFloat(f64::from_le_bytes(bytes.try_into()?)), - ))), - PrimitiveType::Date => Ok(Literal::Primitive(PrimitiveLiteral::Date( - i32::from_le_bytes(bytes.try_into()?), - ))), - PrimitiveType::Time => Ok(Literal::Primitive(PrimitiveLiteral::Time( - i64::from_le_bytes(bytes.try_into()?), - ))), - PrimitiveType::Timestamp => Ok(Literal::Primitive(PrimitiveLiteral::Timestamp( - i64::from_le_bytes(bytes.try_into()?), - ))), - PrimitiveType::Timestamptz => Ok(Literal::Primitive( - PrimitiveLiteral::TimestampTZ(i64::from_le_bytes(bytes.try_into()?)), - )), - PrimitiveType::String => Ok(Literal::Primitive(PrimitiveLiteral::String( - std::str::from_utf8(bytes)?.to_string(), - ))), - PrimitiveType::Uuid => Ok(Literal::Primitive(PrimitiveLiteral::UUID( - Uuid::from_u128(u128::from_be_bytes(bytes.try_into()?)), - ))), - PrimitiveType::Fixed(_) => Ok(Literal::Primitive(PrimitiveLiteral::Fixed( - Vec::from(bytes), - ))), - PrimitiveType::Binary => Ok(Literal::Primitive(PrimitiveLiteral::Binary( - Vec::from(bytes), - ))), - PrimitiveType::Decimal { - precision: _, - scale: _, - } => todo!(), - }, - _ => Err(Error::new( - crate::ErrorKind::DataInvalid, - "Converting bytes to non-primitive types is not supported.", - )), - } - } - /// Create iceberg value from a json value + /// + /// See [this spec](https://iceberg.apache.org/spec/#json-single-value-serialization) for reference. pub fn try_from_json(value: JsonValue, data_type: &Type) -> Result> { match data_type { Type::Primitive(primitive) => match (primitive, value) { @@ -694,22 +1680,22 @@ impl Literal { ))?)), ))), (PrimitiveType::Date, JsonValue::String(s)) => { - Ok(Some(Literal::Primitive(PrimitiveLiteral::Date( + Ok(Some(Literal::Primitive(PrimitiveLiteral::Int( date::date_to_days(&NaiveDate::parse_from_str(&s, "%Y-%m-%d")?), )))) } (PrimitiveType::Time, JsonValue::String(s)) => { - Ok(Some(Literal::Primitive(PrimitiveLiteral::Time( + Ok(Some(Literal::Primitive(PrimitiveLiteral::Long( time::time_to_microseconds(&NaiveTime::parse_from_str(&s, "%H:%M:%S%.f")?), )))) } (PrimitiveType::Timestamp, JsonValue::String(s)) => Ok(Some(Literal::Primitive( - PrimitiveLiteral::Timestamp(timestamp::datetime_to_microseconds( + PrimitiveLiteral::Long(timestamp::datetime_to_microseconds( &NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S%.f")?, )), ))), (PrimitiveType::Timestamptz, JsonValue::String(s)) => { - Ok(Some(Literal::Primitive(PrimitiveLiteral::TimestampTZ( + Ok(Some(Literal::Primitive(PrimitiveLiteral::Long( timestamptz::datetimetz_to_microseconds(&Utc.from_utc_datetime( &NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S%.f+00:00")?, )), @@ -719,7 +1705,7 @@ impl Literal { Ok(Some(Literal::Primitive(PrimitiveLiteral::String(s)))) } (PrimitiveType::Uuid, JsonValue::String(s)) => Ok(Some(Literal::Primitive( - PrimitiveLiteral::UUID(Uuid::parse_str(&s)?), + PrimitiveLiteral::UInt128(Uuid::parse_str(&s)?.as_u128()), ))), (PrimitiveType::Fixed(_), JsonValue::String(_)) => todo!(), (PrimitiveType::Binary, JsonValue::String(_)) => todo!(), @@ -732,7 +1718,7 @@ impl Literal { ) => { let mut decimal = Decimal::from_str_exact(&s)?; decimal.rescale(*scale); - Ok(Some(Literal::Primitive(PrimitiveLiteral::Decimal( + Ok(Some(Literal::Primitive(PrimitiveLiteral::Int128( decimal.mantissa(), )))) } @@ -749,20 +1735,16 @@ impl Literal { if let JsonValue::Object(mut object) = value { Ok(Some(Literal::Struct(Struct::from_iter( schema.fields().iter().map(|field| { - ( - field.id, - object.remove(&field.id.to_string()).and_then(|value| { - Literal::try_from_json(value, &field.field_type) - .and_then(|value| { - value.ok_or(Error::new( - ErrorKind::DataInvalid, - "Key of map cannot be null", - )) - }) - .ok() - }), - field.name.clone(), - ) + object.remove(&field.id.to_string()).and_then(|value| { + Literal::try_from_json(value, &field.field_type) + .and_then(|value| { + value.ok_or(Error::new( + ErrorKind::DataInvalid, + "Key of map cannot be null", + )) + }) + .ok() + }) }), )))) } else { @@ -794,7 +1776,7 @@ impl Literal { if let (Some(JsonValue::Array(keys)), Some(JsonValue::Array(values))) = (object.remove("keys"), object.remove("values")) { - Ok(Some(Literal::Map(BTreeMap::from_iter( + Ok(Some(Literal::Map(Map::from_iter( keys.into_iter() .zip(values.into_iter()) .map(|(key, value)| { @@ -832,51 +1814,70 @@ impl Literal { /// See [this spec](https://iceberg.apache.org/spec/#json-single-value-serialization) for reference. pub fn try_into_json(self, r#type: &Type) -> Result { match (self, r#type) { - (Literal::Primitive(prim), _) => match prim { - PrimitiveLiteral::Boolean(val) => Ok(JsonValue::Bool(val)), - PrimitiveLiteral::Int(val) => Ok(JsonValue::Number((val).into())), - PrimitiveLiteral::Long(val) => Ok(JsonValue::Number((val).into())), - PrimitiveLiteral::Float(val) => match Number::from_f64(val.0 as f64) { - Some(number) => Ok(JsonValue::Number(number)), - None => Ok(JsonValue::Null), - }, - PrimitiveLiteral::Double(val) => match Number::from_f64(val.0) { - Some(number) => Ok(JsonValue::Number(number)), - None => Ok(JsonValue::Null), - }, - PrimitiveLiteral::Date(val) => { + (Literal::Primitive(prim), Type::Primitive(prim_type)) => match (prim_type, prim) { + (PrimitiveType::Boolean, PrimitiveLiteral::Boolean(val)) => { + Ok(JsonValue::Bool(val)) + } + (PrimitiveType::Int, PrimitiveLiteral::Int(val)) => { + Ok(JsonValue::Number((val).into())) + } + (PrimitiveType::Long, PrimitiveLiteral::Long(val)) => { + Ok(JsonValue::Number((val).into())) + } + (PrimitiveType::Float, PrimitiveLiteral::Float(val)) => { + match Number::from_f64(val.0 as f64) { + Some(number) => Ok(JsonValue::Number(number)), + None => Ok(JsonValue::Null), + } + } + (PrimitiveType::Double, PrimitiveLiteral::Double(val)) => { + match Number::from_f64(val.0) { + Some(number) => Ok(JsonValue::Number(number)), + None => Ok(JsonValue::Null), + } + } + (PrimitiveType::Date, PrimitiveLiteral::Int(val)) => { Ok(JsonValue::String(date::days_to_date(val).to_string())) } - PrimitiveLiteral::Time(val) => Ok(JsonValue::String( + (PrimitiveType::Time, PrimitiveLiteral::Long(val)) => Ok(JsonValue::String( time::microseconds_to_time(val).to_string(), )), - PrimitiveLiteral::Timestamp(val) => Ok(JsonValue::String( + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(val)) => Ok(JsonValue::String( timestamp::microseconds_to_datetime(val) .format("%Y-%m-%dT%H:%M:%S%.f") .to_string(), )), - PrimitiveLiteral::TimestampTZ(val) => Ok(JsonValue::String( + (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(val)) => Ok(JsonValue::String( timestamptz::microseconds_to_datetimetz(val) .format("%Y-%m-%dT%H:%M:%S%.f+00:00") .to_string(), )), - PrimitiveLiteral::String(val) => Ok(JsonValue::String(val.clone())), - PrimitiveLiteral::UUID(val) => Ok(JsonValue::String(val.to_string())), - PrimitiveLiteral::Fixed(val) => Ok(JsonValue::String(val.iter().fold( - String::new(), - |mut acc, x| { - acc.push_str(&format!("{:x}", x)); - acc - }, - ))), - PrimitiveLiteral::Binary(val) => Ok(JsonValue::String(val.iter().fold( + (PrimitiveType::TimestampNs, PrimitiveLiteral::Long(val)) => Ok(JsonValue::String( + timestamp::nanoseconds_to_datetime(val) + .format("%Y-%m-%dT%H:%M:%S%.f") + .to_string(), + )), + (PrimitiveType::TimestamptzNs, PrimitiveLiteral::Long(val)) => { + Ok(JsonValue::String( + timestamptz::nanoseconds_to_datetimetz(val) + .format("%Y-%m-%dT%H:%M:%S%.f+00:00") + .to_string(), + )) + } + (PrimitiveType::String, PrimitiveLiteral::String(val)) => { + Ok(JsonValue::String(val.clone())) + } + (_, PrimitiveLiteral::UInt128(val)) => { + Ok(JsonValue::String(Uuid::from_u128(val).to_string())) + } + (_, PrimitiveLiteral::Binary(val)) => Ok(JsonValue::String(val.iter().fold( String::new(), |mut acc, x| { acc.push_str(&format!("{:x}", x)); acc }, ))), - PrimitiveLiteral::Decimal(val) => match r#type { + (_, PrimitiveLiteral::Int128(val)) => match r#type { Type::Primitive(PrimitiveType::Decimal { precision: _precision, scale, @@ -889,6 +1890,10 @@ impl Literal { "The iceberg type for decimal literal must be decimal.", ))?, }, + _ => Err(Error::new( + ErrorKind::DataInvalid, + "The iceberg value doesn't fit to the iceberg type.", + )), }, (Literal::Struct(s), Type::Struct(struct_type)) => { let mut id_and_value = Vec::with_capacity(struct_type.fields().len()); @@ -943,15 +1948,10 @@ impl Literal { PrimitiveLiteral::Long(any) => Box::new(any), PrimitiveLiteral::Float(any) => Box::new(any), PrimitiveLiteral::Double(any) => Box::new(any), - PrimitiveLiteral::Date(any) => Box::new(any), - PrimitiveLiteral::Time(any) => Box::new(any), - PrimitiveLiteral::Timestamp(any) => Box::new(any), - PrimitiveLiteral::TimestampTZ(any) => Box::new(any), - PrimitiveLiteral::Fixed(any) => Box::new(any), PrimitiveLiteral::Binary(any) => Box::new(any), PrimitiveLiteral::String(any) => Box::new(any), - PrimitiveLiteral::UUID(any) => Box::new(any), - PrimitiveLiteral::Decimal(any) => Box::new(any), + PrimitiveLiteral::UInt128(any) => Box::new(any), + PrimitiveLiteral::Int128(any) => Box::new(any), }, _ => unimplemented!(), } @@ -959,7 +1959,7 @@ impl Literal { } mod date { - use chrono::{NaiveDate, NaiveDateTime}; + use chrono::{DateTime, NaiveDate, TimeDelta, TimeZone, Utc}; pub(crate) fn date_to_days(date: &NaiveDate) -> i32 { date.signed_duration_since( @@ -971,10 +1971,20 @@ mod date { pub(crate) fn days_to_date(days: i32) -> NaiveDate { // This shouldn't fail until the year 262000 - NaiveDateTime::from_timestamp_opt(days as i64 * 86_400, 0) - .unwrap() + (chrono::DateTime::UNIX_EPOCH + TimeDelta::try_days(days as i64).unwrap()) + .naive_utc() .date() } + + /// Returns unix epoch. + pub(crate) fn unix_epoch() -> DateTime { + Utc.timestamp_nanos(0) + } + + /// Creates date literal from `NaiveDate`, assuming it's utc timezone. + pub(crate) fn date_from_naive_date(date: NaiveDate) -> i32 { + (date - unix_epoch().date_naive()).num_days() as i32 + } } mod time { @@ -997,22 +2007,24 @@ mod time { } mod timestamp { - use chrono::NaiveDateTime; + use chrono::{DateTime, NaiveDateTime}; pub(crate) fn datetime_to_microseconds(time: &NaiveDateTime) -> i64 { - time.timestamp_micros() + time.and_utc().timestamp_micros() } pub(crate) fn microseconds_to_datetime(micros: i64) -> NaiveDateTime { - let (secs, rem) = (micros / 1_000_000, micros % 1_000_000); - // This shouldn't fail until the year 262000 - NaiveDateTime::from_timestamp_opt(secs, rem as u32 * 1_000).unwrap() + DateTime::from_timestamp_micros(micros).unwrap().naive_utc() + } + + pub(crate) fn nanoseconds_to_datetime(nanos: i64) -> NaiveDateTime { + DateTime::from_timestamp_nanos(nanos).naive_utc() } } mod timestamptz { - use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; + use chrono::{DateTime, Utc}; pub(crate) fn datetimetz_to_microseconds(time: &DateTime) -> i64 { time.timestamp_micros() @@ -1021,30 +2033,26 @@ mod timestamptz { pub(crate) fn microseconds_to_datetimetz(micros: i64) -> DateTime { let (secs, rem) = (micros / 1_000_000, micros % 1_000_000); - Utc.from_utc_datetime( - // This shouldn't fail until the year 262000 - &NaiveDateTime::from_timestamp_opt(secs, rem as u32 * 1_000).unwrap(), - ) + DateTime::from_timestamp(secs, rem as u32 * 1_000).unwrap() + } + + pub(crate) fn nanoseconds_to_datetimetz(nanos: i64) -> DateTime { + let (secs, rem) = (nanos / 1_000_000_000, nanos % 1_000_000_000); + + DateTime::from_timestamp(secs, rem as u32).unwrap() } } mod _serde { - use std::collections::BTreeMap; - - use crate::{ - spec::{PrimitiveType, Type, MAP_KEY_FIELD_NAME, MAP_VALUE_FIELD_NAME}, - Error, ErrorKind, - }; - - use super::{Literal, PrimitiveLiteral}; - use serde::{ - de::Visitor, - ser::{SerializeMap, SerializeSeq, SerializeStruct}, - Deserialize, Serialize, - }; + use serde::de::Visitor; + use serde::ser::{SerializeMap, SerializeSeq, SerializeStruct}; + use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; - use serde_derive::Deserialize as DeserializeDerive; - use serde_derive::Serialize as SerializeDerive; + use serde_derive::{Deserialize as DeserializeDerive, Serialize as SerializeDerive}; + + use super::{Literal, Map, PrimitiveLiteral}; + use crate::spec::{PrimitiveType, Type, MAP_KEY_FIELD_NAME, MAP_VALUE_FIELD_NAME}; + use crate::{Error, ErrorKind}; #[derive(SerializeDerive, DeserializeDerive, Debug)] #[serde(transparent)] @@ -1087,9 +2095,7 @@ mod _serde { impl Serialize for Record { fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { + where S: serde::Serializer { let len = self.required.len() + self.optional.len(); let mut record = serializer.serialize_struct("", len)?; for (k, v) in &self.required { @@ -1110,9 +2116,7 @@ mod _serde { impl Serialize for List { fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { + where S: serde::Serializer { let mut seq = serializer.serialize_seq(Some(self.list.len()))?; for value in &self.list { if self.required { @@ -1137,9 +2141,7 @@ mod _serde { impl Serialize for StringMap { fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { + where S: serde::Serializer { let mut map = serializer.serialize_map(Some(self.raw.len()))?; for (k, v) in &self.raw { if self.required { @@ -1161,9 +2163,7 @@ mod _serde { impl<'de> Deserialize<'de> for RawLiteralEnum { fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { + where D: serde::Deserializer<'de> { struct RawLiteralVisitor; impl<'de> Visitor<'de> for RawLiteralVisitor { type Value = RawLiteralEnum; @@ -1173,80 +2173,58 @@ mod _serde { } fn visit_bool(self, v: bool) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::Boolean(v)) } fn visit_i32(self, v: i32) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::Int(v)) } fn visit_i64(self, v: i64) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::Long(v)) } /// Used in json fn visit_u64(self, v: u64) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::Long(v as i64)) } fn visit_f32(self, v: f32) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::Float(v)) } fn visit_f64(self, v: f64) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::Double(v)) } fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::String(v.to_string())) } fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::Bytes(ByteBuf::from(v))) } fn visit_borrowed_str(self, v: &'de str) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::String(v.to_string())) } fn visit_unit(self) -> Result - where - E: serde::de::Error, - { + where E: serde::de::Error { Ok(RawLiteralEnum::Null) } fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { + where A: serde::de::MapAccess<'de> { let mut required = Vec::new(); while let Some(key) = map.next_key::()? { let value = map.next_value::()?; @@ -1259,9 +2237,7 @@ mod _serde { } fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { + where A: serde::de::SeqAccess<'de> { let mut list = Vec::new(); while let Some(value) = seq.next_element::()? { list.push(Some(value)); @@ -1286,17 +2262,12 @@ mod _serde { super::PrimitiveLiteral::Long(v) => RawLiteralEnum::Long(v), super::PrimitiveLiteral::Float(v) => RawLiteralEnum::Float(v.0), super::PrimitiveLiteral::Double(v) => RawLiteralEnum::Double(v.0), - super::PrimitiveLiteral::Date(v) => RawLiteralEnum::Int(v), - super::PrimitiveLiteral::Time(v) => RawLiteralEnum::Long(v), - super::PrimitiveLiteral::Timestamp(v) => RawLiteralEnum::Long(v), - super::PrimitiveLiteral::TimestampTZ(v) => RawLiteralEnum::Long(v), super::PrimitiveLiteral::String(v) => RawLiteralEnum::String(v), - super::PrimitiveLiteral::UUID(v) => { - RawLiteralEnum::Bytes(ByteBuf::from(v.as_u128().to_be_bytes())) + super::PrimitiveLiteral::UInt128(v) => { + RawLiteralEnum::Bytes(ByteBuf::from(v.to_be_bytes())) } - super::PrimitiveLiteral::Fixed(v) => RawLiteralEnum::Bytes(ByteBuf::from(v)), super::PrimitiveLiteral::Binary(v) => RawLiteralEnum::Bytes(ByteBuf::from(v)), - super::PrimitiveLiteral::Decimal(v) => { + super::PrimitiveLiteral::Int128(v) => { RawLiteralEnum::Bytes(ByteBuf::from(v.to_be_bytes())) } }, @@ -1452,10 +2423,17 @@ mod _serde { RawLiteralEnum::Boolean(v) => Ok(Some(Literal::bool(v))), RawLiteralEnum::Int(v) => match ty { Type::Primitive(PrimitiveType::Int) => Ok(Some(Literal::int(v))), + Type::Primitive(PrimitiveType::Long) => Ok(Some(Literal::long(i64::from(v)))), Type::Primitive(PrimitiveType::Date) => Ok(Some(Literal::date(v))), _ => Err(invalid_err("int")), }, RawLiteralEnum::Long(v) => match ty { + Type::Primitive(PrimitiveType::Int) => Ok(Some(Literal::int( + i32::try_from(v).map_err(|_| invalid_err("long"))?, + ))), + Type::Primitive(PrimitiveType::Date) => Ok(Some(Literal::date( + i32::try_from(v).map_err(|_| invalid_err("long"))?, + ))), Type::Primitive(PrimitiveType::Long) => Ok(Some(Literal::long(v))), Type::Primitive(PrimitiveType::Time) => Ok(Some(Literal::time(v))), Type::Primitive(PrimitiveType::Timestamp) => Ok(Some(Literal::timestamp(v))), @@ -1466,9 +2444,23 @@ mod _serde { }, RawLiteralEnum::Float(v) => match ty { Type::Primitive(PrimitiveType::Float) => Ok(Some(Literal::float(v))), + Type::Primitive(PrimitiveType::Double) => { + Ok(Some(Literal::double(f64::from(v)))) + } _ => Err(invalid_err("float")), }, RawLiteralEnum::Double(v) => match ty { + Type::Primitive(PrimitiveType::Float) => { + let v_32 = v as f32; + if v_32.is_finite() { + let v_64 = f64::from(v_32); + if (v_64 - v).abs() > f32::EPSILON as f64 { + // there is a precision loss + return Err(invalid_err("double")); + } + } + Ok(Some(Literal::float(v_32))) + } Type::Primitive(PrimitiveType::Double) => Ok(Some(Literal::double(v))), _ => Err(invalid_err("double")), }, @@ -1499,7 +2491,7 @@ mod _serde { Type::Map(map_ty) => { let key_ty = map_ty.key_field.field_type.as_ref(); let value_ty = map_ty.value_field.field_type.as_ref(); - let mut map = BTreeMap::new(); + let mut map = Map::new(); for k_v in v.list { let k_v = k_v.ok_or_else(|| invalid_err_with_reason("list","In deserialize, None will be represented as Some(RawLiteral::Null), all element in list must be valid"))?; if let RawLiteralEnum::Record(Record { @@ -1550,6 +2542,89 @@ mod _serde { } Ok(Some(Literal::Map(map))) } + Type::Primitive(PrimitiveType::Uuid) => { + if v.list.len() != 16 { + return Err(invalid_err_with_reason( + "list", + "The length of list should be 16", + )); + } + let mut bytes = [0u8; 16]; + for (i, v) in v.list.iter().enumerate() { + if let Some(RawLiteralEnum::Long(v)) = v { + bytes[i] = *v as u8; + } else { + return Err(invalid_err_with_reason( + "list", + "The element of list should be int", + )); + } + } + Ok(Some(Literal::uuid(uuid::Uuid::from_bytes(bytes)))) + } + Type::Primitive(PrimitiveType::Decimal { + precision: _, + scale: _, + }) => { + if v.list.len() != 16 { + return Err(invalid_err_with_reason( + "list", + "The length of list should be 16", + )); + } + let mut bytes = [0u8; 16]; + for (i, v) in v.list.iter().enumerate() { + if let Some(RawLiteralEnum::Long(v)) = v { + bytes[i] = *v as u8; + } else { + return Err(invalid_err_with_reason( + "list", + "The element of list should be int", + )); + } + } + Ok(Some(Literal::decimal(i128::from_be_bytes(bytes)))) + } + Type::Primitive(PrimitiveType::Binary) => { + let bytes = v + .list + .into_iter() + .map(|v| { + if let Some(RawLiteralEnum::Long(v)) = v { + Ok(v as u8) + } else { + Err(invalid_err_with_reason( + "list", + "The element of list should be int", + )) + } + }) + .collect::, Error>>()?; + Ok(Some(Literal::binary(bytes))) + } + Type::Primitive(PrimitiveType::Fixed(size)) => { + if v.list.len() != *size as usize { + return Err(invalid_err_with_reason( + "list", + "The length of list should be equal to size", + )); + } + let bytes = v + .list + .into_iter() + .map(|v| { + if let Some(RawLiteralEnum::Long(v)) = v { + Ok(v as u8) + } else { + Err(invalid_err_with_reason( + "list", + "The element of list should be int", + )) + } + }) + .collect::, Error>>()?; + Ok(Some(Literal::fixed(bytes))) + } _ => Err(invalid_err("list")), } } @@ -1558,7 +2633,7 @@ mod _serde { optional: _, }) => match ty { Type::Struct(struct_ty) => { - let iters: Vec<(i32, Option, String)> = required + let iters: Vec> = required .into_iter() .map(|(field_name, value)| { let field = struct_ty @@ -1570,7 +2645,7 @@ mod _serde { ) })?; let value = value.try_into(&field.field_type)?; - Ok((field.id, value, field.name.clone())) + Ok(value) }) .collect::>()?; Ok(Some(Literal::Struct(super::Struct::from_iter(iters)))) @@ -1582,7 +2657,7 @@ mod _serde { "Map key must be string", )); } - let mut map = BTreeMap::new(); + let mut map = Map::new(); for (k, v) in required { let value = v.try_into(&map_ty.value_field.field_type)?; if map_ty.value_field.required && value.is_none() { @@ -1605,18 +2680,14 @@ mod _serde { #[cfg(test)] mod tests { - - use apache_avro::{to_value, types::Value}; - - use crate::{ - avro::schema_to_avro_schema, - spec::{ - datatypes::{ListType, MapType, NestedField, StructType}, - Schema, - }, - }; + use apache_avro::to_value; + use apache_avro::types::Value; use super::*; + use crate::avro::schema_to_avro_schema; + use crate::spec::datatypes::{ListType, MapType, NestedField, StructType}; + use crate::spec::Schema; + use crate::spec::Type::Primitive; fn check_json_serde(json: &str, expected_literal: Literal, expected_type: &Type) { let raw_json_value = serde_json::from_str::(json).unwrap(); @@ -1631,23 +2702,27 @@ mod tests { assert_eq!(parsed_json_value, raw_json_value); } - fn check_avro_bytes_serde(input: Vec, expected_literal: Literal, expected_type: &Type) { + fn check_avro_bytes_serde( + input: Vec, + expected_datum: Datum, + expected_type: &PrimitiveType, + ) { let raw_schema = r#""bytes""#; let schema = apache_avro::Schema::parse_str(raw_schema).unwrap(); let bytes = ByteBuf::from(input); - let literal = Literal::try_from_bytes(&bytes, expected_type).unwrap(); - assert_eq!(literal, expected_literal); + let datum = Datum::try_from_bytes(&bytes, expected_type.clone()).unwrap(); + assert_eq!(datum, expected_datum); let mut writer = apache_avro::Writer::new(&schema, Vec::new()); - writer.append_ser(ByteBuf::from(literal)).unwrap(); + writer.append_ser(datum.to_bytes()).unwrap(); let encoded = writer.into_inner().unwrap(); let reader = apache_avro::Reader::with_schema(&schema, &*encoded).unwrap(); for record in reader { let result = apache_avro::from_value::(&record.unwrap()).unwrap(); - let desered_literal = Literal::try_from_bytes(&result, expected_type).unwrap(); - assert_eq!(desered_literal, expected_literal); + let desered_datum = Datum::try_from_bytes(&result, expected_type.clone()).unwrap(); + assert_eq!(desered_datum, expected_datum); } } @@ -1659,11 +2734,8 @@ mod tests { .unwrap(); let avro_schema = schema_to_avro_schema("test", &schema).unwrap(); let struct_type = Type::Struct(StructType::new(fields)); - let struct_literal = Literal::Struct(Struct::from_iter(vec![( - 1, - Some(expected_literal.clone()), - "col".to_string(), - )])); + let struct_literal = + Literal::Struct(Struct::from_iter(vec![Some(expected_literal.clone())])); let mut writer = apache_avro::Writer::new(&avro_schema, Vec::new()); let raw_literal = RawLiteral::try_from(struct_literal.clone(), &struct_type).unwrap(); @@ -1688,11 +2760,7 @@ mod tests { .unwrap(); let avro_schema = schema_to_avro_schema("test", &schema).unwrap(); let struct_type = Type::Struct(StructType::new(fields)); - let struct_literal = Literal::Struct(Struct::from_iter(vec![( - 1, - Some(literal.clone()), - "col".to_string(), - )])); + let struct_literal = Literal::Struct(Struct::from_iter(vec![Some(literal.clone())])); let mut writer = apache_avro::Writer::new(&avro_schema, Vec::new()); let raw_literal = RawLiteral::try_from(struct_literal.clone(), &struct_type).unwrap(); let value = to_value(raw_literal) @@ -1769,7 +2837,7 @@ mod tests { check_json_serde( record, - Literal::Primitive(PrimitiveLiteral::Date(17486)), + Literal::Primitive(PrimitiveLiteral::Int(17486)), &Type::Primitive(PrimitiveType::Date), ); } @@ -1780,7 +2848,7 @@ mod tests { check_json_serde( record, - Literal::Primitive(PrimitiveLiteral::Time(81068123456)), + Literal::Primitive(PrimitiveLiteral::Long(81068123456)), &Type::Primitive(PrimitiveType::Time), ); } @@ -1791,7 +2859,7 @@ mod tests { check_json_serde( record, - Literal::Primitive(PrimitiveLiteral::Timestamp(1510871468123456)), + Literal::Primitive(PrimitiveLiteral::Long(1510871468123456)), &Type::Primitive(PrimitiveType::Timestamp), ); } @@ -1802,7 +2870,7 @@ mod tests { check_json_serde( record, - Literal::Primitive(PrimitiveLiteral::TimestampTZ(1510871468123456)), + Literal::Primitive(PrimitiveLiteral::Long(1510871468123456)), &Type::Primitive(PrimitiveType::Timestamptz), ); } @@ -1824,8 +2892,10 @@ mod tests { check_json_serde( record, - Literal::Primitive(PrimitiveLiteral::UUID( - Uuid::parse_str("f79c3e09-677c-4bbd-a479-3f349cb785e7").unwrap(), + Literal::Primitive(PrimitiveLiteral::UInt128( + Uuid::parse_str("f79c3e09-677c-4bbd-a479-3f349cb785e7") + .unwrap() + .as_u128(), )), &Type::Primitive(PrimitiveType::Uuid), ); @@ -1837,7 +2907,7 @@ mod tests { check_json_serde( record, - Literal::Primitive(PrimitiveLiteral::Decimal(1420)), + Literal::Primitive(PrimitiveLiteral::Int128(1420)), &Type::decimal(28, 2).unwrap(), ); } @@ -1849,19 +2919,11 @@ mod tests { check_json_serde( record, Literal::Struct(Struct::from_iter(vec![ - ( - 1, - Some(Literal::Primitive(PrimitiveLiteral::Int(1))), - "id".to_string(), - ), - ( - 2, - Some(Literal::Primitive(PrimitiveLiteral::String( - "bar".to_string(), - ))), - "name".to_string(), - ), - (3, None, "address".to_string()), + Some(Literal::Primitive(PrimitiveLiteral::Int(1))), + Some(Literal::Primitive(PrimitiveLiteral::String( + "bar".to_string(), + ))), + None, ])), &Type::Struct(StructType::new(vec![ NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), @@ -1900,7 +2962,7 @@ mod tests { check_json_serde( record, - Literal::Map(BTreeMap::from([ + Literal::Map(Map::from([ ( Literal::Primitive(PrimitiveLiteral::String("a".to_string())), Some(Literal::Primitive(PrimitiveLiteral::Int(1))), @@ -1931,66 +2993,42 @@ mod tests { fn avro_bytes_boolean() { let bytes = vec![1u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Boolean(true)), - &Type::Primitive(PrimitiveType::Boolean), - ); + check_avro_bytes_serde(bytes, Datum::bool(true), &PrimitiveType::Boolean); } #[test] fn avro_bytes_int() { let bytes = vec![32u8, 0u8, 0u8, 0u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Int(32)), - &Type::Primitive(PrimitiveType::Int), - ); + check_avro_bytes_serde(bytes, Datum::int(32), &PrimitiveType::Int); } #[test] fn avro_bytes_long() { let bytes = vec![32u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Long(32)), - &Type::Primitive(PrimitiveType::Long), - ); + check_avro_bytes_serde(bytes, Datum::long(32), &PrimitiveType::Long); } #[test] fn avro_bytes_float() { let bytes = vec![0u8, 0u8, 128u8, 63u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Float(OrderedFloat(1.0))), - &Type::Primitive(PrimitiveType::Float), - ); + check_avro_bytes_serde(bytes, Datum::float(1.0), &PrimitiveType::Float); } #[test] fn avro_bytes_double() { let bytes = vec![0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 240u8, 63u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Double(OrderedFloat(1.0))), - &Type::Primitive(PrimitiveType::Double), - ); + check_avro_bytes_serde(bytes, Datum::double(1.0), &PrimitiveType::Double); } #[test] fn avro_bytes_string() { let bytes = vec![105u8, 99u8, 101u8, 98u8, 101u8, 114u8, 103u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::String("iceberg".to_string())), - &Type::Primitive(PrimitiveType::String), - ); + check_avro_bytes_serde(bytes, Datum::string("iceberg"), &PrimitiveType::String); } #[test] @@ -2036,7 +3074,7 @@ mod tests { #[test] fn avro_convert_test_date() { check_convert_with_avro( - Literal::Primitive(PrimitiveLiteral::Date(17486)), + Literal::Primitive(PrimitiveLiteral::Int(17486)), &Type::Primitive(PrimitiveType::Date), ); } @@ -2044,7 +3082,7 @@ mod tests { #[test] fn avro_convert_test_time() { check_convert_with_avro( - Literal::Primitive(PrimitiveLiteral::Time(81068123456)), + Literal::Primitive(PrimitiveLiteral::Long(81068123456)), &Type::Primitive(PrimitiveType::Time), ); } @@ -2052,7 +3090,7 @@ mod tests { #[test] fn avro_convert_test_timestamp() { check_convert_with_avro( - Literal::Primitive(PrimitiveLiteral::Timestamp(1510871468123456)), + Literal::Primitive(PrimitiveLiteral::Long(1510871468123456)), &Type::Primitive(PrimitiveType::Timestamp), ); } @@ -2060,7 +3098,7 @@ mod tests { #[test] fn avro_convert_test_timestamptz() { check_convert_with_avro( - Literal::Primitive(PrimitiveLiteral::TimestampTZ(1510871468123456)), + Literal::Primitive(PrimitiveLiteral::Long(1510871468123456)), &Type::Primitive(PrimitiveType::Timestamptz), ); } @@ -2101,10 +3139,48 @@ mod tests { ); } + fn check_convert_with_avro_map(expected_literal: Literal, expected_type: &Type) { + let fields = vec![NestedField::required(1, "col", expected_type.clone()).into()]; + let schema = Schema::builder() + .with_fields(fields.clone()) + .build() + .unwrap(); + let avro_schema = schema_to_avro_schema("test", &schema).unwrap(); + let struct_type = Type::Struct(StructType::new(fields)); + let struct_literal = + Literal::Struct(Struct::from_iter(vec![Some(expected_literal.clone())])); + + let mut writer = apache_avro::Writer::new(&avro_schema, Vec::new()); + let raw_literal = RawLiteral::try_from(struct_literal.clone(), &struct_type).unwrap(); + writer.append_ser(raw_literal).unwrap(); + let encoded = writer.into_inner().unwrap(); + + let reader = apache_avro::Reader::new(&*encoded).unwrap(); + for record in reader { + let result = apache_avro::from_value::(&record.unwrap()).unwrap(); + let desered_literal = result.try_into(&struct_type).unwrap().unwrap(); + match (&desered_literal, &struct_literal) { + (Literal::Struct(desered), Literal::Struct(expected)) => { + match (&desered.fields[0], &expected.fields[0]) { + (Literal::Map(desered), Literal::Map(expected)) => { + assert!(desered.has_same_content(expected)) + } + _ => { + unreachable!() + } + } + } + _ => { + panic!("unexpected literal type"); + } + } + } + } + #[test] fn avro_convert_test_map() { - check_convert_with_avro( - Literal::Map(BTreeMap::from([ + check_convert_with_avro_map( + Literal::Map(Map::from([ ( Literal::Primitive(PrimitiveLiteral::Int(1)), Some(Literal::Primitive(PrimitiveLiteral::Long(1))), @@ -2127,8 +3203,8 @@ mod tests { }), ); - check_convert_with_avro( - Literal::Map(BTreeMap::from([ + check_convert_with_avro_map( + Literal::Map(Map::from([ ( Literal::Primitive(PrimitiveLiteral::Int(1)), Some(Literal::Primitive(PrimitiveLiteral::Long(1))), @@ -2157,8 +3233,8 @@ mod tests { #[test] fn avro_convert_test_string_map() { - check_convert_with_avro( - Literal::Map(BTreeMap::from([ + check_convert_with_avro_map( + Literal::Map(Map::from([ ( Literal::Primitive(PrimitiveLiteral::String("a".to_string())), Some(Literal::Primitive(PrimitiveLiteral::Int(1))), @@ -2184,8 +3260,8 @@ mod tests { }), ); - check_convert_with_avro( - Literal::Map(BTreeMap::from([ + check_convert_with_avro_map( + Literal::Map(Map::from([ ( Literal::Primitive(PrimitiveLiteral::String("a".to_string())), Some(Literal::Primitive(PrimitiveLiteral::Int(1))), @@ -2216,19 +3292,11 @@ mod tests { fn avro_convert_test_record() { check_convert_with_avro( Literal::Struct(Struct::from_iter(vec![ - ( - 1, - Some(Literal::Primitive(PrimitiveLiteral::Int(1))), - "id".to_string(), - ), - ( - 2, - Some(Literal::Primitive(PrimitiveLiteral::String( - "bar".to_string(), - ))), - "name".to_string(), - ), - (3, None, "address".to_string()), + Some(Literal::Primitive(PrimitiveLiteral::Int(1))), + Some(Literal::Primitive(PrimitiveLiteral::String( + "bar".to_string(), + ))), + None, ])), &Type::Struct(StructType::new(vec![ NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), @@ -2265,4 +3333,133 @@ mod tests { // rust avro can't support to convert any byte-like type to fixed in avro now. // - uuid ser/de // - fixed ser/de + + #[test] + fn test_parse_timestamp() { + let value = Datum::timestamp_from_str("2021-08-01T01:09:00.0899").unwrap(); + assert_eq!(&format!("{value}"), "2021-08-01 01:09:00.089900"); + + let value = Datum::timestamp_from_str("2021-08-01T01:09:00.0899+0800"); + assert!(value.is_err(), "Parse timestamp with timezone should fail!"); + + let value = Datum::timestamp_from_str("dfa"); + assert!( + value.is_err(), + "Parse timestamp with invalid input should fail!" + ); + } + + #[test] + fn test_parse_timestamptz() { + let value = Datum::timestamptz_from_str("2021-08-01T09:09:00.0899+0800").unwrap(); + assert_eq!(&format!("{value}"), "2021-08-01 01:09:00.089900 UTC"); + + let value = Datum::timestamptz_from_str("2021-08-01T01:09:00.0899"); + assert!( + value.is_err(), + "Parse timestamptz without timezone should fail!" + ); + + let value = Datum::timestamptz_from_str("dfa"); + assert!( + value.is_err(), + "Parse timestamptz with invalid input should fail!" + ); + } + + #[test] + fn test_datum_ser_deser() { + let test_fn = |datum: Datum| { + let json = serde_json::to_value(&datum).unwrap(); + let desered_datum: Datum = serde_json::from_value(json).unwrap(); + assert_eq!(datum, desered_datum); + }; + let datum = Datum::int(1); + test_fn(datum); + let datum = Datum::long(1); + test_fn(datum); + + let datum = Datum::float(1.0); + test_fn(datum); + let datum = Datum::float(0_f32); + test_fn(datum); + let datum = Datum::float(-0_f32); + test_fn(datum); + let datum = Datum::float(f32::MAX); + test_fn(datum); + let datum = Datum::float(f32::MIN); + test_fn(datum); + + // serde_json can't serialize f32::INFINITY, f32::NEG_INFINITY, f32::NAN + let datum = Datum::float(f32::INFINITY); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + let datum = Datum::float(f32::NEG_INFINITY); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + let datum = Datum::float(f32::NAN); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + + let datum = Datum::double(1.0); + test_fn(datum); + let datum = Datum::double(f64::MAX); + test_fn(datum); + let datum = Datum::double(f64::MIN); + test_fn(datum); + + // serde_json can't serialize f32::INFINITY, f32::NEG_INFINITY, f32::NAN + let datum = Datum::double(f64::INFINITY); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + let datum = Datum::double(f64::NEG_INFINITY); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + let datum = Datum::double(f64::NAN); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + + let datum = Datum::string("iceberg"); + test_fn(datum); + let datum = Datum::bool(true); + test_fn(datum); + let datum = Datum::date(17486); + test_fn(datum); + let datum = Datum::time_from_hms_micro(22, 15, 33, 111).unwrap(); + test_fn(datum); + let datum = Datum::timestamp_micros(1510871468123456); + test_fn(datum); + let datum = Datum::timestamptz_micros(1510871468123456); + test_fn(datum); + let datum = Datum::uuid(Uuid::parse_str("f79c3e09-677c-4bbd-a479-3f349cb785e7").unwrap()); + test_fn(datum); + let datum = Datum::decimal(1420).unwrap(); + test_fn(datum); + let datum = Datum::binary(vec![1, 2, 3, 4, 5]); + test_fn(datum); + let datum = Datum::fixed(vec![1, 2, 3, 4, 5]); + test_fn(datum); + } + + #[test] + fn test_datum_date_convert_to_int() { + let datum_date = Datum::date(12345); + + let result = datum_date.to(&Primitive(PrimitiveType::Int)).unwrap(); + + let expected = Datum::int(12345); + + assert_eq!(result, expected); + } + + #[test] + fn test_datum_int_convert_to_date() { + let datum_int = Datum::int(12345); + + let result = datum_int.to(&Primitive(PrimitiveType::Date)).unwrap(); + + let expected = Datum::date(12345); + + assert_eq!(result, expected); + } } diff --git a/crates/iceberg/src/spec/view_metadata.rs b/crates/iceberg/src/spec/view_metadata.rs new file mode 100644 index 000000000..741e38649 --- /dev/null +++ b/crates/iceberg/src/spec/view_metadata.rs @@ -0,0 +1,728 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the [view metadata](https://iceberg.apache.org/view-spec/#view-metadata). +//! The main struct here is [ViewMetadata] which defines the data for a view. + +use std::cmp::Ordering; +use std::collections::HashMap; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use _serde::ViewMetadataEnum; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use uuid::Uuid; + +use super::view_version::{ViewVersion, ViewVersionId, ViewVersionRef}; +use super::{SchemaId, SchemaRef}; +use crate::catalog::ViewCreation; +use crate::error::{timestamp_ms_to_utc, Result}; + +/// Reference to [`ViewMetadata`]. +pub type ViewMetadataRef = Arc; + +pub(crate) static INITIAL_VIEW_VERSION_ID: i32 = 1; + +#[derive(Debug, PartialEq, Deserialize, Eq, Clone)] +#[serde(try_from = "ViewMetadataEnum", into = "ViewMetadataEnum")] +/// Fields for the version 1 of the view metadata. +/// +/// We assume that this data structure is always valid, so we will panic when invalid error happens. +/// We check the validity of this data structure when constructing. +pub struct ViewMetadata { + /// Integer Version for the format. + pub(crate) format_version: ViewFormatVersion, + /// A UUID that identifies the view, generated when the view is created. + pub(crate) view_uuid: Uuid, + /// The view's base location; used to create metadata file locations + pub(crate) location: String, + /// ID of the current version of the view (version-id) + pub(crate) current_version_id: ViewVersionId, + /// A list of known versions of the view + pub(crate) versions: HashMap, + /// A list of version log entries with the timestamp and version-id for every + /// change to current-version-id + pub(crate) version_log: Vec, + /// A list of schemas, stored as objects with schema-id. + pub(crate) schemas: HashMap, + /// A string to string map of view properties. + /// Properties are used for metadata such as comment and for settings that + /// affect view maintenance. This is not intended to be used for arbitrary metadata. + pub(crate) properties: HashMap, +} + +impl ViewMetadata { + /// Returns format version of this metadata. + #[inline] + pub fn format_version(&self) -> ViewFormatVersion { + self.format_version + } + + /// Returns uuid of current view. + #[inline] + pub fn uuid(&self) -> Uuid { + self.view_uuid + } + + /// Returns view location. + #[inline] + pub fn location(&self) -> &str { + self.location.as_str() + } + + /// Returns the current version id. + #[inline] + pub fn current_version_id(&self) -> ViewVersionId { + self.current_version_id + } + + /// Returns all view versions. + #[inline] + pub fn versions(&self) -> impl Iterator { + self.versions.values() + } + + /// Lookup a view version by id. + #[inline] + pub fn version_by_id(&self, version_id: ViewVersionId) -> Option<&ViewVersionRef> { + self.versions.get(&version_id) + } + + /// Returns the current view version. + #[inline] + pub fn current_version(&self) -> &ViewVersionRef { + self.versions + .get(&self.current_version_id) + .expect("Current version id set, but not found in view versions") + } + + /// Returns schemas + #[inline] + pub fn schemas_iter(&self) -> impl Iterator { + self.schemas.values() + } + + /// Lookup schema by id. + #[inline] + pub fn schema_by_id(&self, schema_id: SchemaId) -> Option<&SchemaRef> { + self.schemas.get(&schema_id) + } + + /// Get current schema + #[inline] + pub fn current_schema(&self) -> &SchemaRef { + let schema_id = self.current_version().schema_id(); + self.schema_by_id(schema_id) + .expect("Current schema id set, but not found in view metadata") + } + + /// Returns properties of the view. + #[inline] + pub fn properties(&self) -> &HashMap { + &self.properties + } + + /// Returns view history. + #[inline] + pub fn history(&self) -> &[ViewVersionLog] { + &self.version_log + } +} + +/// Manipulating view metadata. +pub struct ViewMetadataBuilder(ViewMetadata); + +impl ViewMetadataBuilder { + /// Creates a new view metadata builder from the given view metadata. + pub fn new(origin: ViewMetadata) -> Self { + Self(origin) + } + + /// Creates a new view metadata builder from the given view creation. + pub fn from_view_creation(view_creation: ViewCreation) -> Result { + let ViewCreation { + location, + schema, + properties, + name: _, + representations, + default_catalog, + default_namespace, + summary, + } = view_creation; + let initial_version_id = super::INITIAL_VIEW_VERSION_ID; + let version = ViewVersion::builder() + .with_default_catalog(default_catalog) + .with_default_namespace(default_namespace) + .with_representations(representations) + .with_schema_id(schema.schema_id()) + .with_summary(summary) + .with_timestamp_ms(Utc::now().timestamp_millis()) + .with_version_id(initial_version_id) + .build(); + + let versions = HashMap::from_iter(vec![(initial_version_id, version.into())]); + + let view_metadata = ViewMetadata { + format_version: ViewFormatVersion::V1, + view_uuid: Uuid::now_v7(), + location, + current_version_id: initial_version_id, + versions, + version_log: Vec::new(), + schemas: HashMap::from_iter(vec![(schema.schema_id(), Arc::new(schema))]), + properties, + }; + + Ok(Self(view_metadata)) + } + + /// Changes uuid of view metadata. + pub fn assign_uuid(mut self, uuid: Uuid) -> Result { + self.0.view_uuid = uuid; + Ok(self) + } + + /// Returns the new view metadata after changes. + pub fn build(self) -> Result { + Ok(self.0) + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +#[serde(rename_all = "kebab-case")] +/// A log of when each snapshot was made. +pub struct ViewVersionLog { + /// ID that current-version-id was set to + version_id: ViewVersionId, + /// Timestamp when the view's current-version-id was updated (ms from epoch) + timestamp_ms: i64, +} + +impl ViewVersionLog { + #[inline] + /// Creates a new view version log. + pub fn new(version_id: ViewVersionId, timestamp: i64) -> Self { + Self { + version_id, + timestamp_ms: timestamp, + } + } + + /// Returns the version id. + #[inline] + pub fn version_id(&self) -> ViewVersionId { + self.version_id + } + + /// Returns the timestamp in milliseconds from epoch. + #[inline] + pub fn timestamp_ms(&self) -> i64 { + self.timestamp_ms + } + + /// Returns the last updated timestamp as a DateTime with millisecond precision. + pub fn timestamp(self) -> Result> { + timestamp_ms_to_utc(self.timestamp_ms) + } +} + +pub(super) mod _serde { + /// This is a helper module that defines types to help with serialization/deserialization. + /// For deserialization the input first gets read into either the [ViewMetadataV1] struct + /// and then converted into the [ViewMetadata] struct. Serialization works the other way around. + /// [ViewMetadataV1] is an internal struct that are only used for serialization and deserialization. + use std::{collections::HashMap, sync::Arc}; + + use serde::{Deserialize, Serialize}; + use uuid::Uuid; + + use super::{ViewFormatVersion, ViewVersionId, ViewVersionLog}; + use crate::spec::schema::_serde::SchemaV2; + use crate::spec::table_metadata::_serde::VersionNumber; + use crate::spec::view_version::_serde::ViewVersionV1; + use crate::spec::{ViewMetadata, ViewVersion}; + use crate::{Error, ErrorKind}; + + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] + #[serde(untagged)] + pub(super) enum ViewMetadataEnum { + V1(ViewMetadataV1), + } + + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] + #[serde(rename_all = "kebab-case")] + /// Defines the structure of a v1 view metadata for serialization/deserialization + pub(super) struct ViewMetadataV1 { + pub format_version: VersionNumber<1>, + pub(super) view_uuid: Uuid, + pub(super) location: String, + pub(super) current_version_id: ViewVersionId, + pub(super) versions: Vec, + pub(super) version_log: Vec, + pub(super) schemas: Vec, + pub(super) properties: Option>, + } + + impl Serialize for ViewMetadata { + fn serialize(&self, serializer: S) -> Result + where S: serde::Serializer { + // we must do a clone here + let metadata_enum: ViewMetadataEnum = + self.clone().try_into().map_err(serde::ser::Error::custom)?; + + metadata_enum.serialize(serializer) + } + } + + impl TryFrom for ViewMetadata { + type Error = Error; + fn try_from(value: ViewMetadataEnum) -> Result { + match value { + ViewMetadataEnum::V1(value) => value.try_into(), + } + } + } + + impl TryFrom for ViewMetadataEnum { + type Error = Error; + fn try_from(value: ViewMetadata) -> Result { + Ok(match value.format_version { + ViewFormatVersion::V1 => ViewMetadataEnum::V1(value.into()), + }) + } + } + + impl TryFrom for ViewMetadata { + type Error = Error; + fn try_from(value: ViewMetadataV1) -> Result { + let schemas = HashMap::from_iter( + value + .schemas + .into_iter() + .map(|schema| Ok((schema.schema_id, Arc::new(schema.try_into()?)))) + .collect::, Error>>()?, + ); + let versions = HashMap::from_iter( + value + .versions + .into_iter() + .map(|x| Ok((x.version_id, Arc::new(ViewVersion::from(x))))) + .collect::, Error>>()?, + ); + // Make sure at least the current schema exists + let current_version = + versions + .get(&value.current_version_id) + .ok_or(self::Error::new( + ErrorKind::DataInvalid, + format!( + "No version exists with the current version id {}.", + value.current_version_id + ), + ))?; + if !schemas.contains_key(¤t_version.schema_id()) { + return Err(self::Error::new( + ErrorKind::DataInvalid, + format!( + "No schema exists with the schema id {}.", + current_version.schema_id() + ), + )); + } + + Ok(ViewMetadata { + format_version: ViewFormatVersion::V1, + view_uuid: value.view_uuid, + location: value.location, + schemas, + properties: value.properties.unwrap_or_default(), + current_version_id: value.current_version_id, + versions, + version_log: value.version_log, + }) + } + } + + impl From for ViewMetadataV1 { + fn from(v: ViewMetadata) -> Self { + let schemas = v + .schemas + .into_values() + .map(|x| { + Arc::try_unwrap(x) + .unwrap_or_else(|schema| schema.as_ref().clone()) + .into() + }) + .collect(); + let versions = v + .versions + .into_values() + .map(|x| { + Arc::try_unwrap(x) + .unwrap_or_else(|version| version.as_ref().clone()) + .into() + }) + .collect(); + ViewMetadataV1 { + format_version: VersionNumber::<1>, + view_uuid: v.view_uuid, + location: v.location, + schemas, + properties: Some(v.properties), + current_version_id: v.current_version_id, + versions, + version_log: v.version_log, + } + } + } +} + +#[derive(Debug, Serialize_repr, Deserialize_repr, PartialEq, Eq, Clone, Copy)] +#[repr(u8)] +/// Iceberg format version +pub enum ViewFormatVersion { + /// Iceberg view spec version 1 + V1 = 1u8, +} + +impl PartialOrd for ViewFormatVersion { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ViewFormatVersion { + fn cmp(&self, other: &Self) -> Ordering { + (*self as u8).cmp(&(*other as u8)) + } +} + +impl Display for ViewFormatVersion { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ViewFormatVersion::V1 => write!(f, "v1"), + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::fs; + use std::sync::Arc; + + use anyhow::Result; + use pretty_assertions::assert_eq; + use uuid::Uuid; + + use super::{ViewFormatVersion, ViewMetadataBuilder, ViewVersionLog}; + use crate::spec::{ + NestedField, PrimitiveType, Schema, SqlViewRepresentation, Type, ViewMetadata, + ViewRepresentations, ViewVersion, + }; + use crate::{NamespaceIdent, ViewCreation}; + + fn check_view_metadata_serde(json: &str, expected_type: ViewMetadata) { + let desered_type: ViewMetadata = serde_json::from_str(json).unwrap(); + assert_eq!(desered_type, expected_type); + + let sered_json = serde_json::to_string(&expected_type).unwrap(); + let parsed_json_value = serde_json::from_str::(&sered_json).unwrap(); + + assert_eq!(parsed_json_value, desered_type); + } + + fn get_test_view_metadata(file_name: &str) -> ViewMetadata { + let path = format!("testdata/view_metadata/{}", file_name); + let metadata: String = fs::read_to_string(path).unwrap(); + + serde_json::from_str(&metadata).unwrap() + } + + #[test] + fn test_view_data_v1() { + let data = r#" + { + "view-uuid": "fa6506c3-7681-40c8-86dc-e36561f83385", + "format-version" : 1, + "location" : "s3://bucket/warehouse/default.db/event_agg", + "current-version-id" : 1, + "properties" : { + "comment" : "Daily event counts" + }, + "versions" : [ { + "version-id" : 1, + "timestamp-ms" : 1573518431292, + "schema-id" : 1, + "default-catalog" : "prod", + "default-namespace" : [ "default" ], + "summary" : { + "engine-name" : "Spark", + "engineVersion" : "3.3.2" + }, + "representations" : [ { + "type" : "sql", + "sql" : "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2", + "dialect" : "spark" + } ] + } ], + "schemas": [ { + "schema-id": 1, + "type" : "struct", + "fields" : [ { + "id" : 1, + "name" : "event_count", + "required" : false, + "type" : "int", + "doc" : "Count of events" + } ] + } ], + "version-log" : [ { + "timestamp-ms" : 1573518431292, + "version-id" : 1 + } ] + } + "#; + + let schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![Arc::new( + NestedField::optional(1, "event_count", Type::Primitive(PrimitiveType::Int)) + .with_doc("Count of events"), + )]) + .build() + .unwrap(); + let version = ViewVersion::builder() + .with_version_id(1) + .with_timestamp_ms(1573518431292) + .with_schema_id(1) + .with_default_catalog("prod".to_string().into()) + .with_default_namespace(NamespaceIdent::from_vec(vec!["default".to_string()]).unwrap()) + .with_summary(HashMap::from_iter(vec![ + ("engineVersion".to_string(), "3.3.2".to_string()), + ("engine-name".to_string(), "Spark".to_string()), + ])) + .with_representations(ViewRepresentations(vec![SqlViewRepresentation { + sql: "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2" + .to_string(), + dialect: "spark".to_string(), + } + .into()])) + .build(); + + let expected = ViewMetadata { + format_version: ViewFormatVersion::V1, + view_uuid: Uuid::parse_str("fa6506c3-7681-40c8-86dc-e36561f83385").unwrap(), + location: "s3://bucket/warehouse/default.db/event_agg".to_string(), + current_version_id: 1, + versions: HashMap::from_iter(vec![(1, Arc::new(version))]), + version_log: vec![ViewVersionLog { + timestamp_ms: 1573518431292, + version_id: 1, + }], + schemas: HashMap::from_iter(vec![(1, Arc::new(schema))]), + properties: HashMap::from_iter(vec![( + "comment".to_string(), + "Daily event counts".to_string(), + )]), + }; + + check_view_metadata_serde(data, expected); + } + + #[test] + fn test_invalid_view_uuid() -> Result<()> { + let data = r#" + { + "format-version" : 1, + "view-uuid": "xxxx" + } + "#; + assert!(serde_json::from_str::(data).is_err()); + Ok(()) + } + + #[test] + fn test_view_builder_from_view_creation() { + let representations = ViewRepresentations(vec![SqlViewRepresentation { + sql: "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2" + .to_string(), + dialect: "spark".to_string(), + } + .into()]); + let creation = ViewCreation::builder() + .location("s3://bucket/warehouse/default.db/event_agg".to_string()) + .name("view".to_string()) + .schema(Schema::builder().build().unwrap()) + .default_namespace(NamespaceIdent::from_vec(vec!["default".to_string()]).unwrap()) + .representations(representations) + .build(); + + let metadata = ViewMetadataBuilder::from_view_creation(creation) + .unwrap() + .build() + .unwrap(); + + assert_eq!( + metadata.location(), + "s3://bucket/warehouse/default.db/event_agg" + ); + assert_eq!(metadata.current_version_id(), 1); + assert_eq!(metadata.versions().count(), 1); + assert_eq!(metadata.schemas_iter().count(), 1); + assert_eq!(metadata.properties().len(), 0); + } + + #[test] + fn test_view_metadata_v1_file_valid() { + let metadata = + fs::read_to_string("testdata/view_metadata/ViewMetadataV1Valid.json").unwrap(); + + let schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + Arc::new( + NestedField::optional(1, "event_count", Type::Primitive(PrimitiveType::Int)) + .with_doc("Count of events"), + ), + Arc::new(NestedField::optional( + 2, + "event_date", + Type::Primitive(PrimitiveType::Date), + )), + ]) + .build() + .unwrap(); + + let version = ViewVersion::builder() + .with_version_id(1) + .with_timestamp_ms(1573518431292) + .with_schema_id(1) + .with_default_catalog("prod".to_string().into()) + .with_default_namespace(NamespaceIdent::from_vec(vec!["default".to_string()]).unwrap()) + .with_summary(HashMap::from_iter(vec![ + ("engineVersion".to_string(), "3.3.2".to_string()), + ("engine-name".to_string(), "Spark".to_string()), + ])) + .with_representations(ViewRepresentations(vec![SqlViewRepresentation { + sql: "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2" + .to_string(), + dialect: "spark".to_string(), + } + .into()])) + .build(); + + let expected = ViewMetadata { + format_version: ViewFormatVersion::V1, + view_uuid: Uuid::parse_str("fa6506c3-7681-40c8-86dc-e36561f83385").unwrap(), + location: "s3://bucket/warehouse/default.db/event_agg".to_string(), + current_version_id: 1, + versions: HashMap::from_iter(vec![(1, Arc::new(version))]), + version_log: vec![ViewVersionLog { + timestamp_ms: 1573518431292, + version_id: 1, + }], + schemas: HashMap::from_iter(vec![(1, Arc::new(schema))]), + properties: HashMap::from_iter(vec![( + "comment".to_string(), + "Daily event counts".to_string(), + )]), + }; + + check_view_metadata_serde(&metadata, expected); + } + + #[test] + fn test_view_builder_assign_uuid() { + let metadata = get_test_view_metadata("ViewMetadataV1Valid.json"); + let metadata_builder = ViewMetadataBuilder::new(metadata); + let uuid = Uuid::new_v4(); + let metadata = metadata_builder.assign_uuid(uuid).unwrap().build().unwrap(); + assert_eq!(metadata.uuid(), uuid); + } + + #[test] + fn test_view_metadata_v1_unsupported_version() { + let metadata = + fs::read_to_string("testdata/view_metadata/ViewMetadataUnsupportedVersion.json") + .unwrap(); + + let desered: Result = serde_json::from_str(&metadata); + + assert_eq!( + desered.unwrap_err().to_string(), + "data did not match any variant of untagged enum ViewMetadataEnum" + ) + } + + #[test] + fn test_view_metadata_v1_version_not_found() { + let metadata = + fs::read_to_string("testdata/view_metadata/ViewMetadataV1CurrentVersionNotFound.json") + .unwrap(); + + let desered: Result = serde_json::from_str(&metadata); + + assert_eq!( + desered.unwrap_err().to_string(), + "DataInvalid => No version exists with the current version id 2." + ) + } + + #[test] + fn test_view_metadata_v1_schema_not_found() { + let metadata = + fs::read_to_string("testdata/view_metadata/ViewMetadataV1SchemaNotFound.json").unwrap(); + + let desered: Result = serde_json::from_str(&metadata); + + assert_eq!( + desered.unwrap_err().to_string(), + "DataInvalid => No schema exists with the schema id 2." + ) + } + + #[test] + fn test_view_metadata_v1_missing_schema_for_version() { + let metadata = + fs::read_to_string("testdata/view_metadata/ViewMetadataV1MissingSchema.json").unwrap(); + + let desered: Result = serde_json::from_str(&metadata); + + assert_eq!( + desered.unwrap_err().to_string(), + "data did not match any variant of untagged enum ViewMetadataEnum" + ) + } + + #[test] + fn test_view_metadata_v1_missing_current_version() { + let metadata = + fs::read_to_string("testdata/view_metadata/ViewMetadataV1MissingCurrentVersion.json") + .unwrap(); + + let desered: Result = serde_json::from_str(&metadata); + + assert_eq!( + desered.unwrap_err().to_string(), + "data did not match any variant of untagged enum ViewMetadataEnum" + ) + } +} diff --git a/crates/iceberg/src/spec/view_version.rs b/crates/iceberg/src/spec/view_version.rs new file mode 100644 index 000000000..30686b5a4 --- /dev/null +++ b/crates/iceberg/src/spec/view_version.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/*! + * View Versions! + */ +use std::collections::HashMap; +use std::sync::Arc; + +use _serde::ViewVersionV1; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use typed_builder::TypedBuilder; + +use super::view_metadata::ViewVersionLog; +use crate::catalog::NamespaceIdent; +use crate::error::{timestamp_ms_to_utc, Result}; +use crate::spec::{SchemaId, SchemaRef, ViewMetadata}; +use crate::{Error, ErrorKind}; + +/// Reference to [`ViewVersion`]. +pub type ViewVersionRef = Arc; + +/// Alias for the integer type used for view version ids. +pub type ViewVersionId = i32; + +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TypedBuilder)] +#[serde(from = "ViewVersionV1", into = "ViewVersionV1")] +#[builder(field_defaults(setter(prefix = "with_")))] +/// A view versions represents the definition of a view at a specific point in time. +pub struct ViewVersion { + /// A unique long ID + version_id: ViewVersionId, + /// ID of the schema for the view version + schema_id: SchemaId, + /// Timestamp when the version was created (ms from epoch) + timestamp_ms: i64, + /// A string to string map of summary metadata about the version + summary: HashMap, + /// A list of representations for the view definition. + representations: ViewRepresentations, + /// Catalog name to use when a reference in the SELECT does not contain a catalog + #[builder(default = None)] + default_catalog: Option, + /// Namespace to use when a reference in the SELECT is a single identifier + default_namespace: NamespaceIdent, +} + +impl ViewVersion { + /// Get the version id of this view version. + #[inline] + pub fn version_id(&self) -> ViewVersionId { + self.version_id + } + + /// Get the schema id of this view version. + #[inline] + pub fn schema_id(&self) -> SchemaId { + self.schema_id + } + + /// Get the timestamp of when the view version was created + #[inline] + pub fn timestamp(&self) -> Result> { + timestamp_ms_to_utc(self.timestamp_ms) + } + + /// Get the timestamp of when the view version was created in milliseconds since epoch + #[inline] + pub fn timestamp_ms(&self) -> i64 { + self.timestamp_ms + } + + /// Get summary of the view version + #[inline] + pub fn summary(&self) -> &HashMap { + &self.summary + } + + /// Get this views representations + #[inline] + pub fn representations(&self) -> &ViewRepresentations { + &self.representations + } + + /// Get the default catalog for this view version + #[inline] + pub fn default_catalog(&self) -> Option<&String> { + self.default_catalog.as_ref() + } + + /// Get the default namespace to use when a reference in the SELECT is a single identifier + #[inline] + pub fn default_namespace(&self) -> &NamespaceIdent { + &self.default_namespace + } + + /// Get the schema of this snapshot. + pub fn schema(&self, view_metadata: &ViewMetadata) -> Result { + let r = view_metadata + .schema_by_id(self.schema_id()) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Schema with id {} not found", self.schema_id()), + ) + }) + .cloned(); + r + } + + /// Retrieve the history log entry for this view version. + #[allow(dead_code)] + pub(crate) fn log(&self) -> ViewVersionLog { + ViewVersionLog::new(self.version_id, self.timestamp_ms) + } +} + +/// A list of view representations. +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +pub struct ViewRepresentations(pub(crate) Vec); + +impl ViewRepresentations { + #[inline] + /// Get the number of representations + pub fn len(&self) -> usize { + self.0.len() + } + + #[inline] + /// Check if there are no representations + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Get an iterator over the representations + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } +} + +// Iterator for ViewRepresentations +impl IntoIterator for ViewRepresentations { + type Item = ViewRepresentation; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +#[serde(tag = "type")] +/// View definitions can be represented in multiple ways. +/// Representations are documented ways to express a view definition. +// ToDo: Make unique per Dialect +pub enum ViewRepresentation { + #[serde(rename = "sql")] + /// The SQL representation stores the view definition as a SQL SELECT, + Sql(SqlViewRepresentation), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +#[serde(rename_all = "kebab-case")] +/// The SQL representation stores the view definition as a SQL SELECT, +/// with metadata such as the SQL dialect. +pub struct SqlViewRepresentation { + #[serde(rename = "sql")] + /// The SQL SELECT statement that defines the view. + pub sql: String, + #[serde(rename = "dialect")] + /// The dialect of the sql SELECT statement (e.g., "trino" or "spark") + pub dialect: String, +} + +pub(super) mod _serde { + /// This is a helper module that defines types to help with serialization/deserialization. + /// For deserialization the input first gets read into the [`ViewVersionV1`] struct. + /// and then converted into the [Snapshot] struct. Serialization works the other way around. + /// [ViewVersionV1] are internal struct that are only used for serialization and deserialization. + use serde::{Deserialize, Serialize}; + + use super::{ViewRepresentation, ViewRepresentations, ViewVersion}; + use crate::catalog::NamespaceIdent; + + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] + #[serde(rename_all = "kebab-case")] + /// Defines the structure of a v1 view version for serialization/deserialization + pub(crate) struct ViewVersionV1 { + pub version_id: i32, + pub schema_id: i32, + pub timestamp_ms: i64, + pub summary: std::collections::HashMap, + pub representations: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub default_catalog: Option, + pub default_namespace: NamespaceIdent, + } + + impl From for ViewVersion { + fn from(v1: ViewVersionV1) -> Self { + ViewVersion { + version_id: v1.version_id, + schema_id: v1.schema_id, + timestamp_ms: v1.timestamp_ms, + summary: v1.summary, + representations: ViewRepresentations(v1.representations), + default_catalog: v1.default_catalog, + default_namespace: v1.default_namespace, + } + } + } + + impl From for ViewVersionV1 { + fn from(v1: ViewVersion) -> Self { + ViewVersionV1 { + version_id: v1.version_id, + schema_id: v1.schema_id, + timestamp_ms: v1.timestamp_ms, + summary: v1.summary, + representations: v1.representations.0, + default_catalog: v1.default_catalog, + default_namespace: v1.default_namespace, + } + } + } +} + +impl From for ViewRepresentation { + fn from(sql: SqlViewRepresentation) -> Self { + ViewRepresentation::Sql(sql) + } +} + +#[cfg(test)] +mod tests { + use chrono::{TimeZone, Utc}; + + use crate::spec::view_version::ViewVersion; + use crate::spec::view_version::_serde::ViewVersionV1; + use crate::spec::ViewRepresentations; + + #[test] + fn view_version() { + let record = serde_json::json!( + { + "version-id" : 1, + "timestamp-ms" : 1573518431292i64, + "schema-id" : 1, + "default-catalog" : "prod", + "default-namespace" : [ "default" ], + "summary" : { + "engine-name" : "Spark", + "engineVersion" : "3.3.2" + }, + "representations" : [ { + "type" : "sql", + "sql" : "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2", + "dialect" : "spark" + } ] + } + ); + + let result: ViewVersion = serde_json::from_value::(record.clone()) + .unwrap() + .into(); + + // Roundtrip + assert_eq!(serde_json::to_value(result.clone()).unwrap(), record); + + assert_eq!(result.version_id(), 1); + assert_eq!( + result.timestamp().unwrap(), + Utc.timestamp_millis_opt(1573518431292).unwrap() + ); + assert_eq!(result.schema_id(), 1); + assert_eq!(result.default_catalog, Some("prod".to_string())); + assert_eq!(result.summary(), &{ + let mut map = std::collections::HashMap::new(); + map.insert("engine-name".to_string(), "Spark".to_string()); + map.insert("engineVersion".to_string(), "3.3.2".to_string()); + map + }); + assert_eq!( + result.representations().to_owned(), + ViewRepresentations(vec![super::ViewRepresentation::Sql( + super::SqlViewRepresentation { + sql: "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2" + .to_string(), + dialect: "spark".to_string(), + }, + )]) + ); + assert_eq!( + result.default_namespace.inner(), + vec!["default".to_string()] + ); + } +} diff --git a/crates/iceberg/src/table.rs b/crates/iceberg/src/table.rs index fad91394c..406f9dd65 100644 --- a/crates/iceberg/src/table.rs +++ b/crates/iceberg/src/table.rs @@ -17,22 +17,155 @@ //! Table API for Apache Iceberg +use std::sync::Arc; + +use crate::arrow::ArrowReaderBuilder; +use crate::io::object_cache::ObjectCache; use crate::io::FileIO; -use crate::spec::TableMetadata; -use crate::TableIdent; -use typed_builder::TypedBuilder; +use crate::scan::TableScanBuilder; +use crate::spec::{TableMetadata, TableMetadataRef}; +use crate::{Error, ErrorKind, Result, TableIdent}; + +/// Builder to create table scan. +pub struct TableBuilder { + file_io: Option, + metadata_location: Option, + metadata: Option, + identifier: Option, + readonly: bool, + disable_cache: bool, + cache_size_bytes: Option, +} + +impl TableBuilder { + pub(crate) fn new() -> Self { + Self { + file_io: None, + metadata_location: None, + metadata: None, + identifier: None, + readonly: false, + disable_cache: false, + cache_size_bytes: None, + } + } + + /// required - sets the necessary FileIO to use for the table + pub fn file_io(mut self, file_io: FileIO) -> Self { + self.file_io = Some(file_io); + self + } + + /// optional - sets the tables metadata location + pub fn metadata_location>(mut self, metadata_location: T) -> Self { + self.metadata_location = Some(metadata_location.into()); + self + } + + /// required - passes in the TableMetadata to use for the Table + pub fn metadata>(mut self, metadata: T) -> Self { + self.metadata = Some(metadata.into()); + self + } + + /// required - passes in the TableIdent to use for the Table + pub fn identifier(mut self, identifier: TableIdent) -> Self { + self.identifier = Some(identifier); + self + } + + /// specifies if the Table is readonly or not (default not) + pub fn readonly(mut self, readonly: bool) -> Self { + self.readonly = readonly; + self + } + + /// specifies if the Table's metadata cache will be disabled, + /// so that reads of Manifests and ManifestLists will never + /// get cached. + pub fn disable_cache(mut self) -> Self { + self.disable_cache = true; + self + } + + /// optionally set a non-default metadata cache size + pub fn cache_size_bytes(mut self, cache_size_bytes: u64) -> Self { + self.cache_size_bytes = Some(cache_size_bytes); + self + } + + /// build the Table + pub fn build(self) -> Result
{ + let Self { + file_io, + metadata_location, + metadata, + identifier, + readonly, + disable_cache, + cache_size_bytes, + } = self; + + let Some(file_io) = file_io else { + return Err(Error::new( + ErrorKind::DataInvalid, + "FileIO must be provided with TableBuilder.file_io()", + )); + }; + + let Some(metadata) = metadata else { + return Err(Error::new( + ErrorKind::DataInvalid, + "TableMetadataRef must be provided with TableBuilder.metadata()", + )); + }; + + let Some(identifier) = identifier else { + return Err(Error::new( + ErrorKind::DataInvalid, + "TableIdent must be provided with TableBuilder.identifier()", + )); + }; + + let object_cache = if disable_cache { + Arc::new(ObjectCache::with_disabled_cache(file_io.clone())) + } else if let Some(cache_size_bytes) = cache_size_bytes { + Arc::new(ObjectCache::new_with_capacity( + file_io.clone(), + cache_size_bytes, + )) + } else { + Arc::new(ObjectCache::new(file_io.clone())) + }; + + Ok(Table { + file_io, + metadata_location, + metadata, + identifier, + readonly, + object_cache, + }) + } +} /// Table represents a table in the catalog. -#[derive(TypedBuilder, Debug)] +#[derive(Debug, Clone)] pub struct Table { file_io: FileIO, - #[builder(default, setter(strip_option))] metadata_location: Option, - metadata: TableMetadata, + metadata: TableMetadataRef, identifier: TableIdent, + readonly: bool, + object_cache: Arc, } impl Table { + /// Returns a TableBuilder to build a table + pub fn builder() -> TableBuilder { + TableBuilder::new() + } + /// Returns table identifier. pub fn identifier(&self) -> &TableIdent { &self.identifier @@ -42,8 +175,200 @@ impl Table { &self.metadata } + /// Returns current metadata ref. + pub fn metadata_ref(&self) -> TableMetadataRef { + self.metadata.clone() + } + /// Returns current metadata location. pub fn metadata_location(&self) -> Option<&str> { self.metadata_location.as_deref() } + + /// Returns file io used in this table. + pub fn file_io(&self) -> &FileIO { + &self.file_io + } + + /// Returns this table's object cache + pub(crate) fn object_cache(&self) -> Arc { + self.object_cache.clone() + } + + /// Creates a table scan. + pub fn scan(&self) -> TableScanBuilder<'_> { + TableScanBuilder::new(self) + } + + /// Returns the flag indicating whether the `Table` is readonly or not + pub fn readonly(&self) -> bool { + self.readonly + } + + /// Create a reader for the table. + pub fn reader_builder(&self) -> ArrowReaderBuilder { + ArrowReaderBuilder::new(self.file_io.clone()) + } +} + +/// `StaticTable` is a read-only table struct that can be created from a metadata file or from `TableMetaData` without a catalog. +/// It can only be used to read metadata and for table scan. +/// # Examples +/// +/// ```rust, no_run +/// # use iceberg::io::FileIO; +/// # use iceberg::table::StaticTable; +/// # use iceberg::TableIdent; +/// # async fn example() { +/// let metadata_file_location = "s3://bucket_name/path/to/metadata.json"; +/// let file_io = FileIO::from_path(&metadata_file_location) +/// .unwrap() +/// .build() +/// .unwrap(); +/// let static_identifier = TableIdent::from_strs(["static_ns", "static_table"]).unwrap(); +/// let static_table = +/// StaticTable::from_metadata_file(&metadata_file_location, static_identifier, file_io) +/// .await +/// .unwrap(); +/// let snapshot_id = static_table +/// .metadata() +/// .current_snapshot() +/// .unwrap() +/// .snapshot_id(); +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct StaticTable(Table); + +impl StaticTable { + /// Creates a static table from a given `TableMetadata` and `FileIO` + pub async fn from_metadata( + metadata: TableMetadata, + table_ident: TableIdent, + file_io: FileIO, + ) -> Result { + let table = Table::builder() + .metadata(metadata) + .identifier(table_ident) + .file_io(file_io.clone()) + .readonly(true) + .build(); + + Ok(Self(table?)) + } + /// Creates a static table directly from metadata file and `FileIO` + pub async fn from_metadata_file( + metadata_file_path: &str, + table_ident: TableIdent, + file_io: FileIO, + ) -> Result { + let metadata_file = file_io.new_input(metadata_file_path)?; + let metadata_file_content = metadata_file.read().await?; + let table_metadata = serde_json::from_slice::(&metadata_file_content)?; + Self::from_metadata(table_metadata, table_ident, file_io).await + } + + /// Create a TableScanBuilder for the static table. + pub fn scan(&self) -> TableScanBuilder<'_> { + self.0.scan() + } + + /// Get TableMetadataRef for the static table + pub fn metadata(&self) -> TableMetadataRef { + self.0.metadata_ref() + } + + /// Consumes the `StaticTable` and return it as a `Table` + /// Please use this method carefully as the Table it returns remains detached from a catalog + /// and can't be used to perform modifications on the table. + pub fn into_table(self) -> Table { + self.0 + } + + /// Create a reader for the table. + pub fn reader_builder(&self) -> ArrowReaderBuilder { + ArrowReaderBuilder::new(self.0.file_io.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_static_table_from_file() { + let metadata_file_name = "TableMetadataV2Valid.json"; + let metadata_file_path = format!( + "{}/testdata/table_metadata/{}", + env!("CARGO_MANIFEST_DIR"), + metadata_file_name + ); + let file_io = FileIO::from_path(&metadata_file_path) + .unwrap() + .build() + .unwrap(); + let static_identifier = TableIdent::from_strs(["static_ns", "static_table"]).unwrap(); + let static_table = + StaticTable::from_metadata_file(&metadata_file_path, static_identifier, file_io) + .await + .unwrap(); + let snapshot_id = static_table + .metadata() + .current_snapshot() + .unwrap() + .snapshot_id(); + assert_eq!( + snapshot_id, 3055729675574597004, + "snapshot id from metadata don't match" + ); + } + + #[tokio::test] + async fn test_static_into_table() { + let metadata_file_name = "TableMetadataV2Valid.json"; + let metadata_file_path = format!( + "{}/testdata/table_metadata/{}", + env!("CARGO_MANIFEST_DIR"), + metadata_file_name + ); + let file_io = FileIO::from_path(&metadata_file_path) + .unwrap() + .build() + .unwrap(); + let static_identifier = TableIdent::from_strs(["static_ns", "static_table"]).unwrap(); + let static_table = + StaticTable::from_metadata_file(&metadata_file_path, static_identifier, file_io) + .await + .unwrap(); + let table = static_table.into_table(); + assert!(table.readonly()); + assert_eq!(table.identifier.name(), "static_table"); + } + + #[tokio::test] + async fn test_table_readonly_flag() { + let metadata_file_name = "TableMetadataV2Valid.json"; + let metadata_file_path = format!( + "{}/testdata/table_metadata/{}", + env!("CARGO_MANIFEST_DIR"), + metadata_file_name + ); + let file_io = FileIO::from_path(&metadata_file_path) + .unwrap() + .build() + .unwrap(); + let metadata_file = file_io.new_input(metadata_file_path).unwrap(); + let metadata_file_content = metadata_file.read().await.unwrap(); + let table_metadata = + serde_json::from_slice::(&metadata_file_content).unwrap(); + let static_identifier = TableIdent::from_strs(["ns", "table"]).unwrap(); + let table = Table::builder() + .metadata(table_metadata) + .identifier(static_identifier) + .file_io(file_io.clone()) + .build() + .unwrap(); + assert!(!table.readonly()); + assert_eq!(table.identifier.name(), "table"); + } } diff --git a/crates/iceberg/src/transaction.rs b/crates/iceberg/src/transaction.rs index 4ea89a297..d416383d7 100644 --- a/crates/iceberg/src/transaction.rs +++ b/crates/iceberg/src/transaction.rs @@ -17,14 +17,15 @@ //! This module contains transaction api. +use std::cmp::Ordering; +use std::collections::HashMap; +use std::mem::discriminant; + use crate::error::Result; use crate::spec::{FormatVersion, NullOrder, SortDirection, SortField, SortOrder, Transform}; use crate::table::Table; use crate::TableUpdate::UpgradeFormatVersion; use crate::{Catalog, Error, ErrorKind, TableCommit, TableRequirement, TableUpdate}; -use std::cmp::Ordering; -use std::collections::HashMap; -use std::mem::discriminant; /// Table transaction. pub struct Transaction<'a> { @@ -140,12 +141,13 @@ impl<'a> ReplaceSortOrderAction<'a> { /// Finished building the action and apply it to the transaction. pub fn apply(mut self) -> Result> { + let unbound_sort_order = SortOrder::builder() + .with_fields(self.sort_fields) + .build_unbound()?; + let updates = vec![ TableUpdate::AddSortOrder { - sort_order: SortOrder { - fields: self.sort_fields, - ..SortOrder::default() - }, + sort_order: unbound_sort_order, }, TableUpdate::SetDefaultSortOrder { sort_order_id: -1 }, ]; @@ -160,7 +162,10 @@ impl<'a> ReplaceSortOrderAction<'a> { .table .metadata() .default_sort_order() - .unwrap() + .ok_or(Error::new( + ErrorKind::Unexpected, + "default sort order impossible to be none", + ))? .order_id, }, ]; @@ -203,14 +208,15 @@ impl<'a> ReplaceSortOrderAction<'a> { #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::fs::File; + use std::io::BufReader; + use crate::io::FileIO; use crate::spec::{FormatVersion, TableMetadata}; use crate::table::Table; use crate::transaction::Transaction; use crate::{TableIdent, TableRequirement, TableUpdate}; - use std::collections::HashMap; - use std::fs::File; - use std::io::BufReader; fn make_v1_table() -> Table { let file = File::open(format!( @@ -228,6 +234,7 @@ mod tests { .identifier(TableIdent::from_strs(["ns1", "test1"]).unwrap()) .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap()) .build() + .unwrap() } fn make_v2_table() -> Table { @@ -246,6 +253,7 @@ mod tests { .identifier(TableIdent::from_strs(["ns1", "test1"]).unwrap()) .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap()) .build() + .unwrap() } #[test] diff --git a/crates/iceberg/src/transform/bucket.rs b/crates/iceberg/src/transform/bucket.rs index beff0be96..ce39826bb 100644 --- a/crates/iceberg/src/transform/bucket.rs +++ b/crates/iceberg/src/transform/bucket.rs @@ -21,6 +21,7 @@ use arrow_array::ArrayRef; use arrow_schema::{DataType, TimeUnit}; use super::TransformFunction; +use crate::spec::{Datum, PrimitiveLiteral, PrimitiveType}; #[derive(Debug)] pub struct Bucket { @@ -35,39 +36,47 @@ impl Bucket { impl Bucket { /// When switch the hash function, we only need to change this function. + #[inline] fn hash_bytes(mut v: &[u8]) -> i32 { murmur3::murmur3_32(&mut v, 0).unwrap() as i32 } + #[inline] fn hash_int(v: i32) -> i32 { Self::hash_long(v as i64) } + #[inline] fn hash_long(v: i64) -> i32 { Self::hash_bytes(v.to_le_bytes().as_slice()) } /// v is days from unix epoch + #[inline] fn hash_date(v: i32) -> i32 { Self::hash_int(v) } /// v is microseconds from midnight + #[inline] fn hash_time(v: i64) -> i32 { Self::hash_long(v) } /// v is microseconds from unix epoch + #[inline] fn hash_timestamp(v: i64) -> i32 { Self::hash_long(v) } + #[inline] fn hash_str(s: &str) -> i32 { Self::hash_bytes(s.as_bytes()) } /// Decimal values are hashed using the minimum number of bytes required to hold the unscaled value as a two’s complement big-endian /// ref: https://iceberg.apache.org/spec/#appendix-b-32-bit-hash-requirements + #[inline] fn hash_decimal(v: i128) -> i32 { let bytes = v.to_be_bytes(); if let Some(start) = bytes.iter().position(|&x| x != 0) { @@ -79,9 +88,50 @@ impl Bucket { /// def bucket_N(x) = (murmur3_x86_32_hash(x) & Integer.MAX_VALUE) % N /// ref: https://iceberg.apache.org/spec/#partitioning + #[inline] fn bucket_n(&self, v: i32) -> i32 { (v & i32::MAX) % (self.mod_n as i32) } + + #[inline] + fn bucket_int(&self, v: i32) -> i32 { + self.bucket_n(Self::hash_int(v)) + } + + #[inline] + fn bucket_long(&self, v: i64) -> i32 { + self.bucket_n(Self::hash_long(v)) + } + + #[inline] + fn bucket_decimal(&self, v: i128) -> i32 { + self.bucket_n(Self::hash_decimal(v)) + } + + #[inline] + fn bucket_date(&self, v: i32) -> i32 { + self.bucket_n(Self::hash_date(v)) + } + + #[inline] + fn bucket_time(&self, v: i64) -> i32 { + self.bucket_n(Self::hash_time(v)) + } + + #[inline] + fn bucket_timestamp(&self, v: i64) -> i32 { + self.bucket_n(Self::hash_timestamp(v)) + } + + #[inline] + fn bucket_str(&self, v: &str) -> i32 { + self.bucket_n(Self::hash_str(v)) + } + + #[inline] + fn bucket_bytes(&self, v: &[u8]) -> i32 { + self.bucket_n(Self::hash_bytes(v)) + } } impl TransformFunction for Bucket { @@ -91,39 +141,39 @@ impl TransformFunction for Bucket { .as_any() .downcast_ref::() .unwrap() - .unary(|v| self.bucket_n(Self::hash_int(v))), + .unary(|v| self.bucket_int(v)), DataType::Int64 => input .as_any() .downcast_ref::() .unwrap() - .unary(|v| self.bucket_n(Self::hash_long(v))), + .unary(|v| self.bucket_long(v)), DataType::Decimal128(_, _) => input .as_any() .downcast_ref::() .unwrap() - .unary(|v| self.bucket_n(Self::hash_decimal(v))), + .unary(|v| self.bucket_decimal(v)), DataType::Date32 => input .as_any() .downcast_ref::() .unwrap() - .unary(|v| self.bucket_n(Self::hash_date(v))), + .unary(|v| self.bucket_date(v)), DataType::Time64(TimeUnit::Microsecond) => input .as_any() .downcast_ref::() .unwrap() - .unary(|v| self.bucket_n(Self::hash_time(v))), + .unary(|v| self.bucket_time(v)), DataType::Timestamp(TimeUnit::Microsecond, _) => input .as_any() .downcast_ref::() .unwrap() - .unary(|v| self.bucket_n(Self::hash_timestamp(v))), + .unary(|v| self.bucket_timestamp(v)), DataType::Utf8 => arrow_array::Int32Array::from_iter( input .as_any() .downcast_ref::() .unwrap() .iter() - .map(|v| self.bucket_n(Self::hash_str(v.unwrap()))), + .map(|v| v.map(|v| self.bucket_str(v))), ), DataType::LargeUtf8 => arrow_array::Int32Array::from_iter( input @@ -131,7 +181,7 @@ impl TransformFunction for Bucket { .downcast_ref::() .unwrap() .iter() - .map(|v| self.bucket_n(Self::hash_str(v.unwrap()))), + .map(|v| v.map(|v| self.bucket_str(v))), ), DataType::Binary => arrow_array::Int32Array::from_iter( input @@ -139,7 +189,7 @@ impl TransformFunction for Bucket { .downcast_ref::() .unwrap() .iter() - .map(|v| self.bucket_n(Self::hash_bytes(v.unwrap()))), + .map(|v| v.map(|v| self.bucket_bytes(v))), ), DataType::LargeBinary => arrow_array::Int32Array::from_iter( input @@ -147,7 +197,7 @@ impl TransformFunction for Bucket { .downcast_ref::() .unwrap() .iter() - .map(|v| self.bucket_n(Self::hash_bytes(v.unwrap()))), + .map(|v| v.map(|v| self.bucket_bytes(v))), ), DataType::FixedSizeBinary(_) => arrow_array::Int32Array::from_iter( input @@ -155,12 +205,47 @@ impl TransformFunction for Bucket { .downcast_ref::() .unwrap() .iter() - .map(|v| self.bucket_n(Self::hash_bytes(v.unwrap()))), + .map(|v| v.map(|v| self.bucket_bytes(v))), ), - _ => unreachable!("Unsupported data type: {:?}", input.data_type()), + _ => { + return Err(crate::Error::new( + crate::ErrorKind::FeatureUnsupported, + format!( + "Unsupported data type for bucket transform: {:?}", + input.data_type() + ), + )) + } }; Ok(Arc::new(res)) } + + fn transform_literal(&self, input: &Datum) -> crate::Result> { + let val = match (input.data_type(), input.literal()) { + (PrimitiveType::Int, PrimitiveLiteral::Int(v)) => self.bucket_int(*v), + (PrimitiveType::Long, PrimitiveLiteral::Long(v)) => self.bucket_long(*v), + (PrimitiveType::Decimal { .. }, PrimitiveLiteral::Int128(v)) => self.bucket_decimal(*v), + (PrimitiveType::Date, PrimitiveLiteral::Int(v)) => self.bucket_date(*v), + (PrimitiveType::Time, PrimitiveLiteral::Long(v)) => self.bucket_time(*v), + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(v)) => self.bucket_timestamp(*v), + (PrimitiveType::String, PrimitiveLiteral::String(v)) => self.bucket_str(v.as_str()), + (PrimitiveType::Uuid, PrimitiveLiteral::UInt128(v)) => { + self.bucket_bytes(uuid::Uuid::from_u128(*v).as_ref()) + } + (PrimitiveType::Binary, PrimitiveLiteral::Binary(v)) => self.bucket_bytes(v.as_ref()), + (PrimitiveType::Fixed(_), PrimitiveLiteral::Binary(v)) => self.bucket_bytes(v.as_ref()), + _ => { + return Err(crate::Error::new( + crate::ErrorKind::FeatureUnsupported, + format!( + "Unsupported data type for bucket transform: {:?}", + input.data_type() + ), + )) + } + }; + Ok(Some(Datum::int(val))) + } } #[cfg(test)] @@ -168,6 +253,451 @@ mod test { use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; use super::Bucket; + use crate::expr::PredicateOperator; + use crate::spec::PrimitiveType::{ + Binary, Date, Decimal, Fixed, Int, Long, String as StringType, Time, Timestamp, + TimestampNs, Timestamptz, TimestamptzNs, Uuid, + }; + use crate::spec::Type::{Primitive, Struct}; + use crate::spec::{Datum, NestedField, PrimitiveType, StructType, Transform, Type}; + use crate::transform::test::{TestProjectionFixture, TestTransformFixture}; + use crate::transform::TransformFunction; + use crate::Result; + + #[test] + fn test_bucket_transform() { + let trans = Transform::Bucket(8); + + let fixture = TestTransformFixture { + display: "bucket[8]".to_string(), + json: r#""bucket[8]""#.to_string(), + dedup_name: "bucket[8]".to_string(), + preserves_order: false, + satisfies_order_of: vec![ + (Transform::Bucket(8), true), + (Transform::Bucket(4), false), + (Transform::Void, false), + (Transform::Day, false), + ], + trans_types: vec![ + (Primitive(Binary), Some(Primitive(Int))), + (Primitive(Date), Some(Primitive(Int))), + ( + Primitive(Decimal { + precision: 8, + scale: 5, + }), + Some(Primitive(Int)), + ), + (Primitive(Fixed(8)), Some(Primitive(Int))), + (Primitive(Int), Some(Primitive(Int))), + (Primitive(Long), Some(Primitive(Int))), + (Primitive(StringType), Some(Primitive(Int))), + (Primitive(Uuid), Some(Primitive(Int))), + (Primitive(Time), Some(Primitive(Int))), + (Primitive(Timestamp), Some(Primitive(Int))), + (Primitive(Timestamptz), Some(Primitive(Int))), + (Primitive(TimestampNs), Some(Primitive(Int))), + (Primitive(TimestamptzNs), Some(Primitive(Int))), + ( + Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()])), + None, + ), + ], + }; + + fixture.assert_transform(trans); + } + + #[test] + fn test_projection_bucket_uuid() -> Result<()> { + let value = uuid::Uuid::from_u64_pair(123, 456); + let another = uuid::Uuid::from_u64_pair(456, 123); + + let fixture = TestProjectionFixture::new( + Transform::Bucket(10), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Uuid)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::uuid(value)), + Some("name = 4"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::uuid(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::uuid(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::uuid(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::uuid(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThanOrEq, Datum::uuid(value)), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::uuid(value), + Datum::uuid(another), + ]), + Some("name IN (4, 6)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::uuid(value), + Datum::uuid(another), + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_bucket_fixed() -> Result<()> { + let value = "abcdefg".as_bytes().to_vec(); + let another = "abcdehij".as_bytes().to_vec(); + + let fixture = TestProjectionFixture::new( + Transform::Bucket(10), + "name", + NestedField::required( + 1, + "value", + Type::Primitive(PrimitiveType::Fixed(value.len() as u64)), + ), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::fixed(value.clone())), + Some("name = 4"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::fixed(value.clone())), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::fixed(value.clone())), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::fixed(value.clone())), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::fixed(value.clone())), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::fixed(value.clone()), + ), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::fixed(value.clone()), + Datum::fixed(another.clone()), + ]), + Some("name IN (4, 6)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::fixed(value.clone()), + Datum::fixed(another.clone()), + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_bucket_string() -> Result<()> { + let value = "abcdefg"; + let another = "abcdefgabc"; + + let fixture = TestProjectionFixture::new( + Transform::Bucket(10), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::String)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::string(value)), + Some("name = 4"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::string(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::string(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::string(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::string(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThanOrEq, Datum::string(value)), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::string(value), + Datum::string(another), + ]), + Some("name IN (9, 4)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::string(value), + Datum::string(another), + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_bucket_decimal() -> Result<()> { + let prev = "99.00"; + let curr = "100.00"; + let next = "101.00"; + + let fixture = TestProjectionFixture::new( + Transform::Bucket(10), + "name", + NestedField::required( + 1, + "value", + Type::Primitive(PrimitiveType::Decimal { + precision: 9, + scale: 2, + }), + ), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::decimal_from_str(curr)?), + Some("name = 2"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::decimal_from_str(curr)?), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::decimal_from_str(curr)?), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::decimal_from_str(curr)?, + ), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::decimal_from_str(curr)?, + ), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::decimal_from_str(curr)?, + ), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::decimal_from_str(next)?, + Datum::decimal_from_str(curr)?, + Datum::decimal_from_str(prev)?, + ]), + Some("name IN (2, 6)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::decimal_from_str(curr)?, + Datum::decimal_from_str(next)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_bucket_long() -> Result<()> { + let value = 100; + let fixture = TestProjectionFixture::new( + Transform::Bucket(10), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Long)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::long(value)), + Some("name = 6"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::long(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::long(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::long(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::long(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThanOrEq, Datum::long(value)), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::long(value - 1), + Datum::long(value), + Datum::long(value + 1), + ]), + Some("name IN (8, 7, 6)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::long(value), + Datum::long(value + 1), + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_bucket_integer() -> Result<()> { + let value = 100; + + let fixture = TestProjectionFixture::new( + Transform::Bucket(10), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Int)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::int(value)), + Some("name = 6"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::int(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::int(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::int(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::int(value)), + None, + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThanOrEq, Datum::int(value)), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::int(value - 1), + Datum::int(value), + Datum::int(value + 1), + ]), + Some("name IN (8, 7, 6)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::int(value), + Datum::int(value + 1), + ]), + None, + )?; + + Ok(()) + } + #[test] fn test_hash() { // test int @@ -242,4 +772,120 @@ mod test { -188683207 ); } + + #[test] + fn test_int_literal() { + let bucket = Bucket::new(10); + assert_eq!( + bucket.transform_literal(&Datum::int(34)).unwrap().unwrap(), + Datum::int(9) + ); + } + + #[test] + fn test_long_literal() { + let bucket = Bucket::new(10); + assert_eq!( + bucket.transform_literal(&Datum::long(34)).unwrap().unwrap(), + Datum::int(9) + ); + } + + #[test] + fn test_decimal_literal() { + let bucket = Bucket::new(10); + assert_eq!( + bucket + .transform_literal(&Datum::decimal(1420).unwrap()) + .unwrap() + .unwrap(), + Datum::int(9) + ); + } + + #[test] + fn test_date_literal() { + let bucket = Bucket::new(100); + assert_eq!( + bucket + .transform_literal(&Datum::date(17486)) + .unwrap() + .unwrap(), + Datum::int(26) + ); + } + + #[test] + fn test_time_literal() { + let bucket = Bucket::new(100); + assert_eq!( + bucket + .transform_literal(&Datum::time_micros(81068000000).unwrap()) + .unwrap() + .unwrap(), + Datum::int(59) + ); + } + + #[test] + fn test_timestamp_literal() { + let bucket = Bucket::new(100); + assert_eq!( + bucket + .transform_literal(&Datum::timestamp_micros(1510871468000000)) + .unwrap() + .unwrap(), + Datum::int(7) + ); + } + + #[test] + fn test_str_literal() { + let bucket = Bucket::new(100); + assert_eq!( + bucket + .transform_literal(&Datum::string("iceberg")) + .unwrap() + .unwrap(), + Datum::int(89) + ); + } + + #[test] + fn test_uuid_literal() { + let bucket = Bucket::new(100); + assert_eq!( + bucket + .transform_literal(&Datum::uuid( + "F79C3E09-677C-4BBD-A479-3F349CB785E7".parse().unwrap() + )) + .unwrap() + .unwrap(), + Datum::int(40) + ); + } + + #[test] + fn test_binary_literal() { + let bucket = Bucket::new(128); + assert_eq!( + bucket + .transform_literal(&Datum::binary(b"\x00\x01\x02\x03".to_vec())) + .unwrap() + .unwrap(), + Datum::int(57) + ); + } + + #[test] + fn test_fixed_literal() { + let bucket = Bucket::new(128); + assert_eq!( + bucket + .transform_literal(&Datum::fixed(b"foo".to_vec())) + .unwrap() + .unwrap(), + Datum::int(32) + ); + } } diff --git a/crates/iceberg/src/transform/identity.rs b/crates/iceberg/src/transform/identity.rs index d22c28fde..68e5a0b1a 100644 --- a/crates/iceberg/src/transform/identity.rs +++ b/crates/iceberg/src/transform/identity.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::Result; use arrow_array::ArrayRef; use super::TransformFunction; +use crate::Result; /// Return identity array. #[derive(Debug)] @@ -28,4 +28,73 @@ impl TransformFunction for Identity { fn transform(&self, input: ArrayRef) -> Result { Ok(input) } + + fn transform_literal(&self, input: &crate::spec::Datum) -> Result> { + Ok(Some(input.clone())) + } +} + +#[cfg(test)] +mod test { + use crate::spec::PrimitiveType::{ + Binary, Date, Decimal, Fixed, Int, Long, String as StringType, Time, Timestamp, + TimestampNs, Timestamptz, TimestamptzNs, Uuid, + }; + use crate::spec::Type::{Primitive, Struct}; + use crate::spec::{NestedField, StructType, Transform}; + use crate::transform::test::TestTransformFixture; + + #[test] + fn test_identity_transform() { + let trans = Transform::Identity; + + let fixture = TestTransformFixture { + display: "identity".to_string(), + json: r#""identity""#.to_string(), + dedup_name: "identity".to_string(), + preserves_order: true, + satisfies_order_of: vec![ + (Transform::Truncate(4), true), + (Transform::Truncate(2), true), + (Transform::Bucket(4), false), + (Transform::Void, false), + (Transform::Day, true), + ], + trans_types: vec![ + (Primitive(Binary), Some(Primitive(Binary))), + (Primitive(Date), Some(Primitive(Date))), + ( + Primitive(Decimal { + precision: 8, + scale: 5, + }), + Some(Primitive(Decimal { + precision: 8, + scale: 5, + })), + ), + (Primitive(Fixed(8)), Some(Primitive(Fixed(8)))), + (Primitive(Int), Some(Primitive(Int))), + (Primitive(Long), Some(Primitive(Long))), + (Primitive(StringType), Some(Primitive(StringType))), + (Primitive(Uuid), Some(Primitive(Uuid))), + (Primitive(Time), Some(Primitive(Time))), + (Primitive(Timestamp), Some(Primitive(Timestamp))), + (Primitive(Timestamptz), Some(Primitive(Timestamptz))), + (Primitive(TimestampNs), Some(Primitive(TimestampNs))), + (Primitive(TimestamptzNs), Some(Primitive(TimestamptzNs))), + ( + Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()])), + None, + ), + ], + }; + + fixture.assert_transform(trans); + } } diff --git a/crates/iceberg/src/transform/mod.rs b/crates/iceberg/src/transform/mod.rs index dead9db89..72b179754 100644 --- a/crates/iceberg/src/transform/mod.rs +++ b/crates/iceberg/src/transform/mod.rs @@ -16,9 +16,12 @@ // under the License. //! Transform function used to compute partition values. -use crate::{spec::Transform, Result}; + use arrow_array::ArrayRef; +use crate::spec::{Datum, Transform}; +use crate::{Error, ErrorKind, Result}; + mod bucket; mod identity; mod temporal; @@ -31,6 +34,18 @@ pub trait TransformFunction: Send { /// The implementation of this function will need to check and downcast the input to specific /// type. fn transform(&self, input: ArrayRef) -> Result; + /// transform_literal will take an input literal and transform it into a new literal. + fn transform_literal(&self, input: &Datum) -> Result>; + /// A thin wrapper around `transform_literal` + /// to return an error even when it's `None`. + fn transform_literal_result(&self, input: &Datum) -> Result { + self.transform_literal(input)?.ok_or_else(|| { + Error::new( + ErrorKind::Unexpected, + format!("Returns 'None' for literal {}", input), + ) + }) + } } /// BoxedTransformFunction is a boxed trait object of TransformFunction. @@ -53,3 +68,116 @@ pub fn create_transform_function(transform: &Transform) -> Result, + field: NestedField, + ) -> Self { + TestProjectionFixture { + transform, + name: name.into(), + field: Arc::new(field), + } + } + pub(crate) fn binary_predicate( + &self, + op: PredicateOperator, + literal: Datum, + ) -> BoundPredicate { + BoundPredicate::Binary(BinaryExpression::new( + op, + BoundReference::new( + self.name.clone(), + self.field.clone(), + Arc::new(StructAccessor::new(1, PrimitiveType::Boolean)), + ), + literal, + )) + } + pub(crate) fn set_predicate( + &self, + op: PredicateOperator, + literals: Vec, + ) -> BoundPredicate { + BoundPredicate::Set(SetExpression::new( + op, + BoundReference::new( + self.name.clone(), + self.field.clone(), + Arc::new(StructAccessor::new(1, PrimitiveType::Boolean)), + ), + HashSet::from_iter(literals), + )) + } + pub(crate) fn assert_projection( + &self, + predicate: &BoundPredicate, + expected: Option<&str>, + ) -> Result<()> { + let result = self.transform.project(&self.name, predicate)?; + match expected { + Some(exp) => assert_eq!(format!("{}", result.unwrap()), exp), + None => assert!(result.is_none()), + } + Ok(()) + } + } + + /// A utility struct, test fixture + /// used for testing the transform on `Transform` + pub(crate) struct TestTransformFixture { + pub display: String, + pub json: String, + pub dedup_name: String, + pub preserves_order: bool, + pub satisfies_order_of: Vec<(Transform, bool)>, + pub trans_types: Vec<(Type, Option)>, + } + + impl TestTransformFixture { + pub(crate) fn assert_transform(&self, trans: Transform) { + assert_eq!(self.display, format!("{trans}")); + assert_eq!(self.json, serde_json::to_string(&trans).unwrap()); + assert_eq!(trans, serde_json::from_str(self.json.as_str()).unwrap()); + assert_eq!(self.dedup_name, trans.dedup_name()); + assert_eq!(self.preserves_order, trans.preserves_order()); + + for (other_trans, satisfies_order_of) in &self.satisfies_order_of { + assert_eq!( + satisfies_order_of, + &trans.satisfies_order_of(other_trans), + "Failed to check satisfies order {}, {}, {}", + trans, + other_trans, + satisfies_order_of + ); + } + + for (input_type, result_type) in &self.trans_types { + assert_eq!(result_type, &trans.result_type(input_type).ok()); + } + } + } +} diff --git a/crates/iceberg/src/transform/temporal.rs b/crates/iceberg/src/transform/temporal.rs index 7b8deb17d..f326cfed6 100644 --- a/crates/iceberg/src/transform/temporal.rs +++ b/crates/iceberg/src/transform/temporal.rs @@ -15,36 +15,58 @@ // specific language governing permissions and limitations // under the License. -use super::TransformFunction; -use crate::{Error, ErrorKind, Result}; -use arrow_arith::{ - arity::binary, - temporal::{month_dyn, year_dyn}, -}; +use std::sync::Arc; + +use arrow_arith::arity::binary; +use arrow_arith::temporal::{date_part, DatePart}; +use arrow_array::types::Date32Type; use arrow_array::{ - types::Date32Type, Array, ArrayRef, Date32Array, Int32Array, TimestampMicrosecondArray, + Array, ArrayRef, Date32Array, Int32Array, TimestampMicrosecondArray, TimestampNanosecondArray, }; use arrow_schema::{DataType, TimeUnit}; -use chrono::Datelike; -use std::sync::Arc; +use chrono::{DateTime, Datelike, Duration}; + +use super::TransformFunction; +use crate::spec::{Datum, PrimitiveLiteral, PrimitiveType}; +use crate::{Error, ErrorKind, Result}; -/// The number of days since unix epoch. -const DAY_SINCE_UNIX_EPOCH: i32 = 719163; /// Hour in one second. const HOUR_PER_SECOND: f64 = 1.0_f64 / 3600.0_f64; -/// Day in one second. -const DAY_PER_SECOND: f64 = 1.0_f64 / 24.0_f64 / 3600.0_f64; /// Year of unix epoch. const UNIX_EPOCH_YEAR: i32 = 1970; +/// One second in micros. +const MICROS_PER_SECOND: i64 = 1_000_000; +/// One second in nanos. +const NANOS_PER_SECOND: i64 = 1_000_000_000; /// Extract a date or timestamp year, as years from 1970 #[derive(Debug)] pub struct Year; +impl Year { + #[inline] + fn timestamp_to_year_micros(timestamp: i64) -> Result { + Ok(DateTime::from_timestamp_micros(timestamp) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Fail to convert timestamp to date in year transform", + ) + })? + .year() + - UNIX_EPOCH_YEAR) + } + + #[inline] + fn timestamp_to_year_nanos(timestamp: i64) -> Result { + Ok(DateTime::from_timestamp_nanos(timestamp).year() - UNIX_EPOCH_YEAR) + } +} + impl TransformFunction for Year { fn transform(&self, input: ArrayRef) -> Result { - let array = - year_dyn(&input).map_err(|err| Error::new(ErrorKind::Unexpected, format!("{err}")))?; + let array = date_part(&input, DatePart::Year) + .map_err(|err| Error::new(ErrorKind::Unexpected, format!("{err}")))?; Ok(Arc::::new( array .as_any() @@ -53,23 +75,93 @@ impl TransformFunction for Year { .unary(|v| v - UNIX_EPOCH_YEAR), )) } + + fn transform_literal(&self, input: &crate::spec::Datum) -> Result> { + let val = match (input.data_type(), input.literal()) { + (PrimitiveType::Date, PrimitiveLiteral::Int(v)) => { + Date32Type::to_naive_date(*v).year() - UNIX_EPOCH_YEAR + } + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(v)) => { + Self::timestamp_to_year_micros(*v)? + } + (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(v)) => { + Self::timestamp_to_year_micros(*v)? + } + (PrimitiveType::TimestampNs, PrimitiveLiteral::Long(v)) => { + Self::timestamp_to_year_nanos(*v)? + } + (PrimitiveType::TimestamptzNs, PrimitiveLiteral::Long(v)) => { + Self::timestamp_to_year_nanos(*v)? + } + _ => { + return Err(crate::Error::new( + crate::ErrorKind::FeatureUnsupported, + format!( + "Unsupported data type for year transform: {:?}", + input.data_type() + ), + )) + } + }; + Ok(Some(Datum::int(val))) + } } /// Extract a date or timestamp month, as months from 1970-01-01 #[derive(Debug)] pub struct Month; +impl Month { + #[inline] + fn timestamp_to_month_micros(timestamp: i64) -> Result { + // date: aaaa-aa-aa + // unix epoch date: 1970-01-01 + // if date > unix epoch date, delta month = (aa - 1) + 12 * (aaaa-1970) + // if date < unix epoch date, delta month = (12 - (aa - 1)) + 12 * (1970-aaaa-1) + let date = DateTime::from_timestamp_micros(timestamp).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Fail to convert timestamp to date in month transform", + ) + })?; + let unix_epoch_date = DateTime::from_timestamp_micros(0) + .expect("0 timestamp from unix epoch should be valid"); + if date > unix_epoch_date { + Ok((date.month0() as i32) + 12 * (date.year() - UNIX_EPOCH_YEAR)) + } else { + let delta = (12 - date.month0() as i32) + 12 * (UNIX_EPOCH_YEAR - date.year() - 1); + Ok(-delta) + } + } + + #[inline] + fn timestamp_to_month_nanos(timestamp: i64) -> Result { + // date: aaaa-aa-aa + // unix epoch date: 1970-01-01 + // if date > unix epoch date, delta month = (aa - 1) + 12 * (aaaa-1970) + // if date < unix epoch date, delta month = (12 - (aa - 1)) + 12 * (1970-aaaa-1) + let date = DateTime::from_timestamp_nanos(timestamp); + let unix_epoch_date = DateTime::from_timestamp_nanos(0); + if date > unix_epoch_date { + Ok((date.month0() as i32) + 12 * (date.year() - UNIX_EPOCH_YEAR)) + } else { + let delta = (12 - date.month0() as i32) + 12 * (UNIX_EPOCH_YEAR - date.year() - 1); + Ok(-delta) + } + } +} + impl TransformFunction for Month { fn transform(&self, input: ArrayRef) -> Result { - let year_array = - year_dyn(&input).map_err(|err| Error::new(ErrorKind::Unexpected, format!("{err}")))?; + let year_array = date_part(&input, DatePart::Year) + .map_err(|err| Error::new(ErrorKind::Unexpected, format!("{err}")))?; let year_array: Int32Array = year_array .as_any() .downcast_ref::() .unwrap() .unary(|v| 12 * (v - UNIX_EPOCH_YEAR)); - let month_array = - month_dyn(&input).map_err(|err| Error::new(ErrorKind::Unexpected, format!("{err}")))?; + let month_array = date_part(&input, DatePart::Month) + .map_err(|err| Error::new(ErrorKind::Unexpected, format!("{err}")))?; Ok(Arc::::new( binary( month_array.as_any().downcast_ref::().unwrap(), @@ -80,12 +172,104 @@ impl TransformFunction for Month { .unwrap(), )) } + + fn transform_literal(&self, input: &crate::spec::Datum) -> Result> { + let val = match (input.data_type(), input.literal()) { + (PrimitiveType::Date, PrimitiveLiteral::Int(v)) => { + (Date32Type::to_naive_date(*v).year() - UNIX_EPOCH_YEAR) * 12 + + Date32Type::to_naive_date(*v).month0() as i32 + } + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(v)) => { + Self::timestamp_to_month_micros(*v)? + } + (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(v)) => { + Self::timestamp_to_month_micros(*v)? + } + (PrimitiveType::TimestampNs, PrimitiveLiteral::Long(v)) => { + Self::timestamp_to_month_nanos(*v)? + } + (PrimitiveType::TimestamptzNs, PrimitiveLiteral::Long(v)) => { + Self::timestamp_to_month_nanos(*v)? + } + _ => { + return Err(crate::Error::new( + crate::ErrorKind::FeatureUnsupported, + format!( + "Unsupported data type for month transform: {:?}", + input.data_type() + ), + )) + } + }; + Ok(Some(Datum::int(val))) + } } /// Extract a date or timestamp day, as days from 1970-01-01 #[derive(Debug)] pub struct Day; +impl Day { + #[inline] + fn day_timestamp_micro(v: i64) -> Result { + let secs = v / MICROS_PER_SECOND; + + let (nanos, offset) = if v >= 0 { + let nanos = (v.rem_euclid(MICROS_PER_SECOND) * 1_000) as u32; + let offset = 0i64; + (nanos, offset) + } else { + let v = v + 1; + let nanos = (v.rem_euclid(MICROS_PER_SECOND) * 1_000) as u32; + let offset = 1i64; + (nanos, offset) + }; + + let delta = Duration::new(secs, nanos).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!( + "Failed to create 'TimeDelta' from seconds {} and nanos {}", + secs, nanos + ), + ) + })?; + + let days = (delta.num_days() - offset) as i32; + + Ok(days) + } + + fn day_timestamp_nano(v: i64) -> Result { + let secs = v / NANOS_PER_SECOND; + + let (nanos, offset) = if v >= 0 { + let nanos = (v.rem_euclid(NANOS_PER_SECOND)) as u32; + let offset = 0i64; + (nanos, offset) + } else { + let v = v + 1; + let nanos = (v.rem_euclid(NANOS_PER_SECOND)) as u32; + let offset = 1i64; + (nanos, offset) + }; + + let delta = Duration::new(secs, nanos).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!( + "Failed to create 'TimeDelta' from seconds {} and nanos {}", + secs, nanos + ), + ) + })?; + + let days = (delta.num_days() - offset) as i32; + + Ok(days) + } +} + impl TransformFunction for Day { fn transform(&self, input: ArrayRef) -> Result { let res: Int32Array = match input.data_type() { @@ -93,16 +277,17 @@ impl TransformFunction for Day { .as_any() .downcast_ref::() .unwrap() - .unary(|v| -> i32 { (v as f64 / 1000.0 / 1000.0 * DAY_PER_SECOND) as i32 }), - DataType::Date32 => { - input - .as_any() - .downcast_ref::() - .unwrap() - .unary(|v| -> i32 { - Date32Type::to_naive_date(v).num_days_from_ce() - DAY_SINCE_UNIX_EPOCH - }) - } + .try_unary(|v| -> Result { Self::day_timestamp_micro(v) })?, + DataType::Timestamp(TimeUnit::Nanosecond, _) => input + .as_any() + .downcast_ref::() + .unwrap() + .try_unary(|v| -> Result { Self::day_timestamp_nano(v) })?, + DataType::Date32 => input + .as_any() + .downcast_ref::() + .unwrap() + .unary(|v| -> i32 { v }), _ => { return Err(Error::new( ErrorKind::Unexpected, @@ -115,12 +300,50 @@ impl TransformFunction for Day { }; Ok(Arc::new(res)) } + + fn transform_literal(&self, input: &crate::spec::Datum) -> Result> { + let val = match (input.data_type(), input.literal()) { + (PrimitiveType::Date, PrimitiveLiteral::Int(v)) => *v, + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(v)) => Self::day_timestamp_micro(*v)?, + (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(v)) => { + Self::day_timestamp_micro(*v)? + } + (PrimitiveType::TimestampNs, PrimitiveLiteral::Long(v)) => { + Self::day_timestamp_nano(*v)? + } + (PrimitiveType::TimestamptzNs, PrimitiveLiteral::Long(v)) => { + Self::day_timestamp_nano(*v)? + } + _ => { + return Err(crate::Error::new( + crate::ErrorKind::FeatureUnsupported, + format!( + "Unsupported data type for day transform: {:?}", + input.data_type() + ), + )) + } + }; + Ok(Some(Datum::int(val))) + } } /// Extract a timestamp hour, as hours from 1970-01-01 00:00:00 #[derive(Debug)] pub struct Hour; +impl Hour { + #[inline] + fn hour_timestamp_micro(v: i64) -> i32 { + (v as f64 / 1000.0 / 1000.0 * HOUR_PER_SECOND) as i32 + } + + #[inline] + fn hour_timestamp_nano(v: i64) -> i32 { + (v as f64 / 1_000_000.0 / 1000.0 * HOUR_PER_SECOND) as i32 + } +} + impl TransformFunction for Hour { fn transform(&self, input: ArrayRef) -> Result { let res: Int32Array = match input.data_type() { @@ -128,28 +351,1967 @@ impl TransformFunction for Hour { .as_any() .downcast_ref::() .unwrap() - .unary(|v| -> i32 { (v as f64 * HOUR_PER_SECOND / 1000.0 / 1000.0) as i32 }), + .unary(|v| -> i32 { Self::hour_timestamp_micro(v) }), _ => { - return Err(Error::new( - ErrorKind::Unexpected, + return Err(crate::Error::new( + crate::ErrorKind::FeatureUnsupported, format!( - "Should not call internally for unsupported data type {:?}", + "Unsupported data type for hour transform: {:?}", input.data_type() ), - )) + )); } }; Ok(Arc::new(res)) } + + fn transform_literal(&self, input: &crate::spec::Datum) -> Result> { + let val = match (input.data_type(), input.literal()) { + (PrimitiveType::Timestamp, PrimitiveLiteral::Long(v)) => Self::hour_timestamp_micro(*v), + (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(v)) => { + Self::hour_timestamp_micro(*v) + } + (PrimitiveType::TimestampNs, PrimitiveLiteral::Long(v)) => { + Self::hour_timestamp_nano(*v) + } + (PrimitiveType::TimestamptzNs, PrimitiveLiteral::Long(v)) => { + Self::hour_timestamp_nano(*v) + } + _ => { + return Err(crate::Error::new( + crate::ErrorKind::FeatureUnsupported, + format!( + "Unsupported data type for hour transform: {:?}", + input.data_type() + ), + )) + } + }; + Ok(Some(Datum::int(val))) + } } #[cfg(test)] mod test { + use std::sync::Arc; + use arrow_array::{ArrayRef, Date32Array, Int32Array, TimestampMicrosecondArray}; use chrono::{NaiveDate, NaiveDateTime}; - use std::sync::Arc; - use crate::transform::TransformFunction; + use crate::expr::PredicateOperator; + use crate::spec::PrimitiveType::{ + Binary, Date, Decimal, Fixed, Int, Long, String as StringType, Time, Timestamp, + TimestampNs, Timestamptz, TimestamptzNs, Uuid, + }; + use crate::spec::Type::{Primitive, Struct}; + use crate::spec::{Datum, NestedField, PrimitiveType, StructType, Transform, Type}; + use crate::transform::test::{TestProjectionFixture, TestTransformFixture}; + use crate::transform::{BoxedTransformFunction, TransformFunction}; + use crate::Result; + + #[test] + fn test_year_transform() { + let trans = Transform::Year; + + let fixture = TestTransformFixture { + display: "year".to_string(), + json: r#""year""#.to_string(), + dedup_name: "time".to_string(), + preserves_order: true, + satisfies_order_of: vec![ + (Transform::Year, true), + (Transform::Month, false), + (Transform::Day, false), + (Transform::Hour, false), + (Transform::Void, false), + (Transform::Identity, false), + ], + trans_types: vec![ + (Primitive(Binary), None), + (Primitive(Date), Some(Primitive(Date))), + ( + Primitive(Decimal { + precision: 8, + scale: 5, + }), + None, + ), + (Primitive(Fixed(8)), None), + (Primitive(Int), None), + (Primitive(Long), None), + (Primitive(StringType), None), + (Primitive(Uuid), None), + (Primitive(Time), None), + (Primitive(Timestamp), Some(Primitive(Date))), + (Primitive(Timestamptz), Some(Primitive(Date))), + (Primitive(TimestampNs), Some(Primitive(Date))), + (Primitive(TimestamptzNs), Some(Primitive(Date))), + ( + Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()])), + None, + ), + ], + }; + + fixture.assert_transform(trans); + } + + #[test] + fn test_month_transform() { + let trans = Transform::Month; + + let fixture = TestTransformFixture { + display: "month".to_string(), + json: r#""month""#.to_string(), + dedup_name: "time".to_string(), + preserves_order: true, + satisfies_order_of: vec![ + (Transform::Year, true), + (Transform::Month, true), + (Transform::Day, false), + (Transform::Hour, false), + (Transform::Void, false), + (Transform::Identity, false), + ], + trans_types: vec![ + (Primitive(Binary), None), + (Primitive(Date), Some(Primitive(Date))), + ( + Primitive(Decimal { + precision: 8, + scale: 5, + }), + None, + ), + (Primitive(Fixed(8)), None), + (Primitive(Int), None), + (Primitive(Long), None), + (Primitive(StringType), None), + (Primitive(Uuid), None), + (Primitive(Time), None), + (Primitive(Timestamp), Some(Primitive(Date))), + (Primitive(Timestamptz), Some(Primitive(Date))), + (Primitive(TimestampNs), Some(Primitive(Date))), + (Primitive(TimestamptzNs), Some(Primitive(Date))), + ( + Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()])), + None, + ), + ], + }; + + fixture.assert_transform(trans); + } + + #[test] + fn test_day_transform() { + let trans = Transform::Day; + + let fixture = TestTransformFixture { + display: "day".to_string(), + json: r#""day""#.to_string(), + dedup_name: "time".to_string(), + preserves_order: true, + satisfies_order_of: vec![ + (Transform::Year, true), + (Transform::Month, true), + (Transform::Day, true), + (Transform::Hour, false), + (Transform::Void, false), + (Transform::Identity, false), + ], + trans_types: vec![ + (Primitive(Binary), None), + (Primitive(Date), Some(Primitive(Date))), + ( + Primitive(Decimal { + precision: 8, + scale: 5, + }), + None, + ), + (Primitive(Fixed(8)), None), + (Primitive(Int), None), + (Primitive(Long), None), + (Primitive(StringType), None), + (Primitive(Uuid), None), + (Primitive(Time), None), + (Primitive(Timestamp), Some(Primitive(Date))), + (Primitive(Timestamptz), Some(Primitive(Date))), + (Primitive(TimestampNs), Some(Primitive(Date))), + (Primitive(TimestamptzNs), Some(Primitive(Date))), + ( + Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()])), + None, + ), + ], + }; + + fixture.assert_transform(trans); + } + + #[test] + fn test_hour_transform() { + let trans = Transform::Hour; + + let fixture = TestTransformFixture { + display: "hour".to_string(), + json: r#""hour""#.to_string(), + dedup_name: "time".to_string(), + preserves_order: true, + satisfies_order_of: vec![ + (Transform::Year, true), + (Transform::Month, true), + (Transform::Day, true), + (Transform::Hour, true), + (Transform::Void, false), + (Transform::Identity, false), + ], + trans_types: vec![ + (Primitive(Binary), None), + (Primitive(Date), None), + ( + Primitive(Decimal { + precision: 8, + scale: 5, + }), + None, + ), + (Primitive(Fixed(8)), None), + (Primitive(Int), None), + (Primitive(Long), None), + (Primitive(StringType), None), + (Primitive(Uuid), None), + (Primitive(Time), None), + (Primitive(Timestamp), Some(Primitive(Int))), + (Primitive(Timestamptz), Some(Primitive(Int))), + (Primitive(TimestampNs), Some(Primitive(Int))), + (Primitive(TimestamptzNs), Some(Primitive(Int))), + ( + Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()])), + None, + ), + ], + }; + + fixture.assert_transform(trans); + } + + #[test] + fn test_projection_timestamp_hour_upper_bound() -> Result<()> { + // 420034 + let value = "2017-12-01T10:59:59.999999"; + // 412007 + let another = "2016-12-31T23:59:59.999999"; + + let fixture = TestProjectionFixture::new( + Transform::Hour, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 420034"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 420034"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 420035"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 420034"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name = 420034"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (420034, 412007)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_hour_lower_bound() -> Result<()> { + // 420034 + let value = "2017-12-01T10:00:00.000000"; + // 411288 + let another = "2016-12-02T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Hour, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 420033"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 420034"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 420034"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 420034"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name = 420034"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (411288, 420034)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_year_upper_bound() -> Result<()> { + let value = "2017-12-31T23:59:59.999999"; + let another = "2016-12-31T23:59:59.999999"; + + let fixture = TestProjectionFixture::new( + Transform::Year, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 48"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name = 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (47, 46)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_year_lower_bound() -> Result<()> { + let value = "2017-01-01T00:00:00.000000"; + let another = "2016-12-02T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Year, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 46"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name = 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (47, 46)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_month_negative_upper_bound() -> Result<()> { + let value = "1969-12-31T23:59:59.999999"; + let another = "1970-01-01T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Month, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= -1"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name IN (-1, 0)"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (0, -1)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_month_upper_bound() -> Result<()> { + let value = "2017-12-01T23:59:59.999999"; + let another = "2017-11-02T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Month, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name = 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (575, 574)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + Ok(()) + } + + #[test] + fn test_projection_timestamp_month_negative_lower_bound() -> Result<()> { + let value = "1969-01-01T00:00:00.000000"; + let another = "1969-03-01T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Month, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= -12"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= -11"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= -12"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= -12"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name IN (-12, -11)"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (-10, -9, -12, -11)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_month_lower_bound() -> Result<()> { + let value = "2017-12-01T00:00:00.000000"; + let another = "2017-12-02T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Month, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 574"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name = 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (575)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_day_negative_upper_bound() -> Result<()> { + // -1 + let value = "1969-12-31T23:59:59.999999"; + // 0 + let another = "1970-01-01T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Day, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= -1"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name IN (-1, 0)"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (0, -1)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_day_upper_bound() -> Result<()> { + // 17501 + let value = "2017-12-01T23:59:59.999999"; + // 17502 + let another = "2017-12-02T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Day, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 17501"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 17501"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 17502"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 17501"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name = 17501"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (17501, 17502)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_day_negative_lower_bound() -> Result<()> { + // -365 + let value = "1969-01-01T00:00:00.000000"; + // -364 + let another = "1969-01-02T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Day, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= -365"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= -364"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= -365"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= -365"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name IN (-364, -365)"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (-363, -365, -364)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_day_lower_bound() -> Result<()> { + // 17501 + let value = "2017-12-01T00:00:00.000000"; + // 17502 + let another = "2017-12-02T00:00:00.000000"; + + let fixture = TestProjectionFixture::new( + Transform::Day, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 17500"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 17501"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 17501"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 17501"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name = 17501"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (17501, 17502)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_timestamp_day_epoch() -> Result<()> { + // 0 + let value = "1970-01-01T00:00:00.00000"; + // 1 + let another = "1970-01-02T00:00:00.00000"; + + let fixture = TestProjectionFixture::new( + Transform::Day, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Timestamp)), + ); + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThan, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThan, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::timestamp_from_str(value)?, + ), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::timestamp_from_str(value)?), + Some("name = 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::timestamp_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + Some("name IN (1, 0)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::timestamp_from_str(value)?, + Datum::timestamp_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_day_negative() -> Result<()> { + // -2 + let value = "1969-12-30"; + // -4 + let another = "1969-12-28"; + + let fixture = TestProjectionFixture::new( + Transform::Day, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= -3"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= -2"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= -1"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= -2"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name = -2"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (-2, -4)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_day() -> Result<()> { + // 17167 + let value = "2017-01-01"; + // 17531 + let another = "2017-12-31"; + + let fixture = TestProjectionFixture::new( + Transform::Day, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= 17166"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= 17167"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= 17168"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= 17167"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name = 17167"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (17531, 17167)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_month_negative_upper_bound() -> Result<()> { + // -1 => 1969-12 + let value = "1969-12-31"; + // -12 => 1969-01 + let another = "1969-01-01"; + + let fixture = TestProjectionFixture::new( + Transform::Month, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= -1"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name IN (-1, 0)"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (-1, -12, -11, 0)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_month_upper_bound() -> Result<()> { + // 575 => 2017-12 + let value = "2017-12-31"; + // 564 => 2017-01 + let another = "2017-01-01"; + + let fixture = TestProjectionFixture::new( + Transform::Month, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= 576"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name = 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (575, 564)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_month_negative_lower_bound() -> Result<()> { + // -12 => 1969-01 + let value = "1969-01-01"; + // -1 => 1969-12 + let another = "1969-12-31"; + + let fixture = TestProjectionFixture::new( + Transform::Month, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= -12"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= -11"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= -12"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= -12"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name IN (-12, -11)"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (-1, -12, -11, 0)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_month_lower_bound() -> Result<()> { + // 575 => 2017-12 + let value = "2017-12-01"; + // 564 => 2017-01 + let another = "2017-01-01"; + + let fixture = TestProjectionFixture::new( + Transform::Month, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= 574"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name = 575"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (575, 564)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_month_epoch() -> Result<()> { + // 0 => 1970-01 + let value = "1970-01-01"; + // -1 => 1969-12 + let another = "1969-12-31"; + + let fixture = TestProjectionFixture::new( + Transform::Month, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name = 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (0, -1)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_year_negative_upper_bound() -> Result<()> { + // -1 => 1969 + let value = "1969-12-31"; + let another = "1969-01-01"; + + let fixture = TestProjectionFixture::new( + Transform::Year, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= -1"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name IN (-1, 0)"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (0, -1)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_year_upper_bound() -> Result<()> { + // 47 => 2017 + let value = "2017-12-31"; + // 46 => 2016 + let another = "2016-01-01"; + + let fixture = TestProjectionFixture::new( + Transform::Year, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= 48"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name = 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (47, 46)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_year_negative_lower_bound() -> Result<()> { + // 0 => 1970 + let value = "1970-01-01"; + // -1 => 1969 + let another = "1969-12-31"; + + let fixture = TestProjectionFixture::new( + Transform::Year, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name = 0"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (0, -1)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_date_year_lower_bound() -> Result<()> { + // 47 => 2017 + let value = "2017-01-01"; + // 46 => 2016 + let another = "2016-12-31"; + + let fixture = TestProjectionFixture::new( + Transform::Year, + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Date)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::date_from_str(value)?), + Some("name <= 46"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name <= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::date_from_str(value)?), + Some("name >= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::date_from_str(value)?, + ), + Some("name >= 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::date_from_str(value)?), + Some("name = 47"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::date_from_str(value)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + Some("name IN (47, 46)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::date_from_str(value)?, + Datum::date_from_str(another)?, + ]), + None, + )?; + + Ok(()) + } #[test] fn test_transform_years() { @@ -161,6 +2323,7 @@ mod test { NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), NaiveDate::from_ymd_opt(2030, 1, 1).unwrap(), NaiveDate::from_ymd_opt(2060, 1, 1).unwrap(), + NaiveDate::from_ymd_opt(1969, 1, 1).unwrap(), ]; let date_array: ArrayRef = Arc::new(Date32Array::from( ori_date @@ -173,11 +2336,12 @@ mod test { )); let res = year.transform(date_array).unwrap(); let res = res.as_any().downcast_ref::().unwrap(); - assert_eq!(res.len(), 4); + assert_eq!(res.len(), 5); assert_eq!(res.value(0), 0); assert_eq!(res.value(1), 30); assert_eq!(res.value(2), 60); assert_eq!(res.value(3), 90); + assert_eq!(res.value(4), -1); // Test TimestampMicrosecond let ori_timestamp = vec![ @@ -189,6 +2353,8 @@ mod test { .unwrap(), NaiveDateTime::parse_from_str("2060-01-01 11:30:42.123", "%Y-%m-%d %H:%M:%S.%f") .unwrap(), + NaiveDateTime::parse_from_str("1969-01-01 00:00:00.00", "%Y-%m-%d %H:%M:%S.%f") + .unwrap(), ]; let date_array: ArrayRef = Arc::new(TimestampMicrosecondArray::from( ori_timestamp @@ -209,11 +2375,123 @@ mod test { )); let res = year.transform(date_array).unwrap(); let res = res.as_any().downcast_ref::().unwrap(); - assert_eq!(res.len(), 4); + assert_eq!(res.len(), 5); assert_eq!(res.value(0), 0); assert_eq!(res.value(1), 30); assert_eq!(res.value(2), 60); assert_eq!(res.value(3), 90); + assert_eq!(res.value(4), -1); + } + + fn test_timestamp_and_tz_transform( + time: &str, + transform: &BoxedTransformFunction, + expect: Datum, + ) { + let timestamp = Datum::timestamp_micros( + NaiveDateTime::parse_from_str(time, "%Y-%m-%d %H:%M:%S.%f") + .unwrap() + .and_utc() + .timestamp_micros(), + ); + let timestamp_tz = Datum::timestamptz_micros( + NaiveDateTime::parse_from_str(time, "%Y-%m-%d %H:%M:%S.%f") + .unwrap() + .and_utc() + .timestamp_micros(), + ); + let res = transform.transform_literal(×tamp).unwrap().unwrap(); + assert_eq!(res, expect); + let res = transform.transform_literal(×tamp_tz).unwrap().unwrap(); + assert_eq!(res, expect); + } + + fn test_timestamp_and_tz_transform_using_i64( + time: i64, + transform: &BoxedTransformFunction, + expect: Datum, + ) { + let timestamp = Datum::timestamp_micros(time); + let timestamp_tz = Datum::timestamptz_micros(time); + let res = transform.transform_literal(×tamp).unwrap().unwrap(); + assert_eq!(res, expect); + let res = transform.transform_literal(×tamp_tz).unwrap().unwrap(); + assert_eq!(res, expect); + } + + fn test_date(date: i32, transform: &BoxedTransformFunction, expect: Datum) { + let date = Datum::date(date); + let res = transform.transform_literal(&date).unwrap().unwrap(); + assert_eq!(res, expect); + } + + fn test_timestamp_ns_and_tz_transform( + time: &str, + transform: &BoxedTransformFunction, + expect: Datum, + ) { + let timestamp_ns = Datum::timestamp_nanos( + NaiveDateTime::parse_from_str(time, "%Y-%m-%d %H:%M:%S.%f") + .unwrap() + .and_utc() + .timestamp_nanos_opt() + .unwrap(), + ); + let timestamptz_ns = Datum::timestamptz_nanos( + NaiveDateTime::parse_from_str(time, "%Y-%m-%d %H:%M:%S.%f") + .unwrap() + .and_utc() + .timestamp_nanos_opt() + .unwrap(), + ); + let res = transform.transform_literal(×tamp_ns).unwrap().unwrap(); + assert_eq!(res, expect); + let res = transform + .transform_literal(×tamptz_ns) + .unwrap() + .unwrap(); + assert_eq!(res, expect); + } + + fn test_timestamp_ns_and_tz_transform_using_i64( + time: i64, + transform: &BoxedTransformFunction, + expect: Datum, + ) { + let timestamp_ns = Datum::timestamp_nanos(time); + let timestamptz_ns = Datum::timestamptz_nanos(time); + let res = transform.transform_literal(×tamp_ns).unwrap().unwrap(); + assert_eq!(res, expect); + let res = transform + .transform_literal(×tamptz_ns) + .unwrap() + .unwrap(); + assert_eq!(res, expect); + } + + #[test] + fn test_transform_year_literal() { + let year = Box::new(super::Year) as BoxedTransformFunction; + + // Test Date32 + test_date(18628, &year, Datum::int(2021 - super::UNIX_EPOCH_YEAR)); + test_date(-365, &year, Datum::int(-1)); + + // Test TimestampMicrosecond + test_timestamp_and_tz_transform_using_i64( + 186280000000, + &year, + Datum::int(1970 - super::UNIX_EPOCH_YEAR), + ); + test_timestamp_and_tz_transform("1969-01-01 00:00:00.00", &year, Datum::int(-1)); + + // Test TimestampNanosecond + test_timestamp_ns_and_tz_transform_using_i64( + 186280000000, + &year, + Datum::int(1970 - super::UNIX_EPOCH_YEAR), + ); + test_timestamp_ns_and_tz_transform("1969-01-01 00:00:00.00", &year, Datum::int(-1)); } #[test] @@ -226,6 +2504,7 @@ mod test { NaiveDate::from_ymd_opt(2000, 4, 1).unwrap(), NaiveDate::from_ymd_opt(2030, 7, 1).unwrap(), NaiveDate::from_ymd_opt(2060, 10, 1).unwrap(), + NaiveDate::from_ymd_opt(1969, 12, 1).unwrap(), ]; let date_array: ArrayRef = Arc::new(Date32Array::from( ori_date @@ -238,11 +2517,12 @@ mod test { )); let res = month.transform(date_array).unwrap(); let res = res.as_any().downcast_ref::().unwrap(); - assert_eq!(res.len(), 4); + assert_eq!(res.len(), 5); assert_eq!(res.value(0), 0); assert_eq!(res.value(1), 30 * 12 + 3); assert_eq!(res.value(2), 60 * 12 + 6); assert_eq!(res.value(3), 90 * 12 + 9); + assert_eq!(res.value(4), -1); // Test TimestampMicrosecond let ori_timestamp = vec![ @@ -254,6 +2534,8 @@ mod test { .unwrap(), NaiveDateTime::parse_from_str("2060-10-01 11:30:42.123", "%Y-%m-%d %H:%M:%S.%f") .unwrap(), + NaiveDateTime::parse_from_str("1969-12-01 00:00:00.00", "%Y-%m-%d %H:%M:%S.%f") + .unwrap(), ]; let date_array: ArrayRef = Arc::new(TimestampMicrosecondArray::from( ori_timestamp @@ -274,11 +2556,47 @@ mod test { )); let res = month.transform(date_array).unwrap(); let res = res.as_any().downcast_ref::().unwrap(); - assert_eq!(res.len(), 4); + assert_eq!(res.len(), 5); assert_eq!(res.value(0), 0); assert_eq!(res.value(1), 30 * 12 + 3); assert_eq!(res.value(2), 60 * 12 + 6); assert_eq!(res.value(3), 90 * 12 + 9); + assert_eq!(res.value(4), -1); + } + + #[test] + fn test_transform_month_literal() { + let month = Box::new(super::Month) as BoxedTransformFunction; + + // Test Date32 + test_date( + 18628, + &month, + Datum::int((2021 - super::UNIX_EPOCH_YEAR) * 12), + ); + test_date(-31, &month, Datum::int(-1)); + + // Test TimestampMicrosecond + test_timestamp_and_tz_transform_using_i64( + 186280000000, + &month, + Datum::int((1970 - super::UNIX_EPOCH_YEAR) * 12), + ); + test_timestamp_and_tz_transform("1969-12-01 23:00:00.00", &month, Datum::int(-1)); + test_timestamp_and_tz_transform("2017-12-01 00:00:00.00", &month, Datum::int(575)); + test_timestamp_and_tz_transform("1970-01-01 00:00:00.00", &month, Datum::int(0)); + test_timestamp_and_tz_transform("1969-12-31 00:00:00.00", &month, Datum::int(-1)); + + // Test TimestampNanosecond + test_timestamp_ns_and_tz_transform_using_i64( + 186280000000, + &month, + Datum::int((1970 - super::UNIX_EPOCH_YEAR) * 12), + ); + test_timestamp_ns_and_tz_transform("1969-12-01 23:00:00.00", &month, Datum::int(-1)); + test_timestamp_ns_and_tz_transform("2017-12-01 00:00:00.00", &month, Datum::int(575)); + test_timestamp_ns_and_tz_transform("1970-01-01 00:00:00.00", &month, Datum::int(0)); + test_timestamp_ns_and_tz_transform("1969-12-31 00:00:00.00", &month, Datum::int(-1)); } #[test] @@ -289,6 +2607,7 @@ mod test { NaiveDate::from_ymd_opt(2000, 4, 1).unwrap(), NaiveDate::from_ymd_opt(2030, 7, 1).unwrap(), NaiveDate::from_ymd_opt(2060, 10, 1).unwrap(), + NaiveDate::from_ymd_opt(1969, 12, 31).unwrap(), ]; let expect_day = ori_date .clone() @@ -311,11 +2630,12 @@ mod test { )); let res = day.transform(date_array).unwrap(); let res = res.as_any().downcast_ref::().unwrap(); - assert_eq!(res.len(), 4); + assert_eq!(res.len(), 5); assert_eq!(res.value(0), expect_day[0]); assert_eq!(res.value(1), expect_day[1]); assert_eq!(res.value(2), expect_day[2]); assert_eq!(res.value(3), expect_day[3]); + assert_eq!(res.value(4), -1); // Test TimestampMicrosecond let ori_timestamp = vec![ @@ -327,6 +2647,8 @@ mod test { .unwrap(), NaiveDateTime::parse_from_str("2060-10-01 11:30:42.123", "%Y-%m-%d %H:%M:%S.%f") .unwrap(), + NaiveDateTime::parse_from_str("1969-12-31 00:00:00.00", "%Y-%m-%d %H:%M:%S.%f") + .unwrap(), ]; let date_array: ArrayRef = Arc::new(TimestampMicrosecondArray::from( ori_timestamp @@ -347,11 +2669,30 @@ mod test { )); let res = day.transform(date_array).unwrap(); let res = res.as_any().downcast_ref::().unwrap(); - assert_eq!(res.len(), 4); + assert_eq!(res.len(), 5); assert_eq!(res.value(0), expect_day[0]); assert_eq!(res.value(1), expect_day[1]); assert_eq!(res.value(2), expect_day[2]); assert_eq!(res.value(3), expect_day[3]); + assert_eq!(res.value(4), -1); + } + + #[test] + fn test_transform_days_literal() { + let day = Box::new(super::Day) as BoxedTransformFunction; + // Test Date32 + test_date(18628, &day, Datum::int(18628)); + test_date(-31, &day, Datum::int(-31)); + + // Test TimestampMicrosecond + test_timestamp_and_tz_transform_using_i64(1512151975038194, &day, Datum::int(17501)); + test_timestamp_and_tz_transform_using_i64(-115200000000, &day, Datum::int(-2)); + test_timestamp_and_tz_transform("2017-12-01 10:30:42.123", &day, Datum::int(17501)); + + // Test TimestampNanosecond + test_timestamp_ns_and_tz_transform_using_i64(1512151975038194, &day, Datum::int(17)); + test_timestamp_ns_and_tz_transform_using_i64(-115200000000, &day, Datum::int(-1)); + test_timestamp_ns_and_tz_transform("2017-12-01 10:30:42.123", &day, Datum::int(17501)); } #[test] @@ -366,6 +2707,8 @@ mod test { .unwrap(), NaiveDateTime::parse_from_str("2060-09-01 05:03:23.123", "%Y-%m-%d %H:%M:%S.%f") .unwrap(), + NaiveDateTime::parse_from_str("1969-12-31 23:00:00.00", "%Y-%m-%d %H:%M:%S.%f") + .unwrap(), ]; let expect_hour = ori_timestamp .clone() @@ -403,10 +2746,24 @@ mod test { )); let res = hour.transform(date_array).unwrap(); let res = res.as_any().downcast_ref::().unwrap(); - assert_eq!(res.len(), 4); + assert_eq!(res.len(), 5); assert_eq!(res.value(0), expect_hour[0]); assert_eq!(res.value(1), expect_hour[1]); assert_eq!(res.value(2), expect_hour[2]); assert_eq!(res.value(3), expect_hour[3]); + assert_eq!(res.value(4), -1); + } + + #[test] + fn test_transform_hours_literal() { + let hour = Box::new(super::Hour) as BoxedTransformFunction; + + // Test TimestampMicrosecond + test_timestamp_and_tz_transform("2017-12-01 18:00:00.00", &hour, Datum::int(420042)); + test_timestamp_and_tz_transform("1969-12-31 23:00:00.00", &hour, Datum::int(-1)); + + // Test TimestampNanosecond + test_timestamp_ns_and_tz_transform("2017-12-01 18:00:00.00", &hour, Datum::int(420042)); + test_timestamp_ns_and_tz_transform("1969-12-31 23:00:00.00", &hour, Datum::int(-1)); } } diff --git a/crates/iceberg/src/transform/truncate.rs b/crates/iceberg/src/transform/truncate.rs index a8ebda8aa..83f769e27 100644 --- a/crates/iceberg/src/transform/truncate.rs +++ b/crates/iceberg/src/transform/truncate.rs @@ -20,9 +20,9 @@ use std::sync::Arc; use arrow_array::ArrayRef; use arrow_schema::DataType; -use crate::Error; - use super::TransformFunction; +use crate::spec::{Datum, PrimitiveLiteral}; +use crate::Error; #[derive(Debug)] pub struct Truncate { @@ -34,12 +34,28 @@ impl Truncate { Self { width } } - fn truncate_str_by_char(s: &str, max_chars: usize) -> &str { - match s.char_indices().nth(max_chars) { + #[inline] + fn truncate_str(s: &str, width: usize) -> &str { + match s.char_indices().nth(width) { None => s, Some((idx, _)) => &s[..idx], } } + + #[inline] + fn truncate_i32(v: i32, width: i32) -> i32 { + v - v.rem_euclid(width) + } + + #[inline] + fn truncate_i64(v: i64, width: i64) -> i64 { + v - (((v % width) + width) % width) + } + + #[inline] + fn truncate_decimal_i128(v: i128, width: i128) -> i128 { + v - (((v % width) + width) % width) + } } impl TransformFunction for Truncate { @@ -56,7 +72,7 @@ impl TransformFunction for Truncate { .as_any() .downcast_ref::() .unwrap() - .unary(|v| v - v.rem_euclid(width)); + .unary(|v| Self::truncate_i32(v, width)); Ok(Arc::new(res)) } DataType::Int64 => { @@ -65,7 +81,7 @@ impl TransformFunction for Truncate { .as_any() .downcast_ref::() .unwrap() - .unary(|v| v - (((v % width) + width) % width)); + .unary(|v| Self::truncate_i64(v, width)); Ok(Arc::new(res)) } DataType::Decimal128(precision, scale) => { @@ -74,7 +90,7 @@ impl TransformFunction for Truncate { .as_any() .downcast_ref::() .unwrap() - .unary(|v| v - (((v % width) + width) % width)) + .unary(|v| Self::truncate_decimal_i128(v, width)) .with_precision_and_scale(*precision, *scale) .map_err(|err| Error::new(crate::ErrorKind::Unexpected, format!("{err}")))?; Ok(Arc::new(res)) @@ -87,7 +103,7 @@ impl TransformFunction for Truncate { .downcast_ref::() .unwrap() .iter() - .map(|v| v.map(|v| Self::truncate_str_by_char(v, len))), + .map(|v| v.map(|v| Self::truncate_str(v, len))), ); Ok(Arc::new(res)) } @@ -99,11 +115,50 @@ impl TransformFunction for Truncate { .downcast_ref::() .unwrap() .iter() - .map(|v| v.map(|v| Self::truncate_str_by_char(v, len))), + .map(|v| v.map(|v| Self::truncate_str(v, len))), ); Ok(Arc::new(res)) } - _ => unreachable!("Truncate transform only supports (int,long,decimal,string) types"), + _ => Err(crate::Error::new( + crate::ErrorKind::FeatureUnsupported, + format!( + "Unsupported data type for truncate transform: {:?}", + input.data_type() + ), + )), + } + } + + fn transform_literal(&self, input: &Datum) -> crate::Result> { + match input.literal() { + PrimitiveLiteral::Int(v) => Ok(Some({ + let width: i32 = self.width.try_into().map_err(|_| { + Error::new( + crate::ErrorKind::DataInvalid, + "width is failed to convert to i32 when truncate Int32Array", + ) + })?; + Datum::int(Self::truncate_i32(*v, width)) + })), + PrimitiveLiteral::Long(v) => Ok(Some({ + let width = self.width as i64; + Datum::long(Self::truncate_i64(*v, width)) + })), + PrimitiveLiteral::Int128(v) => Ok(Some({ + let width = self.width as i128; + Datum::decimal(Self::truncate_decimal_i128(*v, width))? + })), + PrimitiveLiteral::String(v) => Ok(Some({ + let len = self.width as usize; + Datum::string(Self::truncate_str(v, len).to_string()) + })), + _ => Err(crate::Error::new( + crate::ErrorKind::FeatureUnsupported, + format!( + "Unsupported data type for truncate transform: {:?}", + input.data_type() + ), + )), } } } @@ -112,11 +167,522 @@ impl TransformFunction for Truncate { mod test { use std::sync::Arc; - use arrow_array::{ - builder::PrimitiveBuilder, types::Decimal128Type, Decimal128Array, Int32Array, Int64Array, - }; + use arrow_array::builder::PrimitiveBuilder; + use arrow_array::types::Decimal128Type; + use arrow_array::{Decimal128Array, Int32Array, Int64Array}; + use crate::expr::PredicateOperator; + use crate::spec::PrimitiveType::{ + Binary, Date, Decimal, Fixed, Int, Long, String as StringType, Time, Timestamp, + TimestampNs, Timestamptz, TimestamptzNs, Uuid, + }; + use crate::spec::Type::{Primitive, Struct}; + use crate::spec::{Datum, NestedField, PrimitiveType, StructType, Transform, Type}; + use crate::transform::test::{TestProjectionFixture, TestTransformFixture}; use crate::transform::TransformFunction; + use crate::Result; + + #[test] + fn test_truncate_transform() { + let trans = Transform::Truncate(4); + + let fixture = TestTransformFixture { + display: "truncate[4]".to_string(), + json: r#""truncate[4]""#.to_string(), + dedup_name: "truncate[4]".to_string(), + preserves_order: true, + satisfies_order_of: vec![ + (Transform::Truncate(4), true), + (Transform::Truncate(2), false), + (Transform::Bucket(4), false), + (Transform::Void, false), + (Transform::Day, false), + ], + trans_types: vec![ + (Primitive(Binary), Some(Primitive(Binary))), + (Primitive(Date), None), + ( + Primitive(Decimal { + precision: 8, + scale: 5, + }), + Some(Primitive(Decimal { + precision: 8, + scale: 5, + })), + ), + (Primitive(Fixed(8)), None), + (Primitive(Int), Some(Primitive(Int))), + (Primitive(Long), Some(Primitive(Long))), + (Primitive(StringType), Some(Primitive(StringType))), + (Primitive(Uuid), None), + (Primitive(Time), None), + (Primitive(Timestamp), None), + (Primitive(Timestamptz), None), + (Primitive(TimestampNs), None), + (Primitive(TimestamptzNs), None), + ( + Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()])), + None, + ), + ], + }; + + fixture.assert_transform(trans); + } + + #[test] + fn test_projection_truncate_string_rewrite_op() -> Result<()> { + let value = "abcde"; + + let fixture = TestProjectionFixture::new( + Transform::Truncate(5), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::String)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::StartsWith, Datum::string(value)), + Some(r#"name = "abcde""#), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotStartsWith, Datum::string(value)), + Some(r#"name != "abcde""#), + )?; + + let value = "abcdefg"; + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::StartsWith, Datum::string(value)), + Some(r#"name STARTS WITH "abcde""#), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotStartsWith, Datum::string(value)), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_truncate_string() -> Result<()> { + let value = "abcdefg"; + + let fixture = TestProjectionFixture::new( + Transform::Truncate(5), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::String)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::string(value)), + Some(r#"name <= "abcde""#), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::string(value)), + Some(r#"name <= "abcde""#), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThan, Datum::string(value)), + Some(r#"name >= "abcde""#), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThanOrEq, Datum::string(value)), + Some(r#"name >= "abcde""#), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::string(value)), + Some(r#"name = "abcde""#), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::string(value), + Datum::string(format!("{}abc", value)), + ]), + Some(r#"name IN ("abcde")"#), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::string(value), + Datum::string(format!("{}abc", value)), + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_truncate_upper_bound_decimal() -> Result<()> { + let prev = "98.99"; + let curr = "99.99"; + let next = "100.99"; + + let fixture = TestProjectionFixture::new( + Transform::Truncate(10), + "name", + NestedField::required( + 1, + "value", + Type::Primitive(PrimitiveType::Decimal { + precision: 9, + scale: 2, + }), + ), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::decimal_from_str(curr)?), + Some("name <= 9990"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::decimal_from_str(curr)?, + ), + Some("name <= 9990"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::decimal_from_str(curr)?, + ), + Some("name >= 9990"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::decimal_from_str(curr)?), + Some("name = 9990"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::decimal_from_str(curr)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::decimal_from_str(prev)?, + Datum::decimal_from_str(curr)?, + Datum::decimal_from_str(next)?, + ]), + Some("name IN (9890, 9990, 10090)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::decimal_from_str(curr)?, + Datum::decimal_from_str(next)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_truncate_lower_bound_decimal() -> Result<()> { + let prev = "99.00"; + let curr = "100.00"; + let next = "101.00"; + + let fixture = TestProjectionFixture::new( + Transform::Truncate(10), + "name", + NestedField::required( + 1, + "value", + Type::Primitive(PrimitiveType::Decimal { + precision: 9, + scale: 2, + }), + ), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::decimal_from_str(curr)?), + Some("name <= 9990"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::LessThanOrEq, + Datum::decimal_from_str(curr)?, + ), + Some("name <= 10000"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate( + PredicateOperator::GreaterThanOrEq, + Datum::decimal_from_str(curr)?, + ), + Some("name >= 10000"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::decimal_from_str(curr)?), + Some("name = 10000"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::decimal_from_str(curr)?), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::decimal_from_str(prev)?, + Datum::decimal_from_str(curr)?, + Datum::decimal_from_str(next)?, + ]), + Some("name IN (10000, 10100, 9900)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::decimal_from_str(curr)?, + Datum::decimal_from_str(next)?, + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_truncate_upper_bound_long() -> Result<()> { + let value = 99i64; + + let fixture = TestProjectionFixture::new( + Transform::Truncate(10), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Long)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::long(value)), + Some("name <= 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::long(value)), + Some("name <= 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThanOrEq, Datum::long(value)), + Some("name >= 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::long(value)), + Some("name = 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::long(value)), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::long(value - 1), + Datum::long(value), + Datum::long(value + 1), + ]), + Some("name IN (100, 90)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::long(value), + Datum::long(value + 1), + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_truncate_lower_bound_long() -> Result<()> { + let value = 100i64; + + let fixture = TestProjectionFixture::new( + Transform::Truncate(10), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Long)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::long(value)), + Some("name <= 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::long(value)), + Some("name <= 100"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThanOrEq, Datum::long(value)), + Some("name >= 100"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::long(value)), + Some("name = 100"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::long(value)), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::long(value - 1), + Datum::long(value), + Datum::long(value + 1), + ]), + Some("name IN (100, 90)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::long(value), + Datum::long(value + 1), + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_truncate_upper_bound_integer() -> Result<()> { + let value = 99; + + let fixture = TestProjectionFixture::new( + Transform::Truncate(10), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Int)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::int(value)), + Some("name <= 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::int(value)), + Some("name <= 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThanOrEq, Datum::int(value)), + Some("name >= 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::int(value)), + Some("name = 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::int(value)), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::int(value - 1), + Datum::int(value), + Datum::int(value + 1), + ]), + Some("name IN (100, 90)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::int(value), + Datum::int(value + 1), + ]), + None, + )?; + + Ok(()) + } + + #[test] + fn test_projection_truncate_lower_bound_integer() -> Result<()> { + let value = 100; + + let fixture = TestProjectionFixture::new( + Transform::Truncate(10), + "name", + NestedField::required(1, "value", Type::Primitive(PrimitiveType::Int)), + ); + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThan, Datum::int(value)), + Some("name <= 90"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::LessThanOrEq, Datum::int(value)), + Some("name <= 100"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::GreaterThanOrEq, Datum::int(value)), + Some("name >= 100"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::Eq, Datum::int(value)), + Some("name = 100"), + )?; + + fixture.assert_projection( + &fixture.binary_predicate(PredicateOperator::NotEq, Datum::int(value)), + None, + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::In, vec![ + Datum::int(value - 1), + Datum::int(value), + Datum::int(value + 1), + ]), + Some("name IN (100, 90)"), + )?; + + fixture.assert_projection( + &fixture.set_predicate(PredicateOperator::NotIn, vec![ + Datum::int(value), + Datum::int(value + 1), + ]), + None, + )?; + + Ok(()) + } // Test case ref from: https://iceberg.apache.org/spec/#truncate-transform-details #[test] @@ -187,32 +753,74 @@ mod test { fn test_string_truncate() { let test1 = "イロハニホヘト"; let test1_2_expected = "イロ"; - assert_eq!( - super::Truncate::truncate_str_by_char(test1, 2), - test1_2_expected - ); + assert_eq!(super::Truncate::truncate_str(test1, 2), test1_2_expected); let test1_3_expected = "イロハ"; - assert_eq!( - super::Truncate::truncate_str_by_char(test1, 3), - test1_3_expected - ); + assert_eq!(super::Truncate::truncate_str(test1, 3), test1_3_expected); let test2 = "щщаεはчωいにπάほхεろへσκζ"; let test2_7_expected = "щщаεはчω"; - assert_eq!( - super::Truncate::truncate_str_by_char(test2, 7), - test2_7_expected - ); + assert_eq!(super::Truncate::truncate_str(test2, 7), test2_7_expected); let test3 = "\u{FFFF}\u{FFFF}"; - assert_eq!(super::Truncate::truncate_str_by_char(test3, 2), test3); + assert_eq!(super::Truncate::truncate_str(test3, 2), test3); let test4 = "\u{10000}\u{10000}"; let test4_1_expected = "\u{10000}"; - assert_eq!( - super::Truncate::truncate_str_by_char(test4, 1), - test4_1_expected - ); + assert_eq!(super::Truncate::truncate_str(test4, 1), test4_1_expected); + } + + #[test] + fn test_literal_int() { + let input = Datum::int(1); + let res = super::Truncate::new(10) + .transform_literal(&input) + .unwrap() + .unwrap(); + assert_eq!(res, Datum::int(0),); + + let input = Datum::int(-1); + let res = super::Truncate::new(10) + .transform_literal(&input) + .unwrap() + .unwrap(); + assert_eq!(res, Datum::int(-10),); + } + + #[test] + fn test_literal_long() { + let input = Datum::long(1); + let res = super::Truncate::new(10) + .transform_literal(&input) + .unwrap() + .unwrap(); + assert_eq!(res, Datum::long(0),); + + let input = Datum::long(-1); + let res = super::Truncate::new(10) + .transform_literal(&input) + .unwrap() + .unwrap(); + assert_eq!(res, Datum::long(-10),); + } + + #[test] + fn test_decimal_literal() { + let input = Datum::decimal(1065).unwrap(); + let res = super::Truncate::new(50) + .transform_literal(&input) + .unwrap() + .unwrap(); + assert_eq!(res, Datum::decimal(1050).unwrap(),); + } + + #[test] + fn test_string_literal() { + let input = Datum::string("iceberg".to_string()); + let res = super::Truncate::new(3) + .transform_literal(&input) + .unwrap() + .unwrap(); + assert_eq!(res, Datum::string("ice".to_string()),); } } diff --git a/crates/iceberg/src/transform/void.rs b/crates/iceberg/src/transform/void.rs index d419430ba..5d429a593 100644 --- a/crates/iceberg/src/transform/void.rs +++ b/crates/iceberg/src/transform/void.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::Result; use arrow_array::{new_null_array, ArrayRef}; use super::TransformFunction; +use crate::Result; #[derive(Debug)] pub struct Void {} @@ -27,4 +27,130 @@ impl TransformFunction for Void { fn transform(&self, input: ArrayRef) -> Result { Ok(new_null_array(input.data_type(), input.len())) } + + fn transform_literal(&self, _input: &crate::spec::Datum) -> Result> { + Ok(None) + } +} + +#[cfg(test)] +mod test { + use crate::spec::PrimitiveType::{ + Binary, Date, Decimal, Fixed, Int, Long, String as StringType, Time, Timestamp, + TimestampNs, Timestamptz, TimestamptzNs, Uuid, + }; + use crate::spec::Type::{Primitive, Struct}; + use crate::spec::{NestedField, StructType, Transform}; + use crate::transform::test::TestTransformFixture; + + #[test] + fn test_void_transform() { + let trans = Transform::Void; + + let fixture = TestTransformFixture { + display: "void".to_string(), + json: r#""void""#.to_string(), + dedup_name: "void".to_string(), + preserves_order: false, + satisfies_order_of: vec![ + (Transform::Year, false), + (Transform::Month, false), + (Transform::Day, false), + (Transform::Hour, false), + (Transform::Void, true), + (Transform::Identity, false), + ], + trans_types: vec![ + (Primitive(Binary), Some(Primitive(Binary))), + (Primitive(Date), Some(Primitive(Date))), + ( + Primitive(Decimal { + precision: 8, + scale: 5, + }), + Some(Primitive(Decimal { + precision: 8, + scale: 5, + })), + ), + (Primitive(Fixed(8)), Some(Primitive(Fixed(8)))), + (Primitive(Int), Some(Primitive(Int))), + (Primitive(Long), Some(Primitive(Long))), + (Primitive(StringType), Some(Primitive(StringType))), + (Primitive(Uuid), Some(Primitive(Uuid))), + (Primitive(Time), Some(Primitive(Time))), + (Primitive(Timestamp), Some(Primitive(Timestamp))), + (Primitive(Timestamptz), Some(Primitive(Timestamptz))), + (Primitive(TimestampNs), Some(Primitive(TimestampNs))), + (Primitive(TimestamptzNs), Some(Primitive(TimestamptzNs))), + ( + Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()])), + Some(Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()]))), + ), + ], + }; + + fixture.assert_transform(trans); + } + + #[test] + fn test_known_transform() { + let trans = Transform::Unknown; + + let fixture = TestTransformFixture { + display: "unknown".to_string(), + json: r#""unknown""#.to_string(), + dedup_name: "unknown".to_string(), + preserves_order: false, + satisfies_order_of: vec![ + (Transform::Year, false), + (Transform::Month, false), + (Transform::Day, false), + (Transform::Hour, false), + (Transform::Void, false), + (Transform::Identity, false), + (Transform::Unknown, true), + ], + trans_types: vec![ + (Primitive(Binary), Some(Primitive(StringType))), + (Primitive(Date), Some(Primitive(StringType))), + ( + Primitive(Decimal { + precision: 8, + scale: 5, + }), + Some(Primitive(StringType)), + ), + (Primitive(Fixed(8)), Some(Primitive(StringType))), + (Primitive(Int), Some(Primitive(StringType))), + (Primitive(Long), Some(Primitive(StringType))), + (Primitive(StringType), Some(Primitive(StringType))), + (Primitive(Uuid), Some(Primitive(StringType))), + (Primitive(Time), Some(Primitive(StringType))), + (Primitive(Timestamp), Some(Primitive(StringType))), + (Primitive(Timestamptz), Some(Primitive(StringType))), + ( + Struct(StructType::new(vec![NestedField::optional( + 1, + "a", + Primitive(Timestamp), + ) + .into()])), + Some(Primitive(StringType)), + ), + ], + }; + + fixture.assert_transform(trans); + } } diff --git a/crates/iceberg/src/utils.rs b/crates/iceberg/src/utils.rs new file mode 100644 index 000000000..70514cccb --- /dev/null +++ b/crates/iceberg/src/utils.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::num::NonZero; + +// Use a default value of 1 as the safest option. +// See https://doc.rust-lang.org/std/thread/fn.available_parallelism.html#limitations +// for more details. +const DEFAULT_PARALLELISM: usize = 1; + +/// Uses [`std::thread::available_parallelism`] in order to +/// retrieve an estimate of the default amount of parallelism +/// that should be used. Note that [`std::thread::available_parallelism`] +/// returns a `Result` as it can fail, so here we use +/// a default value instead. +/// Note: we don't use a OnceCell or LazyCell here as there +/// are circumstances where the level of available +/// parallelism can change during the lifetime of an executing +/// process, but this should not be called in a hot loop. +pub(crate) fn available_parallelism() -> NonZero { + std::thread::available_parallelism().unwrap_or_else(|_err| { + // Failed to get the level of parallelism. + // TODO: log/trace when this fallback occurs. + + // Using a default value. + NonZero::new(DEFAULT_PARALLELISM).unwrap() + }) +} diff --git a/crates/iceberg/src/writer/base_writer/data_file_writer.rs b/crates/iceberg/src/writer/base_writer/data_file_writer.rs new file mode 100644 index 000000000..c32c98bbc --- /dev/null +++ b/crates/iceberg/src/writer/base_writer/data_file_writer.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provide `DataFileWriter`. + +use arrow_array::RecordBatch; +use itertools::Itertools; + +use crate::spec::{DataContentType, DataFile, Struct}; +use crate::writer::file_writer::{FileWriter, FileWriterBuilder}; +use crate::writer::{CurrentFileStatus, IcebergWriter, IcebergWriterBuilder}; +use crate::Result; + +/// Builder for `DataFileWriter`. +#[derive(Clone)] +pub struct DataFileWriterBuilder { + inner: B, +} + +impl DataFileWriterBuilder { + /// Create a new `DataFileWriterBuilder` using a `FileWriterBuilder`. + pub fn new(inner: B) -> Self { + Self { inner } + } +} + +/// Config for `DataFileWriter`. +pub struct DataFileWriterConfig { + partition_value: Struct, +} + +impl DataFileWriterConfig { + /// Create a new `DataFileWriterConfig` with partition value. + pub fn new(partition_value: Option) -> Self { + Self { + partition_value: partition_value.unwrap_or(Struct::empty()), + } + } +} + +#[async_trait::async_trait] +impl IcebergWriterBuilder for DataFileWriterBuilder { + type R = DataFileWriter; + type C = DataFileWriterConfig; + + async fn build(self, config: Self::C) -> Result { + Ok(DataFileWriter { + inner_writer: Some(self.inner.clone().build().await?), + partition_value: config.partition_value, + }) + } +} + +/// A writer write data is within one spec/partition. +pub struct DataFileWriter { + inner_writer: Option, + partition_value: Struct, +} + +#[async_trait::async_trait] +impl IcebergWriter for DataFileWriter { + async fn write(&mut self, batch: RecordBatch) -> Result<()> { + self.inner_writer.as_mut().unwrap().write(&batch).await + } + + async fn close(&mut self) -> Result> { + let writer = self.inner_writer.take().unwrap(); + Ok(writer + .close() + .await? + .into_iter() + .map(|mut res| { + res.content(DataContentType::Data); + res.partition(self.partition_value.clone()); + res.build().expect("Guaranteed to be valid") + }) + .collect_vec()) + } +} + +impl CurrentFileStatus for DataFileWriter { + fn current_file_path(&self) -> String { + self.inner_writer.as_ref().unwrap().current_file_path() + } + + fn current_row_num(&self) -> usize { + self.inner_writer.as_ref().unwrap().current_row_num() + } + + fn current_written_size(&self) -> usize { + self.inner_writer.as_ref().unwrap().current_written_size() + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use parquet::file::properties::WriterProperties; + use tempfile::TempDir; + + use crate::io::FileIOBuilder; + use crate::spec::{DataContentType, DataFileFormat, Schema, Struct}; + use crate::writer::base_writer::data_file_writer::{ + DataFileWriterBuilder, DataFileWriterConfig, + }; + use crate::writer::file_writer::location_generator::test::MockLocationGenerator; + use crate::writer::file_writer::location_generator::DefaultFileNameGenerator; + use crate::writer::file_writer::ParquetWriterBuilder; + use crate::writer::{IcebergWriter, IcebergWriterBuilder}; + use crate::Result; + + #[tokio::test] + async fn test_parquet_writer() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let location_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + let pw = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new(Schema::builder().build().unwrap()), + file_io.clone(), + location_gen, + file_name_gen, + ); + let mut data_file_writer = DataFileWriterBuilder::new(pw) + .build(DataFileWriterConfig::new(None)) + .await?; + + let data_file = data_file_writer.close().await.unwrap(); + assert_eq!(data_file.len(), 1); + assert_eq!(data_file[0].file_format, DataFileFormat::Parquet); + assert_eq!(data_file[0].content, DataContentType::Data); + assert_eq!(data_file[0].partition, Struct::empty()); + + Ok(()) + } +} diff --git a/crates/iceberg/src/writer/base_writer/mod.rs b/crates/iceberg/src/writer/base_writer/mod.rs new file mode 100644 index 000000000..37da2ab81 --- /dev/null +++ b/crates/iceberg/src/writer/base_writer/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Base writer module contains the basic writer provide by iceberg: `DataFileWriter`, `PositionDeleteFileWriter`, `EqualityDeleteFileWriter`. + +pub mod data_file_writer; diff --git a/crates/iceberg/src/writer/file_writer/location_generator.rs b/crates/iceberg/src/writer/file_writer/location_generator.rs new file mode 100644 index 000000000..44326190d --- /dev/null +++ b/crates/iceberg/src/writer/file_writer/location_generator.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains the location generator and file name generator for generating path of data file. + +use std::sync::atomic::AtomicU64; +use std::sync::Arc; + +use crate::spec::{DataFileFormat, TableMetadata}; +use crate::{Error, ErrorKind, Result}; + +/// `LocationGenerator` used to generate the location of data file. +pub trait LocationGenerator: Clone + Send + 'static { + /// Generate an absolute path for the given file name. + /// e.g + /// For file name "part-00000.parquet", the generated location maybe "/table/data/part-00000.parquet" + fn generate_location(&self, file_name: &str) -> String; +} + +const WRITE_DATA_LOCATION: &str = "write.data.path"; +const WRITE_FOLDER_STORAGE_LOCATION: &str = "write.folder-storage.path"; +const DEFAULT_DATA_DIR: &str = "/data"; + +#[derive(Clone)] +/// `DefaultLocationGenerator` used to generate the data dir location of data file. +/// The location is generated based on the table location and the data location in table properties. +pub struct DefaultLocationGenerator { + dir_path: String, +} + +impl DefaultLocationGenerator { + /// Create a new `DefaultLocationGenerator`. + pub fn new(table_metadata: TableMetadata) -> Result { + let table_location = table_metadata.location(); + let rel_dir_path = { + let prop = table_metadata.properties(); + let data_location = prop + .get(WRITE_DATA_LOCATION) + .or(prop.get(WRITE_FOLDER_STORAGE_LOCATION)); + if let Some(data_location) = data_location { + data_location.strip_prefix(table_location).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!( + "data location {} is not a subpath of table location {}", + data_location, table_location + ), + ) + })? + } else { + DEFAULT_DATA_DIR + } + }; + + Ok(Self { + dir_path: format!("{}{}", table_location, rel_dir_path), + }) + } +} + +impl LocationGenerator for DefaultLocationGenerator { + fn generate_location(&self, file_name: &str) -> String { + format!("{}/{}", self.dir_path, file_name) + } +} + +/// `FileNameGeneratorTrait` used to generate file name for data file. The file name can be passed to `LocationGenerator` to generate the location of the file. +pub trait FileNameGenerator: Clone + Send + 'static { + /// Generate a file name. + fn generate_file_name(&self) -> String; +} + +/// `DefaultFileNameGenerator` used to generate file name for data file. The file name can be +/// passed to `LocationGenerator` to generate the location of the file. +/// The file name format is "{prefix}-{file_count}[-{suffix}].{file_format}". +#[derive(Clone)] +pub struct DefaultFileNameGenerator { + prefix: String, + suffix: String, + format: String, + file_count: Arc, +} + +impl DefaultFileNameGenerator { + /// Create a new `FileNameGenerator`. + pub fn new(prefix: String, suffix: Option, format: DataFileFormat) -> Self { + let suffix = if let Some(suffix) = suffix { + format!("-{}", suffix) + } else { + "".to_string() + }; + + Self { + prefix, + suffix, + format: format.to_string(), + file_count: Arc::new(AtomicU64::new(0)), + } + } +} + +impl FileNameGenerator for DefaultFileNameGenerator { + fn generate_file_name(&self) -> String { + let file_id = self + .file_count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + format!( + "{}-{:05}{}.{}", + self.prefix, file_id, self.suffix, self.format + ) + } +} + +#[cfg(test)] +pub(crate) mod test { + use std::collections::HashMap; + + use uuid::Uuid; + + use super::LocationGenerator; + use crate::spec::{FormatVersion, TableMetadata}; + use crate::writer::file_writer::location_generator::{ + FileNameGenerator, WRITE_DATA_LOCATION, WRITE_FOLDER_STORAGE_LOCATION, + }; + + #[derive(Clone)] + pub(crate) struct MockLocationGenerator { + root: String, + } + + impl MockLocationGenerator { + pub(crate) fn new(root: String) -> Self { + Self { root } + } + } + + impl LocationGenerator for MockLocationGenerator { + fn generate_location(&self, file_name: &str) -> String { + format!("{}/{}", self.root, file_name) + } + } + + #[test] + fn test_default_location_generate() { + let mut table_metadata = TableMetadata { + format_version: FormatVersion::V2, + table_uuid: Uuid::parse_str("fb072c92-a02b-11e9-ae9c-1bb7bc9eca94").unwrap(), + location: "s3://data.db/table".to_string(), + last_updated_ms: 1515100955770, + last_column_id: 1, + schemas: HashMap::new(), + current_schema_id: 1, + partition_specs: HashMap::new(), + default_spec_id: 1, + last_partition_id: 1000, + default_sort_order_id: 0, + sort_orders: HashMap::from_iter(vec![]), + snapshots: HashMap::default(), + current_snapshot_id: None, + last_sequence_number: 1, + properties: HashMap::new(), + snapshot_log: Vec::new(), + metadata_log: vec![], + refs: HashMap::new(), + }; + + let file_name_genertaor = super::DefaultFileNameGenerator::new( + "part".to_string(), + Some("test".to_string()), + crate::spec::DataFileFormat::Parquet, + ); + + // test default data location + let location_generator = + super::DefaultLocationGenerator::new(table_metadata.clone()).unwrap(); + let location = + location_generator.generate_location(&file_name_genertaor.generate_file_name()); + assert_eq!(location, "s3://data.db/table/data/part-00000-test.parquet"); + + // test custom data location + table_metadata.properties.insert( + WRITE_FOLDER_STORAGE_LOCATION.to_string(), + "s3://data.db/table/data_1".to_string(), + ); + let location_generator = + super::DefaultLocationGenerator::new(table_metadata.clone()).unwrap(); + let location = + location_generator.generate_location(&file_name_genertaor.generate_file_name()); + assert_eq!( + location, + "s3://data.db/table/data_1/part-00001-test.parquet" + ); + + table_metadata.properties.insert( + WRITE_DATA_LOCATION.to_string(), + "s3://data.db/table/data_2".to_string(), + ); + let location_generator = + super::DefaultLocationGenerator::new(table_metadata.clone()).unwrap(); + let location = + location_generator.generate_location(&file_name_genertaor.generate_file_name()); + assert_eq!( + location, + "s3://data.db/table/data_2/part-00002-test.parquet" + ); + + // test invalid data location + table_metadata.properties.insert( + WRITE_DATA_LOCATION.to_string(), + // invalid table location + "s3://data.db/data_3".to_string(), + ); + let location_generator = super::DefaultLocationGenerator::new(table_metadata.clone()); + assert!(location_generator.is_err()); + } +} diff --git a/crates/iceberg/src/writer/file_writer/mod.rs b/crates/iceberg/src/writer/file_writer/mod.rs new file mode 100644 index 000000000..4a0fffcc1 --- /dev/null +++ b/crates/iceberg/src/writer/file_writer/mod.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains the writer for data file format supported by iceberg: parquet, orc. + +use arrow_array::RecordBatch; +use futures::Future; + +use super::CurrentFileStatus; +use crate::spec::DataFileBuilder; +use crate::Result; + +mod parquet_writer; +pub use parquet_writer::{ParquetWriter, ParquetWriterBuilder}; +mod track_writer; + +pub mod location_generator; + +type DefaultOutput = Vec; + +/// File writer builder trait. +pub trait FileWriterBuilder: Send + Clone + 'static { + /// The associated file writer type. + type R: FileWriter; + /// Build file writer. + fn build(self) -> impl Future> + Send; +} + +/// File writer focus on writing record batch to different physical file format.(Such as parquet. orc) +pub trait FileWriter: Send + CurrentFileStatus + 'static { + /// Write record batch to file. + fn write(&mut self, batch: &RecordBatch) -> impl Future> + Send; + /// Close file writer. + fn close(self) -> impl Future> + Send; +} diff --git a/crates/iceberg/src/writer/file_writer/parquet_writer.rs b/crates/iceberg/src/writer/file_writer/parquet_writer.rs new file mode 100644 index 000000000..3e2db5855 --- /dev/null +++ b/crates/iceberg/src/writer/file_writer/parquet_writer.rs @@ -0,0 +1,1136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The module contains the file writer for parquet file format. + +use std::collections::HashMap; +use std::sync::atomic::AtomicI64; +use std::sync::Arc; + +use arrow_schema::SchemaRef as ArrowSchemaRef; +use bytes::Bytes; +use futures::future::BoxFuture; +use itertools::Itertools; +use parquet::arrow::async_writer::AsyncFileWriter as ArrowAsyncFileWriter; +use parquet::arrow::AsyncArrowWriter; +use parquet::file::properties::WriterProperties; +use parquet::file::statistics::{from_thrift, Statistics}; +use parquet::format::FileMetaData; + +use super::location_generator::{FileNameGenerator, LocationGenerator}; +use super::track_writer::TrackWriter; +use super::{FileWriter, FileWriterBuilder}; +use crate::arrow::{ + get_parquet_stat_max_as_datum, get_parquet_stat_min_as_datum, DEFAULT_MAP_FIELD_NAME, +}; +use crate::io::{FileIO, FileWrite, OutputFile}; +use crate::spec::{ + visit_schema, DataFileBuilder, DataFileFormat, Datum, ListType, MapType, NestedFieldRef, + PrimitiveType, Schema, SchemaRef, SchemaVisitor, StructType, Type, +}; +use crate::writer::CurrentFileStatus; +use crate::{Error, ErrorKind, Result}; + +/// ParquetWriterBuilder is used to builder a [`ParquetWriter`] +#[derive(Clone)] +pub struct ParquetWriterBuilder { + props: WriterProperties, + schema: SchemaRef, + + file_io: FileIO, + location_generator: T, + file_name_generator: F, +} + +impl ParquetWriterBuilder { + /// Create a new `ParquetWriterBuilder` + /// To construct the write result, the schema should contain the `PARQUET_FIELD_ID_META_KEY` metadata for each field. + pub fn new( + props: WriterProperties, + schema: SchemaRef, + file_io: FileIO, + location_generator: T, + file_name_generator: F, + ) -> Self { + Self { + props, + schema, + file_io, + location_generator, + file_name_generator, + } + } +} + +impl FileWriterBuilder for ParquetWriterBuilder { + type R = ParquetWriter; + + async fn build(self) -> crate::Result { + let arrow_schema: ArrowSchemaRef = Arc::new(self.schema.as_ref().try_into()?); + let written_size = Arc::new(AtomicI64::new(0)); + let out_file = self.file_io.new_output( + self.location_generator + .generate_location(&self.file_name_generator.generate_file_name()), + )?; + let inner_writer = TrackWriter::new(out_file.writer().await?, written_size.clone()); + let async_writer = AsyncFileWriter::new(inner_writer); + let writer = + AsyncArrowWriter::try_new(async_writer, arrow_schema.clone(), Some(self.props)) + .map_err(|err| { + Error::new(ErrorKind::Unexpected, "Failed to build parquet writer.") + .with_source(err) + })?; + + Ok(ParquetWriter { + schema: self.schema.clone(), + writer, + written_size, + current_row_num: 0, + out_file, + }) + } +} + +struct IndexByParquetPathName { + name_to_id: HashMap, + + field_names: Vec, + + field_id: i32, +} + +impl IndexByParquetPathName { + pub fn new() -> Self { + Self { + name_to_id: HashMap::new(), + field_names: Vec::new(), + field_id: 0, + } + } + + pub fn get(&self, name: &str) -> Option<&i32> { + self.name_to_id.get(name) + } +} + +impl SchemaVisitor for IndexByParquetPathName { + type T = (); + + fn before_struct_field(&mut self, field: &NestedFieldRef) -> Result<()> { + self.field_names.push(field.name.to_string()); + self.field_id = field.id; + Ok(()) + } + + fn after_struct_field(&mut self, _field: &NestedFieldRef) -> Result<()> { + self.field_names.pop(); + Ok(()) + } + + fn before_list_element(&mut self, field: &NestedFieldRef) -> Result<()> { + self.field_names.push(format!("list.{}", field.name)); + self.field_id = field.id; + Ok(()) + } + + fn after_list_element(&mut self, _field: &NestedFieldRef) -> Result<()> { + self.field_names.pop(); + Ok(()) + } + + fn before_map_key(&mut self, field: &NestedFieldRef) -> Result<()> { + self.field_names + .push(format!("{DEFAULT_MAP_FIELD_NAME}.key")); + self.field_id = field.id; + Ok(()) + } + + fn after_map_key(&mut self, _field: &NestedFieldRef) -> Result<()> { + self.field_names.pop(); + Ok(()) + } + + fn before_map_value(&mut self, field: &NestedFieldRef) -> Result<()> { + self.field_names + .push(format!("{DEFAULT_MAP_FIELD_NAME}.value")); + self.field_id = field.id; + Ok(()) + } + + fn after_map_value(&mut self, _field: &NestedFieldRef) -> Result<()> { + self.field_names.pop(); + Ok(()) + } + + fn schema(&mut self, _schema: &Schema, _value: Self::T) -> Result { + Ok(()) + } + + fn field(&mut self, _field: &NestedFieldRef, _value: Self::T) -> Result { + Ok(()) + } + + fn r#struct(&mut self, _struct: &StructType, _results: Vec) -> Result { + Ok(()) + } + + fn list(&mut self, _list: &ListType, _value: Self::T) -> Result { + Ok(()) + } + + fn map(&mut self, _map: &MapType, _key_value: Self::T, _value: Self::T) -> Result { + Ok(()) + } + + fn primitive(&mut self, _p: &PrimitiveType) -> Result { + let full_name = self.field_names.iter().map(String::as_str).join("."); + let field_id = self.field_id; + if let Some(existing_field_id) = self.name_to_id.get(full_name.as_str()) { + return Err(Error::new(ErrorKind::DataInvalid, format!("Invalid schema: multiple fields for name {full_name}: {field_id} and {existing_field_id}"))); + } else { + self.name_to_id.insert(full_name, field_id); + } + + Ok(()) + } +} + +/// `ParquetWriter`` is used to write arrow data into parquet file on storage. +pub struct ParquetWriter { + schema: SchemaRef, + out_file: OutputFile, + writer: AsyncArrowWriter>, + written_size: Arc, + current_row_num: usize, +} + +/// Used to aggregate min and max value of each column. +struct MinMaxColAggregator { + lower_bounds: HashMap, + upper_bounds: HashMap, + schema: SchemaRef, +} + +impl MinMaxColAggregator { + fn new(schema: SchemaRef) -> Self { + Self { + lower_bounds: HashMap::new(), + upper_bounds: HashMap::new(), + schema, + } + } + + fn update_state_min(&mut self, field_id: i32, datum: Datum) { + self.lower_bounds + .entry(field_id) + .and_modify(|e| { + if *e > datum { + *e = datum.clone() + } + }) + .or_insert(datum); + } + + fn update_state_max(&mut self, field_id: i32, datum: Datum) { + self.upper_bounds + .entry(field_id) + .and_modify(|e| { + if *e > datum { + *e = datum.clone() + } + }) + .or_insert(datum); + } + + fn update(&mut self, field_id: i32, value: Statistics) -> Result<()> { + let Some(ty) = self + .schema + .field_by_id(field_id) + .map(|f| f.field_type.as_ref()) + else { + // Following java implementation: https://github.com/apache/iceberg/blob/29a2c456353a6120b8c882ed2ab544975b168d7b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java#L163 + // Ignore the field if it is not in schema. + return Ok(()); + }; + let Type::Primitive(ty) = ty.clone() else { + return Err(Error::new( + ErrorKind::Unexpected, + format!( + "Composed type {} is not supported for min max aggregation.", + ty + ), + )); + }; + + if value.min_is_exact() { + let Some(min_datum) = get_parquet_stat_min_as_datum(&ty, &value)? else { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Statistics {} is not match with field type {}.", value, ty), + )); + }; + + self.update_state_min(field_id, min_datum); + } + + if value.max_is_exact() { + let Some(max_datum) = get_parquet_stat_max_as_datum(&ty, &value)? else { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Statistics {} is not match with field type {}.", value, ty), + )); + }; + + self.update_state_max(field_id, max_datum); + } + + Ok(()) + } + + fn produce(self) -> (HashMap, HashMap) { + (self.lower_bounds, self.upper_bounds) + } +} + +impl ParquetWriter { + fn to_data_file_builder( + schema: SchemaRef, + metadata: FileMetaData, + written_size: usize, + file_path: String, + ) -> Result { + let index_by_parquet_path = { + let mut visitor = IndexByParquetPathName::new(); + visit_schema(&schema, &mut visitor)?; + visitor + }; + + let (column_sizes, value_counts, null_value_counts, (lower_bounds, upper_bounds)) = { + let mut per_col_size: HashMap = HashMap::new(); + let mut per_col_val_num: HashMap = HashMap::new(); + let mut per_col_null_val_num: HashMap = HashMap::new(); + let mut min_max_agg = MinMaxColAggregator::new(schema); + + for row_group in &metadata.row_groups { + for column_chunk in row_group.columns.iter() { + let Some(column_chunk_metadata) = &column_chunk.meta_data else { + continue; + }; + let physical_type = column_chunk_metadata.type_; + let Some(&field_id) = + index_by_parquet_path.get(&column_chunk_metadata.path_in_schema.join(".")) + else { + // Following java implementation: https://github.com/apache/iceberg/blob/29a2c456353a6120b8c882ed2ab544975b168d7b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java#L163 + // Ignore the field if it is not in schema. + continue; + }; + *per_col_size.entry(field_id).or_insert(0) += + column_chunk_metadata.total_compressed_size as u64; + *per_col_val_num.entry(field_id).or_insert(0) += + column_chunk_metadata.num_values as u64; + if let Some(null_count) = column_chunk_metadata + .statistics + .as_ref() + .and_then(|s| s.null_count) + { + *per_col_null_val_num.entry(field_id).or_insert(0_u64) += null_count as u64; + } + if let Some(statistics) = &column_chunk_metadata.statistics { + min_max_agg.update( + field_id, + from_thrift(physical_type.try_into()?, Some(statistics.clone()))? + .unwrap(), + )?; + } + } + } + + ( + per_col_size, + per_col_val_num, + per_col_null_val_num, + min_max_agg.produce(), + ) + }; + + let mut builder = DataFileBuilder::default(); + builder + .file_path(file_path) + .file_format(DataFileFormat::Parquet) + .record_count(metadata.num_rows as u64) + .file_size_in_bytes(written_size as u64) + .column_sizes(column_sizes) + .value_counts(value_counts) + .null_value_counts(null_value_counts) + .lower_bounds(lower_bounds) + .upper_bounds(upper_bounds) + // # TODO(#417) + // - nan_value_counts + // - distinct_counts + .key_metadata(metadata.footer_signing_key_metadata.unwrap_or_default()) + .split_offsets( + metadata + .row_groups + .iter() + .filter_map(|group| group.file_offset) + .collect(), + ); + Ok(builder) + } +} + +impl FileWriter for ParquetWriter { + async fn write(&mut self, batch: &arrow_array::RecordBatch) -> crate::Result<()> { + self.current_row_num += batch.num_rows(); + self.writer.write(batch).await.map_err(|err| { + Error::new( + ErrorKind::Unexpected, + "Failed to write using parquet writer.", + ) + .with_source(err) + })?; + Ok(()) + } + + async fn close(self) -> crate::Result> { + let metadata = self.writer.close().await.map_err(|err| { + Error::new(ErrorKind::Unexpected, "Failed to close parquet writer.").with_source(err) + })?; + + let written_size = self.written_size.load(std::sync::atomic::Ordering::Relaxed); + + Ok(vec![Self::to_data_file_builder( + self.schema, + metadata, + written_size as usize, + self.out_file.location().to_string(), + )?]) + } +} + +impl CurrentFileStatus for ParquetWriter { + fn current_file_path(&self) -> String { + self.out_file.location().to_string() + } + + fn current_row_num(&self) -> usize { + self.current_row_num + } + + fn current_written_size(&self) -> usize { + self.written_size.load(std::sync::atomic::Ordering::Relaxed) as usize + } +} + +/// AsyncFileWriter is a wrapper of FileWrite to make it compatible with tokio::io::AsyncWrite. +/// +/// # NOTES +/// +/// We keep this wrapper been used inside only. +struct AsyncFileWriter(W); + +impl AsyncFileWriter { + /// Create a new `AsyncFileWriter` with the given writer. + pub fn new(writer: W) -> Self { + Self(writer) + } +} + +impl ArrowAsyncFileWriter for AsyncFileWriter { + fn write(&mut self, bs: Bytes) -> BoxFuture<'_, parquet::errors::Result<()>> { + Box::pin(async { + self.0 + .write(bs) + .await + .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))) + }) + } + + fn complete(&mut self) -> BoxFuture<'_, parquet::errors::Result<()>> { + Box::pin(async { + self.0 + .close() + .await + .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))) + }) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use anyhow::Result; + use arrow_array::types::Int64Type; + use arrow_array::{ + Array, ArrayRef, BooleanArray, Int32Array, Int64Array, ListArray, RecordBatch, StructArray, + }; + use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef}; + use arrow_select::concat::concat_batches; + use parquet::arrow::PARQUET_FIELD_ID_META_KEY; + use tempfile::TempDir; + use uuid::Uuid; + + use super::*; + use crate::io::FileIOBuilder; + use crate::spec::{PrimitiveLiteral, Struct, *}; + use crate::writer::file_writer::location_generator::test::MockLocationGenerator; + use crate::writer::file_writer::location_generator::DefaultFileNameGenerator; + use crate::writer::tests::check_parquet_data_file; + + fn schema_for_all_type() -> Schema { + Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::optional(0, "boolean", Type::Primitive(PrimitiveType::Boolean)).into(), + NestedField::optional(1, "int", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(2, "long", Type::Primitive(PrimitiveType::Long)).into(), + NestedField::optional(3, "float", Type::Primitive(PrimitiveType::Float)).into(), + NestedField::optional(4, "double", Type::Primitive(PrimitiveType::Double)).into(), + NestedField::optional(5, "string", Type::Primitive(PrimitiveType::String)).into(), + NestedField::optional(6, "binary", Type::Primitive(PrimitiveType::Binary)).into(), + NestedField::optional(7, "date", Type::Primitive(PrimitiveType::Date)).into(), + NestedField::optional(8, "time", Type::Primitive(PrimitiveType::Time)).into(), + NestedField::optional(9, "timestamp", Type::Primitive(PrimitiveType::Timestamp)) + .into(), + NestedField::optional( + 10, + "timestamptz", + Type::Primitive(PrimitiveType::Timestamptz), + ) + .into(), + NestedField::optional( + 11, + "timestamp_ns", + Type::Primitive(PrimitiveType::TimestampNs), + ) + .into(), + NestedField::optional( + 12, + "timestamptz_ns", + Type::Primitive(PrimitiveType::TimestamptzNs), + ) + .into(), + NestedField::optional( + 13, + "decimal", + Type::Primitive(PrimitiveType::Decimal { + precision: 10, + scale: 5, + }), + ) + .into(), + NestedField::optional(14, "uuid", Type::Primitive(PrimitiveType::Uuid)).into(), + NestedField::optional(15, "fixed", Type::Primitive(PrimitiveType::Fixed(10))) + .into(), + ]) + .build() + .unwrap() + } + + fn nested_schema_for_test() -> Schema { + // Int, Struct(Int,Int), String, List(Int), Struct(Struct(Int)), Map(String, List(Int)) + Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::required(0, "col0", Type::Primitive(PrimitiveType::Long)).into(), + NestedField::required( + 1, + "col1", + Type::Struct(StructType::new(vec![ + NestedField::required(5, "col_1_5", Type::Primitive(PrimitiveType::Long)) + .into(), + NestedField::required(6, "col_1_6", Type::Primitive(PrimitiveType::Long)) + .into(), + ])), + ) + .into(), + NestedField::required(2, "col2", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required( + 3, + "col3", + Type::List(ListType::new( + NestedField::required(7, "element", Type::Primitive(PrimitiveType::Long)) + .into(), + )), + ) + .into(), + NestedField::required( + 4, + "col4", + Type::Struct(StructType::new(vec![NestedField::required( + 8, + "col_4_8", + Type::Struct(StructType::new(vec![NestedField::required( + 9, + "col_4_8_9", + Type::Primitive(PrimitiveType::Long), + ) + .into()])), + ) + .into()])), + ) + .into(), + NestedField::required( + 10, + "col5", + Type::Map(MapType::new( + NestedField::required(11, "key", Type::Primitive(PrimitiveType::String)) + .into(), + NestedField::required( + 12, + "value", + Type::List(ListType::new( + NestedField::required( + 13, + "item", + Type::Primitive(PrimitiveType::Long), + ) + .into(), + )), + ) + .into(), + )), + ) + .into(), + ]) + .build() + .unwrap() + } + + #[tokio::test] + async fn test_index_by_parquet_path() { + let expect = HashMap::from([ + ("col0".to_string(), 0), + ("col1.col_1_5".to_string(), 5), + ("col1.col_1_6".to_string(), 6), + ("col2".to_string(), 2), + ("col3.list.element".to_string(), 7), + ("col4.col_4_8.col_4_8_9".to_string(), 9), + ("col5.key_value.key".to_string(), 11), + ("col5.key_value.value.list.item".to_string(), 13), + ]); + let mut visitor = IndexByParquetPathName::new(); + visit_schema(&nested_schema_for_test(), &mut visitor).unwrap(); + assert_eq!(visitor.name_to_id, expect); + } + + #[tokio::test] + async fn test_parquet_writer() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let loccation_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + // prepare data + let schema = { + let fields = vec![ + arrow_schema::Field::new("col", arrow_schema::DataType::Int64, true).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "0".to_string())]), + ), + ]; + Arc::new(arrow_schema::Schema::new(fields)) + }; + let col = Arc::new(Int64Array::from_iter_values(0..1024)) as ArrayRef; + let null_col = Arc::new(Int64Array::new_null(1024)) as ArrayRef; + let to_write = RecordBatch::try_new(schema.clone(), vec![col]).unwrap(); + let to_write_null = RecordBatch::try_new(schema.clone(), vec![null_col]).unwrap(); + + // write data + let mut pw = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new(to_write.schema().as_ref().try_into().unwrap()), + file_io.clone(), + loccation_gen, + file_name_gen, + ) + .build() + .await?; + pw.write(&to_write).await?; + pw.write(&to_write_null).await?; + let res = pw.close().await?; + assert_eq!(res.len(), 1); + let data_file = res + .into_iter() + .next() + .unwrap() + // Put dummy field for build successfully. + .content(crate::spec::DataContentType::Data) + .partition(Struct::empty()) + .build() + .unwrap(); + + // check data file + assert_eq!(data_file.record_count(), 2048); + assert_eq!(*data_file.value_counts(), HashMap::from([(0, 2048)])); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([(0, Datum::long(0))]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([(0, Datum::long(1023))]) + ); + assert_eq!(*data_file.null_value_counts(), HashMap::from([(0, 1024)])); + + // check the written file + let expect_batch = concat_batches(&schema, vec![&to_write, &to_write_null]).unwrap(); + check_parquet_data_file(&file_io, &data_file, &expect_batch).await; + + Ok(()) + } + + #[tokio::test] + async fn test_parquet_writer_with_complex_schema() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let location_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + // prepare data + let schema = nested_schema_for_test(); + let arrow_schema: ArrowSchemaRef = Arc::new((&schema).try_into().unwrap()); + let col0 = Arc::new(Int64Array::from_iter_values(0..1024)) as ArrayRef; + let col1 = Arc::new(StructArray::new( + { + if let DataType::Struct(fields) = arrow_schema.field(1).data_type() { + fields.clone() + } else { + unreachable!() + } + }, + vec![ + Arc::new(Int64Array::from_iter_values(0..1024)), + Arc::new(Int64Array::from_iter_values(0..1024)), + ], + None, + )); + let col2 = Arc::new(arrow_array::StringArray::from_iter_values( + (0..1024).map(|n| n.to_string()), + )) as ArrayRef; + let col3 = Arc::new({ + let list_parts = arrow_array::ListArray::from_iter_primitive::( + (0..1024).map(|n| Some(vec![Some(n)])), + ) + .into_parts(); + arrow_array::ListArray::new( + { + if let DataType::List(field) = arrow_schema.field(3).data_type() { + field.clone() + } else { + unreachable!() + } + }, + list_parts.1, + list_parts.2, + list_parts.3, + ) + }) as ArrayRef; + let col4 = Arc::new(StructArray::new( + { + if let DataType::Struct(fields) = arrow_schema.field(4).data_type() { + fields.clone() + } else { + unreachable!() + } + }, + vec![Arc::new(StructArray::new( + { + if let DataType::Struct(fields) = arrow_schema.field(4).data_type() { + if let DataType::Struct(fields) = fields[0].data_type() { + fields.clone() + } else { + unreachable!() + } + } else { + unreachable!() + } + }, + vec![Arc::new(Int64Array::from_iter_values(0..1024))], + None, + ))], + None, + )); + let col5 = Arc::new({ + let mut map_array_builder = arrow_array::builder::MapBuilder::new( + None, + arrow_array::builder::StringBuilder::new(), + arrow_array::builder::ListBuilder::new(arrow_array::builder::PrimitiveBuilder::< + Int64Type, + >::new()), + ); + for i in 0..1024 { + map_array_builder.keys().append_value(i.to_string()); + map_array_builder + .values() + .append_value(vec![Some(i as i64); i + 1]); + map_array_builder.append(true)?; + } + let (_, offset_buffer, struct_array, null_buffer, ordered) = + map_array_builder.finish().into_parts(); + let struct_array = { + let (_, mut arrays, nulls) = struct_array.into_parts(); + let list_array = { + let list_array = arrays[1] + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + let (_, offsets, array, nulls) = list_array.into_parts(); + let list_field = { + if let DataType::Map(map_field, _) = arrow_schema.field(5).data_type() { + if let DataType::Struct(fields) = map_field.data_type() { + if let DataType::List(list_field) = fields[1].data_type() { + list_field.clone() + } else { + unreachable!() + } + } else { + unreachable!() + } + } else { + unreachable!() + } + }; + ListArray::new(list_field, offsets, array, nulls) + }; + arrays[1] = Arc::new(list_array) as ArrayRef; + StructArray::new( + { + if let DataType::Map(map_field, _) = arrow_schema.field(5).data_type() { + if let DataType::Struct(fields) = map_field.data_type() { + fields.clone() + } else { + unreachable!() + } + } else { + unreachable!() + } + }, + arrays, + nulls, + ) + }; + arrow_array::MapArray::new( + { + if let DataType::Map(map_field, _) = arrow_schema.field(5).data_type() { + map_field.clone() + } else { + unreachable!() + } + }, + offset_buffer, + struct_array, + null_buffer, + ordered, + ) + }) as ArrayRef; + let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![ + col0, col1, col2, col3, col4, col5, + ]) + .unwrap(); + + // write data + let mut pw = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new(schema), + file_io.clone(), + location_gen, + file_name_gen, + ) + .build() + .await?; + pw.write(&to_write).await?; + let res = pw.close().await?; + assert_eq!(res.len(), 1); + let data_file = res + .into_iter() + .next() + .unwrap() + // Put dummy field for build successfully. + .content(crate::spec::DataContentType::Data) + .partition(Struct::empty()) + .build() + .unwrap(); + + // check data file + assert_eq!(data_file.record_count(), 1024); + assert_eq!( + *data_file.value_counts(), + HashMap::from([ + (0, 1024), + (5, 1024), + (6, 1024), + (2, 1024), + (7, 1024), + (9, 1024), + (11, 1024), + (13, (1..1025).sum()), + ]) + ); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([ + (0, Datum::long(0)), + (5, Datum::long(0)), + (6, Datum::long(0)), + (2, Datum::string("0")), + (7, Datum::long(0)), + (9, Datum::long(0)), + (11, Datum::string("0")), + (13, Datum::long(0)) + ]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([ + (0, Datum::long(1023)), + (5, Datum::long(1023)), + (6, Datum::long(1023)), + (2, Datum::string("999")), + (7, Datum::long(1023)), + (9, Datum::long(1023)), + (11, Datum::string("999")), + (13, Datum::long(1023)) + ]) + ); + + // check the written file + check_parquet_data_file(&file_io, &data_file, &to_write).await; + + Ok(()) + } + + #[tokio::test] + async fn test_all_type_for_write() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let loccation_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + // prepare data + // generate iceberg schema for all type + let schema = schema_for_all_type(); + let arrow_schema: ArrowSchemaRef = Arc::new((&schema).try_into().unwrap()); + let col0 = Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + ])) as ArrayRef; + let col1 = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])) as ArrayRef; + let col2 = Arc::new(Int64Array::from(vec![Some(1), Some(2), None, Some(4)])) as ArrayRef; + let col3 = Arc::new(arrow_array::Float32Array::from(vec![ + Some(0.5), + Some(2.0), + None, + Some(3.5), + ])) as ArrayRef; + let col4 = Arc::new(arrow_array::Float64Array::from(vec![ + Some(0.5), + Some(2.0), + None, + Some(3.5), + ])) as ArrayRef; + let col5 = Arc::new(arrow_array::StringArray::from(vec![ + Some("a"), + Some("b"), + None, + Some("d"), + ])) as ArrayRef; + let col6 = Arc::new(arrow_array::LargeBinaryArray::from_opt_vec(vec![ + Some(b"one"), + None, + Some(b""), + Some(b"zzzz"), + ])) as ArrayRef; + let col7 = Arc::new(arrow_array::Date32Array::from(vec![ + Some(0), + Some(1), + None, + Some(3), + ])) as ArrayRef; + let col8 = Arc::new(arrow_array::Time64MicrosecondArray::from(vec![ + Some(0), + Some(1), + None, + Some(3), + ])) as ArrayRef; + let col9 = Arc::new(arrow_array::TimestampMicrosecondArray::from(vec![ + Some(0), + Some(1), + None, + Some(3), + ])) as ArrayRef; + let col10 = Arc::new( + arrow_array::TimestampMicrosecondArray::from(vec![Some(0), Some(1), None, Some(3)]) + .with_timezone_utc(), + ) as ArrayRef; + let col11 = Arc::new(arrow_array::TimestampNanosecondArray::from(vec![ + Some(0), + Some(1), + None, + Some(3), + ])) as ArrayRef; + let col12 = Arc::new( + arrow_array::TimestampNanosecondArray::from(vec![Some(0), Some(1), None, Some(3)]) + .with_timezone_utc(), + ) as ArrayRef; + let col13 = Arc::new( + arrow_array::Decimal128Array::from(vec![Some(1), Some(2), None, Some(100)]) + .with_precision_and_scale(10, 5) + .unwrap(), + ) as ArrayRef; + let col14 = Arc::new( + arrow_array::FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![ + Some(Uuid::from_u128(0).as_bytes().to_vec()), + Some(Uuid::from_u128(1).as_bytes().to_vec()), + None, + Some(Uuid::from_u128(3).as_bytes().to_vec()), + ] + .into_iter(), + 16, + ) + .unwrap(), + ) as ArrayRef; + let col15 = Arc::new( + arrow_array::FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![ + Some(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + Some(vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20]), + None, + Some(vec![21, 22, 23, 24, 25, 26, 27, 28, 29, 30]), + ] + .into_iter(), + 10, + ) + .unwrap(), + ) as ArrayRef; + let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![ + col0, col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13, + col14, col15, + ]) + .unwrap(); + + // write data + let mut pw = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new(schema), + file_io.clone(), + loccation_gen, + file_name_gen, + ) + .build() + .await?; + pw.write(&to_write).await?; + let res = pw.close().await?; + assert_eq!(res.len(), 1); + let data_file = res + .into_iter() + .next() + .unwrap() + // Put dummy field for build successfully. + .content(crate::spec::DataContentType::Data) + .partition(Struct::empty()) + .build() + .unwrap(); + + // check data file + assert_eq!(data_file.record_count(), 4); + assert!(data_file.value_counts().iter().all(|(_, &v)| { v == 4 })); + assert!(data_file + .null_value_counts() + .iter() + .all(|(_, &v)| { v == 1 })); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([ + (0, Datum::bool(false)), + (1, Datum::int(1)), + (2, Datum::long(1)), + (3, Datum::float(0.5)), + (4, Datum::double(0.5)), + (5, Datum::string("a")), + (6, Datum::binary(vec![])), + (7, Datum::date(0)), + (8, Datum::time_micros(0).unwrap()), + (9, Datum::timestamp_micros(0)), + (10, Datum::timestamptz_micros(0)), + (11, Datum::timestamp_nanos(0)), + (12, Datum::timestamptz_nanos(0)), + ( + 13, + Datum::new( + PrimitiveType::Decimal { + precision: 10, + scale: 5 + }, + PrimitiveLiteral::Int128(1) + ) + ), + (14, Datum::uuid(Uuid::from_u128(0))), + (15, Datum::fixed(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])), + ]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([ + (0, Datum::bool(true)), + (1, Datum::int(4)), + (2, Datum::long(4)), + (3, Datum::float(3.5)), + (4, Datum::double(3.5)), + (5, Datum::string("d")), + (6, Datum::binary(vec![122, 122, 122, 122])), + (7, Datum::date(3)), + (8, Datum::time_micros(3).unwrap()), + (9, Datum::timestamp_micros(3)), + (10, Datum::timestamptz_micros(3)), + (11, Datum::timestamp_nanos(3)), + (12, Datum::timestamptz_nanos(3)), + ( + 13, + Datum::new( + PrimitiveType::Decimal { + precision: 10, + scale: 5 + }, + PrimitiveLiteral::Int128(100) + ) + ), + (14, Datum::uuid(Uuid::from_u128(3))), + ( + 15, + Datum::fixed(vec![21, 22, 23, 24, 25, 26, 27, 28, 29, 30]) + ), + ]) + ); + + // check the written file + check_parquet_data_file(&file_io, &data_file, &to_write).await; + + Ok(()) + } +} diff --git a/crates/iceberg/src/writer/file_writer/track_writer.rs b/crates/iceberg/src/writer/file_writer/track_writer.rs new file mode 100644 index 000000000..6c60a1aa7 --- /dev/null +++ b/crates/iceberg/src/writer/file_writer/track_writer.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::atomic::AtomicI64; +use std::sync::Arc; + +use bytes::Bytes; + +use crate::io::FileWrite; +use crate::Result; + +/// `TrackWriter` is used to track the written size. +pub(crate) struct TrackWriter { + inner: Box, + written_size: Arc, +} + +impl TrackWriter { + pub fn new(writer: Box, written_size: Arc) -> Self { + Self { + inner: writer, + written_size, + } + } +} + +#[async_trait::async_trait] +impl FileWrite for TrackWriter { + async fn write(&mut self, bs: Bytes) -> Result<()> { + let size = bs.len(); + self.inner.write(bs).await.map(|v| { + self.written_size + .fetch_add(size as i64, std::sync::atomic::Ordering::Relaxed); + v + }) + } + + async fn close(&mut self) -> Result<()> { + self.inner.close().await + } +} diff --git a/crates/iceberg/src/writer/mod.rs b/crates/iceberg/src/writer/mod.rs new file mode 100644 index 000000000..6cb9aaee6 --- /dev/null +++ b/crates/iceberg/src/writer/mod.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Iceberg writer module. +//! +//! The writer API is designed to be extensible and flexible. Each writer is decoupled and can be create and config independently. User can: +//! 1.Customize the writer using the writer trait. +//! 2.Combine different writer to build a writer which have complex write logic. +//! +//! There are two kinds of writer: +//! 1. FileWriter: Focus on writing record batch to different physical file format.(Such as parquet. orc) +//! 2. IcebergWriter: Focus on the logical format of iceberg table. It will write the data using the FileWriter finally. +//! +//! # Simple example for data file writer: +//! ```ignore +//! // Create a parquet file writer builder. The parameter can get from table. +//! let file_writer_builder = ParquetWriterBuilder::new( +//! 0, +//! WriterProperties::builder().build(), +//! schema, +//! file_io.clone(), +//! loccation_gen, +//! file_name_gen, +//! ) +//! // Create a data file writer using parquet file writer builder. +//! let data_file_builder = DataFileBuilder::new(file_writer_builder); +//! // Build the data file writer. +//! let data_file_writer = data_file_builder.build().await.unwrap(); +//! +//! data_file_writer.write(&record_batch).await.unwrap(); +//! let data_files = data_file_writer.flush().await.unwrap(); +//! ``` + +pub mod base_writer; +pub mod file_writer; + +use arrow_array::RecordBatch; + +use crate::spec::DataFile; +use crate::Result; + +type DefaultInput = RecordBatch; +type DefaultOutput = Vec; + +/// The builder for iceberg writer. +#[async_trait::async_trait] +pub trait IcebergWriterBuilder: + Send + Clone + 'static +{ + /// The associated writer type. + type R: IcebergWriter; + /// The associated writer config type used to build the writer. + type C; + /// Build the iceberg writer. + async fn build(self, config: Self::C) -> Result; +} + +/// The iceberg writer used to write data to iceberg table. +#[async_trait::async_trait] +pub trait IcebergWriter: Send + 'static { + /// Write data to iceberg table. + async fn write(&mut self, input: I) -> Result<()>; + /// Close the writer and return the written data files. + /// If close failed, the data written before maybe be lost. User may need to recreate the writer and rewrite the data again. + /// # NOTE + /// After close, regardless of success or failure, the writer should never be used again, otherwise the writer will panic. + async fn close(&mut self) -> Result; +} + +/// The current file status of iceberg writer. It implement for the writer which write a single +/// file. +pub trait CurrentFileStatus { + /// Get the current file path. + fn current_file_path(&self) -> String; + /// Get the current file row number. + fn current_row_num(&self) -> usize; + /// Get the current file written size. + fn current_written_size(&self) -> usize; +} + +#[cfg(test)] +mod tests { + use arrow_array::RecordBatch; + use arrow_schema::Schema; + use arrow_select::concat::concat_batches; + use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; + + use super::IcebergWriter; + use crate::io::FileIO; + use crate::spec::{DataFile, DataFileFormat}; + + // This function is used to guarantee the trait can be used as a object safe trait. + async fn _guarantee_object_safe(mut w: Box) { + let _ = w + .write(RecordBatch::new_empty(Schema::empty().into())) + .await; + let _ = w.close().await; + } + + // This function check: + // The data of the written parquet file is correct. + // The metadata of the data file is consistent with the written parquet file. + pub(crate) async fn check_parquet_data_file( + file_io: &FileIO, + data_file: &DataFile, + batch: &RecordBatch, + ) { + assert_eq!(data_file.file_format, DataFileFormat::Parquet); + + let input_file = file_io.new_input(data_file.file_path.clone()).unwrap(); + // read the written file + let input_content = input_file.read().await.unwrap(); + let reader_builder = + ParquetRecordBatchReaderBuilder::try_new(input_content.clone()).unwrap(); + + // check data + let reader = reader_builder.build().unwrap(); + let batches = reader.map(|batch| batch.unwrap()).collect::>(); + let res = concat_batches(&batch.schema(), &batches).unwrap(); + assert_eq!(*batch, res); + } +} diff --git a/crates/iceberg/testdata/example_table_metadata_v2.json b/crates/iceberg/testdata/example_table_metadata_v2.json new file mode 100644 index 000000000..cf9fef96d --- /dev/null +++ b/crates/iceberg/testdata/example_table_metadata_v2.json @@ -0,0 +1,62 @@ +{ + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "{{ table_location }}", + "last-sequence-number": 34, + "last-updated-ms": 1602638573590, + "last-column-id": 3, + "current-schema-id": 1, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"}, + {"id": 4, "name": "a", "required": true, "type": "string"} + ] + } + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "default-sort-order-id": 3, + "sort-orders": [ + { + "order-id": 3, + "fields": [ + {"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, + {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"} + ] + } + ], + "properties": {"read.split.target.size": "134217728"}, + "current-snapshot-id": 3055729675574597004, + "snapshots": [ + { + "snapshot-id": 3051729675574597004, + "timestamp-ms": 1515100955770, + "sequence-number": 0, + "summary": {"operation": "append"}, + "manifest-list": "{{ manifest_list_1_location }}" + }, + { + "snapshot-id": 3055729675574597004, + "parent-snapshot-id": 3051729675574597004, + "timestamp-ms": 1555100955770, + "sequence-number": 1, + "summary": {"operation": "append"}, + "manifest-list": "{{ manifest_list_2_location }}", + "schema-id": 1 + } + ], + "snapshot-log": [ + {"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770}, + {"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770} + ], + "metadata-log": [{"metadata-file": "{{ table_metadata_1_location }}", "timestamp-ms": 1515100}], + "refs": {"test": {"snapshot-id": 3051729675574597004, "type": "tag", "max-ref-age-ms": 10000000}} +} \ No newline at end of file diff --git a/crates/iceberg/testdata/file_io_gcs/docker-compose.yaml b/crates/iceberg/testdata/file_io_gcs/docker-compose.yaml new file mode 100644 index 000000000..6935a0864 --- /dev/null +++ b/crates/iceberg/testdata/file_io_gcs/docker-compose.yaml @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +services: + gcs-server: + image: fsouza/fake-gcs-server@sha256:36b0116fae5236e8def76ccb07761a9ca323e476f366a5f4bf449cac19deaf2d + expose: + - 4443 + command: --scheme http diff --git a/crates/iceberg/testdata/file_io_s3/docker-compose.yaml b/crates/iceberg/testdata/file_io_s3/docker-compose.yaml new file mode 100644 index 000000000..cbce31864 --- /dev/null +++ b/crates/iceberg/testdata/file_io_s3/docker-compose.yaml @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +services: + minio: + image: minio/minio:RELEASE.2024-02-26T09-33-48Z + expose: + - 9000 + - 9001 + environment: + MINIO_ROOT_USER: 'admin' + MINIO_ROOT_PASSWORD: 'password' + MINIO_ADDRESS: ':9000' + MINIO_CONSOLE_ADDRESS: ':9001' + entrypoint: sh + command: -c 'mkdir -p /data/bucket1 && /usr/bin/minio server /data' diff --git a/crates/iceberg/testdata/manifests_lists/manifest-list-v2-1.avro b/crates/iceberg/testdata/manifests_lists/manifest-list-v2-1.avro new file mode 100644 index 000000000..5c5cdb1ad Binary files /dev/null and b/crates/iceberg/testdata/manifests_lists/manifest-list-v2-1.avro differ diff --git a/crates/iceberg/testdata/manifests_lists/manifest-list-v2-2.avro b/crates/iceberg/testdata/manifests_lists/manifest-list-v2-2.avro new file mode 100644 index 000000000..00784ff1d Binary files /dev/null and b/crates/iceberg/testdata/manifests_lists/manifest-list-v2-2.avro differ diff --git a/crates/iceberg/testdata/view_metadata/ViewMetadataUnsupportedVersion.json b/crates/iceberg/testdata/view_metadata/ViewMetadataUnsupportedVersion.json new file mode 100644 index 000000000..c5627b8af --- /dev/null +++ b/crates/iceberg/testdata/view_metadata/ViewMetadataUnsupportedVersion.json @@ -0,0 +1,58 @@ +{ + "view-uuid": "fa6506c3-7681-40c8-86dc-e36561f83385", + "format-version": 2, + "location": "s3://bucket/warehouse/default.db/event_agg", + "current-version-id": 1, + "properties": { + "comment": "Daily event counts" + }, + "versions": [ + { + "version-id": 1, + "timestamp-ms": 1573518431292, + "schema-id": 1, + "default-catalog": "prod", + "default-namespace": [ + "default" + ], + "summary": { + "engine-name": "Spark", + "engineVersion": "3.3.2" + }, + "representations": [ + { + "type": "sql", + "sql": "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2", + "dialect": "spark" + } + ] + } + ], + "schemas": [ + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "event_count", + "required": false, + "type": "int", + "doc": "Count of events" + }, + { + "id": 2, + "name": "event_date", + "required": false, + "type": "date" + } + ] + } + ], + "version-log": [ + { + "timestamp-ms": 1573518431292, + "version-id": 1 + } + ] +} \ No newline at end of file diff --git a/crates/iceberg/testdata/view_metadata/ViewMetadataV1CurrentVersionNotFound.json b/crates/iceberg/testdata/view_metadata/ViewMetadataV1CurrentVersionNotFound.json new file mode 100644 index 000000000..4ba94ca4c --- /dev/null +++ b/crates/iceberg/testdata/view_metadata/ViewMetadataV1CurrentVersionNotFound.json @@ -0,0 +1,58 @@ +{ + "view-uuid": "fa6506c3-7681-40c8-86dc-e36561f83385", + "format-version": 1, + "location": "s3://bucket/warehouse/default.db/event_agg", + "current-version-id": 2, + "properties": { + "comment": "Daily event counts" + }, + "versions": [ + { + "version-id": 1, + "timestamp-ms": 1573518431292, + "schema-id": 1, + "default-catalog": "prod", + "default-namespace": [ + "default" + ], + "summary": { + "engine-name": "Spark", + "engineVersion": "3.3.2" + }, + "representations": [ + { + "type": "sql", + "sql": "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2", + "dialect": "spark" + } + ] + } + ], + "schemas": [ + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "event_count", + "required": false, + "type": "int", + "doc": "Count of events" + }, + { + "id": 2, + "name": "event_date", + "required": false, + "type": "date" + } + ] + } + ], + "version-log": [ + { + "timestamp-ms": 1573518431292, + "version-id": 1 + } + ] +} \ No newline at end of file diff --git a/crates/iceberg/testdata/view_metadata/ViewMetadataV1MissingCurrentVersion.json b/crates/iceberg/testdata/view_metadata/ViewMetadataV1MissingCurrentVersion.json new file mode 100644 index 000000000..c21088176 --- /dev/null +++ b/crates/iceberg/testdata/view_metadata/ViewMetadataV1MissingCurrentVersion.json @@ -0,0 +1,57 @@ +{ + "view-uuid": "fa6506c3-7681-40c8-86dc-e36561f83385", + "format-version": 1, + "location": "s3://bucket/warehouse/default.db/event_agg", + "properties": { + "comment": "Daily event counts" + }, + "versions": [ + { + "version-id": 1, + "timestamp-ms": 1573518431292, + "schema-id": 1, + "default-catalog": "prod", + "default-namespace": [ + "default" + ], + "summary": { + "engine-name": "Spark", + "engineVersion": "3.3.2" + }, + "representations": [ + { + "type": "sql", + "sql": "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2", + "dialect": "spark" + } + ] + } + ], + "schemas": [ + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "event_count", + "required": false, + "type": "int", + "doc": "Count of events" + }, + { + "id": 2, + "name": "event_date", + "required": false, + "type": "date" + } + ] + } + ], + "version-log": [ + { + "timestamp-ms": 1573518431292, + "version-id": 1 + } + ] +} \ No newline at end of file diff --git a/crates/iceberg/testdata/view_metadata/ViewMetadataV1MissingSchema.json b/crates/iceberg/testdata/view_metadata/ViewMetadataV1MissingSchema.json new file mode 100644 index 000000000..b5b454bca --- /dev/null +++ b/crates/iceberg/testdata/view_metadata/ViewMetadataV1MissingSchema.json @@ -0,0 +1,56 @@ +{ + "view-uuid": "fa6506c3-7681-40c8-86dc-e36561f83385", + "format-version": 1, + "location": "s3://bucket/warehouse/default.db/event_agg", + "properties": { + "comment": "Daily event counts" + }, + "versions": [ + { + "version-id": 1, + "timestamp-ms": 1573518431292, + "default-catalog": "prod", + "default-namespace": [ + "default" + ], + "summary": { + "engine-name": "Spark", + "engineVersion": "3.3.2" + }, + "representations": [ + { + "type": "sql", + "sql": "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2", + "dialect": "spark" + } + ] + } + ], + "schemas": [ + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "event_count", + "required": false, + "type": "int", + "doc": "Count of events" + }, + { + "id": 2, + "name": "event_date", + "required": false, + "type": "date" + } + ] + } + ], + "version-log": [ + { + "timestamp-ms": 1573518431292, + "version-id": 1 + } + ] +} \ No newline at end of file diff --git a/crates/iceberg/testdata/view_metadata/ViewMetadataV1SchemaNotFound.json b/crates/iceberg/testdata/view_metadata/ViewMetadataV1SchemaNotFound.json new file mode 100644 index 000000000..0026d223e --- /dev/null +++ b/crates/iceberg/testdata/view_metadata/ViewMetadataV1SchemaNotFound.json @@ -0,0 +1,58 @@ +{ + "view-uuid": "fa6506c3-7681-40c8-86dc-e36561f83385", + "format-version": 1, + "location": "s3://bucket/warehouse/default.db/event_agg", + "current-version-id": 1, + "properties": { + "comment": "Daily event counts" + }, + "versions": [ + { + "version-id": 1, + "timestamp-ms": 1573518431292, + "schema-id": 2, + "default-catalog": "prod", + "default-namespace": [ + "default" + ], + "summary": { + "engine-name": "Spark", + "engineVersion": "3.3.2" + }, + "representations": [ + { + "type": "sql", + "sql": "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2", + "dialect": "spark" + } + ] + } + ], + "schemas": [ + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "event_count", + "required": false, + "type": "int", + "doc": "Count of events" + }, + { + "id": 2, + "name": "event_date", + "required": false, + "type": "date" + } + ] + } + ], + "version-log": [ + { + "timestamp-ms": 1573518431292, + "version-id": 1 + } + ] +} \ No newline at end of file diff --git a/crates/iceberg/testdata/view_metadata/ViewMetadataV1Valid.json b/crates/iceberg/testdata/view_metadata/ViewMetadataV1Valid.json new file mode 100644 index 000000000..5011a804f --- /dev/null +++ b/crates/iceberg/testdata/view_metadata/ViewMetadataV1Valid.json @@ -0,0 +1,58 @@ +{ + "view-uuid": "fa6506c3-7681-40c8-86dc-e36561f83385", + "format-version": 1, + "location": "s3://bucket/warehouse/default.db/event_agg", + "current-version-id": 1, + "properties": { + "comment": "Daily event counts" + }, + "versions": [ + { + "version-id": 1, + "timestamp-ms": 1573518431292, + "schema-id": 1, + "default-catalog": "prod", + "default-namespace": [ + "default" + ], + "summary": { + "engine-name": "Spark", + "engineVersion": "3.3.2" + }, + "representations": [ + { + "type": "sql", + "sql": "SELECT\n COUNT(1), CAST(event_ts AS DATE)\nFROM events\nGROUP BY 2", + "dialect": "spark" + } + ] + } + ], + "schemas": [ + { + "schema-id": 1, + "type": "struct", + "fields": [ + { + "id": 1, + "name": "event_count", + "required": false, + "type": "int", + "doc": "Count of events" + }, + { + "id": 2, + "name": "event_date", + "required": false, + "type": "date" + } + ] + } + ], + "version-log": [ + { + "timestamp-ms": 1573518431292, + "version-id": 1 + } + ] +} \ No newline at end of file diff --git a/crates/iceberg/tests/file_io_gcs_test.rs b/crates/iceberg/tests/file_io_gcs_test.rs new file mode 100644 index 000000000..540cd9d99 --- /dev/null +++ b/crates/iceberg/tests/file_io_gcs_test.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for FileIO Google Cloud Storage (GCS). + +#[cfg(all(test, feature = "storage-gcs"))] +mod tests { + use std::collections::HashMap; + use std::net::SocketAddr; + use std::sync::RwLock; + + use bytes::Bytes; + use ctor::{ctor, dtor}; + use iceberg::io::{FileIO, FileIOBuilder, GCS_NO_AUTH, GCS_SERVICE_PATH}; + use iceberg_test_utils::docker::DockerCompose; + use iceberg_test_utils::{normalize_test_name, set_up}; + + static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); + static FAKE_GCS_PORT: u16 = 4443; + static FAKE_GCS_BUCKET: &str = "test-bucket"; + + #[ctor] + fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + let docker_compose = DockerCompose::new( + normalize_test_name(module_path!()), + format!("{}/testdata/file_io_gcs", env!("CARGO_MANIFEST_DIR")), + ); + docker_compose.run(); + guard.replace(docker_compose); + } + + #[dtor] + fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); + } + + async fn get_file_io_gcs() -> FileIO { + set_up(); + + let ip = DOCKER_COMPOSE_ENV + .read() + .unwrap() + .as_ref() + .unwrap() + .get_container_ip("gcs-server"); + let addr = SocketAddr::new(ip, FAKE_GCS_PORT); + + // A bucket must exist for FileIO + create_bucket(FAKE_GCS_BUCKET, addr.to_string()) + .await + .unwrap(); + + FileIOBuilder::new("gcs") + .with_props(vec![ + (GCS_SERVICE_PATH, format!("http://{}", addr)), + (GCS_NO_AUTH, "true".to_string()), + ]) + .build() + .unwrap() + } + + // Create a bucket against the emulated GCS storage server. + async fn create_bucket(name: &str, server_addr: String) -> anyhow::Result<()> { + let mut bucket_data = HashMap::new(); + bucket_data.insert("name", name); + + let client = reqwest::Client::new(); + let endpoint = format!("http://{}/storage/v1/b", server_addr); + client.post(endpoint).json(&bucket_data).send().await?; + Ok(()) + } + + fn get_gs_path() -> String { + format!("gs://{}", FAKE_GCS_BUCKET) + } + + #[tokio::test] + async fn gcs_exists() { + let file_io = get_file_io_gcs().await; + assert!(file_io + .is_exist(format!("{}/", get_gs_path())) + .await + .unwrap()); + } + + #[tokio::test] + async fn gcs_write() { + let gs_file = format!("{}/write-file", get_gs_path()); + let file_io = get_file_io_gcs().await; + let output = file_io.new_output(&gs_file).unwrap(); + output + .write(bytes::Bytes::from_static(b"iceberg-gcs!")) + .await + .expect("Write to test output file"); + assert!(file_io.is_exist(gs_file).await.unwrap()) + } + + #[tokio::test] + async fn gcs_read() { + let gs_file = format!("{}/read-gcs", get_gs_path()); + let file_io = get_file_io_gcs().await; + let output = file_io.new_output(&gs_file).unwrap(); + output + .write(bytes::Bytes::from_static(b"iceberg!")) + .await + .expect("Write to test output file"); + assert!(file_io.is_exist(&gs_file).await.unwrap()); + + let input = file_io.new_input(gs_file).unwrap(); + assert_eq!(input.read().await.unwrap(), Bytes::from_static(b"iceberg!")); + } +} diff --git a/crates/iceberg/tests/file_io_s3_test.rs b/crates/iceberg/tests/file_io_s3_test.rs new file mode 100644 index 000000000..32e2d12a4 --- /dev/null +++ b/crates/iceberg/tests/file_io_s3_test.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for FileIO S3. +#[cfg(all(test, feature = "storage-s3"))] +mod tests { + use std::net::SocketAddr; + use std::sync::RwLock; + + use ctor::{ctor, dtor}; + use iceberg::io::{ + FileIO, FileIOBuilder, S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY, + }; + use iceberg_test_utils::docker::DockerCompose; + use iceberg_test_utils::{normalize_test_name, set_up}; + + const MINIO_PORT: u16 = 9000; + static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); + + #[ctor] + fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + let docker_compose = DockerCompose::new( + normalize_test_name(module_path!()), + format!("{}/testdata/file_io_s3", env!("CARGO_MANIFEST_DIR")), + ); + docker_compose.run(); + guard.replace(docker_compose); + } + + #[dtor] + fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); + } + + async fn get_file_io() -> FileIO { + set_up(); + + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + let container_ip = docker_compose.get_container_ip("minio"); + let minio_socket_addr = SocketAddr::new(container_ip, MINIO_PORT); + + FileIOBuilder::new("s3") + .with_props(vec![ + (S3_ENDPOINT, format!("http://{}", minio_socket_addr)), + (S3_ACCESS_KEY_ID, "admin".to_string()), + (S3_SECRET_ACCESS_KEY, "password".to_string()), + (S3_REGION, "us-east-1".to_string()), + ]) + .build() + .unwrap() + } + + #[tokio::test] + async fn test_file_io_s3_is_exist() { + let file_io = get_file_io().await; + assert!(!file_io.is_exist("s3://bucket2/any").await.unwrap()); + assert!(file_io.is_exist("s3://bucket1/").await.unwrap()); + } + + #[tokio::test] + async fn test_file_io_s3_output() { + let file_io = get_file_io().await; + assert!(!file_io.is_exist("s3://bucket1/test_output").await.unwrap()); + let output_file = file_io.new_output("s3://bucket1/test_output").unwrap(); + { + output_file.write("123".into()).await.unwrap(); + } + assert!(file_io.is_exist("s3://bucket1/test_output").await.unwrap()); + } + + #[tokio::test] + async fn test_file_io_s3_input() { + let file_io = get_file_io().await; + let output_file = file_io.new_output("s3://bucket1/test_input").unwrap(); + { + output_file.write("test_input".into()).await.unwrap(); + } + + let input_file = file_io.new_input("s3://bucket1/test_input").unwrap(); + + { + let buffer = input_file.read().await.unwrap(); + assert_eq!(buffer, "test_input".as_bytes()); + } + } +} diff --git a/crates/integrations/datafusion/Cargo.toml b/crates/integrations/datafusion/Cargo.toml new file mode 100644 index 000000000..87e809cec --- /dev/null +++ b/crates/integrations/datafusion/Cargo.toml @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "iceberg-datafusion" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +rust-version = { workspace = true } + +categories = ["database"] +description = "Apache Iceberg Datafusion Integration" +repository = { workspace = true } +license = { workspace = true } +keywords = ["iceberg", "integrations", "datafusion"] + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +datafusion = { version = "41.0.0" } +futures = { workspace = true } +iceberg = { workspace = true } +tokio = { workspace = true } + +[dev-dependencies] +iceberg-catalog-memory = { workspace = true } +tempfile = { workspace = true } diff --git a/crates/integrations/datafusion/DEPENDENCIES.rust.tsv b/crates/integrations/datafusion/DEPENDENCIES.rust.tsv new file mode 100644 index 000000000..2f6a5649e --- /dev/null +++ b/crates/integrations/datafusion/DEPENDENCIES.rust.tsv @@ -0,0 +1,391 @@ +crate 0BSD Apache-2.0 Apache-2.0 WITH LLVM-exception BSD-2-Clause BSD-3-Clause BSL-1.0 CC0-1.0 ISC MIT MIT-0 MPL-2.0 OpenSSL Unicode-DFS-2016 Unlicense Zlib +addr2line@0.22.0 X X +adler@1.0.2 X X X +adler32@1.2.0 X +ahash@0.8.11 X X +aho-corasick@1.1.3 X X +alloc-no-stdlib@2.0.4 X +alloc-stdlib@0.2.2 X +allocator-api2@0.2.18 X X +android-tzdata@0.1.1 X X +android_system_properties@0.1.5 X X +anstream@0.6.15 X X +anstyle@1.0.8 X X +anstyle-parse@0.2.5 X X +anstyle-query@1.1.1 X X +anstyle-wincon@3.0.4 X X +anyhow@1.0.86 X X +apache-avro@0.17.0 X +array-init@2.1.0 X X +arrayref@0.3.8 X +arrayvec@0.7.4 X X +arrow@52.2.0 X +arrow-arith@52.2.0 X +arrow-array@52.2.0 X +arrow-buffer@52.2.0 X +arrow-cast@52.2.0 X +arrow-csv@52.2.0 X +arrow-data@52.2.0 X +arrow-ipc@52.2.0 X +arrow-json@52.2.0 X +arrow-ord@52.2.0 X +arrow-row@52.2.0 X +arrow-schema@52.2.0 X +arrow-select@52.2.0 X +arrow-string@52.2.0 X +async-broadcast@0.7.1 X X +async-compression@0.4.12 X X +async-recursion@1.1.1 X X +async-trait@0.1.81 X X +atoi@2.0.0 X +autocfg@1.3.0 X X +backon@0.4.4 X +backtrace@0.3.73 X X +base64@0.22.1 X X +bigdecimal@0.4.5 X X +bimap@0.6.3 X X +bitflags@1.3.2 X X +bitflags@2.6.0 X X +bitvec@1.0.1 X +blake2@0.10.6 X X +blake3@1.5.3 X X X +block-buffer@0.10.4 X X +brotli@6.0.0 X X +brotli-decompressor@4.0.1 X X +bumpalo@3.16.0 X X +byteorder@1.5.0 X X +bytes@1.7.1 X +bzip2@0.4.4 X X +bzip2-sys@0.1.11+1.0.8 X X +cc@1.1.11 X X +cfg-if@1.0.0 X X +cfg_aliases@0.1.1 X +chrono@0.4.38 X X +chrono-tz@0.9.0 X X +chrono-tz-build@0.3.0 X X +colorchoice@1.0.2 X X +comfy-table@7.1.1 X +concurrent-queue@2.5.0 X X +const-oid@0.9.6 X X +const-random@0.1.18 X X +const-random-macro@0.1.16 X X +constant_time_eq@0.3.0 X X X +core-foundation-sys@0.8.7 X X +core2@0.4.0 X X +cpufeatures@0.2.13 X X +crc32c@0.6.8 X X +crc32fast@1.4.2 X X +crossbeam-utils@0.8.20 X X +crunchy@0.2.2 X +crypto-common@0.1.6 X X +csv@1.3.0 X X +csv-core@0.1.11 X X +darling@0.20.10 X +darling_core@0.20.10 X +darling_macro@0.20.10 X +dary_heap@0.3.6 X X +dashmap@5.5.3 X +dashmap@6.0.1 X +datafusion@41.0.0 X +datafusion-catalog@41.0.0 X +datafusion-common@41.0.0 X +datafusion-common-runtime@41.0.0 X +datafusion-execution@41.0.0 X +datafusion-expr@41.0.0 X +datafusion-functions@41.0.0 X +datafusion-functions-aggregate@41.0.0 X +datafusion-functions-nested@41.0.0 X +datafusion-optimizer@41.0.0 X +datafusion-physical-expr@41.0.0 X +datafusion-physical-expr-common@41.0.0 X +datafusion-physical-optimizer@41.0.0 X +datafusion-physical-plan@41.0.0 X +datafusion-sql@41.0.0 X +derivative@2.2.0 X X +derive_builder@0.20.0 X X +derive_builder_core@0.20.0 X X +derive_builder_macro@0.20.0 X X +digest@0.10.7 X X +doc-comment@0.3.3 X +either@1.13.0 X X +env_filter@0.1.2 X X +env_logger@0.11.5 X X +equivalent@1.0.1 X X +errno@0.3.9 X X +event-listener@5.3.1 X X +event-listener-strategy@0.5.2 X X +fastrand@2.1.0 X X +faststr@0.2.21 X X +fixedbitset@0.4.2 X X +flagset@0.4.6 X +flatbuffers@24.3.25 X +flate2@1.0.31 X X +fnv@1.0.7 X X +form_urlencoded@1.2.1 X X +funty@2.0.0 X +futures@0.3.30 X X +futures-channel@0.3.30 X X +futures-core@0.3.30 X X +futures-executor@0.3.30 X X +futures-io@0.3.30 X X +futures-macro@0.3.30 X X +futures-sink@0.3.30 X X +futures-task@0.3.30 X X +futures-util@0.3.30 X X +generic-array@0.14.7 X +getrandom@0.2.15 X X +gimli@0.29.0 X X +glob@0.3.1 X X +half@2.4.1 X X +hashbrown@0.14.5 X X +heck@0.4.1 X X +heck@0.5.0 X X +hermit-abi@0.3.9 X X +hex@0.4.3 X X +hive_metastore@0.1.0 X +hmac@0.12.1 X X +home@0.5.9 X X +http@1.1.0 X X +http-body@1.0.1 X +http-body-util@0.1.2 X +httparse@1.9.4 X X +humantime@2.1.0 X X +hyper@1.4.1 X +hyper-rustls@0.27.2 X X X +hyper-util@0.1.7 X +iana-time-zone@0.1.60 X X +iana-time-zone-haiku@0.1.2 X X +iceberg@0.3.0 X +iceberg-catalog-hms@0.3.0 X +iceberg-catalog-memory@0.3.0 X +iceberg-datafusion@0.3.0 X +iceberg_test_utils@0.3.0 X +ident_case@1.0.1 X X +idna@0.5.0 X X +indexmap@2.4.0 X X +instant@0.1.13 X +integer-encoding@3.0.4 X +integer-encoding@4.0.2 X +ipnet@2.9.0 X X +is_terminal_polyfill@1.70.1 X X +itertools@0.12.1 X X +itertools@0.13.0 X X +itoa@1.0.11 X X +jobserver@0.1.32 X X +js-sys@0.3.70 X X +lazy_static@1.5.0 X X +lexical-core@0.8.5 X X +lexical-parse-float@0.8.5 X X +lexical-parse-integer@0.8.6 X X +lexical-util@0.8.5 X X +lexical-write-float@0.8.5 X X +lexical-write-integer@0.8.5 X X +libc@0.2.155 X X +libflate@2.1.0 X +libflate_lz77@2.1.0 X +libm@0.2.8 X X +linked-hash-map@0.5.6 X X +linkedbytes@0.1.8 X X +linux-raw-sys@0.4.14 X X X +lock_api@0.4.12 X X +log@0.4.22 X X +lz4_flex@0.11.3 X +lzma-sys@0.1.20 X X +md-5@0.10.6 X X +memchr@2.7.4 X X +memoffset@0.9.1 X +metainfo@0.7.12 X X +mime@0.3.17 X X +miniz_oxide@0.7.4 X X X +mio@1.0.2 X +motore@0.4.1 X X +motore-macros@0.4.1 X X +mur3@0.1.0 X +murmur3@0.5.2 X X +nix@0.28.0 X +num@0.4.3 X X +num-bigint@0.4.6 X X +num-complex@0.4.6 X X +num-integer@0.1.46 X X +num-iter@0.1.45 X X +num-rational@0.4.2 X X +num-traits@0.2.19 X X +num_cpus@1.16.0 X X +num_enum@0.7.3 X X X +num_enum_derive@0.7.3 X X X +object@0.36.3 X X +object_store@0.10.2 X X +once_cell@1.19.0 X X +opendal@0.49.0 X +ordered-float@2.10.1 X +ordered-float@4.2.2 X +page_size@0.6.0 X X +parking@2.2.0 X X +parking_lot@0.12.3 X X +parking_lot_core@0.9.10 X X +parquet@52.2.0 X +parse-zoneinfo@0.3.1 X +paste@1.0.15 X X +percent-encoding@2.3.1 X X +petgraph@0.6.5 X X +phf@0.11.2 X +phf_codegen@0.11.2 X +phf_generator@0.11.2 X +phf_shared@0.11.2 X +pilota@0.11.3 X X +pin-project@1.1.5 X X +pin-project-internal@1.1.5 X X +pin-project-lite@0.2.14 X X +pin-utils@0.1.0 X X +pkg-config@0.3.30 X X +ppv-lite86@0.2.20 X X +proc-macro-crate@3.1.0 X X +proc-macro2@1.0.86 X X +quad-rand@0.2.1 X +quick-xml@0.36.1 X +quote@1.0.36 X X +radium@0.7.0 X +rand@0.8.5 X X +rand_chacha@0.3.1 X X +rand_core@0.6.4 X X +redox_syscall@0.5.3 X +regex@1.10.6 X X +regex-automata@0.4.7 X X +regex-lite@0.1.6 X X +regex-syntax@0.8.4 X X +reqsign@0.16.0 X +reqwest@0.12.5 X X +ring@0.17.8 X +rle-decode-fast@1.0.3 X X +rust_decimal@1.35.0 X +rustc-demangle@0.1.24 X X +rustc-hash@2.0.0 X X +rustc_version@0.4.0 X X +rustix@0.38.34 X X X +rustls@0.23.12 X X X +rustls-pemfile@2.1.3 X X X +rustls-pki-types@1.8.0 X X +rustls-webpki@0.102.6 X +rustversion@1.0.17 X X +ryu@1.0.18 X X +same-file@1.0.6 X X +scopeguard@1.2.0 X X +semver@1.0.23 X X +seq-macro@0.3.5 X X +serde@1.0.207 X X +serde_bytes@0.11.15 X X +serde_derive@1.0.207 X X +serde_json@1.0.124 X X +serde_repr@0.1.19 X X +serde_urlencoded@0.7.1 X X +serde_with@3.9.0 X X +serde_with_macros@3.9.0 X X +sha1@0.10.6 X X +sha2@0.10.8 X X +shlex@1.3.0 X X +signal-hook-registry@1.4.2 X X +simdutf8@0.1.4 X X +siphasher@0.3.11 X X +slab@0.4.9 X +smallvec@1.13.2 X X +snafu@0.7.5 X X +snafu-derive@0.7.5 X X +snap@1.1.1 X +socket2@0.5.7 X X +sonic-rs@0.3.10 X +spin@0.9.8 X +sqlparser@0.49.0 X +sqlparser_derive@0.2.2 X +static_assertions@1.1.0 X X +strsim@0.11.1 X +strum@0.26.3 X +strum_macros@0.26.4 X +subtle@2.6.1 X +syn@1.0.109 X X +syn@2.0.74 X X +sync_wrapper@1.0.1 X +tap@1.0.1 X +tempfile@3.12.0 X X +thiserror@1.0.63 X X +thiserror-impl@1.0.63 X X +thrift@0.17.0 X +tiny-keccak@2.0.2 X +tinyvec@1.8.0 X X X +tinyvec_macros@0.1.1 X X X +tokio@1.39.2 X +tokio-macros@2.4.0 X +tokio-rustls@0.26.0 X X +tokio-stream@0.1.15 X +tokio-util@0.7.11 X +toml_datetime@0.6.8 X X +toml_edit@0.21.1 X X +tower@0.4.13 X +tower-layer@0.3.3 X +tower-service@0.3.3 X +tracing@0.1.40 X +tracing-attributes@0.1.27 X +tracing-core@0.1.32 X +try-lock@0.2.5 X +twox-hash@1.6.3 X +typed-builder@0.19.1 X X +typed-builder-macro@0.19.1 X X +typenum@1.17.0 X X +unicode-bidi@0.3.15 X X +unicode-ident@1.0.12 X X X +unicode-normalization@0.1.23 X X +unicode-segmentation@1.11.0 X X +unicode-width@0.1.13 X X +untrusted@0.9.0 X +url@2.5.2 X X +utf8parse@0.2.2 X X +uuid@1.10.0 X X +version_check@0.9.5 X X +volo@0.10.1 X X +volo-thrift@0.10.2 X X +walkdir@2.5.0 X X +want@0.3.1 X +wasi@0.11.0+wasi-snapshot-preview1 X X X +wasm-bindgen@0.2.93 X X +wasm-bindgen-backend@0.2.93 X X +wasm-bindgen-futures@0.4.43 X X +wasm-bindgen-macro@0.2.93 X X +wasm-bindgen-macro-support@0.2.93 X X +wasm-bindgen-shared@0.2.93 X X +wasm-streams@0.4.0 X X +web-sys@0.3.70 X X +webpki-roots@0.26.3 X +winapi@0.3.9 X X +winapi-i686-pc-windows-gnu@0.4.0 X X +winapi-util@0.1.9 X X +winapi-x86_64-pc-windows-gnu@0.4.0 X X +windows-core@0.52.0 X X +windows-sys@0.48.0 X X +windows-sys@0.52.0 X X +windows-sys@0.59.0 X X +windows-targets@0.48.5 X X +windows-targets@0.52.6 X X +windows_aarch64_gnullvm@0.48.5 X X +windows_aarch64_gnullvm@0.52.6 X X +windows_aarch64_msvc@0.48.5 X X +windows_aarch64_msvc@0.52.6 X X +windows_i686_gnu@0.48.5 X X +windows_i686_gnu@0.52.6 X X +windows_i686_gnullvm@0.52.6 X X +windows_i686_msvc@0.48.5 X X +windows_i686_msvc@0.52.6 X X +windows_x86_64_gnu@0.48.5 X X +windows_x86_64_gnu@0.52.6 X X +windows_x86_64_gnullvm@0.48.5 X X +windows_x86_64_gnullvm@0.52.6 X X +windows_x86_64_msvc@0.48.5 X X +windows_x86_64_msvc@0.52.6 X X +winnow@0.5.40 X +winreg@0.52.0 X +wyz@0.5.1 X +xz2@0.1.7 X X +zerocopy@0.7.35 X X X +zerocopy-derive@0.7.35 X X X +zeroize@1.8.1 X X +zstd@0.13.2 X +zstd-safe@7.2.1 X X +zstd-sys@2.0.12+zstd.1.5.6 X X diff --git a/crates/integrations/datafusion/README.md b/crates/integrations/datafusion/README.md new file mode 100644 index 000000000..134a8eff4 --- /dev/null +++ b/crates/integrations/datafusion/README.md @@ -0,0 +1,22 @@ + + +# Apache Iceberg DataFusion Integration + +This crate contains the integration of Apache DataFusion and Apache Iceberg. diff --git a/crates/integrations/datafusion/src/catalog.rs b/crates/integrations/datafusion/src/catalog.rs new file mode 100644 index 000000000..ab6ebdccc --- /dev/null +++ b/crates/integrations/datafusion/src/catalog.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::catalog::{CatalogProvider, SchemaProvider}; +use futures::future::try_join_all; +use iceberg::{Catalog, NamespaceIdent, Result}; + +use crate::schema::IcebergSchemaProvider; + +/// Provides an interface to manage and access multiple schemas +/// within an Iceberg [`Catalog`]. +/// +/// Acts as a centralized catalog provider that aggregates +/// multiple [`SchemaProvider`], each associated with distinct namespaces. +pub struct IcebergCatalogProvider { + /// A `HashMap` where keys are namespace names + /// and values are dynamic references to objects implementing the + /// [`SchemaProvider`] trait. + schemas: HashMap>, +} + +impl IcebergCatalogProvider { + /// Asynchronously tries to construct a new [`IcebergCatalogProvider`] + /// using the given client to fetch and initialize schema providers for + /// each namespace in the Iceberg [`Catalog`]. + /// + /// This method retrieves the list of namespace names + /// attempts to create a schema provider for each namespace, and + /// collects these providers into a `HashMap`. + pub async fn try_new(client: Arc) -> Result { + // TODO: + // Schemas and providers should be cached and evicted based on time + // As of right now; schemas might become stale. + let schema_names: Vec<_> = client + .list_namespaces(None) + .await? + .iter() + .flat_map(|ns| ns.as_ref().clone()) + .collect(); + + let providers = try_join_all( + schema_names + .iter() + .map(|name| { + IcebergSchemaProvider::try_new( + client.clone(), + NamespaceIdent::new(name.clone()), + ) + }) + .collect::>(), + ) + .await?; + + let schemas: HashMap> = schema_names + .into_iter() + .zip(providers.into_iter()) + .map(|(name, provider)| { + let provider = Arc::new(provider) as Arc; + (name, provider) + }) + .collect(); + + Ok(IcebergCatalogProvider { schemas }) + } +} + +impl CatalogProvider for IcebergCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.schemas.keys().cloned().collect() + } + + fn schema(&self, name: &str) -> Option> { + self.schemas.get(name).cloned() + } +} diff --git a/crates/integrations/datafusion/src/error.rs b/crates/integrations/datafusion/src/error.rs new file mode 100644 index 000000000..273d92fa6 --- /dev/null +++ b/crates/integrations/datafusion/src/error.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use anyhow::anyhow; +use iceberg::{Error, ErrorKind}; + +/// Converts a datafusion error into an iceberg error. +pub fn from_datafusion_error(error: datafusion::error::DataFusionError) -> Error { + Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting datafusion error".to_string(), + ) + .with_source(anyhow!("datafusion error: {:?}", error)) +} +/// Converts an iceberg error into a datafusion error. +pub fn to_datafusion_error(error: Error) -> datafusion::error::DataFusionError { + datafusion::error::DataFusionError::External(error.into()) +} diff --git a/crates/integrations/datafusion/src/lib.rs b/crates/integrations/datafusion/src/lib.rs new file mode 100644 index 000000000..c40290116 --- /dev/null +++ b/crates/integrations/datafusion/src/lib.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod catalog; +pub use catalog::*; + +mod error; +pub use error::*; + +mod physical_plan; +mod schema; +mod table; diff --git a/crates/integrations/datafusion/src/physical_plan/mod.rs b/crates/integrations/datafusion/src/physical_plan/mod.rs new file mode 100644 index 000000000..5ae586a0a --- /dev/null +++ b/crates/integrations/datafusion/src/physical_plan/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod scan; diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs new file mode 100644 index 000000000..c50b32efb --- /dev/null +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; + +use datafusion::arrow::array::RecordBatch; +use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; +use datafusion::error::Result as DFResult; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, +}; +use futures::{Stream, TryStreamExt}; +use iceberg::table::Table; + +use crate::to_datafusion_error; + +/// Manages the scanning process of an Iceberg [`Table`], encapsulating the +/// necessary details and computed properties required for execution planning. +#[derive(Debug)] +pub(crate) struct IcebergTableScan { + /// A table in the catalog. + table: Table, + /// A reference-counted arrow `Schema`. + schema: ArrowSchemaRef, + /// Stores certain, often expensive to compute, + /// plan properties used in query optimization. + plan_properties: PlanProperties, +} + +impl IcebergTableScan { + /// Creates a new [`IcebergTableScan`] object. + pub(crate) fn new(table: Table, schema: ArrowSchemaRef) -> Self { + let plan_properties = Self::compute_properties(schema.clone()); + + Self { + table, + schema, + plan_properties, + } + } + + /// Computes [`PlanProperties`] used in query optimization. + fn compute_properties(schema: ArrowSchemaRef) -> PlanProperties { + // TODO: + // This is more or less a placeholder, to be replaced + // once we support output-partitioning + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ) + } +} + +impl ExecutionPlan for IcebergTableScan { + fn name(&self) -> &str { + "IcebergTableScan" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn children(&self) -> Vec<&Arc<(dyn ExecutionPlan + 'static)>> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> DFResult> { + Ok(self) + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DFResult { + let fut = get_batch_stream(self.table.clone()); + let stream = futures::stream::once(fut).try_flatten(); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema.clone(), + stream, + ))) + } +} + +impl DisplayAs for IcebergTableScan { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "IcebergTableScan") + } +} + +/// Asynchronously retrieves a stream of [`RecordBatch`] instances +/// from a given table. +/// +/// This function initializes a [`TableScan`], builds it, +/// and then converts it into a stream of Arrow [`RecordBatch`]es. +async fn get_batch_stream( + table: Table, +) -> DFResult> + Send>>> { + let table_scan = table.scan().build().map_err(to_datafusion_error)?; + + let stream = table_scan + .to_arrow() + .await + .map_err(to_datafusion_error)? + .map_err(to_datafusion_error); + + Ok(Box::pin(stream)) +} diff --git a/crates/integrations/datafusion/src/schema.rs b/crates/integrations/datafusion/src/schema.rs new file mode 100644 index 000000000..8133b3746 --- /dev/null +++ b/crates/integrations/datafusion/src/schema.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::catalog::SchemaProvider; +use datafusion::datasource::TableProvider; +use datafusion::error::Result as DFResult; +use futures::future::try_join_all; +use iceberg::{Catalog, NamespaceIdent, Result}; + +use crate::table::IcebergTableProvider; + +/// Represents a [`SchemaProvider`] for the Iceberg [`Catalog`], managing +/// access to table providers within a specific namespace. +pub(crate) struct IcebergSchemaProvider { + /// A `HashMap` where keys are table names + /// and values are dynamic references to objects implementing the + /// [`TableProvider`] trait. + tables: HashMap>, +} + +impl IcebergSchemaProvider { + /// Asynchronously tries to construct a new [`IcebergSchemaProvider`] + /// using the given client to fetch and initialize table providers for + /// the provided namespace in the Iceberg [`Catalog`]. + /// + /// This method retrieves a list of table names + /// attempts to create a table provider for each table name, and + /// collects these providers into a `HashMap`. + pub(crate) async fn try_new( + client: Arc, + namespace: NamespaceIdent, + ) -> Result { + // TODO: + // Tables and providers should be cached based on table_name + // if we have a cache miss; we update our internal cache & check again + // As of right now; tables might become stale. + let table_names: Vec<_> = client + .list_tables(&namespace) + .await? + .iter() + .map(|tbl| tbl.name().to_string()) + .collect(); + + let providers = try_join_all( + table_names + .iter() + .map(|name| IcebergTableProvider::try_new(client.clone(), namespace.clone(), name)) + .collect::>(), + ) + .await?; + + let tables: HashMap> = table_names + .into_iter() + .zip(providers.into_iter()) + .map(|(name, provider)| { + let provider = Arc::new(provider) as Arc; + (name, provider) + }) + .collect(); + + Ok(IcebergSchemaProvider { tables }) + } +} + +#[async_trait] +impl SchemaProvider for IcebergSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.tables.keys().cloned().collect() + } + + fn table_exist(&self, name: &str) -> bool { + self.tables.contains_key(name) + } + + async fn table(&self, name: &str) -> DFResult>> { + Ok(self.tables.get(name).cloned()) + } +} diff --git a/crates/integrations/datafusion/src/table.rs b/crates/integrations/datafusion/src/table.rs new file mode 100644 index 000000000..7ff7b2211 --- /dev/null +++ b/crates/integrations/datafusion/src/table.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; +use datafusion::catalog::Session; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result as DFResult; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::ExecutionPlan; +use iceberg::arrow::schema_to_arrow_schema; +use iceberg::table::Table; +use iceberg::{Catalog, NamespaceIdent, Result, TableIdent}; + +use crate::physical_plan::scan::IcebergTableScan; + +/// Represents a [`TableProvider`] for the Iceberg [`Catalog`], +/// managing access to a [`Table`]. +pub(crate) struct IcebergTableProvider { + /// A table in the catalog. + table: Table, + /// A reference-counted arrow `Schema`. + schema: ArrowSchemaRef, +} + +impl IcebergTableProvider { + /// Asynchronously tries to construct a new [`IcebergTableProvider`] + /// using the given client and table name to fetch an actual [`Table`] + /// in the provided namespace. + pub(crate) async fn try_new( + client: Arc, + namespace: NamespaceIdent, + name: impl Into, + ) -> Result { + let ident = TableIdent::new(namespace, name.into()); + let table = client.load_table(&ident).await?; + + let schema = Arc::new(schema_to_arrow_schema(table.metadata().current_schema())?); + + Ok(IcebergTableProvider { table, schema }) + } +} + +#[async_trait] +impl TableProvider for IcebergTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> ArrowSchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> DFResult> { + Ok(Arc::new(IcebergTableScan::new( + self.table.clone(), + self.schema.clone(), + ))) + } +} diff --git a/crates/integrations/datafusion/tests/integration_datafusion_test.rs b/crates/integrations/datafusion/tests/integration_datafusion_test.rs new file mode 100644 index 000000000..9e62930fd --- /dev/null +++ b/crates/integrations/datafusion/tests/integration_datafusion_test.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for Iceberg Datafusion with Hive Metastore. + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::arrow::datatypes::DataType; +use datafusion::execution::context::SessionContext; +use iceberg::io::FileIOBuilder; +use iceberg::spec::{NestedField, PrimitiveType, Schema, Type}; +use iceberg::{Catalog, NamespaceIdent, Result, TableCreation}; +use iceberg_catalog_memory::MemoryCatalog; +use iceberg_datafusion::IcebergCatalogProvider; +use tempfile::TempDir; + +fn temp_path() -> String { + let temp_dir = TempDir::new().unwrap(); + temp_dir.path().to_str().unwrap().to_string() +} + +fn get_iceberg_catalog() -> MemoryCatalog { + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + MemoryCatalog::new(file_io, Some(temp_path())) +} + +async fn set_test_namespace(catalog: &MemoryCatalog, namespace: &NamespaceIdent) -> Result<()> { + let properties = HashMap::new(); + + catalog.create_namespace(namespace, properties).await?; + + Ok(()) +} + +fn set_table_creation(location: impl ToString, name: impl ToString) -> Result { + let schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "foo", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build()?; + + let creation = TableCreation::builder() + .location(location.to_string()) + .name(name.to_string()) + .properties(HashMap::new()) + .schema(schema) + .build(); + + Ok(creation) +} + +#[tokio::test] +async fn test_provider_get_table_schema() -> Result<()> { + let iceberg_catalog = get_iceberg_catalog(); + let namespace = NamespaceIdent::new("test_provider_get_table_schema".to_string()); + set_test_namespace(&iceberg_catalog, &namespace).await?; + + let creation = set_table_creation(temp_path(), "my_table")?; + iceberg_catalog.create_table(&namespace, creation).await?; + + let client = Arc::new(iceberg_catalog); + let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); + + let ctx = SessionContext::new(); + ctx.register_catalog("catalog", catalog); + + let provider = ctx.catalog("catalog").unwrap(); + let schema = provider.schema("test_provider_get_table_schema").unwrap(); + + let table = schema.table("my_table").await.unwrap().unwrap(); + let table_schema = table.schema(); + + let expected = [("foo", &DataType::Int32), ("bar", &DataType::Utf8)]; + + for (field, exp) in table_schema.fields().iter().zip(expected.iter()) { + assert_eq!(field.name(), exp.0); + assert_eq!(field.data_type(), exp.1); + assert!(!field.is_nullable()) + } + + Ok(()) +} + +#[tokio::test] +async fn test_provider_list_table_names() -> Result<()> { + let iceberg_catalog = get_iceberg_catalog(); + let namespace = NamespaceIdent::new("test_provider_list_table_names".to_string()); + set_test_namespace(&iceberg_catalog, &namespace).await?; + + let creation = set_table_creation(temp_path(), "my_table")?; + iceberg_catalog.create_table(&namespace, creation).await?; + + let client = Arc::new(iceberg_catalog); + let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); + + let ctx = SessionContext::new(); + ctx.register_catalog("catalog", catalog); + + let provider = ctx.catalog("catalog").unwrap(); + let schema = provider.schema("test_provider_list_table_names").unwrap(); + + let expected = vec!["my_table"]; + let result = schema.table_names(); + + assert_eq!(result, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_provider_list_schema_names() -> Result<()> { + let iceberg_catalog = get_iceberg_catalog(); + let namespace = NamespaceIdent::new("test_provider_list_schema_names".to_string()); + set_test_namespace(&iceberg_catalog, &namespace).await?; + + set_table_creation("test_provider_list_schema_names", "my_table")?; + let client = Arc::new(iceberg_catalog); + let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); + + let ctx = SessionContext::new(); + ctx.register_catalog("catalog", catalog); + + let provider = ctx.catalog("catalog").unwrap(); + + let expected = ["test_provider_list_schema_names"]; + let result = provider.schema_names(); + + assert!(expected + .iter() + .all(|item| result.contains(&item.to_string()))); + Ok(()) +} diff --git a/crates/test_utils/Cargo.toml b/crates/test_utils/Cargo.toml index 91210c50f..d4f6e1696 100644 --- a/crates/test_utils/Cargo.toml +++ b/crates/test_utils/Cargo.toml @@ -17,8 +17,13 @@ [package] name = "iceberg_test_utils" -version = "0.1.0" -edition = "2021" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +rust-version = { workspace = true } + +repository = { workspace = true } +license = { workspace = true } [dependencies] env_logger = { workspace = true } diff --git a/crates/test_utils/src/cmd.rs b/crates/test_utils/src/cmd.rs index 604d4a14d..503d63d15 100644 --- a/crates/test_utils/src/cmd.rs +++ b/crates/test_utils/src/cmd.rs @@ -28,14 +28,27 @@ pub fn run_command(mut cmd: Command, desc: impl ToString) { } } -pub fn get_cmd_output(mut cmd: Command, desc: impl ToString) -> String { +pub fn get_cmd_output_result(mut cmd: Command, desc: impl ToString) -> Result { let desc = desc.to_string(); log::info!("Starting to {}, command: {:?}", &desc, cmd); - let output = cmd.output().unwrap(); - if output.status.success() { - log::info!("{} succeed!", desc); - String::from_utf8(output.stdout).unwrap() - } else { - panic!("{} failed: {:?}", desc, output.status); + let result = cmd.output(); + match result { + Ok(output) => { + if output.status.success() { + log::info!("{} succeed!", desc); + Ok(String::from_utf8(output.stdout).unwrap()) + } else { + Err(format!("{} failed with rc: {:?}", desc, output.status)) + } + } + Err(err) => Err(format!("{} failed with error: {}", desc, { err })), + } +} + +pub fn get_cmd_output(cmd: Command, desc: impl ToString) -> String { + let result = get_cmd_output_result(cmd, desc); + match result { + Ok(output_str) => output_str, + Err(err) => panic!("{}", err), } } diff --git a/crates/test_utils/src/docker.rs b/crates/test_utils/src/docker.rs index 6c5fbef1e..bde9737b1 100644 --- a/crates/test_utils/src/docker.rs +++ b/crates/test_utils/src/docker.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::cmd::{get_cmd_output, run_command}; +use std::net::IpAddr; use std::process::Command; -/// A utility to manage lifecycle of docker compose. +use crate::cmd::{get_cmd_output, get_cmd_output_result, run_command}; + +/// A utility to manage the lifecycle of `docker compose`. /// -/// It's will start docker compose when calling `run` method, and will be stopped when dropped. +/// It will start `docker compose` when calling the `run` method and will be stopped via [`Drop`]. #[derive(Debug)] pub struct DockerCompose { project_name: String, @@ -39,10 +41,36 @@ impl DockerCompose { self.project_name.as_str() } + fn get_os_arch() -> String { + let mut cmd = Command::new("docker"); + cmd.arg("info") + .arg("--format") + .arg("{{.OSType}}/{{.Architecture}}"); + + let result = get_cmd_output_result(cmd, "Get os arch".to_string()); + match result { + Ok(value) => value.trim().to_string(), + Err(_err) => { + // docker/podman do not consistently place OSArch info in the same json path across OS and versions + // Below tries an alternative path if the above path fails + let mut alt_cmd = Command::new("docker"); + alt_cmd + .arg("info") + .arg("--format") + .arg("{{.Version.OsArch}}"); + get_cmd_output(alt_cmd, "Get os arch".to_string()) + .trim() + .to_string() + } + } + } + pub fn run(&self) { let mut cmd = Command::new("docker"); cmd.current_dir(&self.docker_compose_dir); + cmd.env("DOCKER_DEFAULT_PLATFORM", Self::get_os_arch()); + cmd.args(vec![ "compose", "-p", @@ -63,7 +91,7 @@ impl DockerCompose { ) } - pub fn get_container_ip(&self, service_name: impl AsRef) -> String { + pub fn get_container_ip(&self, service_name: impl AsRef) -> IpAddr { let container_name = format!("{}-{}-1", self.project_name, service_name.as_ref()); let mut cmd = Command::new("docker"); cmd.arg("inspect") @@ -71,9 +99,16 @@ impl DockerCompose { .arg("{{range.NetworkSettings.Networks}}{{.IPAddress}}{{end}}") .arg(&container_name); - get_cmd_output(cmd, format!("Get container ip of {container_name}")) + let ip_result = get_cmd_output(cmd, format!("Get container ip of {container_name}")) .trim() - .to_string() + .parse::(); + match ip_result { + Ok(ip) => ip, + Err(e) => { + log::error!("Invalid IP, {e}"); + panic!("Failed to parse IP for {container_name}") + } + } } } diff --git a/deny.toml b/deny.toml new file mode 100644 index 000000000..9c62e0d68 --- /dev/null +++ b/deny.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[licenses] +allow = [ + "Apache-2.0", + "Apache-2.0 WITH LLVM-exception", + "MIT", + "BSD-3-Clause", + "ISC", + "CC0-1.0", + "Unicode-DFS-2016", + "Zlib" +] +exceptions = [ + { allow = [ + "OpenSSL", + ], name = "ring" } +] + +[[licenses.clarify]] +name = "ring" +# SPDX considers OpenSSL to encompass both the OpenSSL and SSLeay licenses +# https://spdx.org/licenses/OpenSSL.html +# ISC - Both BoringSSL and ring use this for their new files +# MIT - "Files in third_party/ have their own licenses, as described therein. The MIT +# license, for third_party/fiat, which, unlike other third_party directories, is +# compiled into non-test libraries, is included below." +# OpenSSL - Obviously +expression = "ISC AND MIT AND OpenSSL" +license-files = [{ path = "LICENSE", hash = 0xbd0eed23 }] diff --git a/docs/contributing/podman.md b/docs/contributing/podman.md new file mode 100644 index 000000000..3281ad4da --- /dev/null +++ b/docs/contributing/podman.md @@ -0,0 +1,85 @@ + + +# Using Podman instead of Docker + +Iceberg-rust does not require containerization, except for integration tests, where "docker" and "docker-compose" are used to start containers for minio and various catalogs. Below instructions setup "rootful podman" and docker's official docker-compose plugin to run integration tests as an alternative to docker or Orbstack. + +1. Have podman v4 or newer. + ```console + $ podman --version + podman version 4.9.4-rhel + ``` + +2. Open file `/usr/bin/docker` and add the below contents: + ```bash + #!/bin/sh + [ -e /etc/containers/nodocker ] || \ + echo "Emulate Docker CLI using podman. Create /etc/containers/nodocker to quiet msg." >&2 + exec sudo /usr/bin/podman "$@" + ``` + +3. Install the [docker compose plugin](https://docs.docker.com/compose/install/linux). Check for successful installation. + ```console + $ docker compose version + Docker Compose version v2.28.1 + ``` + +4. Append the below to `~/.bashrc` or equivalent shell config: + ```bash + export DOCKER_HOST=unix:///run/podman/podman.sock + ``` + +5. Start the "rootful" podman socket. + ```shell + sudo systemctl start podman.socket + sudo systemctl status podman.socket + ``` + +6. Check that the following symlink exists. + ```console + $ ls -al /var/run/docker.sock + lrwxrwxrwx 1 root root 27 Jul 24 12:18 /var/run/docker.sock -> /var/run/podman/podman.sock + ``` + If the symlink does not exist, create it. + ```shell + sudo ln -s /var/run/podman/podman.sock /var/run/docker.sock + ``` + +7. Check that the docker socket is working. + ```shell + sudo curl -H "Content-Type: application/json" --unix-socket /var/run/docker.sock http://localhost/_ping + ``` + +8. Try some integration tests! + ```shell + cargo test -p iceberg --test file_io_s3_test + ``` + +# References + +* +* + +# Note on rootless containers + +As of podman v4, ["To be succinct and simple, when running rootless containers, the container itself does not have an IP address"](https://www.redhat.com/sysadmin/container-ip-address-podman) This causes issues with iceberg-rust's integration tests, which rely upon ip-addressable containers via docker-compose. As a result, podman "rootful" containers are required throughout to ensure containers have IP addresses. Perhaps as a future work or with updates to default podman networking, the need for "rootful" podman containers can be eliminated. + +* +* diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a5a7402a5..7b10a8692 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -16,5 +16,5 @@ # under the License. [toolchain] -channel = "1.72.1" +channel = "nightly-2024-06-10" components = ["rustfmt", "clippy"] diff --git a/rustfmt.toml b/rustfmt.toml index 39be343c6..91d924daf 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -17,3 +17,10 @@ edition = "2021" reorder_imports = true + +format_code_in_doc_comments = true +group_imports = "StdExternalCrate" +imports_granularity = "Module" +overflow_delimited_expr = true +trailing_comma = "Vertical" +where_single_line = true diff --git a/scripts/dependencies.py b/scripts/dependencies.py new file mode 100644 index 000000000..1c8db96b9 --- /dev/null +++ b/scripts/dependencies.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, REMAINDER +import subprocess +import os + +DIRS = [ + "crates/iceberg", + + "crates/catalog/glue", "crates/catalog/hms", + "crates/catalog/memory", "crates/catalog/rest", + "crates/catalog/sql", + + "crates/integrations/datafusion" + +] + + +def check_deps(): + cargo_dirs = DIRS + for root in cargo_dirs: + print(f"Checking dependencies of {root}") + subprocess.run(["cargo", "deny", "check", "license"], cwd=root) + + +def generate_deps(): + cargo_dirs = DIRS + for root in cargo_dirs: + print(f"Generating dependencies {root}") + result = subprocess.run( + ["cargo", "deny", "list", "-f", "tsv", "-t", "0.6"], + cwd=root, + capture_output=True, + text=True, + ) + with open(f"{root}/DEPENDENCIES.rust.tsv", "w") as f: + f.write(result.stdout) + + +if __name__ == "__main__": + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.set_defaults(func=parser.print_help) + subparsers = parser.add_subparsers() + + parser_check = subparsers.add_parser('check', + description="Check dependencies", + help="Check dependencies") + parser_check.set_defaults(func=check_deps) + + parser_generate = subparsers.add_parser( + 'generate', + description="Generate dependencies", + help="Generate dependencies") + parser_generate.set_defaults(func=generate_deps) + + args = parser.parse_args() + arg_dict = dict(vars(args)) + del arg_dict['func'] + args.func(**arg_dict) diff --git a/scripts/release.sh b/scripts/release.sh new file mode 100755 index 000000000..884ed8a0a --- /dev/null +++ b/scripts/release.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +if [ -z "${ICEBERG_VERSION}" ]; then + echo "ICEBERG_VERSION is unset" + exit 1 +else + echo "var is set to '$ICEBERG_VERSION'" +fi + +# tar source code +release_version=${ICEBERG_VERSION} +# rc versions +rc_version="${ICEBERG_VERSION_RC:-rc.1}" +# Corresponding git repository branch +git_branch=release-${release_version}-${rc_version} + +rm -rf dist +mkdir -p dist/ + +echo "> Checkout version branch" +git checkout -B "${git_branch}" + +echo "> Start package" +git archive --format=tar.gz --output="dist/apache-iceberg-rust-$release_version-src.tar.gz" --prefix="apache-iceberg-rust-$release_version-src/" --add-file=Cargo.toml "$git_branch" + +cd dist +echo "> Generate signature" +for i in *.tar.gz; do + echo "$i" + gpg --armor --output "$i.asc" --detach-sig "$i" +done +echo "> Check signature" +for i in *.tar.gz; do + echo "$i" + gpg --verify "$i.asc" "$i" +done +echo "> Generate sha512sum" +for i in *.tar.gz; do + echo "$i" + sha512sum "$i" >"$i.sha512" +done +echo "> Check sha512sum" +for i in *.tar.gz; do + echo "$i" + sha512sum --check "$i.sha512" +done + +cd .. +echo "> Check license" +docker run -it --rm -v $(pwd):/github/workspace apache/skywalking-eyes header check diff --git a/scripts/verify.py b/scripts/verify.py new file mode 100644 index 000000000..415bb243a --- /dev/null +++ b/scripts/verify.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import subprocess +import sys +import os + +BASE_DIR = os.getcwd() + + +def check_rust(): + try: + subprocess.run(["cargo", "--version"], check=True) + return True + except FileNotFoundError: + return False + except Exception as e: + raise Exception("Check rust met unexpected error", e) + +def build_core(): + print("Start building iceberg rust") + + subprocess.run(["cargo", "build", "--release"], check=True) + +def main(): + if not check_rust(): + print( + "Cargo is not found, please check if rust development has been setup correctly" + ) + print("Visit https://www.rust-lang.org/tools/install for more information") + sys.exit(1) + + build_core() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/website/.gitignore b/website/.gitignore new file mode 100644 index 000000000..6155ce079 --- /dev/null +++ b/website/.gitignore @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +book diff --git a/website/README.md b/website/README.md new file mode 100644 index 000000000..9914b8222 --- /dev/null +++ b/website/README.md @@ -0,0 +1,40 @@ + + +# Iceberg Rust Website + +## Setup + +Install mdbook first + +```shell +cargo install mdbook +``` + +## Preview + +```shell +mdbook serve +``` + +## Build + +```shell +mdbook build +``` diff --git a/website/book.toml b/website/book.toml new file mode 100644 index 000000000..e8a90c721 --- /dev/null +++ b/website/book.toml @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[book] +authors = ["Iceberg Community"] +language = "en" +multilingual = false +src = "src" +title = "Iceberg Rust" + +[output.html] +git-repository-url = "https://github.com/apache/iceberg-rust" +git-repository-icon = "fa-github" +edit-url-template = "https://github.com/apache/iceberg-rust/edit/main/website/{path}" +cname = "rust.iceberg.apache.org" +no-section-label = true diff --git a/website/src/CONTRIBUTING.md b/website/src/CONTRIBUTING.md new file mode 120000 index 000000000..f939e75f2 --- /dev/null +++ b/website/src/CONTRIBUTING.md @@ -0,0 +1 @@ +../../CONTRIBUTING.md \ No newline at end of file diff --git a/website/src/SUMMARY.md b/website/src/SUMMARY.md new file mode 100644 index 000000000..e2a07ba9c --- /dev/null +++ b/website/src/SUMMARY.md @@ -0,0 +1,36 @@ + + +- [Introduction](./introduction.md) + +# User Guide + +- [Install](./install.md) +- [Download](./download.md) +- [API](./api.md) + +# Developer Guide + +- [Contributing](./CONTRIBUTING.md) +- [Release](./release.md) + +# Reference + +- [Using Podman instead of Docker](./reference/podman.md) +- [Setup GPG key](./reference/setup_gpg.md) \ No newline at end of file diff --git a/website/src/api.md b/website/src/api.md new file mode 100644 index 000000000..eaf81ee66 --- /dev/null +++ b/website/src/api.md @@ -0,0 +1,59 @@ + + +# Catalog + +`Catalog` is the entry point for accessing iceberg tables. You can use a catalog to: + +* Create and list namespaces. +* Create, load, and drop tables + +Currently only rest catalog has been implemented, and other catalogs are under active development. Here is an +example of how to create a `RestCatalog`: + +```rust,no_run,noplayground +{{#rustdoc_include ../../crates/examples/src/rest_catalog_namespace.rs:create_catalog}} +``` + +You can run following code to list all root namespaces: + +```rust,no_run,noplayground +{{#rustdoc_include ../../crates/examples/src/rest_catalog_namespace.rs:list_all_namespace}} +``` + +Then you can run following code to create namespace: +```rust,no_run,noplayground +{{#rustdoc_include ../../crates/examples/src/rest_catalog_namespace.rs:create_namespace}} +``` + +# Table + +After creating `Catalog`, we can manipulate tables through `Catalog`. + +You can use following code to create a table: + +```rust,no_run,noplayground +{{#rustdoc_include ../../crates/examples/src/rest_catalog_table.rs:create_table}} +``` + +Also, you can load a table directly: + +```rust,no_run,noplayground +{{#rustdoc_include ../../crates/examples/src/rest_catalog_table.rs:load_table}} +``` diff --git a/website/src/download.md b/website/src/download.md new file mode 100644 index 000000000..8ec9e6ea3 --- /dev/null +++ b/website/src/download.md @@ -0,0 +1,60 @@ + + +# Apache Iceberg™ Rust Downloads + +The official Apache Iceberg-Rust releases are provided as source artifacts. + +## Releases + +The latest source release is [0.3.0](https://www.apache.org/dyn/closer.lua/iceberg/iceberg-rust-0.3.0/apache-iceberg-rust-0.3.0-src.tar.gz?action=download) ([asc](https://downloads.apache.org/iceberg/iceberg-rust-0.3.0/apache-iceberg-rust-0.3.0-src.tar.gz.asc), +[sha512](https://downloads.apache.org/iceberg/iceberg-rust-0.3.0/apache-iceberg-rust-0.3.0-src.tar.gz.sha512)). + +For older releases, please check the [archive](https://archive.apache.org/dist/iceberg/). + +## Notes + +* When downloading a release, please verify the OpenPGP compatible signature (or failing that, check the SHA-512); these should be fetched from the main Apache site. +* The KEYS file contains the public keys used for signing release. It is recommended that (when possible) a web of trust is used to confirm the identity of these keys. +* Please download the [KEYS](https://downloads.apache.org/iceberg/KEYS) as well as the .asc signature files. + +### To verify the signature of the release artifact + +You will need to download both the release artifact and the .asc signature file for that artifact. Then verify the signature by: + +* Download the KEYS file and the .asc signature files for the relevant release artifacts. +* Import the KEYS file to your GPG keyring: + + ```shell + gpg --import KEYS + ``` + +* Verify the signature of the release artifact using the following command: + + ```shell + gpg --verify .asc + ``` + +### To verify the checksum of the release artifact + +You will need to download both the release artifact and the .sha512 checksum file for that artifact. Then verify the checksum by: + +```shell +shasum -a 512 -c .sha512 +``` diff --git a/website/src/install.md b/website/src/install.md new file mode 100644 index 000000000..38b76a972 --- /dev/null +++ b/website/src/install.md @@ -0,0 +1,36 @@ + + +# Install + +
+Cargo 1.75.0 or later is required to build. +
+ +Add `iceberg` into `Cargo.toml` dependencies: + +```toml +iceberg = "0.2.0" +``` + +iceberg is under active development, you may want to use the git version instead: + +```toml +iceberg = { git = "https://github.com/apache/iceberg-rust", rev = "commit-hash" } +``` diff --git a/website/src/introduction.md b/website/src/introduction.md new file mode 100644 index 000000000..260ec690e --- /dev/null +++ b/website/src/introduction.md @@ -0,0 +1,22 @@ + + +# Iceberg Rust + +Iceberg Rust is a rust implementation for accessing iceberg tables. diff --git a/website/src/reference/setup_gpg.md b/website/src/reference/setup_gpg.md new file mode 100644 index 000000000..562113d0e --- /dev/null +++ b/website/src/reference/setup_gpg.md @@ -0,0 +1,161 @@ + + +# Setup GPG key + +> This section is a brief from the [Cryptography with OpenPGP](https://infra.apache.org/openpgp.html) guideline. + + +## Install GPG + +For more details, please refer to [GPG official website](https://www.gnupg.org/download/index.html). Here shows one approach to install GPG with `apt`: + +```shell +sudo apt install gnupg2 +``` + +## Generate GPG Key + +Attentions: + +- Name is best to keep consistent with your full name of Apache ID; +- Email should be the Apache email; +- Name is best to only use English to avoid garbled. + +Run `gpg --full-gen-key` and complete the generation interactively: + +```shell +gpg (GnuPG) 2.2.20; Copyright (C) 2020 Free Software Foundation, Inc. +This is free software: you are free to change and redistribute it. +There is NO WARRANTY, to the extent permitted by law. + +Please select what kind of key you want: + (1) RSA and RSA (default) + (2) DSA and Elgamal + (3) DSA (sign only) + (4) RSA (sign only) + (14) Existing key from card +Your selection? 1 # input 1 +RSA keys may be between 1024 and 4096 bits long. +What keysize do you want? (2048) 4096 # input 4096 +Requested keysize is 4096 bits +Please specify how long the key should be valid. + 0 = key does not expire + = key expires in n days + w = key expires in n weeks + m = key expires in n months + y = key expires in n years +Key is valid for? (0) 0 # input 0 +Key does not expire at all +Is this correct? (y/N) y # input y + +GnuPG needs to construct a user ID to identify your key. + +Real name: Hulk Lin # input your name +Email address: hulk@apache.org # input your email +Comment: # input some annotations, optional +You selected this USER-ID: + "Hulk " + +Change (N)ame, (C)omment, (E)mail or (O)kay/(Q)uit? O # input O +We need to generate a lot of random bytes. It is a good idea to perform +some other action (type on the keyboard, move the mouse, utilize the +disks) during the prime generation; this gives the random number +generator a better chance to gain enough entropy. +We need to generate a lot of random bytes. It is a good idea to perform +some other action (type on the keyboard, move the mouse, utilize the +disks) during the prime generation; this gives the random number +generator a better chance to gain enough entropy. + +# Input the security key +┌──────────────────────────────────────────────────────┐ +│ Please enter this passphrase │ +│ │ +│ Passphrase: _______________________________ │ +│ │ +│ │ +└──────────────────────────────────────────────────────┘ +# key generation will be done after your inputting the key with the following output +gpg: key E49B00F626B marked as ultimately trusted +gpg: revocation certificate stored as '/Users/hulk/.gnupg/openpgp-revocs.d/F77B887A4F25A9468C513E9AA3008E49B00F626B.rev' +public and secret key created and signed. + +pub rsa4096 2022-07-12 [SC] + F77B887A4F25A9468C513E9AA3008E49B00F626B +uid [ultimate] hulk +sub rsa4096 2022-07-12 [E] +``` + +## Upload your key to public GPG keyserver + +Firstly, list your key: + +```shell +gpg --list-keys +``` + +The output is like: + +```shell +------------------------------- +pub rsa4096 2022-07-12 [SC] + F77B887A4F25A9468C513E9AA3008E49B00F626B +uid [ultimate] hulk +sub rsa4096 2022-07-12 [E] +``` + +Then, send your key id to key server: + +```shell +gpg --keyserver keys.openpgp.org --send-key # e.g., F77B887A4F25A9468C513E9AA3008E49B00F626B +``` + +Among them, `keys.openpgp.org` is a randomly selected keyserver, you can use `keyserver.ubuntu.com` or any other full-featured keyserver. + +## Check whether the key is created successfully + +Uploading takes about one minute; after that, you can check by your email at the corresponding keyserver. + +Uploading keys to the keyserver is mainly for joining a [Web of Trust](https://infra.apache.org/release-signing.html#web-of-trust). + +## Add your GPG public key to the KEYS document + +:::info + +`SVN` is required for this step. + +::: + +The svn repository of the release branch is: https://dist.apache.org/repos/dist/release/iceberg + +Please always add the public key to KEYS in the release branch: + +```shell +svn co https://dist.apache.org/repos/dist/release/iceberg iceberg-dist +# As this step will copy all the versions, it will take some time. If the network is broken, please use svn cleanup to delete the lock before re-execute it. +cd iceberg-dist +(gpg --list-sigs YOUR_NAME@apache.org && gpg --export --armor YOUR_NAME@apache.org) >> KEYS # Append your key to the KEYS file +svn add . # It is not needed if the KEYS document exists before. +svn ci -m "add gpg key for YOUR_NAME" # Later on, if you are asked to enter a username and password, just use your apache username and password. +``` + +## Upload the GPG public key to your GitHub account + +- Enter https://github.com/settings/keys to add your GPG key. +- Please remember to bind the email address used in the GPG key to your GitHub account (https://github.com/settings/emails) if you find "unverified" after adding it. diff --git a/website/src/release.md b/website/src/release.md new file mode 100644 index 000000000..a822b1e72 --- /dev/null +++ b/website/src/release.md @@ -0,0 +1,384 @@ + + +This document mainly introduces how the release manager releases a new version in accordance with the Apache requirements. + +## Introduction + +`Source Release` is the key point which Apache values, and is also necessary for an ASF release. + +Please remember that publishing software has legal consequences. + +This guide complements the foundation-wide policies and guides: + +- [Release Policy](https://www.apache.org/legal/release-policy.html) +- [Release Distribution Policy](https://infra.apache.org/release-distribution) +- [Release Creation Process](https://infra.apache.org/release-publishing.html) + +## Some Terminology of release + +In the context of our release, we use several terms to describe different stages of the release process. + +Here's an explanation of these terms: + +- `iceberg_version`: the version of Iceberg to be released, like `0.2.0`. +- `release_version`: the version of release candidate, like `0.2.0-rc.1`. +- `rc_version`: the minor version for voting round, like `rc.1`. + +## Preparation + +
+ +This section is the requirements for individuals who are new to the role of release manager. + +
+ +Refer to [Setup GPG Key](reference/setup_gpg.md) to make sure the GPG key has been set up. + +## Start a tracking issue about the next release + +Start a tracking issue on GitHub for the upcoming release to track all tasks that need to be completed. + +Title: + +``` +Tracking issues of Iceberg Rust ${iceberg_version} Release +``` + +Content: + +```markdown +This issue is used to track tasks of the iceberg rust ${iceberg_version} release. + +## Tasks + +### Blockers + +> Blockers are the tasks that must be completed before the release. + +### Build Release + +#### GitHub Side + +- [ ] Bump version in project +- [ ] Update docs +- [ ] Generate dependencies list +- [ ] Push release candidate tag to GitHub + +#### ASF Side + +- [ ] Create an ASF Release +- [ ] Upload artifacts to the SVN dist repo + +### Voting + +- [ ] Start VOTE at iceberg community + +### Official Release + +- [ ] Push the release git tag +- [ ] Publish artifacts to SVN RELEASE branch +- [ ] Change Iceberg Rust Website download link +- [ ] Send the announcement + +For details of each step, please refer to: https://rust.iceberg.apache.org/release +``` + +## GitHub Side + +### Bump version in project + +Bump all components' version in the project to the new iceberg version. +Please note that this version is the exact version of the release, not the release candidate version. + +- rust core: bump version in `Cargo.toml` + +### Update docs + +- Update `CHANGELOG.md`, refer to [Generate Release Note](reference/generate_release_note.md) for more information. + +### Generate dependencies list + +Download and setup `cargo-deny`. You can refer to [cargo-deny](https://embarkstudios.github.io/cargo-deny/cli/index.html). + +Running `python3 ./scripts/dependencies.py generate` to update the dependencies list of every package. + +### Push release candidate tag + +After bump version PR gets merged, we can create a GitHub release for the release candidate: + +- Create a tag at `main` branch on the `Bump Version` / `Patch up version` commit: `git tag -s "v0.2.0-rc.1"`, please correctly check out the corresponding commit instead of directly tagging on the main branch. +- Push tags to GitHub: `git push --tags`. + +## ASF Side + +If any step in the ASF Release process fails and requires code changes, +we will abandon that version and prepare for the next one. +Our release page will only display ASF releases instead of GitHub Releases. + +### Create an ASF Release + +After GitHub Release has been created, we can start to create ASF Release. + +- Checkout to released tag. (e.g. `git checkout v0.2.0-rc.1`, tag is created in the previous step) +- Use the release script to create a new release: `ICEBERG_VERSION= ICEBERG_VERSION_RC= ./scripts/release.sh`(e.g. `ICEBERG_VERSION=0.2.0 ICEBERG_VERSION_RC=rc.1 ./scripts/release.sh`) + - This script will do the following things: + - Create a new branch named by `release-${release_version}` from the tag + - Generate the release candidate artifacts under `dist`, including: + - `apache-iceberg-rust-${release_version}-src.tar.gz` + - `apache-iceberg-rust-${release_version}-src.tar.gz.asc` + - `apache-iceberg-rust-${release_version}-src.tar.gz.sha512` + - Check the header of the source code. This step needs docker to run. +- Push the newly created branch to GitHub + +This script will create a new release under `dist`. + +For example: + +```shell +> tree dist +dist +├── apache-iceberg-rust-0.2.0-src.tar.gz +├── apache-iceberg-rust-0.2.0-src.tar.gz.asc +└── apache-iceberg-rust-0.2.0-src.tar.gz.sha512 +``` + +### Upload artifacts to the SVN dist repo + +SVN is required for this step. + +The svn repository of the dev branch is: + +First, checkout Iceberg to local directory: + +```shell +# As this step will copy all the versions, it will take some time. If the network is broken, please use svn cleanup to delete the lock before re-execute it. +svn co https://dist.apache.org/repos/dist/dev/iceberg/ /tmp/iceberg-dist-dev +``` + +Then, upload the artifacts: + +> The `${release_version}` here should be like `0.2.0-rc.1` + +```shell +# create a directory named by version +mkdir /tmp/iceberg-dist-dev/${release_version} +# copy source code and signature package to the versioned directory +cp ${repo_dir}/dist/* /tmp/iceberg-dist-dev/iceberg-rust-${release_version}/ +# change dir to the svn folder +cd /tmp/iceberg-dist-dev/ +# check svn status +svn status +# add to svn +svn add ${release_version} +# check svn status +svn status +# commit to SVN remote server +svn commit -m "Prepare for ${release_version}" +``` + +Visit to make sure the artifacts are uploaded correctly. + +### Rescue + +If you accidentally published wrong or unexpected artifacts, like wrong signature files, wrong sha256 files, +please cancel the release for the current `release_version`, +_increase th RC counting_ and re-initiate a release with the new `release_version`. +And remember to delete the wrong artifacts from the SVN dist repo. + +## Voting + +Iceberg Community Vote should send email to: : + +Title: + +``` +[VOTE] Release Apache Iceberg Rust ${release_version} RC1 +``` + +Content: + +``` +Hello, Apache Iceberg Rust Community, + +This is a call for a vote to release Apache Iceberg rust version ${iceberg_version}. + +The tag to be voted on is ${iceberg_version}. + +The release candidate: + +https://dist.apache.org/repos/dist/dev/iceberg/iceberg-rust-${release_version}/ + +Keys to verify the release candidate: + +https://downloads.apache.org/iceberg/KEYS + +Git tag for the release: + +https://github.com/apache/iceberg-rust/releases/tag/${release_version} + +Please download, verify, and test. + +The VOTE will be open for at least 72 hours and until the necessary +number of votes are reached. + +[ ] +1 approve +[ ] +0 no opinion +[ ] -1 disapprove with the reason + +To learn more about Apache Iceberg, please see https://rust.iceberg.apache.org/ + +Checklist for reference: + +[ ] Download links are valid. +[ ] Checksums and signatures. +[ ] LICENSE/NOTICE files exist +[ ] No unexpected binary files +[ ] All source files have ASF headers +[ ] Can compile from source + +More detailed checklist please refer to: +https://github.com/apache/iceberg-rust/tree/main/scripts + +To compile from source, please refer to: +https://github.com/apache/iceberg-rust/blob/main/CONTRIBUTING.md + +Here is a Python script in release to help you verify the release candidate: + +./scripts/verify.py + +Thanks + +${name} +``` + +Example: + +After at least 3 `+1` binding vote (from Iceberg PMC member), claim the vote result: + +Title: + +``` +[RESULT][VOTE] Release Apache Iceberg Rust ${release_version} RC1 +``` + +Content: + +``` +Hello, Apache Iceberg Rust Community, + +The vote to release Apache Iceberg Rust ${release_version} has passed. + +The vote PASSED with 3 +1 binding and 1 +1 non-binding votes, no +0 or -1 votes: + +Binding votes: + +- xxx +- yyy +- zzz + +Non-Binding votes: + +- aaa + +Vote thread: ${vote_thread_url} + +Thanks + +${name} +``` + +Example: + +## Official Release + +### Push the release git tag + +```shell +# Checkout the tags that passed VOTE +git checkout ${release_version} +# Tag with the iceberg version +git tag -s ${iceberg_version} +# Push tags to github to trigger releases +git push origin ${iceberg_version} +``` + +### Publish artifacts to SVN RELEASE branch + +```shell +svn mv https://dist.apache.org/repos/dist/dev/iceberg/iceberg-rust-${release_version} https://dist.apache.org/repos/dist/release/iceberg/iceberg-rust-${iceberg_version} -m "Release Apache Iceberg Rust ${iceberg_version}" +``` + +### Change Iceberg Rust Website download link + +Update the download link in `website/src/download.md` to the new release version. + +### Create a GitHub Release + +- Click [here](https://github.com/apache/iceberg-rust/releases/new) to create a new release. +- Pick the git tag of this release version from the dropdown menu. +- Make sure the branch target is `main`. +- Generate the release note by clicking the `Generate release notes` button. +- Add the release note from every component's `upgrade.md` if there are breaking changes before the content generated by GitHub. Check them carefully. +- Publish the release. + +### Send the announcement + +Send the release announcement to `dev@iceberg.apache.org` and CC `announce@apache.org`. + +Instead of adding breaking changes, let's include the new features as "notable changes" in the announcement. + +Title: + +``` +[ANNOUNCE] Release Apache Iceberg Rust ${iceberg_version} +``` + +Content: + +``` +Hi all, + +The Apache Iceberg Rust community is pleased to announce +that Apache Iceberg Rust ${iceberg_version} has been released! + +Iceberg is a data access layer that allows users to easily and efficiently +retrieve data from various storage services in a unified way. + +The notable changes since ${iceberg_version} include: +1. xxxxx +2. yyyyyy +3. zzzzzz + +Please refer to the change log for the complete list of changes: +https://github.com/apache/iceberg-rust/releases/tag/v${iceberg_version} + +Apache Iceberg Rust website: https://rust.iceberg.apache.org/ + +Download Links: https://rust.iceberg.apache.org/download + +Iceberg Resources: +- Issue: https://github.com/apache/iceberg-rust/issues +- Mailing list: dev@iceberg.apache.org + +Thanks +On behalf of Apache Iceberg Community +``` + +Example: