Skip to content

Commit

Permalink
[Feature] Add DataPlane support (#700)
Browse files Browse the repository at this point in the history
## Changes
Add DataPlane support

## Tests
- [X] `make test` run locally
- [X] `make fmt` applied
- [ ] relevant integration tests applied
- [X] Manual test against staging workspace (prod workspaces don't
support DataPlane APIs)
  • Loading branch information
hectorcast-db committed Jul 16, 2024
1 parent 6462912 commit 3009a6b
Show file tree
Hide file tree
Showing 12 changed files with 453 additions and 36 deletions.
40 changes: 32 additions & 8 deletions .codegen/__init__.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ from databricks.sdk.credentials_provider import CredentialsStrategy
from databricks.sdk.mixins.files import DbfsExt
from databricks.sdk.mixins.compute import ClustersExt
from databricks.sdk.mixins.workspace import WorkspaceExt
{{- range .Services}} {{if not .IsDataPlane}}
from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}{{end}}
{{- range .Services}}
from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}
from databricks.sdk.service.provisioning import Workspace
from databricks.sdk import azure

Expand Down Expand Up @@ -61,8 +61,20 @@ class WorkspaceClient:
self._dbutils = _make_dbutils(self._config)
self._api_client = client.ApiClient(self._config)

{{- range .Services}}{{if and (not .IsAccounts) (not .HasParent) (not .IsDataPlane)}}
self._{{.SnakeName}} = {{template "api" .}}(self._api_client){{end -}}{{end}}
{{- range .Services}}{{if and (not .IsAccounts) (not .HasParent) .HasDataPlaneAPI (not .IsDataPlane)}}
{{.SnakeName}} = {{template "api" .}}(self._api_client){{end -}}{{end}}

{{- range .Services}}
{{- if and (not .IsAccounts) (not .HasParent)}}
{{- if .IsDataPlane}}
self._{{.SnakeName}} = {{template "api" .}}(self._api_client, {{.ControlPlaneService.SnakeName}})
{{- else if .HasDataPlaneAPI}}
self._{{.SnakeName}} = {{.SnakeName}}
{{- else}}
self._{{.SnakeName}} = {{template "api" .}}(self._api_client)
{{- end -}}
{{- end -}}
{{end}}

@property
def config(self) -> client.Config:
Expand All @@ -76,7 +88,7 @@ class WorkspaceClient:
def dbutils(self) -> dbutils.RemoteDbUtils:
return self._dbutils

{{- range .Services}}{{if and (not .IsAccounts) (not .HasParent) (not .IsDataPlane)}}
{{- range .Services}}{{if and (not .IsAccounts) (not .HasParent)}}
@property
def {{.SnakeName}}(self) -> {{template "api" .}}:
{{if .Description}}"""{{.Summary}}"""{{end}}
Expand Down Expand Up @@ -117,8 +129,20 @@ class AccountClient:
self._config = config.copy()
self._api_client = client.ApiClient(self._config)

{{- range .Services}}{{if and .IsAccounts (not .HasParent) (not .IsDataPlane)}}
self._{{(.TrimPrefix "account").SnakeName}} = {{template "api" .}}(self._api_client){{end -}}{{end}}
{{- range .Services}}{{if and .IsAccounts (not .HasParent) .HasDataPlaneAPI (not .IsDataPlane)}}
{{(.TrimPrefix "account").SnakeName}} = {{template "api" .}}(self._api_client){{end -}}{{end}}

{{- range .Services}}
{{- if and .IsAccounts (not .HasParent)}}
{{- if .IsDataPlane}}
self._{{(.TrimPrefix "account").SnakeName}} = {{template "api" .}}(self._api_client, {{.ControlPlaneService.SnakeName}})
{{- else if .HasDataPlaneAPI}}
self._{{(.TrimPrefix "account").SnakeName}} = {{(.TrimPrefix "account").SnakeName}}
{{- else}}
self._{{(.TrimPrefix "account").SnakeName}} = {{template "api" .}}(self._api_client)
{{- end -}}
{{- end -}}
{{end}}

@property
def config(self) -> client.Config:
Expand All @@ -128,7 +152,7 @@ class AccountClient:
def api_client(self) -> client.ApiClient:
return self._api_client

