Skip to content

Commit

Permalink
Merge pull request #3029 from plotly/feat/hooks
Browse files Browse the repository at this point in the history
Add hooks
  • Loading branch information
T4rk1n authored Nov 8, 2024
2 parents 9705ae5 + 7e2f37e commit 5fe20c7
Show file tree
Hide file tree
Showing 10 changed files with 541 additions and 37 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ disable=fixme,
unnecessary-lambda-assignment,
broad-exception-raised,
consider-using-generator,
too-many-ancestors


# Enable the message, report, category or checker with the given id(s). You can
Expand Down
2 changes: 2 additions & 0 deletions dash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from ._patch import Patch # noqa: F401,E402
from ._jupyter import jupyter_dash # noqa: F401,E402

from ._hooks import hooks # noqa: F401,E402

ctx = callback_context


Expand Down
12 changes: 9 additions & 3 deletions dash/_callback.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import collections
import hashlib
from functools import wraps
from typing import Callable, Optional, Any, List, Tuple

from typing import Callable, Optional, Any, List, Tuple, Union


import flask

from .dependencies import (
handle_callback_args,
handle_grouped_callback_args,
Output,
ClientsideFunction,
Input,
)
from .development.base_component import ComponentRegistry
Expand Down Expand Up @@ -210,7 +213,10 @@ def validate_long_inputs(deps):
)


def clientside_callback(clientside_function, *args, **kwargs):
ClientsideFuncType = Union[str, ClientsideFunction]


def clientside_callback(clientside_function: ClientsideFuncType, *args, **kwargs):
return register_clientside_callback(
GLOBAL_CALLBACK_LIST,
GLOBAL_CALLBACK_MAP,
Expand Down Expand Up @@ -597,7 +603,7 @@ def register_clientside_callback(
callback_map,
config_prevent_initial_callbacks,
inline_scripts,
clientside_function,
clientside_function: ClientsideFuncType,
*args,
**kwargs,
):
Expand Down
231 changes: 231 additions & 0 deletions dash/_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import typing as _t

from importlib import metadata as _importlib_metadata

import typing_extensions as _tx
import flask as _f

from .exceptions import HookError
from .resources import ResourceType
from ._callback import ClientsideFuncType

if _t.TYPE_CHECKING:
from .dash import Dash
from .development.base_component import Component

ComponentType = _t.TypeVar("ComponentType", bound=Component)
LayoutType = _t.Union[ComponentType, _t.List[ComponentType]]
else:
LayoutType = None
ComponentType = None
Dash = None


HookDataType = _tx.TypeVar("HookDataType")


# pylint: disable=too-few-public-methods
class _Hook(_tx.Generic[HookDataType]):
def __init__(self, func, priority=0, final=False, data: HookDataType = None):
self.func = func
self.final = final
self.data = data
self.priority = priority

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)


class _Hooks:
def __init__(self) -> None:
self._ns = {
"setup": [],
"layout": [],
"routes": [],
"error": [],
"callback": [],
"index": [],
}
self._js_dist = []
self._css_dist = []
self._clientside_callbacks: _t.List[
_t.Tuple[ClientsideFuncType, _t.Any, _t.Any]
] = []

# final hooks are a single hook added to the end of regular hooks.
self._finals = {}

def add_hook(
self,
hook: str,
func: _t.Callable,
priority: _t.Optional[int] = None,
final=False,
data=None,
):
if final:
existing = self._finals.get(hook)
if existing:
raise HookError("Final hook already present")
self._finals[hook] = _Hook(func, final, data=data)
return
hks = self._ns.get(hook, [])

p = 0
if not priority and len(hks):
priority_max = max(h.priority for h in hks)
p = priority_max - 1

hks.append(_Hook(func, priority=p, data=data))
self._ns[hook] = sorted(hks, reverse=True, key=lambda h: h.priority)

def get_hooks(self, hook: str) -> _t.List[_Hook]:
final = self._finals.get(hook, None)
if final:
final = [final]
else:
final = []
return self._ns.get(hook, []) + final

def layout(self, priority: _t.Optional[int] = None, final: bool = False):
"""
Run a function when serving the layout, the return value
will be used as the layout.
"""

def _wrap(func: _t.Callable[[LayoutType], LayoutType]):
self.add_hook("layout", func, priority=priority, final=final)
return func

return _wrap

def setup(self, priority: _t.Optional[int] = None, final: bool = False):
"""
Can be used to get a reference to the app after it is instantiated.
"""

def _setup(func: _t.Callable[[Dash], None]):
self.add_hook("setup", func, priority=priority, final=final)
return func

return _setup

def route(
self,
name: _t.Optional[str] = None,
methods: _t.Sequence[str] = ("GET",),
priority: _t.Optional[int] = None,
final=False,
):
"""
Add a route to the Dash server.
"""

def wrap(func: _t.Callable[[], _f.Response]):
_name = name or func.__name__
self.add_hook(
"routes",
func,
priority=priority,
final=final,
data=dict(name=_name, methods=methods),
)
return func

return wrap

def error(self, priority: _t.Optional[int] = None, final=False):
"""Automatically add an error handler to the dash app."""

def _error(func: _t.Callable[[Exception], _t.Any]):
self.add_hook("error", func, priority=priority, final=final)
return func

return _error

