Skip to content

Commit

Permalink
Added crawler for Azure Service principals used for direct storage ac…
Browse files Browse the repository at this point in the history
…cess

Fixes #249
  • Loading branch information
dipankarkush-db committed Sep 28, 2023
1 parent efd30c8 commit 543e2a0
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 35 deletions.
54 changes: 20 additions & 34 deletions src/databricks/labs/ucx/assessment/crawlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ class PipelineInfo:
failures: str


def azure_sp_conf_usage_check(config: str) -> bool:
sp_conf_present = False
for conf in _AZURE_SP_CONF:
if re.search(conf, config):
sp_conf_present = True
return sp_conf_present
def _azure_sp_conf_present_check(config: dict) -> bool:
for key in config.keys():
for conf in _AZURE_SP_CONF:
if re.search(conf, key):
return True
return False


def spark_version_compatibility(spark_version: str) -> str:
Expand Down Expand Up @@ -95,10 +95,8 @@ def _assess_pipelines(self, all_pipelines):
failures = []
pipeline_config = self._ws.pipelines.get(pipeline.pipeline_id).spec.configuration
if pipeline_config:
for key in pipeline_config.items():
if azure_sp_conf_usage_check(str(key)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} pipeline.")
break
if _azure_sp_conf_present_check(pipeline_config):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} pipeline.")

pipeline_info.failures = json.dumps(failures)
if len(failures) > 0:
Expand Down Expand Up @@ -142,22 +140,16 @@ def _assess_clusters(self, all_clusters):
failures.append(f"using DBFS mount in configuration: {value}")

# Checking if Azure cluster config is present in spark config
for key in cluster.spark_conf.items():
if azure_sp_conf_usage_check(str(key)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
break
if _azure_sp_conf_present_check(cluster.spark_conf):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")

# Checking if Azure cluster config is present in cluster policies
if cluster.policy_id:
cluster_policy_definition = self._ws.cluster_policies.get(cluster.policy_id).definition
if azure_sp_conf_usage_check(cluster_policy_definition):
policy = self._ws.cluster_policies.get(cluster.policy_id)
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")

cluster_family_definition = self._ws.cluster_policies.get(
cluster.policy_id
).policy_family_definition_overrides
if cluster_family_definition:
if azure_sp_conf_usage_check(cluster_family_definition):
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")

cluster_info.failures = json.dumps(failures)
Expand Down Expand Up @@ -223,22 +215,16 @@ def _assess_jobs(self, all_jobs: list[BaseJob], all_clusters_by_id) -> list[JobI
job_assessment[job.job_id].add(f"using DBFS mount in configuration: {value}")

# Checking if Azure cluster config is present in spark config
for key in cluster_config.spark_conf.items():
if azure_sp_conf_usage_check(str(key)):
job_assessment[job.job_id].add(f"{_AZURE_SP_CONF_FAILURE_MSG} Job cluster.")
break
if _azure_sp_conf_present_check(cluster_config.spark_conf):
job_assessment[job.job_id].add(f"{_AZURE_SP_CONF_FAILURE_MSG} Job cluster.")

# Checking if Azure cluster config is present in cluster policies
if cluster_config.policy_id:
job_cluster_policy_definition = self._ws.cluster_policies.get(cluster_config.policy_id).definition
if azure_sp_conf_usage_check(job_cluster_policy_definition):
policy = self._ws.cluster_policies.get(cluster_config.policy_id)
if _azure_sp_conf_present_check(json.loads(policy.definition)):
job_assessment[job.job_id].add(f"{_AZURE_SP_CONF_FAILURE_MSG} Job cluster.")

job_cluster_family_definition = self._ws.cluster_policies.get(
cluster_config.policy_id
).policy_family_definition_overrides
if job_cluster_family_definition:
if azure_sp_conf_usage_check(job_cluster_family_definition):
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
job_assessment[job.job_id].add(f"{_AZURE_SP_CONF_FAILURE_MSG} Job cluster.")

for job_key in job_details.keys():
Expand Down
40 changes: 39 additions & 1 deletion tests/unit/assessment/test_assessment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from databricks.labs.ucx.assessment.crawlers import (
ClustersCrawler,
JobsCrawler,
PipelineInfo,
PipelinesCrawler,
)
from databricks.labs.ucx.hive_metastore.data_objects import ExternalLocationCrawler
Expand Down Expand Up @@ -372,7 +373,21 @@ def test_cluster_assessment_cluster_policy_no_spark_conf(mocker):
'"autotermination_minutes":{"type":"unlimited","defaultValue":4320,"isOptional":true}}'
)

ws.cluster_policies.get().policy_family_definition_overrides = "family_definition"
ws.cluster_policies.get().policy_family_definition_overrides = (
'{\n "not.spark.conf": {\n '
'"type": "fixed",\n "value": "OAuth",\n '
' "hidden": true\n },\n "not.a.type": {\n '
' "type": "fixed",\n "value": '
'"not.a.matching.type",\n '
'"hidden": true\n },\n "not.a.matching.type": {\n '
'"type": "fixed",\n "value": "fsfsfsfsffsfsf",\n "hidden": true\n },\n '
'"not.a.matching.type": {\n "type": "fixed",\n '
'"value": "gfgfgfgfggfggfgfdds",\n "hidden": true\n },\n '
'"not.a.matching.type": {\n '
'"type": "fixed",\n '
'"value": "https://login.microsoftonline.com/1234ededed/oauth2/token",\n '
'"hidden": true\n }\n}'
)

crawler = ClustersCrawler(ws, MockBackend(), "ucx")._assess_clusters(sample_clusters1)
result_set1 = list(crawler)
Expand Down Expand Up @@ -429,3 +444,26 @@ def test_pipeline_assessment_without_config(mocker):

assert len(result_set) == 1
assert result_set[0].success == 1


def test_pipeline_snapshot_with_config():
sample_pipelines = [
PipelineInfo(
creator_name="abcde.defgh@databricks.com",
pipeline_name="New DLT Pipeline",
pipeline_id="0112eae7-9d11-4b40-a2b8-6c83cb3c7497",
success=1,
failures="",
)
]
mock_ws = Mock()

crawler = PipelinesCrawler(mock_ws, MockBackend(), "ucx")

crawler._try_fetch = Mock(return_value=[])
crawler._crawl = Mock(return_value=sample_pipelines)

result_set = crawler.snapshot()

assert len(result_set) == 1
assert result_set[0].success == 1

0 comments on commit 543e2a0

Please sign in to comment.