Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry with new Access Token on 419 response #340

Merged
merged 10 commits into from
Feb 20, 2024
28 changes: 21 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from pydantic import Field, ValidationError
from requests import HTTPError, Session
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt

from pyiceberg import __version__
from pyiceberg.catalog import (
Expand Down Expand Up @@ -211,6 +212,11 @@ def __init__(self, name: str, **properties: str):
self._fetch_config()
self._session = self._create_session()

@staticmethod
def _retry_hook(retry_state: RetryCallState) -> None:
rest_catalog: RestCatalog = retry_state.args[0]
rest_catalog._refresh_token() # pylint: disable=protected-access

def _create_session(self) -> Session:
"""Create a request session with provided catalog configuration."""
session = Session()
Expand Down Expand Up @@ -438,6 +444,16 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response:
catalog=self,
)

def _refresh_token(self) -> None:
anupam-saini marked this conversation as resolved.
Show resolved Hide resolved
session: Session = self._session
# If we have credentials, fetch a new token
if CREDENTIAL in self.properties:
self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])
# Set Auth token for subsequent calls in the session
if token := self.properties.get(TOKEN):
session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def create_table(
self,
identifier: Union[str, Identifier],
Expand Down Expand Up @@ -472,6 +488,7 @@ def create_table(
table_response = TableResponse(**response.json())
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
anupam-saini marked this conversation as resolved.
Show resolved Hide resolved
def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table:
"""Register a new table using existing metadata.

Expand Down Expand Up @@ -503,6 +520,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
table_response = TableResponse(**response.json())
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
Expand All @@ -513,6 +531,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
return [(*table.namespace, table.name) for table in ListTablesResponse(**response.json()).identifiers]

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def load_table(self, identifier: Union[str, Identifier]) -> Table:
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
response = self._session.get(
Expand All @@ -526,6 +545,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
table_response = TableResponse(**response.json())
return self._response_to_table(identifier_tuple, table_response)

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None:
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
response = self._session.delete(
Expand All @@ -538,9 +558,11 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool =
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchTableError})

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def purge_table(self, identifier: Union[str, Identifier]) -> None:
self.drop_table(identifier=identifier, purge_requested=True)

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table:
from_identifier_tuple = self.identifier_to_tuple_without_catalog(from_identifier)
payload = {
Expand Down Expand Up @@ -585,6 +607,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
)
return CommitTableResponse(**response.json())

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
payload = {"namespace": namespace_tuple, "properties": properties}
Expand All @@ -594,6 +617,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceAlreadyExistsError})

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
Expand All @@ -603,6 +627,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
namespace_tuple = self.identifier_to_tuple(namespace)
response = self._session.get(
Expand All @@ -620,6 +645,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi
namespaces = ListNamespaceResponse(**response.json())
return [namespace_tuple + child_namespace for child_namespace in namespaces.namespaces]

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
Expand All @@ -631,6 +657,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper

return NamespaceResponse(**response.json()).properties

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
def update_namespace_properties(
self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT
) -> PropertiesUpdateSummary:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ sortedcontainers = "2.4.0"
fsspec = ">=2023.1.0,<2024.1.0"
pyparsing = ">=3.1.0,<4.0.0"
zstandard = ">=0.13.0,<1.0.0"
tenacity = ">=8.2.3,<9.0.0"
pyarrow = { version = ">=9.0.0,<16.0.0", optional = true }
pandas = { version = ">=1.0.0,<3.0.0", optional = true }
duckdb = { version = ">=0.5.0,<1.0.0", optional = true }
Expand Down Expand Up @@ -295,6 +296,10 @@ ignore_missing_imports = true
module = "setuptools.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "tenacity.*"
ignore_missing_imports = true

[tool.coverage.run]
source = ['pyiceberg/']

Expand Down
72 changes: 72 additions & 0 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pyiceberg.catalog import PropertiesUpdateSummary, Table, load_catalog
from pyiceberg.catalog.rest import AUTH_URL, RestCatalog
from pyiceberg.exceptions import (
AuthorizationExpiredError,
NamespaceAlreadyExistsError,
NoSuchNamespaceError,
NoSuchTableError,
Expand Down Expand Up @@ -266,6 +267,48 @@ def test_list_namespace_with_parent_200(rest_mock: Mocker) -> None:
]


def test_list_namespaces_419(rest_mock: Mocker) -> None:
new_token = "new_jwt_token"
new_header = dict(TEST_HEADERS)
new_header["Authorization"] = f"Bearer {new_token}"

rest_mock.post(
f"{TEST_URI}v1/namespaces",
json={
"error": {
"message": "Authorization expired.",
"type": "AuthorizationExpiredError",
"code": 419,
}
},
status_code=419,
request_headers=TEST_HEADERS,
)
rest_mock.post(
f"{TEST_URI}v1/oauth/tokens",
json={
"access_token": new_token,
"token_type": "Bearer",
"expires_in": 86400,
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
},
status_code=200,
)
rest_mock.get(
f"{TEST_URI}v1/namespaces",
json={"namespaces": [["default"], ["examples"], ["fokko"], ["system"]]},
status_code=200,
request_headers=new_header,
)
catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, credential=TEST_CREDENTIALS)
assert catalog.list_namespaces() == [
("default",),
("examples",),
("fokko",),
("system",),
]


def test_create_namespace_200(rest_mock: Mocker) -> None:
namespace = "leden"
rest_mock.post(
Expand Down Expand Up @@ -517,6 +560,35 @@ def test_create_table_409(rest_mock: Mocker, table_schema_simple: Schema) -> Non
assert "Table already exists" in str(e.value)


def test_create_table_419(rest_mock: Mocker, table_schema_simple: Schema) -> None:
rest_mock.post(
f"{TEST_URI}v1/namespaces/fokko/tables",
json={
"error": {
"message": "Authorization expired.",
"type": "AuthorizationExpiredError",
"code": 419,
}
},
status_code=419,
request_headers=TEST_HEADERS,
)

with pytest.raises(AuthorizationExpiredError) as e:
RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).create_table(
identifier=("fokko", "fokko2"),
schema=table_schema_simple,
location=None,
partition_spec=PartitionSpec(
PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=3), name="id")
),
sort_order=SortOrder(SortField(source_id=2, transform=IdentityTransform())),
properties={"owner": "fokko"},
)
assert "Authorization expired" in str(e.value)
assert rest_mock.call_count == 3


def test_register_table_200(
rest_mock: Mocker, table_schema_simple: Schema, example_table_metadata_no_snapshot_v1_rest_json: Dict[str, Any]
) -> None:
Expand Down
Loading