diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 83271d0b278a..4ac3085b5d69 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -128,7 +128,7 @@ from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from ray.widgets import Template -from ray.widgets.util import ensure_notebook_deps +from ray.widgets.util import ensure_notebook_deps, fallback_if_colab if sys.version_info >= (3, 8): from typing import Literal @@ -4321,6 +4321,7 @@ def _aggregate_result(self, result: Union[Tuple, TableRow]) -> U: @ensure_notebook_deps( ["ipywidgets", "8"], ) + @fallback_if_colab def _ipython_display_(self): from ipywidgets import HTML, VBox, Layout from IPython.display import display diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index 4169f52aacf1..4f0b95f17ae5 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -21,7 +21,7 @@ from ray.train.trainer import BaseTrainer, GenDataset from ray.util.annotations import DeveloperAPI from ray.widgets import Template -from ray.widgets.util import ensure_notebook_deps +from ray.widgets.util import ensure_notebook_deps, fallback_if_colab if TYPE_CHECKING: from ray.data.preprocessor import Preprocessor @@ -447,6 +447,7 @@ def get_dataset_config(self) -> Dict[str, DatasetConfig]: ["tabulate", None], ["ipywidgets", "8"], ) + @fallback_if_colab def _ipython_display_(self): from ipywidgets import HTML, VBox, Tab, Layout from IPython.display import display diff --git a/python/ray/widgets/util.py b/python/ray/widgets/util.py index 212fb1c39c96..0c6e406a189e 100644 --- a/python/ray/widgets/util.py +++ b/python/ray/widgets/util.py @@ -167,3 +167,22 @@ def _has_outdated( logger.warning(f"Outdated packages:\n{outdated_str}\n{message}", stacklevel=3) return outdated + + +@DeveloperAPI +def fallback_if_colab(func: F) -> Callable[[F], F]: + try: + ipython = get_ipython() + except NameError: + ipython = None + + @wraps(func) + def wrapped(self, *args, **kwargs): + if ipython and "google.colab" not in str(ipython): + return func(self, *args, **kwargs) + elif hasattr(self, "__repr__"): + return print(self.__repr__(*args, **kwargs)) + else: + return None + + return wrapped