diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 4e1f4982a86..284acbca1d2 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -13,7 +13,7 @@ from functools import partial from logging import Logger -from typing import Any, TypeVar +from typing import Any, Optional, TypeVar import ax.service.utils.early_stopping as early_stopping_utils import numpy as np @@ -179,7 +179,7 @@ class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase): def __init__( self, generation_strategy: GenerationStrategy | None = None, - db_settings: DBSettings | None = None, + db_settings: Optional[DBSettings] = None, enforce_sequential_optimization: bool = True, random_seed: int | None = None, torch_device: torch.device | None = None, diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index c61548dbcc1..3ec81ebd82b 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -14,7 +14,7 @@ from datetime import datetime from logging import LoggerAdapter from time import sleep -from typing import Any, cast, NamedTuple +from typing import Any, cast, NamedTuple, Optional import ax.service.utils.early_stopping as early_stopping_utils from ax.analysis.analysis import Analysis, AnalysisCard @@ -180,7 +180,7 @@ def __init__( experiment: Experiment, generation_strategy: GenerationStrategyInterface, options: SchedulerOptions, - db_settings: DBSettings | None = None, + db_settings: Optional[DBSettings] = None, _skip_experiment_save: bool = False, ) -> None: self.experiment = experiment @@ -242,7 +242,7 @@ def from_stored_experiment( cls, experiment_name: str, options: SchedulerOptions, - db_settings: DBSettings | None = None, + db_settings: Optional[DBSettings] = None, generation_strategy: GenerationStrategy | None = None, reduced_state: bool = True, **kwargs: Any, diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 52d3202840f..77af969c44f 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -14,7 +14,7 @@ from math import ceil from random import randint from tempfile import NamedTemporaryFile -from typing import Any, cast +from typing import Any, cast, Optional from unittest.mock import call, Mock, patch, PropertyMock import pandas as pd @@ -404,7 +404,7 @@ def db_settings(self) -> DBSettings: return DBSettings(encoder=encoder, decoder=decoder) @property - def db_settings_if_always_needed(self) -> DBSettings | None: + def db_settings_if_always_needed(self) -> Optional[DBSettings]: if self.ALWAYS_USE_DB: return self.db_settings return None diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index 2163cd4b1a3..8d15f5998c8 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -11,7 +11,7 @@ from collections.abc import Iterable from logging import INFO, Logger -from typing import Any +from typing import Any, Optional from ax.analysis.analysis import AnalysisCard @@ -93,7 +93,7 @@ class WithDBSettingsBase: if `db_settings` property is set to a non-None value on the instance. """ - _db_settings: DBSettings | None = None + _db_settings: Optional[DBSettings] = None # Mapping of object types to mapping of fields to override values # loaded objects will all be instantiated with fields set to @@ -103,7 +103,7 @@ class WithDBSettingsBase: def __init__( self, - db_settings: DBSettings | None = None, + db_settings: Optional[DBSettings] = None, logging_level: int = INFO, suppress_all_errors: bool = False, ) -> None: @@ -123,7 +123,7 @@ def __init__( logger.setLevel(logging_level) @staticmethod - def _get_default_db_settings() -> DBSettings | None: + def _get_default_db_settings() -> Optional[DBSettings]: """Overridable method to get default db_settings if none are passed in __init__ """