diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 7bbaa97b29..58fe95a554 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -6,15 +6,18 @@ on: push: branches: [main] +env: + HATCH_VERSION: 1.7.0 + jobs: ci: strategy: matrix: - pyVersion: [ '3.9' ] + pyVersion: [ '3.10' ] runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Unshallow run: git fetch --prune --unshallow @@ -22,29 +25,17 @@ jobs: - name: Install Python uses: actions/setup-python@v4 with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' python-version: ${{ matrix.pyVersion }} - - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - - - name: Load cache - id: cached-poetry-dependencies - uses: actions/cache@v3 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} - - - - name: Install project dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root --with=dev - + - name: Install hatch + run: pip install hatch==$HATCH_VERSION - name: Verify linting - run: make verify \ No newline at end of file + run: | + hatch run lint:verify + + - name: Run unit tests + run: | + hatch run unit:test diff --git a/.gitignore b/.gitignore index fd0175eca7..5494ed79a8 100644 --- a/.gitignore +++ b/.gitignore @@ -93,8 +93,9 @@ celerybeat.pid *.sage.py # Environments -.env +.env.admin .venv +.env.* env/ venv/ ENV/ @@ -134,3 +135,7 @@ cython_debug/ # ruff .ruff_cache +/scratch + +# dev files and scratches +dev/cleanup.py \ No newline at end of file diff --git a/Makefile b/Makefile index 934a3c0ff8..e69de29bb2 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +0,0 @@ -lint: - @echo "Linting the project code" - poetry run black . - poetry run isort . - poetry run ruff . --fix - -verify: - @echo "Verifying the project code" - poetry run black . --check - poetry run isort . --check - poetry run ruff . diff --git a/README.md b/README.md index 0fa85782ed..70280c05f8 100644 --- a/README.md +++ b/README.md @@ -2,14 +2,17 @@ This repo contains various functions and utilities for UC Upgrade. - ## Latest working version and how-to Please note that current project statis is 🏗️ **WIP**, but we have a minimal set of already working utilities. + To run the notebooks please use latest LTS Databricks Runtime (non-ML), without Photon, in a single-user cluster mode. -If you have Table ACL Clusters or SQL Warehouse where ACL have been defined, you should create a TableACL cluster to run this notebook -Please note that script is executed only on the driver node, therefore you'll need to use a Single Node Cluster with sufficient amount of cores (e.g. 16 cores). +> If you have Table ACL Clusters or SQL Warehouse where ACL have been defined, you should create a TableACL cluster to +> run this notebook. + +Please note that script is executed **only** on the driver node, therefore you'll need to use a Single Node Cluster with +sufficient amount of cores (e.g. 16 cores). Recommended VM types are: @@ -18,27 +21,131 @@ Recommended VM types are: - GCP: `c2-standard-16` **For now please switch to the `v0.0.1` tag in the GitHub to get the latest working version.** +**All instructions below are currently in WIP mode.** + +## Group migration + +During the UC adoption, it's critical to move the groups from the workspace to account level. + +To deliver this migration, the following steps are performed: + + +| Step description | Relevant API method | +|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------| +| A set of groups to be migrated is identified (either via `groups.selected` config property, or automatically).
Group existence is verified against the account level.
**If there is no group on the account level, an error is thrown.**
Backup groups are created on the workspace level. | `toolkit.prepare_groups_in_environment()` | +| Inventory table is cleaned up. | `toolkit.cleanup_inventory_table()` | +| Workspace local group permissions are inventorized and saved into a Delta Table. | `toolkit.inventorize_permissions()` | +| Backup groups are entitled with permissions from the inventory table. | `toolkit.apply_permissions_to_backup_groups()` | +| Workspace-level groups are deleted. Account-level groups are granted with access to the workspace.
Workspace-level entitlements are synced from backup groups to newly added account-level groups. | `toolkit.replace_workspace_groups_with_account_groups()` | +| Account-level groups are entitled with workspace-level permissions from the inventory table. | `toolkit.apply_permissions_to_account_groups()` | +| Backup groups are deleted | `toolkit.delete_backup_groups()` | +| Inventory table is cleaned up. This step is optional. | `toolkit.cleanup_inventory_table()` | + +## Permissions and entitlements that we inventorize + +> Please note that inherited permissions will not be inventorized / migrated. +> We only cover direct permissions. + +Group-level: + +- [x] Entitlements (One of `workspace-access`, `databricks-sql-access`, `allow-cluster-create`, `allow-instance-pool-create`) +- [x] Roles (AWS Only, represents Instance Profile Access) + +Compute infrastructure: + +- [x] Clusters +- [ ] Cluster policies +- [ ] Pools +- [ ] Instance Profile (for AWS) + +Workflows: + +- [ ] Delta Live Tables +- [ ] Jobs + +ML: + +- [ ] MLflow experiments +- [ ] MLflow registry +- [ ] Legacy Mlflow model endpoints (?) + +SQL: + +- [ ] Databricks SQL warehouses +- [ ] Dashboard +- [ ] Queries +- [ ] Alerts + +Security: + +- [ ] Tokens +- [ ] Passwords (for AWS) +- [ ] Secrets + +Workspace: + +- [ ] Notebooks in the Workspace FS +- [ ] Directories in the Workspace FS +- [ ] Files in the Workspace FS + +Repos: +- [ ] User-level Repos +- [ ] Org-level Repos -## Local setup and development process +Data access: -- Install [poetry](https://python-poetry.org/) -- Run `poetry install` in the project directory -- Pin your IDE to use the newly created poetry environment +- [ ] Table ACLS -> Please note that you **don't** need to use `poetry` inside notebooks or in the Databricks workspace. -> It's only introduced to simplify local development. +## Development -Before running `git push`, don't forget to link your code with: +This section describes setup and development process for the project. + +### Local setup + +- Install [hatch](https://github.com/pypa/hatch): + +```shell +pip install hatch +``` + +- Create environment: + +```shell +hatch env create +``` + +- Install dev dependencies: + +```shell +hatch run pip install -e '.[dbconnect]' +``` + +- Pin your IDE to use the newly created virtual environment. You can get the python path with: + +```shell +hatch run python -c "import sys; print(sys.executable)" +``` + +- You're good to go! 🎉 + +### Development process + +Please note that you **don't** need to use `hatch` inside notebooks or in the Databricks workspace. +It's only introduced to simplify local development. + +Write your code in the IDE. Please keep all relevant files under the `src/uc_migration_toolkit` directory. + +Don't forget to test your code via: ```shell -make lint +hatch run test ``` -### Details of package installation +Please note that all commits go through the CI process, and it verifies linting. You can run linting locally via: -Since the package itself is managed with `poetry`, to re-use it inside the notebooks we're doing the following: +```shell +hatch run lint:fmt +``` -1. Installing the package dependencies via poetry export -2. Adding the package itself to the notebook via `sys.path` diff --git a/dev/init_setup.py b/dev/init_setup.py new file mode 100644 index 0000000000..395e3bdb32 --- /dev/null +++ b/dev/init_setup.py @@ -0,0 +1,45 @@ +from functools import partial +from pathlib import Path + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.iam import ComplexValue +from dotenv import load_dotenv + +from uc_migration_toolkit.config import RateLimitConfig +from uc_migration_toolkit.providers.logger import logger +from uc_migration_toolkit.utils import ThreadedExecution + +Threader = partial(ThreadedExecution, num_threads=40, rate_limit=RateLimitConfig()) + + +def _create_user(_ws: WorkspaceClient, uid: str): + user_name = f"test-user-{uid}@example.com" + potential_user = list(ws.users.list(filter=f"userName eq '{user_name}'")) + if potential_user: + logger.debug(f"User {user_name} already exists, skipping its creation") + else: + ws.users.create( + active=True, + user_name=user_name, + display_name=f"test-user-{uid}", + emails=[ComplexValue(display=None, primary=True, value=f"test-user-{uid}@example.com")], + ) + + +def _create_users(_ws: WorkspaceClient): + executables = [partial(_create_user, ws, uid) for uid in range(200)] + Threader(executables).run() + + +if __name__ == "__main__": + principal_env = Path(__file__).parent.parent / ".env.principal" + if principal_env.exists(): + logger.info("Using credentials provided in .env.principal") + load_dotenv(dotenv_path=principal_env) + + logger.debug("setting up the workspace client") + ws = WorkspaceClient() + user_info = ws.current_user.me() + logger.debug("workspace client is set up") + + _create_users(ws) diff --git a/examples/migration_config.yml b/examples/migration_config.yml new file mode 100644 index 0000000000..d7c51cf7c3 --- /dev/null +++ b/examples/migration_config.yml @@ -0,0 +1,14 @@ +inventory: + table: + catalog: main + database: default + name: uc_migration_inventory + + +with_table_acls: False + +groups: + selected: [ "analyst" ] + +num_threads: 80 + diff --git a/notebooks/GroupMigration/Workspace_Group_Migration_Notebook.py b/notebooks/GroupMigration/Workspace_Group_Migration_Notebook.py deleted file mode 100644 index 96cc5c4f04..0000000000 --- a/notebooks/GroupMigration/Workspace_Group_Migration_Notebook.py +++ /dev/null @@ -1,238 +0,0 @@ -# Databricks notebook source -# MAGIC %md -# MAGIC # Workspace Group Migration -# MAGIC -# MAGIC **Objective**
-# MAGIC Customers who have groups created at workspace level, when they integrate with Unity Catalog and want to enable identity federation for users, groups, service principals at account level, face problems for groups federation. While users and service principals are synched up with account level identities, groups are not. As a result, customers cannot add account level groups to workspace if a workspace group with same name exists, which limits tru identity federation. -# MAGIC This notebook and the associated script is designed to help customer migrate workspace level groups to account level groups. -# MAGIC -# MAGIC **How it works**
-# MAGIC The script essentially performs following major steps: -# MAGIC - Initiate the run by providing a list of workspace group to be migrated for a given workspace -# MAGIC - Script performs inventory of all the ACL permission for the given workspace groups -# MAGIC - Create back up workspace group of same name but add prefix "db-temp-" and apply the same ACL on them -# MAGIC - Delete the original workspace groups -# MAGIC - Add account level groups to the workspace -# MAGIC - migrate the acl from temp workspace group to the new account level groups -# MAGIC - delete the temp workspace groups -# MAGIC - Save the details of the inventory in a delta table -# MAGIC -# MAGIC **Scope of ACL**
-# MAGIC Following objects are covered as part of the ACL migration: -# MAGIC - Clusters -# MAGIC - Cluster policies -# MAGIC - Delta Live Tables pipelines -# MAGIC - Directories -# MAGIC - Jobs -# MAGIC - MLflow experiments -# MAGIC - MLflow registered models -# MAGIC - Notebooks -# MAGIC - Files -# MAGIC - Pools -# MAGIC - Repos -# MAGIC - Databricks SQL warehouses -# MAGIC - Dashboard -# MAGIC - Query -# MAGIC - Alerts -# MAGIC - Tokens -# MAGIC - Password (for AWS) -# MAGIC - Instance Profile (for AWS) -# MAGIC - Secrets -# MAGIC - Table ACL (Non UC Cluster) - -# COMMAND ---------- - -# MAGIC %md -# MAGIC ## Pre-requisite -# MAGIC -# MAGIC Before running the script, please make sure you have the following checks -# MAGIC 1. Ensure you have equivalent account level group created for the workspace group to be migrated -# MAGIC 2. create a PAT token for the workspace which has admin access -# MAGIC 3. Ensure SCIM integration at workspace group is disabled -# MAGIC 4. Ensure no jobs or process is running the workspace using an user/service principal which is member of the workspace group -# MAGIC 5. Confirm if Table ACL is defined in the workspace and ACL defined for groups, if not Table ACL check can be skipped as it takes time to capture ACL for tables if the list is huge - -# COMMAND ---------- - -# MAGIC %md -# MAGIC ## How to Run -# MAGIC - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Step 1: Initialize the class -# MAGIC Import the module WSGroupMigration and initialize the class by passing following attributes: -# MAGIC - list of workspace group to be migrated (make sure these are workspace groups and not account level groups) -# MAGIC - if the workspace is AWS or Azure -# MAGIC - workspace url -# MAGIC - name of the table to persist inventory data -# MAGIC - pat token of the admin to the workspace -# MAGIC - user name of the user whose pat token is generated -# MAGIC - confirm if Table ACL are used and access permission set for workspace groups - -# COMMAND ---------- - -from uc_upgrade.group_migration import GroupMigration - -# COMMAND ---------- - -# If autoGenerateList=True then groupL will be ignored and all eliglbe groups will be migrated. -autoGenerateList = False - -# please provide groups here, e.g. analyst. -# please provide group names and not ids -groupL = ["groupA", "groupB"] - - -# Find this in the account console -inventoryTableName = "WorkspaceInventory" -# the script will create two table -# WorkspaceInventory - to store all the ACL permission -# WorkspaceInventoryTableACL - to store the table acl permission specifically - -# Pull from your browser URL bar. Should start with "https://" and end with ".com" or ".net" -workspace_url = "https://" - - -# Personal Access Token. Create one in "User Settings" -token = "" - -# Should the migration Check the ACL on tables/views as well? -checkTableACL = False - -# What cloud provider? Acceptable values are "AWS" or anything other value. -cloud = "AWS" - -# Your databricks user email. -userName = "" - -# Number of threads to issue Databricks API requests with. If you get a lot of errors during the inventory, lower this value. -numThreads = 10 - -# The notebook will populate data in the WorkspaceInventory and WorkspaceInventoryTableACL(If applicable). -# if the notebook is run second time, it will retrieve the data from the table if already captured. -# Users have the option to do a fresh inventory in which case it will recreate the tables and start again. -# default set to False -freshInventory = False -# Initialize GroupMigration Class with values supplied above -gm = GroupMigration( - groupL=groupL, - cloud=cloud, - inventoryTableName=inventoryTableName, - workspace_url=workspace_url, - pat=token, - spark=spark, - userName=userName, - checkTableACL=checkTableACL, - autoGenerateList=autoGenerateList, - numThreads=numThreads, - freshInventory=freshInventory, -) - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Step 2: Perform Dry run -# MAGIC This steps performs a dry run to verify the current ACL on the supplied workspace groups and print outs the permission. -# MAGIC Please verify if all the permissions are covered -# MAGIC If the inventory was run previously and stored in the table for either Workspace or Account then it will use the same and save time, else it will do a fresh inventory -# MAGIC If the inventory data in the table is present for only few workspace objects , the dryRun will do the fresh inventory of objects not present in the table - -# COMMAND ---------- - -gm.dryRun("Workspace") - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Adhoc Step: Selective Inventory -# MAGIC This is a adhoc step for troubleshooting purpose. Once dryRun is complete and data stored in tables, if the acl of any object is changed in the workspace -# MAGIC Ex new notebook permission added, User can force a fresh inventory of the selected object instead of doing a full cleanup and running the dryRun -# MAGIC To save time call gm.performInventory with 3 parameters: -# MAGIC - mode: Workpace("for workspace local group") or Account ("for workspace back up group") -# MAGIC - force: setting to True will force fresh inventory capture and updates to the tables -# MAGIC - objectType: select the list of object for which to do the fresh inventory, options are -# MAGIC -# MAGIC "Group"(will do members, group list, entitlement, roles), "Password","Cluster","ClusterPolicy","Warehouse","Dashboard","Query","Job","Folder"(Will do folders, notebook and files),"TableACL","Alert","Pool","Experiment","Model","DLT","Repo","Token","Secret" -# MAGIC Ex: gm.performInventory('Workspace',force=True,objectType='Cluster') will do: -# MAGIC - fresh inventory of all cluster objects and updated the data the inventory table -# MAGIC - run printInventory() to verify all the permission again (including clusters). - -# COMMAND ---------- - -gm.performInventory("Workspace", force=True, objectType="Cluster") -gm.printInventory() - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Step 3: Create Back up group -# MAGIC This steps creates the back up groups, applies the ACL on the new temp group from the original workspace group. -# MAGIC - Verify the temp groups are created in the workspace admin console -# MAGIC - check randomly if all the ACL are applied correctly -# MAGIC - there should be one temp group for every workspace group (Ex: db-temp-analysts and analysts with same ACLs) - -# COMMAND ---------- - -gm.createBackupGroup() - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Step 3 Verification: Verify backup groups -# MAGIC This steps runs the permission inventory, tracking the new temp groups -# MAGIC - Verify the temp group permissions are as seen in the initial dry run -# MAGIC - check randomly if all the ACL are applied correctly -# MAGIC - there should be one temp group for every workspace group (Ex: db-temp-analysts and analysts with same ACLs) -# MAGIC - Similar to dryRun("workspace"), this will also capture inventory for first run and store it in tables, subsequent times inventory will be retrived from the table to save time. -# MAGIC - if inventory table contains partial workspace objects(ex cluster acl is missing), it will do fresh inventory for the missing object and update table - -# COMMAND ---------- - -gm.dryRun("Account") - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Step 4: Delete original workspace group -# MAGIC This steps deletes the original workspace group. -# MAGIC - Verify original workspace groups are deleted in the workspace admin console -# MAGIC - end user permissions shouldnt be impacted as ACL permission from temp workspace group should be in effect - -# COMMAND ---------- - -gm.deleteWorkspaceLocalGroups() - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Step 5: Create account level groups -# MAGIC This steps adds the account level groups to the workspace and applies the same ACL from the back workspace group to the account level group. -# MAGIC - Ensure account level groups are created upfront before -# MAGIC - verify account level groups are added to the workspace now -# MAGIC - check randomly if all the ACL are applied correctly to the account level groups -# MAGIC - there should be one temp group and account level group present (Ex: db-temp-analysts and analysts (account level group) with same ACLs) - -# COMMAND ---------- - -gm.createAccountGroup() - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Step 6: Delete temp workspace group -# MAGIC This steps deletes the temp workspace group. -# MAGIC - Verify temp workspace groups are deleted in the workspace admin console -# MAGIC - end user permissions shouldnt be impacted as ACL permission from account level group should be in effect - -# COMMAND ---------- - -gm.deleteTempGroups() - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Complete -# MAGIC - Repeat the steps for other workspace group in the same workspace -# MAGIC - Repeat the steps for other workspace that require migration diff --git a/notebooks/common.py b/notebooks/common.py deleted file mode 100644 index 42a7d179f4..0000000000 --- a/notebooks/common.py +++ /dev/null @@ -1,46 +0,0 @@ -import sys -from pathlib import Path -from tempfile import NamedTemporaryFile - -# from databricks.sdk.runtime import * # noqa: F403 - - -def install_uc_upgrade_package(): - ipython = get_ipython() # noqa: F405, F821 - - print("Installing poetry for package management") - ipython.run_line_magic("pip", "install poetry -I") - print("Poetry successfully installed") - print("Installing the uc-upgrade package and it's dependencies") - - with NamedTemporaryFile(suffix="-uc-upgrade-requirements.txt") as requirements_file: - print(f"Writing requirements to file {requirements_file.name}") - ipython.run_cell_magic("sh", "", f"poetry export --output={requirements_file.name} --without-hashes") - print("Saved the requirements to a provided file, installing them with pip") - ipython.run_line_magic("pip", f"install -r {requirements_file.name} -I") - print("Requirements installed successfully, restarting Python interpreter") - dbutils.library.restartPython() # noqa: F405, F821 - print("Python interpreter restarted successfully") - - print("Reloading the path-based modules") - ipython.run_line_magic("load_ext", "autoreload") - ipython.run_line_magic("autoreload", 2) - print("Path-based modules successfully reloaded") - - project_root = Path(".").absolute().parent - print(f"appending the uc-upgrade library from {project_root}") - sys.path.append(project_root) - - print("Verifying that package can be properly loaded") - try: - from uc_upgrade.group_migration import GroupMigration # noqa: F401 - - print("Successfully loaded the uc-upgrade package") - except Exception as e: - print( - "Unable to import the UC migration utilities package from source. " - "Please check that you've imported the whole repository and not just copied one file." - ) - print("Also check that you have the Files in Repos activated, e.g. use DBR 11.X+") - print("Original exception:") - print(e) diff --git a/poetry.lock b/poetry.lock deleted file mode 100644 index 233cc24bf6..0000000000 --- a/poetry.lock +++ /dev/null @@ -1,431 +0,0 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. - -[[package]] -name = "black" -version = "23.7.0" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.8" -files = [ - {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, - {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, - {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, - {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, - {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, - {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, - {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, - {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, - {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, - {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, - {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, - {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, - {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -packaging = ">=22.0" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - -[[package]] -name = "certifi" -version = "2023.5.7" -description = "Python package for providing Mozilla's CA Bundle." -optional = false -python-versions = ">=3.6" -files = [ - {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, - {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, -] - -[[package]] -name = "charset-normalizer" -version = "3.1.0" -description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "charset-normalizer-3.1.0.tar.gz", hash = "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:04eefcee095f58eaabe6dc3cc2262f3bcd776d2c67005880894f447b3f2cb9c1"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20064ead0717cf9a73a6d1e779b23d149b53daf971169289ed2ed43a71e8d3b0"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1435ae15108b1cb6fffbcea2af3d468683b7afed0169ad718451f8db5d1aff6f"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c84132a54c750fda57729d1e2599bb598f5fa0344085dbde5003ba429a4798c0"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75f2568b4189dda1c567339b48cba4ac7384accb9c2a7ed655cd86b04055c795"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11d3bcb7be35e7b1bba2c23beedac81ee893ac9871d0ba79effc7fc01167db6c"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:891cf9b48776b5c61c700b55a598621fdb7b1e301a550365571e9624f270c203"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5f008525e02908b20e04707a4f704cd286d94718f48bb33edddc7d7b584dddc1"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b06f0d3bf045158d2fb8837c5785fe9ff9b8c93358be64461a1089f5da983137"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:49919f8400b5e49e961f320c735388ee686a62327e773fa5b3ce6721f7e785ce"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:22908891a380d50738e1f978667536f6c6b526a2064156203d418f4856d6e86a"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-win32.whl", hash = "sha256:12d1a39aa6b8c6f6248bb54550efcc1c38ce0d8096a146638fd4738e42284448"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:65ed923f84a6844de5fd29726b888e58c62820e0769b76565480e1fdc3d062f8"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9a3267620866c9d17b959a84dd0bd2d45719b817245e49371ead79ed4f710d19"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6734e606355834f13445b6adc38b53c0fd45f1a56a9ba06c2058f86893ae8017"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f8303414c7b03f794347ad062c0516cee0e15f7a612abd0ce1e25caf6ceb47df"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf53a6cebad0eae578f062c7d462155eada9c172bd8c4d250b8c1d8eb7f916a"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dc5b6a8ecfdc5748a7e429782598e4f17ef378e3e272eeb1340ea57c9109f41"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e1b25e3ad6c909f398df8921780d6a3d120d8c09466720226fc621605b6f92b1"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca564606d2caafb0abe6d1b5311c2649e8071eb241b2d64e75a0d0065107e62"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b82fab78e0b1329e183a65260581de4375f619167478dddab510c6c6fb04d9b6"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bd7163182133c0c7701b25e604cf1611c0d87712e56e88e7ee5d72deab3e76b5"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:11d117e6c63e8f495412d37e7dc2e2fff09c34b2d09dbe2bee3c6229577818be"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:cf6511efa4801b9b38dc5546d7547d5b5c6ef4b081c60b23e4d941d0eba9cbeb"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:abc1185d79f47c0a7aaf7e2412a0eb2c03b724581139193d2d82b3ad8cbb00ac"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cb7b2ab0188829593b9de646545175547a70d9a6e2b63bf2cd87a0a391599324"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-win32.whl", hash = "sha256:c36bcbc0d5174a80d6cccf43a0ecaca44e81d25be4b7f90f0ed7bcfbb5a00909"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:cca4def576f47a09a943666b8f829606bcb17e2bc2d5911a46c8f8da45f56755"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:0c95f12b74681e9ae127728f7e5409cbbef9cd914d5896ef238cc779b8152373"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fca62a8301b605b954ad2e9c3666f9d97f63872aa4efcae5492baca2056b74ab"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac0aa6cd53ab9a31d397f8303f92c42f534693528fafbdb997c82bae6e477ad9"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3af8e0f07399d3176b179f2e2634c3ce9c1301379a6b8c9c9aeecd481da494f"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a5fc78f9e3f501a1614a98f7c54d3969f3ad9bba8ba3d9b438c3bc5d047dd28"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:628c985afb2c7d27a4800bfb609e03985aaecb42f955049957814e0491d4006d"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:74db0052d985cf37fa111828d0dd230776ac99c740e1a758ad99094be4f1803d"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1e8fcdd8f672a1c4fc8d0bd3a2b576b152d2a349782d1eb0f6b8e52e9954731d"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:04afa6387e2b282cf78ff3dbce20f0cc071c12dc8f685bd40960cc68644cfea6"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:dd5653e67b149503c68c4018bf07e42eeed6b4e956b24c00ccdf93ac79cdff84"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d2686f91611f9e17f4548dbf050e75b079bbc2a82be565832bc8ea9047b61c8c"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-win32.whl", hash = "sha256:4155b51ae05ed47199dc5b2a4e62abccb274cee6b01da5b895099b61b1982974"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:322102cdf1ab682ecc7d9b1c5eed4ec59657a65e1c146a0da342b78f4112db23"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e633940f28c1e913615fd624fcdd72fdba807bf53ea6925d6a588e84e1151531"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3a06f32c9634a8705f4ca9946d667609f52cf130d5548881401f1eb2c39b1e2c"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7381c66e0561c5757ffe616af869b916c8b4e42b367ab29fedc98481d1e74e14"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3573d376454d956553c356df45bb824262c397c6e26ce43e8203c4c540ee0acb"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e89df2958e5159b811af9ff0f92614dabf4ff617c03a4c1c6ff53bf1c399e0e1"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:78cacd03e79d009d95635e7d6ff12c21eb89b894c354bd2b2ed0b4763373693b"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de5695a6f1d8340b12a5d6d4484290ee74d61e467c39ff03b39e30df62cf83a0"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c60b9c202d00052183c9be85e5eaf18a4ada0a47d188a83c8f5c5b23252f649"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f645caaf0008bacf349875a974220f1f1da349c5dbe7c4ec93048cdc785a3326"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ea9f9c6034ea2d93d9147818f17c2a0860d41b71c38b9ce4d55f21b6f9165a11"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:80d1543d58bd3d6c271b66abf454d437a438dff01c3e62fdbcd68f2a11310d4b"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:73dc03a6a7e30b7edc5b01b601e53e7fc924b04e1835e8e407c12c037e81adbd"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6f5c2e7bc8a4bf7c426599765b1bd33217ec84023033672c1e9a8b35eaeaaaf8"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-win32.whl", hash = "sha256:12a2b561af122e3d94cdb97fe6fb2bb2b82cef0cdca131646fdb940a1eda04f0"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3160a0fd9754aab7d47f95a6b63ab355388d890163eb03b2d2b87ab0a30cfa59"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38e812a197bf8e71a59fe55b757a84c1f946d0ac114acafaafaf21667a7e169e"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6baf0baf0d5d265fa7944feb9f7451cc316bfe30e8df1a61b1bb08577c554f31"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8f25e17ab3039b05f762b0a55ae0b3632b2e073d9c8fc88e89aca31a6198e88f"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3747443b6a904001473370d7810aa19c3a180ccd52a7157aacc264a5ac79265e"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b116502087ce8a6b7a5f1814568ccbd0e9f6cfd99948aa59b0e241dc57cf739f"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d16fd5252f883eb074ca55cb622bc0bee49b979ae4e8639fff6ca3ff44f9f854"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fa558996782fc226b529fdd2ed7866c2c6ec91cee82735c98a197fae39f706"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f6c7a8a57e9405cad7485f4c9d3172ae486cfef1344b5ddd8e5239582d7355e"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ac3775e3311661d4adace3697a52ac0bab17edd166087d493b52d4f4f553f9f0"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:10c93628d7497c81686e8e5e557aafa78f230cd9e77dd0c40032ef90c18f2230"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:6f4f4668e1831850ebcc2fd0b1cd11721947b6dc7c00bf1c6bd3c929ae14f2c7"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:0be65ccf618c1e7ac9b849c315cc2e8a8751d9cfdaa43027d4f6624bd587ab7e"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:53d0a3fa5f8af98a1e261de6a3943ca631c526635eb5817a87a59d9a57ebf48f"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-win32.whl", hash = "sha256:a04f86f41a8916fe45ac5024ec477f41f886b3c435da2d4e3d2709b22ab02af1"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b"}, - {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"}, -] - -[[package]] -name = "click" -version = "8.1.3" -description = "Composable command line interface toolkit" -optional = false -python-versions = ">=3.7" -files = [ - {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, - {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, -] - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - -[[package]] -name = "colorama" -version = "0.4.6" -description = "Cross-platform colored terminal text." -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -files = [ - {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, - {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, -] - -[[package]] -name = "databricks-sdk" -version = "0.2.1" -description = "Databricks SDK for Python (Beta)" -optional = false -python-versions = ">=3.7" -files = [ - {file = "databricks-sdk-0.2.1.tar.gz", hash = "sha256:0db4919c0f5b54a831a4756dcc464cd62813560026fdf8e3f8ef01c9385f4061"}, - {file = "databricks_sdk-0.2.1-py3-none-any.whl", hash = "sha256:b466539da8dac45f3947ff209ff915a338e1b0a2b1ec8827aed79de3c6ecca5a"}, -] - -[package.dependencies] -requests = ">=2.28.1,<3" - -[package.extras] -dev = ["autoflake", "ipython", "ipywidgets", "isort", "pycodestyle", "pytest", "pytest-cov", "pytest-mock", "pytest-xdist", "wheel", "yapf"] -notebook = ["ipython (>=8,<9)", "ipywidgets (>=8,<9)"] - -[[package]] -name = "delta-spark" -version = "2.4.0" -description = "Python APIs for using Delta Lake with Apache Spark" -optional = false -python-versions = ">=3.6" -files = [ - {file = "delta-spark-2.4.0.tar.gz", hash = "sha256:ef776e325e80d98e3920cab982c747b094acc46599d62dfcdc9035fb112ba6a9"}, - {file = "delta_spark-2.4.0-py3-none-any.whl", hash = "sha256:7204142a97ef16367403b020d810d0c37f4ae8275b4997de4056423cf69b3a4b"}, -] - -[package.dependencies] -importlib-metadata = ">=1.0.0" -pyspark = ">=3.4.0,<3.5.0" - -[[package]] -name = "idna" -version = "3.4" -description = "Internationalized Domain Names in Applications (IDNA)" -optional = false -python-versions = ">=3.5" -files = [ - {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, - {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, -] - -[[package]] -name = "importlib-metadata" -version = "6.6.0" -description = "Read metadata from Python packages" -optional = false -python-versions = ">=3.7" -files = [ - {file = "importlib_metadata-6.6.0-py3-none-any.whl", hash = "sha256:43dd286a2cd8995d5eaef7fee2066340423b818ed3fd70adf0bad5f1fac53fed"}, - {file = "importlib_metadata-6.6.0.tar.gz", hash = "sha256:92501cdf9cc66ebd3e612f1b4f0c0765dfa42f0fa38ffb319b6bd84dd675d705"}, -] - -[package.dependencies] -zipp = ">=0.5" - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -perf = ["ipython"] -testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] - -[[package]] -name = "isort" -version = "5.12.0" -description = "A Python utility / library to sort Python imports." -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"}, - {file = "isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504"}, -] - -[package.extras] -colors = ["colorama (>=0.4.3)"] -pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"] -plugins = ["setuptools"] -requirements-deprecated-finder = ["pip-api", "pipreqs"] - -[[package]] -name = "mypy-extensions" -version = "1.0.0" -description = "Type system extensions for programs checked with the mypy type checker." -optional = false -python-versions = ">=3.5" -files = [ - {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, - {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, -] - -[[package]] -name = "packaging" -version = "23.1" -description = "Core utilities for Python packages" -optional = false -python-versions = ">=3.7" -files = [ - {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, - {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, -] - -[[package]] -name = "pathspec" -version = "0.11.1" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.7" -files = [ - {file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"}, - {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"}, -] - -[[package]] -name = "platformdirs" -version = "3.5.1" -description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -optional = false -python-versions = ">=3.7" -files = [ - {file = "platformdirs-3.5.1-py3-none-any.whl", hash = "sha256:e2378146f1964972c03c085bb5662ae80b2b8c06226c54b2ff4aa9483e8a13a5"}, - {file = "platformdirs-3.5.1.tar.gz", hash = "sha256:412dae91f52a6f84830f39a8078cecd0e866cb72294a5c66808e74d5e88d251f"}, -] - -[package.extras] -docs = ["furo (>=2023.3.27)", "proselint (>=0.13)", "sphinx (>=6.2.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] - -[[package]] -name = "py4j" -version = "0.10.9.7" -description = "Enables Python programs to dynamically access arbitrary Java objects" -optional = false -python-versions = "*" -files = [ - {file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"}, - {file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"}, -] - -[[package]] -name = "pyspark" -version = "3.4.1" -description = "Apache Spark Python API" -optional = false -python-versions = ">=3.7" -files = [ - {file = "pyspark-3.4.1.tar.gz", hash = "sha256:72cd66ab8cf61a75854e5a753f75bea35ee075c3a96f9de4e2a66d02ec7fc652"}, -] - -[package.dependencies] -py4j = "0.10.9.7" - -[package.extras] -connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.48.1)", "grpcio-status (>=1.48.1)", "numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] -ml = ["numpy (>=1.15)"] -mllib = ["numpy (>=1.15)"] -pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] -sql = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] - -[[package]] -name = "requests" -version = "2.28.2" -description = "Python HTTP for Humans." -optional = false -python-versions = ">=3.7, <4" -files = [ - {file = "requests-2.28.2-py3-none-any.whl", hash = "sha256:64299f4909223da747622c030b781c0d7811e359c37124b4bd368fb8c6518baa"}, - {file = "requests-2.28.2.tar.gz", hash = "sha256:98b1b2782e3c6c4904938b84c0eb932721069dfdb9134313beff7c83c2df24bf"}, -] - -[package.dependencies] -certifi = ">=2017.4.17" -charset-normalizer = ">=2,<4" -idna = ">=2.5,<4" -urllib3 = ">=1.21.1,<1.27" - -[package.extras] -socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] - -[[package]] -name = "ruff" -version = "0.0.278" -description = "An extremely fast Python linter, written in Rust." -optional = false -python-versions = ">=3.7" -files = [ - {file = "ruff-0.0.278-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:1a90ebd8f2a554db1ee8d12b2f3aa575acbd310a02cd1a9295b3511a4874cf98"}, - {file = "ruff-0.0.278-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:38ca1c0c8c1221fe64c0a66784c91501d09a8ed02a4dbfdc117c0ce32a81eefc"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c62a0bde4d20d087cabce2fa8b012d74c2e985da86d00fb3359880469b90e31"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7545bb037823cd63dca19280f75a523a68bd3e78e003de74609320d6822b5a52"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb380d2d6fdb60656a0b5fa78305535db513fc72ce11f4532cc1641204ef380"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d11149c7b186f224f2055e437a030cd83b164a43cc0211314c33ad1553ed9c4c"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:666e739fb2685277b879d493848afe6933e3be30d40f41fe0e571ad479d57d77"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ec8b0469b54315803aaf1fbf9a37162a3849424cab6182496f972ad56e0ea702"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c25b96602695a147d62a572865b753ef56aff1524abab13b9436724df30f9bd7"}, - {file = "ruff-0.0.278-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a48621f5f372d5019662db5b3dbfc5f1450f927683d75f1153fe0ebf20eb9698"}, - {file = "ruff-0.0.278-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1078125123a3c68e92463afacedb7e41b15ccafc09e510c6c755a23087afc8de"}, - {file = "ruff-0.0.278-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3ce0d620e257b4cad16e2f0c103b2f43a07981668a3763380542e8a131d11537"}, - {file = "ruff-0.0.278-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1cae4c07d334eb588f171f1363fa89a8911047eb93184276be11a24dbbc996c7"}, - {file = "ruff-0.0.278-py3-none-win32.whl", hash = "sha256:70d39f5599d8449082ab8ce542fa98e16413145eb411dd1dc16575b44565d52d"}, - {file = "ruff-0.0.278-py3-none-win_amd64.whl", hash = "sha256:e131595ab7f4ce61a1650463bd2fe304b49e7d0deb0dfa664b92817c97cdba5f"}, - {file = "ruff-0.0.278-py3-none-win_arm64.whl", hash = "sha256:737a0cfb6c36aaa92d97a46957dfd5e55329299074ad06ed12663b98e0c6fc82"}, - {file = "ruff-0.0.278.tar.gz", hash = "sha256:1a9f1d925204cfba81b18368b7ac943befcfccc3a41e170c91353b674c6b7a66"}, -] - -[[package]] -name = "tomli" -version = "2.0.1" -description = "A lil' TOML parser" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, - {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, -] - -[[package]] -name = "typing-extensions" -version = "4.6.3" -description = "Backported and Experimental Type Hints for Python 3.7+" -optional = false -python-versions = ">=3.7" -files = [ - {file = "typing_extensions-4.6.3-py3-none-any.whl", hash = "sha256:88a4153d8505aabbb4e13aacb7c486c2b4a33ca3b3f807914a9b4c844c471c26"}, - {file = "typing_extensions-4.6.3.tar.gz", hash = "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"}, -] - -[[package]] -name = "urllib3" -version = "1.26.16" -description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" -files = [ - {file = "urllib3-1.26.16-py2.py3-none-any.whl", hash = "sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f"}, - {file = "urllib3-1.26.16.tar.gz", hash = "sha256:8f135f6502756bde6b2a9b28989df5fbe87c9970cecaa69041edcce7f0589b14"}, -] - -[package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] -secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] -socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] - -[[package]] -name = "zipp" -version = "3.15.0" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.7" -files = [ - {file = "zipp-3.15.0-py3-none-any.whl", hash = "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"}, - {file = "zipp-3.15.0.tar.gz", hash = "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] - -[metadata] -lock-version = "2.0" -python-versions = "^3.9" -content-hash = "29369e89b76066a79173f9b971afb8a29faf5fdeac8d0a4b8019b08f99d8d6cd" diff --git a/pyproject.toml b/pyproject.toml index 74eff1a8b8..fbb6aca218 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,34 +1,194 @@ -[tool.poetry] -name = "uc-upgrade" -version = "0.1.0" -description = "" -authors = ["Your Name "] +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "uc-migration-toolkit" +dynamic = ["version"] +description = '' readme = "README.md" -packages = [{ include = "uc_upgrade" }] +requires-python = ">=3.10.6" # latest available in DBR 13.2 +keywords = [] +authors = [ + { name = "renardeinside", email = "polarpersonal@gmail.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "databricks-sdk>=0.2.1, <=0.3.0", + "typer[all]>=0.9.0,<0.10.0", + "pyhocon>=0.3.60,<0.4.0", + "pydantic>=2.0.3, <3.0.0", + "loguru>=0.7.0, <1.0.0", + "PyYAML>=6.0.0,<7.0.0", + "ratelimit>=2.2.1,<3.0.0", + "pandas>=2.0.3,<3.0.0", + "python-dotenv>=1.0.0,<=2.0.0" +] + +[project.optional-dependencies] +dbconnect = [ + "databricks-connect>=13.2.0,<=14.0.0" +] +test = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.0.0,<4.0.0", +] + +[project.scripts] +ucx = "uc_migration_toolkit.__main__:entrypoint" + +[project.urls] +Issues = "https://github.com/databricks/UC-Upgrade/issues" +Source = "https://github.com/databricks/UC-Upgrade" + +[tool.hatch.version] +path = "src/uc_migration_toolkit/__about__.py" -[tool.poetry.dependencies] -python = "^3.9" -databricks-sdk = "^0.2.1" +[tool.hatch.envs.default] +dependencies = [ + "uc-migration-toolkit[test]", +] + +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "pytest --cov uc_migration_toolkit {args:tests}" +test-cov-report = "pytest {args:tests} --cov uc_migration_toolkit --cov-report=html" + +[tool.hatch.envs.unit] +dependencies = [ + "uc-migration-toolkit[test]", + "pyspark>=3.4.0,<=3.5.0", + "delta-spark>=2.4.0,<3.0.0" +] +[tool.hatch.envs.unit.scripts] +test = "pytest tests/unit" +test-cov = "pytest tests/unit --cov uc_migration_toolkit" +test-cov-report = "pytest tests/unit --cov uc_migration_toolkit --cov-report=html" + + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "ruff>=0.0.243", + "isort>=2.5.0" +] +[tool.hatch.envs.lint.scripts] +fmt = [ + "black .", + "ruff --fix .", + "isort ." +] +verify = [ + "black --check .", + "ruff .", + "isort . --check-only" +] + +[tool.isort] +skip_glob = [ + "notebooks/*.py" +] +profile = "black" + +[tool.pytest.ini_options] +addopts = "-s -p no:warnings" +log_cli = true +filterwarnings = [ + "ignore:::.*pyspark.broadcast*", + "ignore:::.*pyspark.sql.pandas.utils*" +] -[tool.poetry.group.dev.dependencies] -black = "^23.7.0" -isort = "^5.12.0" -ruff = "^0.0.278" -pyspark = "^3.4.1" -delta-spark = "^2.4.0" [tool.black] +target-version = ["py310"] line-length = 120 -#exclude = "notebooks" +skip-string-normalization = true [tool.ruff] +target-version = "py310" line-length = 120 -extend-exclude = ["notebooks/*/*.py"] +select = [ + "A", + "ARG", + "B", + "C", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Allow print statements + "T201", + # Allow asserts + "S101", + # Allow standard random generators + "S311", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] -[tool.isort] -profile = "black" -#skip_glob = ["notebooks/*"] +extend-exclude = [ + "notebooks/*.py" +] -[build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +[tool.ruff.isort] +known-first-party = ["uc_migration_toolkit"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source = ["uc_migration_toolkit"] +branch = true +parallel = true +omit = [ + "src/uc_migration_toolkit/__about__.py", +] + + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] diff --git a/src/uc_migration_toolkit/__about__.py b/src/uc_migration_toolkit/__about__.py new file mode 100644 index 0000000000..5d13a610b2 --- /dev/null +++ b/src/uc_migration_toolkit/__about__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2023-present renardeinside +# +# SPDX-License-Identifier: MIT +__version__ = "0.0.1" diff --git a/src/uc_migration_toolkit/__init__.py b/src/uc_migration_toolkit/__init__.py new file mode 100644 index 0000000000..9393bfa907 --- /dev/null +++ b/src/uc_migration_toolkit/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present renardeinside +# +# SPDX-License-Identifier: MIT diff --git a/src/uc_migration_toolkit/__main__.py b/src/uc_migration_toolkit/__main__.py new file mode 100644 index 0000000000..a7e5c78dba --- /dev/null +++ b/src/uc_migration_toolkit/__main__.py @@ -0,0 +1,9 @@ +from uc_migration_toolkit.cli.app import app + + +def entrypoint(): + app() + + +if __name__ == "__main__": + entrypoint() diff --git a/src/uc_migration_toolkit/cli/app.py b/src/uc_migration_toolkit/cli/app.py new file mode 100644 index 0000000000..2c279dc226 --- /dev/null +++ b/src/uc_migration_toolkit/cli/app.py @@ -0,0 +1,26 @@ +from pathlib import Path +from typing import Annotated + +import typer +from typer import Typer + +app = Typer(name="UC Migration Toolkit", pretty_exceptions_show_locals=True) + + +@app.command() +def migrate_groups(config_file: Annotated[Path, typer.Argument(help="Path to config file")] = "migration_config.yml"): + from uc_migration_toolkit.cli.utils import get_migration_config + from uc_migration_toolkit.toolkits.group_migration import GroupMigrationToolkit + + config = get_migration_config(config_file) + toolkit = GroupMigrationToolkit(config) + + toolkit.validate_groups() + toolkit.cleanup_inventory_table() + toolkit.inventorize_permissions() + toolkit.create_or_update_backup_groups() + toolkit.apply_backup_group_permissions() + toolkit.replace_workspace_groups_with_account_groups() + # toolkit.apply_account_group_permissions() + # toolkit.delete_backup_groups() + # toolkit.cleanup_inventory_table() diff --git a/src/uc_migration_toolkit/cli/utils.py b/src/uc_migration_toolkit/cli/utils.py new file mode 100644 index 0000000000..63499312fb --- /dev/null +++ b/src/uc_migration_toolkit/cli/utils.py @@ -0,0 +1,11 @@ +from pathlib import Path + +from yaml import safe_load + +from uc_migration_toolkit.config import MigrationConfig + + +def get_migration_config(config_file: Path) -> MigrationConfig: + _raw_config = safe_load(config_file.read_text()) + _raw_config = {} if not _raw_config else _raw_config + return MigrationConfig(**_raw_config) diff --git a/src/uc_migration_toolkit/config.py b/src/uc_migration_toolkit/config.py new file mode 100644 index 0000000000..277bcd84c9 --- /dev/null +++ b/src/uc_migration_toolkit/config.py @@ -0,0 +1,75 @@ +from pydantic import Field, RootModel +from pydantic.dataclasses import dataclass + + +@dataclass +class InventoryTable: + catalog: str + database: str + name: str + + def __repr__(self): + return f"{self.catalog}.{self.database}.{self.name}" + + def to_spark(self): + return self.__repr__() + + +@dataclass +class GroupsConfig: + selected: list[str] | None = None + auto: bool | None = None + backup_group_prefix: str | None = "db-temp-" + + def __post_init__(self): + if not self.selected and self.auto is None: + msg = "Either selected or auto must be set" + raise ValueError(msg) + if self.selected and self.auto is False: + msg = "No selected groups provided, but auto-collection is disabled" + raise ValueError(msg) + + +@dataclass +class WorkspaceAuthConfig: + token: str | None = None + host: str | None = None + client_id: str | None = None + client_secret: str | None = None + + +@dataclass +class AuthConfig: + workspace: WorkspaceAuthConfig | None = None + + class Config: + frozen = True + + +@dataclass +class InventoryConfig: + table: InventoryTable + + +@dataclass +class RateLimitConfig: + max_requests_per_period: int | None = 100 + period_in_seconds: int | None = 1 + + +@dataclass +class MigrationConfig: + inventory: InventoryConfig + with_table_acls: bool + groups: GroupsConfig + auth: AuthConfig | None = None + rate_limit: RateLimitConfig | None = Field(default_factory=lambda: RateLimitConfig()) + num_threads: int | None = 4 + + def __post_init__(self): + if self.with_table_acls: + msg = "Table ACLS are not yet implemented" + raise NotImplementedError(msg) + + def to_json(self) -> str: + return RootModel[MigrationConfig](self).model_dump_json(indent=4) diff --git a/uc_upgrade/__init__.py b/src/uc_migration_toolkit/managers/__init__.py similarity index 100% rename from uc_upgrade/__init__.py rename to src/uc_migration_toolkit/managers/__init__.py diff --git a/src/uc_migration_toolkit/managers/group.py b/src/uc_migration_toolkit/managers/group.py new file mode 100644 index 0000000000..bc688ced4d --- /dev/null +++ b/src/uc_migration_toolkit/managers/group.py @@ -0,0 +1,224 @@ +import json +import typing +from dataclasses import dataclass +from functools import partial + +from databricks.sdk.service.iam import Group + +from uc_migration_toolkit.providers.client import provider +from uc_migration_toolkit.providers.config import provider as config_provider +from uc_migration_toolkit.providers.logger import logger +from uc_migration_toolkit.utils import StrEnum, ThreadedExecution + + +@dataclass +class MigrationGroupInfo: + workspace: Group + backup: Group + account: Group + + +class GroupLevel(StrEnum): + WORKSPACE = "workspace" + ACCOUNT = "account" + + +class MigrationGroupsProvider: + def __init__(self): + self.groups: list[MigrationGroupInfo] = [] + + def add(self, group: MigrationGroupInfo): + self.groups.append(group) + + def get_by_workspace_group_name(self, workspace_group_name: str) -> MigrationGroupInfo | None: + found = [g for g in self.groups if g.workspace.display_name == workspace_group_name] + if len(found) == 0: + return None + else: + return found[0] + + +class GroupManager: + SYSTEM_GROUPS: typing.ClassVar[list[str]] = ["users", "admins", "account users"] + + def __init__(self): + self.config = config_provider.config.groups + self._migration_groups_provider: MigrationGroupsProvider = MigrationGroupsProvider() + + # please keep the internal methods below this line + + @staticmethod + def _find_eligible_groups() -> list[str]: + logger.info("Finding eligible groups automatically") + _display_name_filter = " and ".join([f'displayName ne "{group}"' for group in GroupManager.SYSTEM_GROUPS]) + ws_groups = list(provider.ws.groups.list(attributes="displayName,meta", filter=_display_name_filter)) + eligible_groups = [g for g in ws_groups if g.meta.resource_type == "WorkspaceGroup"] + logger.info(f"Found {len(eligible_groups)} eligible groups") + return [g.display_name for g in eligible_groups] + + @staticmethod + def _get_clean_group_info(group: Group, cleanup_keys: list[str] | None = None) -> dict: + """ + Returns a dictionary with group information, excluding some keys + :param group: Group object from SDK + :param cleanup_keys: default (with None) ["id", "externalId", "displayName"] + :return: dictionary with group information + """ + + cleanup_keys = cleanup_keys or ["id", "externalId", "displayName"] + group_info = group.as_dict() + + for key in cleanup_keys: + if key in group_info: + group_info.pop(key) + + return group_info + + @staticmethod + def _get_group(group_name, level: GroupLevel) -> Group | None: + method = provider.ws.groups.list if level == GroupLevel.WORKSPACE else provider.ws.list_account_level_groups + query_filter = f"displayName eq '{group_name}'" + attributes = ",".join(["id", "displayName", "meta", "entitlements", "roles"]) + + group = next( + iter(method(filter=query_filter, attributes=attributes)), + None, + ) + + return group + + def _get_or_create_backup_group(self, source_group_name: str, source_group: Group) -> Group: + backup_group_name = f"{self.config.backup_group_prefix}{source_group_name}" + backup_group = self._get_group(backup_group_name, GroupLevel.WORKSPACE) + + if backup_group: + logger.info(f"Backup group {backup_group_name} already exists, updating it") + else: + logger.info(f"Creating backup group {backup_group_name}") + new_group_payload = self._get_clean_group_info(source_group) + new_group_payload["displayName"] = backup_group_name + backup_group = provider.ws.groups.create(request=Group.from_dict(new_group_payload)) + logger.info(f"Backup group {backup_group_name} successfully created") + + self._apply_roles_and_entitlements(source_group, backup_group) + + return backup_group + + def _set_migration_groups(self, groups_names: list[str]): + def get_group_info(name: str): + ws_group = self._get_group(name, GroupLevel.WORKSPACE) + assert ws_group, f"Group {name} not found on the workspace level" + acc_group = self._get_group(name, GroupLevel.ACCOUNT) + assert acc_group, f"Group {name} not found on the account level" + backup_group = self._get_or_create_backup_group(source_group_name=name, source_group=ws_group) + return MigrationGroupInfo(workspace=ws_group, backup=backup_group, account=acc_group) + + executables = [partial(get_group_info, group_name) for group_name in groups_names] + + collected_groups = ThreadedExecution[MigrationGroupInfo](executables).run() + + self._migration_groups_provider.groups = collected_groups + + logger.info(f"Prepared {len(self._migration_groups_provider.groups)} groups for migration") + + def _replace_group(self, migration_info: MigrationGroupInfo): + ws_group = migration_info.workspace + acc_group = migration_info.account + backup_group = migration_info.backup + + if self._get_group(ws_group.display_name, GroupLevel.WORKSPACE): + logger.info(f"Deleting the workspace-level group {ws_group.display_name} with id {ws_group.id}") + provider.ws.groups.delete(ws_group.id) + logger.info(f"Workspace-level group {ws_group.display_name} with id {ws_group.id} was deleted") + else: + logger.warning(f"Workspace-level group {ws_group.display_name} does not exist, skipping") + + provider.ws.reflect_account_group_to_workspace(acc_group) + + logger.info("Updating group-level entitlements for account-level group from backup group") + + # tbd: raise this as an issue for SDK team + self._apply_roles_and_entitlements(backup_group, acc_group) + logger.info("Updated group-level entitlements and roles for account-level group from backup group") + + @staticmethod + def _apply_roles_and_entitlements(source: Group, destination: Group): + op_schema = "urn:ietf:params:scim:api:messages:2.0:PatchOp" + schemas = [op_schema, op_schema] + entitlements = ( + { + "op": "add", + "path": "entitlements", + "value": [{"value": e.value} for e in source.entitlements], + } + if source.entitlements + else {} + ) + + roles = ( + { + "op": "add", + "path": "roles", + "value": [{"value": r.value} for r in source.roles], + } + if source.roles + else {} + ) + + operations = [entitlements, roles] + request = { + "schemas": schemas, + "Operations": operations, + } + provider.ws.api_client.do( + "PATCH", f"/api/2.0/preview/scim/v2/Groups/{destination.id}", data=json.dumps(request) + ) + + # please keep the public methods below this line + + def prepare_groups_in_environment(self): + logger.info("Preparing groups in the current environment") + logger.info("At this step we'll verify that all groups exist and are of the correct type") + logger.info("If some temporary groups are missing, they'll be created") + if self.config.selected: + logger.info("Using the provided group listing") + + for g in self.config.selected: + assert g not in self.SYSTEM_GROUPS, f"Cannot migrate system group {g}" + + self._set_migration_groups(self.config.selected) + else: + logger.info("No group listing provided, finding eligible groups automatically") + self._set_migration_groups(groups_names=self._find_eligible_groups()) + logger.info("Environment prepared successfully") + + @property + def migration_groups_provider(self) -> MigrationGroupsProvider: + assert len(self._migration_groups_provider.groups) > 0, "Migration groups were not loaded or initialized" + return self._migration_groups_provider + + def replace_workspace_groups_with_account_groups(self): + logger.info("Replacing the workspace groups with account-level groups") + logger.info(f"In total, {len(self.migration_groups_provider.groups)} group(s) to be replaced") + + executables = [ + partial(self._replace_group, migration_info) for migration_info in self.migration_groups_provider.groups + ] + ThreadedExecution(executables).run() + logger.info("Workspace groups were successfully replaced with account-level groups") + + def delete_backup_groups(self): + logger.info("Deleting the workspace-level backup groups") + logger.info(f"In total, {len(self.migration_groups_provider.groups)} group(s) to be deleted") + + for migration_info in self.migration_groups_provider.groups: + try: + provider.ws.groups.delete(id=migration_info.backup.id) + except Exception as e: + logger.warning( + f"Failed to delete backup group {migration_info.backup.display_name} " + f"with id {migration_info.backup.id}" + ) + logger.warning(f"Original exception {e}") + + logger.info("Backup groups were successfully deleted") diff --git a/src/uc_migration_toolkit/managers/inventory/__init__.py b/src/uc_migration_toolkit/managers/inventory/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/src/uc_migration_toolkit/managers/inventory/__init__.py @@ -0,0 +1 @@ + diff --git a/src/uc_migration_toolkit/managers/inventory/inventorizer.py b/src/uc_migration_toolkit/managers/inventory/inventorizer.py new file mode 100644 index 0000000000..4c79caf654 --- /dev/null +++ b/src/uc_migration_toolkit/managers/inventory/inventorizer.py @@ -0,0 +1,69 @@ +from collections.abc import Callable, Iterator +from functools import partial +from typing import Generic, TypeVar + +from databricks.sdk.service.iam import ObjectPermissions + +from uc_migration_toolkit.managers.inventory.types import ( + LogicalObjectType, + PermissionsInventoryItem, + RequestObjectType, +) +from uc_migration_toolkit.providers.client import provider +from uc_migration_toolkit.providers.config import provider as config_provider +from uc_migration_toolkit.providers.logger import logger +from uc_migration_toolkit.utils import ThreadedExecution + +InventoryObject = TypeVar("InventoryObject") + + +class StandardInventorizer(Generic[InventoryObject]): + """ + Standard means that it can collect using the default listing/permissions function without any additional logic. + """ + + def __init__( + self, + logical_object_type: LogicalObjectType, + request_object_type: RequestObjectType, + listing_function: Callable[..., Iterator[InventoryObject]], + id_attribute: str, + permissions_function: Callable[..., ObjectPermissions] | None = None, + ): + self._config = config_provider.config.rate_limit + self._logical_object_type = logical_object_type + self._request_object_type = request_object_type + self._listing_function = listing_function + self._id_attribute = id_attribute + self._permissions_function = permissions_function if permissions_function else provider.ws.permissions.get + self._objects: list[InventoryObject] = [] + + @property + def logical_object_type(self) -> LogicalObjectType: + return self._logical_object_type + + def preload(self): + logger.info(f"Listing objects with type {self._request_object_type}...") + self._objects = list(self._listing_function()) + logger.info(f"Object metadata prepared for {len(self._objects)} objects.") + + def _process_single_object(self, _object: InventoryObject) -> PermissionsInventoryItem: + permissions = self._permissions_function( + self._request_object_type, _object.__getattribute__(self._id_attribute) + ) + inventory_item = PermissionsInventoryItem( + object_id=_object.__getattribute__(self._id_attribute), + logical_object_type=self._logical_object_type, + request_object_type=self._request_object_type, + object_permissions=permissions.as_dict(), + ) + return inventory_item + + def inventorize(self): + logger.info(f"Fetching permissions for {len(self._objects)} objects...") + + executables = [partial(self._process_single_object, _object) for _object in self._objects] + threaded_execution = ThreadedExecution[PermissionsInventoryItem](executables) + collected = threaded_execution.run() + logger.info(f"Permissions fetched for {len(collected)} objects of type {self._request_object_type}") + return collected diff --git a/src/uc_migration_toolkit/managers/inventory/permissions.py b/src/uc_migration_toolkit/managers/inventory/permissions.py new file mode 100644 index 0000000000..2a539c4f81 --- /dev/null +++ b/src/uc_migration_toolkit/managers/inventory/permissions.py @@ -0,0 +1,128 @@ +from dataclasses import dataclass +from functools import partial +from typing import Literal + +from databricks.sdk.service.iam import AccessControlRequest, Group + +from uc_migration_toolkit.managers.group import MigrationGroupsProvider +from uc_migration_toolkit.managers.inventory.inventorizer import StandardInventorizer +from uc_migration_toolkit.managers.inventory.table import InventoryTableManager +from uc_migration_toolkit.managers.inventory.types import ( + LogicalObjectType, + PermissionsInventoryItem, + RequestObjectType, +) +from uc_migration_toolkit.providers.client import provider +from uc_migration_toolkit.providers.config import provider as config_provider +from uc_migration_toolkit.providers.logger import logger +from uc_migration_toolkit.utils import ThreadedExecution + + +@dataclass +class PermissionRequestPayload: + request_object_type: RequestObjectType + object_id: str + access_control_list: list[AccessControlRequest] + + def as_dict(self): + return { + "request_object_type": self.request_object_type, + "object_id": self.object_id, + "access_control_list": self.access_control_list, + } + + +class PermissionManager: + def __init__(self, inventory_table_manager: InventoryTableManager): + self.config = config_provider.config + self.inventory_table_manager = inventory_table_manager + + @staticmethod + def get_inventorizers(): + return [ + StandardInventorizer( + logical_object_type=LogicalObjectType.CLUSTER, + request_object_type=RequestObjectType.CLUSTERS, + listing_function=provider.ws.clusters.list, + id_attribute="cluster_id", + ) + ] + + def inventorize_permissions(self): + logger.info("Inventorizing the permissions") + + for inventorizer in self.get_inventorizers(): + inventorizer.preload() + collected = inventorizer.inventorize() + if collected: + self.inventory_table_manager.save(collected) + else: + logger.warning(f"No objects of type {inventorizer.logical_object_type} were found") + + logger.info("Permissions were inventorized and saved") + + @staticmethod + def _prepare_new_permission_request( + item: PermissionsInventoryItem, + migration_groups_provider: MigrationGroupsProvider, + destination: Literal["backup", "account"], + ) -> PermissionRequestPayload: + new_acls: list[AccessControlRequest] = [] + + logger.info("Attempting to build the new ACLs, verifying if there are any relevant groups") + for acl in item.typed_object_permissions.access_control_list: + if acl.group_name in [g.workspace.display_name for g in migration_groups_provider.groups]: + migration_info = migration_groups_provider.get_by_workspace_group_name(acl.group_name) + assert migration_info is not None, f"Group {acl.group_name} is not in the migration groups provider" + destination_group: Group = getattr(migration_info, destination) + for permission in acl.all_permissions: + if permission.inherited: + continue + new_acls.append( + AccessControlRequest( + group_name=destination_group.display_name, + permission_level=permission.permission_level, + ) + ) + else: + continue + + if new_acls: + return PermissionRequestPayload( + request_object_type=item.request_object_type, + object_id=item.object_id, + access_control_list=new_acls, + ) + + @staticmethod + def _apply_permissions_in_parallel(requests: list[PermissionRequestPayload]): + executables = [ + partial( + provider.ws.permissions.update, + request_object_type=payload.request_object_type, + request_object_id=payload.object_id, + access_control_list=payload.access_control_list, + ) + for payload in requests + ] + execution = ThreadedExecution[None](executables) + execution.run() + + def apply_group_permissions( + self, migration_groups_provider: MigrationGroupsProvider, destination: Literal["backup", "account"] + ): + logger.info(f"Applying the permissions to {destination} groups") + logger.info(f"Total groups to apply permissions: {len(migration_groups_provider.groups)}") + + permissions_on_source = self.inventory_table_manager.load_for_groups( + groups=[g.workspace for g in migration_groups_provider.groups] + ) + applicable_permissions = filter( + lambda item: item is not None, + [ + self._prepare_new_permission_request(item, migration_groups_provider, destination=destination) + for item in permissions_on_source + ], + ) + self._apply_permissions_in_parallel(requests=list(applicable_permissions)) + logger.info("All permissions were applied") diff --git a/src/uc_migration_toolkit/managers/inventory/table.py b/src/uc_migration_toolkit/managers/inventory/table.py new file mode 100644 index 0000000000..4d5710c469 --- /dev/null +++ b/src/uc_migration_toolkit/managers/inventory/table.py @@ -0,0 +1,126 @@ +import pandas as pd +import pyspark.sql.functions as F # noqa: N812 +from databricks.sdk.service.iam import Group +from pyspark.sql import DataFrame +from pyspark.sql.types import ( + ArrayType, + BooleanType, + StringType, + StructField, + StructType, +) + +from uc_migration_toolkit.managers.inventory.types import PermissionsInventoryItem +from uc_migration_toolkit.providers.config import provider as config_provider +from uc_migration_toolkit.providers.logger import logger +from uc_migration_toolkit.providers.spark import SparkMixin + + +class InventoryTableManager(SparkMixin): + def __init__(self): + super().__init__() + self.config = config_provider.config.inventory + + @property + def _table_schema(self) -> StructType: + return StructType( + [ + StructField("object_id", StringType(), True), + StructField("logical_object_type", StringType(), True), + StructField("request_object_type", StringType(), True), + StructField( + "object_permissions", + StructType( + [ + StructField( + "access_control_list", + ArrayType( + StructType( + [ + StructField( + "all_permissions", + ArrayType( + StructType( + [ + StructField("inherited", BooleanType(), True), + StructField( + "inherited_from_object", + ArrayType(StringType(), True), + True, + ), + StructField("permission_level", StringType(), True), + ] + ), + True, + ), + True, + ), + StructField("group_name", StringType(), True), + StructField("service_principal_name", StringType(), True), + StructField("user_name", StringType(), True), + ] + ), + True, + ), + True, + ), + StructField("object_id", StringType(), True), + StructField("object_type", StringType(), True), + ] + ), + True, + ), + ] + ) + + @property + def _table(self) -> DataFrame: + assert self.config.table, "Inventory table name is not set" + return self.spark.table(self.config.table.to_spark()) + + def cleanup(self): + logger.info(f"Cleaning up inventory table {self.config.table}") + self.spark.sql(f"DROP TABLE IF EXISTS {self.config.table.to_spark()}") + logger.info("Inventory table cleanup complete") + + def save(self, items: list[PermissionsInventoryItem]): + logger.info(f"Saving {len(items)} items to inventory table {self.config.table}") + serialized_items = pd.DataFrame([item.model_dump(mode="json") for item in items]) + df = self.spark.createDataFrame(serialized_items, schema=self._table_schema) + df.write.mode("append").format("delta").saveAsTable(self.config.table.to_spark()) + logger.info("Successfully saved the items to inventory table") + + def load_all(self) -> list[PermissionsInventoryItem]: + logger.info(f"Loading inventory table {self.config.table}") + df = ( + self._table.withColumn("plain_permissions", F.to_json("object_permissions")) + .drop("object_permissions") + .toPandas() + ) + logger.info("Successfully loaded the inventory table") + return PermissionsInventoryItem.from_pandas(df) + + def load_for_groups(self, groups: list[Group]) -> list[PermissionsInventoryItem]: + logger.info(f"Scanning inventory table {self.config.table} for {len(groups)} groups") + group_names = [g.display_name for g in groups] + group_names_sql_argument = ",".join([f'"{name}"' for name in group_names]) + df = ( + self._table.where( + f""" + size( + array_intersect( + array_distinct(transform(object_permissions.access_control_list, item -> item.group_name)) + , array({group_names_sql_argument}) + ) + ) > 0 + """ + ) + .withColumn("plain_permissions", F.to_json("object_permissions")) + .drop("object_permissions") + .toPandas() + ) + + logger.info( + f"Successfully scanned the inventory table, loaded {len(df)} relevant objects for {len(groups)} groups" + ) + return PermissionsInventoryItem.from_pandas(df) diff --git a/src/uc_migration_toolkit/managers/inventory/types.py b/src/uc_migration_toolkit/managers/inventory/types.py new file mode 100644 index 0000000000..9c237641e1 --- /dev/null +++ b/src/uc_migration_toolkit/managers/inventory/types.py @@ -0,0 +1,66 @@ +import json + +import pandas as pd +from databricks.sdk.service.iam import ObjectPermissions +from pydantic import BaseModel + +from uc_migration_toolkit.utils import StrEnum + + +class RequestObjectType(StrEnum): + AUTHORIZATION = "authorization" # tokens and passwords are here too! + CLUSTERS = "clusters" + CLUSTER_POLICIES = "cluster-policies" + DIRECTORIES = "directories" + EXPERIMENTS = "experiments" + FILES = "files" + INSTANCE_POOLS = "instance-pools" + JOBS = "jobs" + NOTEBOOKS = "notebooks" + PIPELINES = "pipelines" + REGISTERED_MODELS = "registered-models" + REPOS = "repos" + SERVING_ENDPOINTS = "serving-endpoints" + SQL_WAREHOUSES = "sql-warehouses" + TOKENS = "tokens" + + def __repr__(self): + return self.value + + +class SqlRequestObjectType(StrEnum): + ALERTS = "alerts" + DASHBOARDS = "dashboards" + DATA_SOURCES = "data-sources" + QUERIES = "queries" + + def __repr__(self): + return self.value + + +class LogicalObjectType(StrEnum): + CLUSTER = "CLUSTER" + + def __repr__(self): + return self.value + + +class PermissionsInventoryItem(BaseModel): + object_id: str + logical_object_type: LogicalObjectType + request_object_type: RequestObjectType | SqlRequestObjectType + object_permissions: dict + + @property + def typed_object_permissions(self) -> ObjectPermissions: + return ObjectPermissions.from_dict(self.object_permissions) + + @staticmethod + def from_pandas(source: pd.DataFrame) -> list["PermissionsInventoryItem"]: + items = source.to_dict(orient="records") + + for item in items: + item["object_permissions"] = json.loads(item["plain_permissions"]) + item.pop("plain_permissions") + + return [PermissionsInventoryItem(**item) for item in items] diff --git a/src/uc_migration_toolkit/providers/__init__.py b/src/uc_migration_toolkit/providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/uc_migration_toolkit/providers/client.py b/src/uc_migration_toolkit/providers/client.py new file mode 100644 index 0000000000..cb17224ac2 --- /dev/null +++ b/src/uc_migration_toolkit/providers/client.py @@ -0,0 +1,93 @@ +import json +from dataclasses import asdict + +import requests +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.iam import Group +from requests.adapters import HTTPAdapter +from urllib3 import Retry + +from uc_migration_toolkit.config import AuthConfig +from uc_migration_toolkit.providers.config import provider as config_provider +from uc_migration_toolkit.providers.logger import logger + + +class ImprovedWorkspaceClient(WorkspaceClient): + def assign_permissions(self, principal_id: str, permissions: list[str]): + request_string = f"/api/2.0/preview/permissionassignments/principals/{principal_id}" + + self.api_client.do("put", request_string, data=json.dumps({"permissions": permissions})) + + def list_account_level_groups( + self, filter: str, attributes: str | None = None, excluded_attributes: str | None = None # noqa: A002 + ) -> list[Group]: + query = {"filter": filter, "attributes": attributes, "excludedAttributes": excluded_attributes} + response = self.api_client.do("get", "/api/2.0/account/scim/v2/Groups", query=query) + return [Group.from_dict(v) for v in response.get("Resources", [])] + + def reflect_account_group_to_workspace(self, acc_group: Group) -> None: + logger.info(f"Reflecting group {acc_group.display_name} to workspace") + self.assign_permissions(principal_id=acc_group.id, permissions=["USER"]) + logger.info(f"Group {acc_group.display_name} successfully reflected to workspace") + + +class ClientProvider: + def __init__(self): + self._ws_client: ImprovedWorkspaceClient | None = None + + @staticmethod + def _verify_ws_client(w: ImprovedWorkspaceClient): + assert w.current_user.me(), "Cannot authenticate with the workspace client" + _me = w.current_user.me() + is_workspace_admin = any(g.display == "admins" for g in _me.groups) + if not is_workspace_admin: + msg = "Current user is not a workspace admin" + raise RuntimeError(msg) + + @staticmethod + def __get_retry_strategy(): + retry_strategy = Retry( + total=10, + backoff_factor=0.5, + status_forcelist=[429], + respect_retry_after_header=True, + raise_on_status=False, # return original response when retries have been exhausted + # adjusted from the default values + allowed_methods=["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE", "PATCH", "POST"], + ) + return retry_strategy + + def _adjust_session(self, client: ImprovedWorkspaceClient, pool_size: int | None = None): + pool_size = pool_size if pool_size else config_provider.config.num_threads + logger.debug(f"Adjusting the session to fully utilize {pool_size} threads") + _existing_session = client.api_client._session + _session = requests.Session() + _session.auth = _existing_session.auth + _session.mount("https://", HTTPAdapter(max_retries=self.__get_retry_strategy(), pool_maxsize=pool_size)) + client.api_client._session = _session + logger.debug("Session adjusted") + + def set_ws_client(self, auth_config: AuthConfig | None = None, pool_size: int | None = None): + if self._ws_client: + logger.warning("Workspace client already initialized, skipping") + return + + logger.info("Initializing the workspace client") + if auth_config and auth_config.workspace: + logger.info("Using the provided workspace client credentials") + _client = ImprovedWorkspaceClient(**asdict(auth_config.workspace)) + else: + logger.info("Trying standard workspace auth mechanisms") + _client = ImprovedWorkspaceClient() + + self._verify_ws_client(_client) + self._adjust_session(_client, pool_size) + self._ws_client = _client + + @property + def ws(self) -> ImprovedWorkspaceClient: + assert self._ws_client, "Workspace client not initialized" + return self._ws_client + + +provider = ClientProvider() diff --git a/src/uc_migration_toolkit/providers/config.py b/src/uc_migration_toolkit/providers/config.py new file mode 100644 index 0000000000..29fdd0d33b --- /dev/null +++ b/src/uc_migration_toolkit/providers/config.py @@ -0,0 +1,17 @@ +from uc_migration_toolkit.config import MigrationConfig + + +class ConfigProvider: + def __init__(self): + self._config: MigrationConfig | None = None + + def set_config(self, config: MigrationConfig): + self._config = config + + @property + def config(self) -> MigrationConfig: + assert self._config, "Config is not set" + return self._config + + +provider = ConfigProvider() diff --git a/src/uc_migration_toolkit/providers/logger.py b/src/uc_migration_toolkit/providers/logger.py new file mode 100644 index 0000000000..5e4079f423 --- /dev/null +++ b/src/uc_migration_toolkit/providers/logger.py @@ -0,0 +1,5 @@ +from loguru import logger as _loguru_logger + +# reassigning the logger to the loguru logger +# for flexibility and simple dependency injection +logger = _loguru_logger diff --git a/src/uc_migration_toolkit/providers/spark.py b/src/uc_migration_toolkit/providers/spark.py new file mode 100644 index 0000000000..34bf6d1197 --- /dev/null +++ b/src/uc_migration_toolkit/providers/spark.py @@ -0,0 +1,50 @@ +import functools +import os +import time + +from databricks.sdk.service.compute import State +from pyspark.sql import SparkSession + +from uc_migration_toolkit.providers.client import provider +from uc_migration_toolkit.providers.logger import logger + + +class SparkMixin: + def __init__(self): + super().__init__() + self._spark = self._initialize_spark() + + @staticmethod + @functools.lru_cache(maxsize=10_000) + def _initialize_spark() -> SparkSession: + logger.info("Initializing Spark session") + if "spark" in locals(): + logger.info("Using the Spark session from runtime") + return locals()["spark"] + else: + logger.info("Using DB Connect") + from databricks.connect import DatabricksSession + + if "DATABRICKS_CLUSTER_ID" not in os.environ: + msg = "DATABRICKS_CLUSTER_ID environment variable is not set, cannot use DB Connect" + raise RuntimeError(msg) + cluster_id = os.environ["DATABRICKS_CLUSTER_ID"] + cluster_info = provider.ws.clusters.get(cluster_id) + + logger.info(f"Using cluster {cluster_id} with name {cluster_info.cluster_name}") + + if cluster_info.state not in [State.RUNNING, State.PENDING, State.RESTARTING]: + logger.info("Cluster is not running, starting it") + provider.ws.clusters.start(cluster_id) + time.sleep(2) + + logger.info("Waiting for the cluster to get running") + provider.ws.clusters.wait_get_cluster_running(cluster_id) + logger.info("Cluster is ready, creating the DBConnect session") + provider.ws.config.cluster_id = cluster_id + spark = DatabricksSession.builder.sdkConfig(provider.ws.config).getOrCreate() + return spark + + @property + def spark(self): + return self._spark diff --git a/src/uc_migration_toolkit/toolkits/__init__.py b/src/uc_migration_toolkit/toolkits/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/uc_migration_toolkit/toolkits/group_migration.py b/src/uc_migration_toolkit/toolkits/group_migration.py new file mode 100644 index 0000000000..3827f0a166 --- /dev/null +++ b/src/uc_migration_toolkit/toolkits/group_migration.py @@ -0,0 +1,41 @@ +from uc_migration_toolkit.config import MigrationConfig +from uc_migration_toolkit.managers.group import GroupManager +from uc_migration_toolkit.managers.inventory.permissions import PermissionManager +from uc_migration_toolkit.managers.inventory.table import InventoryTableManager +from uc_migration_toolkit.providers.client import provider +from uc_migration_toolkit.providers.config import provider as config_provider + + +class GroupMigrationToolkit: + def __init__(self, config: MigrationConfig): + # please note the order of configs here + config_provider.set_config(config) + provider.set_ws_client(config.auth) + self.group_manager = GroupManager() + self.table_manager = InventoryTableManager() + self.permissions_manager = PermissionManager(self.table_manager) + + def prepare_groups_in_environment(self): + self.group_manager.prepare_groups_in_environment() + + def cleanup_inventory_table(self): + self.table_manager.cleanup() + + def inventorize_permissions(self): + self.permissions_manager.inventorize_permissions() + + def apply_permissions_to_backup_groups(self): + self.permissions_manager.apply_group_permissions( + self.group_manager.migration_groups_provider, destination="backup" + ) + + def replace_workspace_groups_with_account_groups(self): + self.group_manager.replace_workspace_groups_with_account_groups() + + def apply_permissions_to_account_groups(self): + self.permissions_manager.apply_group_permissions( + self.group_manager.migration_groups_provider, destination="account" + ) + + def delete_backup_groups(self): + self.group_manager.delete_backup_groups() diff --git a/src/uc_migration_toolkit/utils.py b/src/uc_migration_toolkit/utils.py new file mode 100644 index 0000000000..8a7a08af43 --- /dev/null +++ b/src/uc_migration_toolkit/utils.py @@ -0,0 +1,106 @@ +import concurrent +import datetime as dt +import enum +from collections.abc import Callable +from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor +from typing import Generic, TypeVar + +from ratelimit import limits, sleep_and_retry + +from uc_migration_toolkit.config import RateLimitConfig +from uc_migration_toolkit.providers.config import provider as config_provider +from uc_migration_toolkit.providers.logger import logger + +ExecutableResult = TypeVar("ExecutableResult") +ExecutableFunction = Callable[..., ExecutableResult] + + +class ProgressReporter: + def __init__(self, total_executables: int, message_prefix: str | None = "threaded execution - processed: "): + self.counter = 0 + self._total_executables = total_executables + self.start_time = dt.datetime.now() + self._message_prefix = message_prefix + + def progress_report(self, _): + self.counter += 1 + measuring_time = dt.datetime.now() + delta_from_start = measuring_time - self.start_time + rps = self.counter / delta_from_start.total_seconds() + offset = len(str(self._total_executables)) + if self.counter % 10 == 0 or self.counter == self._total_executables: + logger.info( + f"{self._message_prefix}{self.counter:>{offset}d}/{self._total_executables}, rps: {rps:.3f}/sec" + ) + + +class ThreadedExecution(Generic[ExecutableResult]): + def __init__( + self, + executables: list[ExecutableFunction], + num_threads: int | None = None, + rate_limit: RateLimitConfig | None = None, + done_callback: Callable[..., None] | None = None, + ): + self._num_threads = num_threads if num_threads else config_provider.config.num_threads + self._rate_limit = rate_limit if rate_limit else config_provider.config.rate_limit + self._executables = executables + self._futures = [] + self._done_callback = ( + done_callback if done_callback else self._prepare_default_done_callback(len(self._executables)) + ) + + @staticmethod + def _prepare_default_done_callback(total_executables: int): + progress_reporter = ProgressReporter(total_executables) + return progress_reporter.progress_report + + def run(self) -> list[ExecutableResult]: + logger.info("Starting threaded execution") + + @sleep_and_retry + @limits(calls=self._rate_limit.max_requests_per_period, period=self._rate_limit.period_in_seconds) + def rate_limited_wrapper(func: ExecutableFunction) -> ExecutableResult: + return func() + + with ThreadPoolExecutor(self._num_threads) as executor: + for executable in self._executables: + future = executor.submit(rate_limited_wrapper, executable) + if self._done_callback: + future.add_done_callback(self._done_callback) + self._futures.append(future) + + results = concurrent.futures.wait(self._futures, return_when=ALL_COMPLETED) + + logger.info("Collecting the results from threaded execution") + collected = [future.result() for future in results.done] + return collected + + +class Request: + def __init__(self, req: dict): + self.request = req + + def as_dict(self) -> dict: + return self.request + + +class StrEnum(str, enum.Enum): # re-exported for compatability with older python versions + def __new__(cls, value, *args, **kwargs): + if not isinstance(value, str | enum.auto): + msg = f"Values of StrEnums must be strings: {value!r} is a {type(value)}" + raise TypeError(msg) + return super().__new__(cls, value, *args, **kwargs) + + def __str__(self): + return str(self.value) + + def _generate_next_value_(name, *_): # noqa: N805 + return name + + +class WorkspaceLevelEntitlement(StrEnum): + WORKSPACE_ACCESS = "workspace-access" + DATABRICKS_SQL_ACCESS = "databricks-sql-access" + ALLOW_CLUSTER_CREATE = "allow-cluster-create" + ALLOW_INSTANCE_POOL_CREATE = "allow-instance-pool-create" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000000..d4beb81d73 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,311 @@ +import json +import os +import random +import uuid +from dataclasses import dataclass +from functools import partial +from pathlib import Path + +import pytest +from _pytest.fixtures import SubRequest +from databricks.sdk import AccountClient, WorkspaceClient +from databricks.sdk.service.compute import ClusterDetails +from databricks.sdk.service.iam import ( + AccessControlRequest, + ComplexValue, + Group, + PermissionLevel, + User, +) +from dotenv import load_dotenv + +from uc_migration_toolkit.config import ( + AuthConfig, + InventoryTable, + RateLimitConfig, + WorkspaceAuthConfig, +) +from uc_migration_toolkit.managers.inventory.types import RequestObjectType +from uc_migration_toolkit.providers.client import ImprovedWorkspaceClient, provider +from uc_migration_toolkit.providers.logger import logger +from uc_migration_toolkit.utils import ( + Request, + ThreadedExecution, + WorkspaceLevelEntitlement, +) + + +def initialize_env() -> None: + principal_env = Path(__file__).parent.parent.parent / ".env.principal" + + if principal_env.exists(): + logger.debug("Using credentials provided in .env.principal") + load_dotenv(dotenv_path=principal_env) + else: + logger.debug(f"No .env.principal found at {principal_env.absolute()}, using environment variables") + + +initialize_env() + +NUM_TEST_GROUPS = os.environ.get("NUM_TEST_GROUPS", 5) + +NUM_TEST_INSTANCE_PROFILES = os.environ.get("NUM_TEST_INSTANCE_PROFILES", 3) +NUM_TEST_CLUSTERS = os.environ.get("NUM_TEST_CLUSTERS", 3) + +NUM_THREADS = os.environ.get("NUM_TEST_THREADS", 20) +DB_CONNECT_CLUSTER_NAME = os.environ.get("DB_CONNECT_CLUSTER_NAME", "ucx-integration-testing") +UCX_TESTING_PREFIX = os.environ.get("UCX_TESTING_PREFIX", "ucx") + +Threader = partial(ThreadedExecution, num_threads=NUM_THREADS, rate_limit=RateLimitConfig()) + + +@dataclass +class InstanceProfile: + instance_profile_arn: str + iam_role_arn: str + + +@dataclass +class EnvironmentInfo: + test_uid: str + groups: list[tuple[Group, Group]] + + +def generate_group_by_id( + _ws: WorkspaceClient, _acc: AccountClient, group_name: str, users_sample: list[User] +) -> tuple[Group, Group]: + entities = [ComplexValue(display=user.display_name, value=user.id) for user in users_sample] + logger.debug(f"Creating group with name {group_name}") + + def get_random_entitlements(): + chosen: list[WorkspaceLevelEntitlement] = random.choices( + list(WorkspaceLevelEntitlement), + k=random.randint(1, 3), + ) + entitlements = [ComplexValue(display=None, primary=None, type=None, value=value) for value in chosen] + return entitlements + + ws_group = _ws.groups.create(display_name=group_name, members=entities, entitlements=get_random_entitlements()) + acc_group = _acc.groups.create(display_name=group_name, members=entities) + return ws_group, acc_group + + +def _create_groups(_ws: ImprovedWorkspaceClient, _acc: AccountClient, prefix: str) -> list[tuple[Group, Group]]: + logger.debug("Listing users to create sample groups") + test_users = list(_ws.users.list(filter="displayName sw 'test-user-'", attributes="id, userName, displayName")) + logger.debug(f"Total of test users {len(test_users)}") + user_samples: dict[str, list[User]] = { + f"{prefix}-test-group-{gid}": random.choices(test_users, k=random.randint(1, 40)) + for gid in range(NUM_TEST_GROUPS) + } + executables = [ + partial(generate_group_by_id, _ws, _acc, group_name, users_sample) + for group_name, users_sample in user_samples.items() + ] + return Threader(executables).run() + + +@pytest.fixture(scope="session") +def ws() -> ImprovedWorkspaceClient: + auth_config = AuthConfig( + workspace=WorkspaceAuthConfig( + host=os.environ["DATABRICKS_WS_HOST"], + client_id=os.environ["DATABRICKS_COMMON_CLIENT_ID"], + client_secret=os.environ["DATABRICKS_COMMON_CLIENT_SECRET"], + ) + ) + provider.set_ws_client(auth_config, pool_size=NUM_THREADS) + yield provider.ws + + +@pytest.fixture(scope="session", autouse=True) +def acc() -> AccountClient: + acc_client = AccountClient( + host=os.environ["DATABRICKS_ACC_HOST"], + client_id=os.environ["DATABRICKS_COMMON_CLIENT_ID"], + client_secret=os.environ["DATABRICKS_COMMON_CLIENT_SECRET"], + account_id=os.environ["DATABRICKS_ACC_ACCOUNT_ID"], + ) + yield acc_client + + +@pytest.fixture(scope="session", autouse=True) +def dbconnect(ws: ImprovedWorkspaceClient): + dbc_cluster = next(filter(lambda c: c.cluster_name == DB_CONNECT_CLUSTER_NAME, ws.clusters.list()), None) + + if dbc_cluster: + logger.debug(f"Integration testing cluster {DB_CONNECT_CLUSTER_NAME} already exists, skipping it's creation") + else: + logger.debug("Creating a cluster for integration testing") + request = { + "cluster_name": DB_CONNECT_CLUSTER_NAME, + "spark_version": "13.2.x-scala2.12", + "instance_pool_id": os.environ["TEST_POOL_ID"], + "driver_instance_pool_id": os.environ["TEST_POOL_ID"], + "num_workers": 0, + "spark_conf": {"spark.master": "local[*, 4]", "spark.databricks.cluster.profile": "singleNode"}, + "custom_tags": { + "ResourceClass": "SingleNode", + }, + "data_security_mode": "SINGLE_USER", + "autotermination_minutes": 180, + "runtime_engine": "PHOTON", + } + + dbc_cluster = ws.clusters.create(spark_version="13.2.x-scala2.12", request=Request(request)) + + logger.debug(f"Cluster {dbc_cluster.cluster_id} created") + + os.environ["DATABRICKS_CLUSTER_ID"] = dbc_cluster.cluster_id + yield + + +@pytest.fixture(scope="session", autouse=True) +def env(ws: ImprovedWorkspaceClient, acc: AccountClient, request: SubRequest) -> EnvironmentInfo: + # prepare environment + test_uid = f"{UCX_TESTING_PREFIX}_{str(uuid.uuid4())[:8]}" + logger.debug(f"Creating environment with uid {test_uid}") + groups = _create_groups(ws, acc, test_uid) + + def _cleanup_groups(_ws: WorkspaceClient, _acc: AccountClient, _groups: tuple[Group, Group]): + ws_g, acc_g = _groups + logger.debug(f"Deleting groups {ws_g.display_name} [ws-level] and {acc_g.display_name} [acc-level]") + + try: + ws.groups.delete(ws_g.id) + except Exception as e: + logger.warning(f"Cannot delete ws-level group {ws_g.display_name}, skipping it. Original exception {e}") + + try: + g = next(iter(acc.groups.list(filter=f"displayName eq '{acc_g.display_name}'")), None) + if g: + acc.groups.delete(g.id) + except Exception as e: + logger.warning(f"Cannot delete acc-level group {acc_g.display_name}, skipping it. Original exception {e}") + + def post_cleanup(): + print("\n") + logger.debug("Cleaning up the environment") + logger.debug("Deleting test groups") + cleanups = [partial(_cleanup_groups, ws, acc, g) for g in groups] + + def error_silencer(func): + def _wrapped(*args, **kwargs): + try: + func(*args, **kwargs) + except Exception as e: + logger.warning(f"Cannot delete temp group, skipping it. Original exception {e}") + + return _wrapped + + silent_delete = error_silencer(ws.groups.delete) + + temp_cleanups = [ + partial(silent_delete, g.id) for g in ws.groups.list(filter=f"displayName sw 'db-temp-{test_uid}'") + ] + new_ws_groups_cleanups = [ + partial(silent_delete, g.id) for g in ws.groups.list(filter=f"displayName sw '{test_uid}'") + ] + + all_cleanups = cleanups + temp_cleanups + new_ws_groups_cleanups + Threader(all_cleanups).run() + logger.debug(f"Finished cleanup for the environment {test_uid}") + + request.addfinalizer(post_cleanup) + yield EnvironmentInfo(test_uid=test_uid, groups=groups) + + +@pytest.fixture(scope="session", autouse=True) +def instance_profiles(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[InstanceProfile]: + logger.debug("Adding test instance profiles") + profiles: list[InstanceProfile] = [] + + for i in range(NUM_TEST_INSTANCE_PROFILES): + profile_arn = f"arn:aws:iam::123456789:instance-profile/{env.test_uid}-test-{i}" + iam_role_arn = f"arn:aws:iam::123456789:role/{env.test_uid}-test-{i}" + ws.instance_profiles.add(instance_profile_arn=profile_arn, iam_role_arn=iam_role_arn, skip_validation=True) + profiles.append(InstanceProfile(instance_profile_arn=profile_arn, iam_role_arn=iam_role_arn)) + + for ws_group, _ in env.groups: + roles = { + "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], + "Operations": [ + { + "op": "add", + "path": "roles", + "value": [{"value": p.instance_profile_arn} for p in random.choices(profiles, k=2)], + } + ], + } + provider.ws.api_client.do("PATCH", f"/api/2.0/preview/scim/v2/Groups/{ws_group.id}", data=json.dumps(roles)) + + yield profiles + + logger.debug("Deleting test instance profiles") + for profile in profiles: + ws.instance_profiles.remove(profile.instance_profile_arn) + logger.debug("Test instance profiles deleted") + + +@pytest.fixture(scope="session", autouse=True) +def clusters(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[ClusterDetails]: + logger.debug("Creating test clusters") + + test_clusters = [ + ws.clusters.create( + spark_version="13.2.x-scala2.12", + instance_pool_id=os.environ["TEST_POOL_ID"], + driver_instance_pool_id=os.environ["TEST_POOL_ID"], + cluster_name=f"{env.test_uid}-test-{i}", + num_workers=1, + ) + for i in range(NUM_TEST_CLUSTERS) + ] + + for cluster in test_clusters: + + def get_random_ws_group() -> Group: + return random.choice([g[0] for g in env.groups]) + + def get_random_permission_level() -> PermissionLevel: + return random.choice( + [PermissionLevel.CAN_MANAGE, PermissionLevel.CAN_RESTART, PermissionLevel.CAN_ATTACH_TO] + ) + + acl_req = [ + AccessControlRequest( + group_name=get_random_ws_group().display_name, permission_level=get_random_permission_level() + ) + for _ in range(3) + ] + + ws.permissions.set( + request_object_type=RequestObjectType.CLUSTERS, + request_object_id=cluster.cluster_id, + access_control_list=acl_req, + ) + + yield test_clusters + + logger.debug("Deleting test clusters") + executables = [partial(ws.clusters.permanent_delete, c.cluster_id) for c in test_clusters] + Threader(executables).run() + logger.debug("Test clusters deleted") + + +@pytest.fixture() +def inventory_table(env: EnvironmentInfo) -> InventoryTable: + table = InventoryTable( + catalog="main", + database="default", + name=f"test_inventory_{env.test_uid}", + ) + + yield table + + logger.debug(f"Cleaning up inventory table {table}") + try: + provider.ws.tables.delete(table.to_spark()) + logger.debug(f"Inventory table {table} deleted") + except Exception as e: + logger.warning(f"Cannot delete inventory table, skipping it. Original exception {e}") diff --git a/tests/integration/test_e2e.py b/tests/integration/test_e2e.py new file mode 100644 index 0000000000..cf03eab4e2 --- /dev/null +++ b/tests/integration/test_e2e.py @@ -0,0 +1,141 @@ +from typing import Literal + +import pytest +from conftest import EnvironmentInfo +from databricks.sdk.service.compute import ClusterDetails +from pyspark.errors import AnalysisException + +from uc_migration_toolkit.config import ( + GroupsConfig, + InventoryConfig, + InventoryTable, + MigrationConfig, +) +from uc_migration_toolkit.managers.group import MigrationGroupInfo +from uc_migration_toolkit.managers.inventory.types import RequestObjectType +from uc_migration_toolkit.providers.client import ImprovedWorkspaceClient +from uc_migration_toolkit.providers.logger import logger +from uc_migration_toolkit.toolkits.group_migration import GroupMigrationToolkit + + +def _verify_group_permissions( + clusters: list[ClusterDetails], + ws: ImprovedWorkspaceClient, + toolkit: GroupMigrationToolkit, + target: Literal["backup", "account"], +): + logger.info("Verifying that the permissions were applied to backup groups") + for cluster in clusters: + cluster_permissions = ws.permissions.get(RequestObjectType.CLUSTERS, cluster.cluster_id) + for migration_info in toolkit.group_manager.migration_groups_provider.groups: + target_permissions = sorted( + [ + p + for p in cluster_permissions.access_control_list + if p.group_name == getattr(migration_info, target).display_name + ], + key=lambda p: p.group_name, + ) + + source_permissions = sorted( + [ + p + for p in cluster_permissions.access_control_list + if p.group_name == migration_info.workspace.display_name + ], + key=lambda p: p.group_name, + ) + + assert len(target_permissions) == len( + source_permissions + ), f"Target permissions were not applied correctly for cluster {cluster.cluster_id}" + + assert [t.all_permissions for t in target_permissions] == [ + s.all_permissions for s in source_permissions + ], f"Target permissions were not applied correctly for cluster {cluster.cluster_id}" + + +def _verify_roles_and_entitlements( + groups: list[MigrationGroupInfo], ws: ImprovedWorkspaceClient, target: Literal["backup", "account"] +): + for migration_info in groups: + workspace_group = migration_info.workspace + target_group = ws.groups.get(getattr(migration_info, target).id) + + assert workspace_group.roles == target_group.roles + assert workspace_group.entitlements == target_group.entitlements + + +def test_e2e( + env: EnvironmentInfo, inventory_table: InventoryTable, ws: ImprovedWorkspaceClient, clusters: list[ClusterDetails] +): + logger.info(f"Test environment: {env.test_uid}") + + config = MigrationConfig( + with_table_acls=False, + inventory=InventoryConfig(table=inventory_table), + groups=GroupsConfig(selected=[g[0].display_name for g in env.groups]), + auth=None, + ) + logger.info(f"Starting e2e with config: {config.to_json()}") + toolkit = GroupMigrationToolkit(config) + toolkit.prepare_groups_in_environment() + + logger.info("Verifying that the groups were created") + _verify_roles_and_entitlements(toolkit.group_manager.migration_groups_provider.groups, ws, "backup") + + assert len(ws.groups.list(filter=f"displayName sw '{config.groups.backup_group_prefix}{env.test_uid}'")) == len( + toolkit.group_manager.migration_groups_provider.groups + ) + + assert len(ws.groups.list(filter=f"displayName sw '{env.test_uid}'")) == len( + toolkit.group_manager.migration_groups_provider.groups + ) + + assert len(ws.list_account_level_groups(filter=f"displayName sw '{env.test_uid}'")) == len( + toolkit.group_manager.migration_groups_provider.groups + ) + + logger.info("Verifying that the groups were created - done") + + toolkit.cleanup_inventory_table() + + with pytest.raises(AnalysisException): + toolkit.table_manager.spark.catalog.getTable(toolkit.table_manager.config.table.to_spark()) + + toolkit.inventorize_permissions() + + logger.info("Verifying that the permissions were inventorized correctly") + saved_permissions = toolkit.table_manager.load_all() + + for cluster in clusters: + cluster_permissions = ws.permissions.get(RequestObjectType.CLUSTERS, cluster.cluster_id) + relevant_permission = next(filter(lambda p: p.object_id == cluster.cluster_id, saved_permissions), None) + assert relevant_permission is not None, f"Cluster {cluster.cluster_id} permissions were not inventorized" + assert relevant_permission.typed_object_permissions == cluster_permissions + + logger.info("Permissions were inventorized properly") + + toolkit.apply_permissions_to_backup_groups() + + _verify_group_permissions(clusters, ws, toolkit, "backup") + toolkit.replace_workspace_groups_with_account_groups() + + new_groups = list(ws.groups.list(filter=f"displayName sw '{env.test_uid}'", attributes="displayName,meta")) + assert len(new_groups) == len(toolkit.group_manager.migration_groups_provider.groups) + assert all(g.meta.resource_type == "Group" for g in new_groups) + _verify_roles_and_entitlements(toolkit.group_manager.migration_groups_provider.groups, ws, "account") + + toolkit.apply_permissions_to_account_groups() + _verify_group_permissions(clusters, ws, toolkit, "account") + + toolkit.delete_backup_groups() + + backup_groups = list( + ws.groups.list( + filter=f"displayName sw '{config.groups.backup_group_prefix}{env.test_uid}'", attributes="displayName,meta" + ) + ) + assert len(backup_groups) == 0 + + toolkit.cleanup_inventory_table() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000000..aeac516c8d --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,36 @@ +import shutil +import tempfile +from pathlib import Path + +import pytest +from delta import configure_spark_with_delta_pip +from pyspark.sql import SparkSession + +from uc_migration_toolkit.providers.logger import logger + + +@pytest.fixture(scope="session") +def spark() -> SparkSession: + """ + This fixture provides preconfigured SparkSession with Hive and Delta support. + After the test session, temporary warehouse directory is deleted. + :return: SparkSession + """ + logger.info("Configuring Spark session for testing environment") + warehouse_dir = tempfile.TemporaryDirectory().name + _builder = ( + SparkSession.builder.master("local[1]") + .config("spark.hive.metastore.warehouse.dir", Path(warehouse_dir).as_uri()) + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + ) + spark: SparkSession = configure_spark_with_delta_pip(_builder).getOrCreate() + logger.info("Spark session configured") + yield spark + logger.info("Shutting down Spark session") + spark.stop() + if Path(warehouse_dir).exists(): + shutil.rmtree(warehouse_dir) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000000..df47bbc02e --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,68 @@ +import os +from contextlib import contextmanager +from functools import partial +from pathlib import Path + +import pytest +import yaml +from pydantic import RootModel + +from uc_migration_toolkit.cli.utils import get_migration_config +from uc_migration_toolkit.config import ( + GroupsConfig, + InventoryConfig, + InventoryTable, + MigrationConfig, +) + + +def test_initialization(): + mc = partial( + MigrationConfig, + inventory=InventoryConfig(table=InventoryTable(catalog="catalog", database="database", name="name")), + groups=GroupsConfig(auto=True), + ) + + with pytest.raises(NotImplementedError): + mc(with_table_acls=True) + + mc(with_table_acls=False) + + +# path context manager +# changes current directory to given path, then changes back to previous directory +@contextmanager +def set_directory(path: Path): + """Sets the cwd within the context + + Args: + path (Path): The path to the cwd + + Yields: + None + """ + + origin = Path().absolute() + try: + os.chdir(path) + yield + finally: + os.chdir(origin) + + +def test_reader(tmp_path: Path): + with set_directory(tmp_path): + mc = partial( + MigrationConfig, + inventory=InventoryConfig(table=InventoryTable(catalog="catalog", database="database", name="name")), + groups=GroupsConfig(auto=True), + ) + + config: MigrationConfig = mc(with_table_acls=False) + config_file = tmp_path / "config.yml" + + with config_file.open("w") as writable: + yaml.safe_dump(RootModel[MigrationConfig](config).model_dump(), writable) + + loaded = get_migration_config(config_file) + assert loaded == config diff --git a/uc_upgrade/group_migration.py b/uc_upgrade/group_migration.py deleted file mode 100644 index 6707f3edee..0000000000 --- a/uc_upgrade/group_migration.py +++ /dev/null @@ -1,2330 +0,0 @@ -import concurrent.futures -import json -import logging -import math -import time -from typing import List - -import requests -from pyspark.sql import session -from pyspark.sql.functions import ( - array_contains, - col, - collect_set, - lit, - regexp_replace, - split, -) -from pyspark.sql.types import MapType, StringType, StructField, StructType - -# Initialize logger -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - - -class GroupMigration: - def __init__( - self, - groupL: List[str], - cloud: str, - inventoryTableName: str, - workspace_url: str, - pat: str, - spark: session.SparkSession, - userName: str, - checkTableACL: bool = False, - numThreads: int = 10, - autoGenerateList: bool = False, - verbose: bool = False, - freshInventory: bool = False, - ): - self.groupL = groupL - self.cloud = cloud - self.workspace_url = workspace_url.rstrip("/") - self.inventoryTableName = inventoryTableName - self.token = pat - self.headers = {"Authorization": "Bearer %s" % self.token} - self.groupIdDict = {} # map: group id => group name - self.groupNameDict = {} # map: group name => group id - self.accountGroups_lower = {} - self.groupMembers = {} # map: group id => list[tuple[member name, memberid]] - self.groupEntitlements = {} - self.groupGroupList = [] - self.groupUserList = [] - self.groupSPList = [] - self.groupWSGIdDict = {} # map : temp group id => temp group name - self.groupWSGNameDict = {} # map : temp group name => temp group id - self.groupRoles = {} - self.passwordPerm = {} - self.clusterPerm = {} - self.clusterPolicyPerm = {} - self.warehousePerm = {} - self.dashboardPerm = {} - self.queryPerm = {} - self.alertPerm = {} - self.instancePoolPerm = {} - self.jobPerm = {} - self.expPerm = {} - self.modelPerm = {} - self.dltPerm = {} - self.folderPerm = {} - self.notebookPerm = {} - self.filePerm = {} - self.repoPerm = {} - self.tokenPerm = {} - self.secretScopePerm = {} - self.dataObjectsPerm = [] - self.folderList = {} - self.notebookList = {} - self.fileList = {} - self.spark = spark - self.userName = userName - self.checkTableACL = checkTableACL - self.verbose = verbose - self.numThreads = numThreads - - self.lastInventoryRun = None - self.checkAllDB = False - if freshInventory: - logger.info(f"Clearing inventory table {self.inventoryTableName}") - spark.sql(f"drop table if exists {self.inventoryTableName}") - spark.sql(f"drop table if exists {self.inventoryTableName+'TableACL'}") - createSQL = f"create table {self.inventoryTableName} (GroupType string, WorkspaceObject string, \ - Permission MAP)" - spark.sql(createSQL) - createSQL = f"create table {self.inventoryTableName}TableACL (GroupType string, Database string, Principal \ - string, ActionTypes string, ObjectType string, ObjectKey string)" - spark.sql(createSQL) - logger.info("recreated tables...") - - # Check if we should automatically generate list, and do it immediately. - # Implementers Note: Could change this section to a lazy calculation - # by setting groupL to nil or some sentinel value and adding checks before use. - res = requests.get(f"{self.workspace_url}/api/2.0/preview/scim/v2/Me", headers=self.headers) - # logger.info(res.text) - if res.status_code == 403: - logger.error("token not valid.") - return - if autoGenerateList: - logger.info( - "autoGenerateList parameter is set to TRUE. " - "Ignoring groupL parameter and instead will automatically generate list of migraiton groups." - ) - self.groupL = self.findMigrationEligibleGroups() - - # Finish setting some params that depend on groupL - if len(self.groupL) == 0: - raise Exception("Migration group list (groupL) is empty!") - - self.TempGroupNames = ["db-temp-" + g for g in self.groupL] - self.WorkspaceGroupNames = self.groupL - - logger.info( - f"Successfully initialized GroupMigration class " - f"with {len(self.groupL)} workspace-local groups to migrate. Groups to migrate:" - ) - for i, group in enumerate(self.groupL, start=1): - logger.info(f"{i}. {group}") - logger.info(f"Done listing {len(self.groupL)} groups to migrate.") - - def findMigrationEligibleGroups(self): - logger.info("Begin automatic generation of all migration eligible groups.") - # Get all workspace-local groups - try: - logger.info("Executing request to list workspace groups") - res = requests.get( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups", - headers=self.headers, - ) - if res.status_code != 200: - raise Exception(f"Bad status code. Expected: 200. Got: {res.status_code}") - - resJson = res.json() - - allWsLocalGroups = [ - o["displayName"] for o in resJson["Resources"] if o["meta"]["resourceType"] == "WorkspaceGroup" - ] - - # Prune special groups. - prune_groups = ["admins", "users"] - allWsLocalGroups = [g for g in allWsLocalGroups if g not in prune_groups] - allWsLocalGroups_lower = [x.casefold() for x in allWsLocalGroups] - allWsLocalGroups.sort() - logger.info( - f"\nFound {len(allWsLocalGroups)} workspace local groups. Listing (alphabetical order): \n" - + "\n".join(f"{i+1}. {name}" for i, name in enumerate(allWsLocalGroups)) - ) - - except Exception as e: - logger.error(f"ERROR in retrieving workspace group list: {e}") - raise - - # Now match against account groups. - try: - logger.info("\nExecuting request to list account groups") - res = requests.get( - f"{self.workspace_url}/api/2.0/account/scim/v2/Groups", - headers=self.headers, - ) - if res.status_code != 200: - raise Exception(f"Bad status code. Expected: 200. Got: {res.status_code}") - resJson2 = res.json() - allAccountGroups_lower = [r["displayName"].casefold() for r in resJson2["Resources"]] - allAccountGroups_lower.sort() - - # Get set intersection of both lists - migration_eligible_lower = list(set(allWsLocalGroups_lower) & set(allAccountGroups_lower)) - migration_eligible = [wsl for wsl in allWsLocalGroups if wsl.casefold() in migration_eligible_lower] - migration_eligible.sort() - - # Get list of items in allWsLocalGroups that are not in allAccountGroups - not_in_account_groups = [ - group for group in allWsLocalGroups if group.casefold() not in allAccountGroups_lower - ] - not_in_account_groups.sort() - - # logger.info count and membership of not_in_account_groups - logger.info( - f"Unable to match {len(not_in_account_groups)} current workspace-local groups. " - f"No matching account level group with the same name found. These groups WILL NOT MIGRATE:" - ) - for i, group in enumerate(not_in_account_groups, start=1): - logger.info(f"{i}. {group} (WON'T MIGRATE)") - - if len(migration_eligible) > 0: - # logger.info count and membership of intersection - logger.info( - f"\nFound {len(migration_eligible)} current workspace-local groups to account level groups. " - f"These groups WILL BE MIGRATED." - ) - for i, group in enumerate(migration_eligible, start=1): - logger.info(f"{i}. {group} (WILL MIGRATE)") - logger.info("") - - return migration_eligible - else: - logger.info( - "There are no migration eligible groups. " - "All existing workspace-local groups do not exist at the account level." - "\nNO MIGRATION WILL BE PERFORMED." - ) - return [] - except Exception as e: - logger.error(f"ERROR in retrieving account group list : {e}") - raise - - def validateWSGroup(self) -> list: - try: - res = requests.get( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups", - headers=self.headers, - ) - resJson = res.json() - for e in resJson["Resources"]: - if e["meta"]["resourceType"] == "Group" and e["displayName"] in self.groupL: - logger.info(f"{e['displayName']} is a Account level group, please provide workspace group") - return 0 - return 1 - except Exception as e: - logger.error(f"error in retrieving group objects : {e}") - - def getGroupObjects(self, groupFilterKeeplist) -> list: - try: - groupIdDict = {} - groupMembers = {} - groupEntitlements = {} - groupRoles = {} - res = requests.get( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups?attributes=id", headers=self.headers - ) - resJson = res.json() - totalGroups = resJson["totalResults"] - pages = totalGroups // 100 - # normalize case - groupFilterKeeplist = [x.casefold() for x in groupFilterKeeplist] - logger.info(f"Total groups: {totalGroups}. Retrieving group details in chunks of 100") - for i in range(0, pages + 1): - logger.info(f"Retrieving the next 100 items from {str(i*100+1)}") - - res = requests.get( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups?startIndex={str(i*100+1)}&count=100", - headers=self.headers, - ) - resJson = res.json() - # Iterate over workspace groups, extracting useful info to vars above - for e in resJson["Resources"]: - if e["displayName"].casefold() not in groupFilterKeeplist: - continue - - groupIdDict[e["id"]] = e["displayName"] - - # Get Group Members - members = [] - try: - for mem in e["members"]: - members.append(list([mem["display"], mem["value"], mem["$ref"]])) - except KeyError: - pass - groupMembers[e["id"]] = members - - # Get entitlements - entms = [] - try: - for ent in e["entitlements"]: - entms.append(ent["value"]) - except: # noqa: E722 - pass - - if len(entms) > 0: - groupEntitlements[e["id"]] = entms - - # Get Roles (AWS only) - if self.cloud == "AWS": - entms = [] - try: - for ent in e["roles"]: - entms.append(ent["value"]) - except: # noqa: E722 - continue - if len(entms) == 0: - continue - groupRoles[e["id"]] = entms - - # Finally assign to self (Now that exception hasn't been thrown) - # self.groupIdDict = groupIdDict - # self.groupMembers = groupMembers - # self.groupEntitlements = groupEntitlements - # self.groupRoles = groupRoles - # Create reverse of groupIdDict - groupNameDict = {} - for k, v in groupIdDict.items(): - groupNameDict[v] = k - return groupIdDict, groupMembers, groupEntitlements, groupRoles, groupNameDict - - except Exception as e: - logger.error(f"error in retrieving group objects : {e}") - - # get list of users and service principals recursively for groups and nested groups - def getRecursiveGroupMember(self, groupM: dict): - groupPrincipalList = [] - for key, value in groupM.items(): - try: - res = requests.get( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups/{key}", - headers=self.headers, - ) - resJson = res.json() - groupPrincipalList.append(resJson["displayName"]) - except Exception as e: - logger.info(f"error in retrieving group Names : {e}") - self.groupGroupList.extend(groupPrincipalList) - for key, value in groupM.items(): - userList = [u[1] for u in value if u[2].startswith("User")] - spList = [u[1] for u in value if u[2].startswith("ServicePrincipal")] - - groupList = [u[1] for u in value if u[2].startswith("Group")] - userPrincipalList = [] - spPrincipalList = [] - try: - groupMembers = {} - for g in groupList: - res = requests.get( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups/{g}", - headers=self.headers, - ) - resJson = res.json() - members = [] - try: - for mem in resJson["members"]: - members.append(list([mem["display"], mem["value"], mem["$ref"]])) - except KeyError: - pass - groupMembers[resJson["id"]] = members - except Exception as e: - logger.error(f"error in retrieving nested group members : {e}") - if len(groupMembers) > 0: - self.getRecursiveGroupMember(groupMembers) - - for userid in userList: - try: - res = requests.get( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Users/{userid}", - headers=self.headers, - ) - resJson = res.json() - userPrincipalList.append(resJson["userName"]) - except Exception as e: - logger.error(f"error in retrieving user details : {e}") - - for spid in spList: - try: - res = requests.get( - f"{self.workspace_url}/api/2.0/preview/scim/v2/ServicePrincipals/{spid}", - headers=self.headers, - ) - resJson = res.json() - spPrincipalList.append(resJson["applicationId"]) - except Exception as e: - logger.error(f"error in retrieving SP details : {e}") - self.groupUserList.extend(userPrincipalList) - self.groupSPList.extend(spPrincipalList) - - # getACL[n] family of functions extract the ACL - # from the converted json response into a standard format, filtering by groupL - def getACL(self, acls: dict) -> list: - aclList = [] - for acl in acls: - try: - if acl["all_permissions"][0]["inherited"] is True: - continue - aclList.append( - list( - [ - acl["group_name"], - acl["all_permissions"][0]["permission_level"], - ] - ) - ) - except KeyError: - continue - aclList = [acl for acl in aclList if acl[0] in self.groupL] - return aclList - - def getACL3(self, acls: dict) -> list: - aclList = [] - for acl in acls: - try: - aclList.append( - list( - [ - acl["group_name"], - acl["all_permissions"][0]["permission_level"], - ] - ) - ) - except KeyError: - continue - aclList = [acl for acl in aclList if acl[0] in self.groupL] - return aclList - - def getACL2(self, acls: dict) -> list: - aclList = [] - for acl in acls: - try: - acls_items = [] - for k, v in acl.items(): - acls_items.append(v) - aclList.append(acls_items) - except KeyError: - continue - for acl in aclList: - if acl[0] in self.groupL: - return aclList - return {} - - def getSingleClusterACL(self, clusterId): - if self.verbose: - logger.info(f"[Verbose] Getting cluster permissions for cluster {clusterId}") - resCPerm = requests.get( - f"{self.workspace_url}/api/2.0/preview/permissions/clusters/{clusterId}", - headers=self.headers, - ) - if resCPerm.status_code == 404: - logger.error(f"Error: cluster ACL not enabled for the cluster: {clusterId}") - return None - resCPermJson = resCPerm.json() - aclList = self.getACL(resCPermJson["access_control_list"]) - if len(aclList) == 0: - return None - return (clusterId, aclList) - - def getAllClustersACL(self) -> dict: - logger.info("Performing cluster inventory...") - try: - resC = requests.get(f"{self.workspace_url}/api/2.0/clusters/list", headers=self.headers) - resCJson = resC.json() - clusterPerm = {} - if len(resCJson) == 0: - return {} - logger.info(f"Scanning permissions of {len(resCJson['clusters'])} clusters.") - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - future_to_cluster = [ - executor.submit(self.getSingleClusterACL, c["cluster_id"]) for c in resCJson["clusters"] - ] - for future in concurrent.futures.as_completed(future_to_cluster): - result = future.result() - if result is not None: - clusterPerm[result[0]] = result[1] - return clusterPerm - except Exception as e: - logger.error(f"error in retrieving cluster permission: {e}") - - def getSingleClusterPolicyACL(self, policyId): - if self.verbose: - logger.info(f"[Verbose] Getting policy permissions for {policyId}") - resCPPerm = requests.get( - f"{self.workspace_url}/api/2.0/preview/permissions/cluster-policies/{policyId}", - headers=self.headers, - ) - if resCPPerm.status_code == 404: - logger.error(f"Error: cluster policy feature is not enabled for policy: {policyId}") - return None - resCPPermJson = resCPPerm.json() - aclList = self.getACL(resCPPermJson["access_control_list"]) - if len(aclList) == 0: - return None - return (policyId, aclList) - - def getAllClusterPolicyACL(self) -> dict: - logger.info("Performing cluster policy inventory...") - try: - resCP = requests.get( - f"{self.workspace_url}/api/2.0/policies/clusters/list", - headers=self.headers, - ) - resCPJson = resCP.json() - if resCPJson["total_count"] == 0: - logger.info("No cluster policies defined.") - return {} - logger.info(f"Scanning permissions of {len(resCPJson['policies'])} cluster policies.") - clusterPolicyPerm = {} - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - future_to_cluster = [ - executor.submit(self.getSingleClusterPolicyACL, c["policy_id"]) for c in resCPJson["policies"] - ] - for future in concurrent.futures.as_completed(future_to_cluster): - result = future.result() - if result is not None: - clusterPolicyPerm[result[0]] = result[1] - return clusterPolicyPerm - except Exception as e: - logger.error(f"Error in retrieving cluster policy permission: {e}") - - def getSingleWarehouseACL(self, warehouseId): - if self.verbose: - logger.info(f"[Verbose] Getting warehouse permissions for warehouse {warehouseId}") - resWPerm = requests.get( - f"{self.workspace_url}/api/2.0/preview/permissions/sql/warehouses/{warehouseId}", - headers=self.headers, - ) - if resWPerm.status_code == 404: - logger.error(f"Error: warehouse ACL not enabled for the warehouse: {warehouseId}") - return None - resWPermJson = resWPerm.json() - aclList = self.getACL(resWPermJson["access_control_list"]) - if len(aclList) == 0: - return None - return (warehouseId, aclList) - - def getAllWarehouseACL(self) -> dict: - logger.info("Performing warehouse inventory ...") - try: - resW = requests.get(f"{self.workspace_url}/api/2.0/sql/warehouses", headers=self.headers) - resWJson = resW.json() - warehousePerm = {} - if len(resWJson) == 0: - return {} - logger.info(f"Scanning permissions of {len(resWJson['warehouses'])} warehouses.") - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - future_to_warehouse = [ - executor.submit(self.getSingleWarehouseACL, w["id"]) for w in resWJson["warehouses"] - ] - for future in concurrent.futures.as_completed(future_to_warehouse): - result = future.result() - if result is not None: - warehousePerm[result[0]] = result[1] - return warehousePerm - except Exception as e: - logger.error(f"error in retrieving warehouse permission: {e}") - - def getAllDashboardACL(self, verbose=False) -> dict: - logger.info("Performing dashboard inventory ...") - try: - resD = requests.get( - f"{self.workspace_url}/api/2.0/preview/sql/dashboards", - headers=self.headers, - ) - resDJson = resD.json() - pages = math.ceil(resDJson["count"] / resDJson["page_size"]) - - dashboardPerm = {} - for pg in range(1, pages + 1): - if self.verbose: - logger.info(f"[Verbose] Requesting dashboard page {pg}...") - resD = requests.get( - f"{self.workspace_url}/api/2.0/preview/sql/dashboards?page={str(pg)}", - headers=self.headers, - ) - resDJson = resD.json() - results = resDJson["results"] - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - future_dashboard_perms = { - executor.submit(self.getSingleDashboardACL, dashboard["id"]): dashboard["id"] - for dashboard in results - } - for future in concurrent.futures.as_completed(future_dashboard_perms): - dashboard_id = future_dashboard_perms[future] - try: - result = future.result() - if len(result) > 0: - dashboardPerm[dashboard_id] = result - except Exception as e: - logger.error(f"Error in retrieving dashboard permission for dashboard {dashboard_id}: {e}") - return dashboardPerm - - except Exception as e: - logger.error(f"Error in retrieving dashboard permission: {e}") - raise e - - # this request sometimes fails so we wrap in retry loop - def getSingleDashboardACL(self, dashboardId) -> list: - RETRY_LIMIT = 3 - RETRY_DELAY = 500 / 1000 # 500 ms - retry_count = 0 - while retry_count < RETRY_LIMIT: - if retry_count > 0: - time.sleep(RETRY_DELAY) - if self.verbose: - logger.info(f"[Verbose] Requesting dashboard id {dashboardId}. retry_count={retry_count}") - resDPerm = requests.get( - f"{self.workspace_url}/api/2.0/preview/sql/permissions/dashboards/{dashboardId}", - headers=self.headers, - ) - if resDPerm.status_code != 200: - retry_count += 1 - continue - try: - resDPermJson = resDPerm.json() - aclList = resDPermJson["access_control_list"] - dashboard_acl = [] - if len(aclList) > 0: - for acl in aclList: - try: - if acl["group_name"] in self.groupL: - dashboard_acl = aclList - break - except KeyError: - continue - return dashboard_acl - except KeyError: - retry_count += 1 - continue - logger.error(f"ERROR: Retry limit of {RETRY_LIMIT} exceeded requesting dashboard id {dashboardId}") - return [] # if retry limit exceeded, return empty list - - def getAllQueriesACL(self, verbose=False) -> dict: - logger.info("Performing query inventory ...") - try: - resQ = requests.get( - f"{self.workspace_url}/api/2.0/preview/sql/queries", - headers=self.headers, - ) - resQJson = resQ.json() - pages = math.ceil(resQJson["count"] / resQJson["page_size"]) - - queryPerm = {} - for pg in range(1, pages + 1): - if self.verbose: - logger.info(f"[Verbose] Requesting query page {pg}...") - resQ = requests.get( - f"{self.workspace_url}/api/2.0/preview/sql/queries?page={str(pg)}", - headers=self.headers, - ) - resQJson = resQ.json() - results = resQJson["results"] - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - future_query_perms = { - executor.submit(self.getSingleQueryACL, query["id"]): query["id"] for query in results - } - for future in concurrent.futures.as_completed(future_query_perms): - query_id = future_query_perms[future] - try: - result = future.result() - if len(result) > 0: - queryPerm[query_id] = result - except Exception as e: - logger.error(f"Error in retrieving query permission for query {query_id}: {e}") - return queryPerm - - except Exception as e: - logger.error(f"Error in retrieving query permission: {e}") - raise e - - # this request sometimes fails so we wrap in retry loop - def getSingleQueryACL(self, queryId) -> list: - RETRY_LIMIT = 3 - RETRY_DELAY = 500 / 1000 # 500 ms - retry_count = 0 - while retry_count < RETRY_LIMIT: - if retry_count > 0: - time.sleep(RETRY_DELAY) - if self.verbose: - logger.info(f"[Verbose] Requesting query id {queryId}. retry_count={retry_count}") - resQPerm = requests.get( - f"{self.workspace_url}/api/2.0/preview/sql/permissions/queries/{queryId}", - headers=self.headers, - ) - if resQPerm.status_code != 200: - retry_count += 1 - continue - try: - resQPermJson = resQPerm.json() - aclList = resQPermJson["access_control_list"] - query_acl = [] - if len(aclList) > 0: - for acl in aclList: - try: - if acl["group_name"] in self.groupL: - query_acl = aclList - break - except KeyError: - continue - return query_acl - except KeyError: - retry_count += 1 - continue - logger.error(f"ERROR: Retry limit of {RETRY_LIMIT} exceeded requesting query id {queryId}") - return [] # if retry limit exceeded, return empty list - - def getAlertsACL(self) -> dict: - try: - resA = requests.get(f"{self.workspace_url}/api/2.0/preview/sql/alerts", headers=self.headers) - resAJson = resA.json() - alertPerm = {} - for c in resAJson: - alertId = c["id"] - resAPerm = requests.get( - f"{self.workspace_url}/api/2.0/preview/sql/permissions/alerts/{alertId}", - headers=self.headers, - ) - if resAPerm.status_code == 404: - logger.error("feature not enabled for this tier") - continue - resAPermJson = resAPerm.json() - aclList = resAPermJson["access_control_list"] - if len(aclList) == 0: - continue - for acl in aclList: - try: - if acl["group_name"] in self.groupL: - alertPerm[alertId] = aclList - break - except KeyError: - continue - return alertPerm - - except Exception as e: - logger.error(f"error in retrieving alerts permission: {e}") - - def getPasswordACL(self) -> dict: - try: - if self.cloud != "AWS": - return - resP = requests.get( - f"{self.workspace_url}/api/2.0/preview/permissions/authorization/passwords", - headers=self.headers, - ) - resPJson = resP.json() - if len(resPJson) < 3: - logger.info("No password acls defined.") - return {} - - passwordPerm = {} - passwordPerm["passwords"] = self.getACL(resPJson["access_control_list"]) - return passwordPerm - except Exception as e: - logger.error(f"error in retrieving password permission: {e}") - - def getPoolACL(self) -> dict: - try: - resIP = requests.get( - f"{self.workspace_url}/api/2.0/instance-pools/list", - headers=self.headers, - ) - resIPJson = resIP.json() - if len(resIPJson) == 0: - logger.info("No Instance Pools defined.") - return {} - instancePoolPerm = {} - for c in resIPJson["instance_pools"]: - instancePID = c["instance_pool_id"] - resIPPerm = requests.get( - f"{self.workspace_url}/api/2.0/preview/permissions/instance-pools/{instancePID}", - headers=self.headers, - ) - if resIPPerm.status_code == 404: - logger.info("feature not enabled for this tier") - continue - resIPPermJson = resIPPerm.json() - aclList = self.getACL(resIPPermJson["access_control_list"]) - if len(aclList) == 0: - continue - instancePoolPerm[instancePID] = aclList - return instancePoolPerm - except Exception as e: - logger.error(f"error in retrieving Instance Pool permission: {e}") - - def getAllJobACL(self) -> dict: - logger.info("Running job ACL inventory ...") - try: - jobPerm = {} - offset = 0 - limit = 25 # 25 is the max before the api complains - while True: - # Query next page - if self.verbose: - logger.info(f"[Verbose] Retrieving jobs page offset={offset}, limit={limit}") - resJob = requests.get( - f"{self.workspace_url}/api/2.1/jobs/list?limit={str(limit)}&offset={str(offset)}", - headers=self.headers, - ) - resJobJson = resJob.json() - - if resJobJson["has_more"] is False and len(resJobJson) == 1: - logger.info("Finished listing jobs") - break - - # Grab job IDs and parallel map over them to get all ACLs - jobIDs = [c["job_id"] for c in resJobJson["jobs"]] - with concurrent.futures.ThreadPoolExecutor() as executor: - results = executor.map(self.getSingleJobACL, jobIDs) - for result in results: - if result is not None: - jobPerm[result[0]] = result[1] - # Check for finish? - if not resJobJson["has_more"]: - break - offset += limit - return jobPerm - except Exception as e: - logger.error(f"error in retrieving job permissions: {e}") - - def getSingleJobACL(self, jobID): - try: - resJobPerm = requests.get( - f"{self.workspace_url}/api/2.0/permissions/jobs/{jobID}", - headers=self.headers, - ) - if resJobPerm.status_code == 404: - logger.error("feature not enabled for this tier") - return None - resJobPermJson = resJobPerm.json() - aclList = self.getACL(resJobPermJson["access_control_list"]) - if len(aclList) == 0: - return None - return (jobID, aclList) - except Exception as e: - logger.error(f"error in retrieving permission for job {jobID}: {e}") - return None - - def getExperimentACL(self) -> dict: - try: - nextPageToken = "" - expPerm = {} - while True: - data = {} - data = {"max_results": 100} - if nextPageToken != "": - data = {"page_token": nextPageToken, "max_results": "100"} - - resExp = requests.get( - f"{self.workspace_url}/api/2.0/mlflow/experiments/list", - headers=self.headers, - data=json.dumps(data), - ) - resExpJson = resExp.json() - if len(resExpJson) == 0: - logger.info("No experiments available") - return {} - for c in resExpJson["experiments"]: - expID = c["experiment_id"] - # logger.info(c) - - for k in c["tags"]: - if k["key"] == "mlflow.experimentType": - if k["value"] == "NOTEBOOK": - # logger.info('notebook') - resExpPerm = requests.get( - f"{self.workspace_url}/api/2.0/permissions/notebooks/{expID}", - headers=self.headers, - ) - else: - # logger.info('experiment') - resExpPerm = requests.get( - f"{self.workspace_url}/api/2.0/permissions/experiments/{expID}", - headers=self.headers, - ) - # resExpPerm=requests.get( - # f"{self.workspace_url}/api/2.0/permissions/experiments/{expID}", headers=self.headers - # ) - if resExpPerm.status_code == 404: - logger.error("feature not enabled for this tier") - continue - resExpPermJson = resExpPerm.json() - if resExpPerm.status_code != 200: - logger.error(f"unable to get permission for experiment {expID}") - continue - aclList = self.getACL(resExpPermJson["access_control_list"]) - if len(aclList) == 0: - continue - - expPerm[expID] = aclList - try: - nextPageToken = resExpJson["next_page_token"] - # break - except KeyError: - break - return expPerm - except Exception as e: - logger.error(f"error in retrieving experiment permission: {e}") - - def getModelACL(self) -> dict: - try: - nextPageToken = "" - while True: - data = {} - data = {"max_results": 20} - if nextPageToken != "": - data = {"page_token": nextPageToken} - resModel = requests.get( - f"{self.workspace_url}/api/2.0/mlflow/registered-models/list", - headers=self.headers, - data=json.dumps(data), - ) - resModelJson = resModel.json() - if len(resModelJson) == 0: - logger.info("No models available") - return {} - modelPerm = {} - for c in resModelJson["registered_models"]: - modelName = c["name"] - param = {"name": modelName} - modIDRes = requests.get( - f"{self.workspace_url}/api/2.0/mlflow/databricks/registered-models/get", - headers=self.headers, - data=json.dumps(param), - ) - modelID = modIDRes.json()["registered_model_databricks"]["id"] - resModelPerm = requests.get( - f"{self.workspace_url}/api/2.0/permissions/registered-models/{modelID}", - headers=self.headers, - ) - if resModelPerm.status_code == 404: - logger.error("feature not enabled for this tier") - continue - resModelPermJson = resModelPerm.json() - aclList = self.getACL(resModelPermJson["access_control_list"]) - if len(aclList) == 0: - continue - modelPerm[modelID] = aclList - try: - nextPageToken = resModelJson["next_page_token"] - # break - except KeyError: - break - return modelPerm - except Exception as e: - logger.error(f"error in retrieving model permission: {e}") - - def getDLTACL(self) -> dict: - try: - nextPageToken = "" - dltPerm = {} - while True: - data = {} - data = {"max_results": 20} - if nextPageToken != "": - data = {"page_token": nextPageToken} - resDlt = requests.get( - f"{self.workspace_url}/api/2.0/pipelines", - headers=self.headers, - data=json.dumps(data), - ) - resDltJson = resDlt.json() - if len(resDltJson) == 0: - logger.info("No dlt pipelines available") - return {} - for c in resDltJson["statuses"]: - dltID = c["pipeline_id"] - resDltPerm = requests.get( - f"{self.workspace_url}/api/2.0/permissions/pipelines/{dltID}", - headers=self.headers, - ) - if resDltPerm.status_code == 404: - logger.error("feature not enabled for this tier") - continue - resDltPermJson = resDltPerm.json() - aclList = self.getACL(resDltPermJson["access_control_list"]) - if len(aclList) == 0: - continue - dltPerm[dltID] = aclList - try: - nextPageToken = resDltJson["next_page_token"] - # break - except KeyError: - break - - return dltPerm - except Exception as e: - logger.error(f"error in retrieving dlt pipelines permission: {e}") - - def getRecursiveFolderList(self, path: str) -> dict: - logger.info(f"Getting directory structure starting with root path: {path} ...") - - self.folderList.clear() - self.notebookList.clear() - self.fileList.clear() - remaining_dirs = [path] - depth = 0 - while remaining_dirs: - if self.verbose: - logger.info(f"[Verbose] Requesting file list for Depth {depth} Path: {path}") - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - futuresMap = { - executor.submit(self.getSingleFolderList, dir_path, depth): dir_path for dir_path in remaining_dirs - } - for future in concurrent.futures.as_completed(futuresMap): - dir_path = futuresMap[future] - res = future.result() - if res: - dir_path2, folders, notebooks, files = res - if dir_path2 != dir_path: - logger.error(f"ERROR: got WRONG RESULT from future: sent: {dir_path} recieved: {dir_path2}") - remaining_dirs.remove(dir_path2) - # todo: what?? - else: - self.folderList.update(folders) - self.notebookList.update(notebooks) - self.fileList.update(files) - remaining_dirs.extend(dir_path for dir_path in folders.values()) - else: - logger.error(f"ERROR: one of the futurue results was None: {dir_path}") - remaining_dirs.remove(dir_path) - depth = depth + 1 - - return (self.folderList, self.notebookList, self.fileList) - - def getSingleFolderList(self, path: str, depth: int) -> dict: - MAX_RETRY = 5 - RETRY_DELAY = 500 / 1000 - retry_count = 0 - lastError = "" - while retry_count < MAX_RETRY: - # Give some time for the server to recover - if retry_count > 0: - time.sleep(RETRY_DELAY) - # logger.info(f'[ERROR] retrying folder list for folder {path}.') - if self.verbose: - logger.info(f"[Verbose] Requesting file list for Depth {depth} Retry {retry_count} Path: {path}") - retry_count = retry_count + 1 - try: - data = {"path": path} - resFolder = requests.get( - f"{self.workspace_url}/api/2.0/workspace/list", - headers=self.headers, - data=json.dumps(data), - ) - if resFolder.status_code == 403: - logger.error(f"[ERROR] status code 403 permission denied to read folder {path}.") - return (path, {}, {}) - if resFolder.status_code != 200: - logger.error(f"[ERROR] bad status code for folder {path}. code: {resFolder.status_code}") - continue - resFolderJson = resFolder.json() - - subFolders = {} - notebooks = {} - files = {} - if len(resFolderJson) == 0: - return (path, subFolders, notebooks, files) - - for c in resFolderJson["objects"]: - if ( - c["object_type"] == "DIRECTORY" - and c["path"].startswith("/Shared") is False - and c["path"].endswith("/Trash") is False - ): - subFolders[c["object_id"]] = c["path"] - elif ( - c["object_type"] == "NOTEBOOK" - and c["path"].startswith("/Repos") is False - and c["path"].startswith("/Shared") is False - ): - notebooks[c["object_id"]] = c["path"] - elif ( - c["object_type"] == "FILE" - and c["path"].startswith("/Repos") is False - and c["path"].startswith("/Shared") is False - ): - files[c["object_id"]] = c["path"] - return (path, subFolders, notebooks, files) - - except Exception as e: - lastError = e - continue - logger.error( - f"[ERROR] retry limit ({MAX_RETRY}) limit exceeded while retrieving path {path}. last err: {lastError}." - ) - return (path, {}, {}, {}) - - def getFoldersNotebookACL(self, rootPath="/") -> list: - logger.info("Performing folders and notebook inventory ...") - try: - # Get folder list - self.getRecursiveFolderList(rootPath) - - # Collect folder IDs, ignoring suffix /Trash to avoid useless errors. - # /Repos and /Shared are ignored at the folder list level - folder_ids = [ - folder_id for folder_id in self.folderList.keys() if not self.folderList[folder_id].endswith("/Trash") - ] - - # Get folder permissions in parallel - folder_results = {} - currentFolderCount = 0 - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - folder_futures = { - executor.submit( - requests.get, - f"{self.workspace_url}/api/2.0/permissions/directories/{folder_id}", - headers=self.headers, - ): folder_id - for folder_id in folder_ids - } - logger.info(f"Awaiting parallel permission requests for {len(folder_futures)} folders ...") - for future in concurrent.futures.as_completed(folder_futures): - folder_id = folder_futures[future] - try: - resFolderPerm = future.result() - currentFolderCount += 1 - if resFolderPerm.status_code == 404: - logger.error("feature not enabled for this tier") - continue - if resFolderPerm.status_code == 403: - logger.error( - "Error retrieving permission for " - + self.folderList[folder_id] - + " " - + resFolderPerm.json()["message"] - ) - continue - resFolderPermJson = resFolderPerm.json() - try: - aclList = self.getACL(resFolderPermJson["access_control_list"]) - except Exception as e: - logger.error(f"error in retrieving folder details: {e}") - if currentFolderCount % 1000 == 0: - logger.info(f"Completed ACL for {currentFolderCount} folders") - if len(aclList) == 0: - continue - folder_results[folder_id] = aclList - - except Exception as e: - logger.error(f"error in retrieving folder permission: {e}") - - # Get notebook permissions in parallel - notebook_results = {} - currentNotebookCount = 0 - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - notebook_futures = { - executor.submit( - requests.get, - f"{self.workspace_url}/api/2.0/permissions/notebooks/{notebook_id}", - headers=self.headers, - ): notebook_id - for notebook_id in self.notebookList.keys() - } - logger.info(f"Awaiting parallel permission requests for {len(notebook_futures)} notebooks ...") - for future in concurrent.futures.as_completed(notebook_futures): - notebook_id = notebook_futures[future] - try: - resNotebookPerm = future.result() - currentNotebookCount += 1 - if resNotebookPerm.status_code == 404: - logger.error("feature not enabled for this tier") - continue - if resNotebookPerm.status_code == 403: - logger.error( - "Error retrieving permission for " - + self.notebookList[notebook_id] - + " " - + resNotebookPerm.json()["message"] - ) - continue - resNotebookPermJson = resNotebookPerm.json() - try: - aclList = self.getACL(resNotebookPermJson["access_control_list"]) - except Exception as e: - logger.error(f"error in retrieving notebook details: {e}") - if currentNotebookCount % 1000 == 0: - logger.info(f"Completed ACL for {currentNotebookCount} notebooks") - if len(aclList) == 0: - continue - notebook_results[notebook_id] = aclList - - except Exception as e: - logger.error(f"error in retrieving notebook permission: {e}") - - # Get file permissions in parallel - file_results = {} - currentFileCount = 0 - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - file_futures = { - executor.submit( - requests.get, - f"{self.workspace_url}/api/2.0/permissions/files/{file_id}", - headers=self.headers, - ): file_id - for file_id in self.fileList.keys() - } - logger.info(f"Awaiting parallel permission requests for {len(file_futures)} files ...") - for future in concurrent.futures.as_completed(file_futures): - file_id = file_futures[future] - try: - resFilePerm = future.result() - currentFileCount += 1 - if resFilePerm.status_code == 404: - logger.error("feature not enabled for this tier") - continue - if resFilePerm.status_code == 403: - logger.error( - "Error retrieving permission for " - + self.fileList[file_id] - + " " - + resFilePerm.json()["message"] - ) - continue - resFilePermJson = resFilePerm.json() - try: - aclList = self.getACL(resFilePermJson["access_control_list"]) - except Exception as e: - logger.error(f"error in retrieving file details: {e}") - if currentFileCount % 1000 == 0: - logger.info(f"Completed ACL for {currentFileCount} notebooks") - if len(aclList) == 0: - continue - file_results[file_id] = aclList - - except Exception as e: - logger.error(f"error in retrieving file permission: {e}") - - return folder_results, notebook_results, file_results - except Exception as e: - logger.error(f"error in retrieving folder and notebook permissions: {e}") - - def getRepoACL(self) -> dict: - try: - nextPageToken = "" - repoPerm = {} - while True: - data = {} - data = {"max_results": 20} - if nextPageToken != "": - data = {"next_page_token": nextPageToken} - resRepo = requests.get( - f"{self.workspace_url}/api/2.0/repos", - headers=self.headers, - data=json.dumps(data), - ) - resRepoJson = resRepo.json() - if len(resRepoJson) == 0: - logger.info("No repos available") - return {} - for c in resRepoJson["repos"]: - repoID = c["id"] - resRepoPerm = requests.get( - f"{self.workspace_url}/api/2.0/permissions/repos/{repoID}", - headers=self.headers, - ) - if resRepoPerm.status_code == 404: - logger.error("feature not enabled for this tier") - continue - resRepoPermJson = resRepoPerm.json() - aclList = self.getACL3(resRepoPermJson["access_control_list"]) - if len(aclList) == 0: - continue - repoPerm[repoID] = aclList - try: - nextPageToken = resRepoJson["next_page_token"] - except KeyError: - break - - return repoPerm - except Exception as e: - logger.error(f"error in retrieving repos permission: {e}") - - def getTokenACL(self) -> dict: - try: - tokenPerm = {} - resTokenPerm = requests.get( - f"{self.workspace_url}/api/2.0/preview/permissions/authorization/tokens", - headers=self.headers, - ) - if resTokenPerm.status_code == 404: - logger.error("feature not enabled for this tier") - return {} - resTokenPermJson = resTokenPerm.json() - aclList = [] - for acl in resTokenPermJson["access_control_list"]: - try: - if acl["all_permissions"][0]["inherited"] is True: - continue - aclList.append( - list( - [ - acl["group_name"], - acl["all_permissions"][0]["permission_level"], - ] - ) - ) - except KeyError: - continue - aclList = [acl for acl in aclList if acl[0] in self.groupL] - tokenPerm["tokens"] = aclList - return tokenPerm - except Exception as e: - logger.error(f"error in retrieving Token permission: {e}") - return {} - - def getSecretScoppeACL(self) -> dict: - try: - resSScope = requests.get( - f"{self.workspace_url}/api/2.0/secrets/scopes/list", - headers=self.headers, - ) - resSScopeJson = resSScope.json() - if len(resSScopeJson) == 0: - logger.info("No secret scopes defined.") - return {} - - secretScopePerm = {} - for c in resSScopeJson["scopes"]: - scopeName = c["name"] - data = {"scope": scopeName} - resSSPerm = requests.get( - f"{self.workspace_url}/api/2.0/secrets/acls/list/", - headers=self.headers, - data=json.dumps(data), - ) - - if resSSPerm.status_code == 404: - logger.error("feature not enabled for this tier") - continue - if resSSPerm.status_code != 200: - logger.error( - f"Error retrieving ACL for Secret Scope: {scopeName}. HTTP Status Code {resSSPerm.status_code}" - ) - continue - - resSSPermJson = resSSPerm.json() - if "items" not in resSSPermJson: - # logger.info( - # f'ACL for Secret Scope {scopeName} missing "items" key. Contents:\n{resSSPermJson}\nSkipping...' - # ) - # This seems to be expected behaviour if there are no ACLs, silently ignore - continue - - aclList = [] - for acl in resSSPermJson["items"]: - try: - if acl["principal"] in self.groupL: - aclList.append(list([acl["principal"], acl["permission"]])) - except KeyError: - continue - if len(aclList) == 0: - continue - secretScopePerm[scopeName] = aclList - - return secretScopePerm - except Exception as e: - logger.error(f"error in retrieving Secret Scope permission: {e}") - - def updateGroupEntitlements(self, groupEntitlements: dict, level: str): - try: - for group_id, etl in groupEntitlements.items(): - entitlementList = [] - if level == "Workspace": - groupId = self.groupWSGNameDict["db-temp-" + self.groupIdDict[group_id]] - else: # Account, aka temp group, must discard db-temp- (8 chars) - groupId = self.accountGroups_lower[self.groupIdDict[group_id][8:].casefold()] - # logger.info(groupId) - for e in etl: - entitlementList.append({"value": e}) - entitlements = { - "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], - "Operations": [{"op": "add", "path": "entitlements", "value": entitlementList}], - } - requests.patch( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups/{groupId}", - headers=self.headers, - data=json.dumps(entitlements), - ) - except Exception as e: - logger.error(f"error applying entitiement for group id: {group_id}.") - - def updateGroupRoles(self, level: str): - try: - for group_id, roles in self.groupRoles.items(): - roleList = [] - if level == "Workspace": - groupId = self.groupWSGNameDict["db-temp-" + self.groupIdDict[group_id]] - else: # Account, aka temp group, must discard db-temp- (8 chars) - groupId = self.accountGroups_lower[self.groupIdDict[group_id][8:].casefold()] - for e in roles: - roleList.append({"value": e}) - instanceProfileRoles = { - "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], - "Operations": [{"op": "add", "path": "roles", "value": roleList}], - } - requests.patch( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups/{groupId}", - headers=self.headers, - data=json.dumps(instanceProfileRoles), - ) - except Exception as e: - logger.error(f"error applying role for group id: {group_id}.{e}") - - def updateGroupPermission(self, object: str, groupPermission: dict, level: str): - try: - for object_id, aclList in groupPermission.items(): - dataAcl = [] - for acl in aclList: - if level == "Workspace": - gName = "db-temp-" + acl[0] - elif level == "Account": - gName = acl[0][8:] - dataAcl.append({"group_name": gName, "permission_level": acl[1]}) - data = {"access_control_list": dataAcl} - requests.patch( - f"{self.workspace_url}/api/2.0/preview/permissions/{object}/{object_id}", - headers=self.headers, - data=json.dumps(data), - ) - except Exception as e: - logger.error(f"Error setting permission for {object} {object_id}. {e} ") - - def updateGroup2Permission(self, object: str, groupPermission: dict, level: str): - try: - for object_id, aclList in groupPermission.items(): - dataAcl = [] - for acl in aclList: - try: - gName = acl["group_name"] - if gName == "ADMIN" and acl["permission_level"] != "CAN_MANAGE": - dataAcl.append({"group_name": gName, "permission_level": "CAN_MANAGE"}) - if level == "Workspace": - if acl["group_name"] in self.WorkspaceGroupNames: - gName = "db-temp-" + acl["group_name"] - dataAcl.append({"group_name": gName, "permission_level": acl["permission_level"]}) - elif level == "Account": - if acl["group_name"] in self.TempGroupNames: - gName = acl["group_name"][8:] - else: - gName = acl["group_name"] - dataAcl.append(acl) - except KeyError: - dataAcl.append(acl) - continue - data = {"access_control_list": dataAcl} - requests.post( - f"{self.workspace_url}/api/2.0/preview/sql/permissions/{object}/{object_id}", - headers=self.headers, - data=json.dumps(data), - ) - except Exception as e: - logger.error(f"Error setting permission for {object} {object_id}. {e} ") - - def updateSecretPermission(self, secretPermission: dict, level: str): - try: - for object_id, aclList in secretPermission.items(): - for acl in aclList: - if level == "Workspace": - gName = "db-temp-" + acl[0] - elif level == "Account": - gName = acl[0][8:] - data = {"scope": object_id, "principal": gName, "permission": acl[1]} - requests.post( - f"{self.workspace_url}/api/2.0/secrets/acls/put", - headers=self.headers, - data=json.dumps(data), - ) - except Exception as e: - logger.error(f"Error setting permission for scope {object_id}. {e} ") - - def runVerboseSql(self, queryString): - if self.verbose: - logger.info(f"[Verbose] SQL: {queryString}") - return self.spark.sql(queryString) - - def getGrantsOnObjects(self, database_name: str, object_type: str, object_key: str): - try: - if object_type in [ - "CATALOG", - "ANY FILE", - "ANONYMOUS FUNCTION", - ]: # without object key - grants_df = ( - self.spark.sql(f"SHOW GRANT ON {object_type}") - .groupBy("ObjectType", "ObjectKey", "Principal") - .agg(collect_set("ActionType").alias("ActionTypes")) - .selectExpr( - "CAST(NULL AS STRING) AS Database", - "Principal", - "ActionTypes", - "ObjectType", - "ObjectKey", - ) - ) - else: - grants_df = ( - self.spark.sql(f"SHOW GRANT ON {object_type} {object_key}") - .filter(col("ObjectType") == f"{object_type}") - .groupBy("ObjectType", "ObjectKey", "Principal") - .agg(collect_set("ActionType").alias("ActionTypes")) - .selectExpr( - f"'{database_name}' AS Database", - "Principal", - "ActionTypes", - "ObjectType", - "ObjectKey", - ) - ) - except Exception as e: - logger.error(f"Error retrieving grants on object {object_key}. {e}") - return - - return grants_df - - def getDBACL(self, db: str): - try: - aclList = [] - dbdf = self.getGrantsOnObjects(db, "DATABASE", db) - aclList += dbdf.collect() - if not self.checkAllDB: - userListCollect = ( - dbdf.filter(col("ObjectType") == "DATABASE") - .filter((array_contains(col("ActionTypes"), "USAGE") | array_contains(col("ActionTypes"), "OWN"))) - .select(col("Principal")) - .collect() - ) - userList = [p.Principal for p in userListCollect] - userList = list(set(userList)) - if not self.checkPrincipalInGroupOrMember(userList, db): - # logger.info( - # f'selected groups or members of the groups have no USAGE or OWN permission on database level.' - # 'Skipping object level permission check for database {db}.' - # ) - return [] - - tables = self.runVerboseSql("show tables in spark_catalog.{}".format(db)).filter( - col("isTemporary") == False # noqa: E712 - ) - for table in tables.collect(): - try: - tbldf = self.getGrantsOnObjects(db, "TABLE", f"`{table.database}`.`{table.tableName}`") - aclList += tbldf.collect() - except Exception as e: - logger.error(f"error retrieving acl for table {table.tableName}. {e}") - - functions = self.runVerboseSql("show functions in {}".format(db)).filter( - col("function").startswith("spark_catalog." + db + ".") - ) - for function in functions.collect(): - try: - funcdf = self.getGrantsOnObjects(db, "FUNCTION", f"{function.function}") - aclList += funcdf.collect() - except Exception as e: - logger.error(f"error retrieving acl for function {function.function}. {e}") - # filter for required groups - return aclList - - except Exception as e: - logger.info(f"Error retrieving ACL for database {db}. {e}") - - # check principal given usage permission is group or a member of the group (user or sp) - def checkPrincipalInGroupOrMember(self, principalList: str, name: str) -> bool: - for p in principalList: - if p in self.groupGroupList: - logger.info(f"Group {p} is given USAGE or OWN permission for {name}.") - return True - for p in principalList: - if p in self.groupUserList: - logger.info(f"User {p} is given USAGE or OWN permission for {name}.") - return True - for p in principalList: - if p in self.groupSPList: - logger.info(f"SP {p} is given USAGE or OWN permission for {name}.") - return True - return False - - def getTableACLs(self) -> list: - self.groupUserList = [] - self.groupSPList = [] - self.getRecursiveGroupMember(self.groupMembers) - self.groupUserList = list(set(self.groupUserList)) - self.groupSPList = list(set(self.groupSPList)) - # ANONYMOUS FUNCTION - common_df = self.getGrantsOnObjects(None, "ANONYMOUS FUNCTION", None) - # ANY FILE - common_df = common_df.unionAll(self.getGrantsOnObjects(None, "ANY FILE", None)) - # CATALOG - common_df = common_df.unionAll(self.getGrantsOnObjects(None, "CATALOG", None)) - aclList = [] - aclList = common_df.collect() - # check if any group is given permission at catalog level - userListCollect = ( - common_df.filter(col("ObjectType") == "CATALOG$") - .filter(array_contains(col("ActionTypes"), "USAGE")) - .select(col("Principal")) - .collect() - ) - userList = [p.Principal for p in userListCollect] - userList = list(set(userList)) - if self.checkPrincipalInGroupOrMember(userList, "CATALOG"): - logger.info( - "some groups or members of the group given " - "permission at catalog level, running permission for all databases" - ) - self.checkAllDB = True - database_names = [] - dbs = self.spark.sql("show databases").collect() - len(database_names) - for db in dbs: - database_names.append(db.databaseName) - logger.info(f"Got {len(database_names)} dbs to query") - # database_names=['aaron_binns','hsdb'] - currentCount = 0 - try: - # aclList = [] - aclFinalList = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor: - future_db = [executor.submit(self.getDBACL, f"`{databaseName}`") for databaseName in database_names] - for future in concurrent.futures.as_completed(future_db): - result = future.result() - if result is not None: - aclList += result - currentCount += 1 - if currentCount % 100 == 0: - logger.info(f"Completed ACL for {currentCount} databases") - aclFinalList = [acl for acl in aclList if acl.Principal in self.groupL] - except Exception as e: - logger.error(f"Error retrieving table acl object permission {e}") - return aclFinalList - - def generate_table_acls_command(self, action_types, object_type, object_key, groupName): - lines = [] - grant_privs = [x for x in action_types if not x.startswith("DENIED_") and x != "OWN"] - deny_privs = [x[len("DENIED_") :] for x in action_types if x.startswith("DENIED_") and x != "OWN"] - if grant_privs: - lines.append(f"GRANT {', '.join(grant_privs)} ON {object_type} {object_key} TO `{groupName}`;") - if deny_privs: - lines.append(f"DENY {', '.join(deny_privs)} ON {object_type} {object_key} TO `{groupName}`;") - if "OWN" in action_types: - lines.append(f"ALTER {object_type} {object_key} OWNER TO `{groupName}`;") - return lines - - def updateDataObjectsPermission(self, aclList: List, level: str): - try: - lines = [] - for acl in aclList: - # if acl.ObjectType!="DATABASE" and acl.ActionType=="USAGE": continue - if level == "Workspace": - gName = "db-temp-" + acl.Principal - elif level == "Account": - gName = acl.Principal[8:] - if acl.ObjectType == "ANONYMOUS_FUNCTION": - lines.extend(self.generate_table_acls_command(acl.ActionTypes, "ANONYMOUS FUNCTION", "", gName)) - elif acl.ObjectType == "ANY_FILE": - lines.extend(self.generate_table_acls_command(acl.ActionTypes, "ANY FILE", "", gName)) - elif acl.ObjectType == "CATALOG$": - lines.extend(self.generate_table_acls_command(acl.ActionTypes, "CATALOG", "", gName)) - elif acl.ObjectType in ["DATABASE", "TABLE"]: - # DATABASE, TABLE, VIEW (view's seem to show up as tables) - lines.extend( - self.generate_table_acls_command(acl.ActionTypes, acl.ObjectType, acl.ObjectKey, gName) - ) - # lines.extend(self.generate_table_acls_command(acl.ActionTypes, acl.ObjectType, acl.ObjectKey, gName)) - for aclQuery in lines: - # logger.info(aclQuery) - self.runVerboseSql(aclQuery) - except Exception as e: - logger.error(f"Error setting permission, {e} ") - - def setGroupListForMode(self, mode: str): - logger.info(f"Retrieving group metadata for mode: {mode}") - if mode == "Workspace": - self.groupL = self.WorkspaceGroupNames - ( - self.groupIdDict, - self.groupMembers, - self.groupEntitlements, - self.groupRoles, - self.groupNameDict, - ) = self.getGroupObjects(self.groupL) - elif mode == "Account": - self.groupL = self.TempGroupNames - ( - self.groupIdDict, - self.groupMembers, - self.groupEntitlements, - self.groupRoles, - self.groupNameDict, - ) = self.getGroupObjects(self.groupL) - else: - raise ValueError(f"mode {mode} not supported. Valid values are 'Workspace' and 'Account'") - - def clearInventoryCache(self): - self.lastInventoryRun = None - - def isObjectInventoryPresent(self, mode: str, objectType: str, tableACL: bool = False): - try: - groupType = "" - if mode == "Workspace": - # print(f'Saving data for workspace groups for {objectType} in {self.inventoryTableName} table.') - groupType = "WorkspaceLocal" - else: - # print(f'Saving data for workspace temp groups for {objectType} in {self.inventoryTableName} table.') - groupType = "WorkspaceTemp" - if tableACL: - checkSQL = f"select count(*) from {self.inventoryTableName}TableACL where groupType='{groupType}' " - # logger.info(checkSQL) - else: - checkSQL = f"select count(*) from {self.inventoryTableName} where groupType='{groupType}' \ - and WorkspaceObject='{objectType}'" - if self.spark.sql(checkSQL).collect()[0][0] > 0: - return True - else: - return False - - except Exception as e: - logger.error(f"Error doing inventory check : {e}") - - # retrives the permission for a object type from delta table - # seperate logic for workspace object and table acl if used - def getObjectInventory(self, mode: str, objectType: str, tableACL: bool = False): - try: - groupType = "" - if mode == "Workspace": - groupType = "WorkspaceLocal" - else: - groupType = "WorkspaceTemp" - if tableACL: - checkPermSQL = f"select Database, Principal, ActionTypes, ObjectType, ObjectKey from \ - {self.inventoryTableName}TableACL where groupType='{groupType}'" - perm = ( - self.spark.sql(checkPermSQL) - .withColumn("ActionTypes", regexp_replace(col("ActionTypes"), "\[", "")) - .withColumn("ActionTypes", regexp_replace(col("ActionTypes"), "\]", "")) - .withColumn("ActionTypes", split(col("ActionTypes"), ",")) - .collect() - ) - return perm - else: - checkPermSQL = f"select Permission from {self.inventoryTableName} where groupType='{groupType}' and \ - WorkspaceObject='{objectType}'" - perm = self.spark.sql(checkPermSQL).collect()[0][0] - return perm - except Exception as e: - logger.error(f"Error retrieving inventory for {objectType} : {e}") - - # performs inventory of the objects - # if objectType is All, do inventory for all - # if objectType is specific , do inventory for that object alone - # if force is set to True do fresh inventory ireespective of data present in inventory table or not - # if force is set to false and data is present in inventory table, retrive data - def performInventory(self, mode: str, force: bool = False, objectType: str = "All"): - # check if all this should already be cached - # if self.lastInventoryRun == mode: - # self.setGroupListForMode(mode) - # print(f'Skipping inventory for mode = {mode} since already performed.') - # return - - logger.info( - f"Performing inventory of workspace object permissions. Filtering results by group list for mode: {mode}." - ) - try: - # checks if only valid ObjectType is passed, else thorw msg and return - if objectType not in [ - "All", - "Group", - "Password", - "Cluster", - "ClusterPolicy", - "Warehouse", - "Dashboard", - "Query", - "Job", - "Folder", - "TableACL", - "Alert", - "Pool", - "Experiment", - "Model", - "DLT", - "Repo", - "Token", - "Secret", - ]: - logger.info( - "Enter valid object types from All,Group,Password,Cluster,ClusterPolicy,Warehouse,Dashboard,Query,\ - Job,Folder,TableACL,Alert,Pool,Experiment,Model,DLT,Repo,Token,Secret" - ) - return - if objectType == "Group" or objectType == "All": - logger.info("performing Groups inventory") - if force or not self.isObjectInventoryPresent(mode, "GroupDict"): - logger.info("performing fresh inventory") - self.setGroupListForMode(mode) - self.persistInventory(mode, "GroupDict", self.groupIdDict) - self.persistInventory(mode, "GroupMember", self.groupMembers) - self.persistInventory(mode, "GroupEntitlement", self.groupEntitlements) - self.persistInventory(mode, "GroupRole", self.groupRoles) - self.persistInventory(mode, "GroupName", self.groupNameDict) - - else: - logger.info("retrieving inventory from table") - self.groupIdDict = self.getObjectInventory(mode, "GroupDict") - self.groupMembers = self.fixList(self.getObjectInventory(mode, "GroupMember")) - self.groupEntitlements = self.fixListv2(self.getObjectInventory(mode, "GroupEntitlement")) - self.groupRoles = self.fixListv2(self.getObjectInventory(mode, "GroupRole")) - self.groupNameDict = self.getObjectInventory(mode, "GroupName") - if self.cloud == "AWS": - if objectType == "Password" or objectType == "All": - logger.info("performing password inventory") - if force or not self.isObjectInventoryPresent(mode, "Password"): - logger.info("performing fresh inventory") - self.passwordPerm = self.getPasswordACL() - self.persistInventory(mode, "Password", self.passwordPerm) - else: - logger.info("retrieving inventory from table") - self.passwordPerm = self.fixList(self.getObjectInventory(mode, "Password")) - # These are parallel - # self.clusterPerm = self.getAllClustersACL() - if objectType == "Cluster" or objectType == "All": - logger.info("performing cluster inventory") - if force or not self.isObjectInventoryPresent(mode, "Cluster"): - logger.info("performing fresh inventory") - self.clusterPerm = self.getAllClustersACL() - self.persistInventory(mode, "Cluster", self.clusterPerm) - else: - logger.info("retrieving inventory from table") - self.clusterPerm = self.fixList(self.getObjectInventory(mode, "Cluster")) - - # self.clusterPolicyPerm = self.getAllClusterPolicyACL() - if objectType == "ClusterPolicy" or objectType == "All": - logger.info("performing cluster policy inventory") - if force or not self.isObjectInventoryPresent(mode, "ClusterPolicy"): - logger.info("performing fresh inventory") - self.clusterPolicyPerm = self.getAllClusterPolicyACL() - self.persistInventory(mode, "ClusterPolicy", self.clusterPolicyPerm) - else: - logger.info("retrieving inventory from table") - self.clusterPolicyPerm = self.fixList(self.getObjectInventory(mode, "ClusterPolicy")) - # self.warehousePerm = self.getAllWarehouseACL() - if objectType == "Warehouse" or objectType == "All": - logger.info("performing warehouse inventory") - if force or not self.isObjectInventoryPresent(mode, "Warehouse"): - logger.info("performing fresh inventory") - self.warehousePerm = self.getAllWarehouseACL() - self.persistInventory(mode, "Warehouse", self.warehousePerm) - else: - logger.info("retrieving inventory from table") - self.warehousePerm = self.fixList(self.getObjectInventory(mode, "Warehouse")) - # self.dashboardPerm=self.getAllDashboardACL() # 5 mins - if objectType == "Dashboard" or objectType == "All": - logger.info("performing dashboards inventory") - if force or not self.isObjectInventoryPresent(mode, "Dashboard"): - logger.info("performing fresh inventory") - self.dashboardPerm = self.getAllDashboardACL() - self.persistInventory(mode, "Dashboard", self.dashboardPerm) - else: - logger.info("retrieving inventory from table") - self.dashboardPerm = self.fixListv3(self.getObjectInventory(mode, "Dashboard")) - # self.queryPerm=self.getAllQueriesACL() - if objectType == "Query" or objectType == "All": - logger.info("performing query inventory") - if force or not self.isObjectInventoryPresent(mode, "Query"): - logger.info("performing fresh inventory") - self.queryPerm = self.getAllQueriesACL() - self.persistInventory(mode, "Query", self.queryPerm) - else: - logger.info("retrieving inventory from table") - self.queryPerm = self.fixListv3(self.getObjectInventory(mode, "Query")) - # self.jobPerm=self.getAllJobACL() #33 mins - if objectType == "Job" or objectType == "All": - logger.info("performing job inventory") - if force or not self.isObjectInventoryPresent(mode, "Job"): - logger.info("performing fresh inventory") - self.jobPerm = self.getAllJobACL() - self.persistInventory(mode, "Job", self.jobPerm) - else: - logger.info("retrieving inventory from table") - self.jobPerm = self.fixList(self.getObjectInventory(mode, "Job")) - # self.folderPerm, self.notebookPerm, self.filePerm=self.getFoldersNotebookACL() - if objectType == "Folder" or objectType == "All": - logger.info("performing folders,notebooks, files inventory") - if force or not self.isObjectInventoryPresent(mode, "Folder"): - logger.info("performing fresh inventory") - self.folderPerm, self.notebookPerm, self.filePerm = self.getFoldersNotebookACL() - self.persistInventory(mode, "Folder", self.folderPerm) - self.persistInventory(mode, "Notebook", self.notebookPerm) - self.persistInventory(mode, "File", self.filePerm) - else: - logger.info("retrieving inventory from table") - self.folderPerm = self.fixList(self.getObjectInventory(mode, "Folder")) - self.notebookPerm = self.fixList(self.getObjectInventory(mode, "Notebook")) - self.filePerm = self.fixList(self.getObjectInventory(mode, "File")) - - # These have yet to be parallelized: - if self.checkTableACL is True: - # self.dataObjectsPerm=self.getTableACLs() - if objectType == "TableACL" or objectType == "All": - logger.info("performing Tabel ACL object inventory") - if force or not self.isObjectInventoryPresent(mode, "TableACL", True): - logger.info("performing fresh inventory") - self.dataObjectsPerm = self.getTableACLs() - self.persistInventory(mode, "TableACL", self.dataObjectsPerm, True) - else: - logger.info("retrieving inventory from table") - self.dataObjectsPerm = self.getObjectInventory(mode, "TableACL", True) - - # self.alertPerm=self.getAlertsACL() - if objectType == "Alert" or objectType == "All": - logger.info("performing alerts inventory") - if force or not self.isObjectInventoryPresent(mode, "Alert"): - logger.info("performing fresh inventory") - self.alertPerm = self.getAlertsACL() - self.persistInventory(mode, "Alert", self.alertPerm) - else: - logger.info("retrieving inventory from table") - self.alertPerm = self.fixListv3(self.getObjectInventory(mode, "Alert")) - - # self.instancePoolPerm=self.getPoolACL() - if objectType == "Pool" or objectType == "All": - logger.info("performing instance pools inventory") - if force or not self.isObjectInventoryPresent(mode, "Pool"): - logger.info("performing fresh inventory") - self.instancePoolPerm = self.getPoolACL() - self.persistInventory(mode, "Pool", self.instancePoolPerm) - else: - logger.info("retrieving inventory from table") - self.instancePoolPerm = self.fixList(self.getObjectInventory(mode, "Pool")) - - # self.expPerm=self.getExperimentACL() - if objectType == "Experiment" or objectType == "All": - logger.info("performing experiments inventory") - if force or not self.isObjectInventoryPresent(mode, "Experiment"): - logger.info("performing fresh inventory") - self.expPerm = self.getExperimentACL() - self.persistInventory(mode, "Experiment", self.expPerm) - else: - logger.info("retrieving inventory from table") - self.expPerm = self.fixList(self.getObjectInventory(mode, "Experiment")) - - # self.modelPerm=self.getModelACL() - if objectType == "Model" or objectType == "All": - logger.info("performing registered models inventory") - if force or not self.isObjectInventoryPresent(mode, "Model"): - logger.info("performing fresh inventory") - self.modelPerm = self.getModelACL() - self.persistInventory(mode, "Model", self.modelPerm) - else: - logger.info("retrieving inventory from table") - self.modelPerm = self.fixList(self.getObjectInventory(mode, "Model")) - - # self.dltPerm=self.getDLTACL() - if objectType == "DLT" or objectType == "All": - logger.info("performing DLT inventory") - if force or not self.isObjectInventoryPresent(mode, "DLT"): - logger.info("performing fresh inventory") - self.dltPerm = self.getDLTACL() - self.persistInventory(mode, "DLT", self.dltPerm) - else: - logger.info("retrieving inventory from table") - self.dltPerm = self.fixList(self.getObjectInventory(mode, "DLT")) - - # self.repoPerm=self.getRepoACL() - if objectType == "Repo" or objectType == "All": - logger.info("performing repos inventory") - if force or not self.isObjectInventoryPresent(mode, "Repo"): - logger.info("performing fresh inventory") - self.repoPerm = self.getRepoACL() - self.persistInventory(mode, "Repo", self.repoPerm) - else: - logger.info("retrieving inventory from table") - self.repoPerm = self.fixList(self.getObjectInventory(mode, "Repo")) - - # self.tokenPerm=self.getTokenACL() - if objectType == "Token" or objectType == "All": - logger.info("performing token inventory") - if force or not self.isObjectInventoryPresent(mode, "Token"): - logger.info("performing fresh inventory") - self.tokenPerm = self.getTokenACL() - self.persistInventory(mode, "Token", self.tokenPerm) - else: - logger.info("retrieving inventory from table") - self.tokenPerm = self.fixList(self.getObjectInventory(mode, "Token")) - - # self.secretScopePerm=self.getSecretScoppeACL() - if objectType == "Secret" or objectType == "All": - logger.info("performing secret scope inventory") - if force or not self.isObjectInventoryPresent(mode, "Secret"): - logger.info("performing fresh inventory") - self.secretScopePerm = self.getSecretScoppeACL() - self.persistInventory(mode, "Secret", self.secretScopePerm) - else: - logger.info("retrieving inventory from table") - self.secretScopePerm = self.fixList(self.getObjectInventory(mode, "Secret")) - - self.lastInventoryRun = mode - except Exception as e: - logger.error(f" Error creating group inventory, {e}") - - def fixList(self, perm: str): - for mem in perm: - if perm[mem] == "[]": - perm[mem] = [] - else: - perm[mem] = [x.split(", ") for x in perm[mem].replace("[[", "").replace("]]", "").split("], [")] - return perm - - def fixListv2(self, perm: str): - for mem in perm: - if perm[mem] == "[]": - perm[mem] = [] - else: - perm[mem] = perm[mem].replace("[", "").replace("]", "").split(", ") - return perm - - def fixListv3(self, perm: str): - for mem in perm: - if perm[mem] == "[]": - perm[mem] = [] - else: - ss = [ - x.replace("=", ",").split(",") for x in perm[mem].replace("[{", "").replace("}]", "").split("}, {") - ] - perm[mem] = [{s[0].strip(): s[1].strip(), s[2].strip(): s[3].strip()} for s in ss] - return perm - - def printInventory(self, printMembers: bool = False): - logger.info("Displaying Inventory Results -- ACLs of selected groups:") - logger.info("Group List:") - logger.info("{:<20} {:<10}".format("Group ID", "Group Name")) - for key, value in self.groupIdDict.items(): - logger.info("{:<20} {:<10}".format(key, value)) - - if printMembers: - logger.info("Group Members:") - logger.info("{:<20} {:<100}".format("Group ID", "Group Member")) - for key, value in self.groupMembers.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - - logger.info("Group Entitlements:") - logger.info("{:<20} {:<100}".format("Group ID", "Group Entitlements")) - for key, value in self.groupEntitlements.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - if self.cloud == "AWS": - logger.info("Group Roles:") - logger.info("{:<20} {:<100}".format("Group ID", "Group Roles")) - for key, value in self.groupRoles.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Group Passwords:") - logger.info("{:<20} {:<100}".format("Password", "Group Names")) - for key, value in self.passwordPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Cluster Permission:") - logger.info("{:<20} {:<100}".format("Cluster ID", "Group Permission")) - for key, value in self.clusterPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Cluster Policy Permission:") - logger.info("{:<20} {:<100}".format("Cluster Policy ID", "Group Permission")) - for key, value in self.clusterPolicyPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Warehouse Permission:") - logger.info("{:<20} {:<100}".format("SQL Warehouse ID", "Group Permission")) - for key, value in self.warehousePerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Dashboard Permission:") - logger.info("{:<20} {:<100}".format("Dashboard ID", "Group Permission")) - for key, value in self.dashboardPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Query Permission:") - logger.info("{:<20} {:<100}".format("Query ID", "Group Permission")) - for key, value in self.queryPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Alerts Permission:") - logger.info("{:<20} {:<100}".format("Alerts ID", "Group Permission")) - for key, value in self.alertPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Instance Pool Permission:") - logger.info("{:<20} {:<100}".format("InstancePool ID", "Group Permission")) - for key, value in self.instancePoolPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Jobs Permission:") - logger.info("{:<20} {:<100}".format("Job ID", "Group Permission")) - for key, value in self.jobPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Experiments Permission:") - logger.info("{:<20} {:<100}".format("Experiment ID", "Group Permission")) - for key, value in self.expPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Models Permission:") - logger.info("{:<20} {:<100}".format("Model ID", "Group Permission")) - for key, value in self.modelPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Delta Live Tables Permission:") - logger.info("{:<20} {:<100}".format("Pipeline ID", "Group Permission")) - for key, value in self.dltPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Repos Permission:") - logger.info("{:<20} {:<100}".format("Repo ID", "Group Permission")) - for key, value in self.repoPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Tokens Permission:") - logger.info("{:<20} {:<100}".format("Token ID", "Group Permission")) - for key, value in self.tokenPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Secret Scopes Permission:") - logger.info("{:<20} {:<100}".format("SecretScope ID", "Group Permission")) - for key, value in self.secretScopePerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Folder Permission:") - logger.info("{:<20} {:<100}".format("Folder ID", "Group Permission")) - for key, value in self.folderPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("Notebook Permission:") - logger.info("{:<20} {:<100}".format("Notebook ID", "Group Permission")) - for key, value in self.notebookPerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - logger.info("File Permission:") - logger.info("{:<20} {:<100}".format("File ID", "Group Permission")) - for key, value in self.filePerm.items(): - logger.info("{:<20} {:<100}".format(key, str(value))) - - if self.checkTableACL is True: - logger.info("TableACL Permission:") - for item in self.dataObjectsPerm: - logger.info(item) - - def dryRun(self, mode: str = "Workspace", printMembers: bool = False): - self.performInventory(mode) - self.printInventory(printMembers) - - def applyGroupPermission(self, level: str): - try: - logger.info("applying group entitlement permissions") - self.updateGroupEntitlements(self.groupEntitlements, level) - logger.info("applying cluster permissions") - self.updateGroupPermission("clusters", self.clusterPerm, level) - logger.info("applying cluster policy permissions") - self.updateGroupPermission("cluster-policies", self.clusterPolicyPerm, level) - logger.info("applying warehouse permissions") - self.updateGroupPermission("sql/warehouses", self.warehousePerm, level) - logger.info("applying instance pool permissions") - self.updateGroupPermission("instance-pools", self.instancePoolPerm, level) - logger.info("applying jobs permissions") - self.updateGroupPermission("jobs", self.jobPerm, level) - logger.info("applying experiments permissions") - self.updateGroupPermission("experiments", self.expPerm, level) - logger.info("applying model permissions") - self.updateGroupPermission("registered-models", self.modelPerm, level) - logger.info("applying DLT permissions") - self.updateGroupPermission("pipelines", self.dltPerm, level) - logger.info("applying folders permissions") - self.updateGroupPermission("directories", self.folderPerm, level) - logger.info("applying notebooks permissions") - self.updateGroupPermission("notebooks", self.notebookPerm, level) - logger.info("applying files permissions") - self.updateGroupPermission("files", self.filePerm, level) - logger.info("applying repos permissions") - self.updateGroupPermission("repos", self.repoPerm, level) - logger.info("applying token permissions") - self.updateGroupPermission("authorization", self.tokenPerm, level) - logger.info("applying secret scope permissions") - self.updateSecretPermission(self.secretScopePerm, level) - self.updateSecretPermission(self.secretScopePerm, level) - logger.info("applying dashboard permissions") - self.updateGroup2Permission("dashboards", self.dashboardPerm, level) - logger.info("applying query permissions") - self.updateGroup2Permission("queries", self.queryPerm, level) - logger.info("applying alerts permissions") - self.updateGroup2Permission("alerts", self.alertPerm, level) - if self.cloud == "AWS": - logger.info("applying password permissions") - self.updateGroupPermission("authorization", self.passwordPerm, level) - logger.info("applying instance profile permissions") - self.updateGroupRoles(level) - if self.checkTableACL is True: - logger.info("applying table acl object permissions") - self.updateDataObjectsPermission(self.dataObjectsPerm, level) - - except Exception as e: - logger.error(f" Error applying group permission, {e}") - - def validateTempWSGroup(self) -> list: - try: - res = requests.get( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups", - headers=self.headers, - ) - resJson = res.json() - WSGGroup = [e["displayName"] for e in resJson["Resources"] if e["meta"]["resourceType"] == "WorkspaceGroup"] - for g in self.groupL: - if "db-temp-" + g not in WSGGroup: - logger.info(f"temp workspace group db-temp-{g} not present, please check") - return 0 - return 1 - except Exception as e: - logger.info(f"error validating WS group objects : {e}") - - def bulkTryDelete(self, deleteList): - for g in deleteList: - gID = self.groupNameDict[g] - logger.info(f"Attempting to delete group [{gID}] - {g}") - try: - requests.delete( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups/{gID}", - headers=self.headers, - ) - except Exception: - logger.error("ERROR - Failed to delete group [{gID}] - {g}. ErrorMessage: {deleteError}") - pass - else: - logger.info(f"SUCCESS - Deleted group [{gID}] - {g}") - - def persistInventory(self, mode: str, objectType: str, objectPerm: dict, tableACL: bool = False): - try: - groupType = "" - if mode == "Workspace": - logger.info(f"Saving data for workspace groups for {objectType} in {self.inventoryTableName} table.") - groupType = "WorkspaceLocal" - else: - logger.info( - f"Saving data for workspace temp groups for {objectType} in {self.inventoryTableName} table." - ) - groupType = "WorkspaceTemp" - - persistList = [] - if tableACL: - deleteSQL = f"delete from {self.inventoryTableName}TableACL where groupType='{groupType}' ;" - self.spark.sql(deleteSQL) - tableACLCol = StructType( - [ - StructField("Database", StringType(), True), - StructField("Principal", StringType(), True), - StructField("ActionTypes", StringType(), True), - StructField("ObjectType", StringType(), True), - StructField("ObjectKey", StringType(), True), - ] - ) - - tableACLDF = self.spark.createDataFrame(data=objectPerm, schema=tableACLCol).withColumn( - "GroupType", lit(groupType) - ) - tableACLDF.write.format("delta").mode("append").saveAsTable(self.inventoryTableName + "TableACL") - logger.info(f"Saved data in {self.inventoryTableName}TableACL table for {objectType}") - else: - deleteSQL = f"delete from {self.inventoryTableName} where GroupType='{groupType}' and \ - WorkspaceObject='{objectType}';" - self.spark.sql(deleteSQL) - persistList.append([groupType, objectType, objectPerm]) - persistColumns = StructType( - [ - StructField("GroupType", StringType(), True), - StructField("WorkspaceObject", StringType(), True), - StructField("Permission", MapType(StringType(), StringType()), True), - ] - ) - persistDF = self.spark.createDataFrame(data=persistList, schema=persistColumns) - # return persistDF - persistDF.write.format("delta").mode("append").saveAsTable(self.inventoryTableName) - logger.info(f"Saved data in {self.inventoryTableName} table for {objectType}") - except Exception as e: - logger.error(f"Error creating delta table to store inventory : {e}") - - def deleteWorkspaceLocalGroups(self): - try: - self.setGroupListForMode("Workspace") - if self.validateTempWSGroup() == 0: - logger.info("temp group validation failed, aborting deletion") - return - self.bulkTryDelete(self.groupL) - except Exception as e: - logger.error(f"Error deleting groups : {e}") - - def deleteTempGroups(self): - self.setGroupListForMode("Account") - try: - self.bulkTryDelete(self.groupL) - except Exception as e: - logger.error(f"Error deleting temp groups : {e}") - - def createBackupGroup(self): - try: - if self.validateWSGroup() == 0: - return - self.performInventory("Workspace") - self.printInventory() - - for g in self.groupL: - memberList = [] - if self.groupNameDict[g] in self.groupMembers: - for mem in self.groupMembers[self.groupNameDict[g]]: - memberList.append({"value": mem[1]}) - data = { - "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], - "displayName": "db-temp-" + g, - "members": memberList, - } - res = requests.post( - f"{self.workspace_url}/api/2.0/preview/scim/v2/Groups", - headers=self.headers, - data=json.dumps(data), - ) - if res.status_code == 409: - logger.info(f'group with name "db-temp-"{g} already present, please delete and try again.') - return - self.groupWSGIdDict[res.json()["id"]] = "db-temp-" + g - self.groupWSGNameDict["db-temp-" + g] = res.json()["id"] - self.applyGroupPermission("Workspace") - # self.persistInventory("Workspace") - except Exception as e: - logger.error(f" Error creating backup groups , {e}") - - def validateAccountGroup(self): - try: - res = requests.get( - f"{self.workspace_url}/api/2.0/account/scim/v2/Groups", - headers=self.headers, - ) - for grp in res.json()["Resources"]: - self.accountGroups_lower[grp["displayName"].casefold()] = grp["id"] - for g in self.WorkspaceGroupNames: - if g.casefold() not in self.accountGroups_lower: - logger.info(f"group {g} is not present in account level, please add correct group and try again") - return 1 - return 0 - except Exception as e: - logger.error(f" Error validating account level group, {e}") - - def createAccountGroup(self): - try: - if self.validateAccountGroup() == 1: - return - if self.validateTempWSGroup() == 0: - return - self.performInventory("Account") - self.printInventory() - data = {"permissions": ["USER"]} - for g in self.WorkspaceGroupNames: - requests.put( - f"{self.workspace_url}/api/2.0/preview/permissionassignments/principals/{self.accountGroups_lower[g.casefold()]}", - headers=self.headers, - data=json.dumps(data), - ) - self.applyGroupPermission("Account") - # self.persistInventory("Account") - - except Exception as e: - logger.error(f" Error creating account level group, {e}")