def callback(self, *args, priority: _t.Optional[int] = None, final=False, **kwargs):
"""
Add a callback to all the apps with the hook installed.
"""

def wrap(func):
self.add_hook(
"callback",
func,
priority=priority,
final=final,
data=(list(args), dict(kwargs)),
)
return func

return wrap

def clientside_callback(
self, clientside_function: ClientsideFuncType, *args, **kwargs
):
"""
Add a callback to all the apps with the hook installed.
"""
self._clientside_callbacks.append((clientside_function, args, kwargs))

def script(self, distribution: _t.List[ResourceType]):
"""Add js scripts to the page."""
self._js_dist.extend(distribution)

def stylesheet(self, distribution: _t.List[ResourceType]):
"""Add stylesheets to the page."""
self._css_dist.extend(distribution)

def index(self, priority: _t.Optional[int] = None, final=False):
"""Modify the index of the apps."""

def wrap(func):
self.add_hook(
"index",
func,
priority=priority,
final=final,
)
return func

return wrap


hooks = _Hooks()


class HooksManager:
# Flag to only run `register_setuptools` once
_registered = False
hooks = hooks

# pylint: disable=too-few-public-methods
class HookErrorHandler:
def __init__(self, original):
self.original = original

def __call__(self, err: Exception):
result = None
if self.original:
result = self.original(err)
hook_result = None
for hook in HooksManager.get_hooks("error"):
hook_result = hook(err)
return result or hook_result

@classmethod
def get_hooks(cls, hook: str):
return cls.hooks.get_hooks(hook)

@classmethod
def register_setuptools(cls):
if cls._registered:
# Only have to register once.
return

for dist in _importlib_metadata.distributions():
for entry in dist.entry_points:
# Look for setup.py entry points named `dash-hooks`
if entry.group != "dash-hooks":
continue
entry.load()
52 changes: 51 additions & 1 deletion dash/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,8 @@ def __init__( # pylint: disable=too-many-statements
for plugin in plugins:
plugin.plug(self)

self._setup_hooks()

# tracks internally if a function already handled at least one request.
self._got_first_request = {"pages": False, "setup_server": False}

Expand All @@ -582,6 +584,38 @@ def __init__( # pylint: disable=too-many-statements
)
self.setup_startup_routes()

def _setup_hooks(self):
# pylint: disable=import-outside-toplevel,protected-access
from ._hooks import HooksManager

self._hooks = HooksManager
self._hooks.register_setuptools()

for setup in self._hooks.get_hooks("setup"):
setup(self)

for hook in self._hooks.get_hooks("callback"):
callback_args, callback_kwargs = hook.data
self.callback(*callback_args, **callback_kwargs)(hook.func)

for (
clientside_function,
args,
kwargs,
) in self._hooks.hooks._clientside_callbacks:
_callback.register_clientside_callback(
self._callback_list,
self.callback_map,
self.config.prevent_initial_callbacks,
self._inline_scripts,
clientside_function,
*args,
**kwargs,
)

if self._hooks.get_hooks("error"):
self._on_error = self._hooks.HookErrorHandler(self._on_error)

def init_app(self, app=None, **kwargs):
"""Initialize the parts of Dash that require a flask app."""

Expand Down Expand Up @@ -682,6 +716,9 @@ def _setup_routes(self):
"_alive_" + jupyter_dash.alive_token, jupyter_dash.serve_alive
)

for hook in self._hooks.get_hooks("routes"):
self._add_url(hook.data["name"], hook.func, hook.data["methods"])

# catch-all for front-end routes, used by dcc.Location
self._add_url("<path:path>", self.index)

Expand Down Expand Up @@ -748,6 +785,9 @@ def index_string(self, value):
def serve_layout(self):
layout = self._layout_value()

for hook in self._hooks.get_hooks("layout"):
layout = hook(layout)

# TODO - Set browser cache limit - pass hash into frontend
return flask.Response(
to_json(layout),
Expand Down Expand Up @@ -890,9 +930,13 @@ def _relative_url_path(relative_package_path="", namespace=""):

return srcs

# pylint: disable=protected-access
def _generate_css_dist_html(self):
external_links = self.config.external_stylesheets
links = self._collect_and_register_resources(self.css.get_all_css())
links = self._collect_and_register_resources(
self.css.get_all_css()
+ self.css._resources._filter_resources(self._hooks.hooks._css_dist)
)

return "\n".join(
[
Expand Down Expand Up @@ -941,6 +985,9 @@ def _generate_scripts_html(self):
+ self.scripts._resources._filter_resources(
dash_table._js_dist, dev_bundles=dev
)
+ self.scripts._resources._filter_resources(
self._hooks.hooks._js_dist, dev_bundles=dev
)
)
)

Expand Down Expand Up @@ -1064,6 +1111,9 @@ def index(self, *args, **kwargs): # pylint: disable=unused-argument
renderer=renderer,
)

for hook in self._hooks.get_hooks("index"):
index = hook(index)

checks = (
_re_index_entry_id,
_re_index_config_id,
Expand Down
4 changes: 4 additions & 0 deletions dash/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,7 @@ class PageError(DashException):

class ImportedInsideCallbackError(DashException):
pass


class HookError(DashException):
pass
Loading

0 comments on commit 5fe20c7

Please sign in to comment.