{{- range .Services}}{{if and .IsAccounts (not .HasParent) (not .IsDataPlane)}}
{{- range .Services}}{{if and .IsAccounts (not .HasParent)}}
@property
def {{(.TrimPrefix "account").SnakeName}}(self) -> {{template "api" .}}:{{if .Description}}
"""{{.Summary}}"""{{end}}
Expand Down
70 changes: 54 additions & 16 deletions .codegen/service.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ from typing import Dict, List, Any, Iterator, Type, Callable, Optional, BinaryIO
import time
import random
import logging
import requests

from ..data_plane import DataPlaneService
from ..errors import OperationTimeout, OperationFailed
from ._internal import _enum, _from_dict, _repeated_dict, _repeated_enum, Wait, _escape_multi_segment_path_parameter
from ..oauth import Token

_LOG = logging.getLogger('databricks.sdk')

Expand Down Expand Up @@ -100,12 +104,16 @@ class {{.PascalName}}{{if eq "List" .PascalName}}Request{{end}}:{{if .Descriptio
{{- end -}}
{{- end -}}

{{range .Services}} {{if not .IsDataPlane}}
{{range .Services}}
class {{.PascalName}}API:{{if .Description}}
"""{{.Comment " " 110}}"""
{{end}}
def __init__(self, api_client):
def __init__(self, api_client{{if .IsDataPlane}}, control_plane{{end}}):
self._api = api_client
{{if .IsDataPlane -}}
self._control_plane = control_plane
self._data_plane_service = DataPlaneService()
{{end -}}
{{range .Subservices}}
self._{{.SnakeName}} = {{.PascalName}}API(self._api){{end}}

Expand Down Expand Up @@ -183,6 +191,9 @@ class {{.PascalName}}API:{{if .Description}}
{{if .Request -}}
{{template "method-serialize" .}}
{{- end}}
{{- if .Service.IsDataPlane}}
{{template "data-plane" .}}
{{- end}}
{{template "method-headers" . }}
{{if .Response.HasHeaderField -}}
{{template "method-response-headers" . }}
Expand All @@ -195,7 +206,27 @@ class {{.PascalName}}API:{{if .Description}}
return self.{{template "safe-snake-name" .}}({{range $i, $x := .Request.Fields}}{{if $i}}, {{end}}{{template "safe-snake-name" .}}={{template "safe-snake-name" .}}{{end}}).result(timeout=timeout)
{{end}}
{{end -}}
{{end}}
{{- end}}

{{define "data-plane" -}}
def info_getter():
response = self._control_plane.{{.Service.DataPlaneInfoMethod.SnakeName}}(
{{- range .Service.DataPlaneInfoMethod.Request.Fields }}
{{.SnakeName}} = {{.SnakeName}},
{{- end}}
)
if response.{{(index .DataPlaneInfoFields 0).SnakeName}} is None:
raise Exception("Resource does not support direct Data Plane access")
return response{{range .DataPlaneInfoFields}}.{{.SnakeName}}{{end}}

get_params = [{{- range .Service.DataPlaneInfoMethod.Request.Fields }}{{.SnakeName}},{{end}}]
data_plane_details = self._data_plane_service.get_data_plane_details('{{.SnakeName}}', get_params, info_getter, self._api.get_oauth_token)
token = data_plane_details.token

def auth(r: requests.PreparedRequest) -> requests.PreparedRequest:
authorization = f"{token.token_type} {token.access_token}"
r.headers["Authorization"] = authorization
return r
{{- end}}

{{define "method-parameters" -}}
Expand Down Expand Up @@ -325,19 +356,26 @@ class {{.PascalName}}API:{{if .Description}}
{{- end}}

{{define "method-do" -}}
self._api.do('{{.Verb}}',
{{ template "path" . }}
{{if .Request}}
{{- if .Request.HasQueryField}}, query=query{{end}}
{{- if .Request.MapValue}}, body=contents
{{- else if .Request.HasJsonField}}, body=body{{end}}
{{end}}
, headers=headers
{{if .Response.HasHeaderField -}}
, response_headers=response_headers
{{- end}}
{{- if and .IsRequestByteStream .RequestBodyField }}, data={{template "safe-snake-name" .RequestBodyField}}{{ end }}
{{- if .IsResponseByteStream }}, raw=True{{ end }})
self._api.do('{{.Verb}}',
{{- if .Service.IsDataPlane -}}
url=data_plane_details.endpoint_url
{{- else -}}
{{ template "path" . }}
{{- end -}}
{{if .Request}}
{{- if .Request.HasQueryField}}, query=query{{end}}
{{- if .Request.MapValue}}, body=contents
{{- else if .Request.HasJsonField}}, body=body{{end}}
{{end}}
, headers=headers
{{if .Response.HasHeaderField -}}
, response_headers=response_headers
{{- end}}
{{- if and .IsRequestByteStream .RequestBodyField }}, data={{template "safe-snake-name" .RequestBodyField}}{{ end }}
{{- if .Service.IsDataPlane -}}
,auth=auth
{{- end -}}
{{- if .IsResponseByteStream }}, raw=True{{ end }})
{{- end}}

{{- define "path" -}}
Expand Down
12 changes: 10 additions & 2 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 16 additions & 9 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,31 +133,36 @@ def get_oauth_token(self, auth_details: str) -> Token:

def do(self,
method: str,
path: str,
path: str = None,
url: str = None,
query: dict = None,
headers: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None,
response_headers: List[str] = None) -> Union[dict, BinaryIO]:
# Remove extra `/` from path for Files API
# Once we've fixed the OpenAPI spec, we can remove this
path = re.sub('^/api/2.0/fs/files//', '/api/2.0/fs/files/', path)
if headers is None:
headers = {}
if url is None:
# Remove extra `/` from path for Files API
# Once we've fixed the OpenAPI spec, we can remove this
path = re.sub('^/api/2.0/fs/files//', '/api/2.0/fs/files/', path)
url = f"{self._cfg.host}{path}"
headers['User-Agent'] = self._user_agent_base
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._cfg.clock)
response = retryable(self._perform)(method,
path,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data)
data=data,
auth=auth)

