diff --git a/docs/source/dev/types.rst b/docs/source/dev/types.rst index f9314f14..ca39eab5 100644 --- a/docs/source/dev/types.rst +++ b/docs/source/dev/types.rst @@ -67,3 +67,8 @@ Timeout ======= .. autoclass:: uplink.Timeout + +Context +======= + +.. autoclass:: uplink.Context diff --git a/tests/integration/test_handlers.py b/tests/integration/test_handlers.py index 968ad376..d0bdfe77 100644 --- a/tests/integration/test_handlers.py +++ b/tests/integration/test_handlers.py @@ -35,22 +35,28 @@ def handle_error(exc_type, exc_value, exc_tb): class Calendar(uplink.Consumer): @handle_response_with_consumer - @uplink.get("/calendar/{todo_id}") + @uplink.get("todos/{todo_id}") def get_todo(self, todo_id): pass @handle_response - @uplink.get("/calendar/{name}") + @uplink.get("months/{name}/todos") def get_month(self, name): pass + @handle_response_with_consumer + @handle_response + @uplink.get("months/{month}/days/{day}/todos") + def get_day(self, month, day): + pass + @handle_error_with_consumer - @uplink.get("/calendar/{user_id}") + @uplink.get("users/{user_id}") def get_user(self, user_id): pass @handle_error - @uplink.get("/calendar/{event_id}") + @uplink.get("events/{event_id}") def get_event(self, event_id): pass @@ -76,6 +82,17 @@ def test_response_handler(mock_client): assert response.flagged is True +def test_multiple_response_handlers(mock_client): + calendar = Calendar(base_url=BASE_URL, client=mock_client) + + # Run + response = calendar.get_day("September", 2) + + # Verify + assert response.flagged + assert calendar.flagged + + def test_error_handler_with_consumer(mock_client): # Setup: raise specific exception expected_error = IOError() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9224f419..1851a32f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -15,12 +15,6 @@ def http_client_mock(mocker): return client -@pytest.fixture -def request_mock(mocker): - # TODO: Remove - return None - - @pytest.fixture def transaction_hook_mock(mocker): return mocker.Mock(spec=hooks.TransactionHook) @@ -70,6 +64,7 @@ def uplink_builder_mock(mocker): def request_builder(mocker): builder = mocker.MagicMock(spec=helpers.RequestBuilder) builder.info = collections.defaultdict(dict) + builder.context = {} builder.get_converter.return_value = lambda x: x builder.client.exceptions = Exceptions() return builder diff --git a/tests/unit/test_arguments.py b/tests/unit/test_arguments.py index 826fcd62..c8e9be00 100644 --- a/tests/unit/test_arguments.py +++ b/tests/unit/test_arguments.py @@ -276,9 +276,7 @@ def test_modify_request_with_mismatched_encoding(self, request_builder): ) def test_skip_none(self, request_builder): - arguments.Query("name").modify_request( - request_builder, None - ) + arguments.Query("name").modify_request(request_builder, None) assert request_builder.info["params"] == {} def test_encode_none(self, request_builder): @@ -433,3 +431,25 @@ class TestTimeout(ArgumentTestCase, FuncDecoratorTestCase): def test_modify_request(self, request_builder): arguments.Timeout().modify_request(request_builder, 10) assert request_builder.info["timeout"] == 10 + + +class TestContext(ArgumentTestCase, FuncDecoratorTestCase): + type_cls = arguments.Context + expected_converter_key = keys.Identity() + + def test_modify_request(self, request_builder): + arguments.Context("key").modify_request(request_builder, "value") + assert request_builder.context["key"] == "value" + + +class TestContextMap(ArgumentTestCase, FuncDecoratorTestCase): + type_cls = arguments.ContextMap + expected_converter_key = keys.Identity() + + def test_modify_request(self, request_builder): + arguments.ContextMap().modify_request(request_builder, {"key": "value"}) + assert request_builder.context == {"key": "value"} + + def test_modify_request_not_mapping(self, request_builder): + with pytest.raises(TypeError): + arguments.ContextMap().modify_request(request_builder, "value") diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py index 355cbca7..d4652f11 100644 --- a/tests/unit/test_builder.py +++ b/tests/unit/test_builder.py @@ -48,7 +48,7 @@ def test_prepare_request_with_transaction_hook( request_builder.url = "/example/path" request_builder.request_template = "request_template" uplink_builder.base_url = "https://example.com" - uplink_builder.add_hook(transaction_hook_mock) + request_builder.transaction_hooks = [transaction_hook_mock] request_preparer = builder.RequestPreparer(uplink_builder) execution_builder = mocker.Mock(spec=io.RequestExecutionBuilder) request_preparer.prepare_request(request_builder, execution_builder) @@ -61,10 +61,25 @@ def test_prepare_request_with_transaction_hook( execution_builder.with_io.assert_called_with(uplink_builder.client.io()) execution_builder.with_template(request_builder.request_template) - def test_create_request_builder(self, uplink_builder, request_definition): + def test_create_request_builder(self, mocker, request_definition): + uplink_builder = mocker.Mock(spec=builder.Builder) + uplink_builder.converters = () + uplink_builder.hooks = () + request_definition.make_converter_registry.return_value = {} + request_preparer = builder.RequestPreparer(uplink_builder) + request = request_preparer.create_request_builder(request_definition) + assert isinstance(request, helpers.RequestBuilder) + + def test_create_request_builder_with_session_hooks( + self, mocker, request_definition, transaction_hook_mock + ): + uplink_builder = mocker.Mock(spec=builder.Builder) + uplink_builder.converters = () + uplink_builder.hooks = (transaction_hook_mock,) request_definition.make_converter_registry.return_value = {} request_preparer = builder.RequestPreparer(uplink_builder) request = request_preparer.create_request_builder(request_definition) + assert transaction_hook_mock.audit_request.called assert isinstance(request, helpers.RequestBuilder) diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 6720292a..7f8f7fe9 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -43,3 +43,13 @@ def test_add_transaction_hook(self, transaction_hook_mock): # Verify assert list(builder.transaction_hooks) == [transaction_hook_mock] + + def test_context(self): + # Setup + builder = helpers.RequestBuilder(None, {}, "base_url") + + # Run + builder.context["key"] = "value" + + # Verify + assert builder.context["key"] == "value" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 16036b2b..eb6f8751 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -53,3 +53,15 @@ def test_auth_set(uplink_builder_mock): # Verify assert ("username", "password") == uplink_builder_mock.auth + + +def test_context(uplink_builder_mock): + # Setup + sess = session.Session(uplink_builder_mock) + + # Run + sess.context["key"] = "value" + + # Verify + assert uplink_builder_mock.add_hook.called + assert sess.context == {"key": "value"} diff --git a/uplink/__init__.py b/uplink/__init__.py index 9e8ec9cb..4c90a818 100644 --- a/uplink/__init__.py +++ b/uplink/__init__.py @@ -41,6 +41,7 @@ Body, Url, Timeout, + Context, ) from uplink.ratelimit import ratelimit from uplink.retry import retry @@ -90,6 +91,7 @@ "Body", "Url", "Timeout", + "Context", "retry", "ratelimit", ] diff --git a/uplink/arguments.py b/uplink/arguments.py index 979ec53d..a888c10e 100644 --- a/uplink/arguments.py +++ b/uplink/arguments.py @@ -24,6 +24,7 @@ "Body", "Url", "Timeout", + "Context", ] @@ -683,7 +684,7 @@ def _modify_request(cls, request_builder, value): class Timeout(FuncDecoratorMixin, ArgumentAnnotation): """ - Pass a timeout as a method argument at runtime. + Passes a timeout as a method argument at runtime. While :py:class:`uplink.timeout` attaches static timeout to all requests sent from a consumer method, this class turns a method argument into a @@ -696,7 +697,6 @@ class Timeout(FuncDecoratorMixin, ArgumentAnnotation): def get_posts(self, timeout: Timeout() = 60): \"""Fetch all posts for the current users giving up after given number of seconds.\""" - """ @property @@ -711,3 +711,78 @@ def converter_key(self): def _modify_request(self, request_builder, value): """Modifies request timeout.""" request_builder.info["timeout"] = value + + +class Context(FuncDecoratorMixin, NamedArgument): + """ + Defines a name-value pair that is accessible to middleware at + runtime. + + Request middleware can leverage this annotation to give users + control over the middleware's behavior. + + Example: + Consider a custom decorator :obj:`@cache` (this would be a + subclass of :class:`uplink.decorators.MethodAnnotation`): + + .. code-block:: python + + @cache(hours=3) + @get("users/user_id") + def get_user(self, user_id) + \"""Retrieves a single user.\""" + + As its name suggests, the :obj:`@cache` decorator enables + caching server responses so that, once a request is cached, + subsequent identical requests can be served by the cache + provider rather than adding load to the upstream service. + + Importantly, the writers of the :obj:`@cache` decorators can + allow users to pass their own cache provider implementation + through an argument annotated with :class:`Context `: + + .. code-block:: python + + @cache(hours=3) + @get("users/user_id") + def get_user(self, user_id, cache_provider: Context) + \"""Retrieves a single user.\""" + + To add a name-value pair to the context of any request made from + a :class:`Consumer ` instance, you + can use the :attr:`Consumer.session.context + ` property. Alternatively, you + can annotate a constructor argument of a :class:`Consumer + ` subclass with :class:`Context + `, as explained + :ref:`here `. + """ + + @property + def converter_key(self): + """Do not convert passed argument.""" + return keys.Identity() + + def _modify_request(self, request_builder, value): + """Sets the name-value pair in the context.""" + request_builder.context[self.name] = value + + +class ContextMap(FuncDecoratorMixin, ArgumentAnnotation): + """ + Defines a mapping of name-value pairs that are accessible to + middleware at runtime. + """ + + @property + def converter_key(self): + """Do not convert passed argument.""" + return keys.Identity() + + def _modify_request(self, request_builder, value): + """Updates the context with the given name-value pairs.""" + if not isinstance(value, collections.Mapping): + raise TypeError( + "ContextMap requires a mapping; got %s instead.", type(value) + ) + request_builder.context.update(value) diff --git a/uplink/builder.py b/uplink/builder.py index 79343197..12afdf54 100644 --- a/uplink/builder.py +++ b/uplink/builder.py @@ -22,41 +22,55 @@ class RequestPreparer(object): def __init__(self, builder, consumer=None): - self._hooks = list(builder.hooks) self._client = builder.client self._base_url = str(builder.base_url) self._converters = list(builder.converters) self._auth = builder.auth self._consumer = consumer + if builder.hooks: + self._session_chain = hooks_.TransactionHookChain(*builder.hooks) + else: + self._session_chain = None + def _join_url_with_base(self, url): return utils.urlparse.urljoin(self._base_url, url) - def _get_hook_chain(self, contract): + @staticmethod + def _get_request_hooks(contract): chain = list(contract.transaction_hooks) if callable(contract.return_type): chain.append(hooks_.ResponseHandler(contract.return_type)) - chain.extend(self._hooks) return chain def _wrap_hook(self, func): return functools.partial(func, self._consumer) - def apply_hooks(self, execution_builder, chain, request_builder): - hook = hooks_.TransactionHookChain(*chain) - hook.audit_request(self._consumer, request_builder) - if hook.handle_response is not None: + def apply_hooks(self, execution_builder, chain): + # TODO: + # Instead of creating a TransactionChain, we could simply + # add each response and error handler in the chain to the + # execution builder. This would allow heterogenous response + # and error handlers. Right now, the TransactionChain + # enforces that all response/error handlers are blocking if + # any response/error handler is blocking, which is + # unnecessary now that we delegate execution to an IO layer. + if chain.handle_response is not None: execution_builder.with_callbacks( - self._wrap_hook(hook.handle_response) + self._wrap_hook(chain.handle_response) ) - execution_builder.with_errbacks(self._wrap_hook(hook.handle_exception)) + execution_builder.with_errbacks(self._wrap_hook(chain.handle_exception)) def prepare_request(self, request_builder, execution_builder): request_builder.url = self._join_url_with_base(request_builder.url) self._auth(request_builder) - chain = self._get_hook_chain(request_builder) - if chain: - self.apply_hooks(execution_builder, chain, request_builder) + request_hooks = self._get_request_hooks(request_builder) + if request_hooks: + chain = hooks_.TransactionHookChain(*request_hooks) + chain.audit_request(self._consumer, request_builder) + self.apply_hooks(execution_builder, chain) + if self._session_chain: + self.apply_hooks(execution_builder, self._session_chain) execution_builder.with_client(self._client) execution_builder.with_io(self._client.io()) @@ -64,7 +78,10 @@ def prepare_request(self, request_builder, execution_builder): def create_request_builder(self, definition): registry = definition.make_converter_registry(self._converters) - return helpers.RequestBuilder(self._client, registry, self._base_url) + req = helpers.RequestBuilder(self._client, registry, self._base_url) + if self._session_chain: + self._session_chain.audit_request(self._consumer, req) + return req class CallFactory(object): diff --git a/uplink/clients/io/execution.py b/uplink/clients/io/execution.py index b65270b3..6396babd 100644 --- a/uplink/clients/io/execution.py +++ b/uplink/clients/io/execution.py @@ -26,11 +26,11 @@ def with_io(self, io): return self def with_callbacks(self, *callbacks): - self._callbacks = list(callbacks) + self._callbacks.extend(callbacks) return self def with_errbacks(self, *errbacks): - self._errbacks = list(errbacks) + self._errbacks.extend(errbacks) return self def build(self): diff --git a/uplink/helpers.py b/uplink/helpers.py index 2929ec36..82405107 100644 --- a/uplink/helpers.py +++ b/uplink/helpers.py @@ -47,6 +47,7 @@ def __init__(self, client, converter_registry, base_url): # TODO: Pass this in as constructor parameter # TODO: Delegate instantiations to uplink.HTTPClientAdapter self._info = collections.defaultdict(dict) + self._context = {} self._converter_registry = converter_registry self._transaction_hooks = [] @@ -80,6 +81,10 @@ def url(self, url): def info(self): return self._info + @property + def context(self): + return self._context + @property def transaction_hooks(self): return iter(self._transaction_hooks) diff --git a/uplink/session.py b/uplink/session.py index 5797fce3..e67d4d05 100644 --- a/uplink/session.py +++ b/uplink/session.py @@ -15,6 +15,7 @@ def __init__(self, builder): self.__builder = builder self.__params = None self.__headers = None + self.__context = None def create(self, consumer, definition): return self.__builder.build(definition, consumer) @@ -48,6 +49,17 @@ def params(self): self.inject(arguments.QueryMap().with_value(self.__params)) return self.__params + @property + def context(self): + """ + A dictionary of name-value pairs that are made available to + request middleware. + """ + if self.__context is None: + self.__context = {} + self.inject(arguments.ContextMap().with_value(self.__context)) + return self.__context + @property def auth(self): """The authentication object for this consumer instance."""