Skip to content

Commit

Permalink
Allow connecting to external servers enrolled as ZenML Pro tenants
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Dec 16, 2024
1 parent 1e51454 commit 2d20c2e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 26 deletions.
18 changes: 10 additions & 8 deletions src/zenml/cli/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def connect_to_server(
api_key: Optional[str] = None,
verify_ssl: Union[str, bool] = True,
refresh: bool = False,
pro_server: bool = False,
) -> None:
"""Connect the client to a ZenML server or a SQL database.
Expand All @@ -154,6 +155,7 @@ def connect_to_server(
verify_ssl: Whether to verify the server's TLS certificate. If a string
is passed, it is interpreted as the path to a CA bundle file.
refresh: Whether to force a new login flow with the ZenML server.
pro_server: Whether the server is a ZenML Pro server.
"""
from zenml.login.credentials_store import get_credentials_store
from zenml.zen_stores.base_zen_store import BaseZenStore
Expand All @@ -170,7 +172,12 @@ def connect_to_server(
f"Authenticating to ZenML server '{url}' using an API key..."
)
credentials_store.set_api_key(url, api_key)
elif not is_zenml_pro_server_url(url):
elif pro_server:
# We don't have to do anything here assuming the user has already
# logged in to the ZenML Pro server using the ZenML Pro web login
# flow.
cli_utils.declare(f"Authenticating to ZenML server '{url}'...")
else:
if refresh or not credentials_store.has_valid_authentication(url):
cli_utils.declare(
f"Authenticating to ZenML server '{url}' using the web "
Expand All @@ -179,11 +186,6 @@ def connect_to_server(
web_login(url=url, verify_ssl=verify_ssl)
else:
cli_utils.declare(f"Connecting to ZenML server '{url}'...")
else:
# We don't have to do anything here assuming the user has already
# logged in to the ZenML Pro server using the ZenML Pro web login
# flow.
cli_utils.declare(f"Authenticating to ZenML server '{url}'...")

rest_store_config = RestZenStoreConfiguration(
url=url,
Expand Down Expand Up @@ -277,7 +279,7 @@ def connect_to_pro_server(
# server to connect to.
if api_key:
if server_url:
connect_to_server(server_url, api_key=api_key)
connect_to_server(server_url, api_key=api_key, pro_server=True)
return
else:
raise ValueError(
Expand Down Expand Up @@ -405,7 +407,7 @@ def connect_to_pro_server(
f"Connecting to ZenML Pro server: {server.name} [{str(server.id)}] "
)

connect_to_server(server.url, api_key=api_key)
connect_to_server(server.url, api_key=api_key, pro_server=True)

# Update the stored server info with more accurate data taken from the
# ZenML Pro tenant object.
Expand Down
8 changes: 8 additions & 0 deletions src/zenml/models/v2/misc/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ def is_local(self) -> bool:
# server ID is the same as the local client (user) ID.
return self.id == GlobalConfiguration().user_id

def is_pro_server(self) -> bool:
"""Return whether the server is a ZenML Pro server.
Returns:
True if the server is a ZenML Pro server, False otherwise.
"""
return self.deployment_type == ServerDeploymentType.CLOUD


class ServerLoadInfo(BaseModel):
"""Domain model for ZenML server load information."""
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/zen_server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,12 +1003,12 @@ async def __call__(self, request: Request) -> Optional[str]:


def oauth2_authentication(
request: Request,
token: str = Depends(
CookieOAuth2TokenBearer(
tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN,
)
),
request: Request = Depends(),
) -> AuthContext:
"""Authenticates any request to the ZenML server with OAuth2 JWT tokens.
Expand Down
34 changes: 17 additions & 17 deletions src/zenml/zen_stores/rest_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ class RestZenStore(BaseZenStore):
CONFIG_TYPE: ClassVar[Type[StoreConfiguration]] = RestZenStoreConfiguration
_api_token: Optional[APIToken] = None
_session: Optional[requests.Session] = None
_server_info: Optional[ServerModel] = None

# ====================================
# ZenML Store interface implementation
Expand All @@ -469,7 +470,7 @@ def _initialize(self) -> None:
"""
try:
client_version = zenml.__version__
server_version = self.get_store_info().version
server_version = self.server_info.version

# Handle cases where the ZenML server is not available
except ConnectionError as e:
Expand Down Expand Up @@ -522,22 +523,34 @@ def _initialize(self) -> None:
ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING,
)

@property
def server_info(self) -> ServerModel:
"""Get cached information about the server.
Returns:
Cached information about the server.
"""
if self._server_info is None:
return self.get_store_info()
return self._server_info

def get_store_info(self) -> ServerModel:
"""Get information about the server.
Returns:
Information about the server.
"""
body = self.get(INFO)
return ServerModel.model_validate(body)
self._server_info = ServerModel.model_validate(body)
return self._server_info

def get_deployment_id(self) -> UUID:
"""Get the ID of the deployment.
Returns:
The ID of the deployment.
"""
return self.get_store_info().id
return self.server_info.id

# -------------------- Server Settings --------------------

Expand Down Expand Up @@ -4028,19 +4041,6 @@ def get_or_generate_api_token(self) -> str:
token = credentials.api_token if credentials else None
if credentials and token and not token.expired:
self._api_token = token

# Populate the server info in the credentials store if it is
# not already present
if not credentials.server_id:
try:
server_info = self.get_store_info()
except Exception as e:
logger.warning(f"Failed to get server info: {e}.")
else:
credentials_store.update_server_info(
self.url, server_info
)

return self._api_token.access_token

# Token is expired or not found in the cache. Time to get a new one.
Expand Down Expand Up @@ -4084,7 +4084,7 @@ def get_or_generate_api_token(self) -> str:
"username": username,
"password": password,
}
elif is_zenml_pro_server_url(self.url):
elif self.server_info.is_pro_server():
# ZenML Pro tenants use a proprietary authorization grant
# where the ZenML Pro API session token is exchanged for a
# regular ZenML server access token.
Expand Down

0 comments on commit 2d20c2e

Please sign in to comment.