resp = dict()
for header in response_headers if response_headers else []:
Expand Down Expand Up @@ -239,20 +244,22 @@ def _parse_retry_after(cls, response: requests.Response) -> Optional[int]:

def _perform(self,
method: str,
path: str,
url: str,
query: dict = None,
headers: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None):
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
response = self._session.request(method,
f"{self._cfg.host}{path}",
url,
params=self._fix_query_string(query),
json=body,
headers=headers,
files=files,
data=data,
auth=auth,
stream=raw,
timeout=self._http_timeout_seconds)
try:
Expand Down
65 changes: 65 additions & 0 deletions databricks/sdk/data_plane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import threading
from dataclasses import dataclass
from typing import Callable, List

from databricks.sdk.oauth import Token
from databricks.sdk.service.oauth2 import DataPlaneInfo


@dataclass
class DataPlaneDetails:
"""
Contains details required to query a DataPlane endpoint.
"""
endpoint_url: str
"""URL used to query the endpoint through the DataPlane."""
token: Token
"""Token to query the DataPlane endpoint."""


class DataPlaneService:
"""Helper class to fetch and manage DataPlane details."""

def __init__(self):
self._data_plane_info = {}
self._tokens = {}
self._lock = threading.Lock()

def get_data_plane_details(self, method: str, params: List[str], info_getter: Callable[[], DataPlaneInfo],
refresh: Callable[[str], Token]):
"""Get and cache information required to query a Data Plane endpoint using the provided methods.
Returns a cached DataPlaneDetails if the details have already been fetched previously and are still valid.
If not, it uses the provided functions to fetch the details.
:param method: method name. Used to construct a unique key for the cache.
:param params: path params used in the "get" operation which uniquely determine the object. Used to construct a unique key for the cache.
:param info_getter: function which returns the DataPlaneInfo. It will only be called if the information is not already present in the cache.
:param refresh: function to refresh the token. It will only be called if the token is missing or expired.
"""
all_elements = params.copy()
all_elements.insert(0, method)
map_key = "/".join(all_elements)
info = self._data_plane_info.get(map_key)
if not info:
self._lock.acquire()
try:
info = self._data_plane_info.get(map_key)
if not info:
info = info_getter()
self._data_plane_info[map_key] = info
finally:
self._lock.release()

token = self._tokens.get(map_key)
if not token or not token.valid:
self._lock.acquire()
token = self._tokens.get(map_key)
try:
if not token or not token.valid:
token = refresh(info.authorization_details)
self._tokens[map_key] = token
finally:
self._lock.release()

return DataPlaneDetails(endpoint_url=info.endpoint_url, token=token)
Loading

0 comments on commit 3009a6b

Please sign in to comment.