diff --git a/.codegen/__init__.py.tmpl b/.codegen/__init__.py.tmpl index d5b83e3f..572b5049 100644 --- a/.codegen/__init__.py.tmpl +++ b/.codegen/__init__.py.tmpl @@ -5,8 +5,8 @@ from databricks.sdk.credentials_provider import CredentialsStrategy from databricks.sdk.mixins.files import DbfsExt from databricks.sdk.mixins.compute import ClustersExt from databricks.sdk.mixins.workspace import WorkspaceExt -{{- range .Services}} {{if not .IsDataPlane}} -from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}{{end}} +{{- range .Services}} +from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}} from databricks.sdk.service.provisioning import Workspace from databricks.sdk import azure @@ -61,8 +61,20 @@ class WorkspaceClient: self._dbutils = _make_dbutils(self._config) self._api_client = client.ApiClient(self._config) - {{- range .Services}}{{if and (not .IsAccounts) (not .HasParent) (not .IsDataPlane)}} - self._{{.SnakeName}} = {{template "api" .}}(self._api_client){{end -}}{{end}} + {{- range .Services}}{{if and (not .IsAccounts) (not .HasParent) .HasDataPlaneAPI (not .IsDataPlane)}} + {{.SnakeName}} = {{template "api" .}}(self._api_client){{end -}}{{end}} + + {{- range .Services}} + {{- if and (not .IsAccounts) (not .HasParent)}} + {{- if .IsDataPlane}} + self._{{.SnakeName}} = {{template "api" .}}(self._api_client, {{.ControlPlaneService.SnakeName}}) + {{- else if .HasDataPlaneAPI}} + self._{{.SnakeName}} = {{.SnakeName}} + {{- else}} + self._{{.SnakeName}} = {{template "api" .}}(self._api_client) + {{- end -}} + {{- end -}} + {{end}} @property def config(self) -> client.Config: @@ -76,7 +88,7 @@ class WorkspaceClient: def dbutils(self) -> dbutils.RemoteDbUtils: return self._dbutils - {{- range .Services}}{{if and (not .IsAccounts) (not .HasParent) (not .IsDataPlane)}} + {{- range .Services}}{{if and (not .IsAccounts) (not .HasParent)}} @property def {{.SnakeName}}(self) -> {{template "api" .}}: {{if .Description}}"""{{.Summary}}"""{{end}} @@ -117,8 +129,20 @@ class AccountClient: self._config = config.copy() self._api_client = client.ApiClient(self._config) - {{- range .Services}}{{if and .IsAccounts (not .HasParent) (not .IsDataPlane)}} - self._{{(.TrimPrefix "account").SnakeName}} = {{template "api" .}}(self._api_client){{end -}}{{end}} + {{- range .Services}}{{if and .IsAccounts (not .HasParent) .HasDataPlaneAPI (not .IsDataPlane)}} + {{(.TrimPrefix "account").SnakeName}} = {{template "api" .}}(self._api_client){{end -}}{{end}} + + {{- range .Services}} + {{- if and .IsAccounts (not .HasParent)}} + {{- if .IsDataPlane}} + self._{{(.TrimPrefix "account").SnakeName}} = {{template "api" .}}(self._api_client, {{.ControlPlaneService.SnakeName}}) + {{- else if .HasDataPlaneAPI}} + self._{{(.TrimPrefix "account").SnakeName}} = {{(.TrimPrefix "account").SnakeName}} + {{- else}} + self._{{(.TrimPrefix "account").SnakeName}} = {{template "api" .}}(self._api_client) + {{- end -}} + {{- end -}} + {{end}} @property def config(self) -> client.Config: @@ -128,7 +152,7 @@ class AccountClient: def api_client(self) -> client.ApiClient: return self._api_client - {{- range .Services}}{{if and .IsAccounts (not .HasParent) (not .IsDataPlane)}} + {{- range .Services}}{{if and .IsAccounts (not .HasParent)}} @property def {{(.TrimPrefix "account").SnakeName}}(self) -> {{template "api" .}}:{{if .Description}} """{{.Summary}}"""{{end}} diff --git a/.codegen/service.py.tmpl b/.codegen/service.py.tmpl index 39892b43..643b1f33 100644 --- a/.codegen/service.py.tmpl +++ b/.codegen/service.py.tmpl @@ -8,8 +8,12 @@ from typing import Dict, List, Any, Iterator, Type, Callable, Optional, BinaryIO import time import random import logging +import requests + +from ..data_plane import DataPlaneService from ..errors import OperationTimeout, OperationFailed from ._internal import _enum, _from_dict, _repeated_dict, _repeated_enum, Wait, _escape_multi_segment_path_parameter +from ..oauth import Token _LOG = logging.getLogger('databricks.sdk') @@ -100,12 +104,16 @@ class {{.PascalName}}{{if eq "List" .PascalName}}Request{{end}}:{{if .Descriptio {{- end -}} {{- end -}} -{{range .Services}} {{if not .IsDataPlane}} +{{range .Services}} class {{.PascalName}}API:{{if .Description}} """{{.Comment " " 110}}""" {{end}} - def __init__(self, api_client): + def __init__(self, api_client{{if .IsDataPlane}}, control_plane{{end}}): self._api = api_client + {{if .IsDataPlane -}} + self._control_plane = control_plane + self._data_plane_service = DataPlaneService() + {{end -}} {{range .Subservices}} self._{{.SnakeName}} = {{.PascalName}}API(self._api){{end}} @@ -183,6 +191,9 @@ class {{.PascalName}}API:{{if .Description}} {{if .Request -}} {{template "method-serialize" .}} {{- end}} + {{- if .Service.IsDataPlane}} + {{template "data-plane" .}} + {{- end}} {{template "method-headers" . }} {{if .Response.HasHeaderField -}} {{template "method-response-headers" . }} @@ -195,7 +206,27 @@ class {{.PascalName}}API:{{if .Description}} return self.{{template "safe-snake-name" .}}({{range $i, $x := .Request.Fields}}{{if $i}}, {{end}}{{template "safe-snake-name" .}}={{template "safe-snake-name" .}}{{end}}).result(timeout=timeout) {{end}} {{end -}} -{{end}} +{{- end}} + +{{define "data-plane" -}} + def info_getter(): + response = self._control_plane.{{.Service.DataPlaneInfoMethod.SnakeName}}( + {{- range .Service.DataPlaneInfoMethod.Request.Fields }} + {{.SnakeName}} = {{.SnakeName}}, + {{- end}} + ) + if response.{{(index .DataPlaneInfoFields 0).SnakeName}} is None: + raise Exception("Resource does not support direct Data Plane access") + return response{{range .DataPlaneInfoFields}}.{{.SnakeName}}{{end}} + + get_params = [{{- range .Service.DataPlaneInfoMethod.Request.Fields }}{{.SnakeName}},{{end}}] + data_plane_details = self._data_plane_service.get_data_plane_details('{{.SnakeName}}', get_params, info_getter, self._api.get_oauth_token) + token = data_plane_details.token + + def auth(r: requests.PreparedRequest) -> requests.PreparedRequest: + authorization = f"{token.token_type} {token.access_token}" + r.headers["Authorization"] = authorization + return r {{- end}} {{define "method-parameters" -}} @@ -325,19 +356,26 @@ class {{.PascalName}}API:{{if .Description}} {{- end}} {{define "method-do" -}} -self._api.do('{{.Verb}}', - {{ template "path" . }} - {{if .Request}} - {{- if .Request.HasQueryField}}, query=query{{end}} - {{- if .Request.MapValue}}, body=contents - {{- else if .Request.HasJsonField}}, body=body{{end}} - {{end}} - , headers=headers - {{if .Response.HasHeaderField -}} - , response_headers=response_headers - {{- end}} - {{- if and .IsRequestByteStream .RequestBodyField }}, data={{template "safe-snake-name" .RequestBodyField}}{{ end }} - {{- if .IsResponseByteStream }}, raw=True{{ end }}) + self._api.do('{{.Verb}}', + {{- if .Service.IsDataPlane -}} + url=data_plane_details.endpoint_url + {{- else -}} + {{ template "path" . }} + {{- end -}} + {{if .Request}} + {{- if .Request.HasQueryField}}, query=query{{end}} + {{- if .Request.MapValue}}, body=contents + {{- else if .Request.HasJsonField}}, body=body{{end}} + {{end}} + , headers=headers + {{if .Response.HasHeaderField -}} + , response_headers=response_headers + {{- end}} + {{- if and .IsRequestByteStream .RequestBodyField }}, data={{template "safe-snake-name" .RequestBodyField}}{{ end }} + {{- if .Service.IsDataPlane -}} + ,auth=auth + {{- end -}} + {{- if .IsResponseByteStream }}, raw=True{{ end }}) {{- end}} {{- define "path" -}} diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 05c95fb6..8485efba 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -55,7 +55,8 @@ NetworksAPI, PrivateAccessAPI, StorageAPI, VpcEndpointsAPI, Workspace, WorkspacesAPI) -from databricks.sdk.service.serving import AppsAPI, ServingEndpointsAPI +from databricks.sdk.service.serving import (AppsAPI, ServingEndpointsAPI, + ServingEndpointsDataPlaneAPI) from databricks.sdk.service.settings import (AccountIpAccessListsAPI, AccountSettingsAPI, AutomaticClusterUpdateAPI, @@ -162,6 +163,7 @@ def __init__(self, self._config = config.copy() self._dbutils = _make_dbutils(self._config) self._api_client = client.ApiClient(self._config) + serving_endpoints = ServingEndpointsAPI(self._api_client) self._account_access_control_proxy = AccountAccessControlProxyAPI(self._api_client) self._alerts = AlertsAPI(self._api_client) self._apps = AppsAPI(self._api_client) @@ -226,7 +228,8 @@ def __init__(self, self._schemas = SchemasAPI(self._api_client) self._secrets = SecretsAPI(self._api_client) self._service_principals = ServicePrincipalsAPI(self._api_client) - self._serving_endpoints = ServingEndpointsAPI(self._api_client) + self._serving_endpoints = serving_endpoints + self._serving_endpoints_data_plane = ServingEndpointsDataPlaneAPI(self._api_client, serving_endpoints) self._settings = SettingsAPI(self._api_client) self._shares = SharesAPI(self._api_client) self._statement_execution = StatementExecutionAPI(self._api_client) @@ -577,6 +580,11 @@ def serving_endpoints(self) -> ServingEndpointsAPI: """The Serving Endpoints API allows you to create, update, and delete model serving endpoints.""" return self._serving_endpoints + @property + def serving_endpoints_data_plane(self) -> ServingEndpointsDataPlaneAPI: + """Serving endpoints DataPlane provides a set of operations to interact with data plane endpoints for Serving endpoints service.""" + return self._serving_endpoints_data_plane + @property def settings(self) -> SettingsAPI: """Workspace Settings API allows users to manage settings at the workspace level.""" diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index cacbad90..b686bd7f 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -133,31 +133,36 @@ def get_oauth_token(self, auth_details: str) -> Token: def do(self, method: str, - path: str, + path: str = None, + url: str = None, query: dict = None, headers: dict = None, body: dict = None, raw: bool = False, files=None, data=None, + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, response_headers: List[str] = None) -> Union[dict, BinaryIO]: - # Remove extra `/` from path for Files API - # Once we've fixed the OpenAPI spec, we can remove this - path = re.sub('^/api/2.0/fs/files//', '/api/2.0/fs/files/', path) if headers is None: headers = {} + if url is None: + # Remove extra `/` from path for Files API + # Once we've fixed the OpenAPI spec, we can remove this + path = re.sub('^/api/2.0/fs/files//', '/api/2.0/fs/files/', path) + url = f"{self._cfg.host}{path}" headers['User-Agent'] = self._user_agent_base retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), is_retryable=self._is_retryable, clock=self._cfg.clock) response = retryable(self._perform)(method, - path, + url, query=query, headers=headers, body=body, raw=raw, files=files, - data=data) + data=data, + auth=auth) resp = dict() for header in response_headers if response_headers else []: @@ -239,20 +244,22 @@ def _parse_retry_after(cls, response: requests.Response) -> Optional[int]: def _perform(self, method: str, - path: str, + url: str, query: dict = None, headers: dict = None, body: dict = None, raw: bool = False, files=None, - data=None): + data=None, + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): response = self._session.request(method, - f"{self._cfg.host}{path}", + url, params=self._fix_query_string(query), json=body, headers=headers, files=files, data=data, + auth=auth, stream=raw, timeout=self._http_timeout_seconds) try: diff --git a/databricks/sdk/data_plane.py b/databricks/sdk/data_plane.py new file mode 100644 index 00000000..6f6ddf80 --- /dev/null +++ b/databricks/sdk/data_plane.py @@ -0,0 +1,65 @@ +import threading +from dataclasses import dataclass +from typing import Callable, List + +from databricks.sdk.oauth import Token +from databricks.sdk.service.oauth2 import DataPlaneInfo + + +@dataclass +class DataPlaneDetails: + """ + Contains details required to query a DataPlane endpoint. + """ + endpoint_url: str + """URL used to query the endpoint through the DataPlane.""" + token: Token + """Token to query the DataPlane endpoint.""" + + +class DataPlaneService: + """Helper class to fetch and manage DataPlane details.""" + + def __init__(self): + self._data_plane_info = {} + self._tokens = {} + self._lock = threading.Lock() + + def get_data_plane_details(self, method: str, params: List[str], info_getter: Callable[[], DataPlaneInfo], + refresh: Callable[[str], Token]): + """Get and cache information required to query a Data Plane endpoint using the provided methods. + + Returns a cached DataPlaneDetails if the details have already been fetched previously and are still valid. + If not, it uses the provided functions to fetch the details. + + :param method: method name. Used to construct a unique key for the cache. + :param params: path params used in the "get" operation which uniquely determine the object. Used to construct a unique key for the cache. + :param info_getter: function which returns the DataPlaneInfo. It will only be called if the information is not already present in the cache. + :param refresh: function to refresh the token. It will only be called if the token is missing or expired. + """ + all_elements = params.copy() + all_elements.insert(0, method) + map_key = "/".join(all_elements) + info = self._data_plane_info.get(map_key) + if not info: + self._lock.acquire() + try: + info = self._data_plane_info.get(map_key) + if not info: + info = info_getter() + self._data_plane_info[map_key] = info + finally: + self._lock.release() + + token = self._tokens.get(map_key) + if not token or not token.valid: + self._lock.acquire() + token = self._tokens.get(map_key) + try: + if not token or not token.valid: + token = refresh(info.authorization_details) + self._tokens[map_key] = token + finally: + self._lock.release() + + return DataPlaneDetails(endpoint_url=info.endpoint_url, token=token) diff --git a/databricks/sdk/service/serving.py b/databricks/sdk/service/serving.py index 6c39c598..0f3d00de 100755 --- a/databricks/sdk/service/serving.py +++ b/databricks/sdk/service/serving.py @@ -10,6 +10,9 @@ from enum import Enum from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Optional +import requests + +from ..data_plane import DataPlaneService from ..errors import OperationFailed from ._internal import Wait, _enum, _from_dict, _repeated_dict @@ -3335,3 +3338,118 @@ def update_permissions( body=body, headers=headers) return ServingEndpointPermissions.from_dict(res) + + +class ServingEndpointsDataPlaneAPI: + """Serving endpoints DataPlane provides a set of operations to interact with data plane endpoints for Serving + endpoints service.""" + + def __init__(self, api_client, control_plane): + self._api = api_client + self._control_plane = control_plane + self._data_plane_service = DataPlaneService() + + def query(self, + name: str, + *, + dataframe_records: Optional[List[Any]] = None, + dataframe_split: Optional[DataframeSplitInput] = None, + extra_params: Optional[Dict[str, str]] = None, + input: Optional[Any] = None, + inputs: Optional[Any] = None, + instances: Optional[List[Any]] = None, + max_tokens: Optional[int] = None, + messages: Optional[List[ChatMessage]] = None, + n: Optional[int] = None, + prompt: Optional[Any] = None, + stop: Optional[List[str]] = None, + stream: Optional[bool] = None, + temperature: Optional[float] = None) -> QueryEndpointResponse: + """Query a serving endpoint. + + :param name: str + The name of the serving endpoint. This field is required. + :param dataframe_records: List[Any] (optional) + Pandas Dataframe input in the records orientation. + :param dataframe_split: :class:`DataframeSplitInput` (optional) + Pandas Dataframe input in the split orientation. + :param extra_params: Dict[str,str] (optional) + The extra parameters field used ONLY for __completions, chat,__ and __embeddings external & + foundation model__ serving endpoints. This is a map of strings and should only be used with other + external/foundation model query fields. + :param input: Any (optional) + The input string (or array of strings) field used ONLY for __embeddings external & foundation + model__ serving endpoints and is the only field (along with extra_params if needed) used by + embeddings queries. + :param inputs: Any (optional) + Tensor-based input in columnar format. + :param instances: List[Any] (optional) + Tensor-based input in row format. + :param max_tokens: int (optional) + The max tokens field used ONLY for __completions__ and __chat external & foundation model__ serving + endpoints. This is an integer and should only be used with other chat/completions query fields. + :param messages: List[:class:`ChatMessage`] (optional) + The messages field used ONLY for __chat external & foundation model__ serving endpoints. This is a + map of strings and should only be used with other chat query fields. + :param n: int (optional) + The n (number of candidates) field used ONLY for __completions__ and __chat external & foundation + model__ serving endpoints. This is an integer between 1 and 5 with a default of 1 and should only be + used with other chat/completions query fields. + :param prompt: Any (optional) + The prompt string (or array of strings) field used ONLY for __completions external & foundation + model__ serving endpoints and should only be used with other completions query fields. + :param stop: List[str] (optional) + The stop sequences field used ONLY for __completions__ and __chat external & foundation model__ + serving endpoints. This is a list of strings and should only be used with other chat/completions + query fields. + :param stream: bool (optional) + The stream field used ONLY for __completions__ and __chat external & foundation model__ serving + endpoints. This is a boolean defaulting to false and should only be used with other chat/completions + query fields. + :param temperature: float (optional) + The temperature field used ONLY for __completions__ and __chat external & foundation model__ serving + endpoints. This is a float between 0.0 and 2.0 with a default of 1.0 and should only be used with + other chat/completions query fields. + + :returns: :class:`QueryEndpointResponse` + """ + body = {} + if dataframe_records is not None: body['dataframe_records'] = [v for v in dataframe_records] + if dataframe_split is not None: body['dataframe_split'] = dataframe_split.as_dict() + if extra_params is not None: body['extra_params'] = extra_params + if input is not None: body['input'] = input + if inputs is not None: body['inputs'] = inputs + if instances is not None: body['instances'] = [v for v in instances] + if max_tokens is not None: body['max_tokens'] = max_tokens + if messages is not None: body['messages'] = [v.as_dict() for v in messages] + if n is not None: body['n'] = n + if prompt is not None: body['prompt'] = prompt + if stop is not None: body['stop'] = [v for v in stop] + if stream is not None: body['stream'] = stream + if temperature is not None: body['temperature'] = temperature + + def info_getter(): + response = self._control_plane.get(name=name, ) + if response.data_plane_info is None: + raise Exception("Resource does not support direct Data Plane access") + return response.data_plane_info.query_info + + get_params = [name, ] + data_plane_details = self._data_plane_service.get_data_plane_details('query', get_params, info_getter, + self._api.get_oauth_token) + token = data_plane_details.token + + def auth(r: requests.PreparedRequest) -> requests.PreparedRequest: + authorization = f"{token.token_type} {token.access_token}" + r.headers["Authorization"] = authorization + return r + + headers = {'Accept': 'application/json', 'Content-Type': 'application/json', } + response_headers = ['served-model-name', ] + res = self._api.do('POST', + url=data_plane_details.endpoint_url, + body=body, + headers=headers, + response_headers=response_headers, + auth=auth) + return QueryEndpointResponse.from_dict(res) diff --git a/databricks/sdk/service/sql.py b/databricks/sdk/service/sql.py index fa7f93f6..b363ab7d 100755 --- a/databricks/sdk/service/sql.py +++ b/databricks/sdk/service/sql.py @@ -360,6 +360,7 @@ def from_dict(cls, d: Dict[str, any]) -> ChannelInfo: class ChannelName(Enum): + """Name of the channel""" CHANNEL_NAME_CURRENT = 'CHANNEL_NAME_CURRENT' CHANNEL_NAME_CUSTOM = 'CHANNEL_NAME_CUSTOM' diff --git a/docs/dbdataclasses/sql.rst b/docs/dbdataclasses/sql.rst index adf3ced5..fe1469a3 100644 --- a/docs/dbdataclasses/sql.rst +++ b/docs/dbdataclasses/sql.rst @@ -64,6 +64,8 @@ These dataclasses are used in the SDK to represent API requests and responses fo .. py:class:: ChannelName + Name of the channel + .. py:attribute:: CHANNEL_NAME_CURRENT :value: "CHANNEL_NAME_CURRENT" diff --git a/docs/workspace/catalog/endpoints.rst b/docs/workspace/catalog/endpoints.rst new file mode 100644 index 00000000..8c6efba4 --- /dev/null +++ b/docs/workspace/catalog/endpoints.rst @@ -0,0 +1,35 @@ +``w.endpoints``: Online Endpoints +================================= +.. currentmodule:: databricks.sdk.service.catalog + +.. py:class:: EndpointsAPI + + Endpoints are used to connect to PG clusters. + + .. py:method:: create( [, endpoint: Optional[Endpoint]]) -> Endpoint + + Create an Endpoint. + + :param endpoint: :class:`Endpoint` (optional) + Endpoint + + :returns: :class:`Endpoint` + + + .. py:method:: delete(name: str) + + Delete an Endpoint. + + :param name: str + + + + + .. py:method:: get(name: str) -> Endpoint + + Get an Endpoint. + + :param name: str + + :returns: :class:`Endpoint` + \ No newline at end of file diff --git a/docs/workspace/serving/index.rst b/docs/workspace/serving/index.rst index ce3d216f..1d0bdf7f 100644 --- a/docs/workspace/serving/index.rst +++ b/docs/workspace/serving/index.rst @@ -8,4 +8,5 @@ Use real-time inference for machine learning :maxdepth: 1 apps - serving_endpoints \ No newline at end of file + serving_endpoints + serving_endpoints_data_plane \ No newline at end of file diff --git a/docs/workspace/serving/serving_endpoints_data_plane.rst b/docs/workspace/serving/serving_endpoints_data_plane.rst new file mode 100644 index 00000000..8fb09e7f --- /dev/null +++ b/docs/workspace/serving/serving_endpoints_data_plane.rst @@ -0,0 +1,59 @@ +``w.serving_endpoints_data_plane``: Serving endpoints DataPlane +=============================================================== +.. currentmodule:: databricks.sdk.service.serving + +.. py:class:: ServingEndpointsDataPlaneAPI + + Serving endpoints DataPlane provides a set of operations to interact with data plane endpoints for Serving + endpoints service. + + .. py:method:: query(name: str [, dataframe_records: Optional[List[Any]], dataframe_split: Optional[DataframeSplitInput], extra_params: Optional[Dict[str, str]], input: Optional[Any], inputs: Optional[Any], instances: Optional[List[Any]], max_tokens: Optional[int], messages: Optional[List[ChatMessage]], n: Optional[int], prompt: Optional[Any], stop: Optional[List[str]], stream: Optional[bool], temperature: Optional[float]]) -> QueryEndpointResponse + + Query a serving endpoint. + + :param name: str + The name of the serving endpoint. This field is required. + :param dataframe_records: List[Any] (optional) + Pandas Dataframe input in the records orientation. + :param dataframe_split: :class:`DataframeSplitInput` (optional) + Pandas Dataframe input in the split orientation. + :param extra_params: Dict[str,str] (optional) + The extra parameters field used ONLY for __completions, chat,__ and __embeddings external & + foundation model__ serving endpoints. This is a map of strings and should only be used with other + external/foundation model query fields. + :param input: Any (optional) + The input string (or array of strings) field used ONLY for __embeddings external & foundation + model__ serving endpoints and is the only field (along with extra_params if needed) used by + embeddings queries. + :param inputs: Any (optional) + Tensor-based input in columnar format. + :param instances: List[Any] (optional) + Tensor-based input in row format. + :param max_tokens: int (optional) + The max tokens field used ONLY for __completions__ and __chat external & foundation model__ serving + endpoints. This is an integer and should only be used with other chat/completions query fields. + :param messages: List[:class:`ChatMessage`] (optional) + The messages field used ONLY for __chat external & foundation model__ serving endpoints. This is a + map of strings and should only be used with other chat query fields. + :param n: int (optional) + The n (number of candidates) field used ONLY for __completions__ and __chat external & foundation + model__ serving endpoints. This is an integer between 1 and 5 with a default of 1 and should only be + used with other chat/completions query fields. + :param prompt: Any (optional) + The prompt string (or array of strings) field used ONLY for __completions external & foundation + model__ serving endpoints and should only be used with other completions query fields. + :param stop: List[str] (optional) + The stop sequences field used ONLY for __completions__ and __chat external & foundation model__ + serving endpoints. This is a list of strings and should only be used with other chat/completions + query fields. + :param stream: bool (optional) + The stream field used ONLY for __completions__ and __chat external & foundation model__ serving + endpoints. This is a boolean defaulting to false and should only be used with other chat/completions + query fields. + :param temperature: float (optional) + The temperature field used ONLY for __completions__ and __chat external & foundation model__ serving + endpoints. This is a float between 0.0 and 2.0 with a default of 1.0 and should only be used with + other chat/completions query fields. + + :returns: :class:`QueryEndpointResponse` + \ No newline at end of file diff --git a/tests/test_data_plane.py b/tests/test_data_plane.py new file mode 100644 index 00000000..a7465896 --- /dev/null +++ b/tests/test_data_plane.py @@ -0,0 +1,59 @@ +from datetime import datetime, timedelta + +from databricks.sdk.data_plane import DataPlaneService +from databricks.sdk.oauth import Token +from databricks.sdk.service.oauth2 import DataPlaneInfo + +info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url") + +token = Token(access_token="token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) + + +class MockRefresher: + + def __init__(self, expected: str): + self._expected = expected + + def __call__(self, auth_details: str) -> Token: + assert self._expected == auth_details + return token + + +def throw_exception(): + raise Exception("Expected value to be cached") + + +def test_not_cached(): + data_plane = DataPlaneService() + res = data_plane.get_data_plane_details("method", ["params"], lambda: info, + lambda a: MockRefresher(info.authorization_details).__call__(a)) + assert res.endpoint_url == info.endpoint_url + assert res.token == token + + +def test_token_expired(): + expired = Token(access_token="expired", token_type="type", expiry=datetime.now() + timedelta(hours=-1)) + data_plane = DataPlaneService() + data_plane._tokens["method/params"] = expired + res = data_plane.get_data_plane_details("method", ["params"], lambda: info, + lambda a: MockRefresher(info.authorization_details).__call__(a)) + assert res.endpoint_url == info.endpoint_url + assert res.token == token + + +def test_info_cached(): + data_plane = DataPlaneService() + data_plane._data_plane_info["method/params"] = info + res = data_plane.get_data_plane_details("method", ["params"], throw_exception, + lambda a: MockRefresher(info.authorization_details).__call__(a)) + assert res.endpoint_url == info.endpoint_url + assert res.token == token + + +def test_token_cached(): + data_plane = DataPlaneService() + data_plane._data_plane_info["method/params"] = info + data_plane._tokens["method/params"] = token + res = data_plane.get_data_plane_details("method", ["params"], throw_exception, throw_exception) + assert res.endpoint_url == info.endpoint_url + assert res.token == token