Skip to content

Commit

Permalink
Added baseline for Table ACL migration (#78)
Browse files Browse the repository at this point in the history
Baseline implementation for scanning tables and grants in Hive Metastore and their transformation to Unity Catalog equivalents
  • Loading branch information
nfx authored Aug 18, 2023
1 parent 4c6a2ed commit 7a5254b
Show file tree
Hide file tree
Showing 14 changed files with 1,194 additions and 6 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,6 @@ cython_debug/
/scratch

# dev files and scratches
dev/cleanup.py
dev/cleanup.py

Support
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ authors = [
{ name = "renardeinside", email = "polarpersonal@gmail.com" },
]
classifiers = [
"Development Status :: 4 - Beta",
"Development Status :: 3 - Alpha",
"Programming Language :: Python",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"databricks-sdk>=0.2.1",
"databricks-sdk~=0.5.0",
"typer[all]>=0.9.0,<0.10.0",
"pyhocon>=0.3.60,<0.4.0",
"pydantic>=2.0.3, <3.0.0",
Expand Down Expand Up @@ -157,8 +157,8 @@ ignore = [
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
"FBT003",
# Ignore checks for possible passwords
"S105", "S106", "S107",
# Ignore checks for possible passwords and SQL statement construction
"S105", "S106", "S107", "S608",
# Allow print statements
"T201",
# Allow asserts
Expand Down
Empty file.
188 changes: 188 additions & 0 deletions src/uc_migration_toolkit/providers/mixins/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import json
import logging
import random
import time
from collections.abc import Iterator
from datetime import timedelta

from databricks.sdk.service.sql import (
ColumnInfoTypeName,
Disposition,
ExecuteStatementResponse,
Format,
ResultData,
StatementExecutionAPI,
StatementState,
StatementStatus,
)

MAX_SLEEP_PER_ATTEMPT = 10

MAX_PLATFORM_TIMEOUT = 50

MIN_PLATFORM_TIMEOUT = 5

_LOG = logging.getLogger("databricks.sdk")


class _RowCreator(tuple):
def __new__(cls, fields):
instance = super().__new__(cls, fields)
return instance

def __repr__(self):
field_values = ", ".join(f"{field}={getattr(self, field)}" for field in self)
return f"{self.__class__.__name__}({field_values})"


class Row(tuple):
def as_dict(self) -> dict[str, any]:
return dict(zip(self.__columns__, self, strict=True))

def __getattr__(self, col):
idx = self.__columns__.index(col)
return self[idx]

def __getitem__(self, col):
# if columns are named `2 + 2`, for example
return self.__getattr__(col)

def __repr__(self):
return f"Row({', '.join(f'{k}={v}' for (k, v) in zip(self.__columns__, self, strict=True))})"


class StatementExecutionExt(StatementExecutionAPI):
def __init__(self, api_client):
super().__init__(api_client)
self.type_converters = {
ColumnInfoTypeName.ARRAY: json.loads,
# ColumnInfoTypeName.BINARY: not_supported(ColumnInfoTypeName.BINARY),
ColumnInfoTypeName.BOOLEAN: bool,
# ColumnInfoTypeName.BYTE: not_supported(ColumnInfoTypeName.BYTE),
ColumnInfoTypeName.CHAR: str,
# ColumnInfoTypeName.DATE: not_supported(ColumnInfoTypeName.DATE),
ColumnInfoTypeName.DOUBLE: float,
ColumnInfoTypeName.FLOAT: float,
ColumnInfoTypeName.INT: int,
# ColumnInfoTypeName.INTERVAL: not_supported(ColumnInfoTypeName.INTERVAL),
ColumnInfoTypeName.LONG: int,
ColumnInfoTypeName.MAP: json.loads,
ColumnInfoTypeName.NULL: lambda _: None,
ColumnInfoTypeName.SHORT: int,
ColumnInfoTypeName.STRING: str,
ColumnInfoTypeName.STRUCT: json.loads,
# ColumnInfoTypeName.TIMESTAMP: not_supported(ColumnInfoTypeName.TIMESTAMP),
# ColumnInfoTypeName.USER_DEFINED_TYPE: not_supported(ColumnInfoTypeName.USER_DEFINED_TYPE),
}

@staticmethod
def _raise_if_needed(status: StatementStatus):
if status.state not in [StatementState.FAILED, StatementState.CANCELED, StatementState.CLOSED]:
return
msg = status.state.value
if status.error is not None:
msg = f"{msg}: {status.error.error_code.value} {status.error.message}"
raise RuntimeError(msg)

def execute(
self,
warehouse_id: str,
statement: str,
*,
byte_limit: int | None = None,
catalog: str | None = None,
schema: str | None = None,
timeout: timedelta = timedelta(minutes=20),
) -> ExecuteStatementResponse:
# The wait_timeout field must be 0 seconds (disables wait),
# or between 5 seconds and 50 seconds.
wait_timeout = None
if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT:
# set server-side timeout
wait_timeout = f"{timeout.total_seconds()}s"

_LOG.debug(f"Executing SQL statement: {statement}")

# technically, we can do Disposition.EXTERNAL_LINKS, but let's push it further away.
# format is limited to Format.JSON_ARRAY, but other iterations may include ARROW_STREAM.
immediate_response = self.execute_statement(
warehouse_id=warehouse_id,
statement=statement,
catalog=catalog,
schema=schema,
disposition=Disposition.INLINE,
format=Format.JSON_ARRAY,
byte_limit=byte_limit,
wait_timeout=wait_timeout,
)

if immediate_response.status.state == StatementState.SUCCEEDED:
return immediate_response

self._raise_if_needed(immediate_response.status)

attempt = 1
status_message = "polling..."
deadline = time.time() + timeout.total_seconds()
while time.time() < deadline:
res = self.get_statement(immediate_response.statement_id)
if res.status.state == StatementState.SUCCEEDED:
return ExecuteStatementResponse(
manifest=res.manifest, result=res.result, statement_id=res.statement_id, status=res.status
)
status_message = f"current status: {res.status.state.value}"
self._raise_if_needed(res.status)
sleep = attempt
if sleep > MAX_SLEEP_PER_ATTEMPT:
# sleep 10s max per attempt
sleep = MAX_SLEEP_PER_ATTEMPT
_LOG.debug(f"SQL statement {res.statement_id}: {status_message} (sleeping ~{sleep}s)")
time.sleep(sleep + random.random())
attempt += 1
self.cancel_execution(immediate_response.statement_id)
msg = f"timed out after {timeout}: {status_message}"
raise TimeoutError(msg)

def execute_fetch_all(
self,
warehouse_id: str,
statement: str,
*,
byte_limit: int | None = None,
catalog: str | None = None,
schema: str | None = None,
timeout: timedelta = timedelta(minutes=20),
) -> Iterator[Row]:
execute_response = self.execute(
warehouse_id, statement, byte_limit=byte_limit, catalog=catalog, schema=schema, timeout=timeout
)
col_names = []
col_conv = []
for col in execute_response.manifest.schema.columns:
col_names.append(col.name)
conv = self.type_converters.get(col.type_name, None)
if conv is None:
msg = f"{col.name} has no {col.type_name.value} converter"
raise ValueError(msg)
col_conv.append(conv)
row_factory = type("Row", (Row,), {"__columns__": col_names})
result_data = execute_response.result
if result_data is None:
return []
while True:
for data in result_data.data_array:
# enumerate() + iterator + tuple constructor makes it more performant
# on larger humber of records for Python, even though it's less
# readable code.
row = []
for i, value in enumerate(data):
if value is None:
row.append(None)
else:
row.append(col_conv[i](value))
yield row_factory(row)
if result_data.next_chunk_index is None:
return
# TODO: replace once ES-828324 is fixed
json_response = self._api.do("GET", result_data.next_chunk_internal_link)
result_data = ResultData.from_dict(json_response)
Loading

0 comments on commit 7a5254b

Please sign in to comment.