From 08674603731f35e7bf6eee96014c52942974e7db Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Wed, 2 Oct 2024 12:18:52 -0700 Subject: [PATCH] Apply PEP 604 union type syntax codemod (#2808) Summary: This codemods all `Optional[X]` type definitions to use the PEP 604 syntax `X | None` instead. Pull Request resolved: https://github.com/facebook/Ax/pull/2808 Reviewed By: mpolson64 Differential Revision: D63764095 Pulled By: Balandat --- ax/analysis/analysis.py | 11 +- .../healthcheck/healthcheck_analysis.py | 5 +- ax/analysis/markdown/markdown_analysis.py | 5 +- ax/analysis/old/analysis_report.py | 11 +- ax/analysis/old/base_analysis.py | 5 +- ax/analysis/old/base_plotly_visualization.py | 5 +- ax/analysis/old/cross_validation_plot.py | 6 +- .../old/helpers/cross_validation_helpers.py | 4 +- ax/analysis/old/helpers/plot_helpers.py | 2 +- ax/analysis/old/helpers/scatter_helpers.py | 8 +- .../old/predicted_outcomes_dot_plot.py | 4 +- ax/analysis/plotly/parallel_coordinates.py | 8 +- ax/analysis/plotly/plotly_analysis.py | 5 +- ax/analysis/plotly/predicted_effects.py | 12 +- ax/benchmark/benchmark_metric.py | 4 +- ax/benchmark/benchmark_problem.py | 6 +- ax/benchmark/benchmark_result.py | 5 +- ax/benchmark/methods/modular_botorch.py | 10 +- ax/benchmark/methods/sobol.py | 3 +- ax/benchmark/problems/hpo/torchvision.py | 2 +- ax/benchmark/problems/registry.py | 3 +- .../synthetic/discretized/mixed_integer.py | 12 +- .../problems/synthetic/hss/jenatton.py | 21 +- ax/benchmark/runners/base.py | 4 +- ax/benchmark/runners/botorch_test.py | 26 +-- ax/benchmark/runners/surrogate.py | 20 +- ax/benchmark/tests/test_benchmark_problem.py | 5 +- ax/core/__init__.py | 1 + ax/core/arm.py | 5 +- ax/core/auxiliary.py | 4 +- ax/core/base_trial.py | 82 ++++--- ax/core/batch_trial.py | 48 ++--- ax/core/data.py | 56 ++--- ax/core/experiment.py | 116 +++++----- ax/core/formatting_utils.py | 6 +- ax/core/generation_strategy_interface.py | 8 +- ax/core/generator_run.py | 62 +++--- ax/core/map_data.py | 40 ++-- ax/core/metric.py | 18 +- ax/core/multi_type_experiment.py | 39 ++-- ax/core/objective.py | 8 +- ax/core/observation.py | 43 ++-- ax/core/optimization_config.py | 69 +++--- ax/core/outcome_constraint.py | 17 +- ax/core/parameter.py | 26 +-- ax/core/parameter_constraint.py | 10 +- ax/core/parameter_distribution.py | 6 +- ax/core/risk_measures.py | 3 +- ax/core/runner.py | 4 +- ax/core/search_space.py | 29 ++- .../test_generation_strategy_interface.py | 5 +- ax/core/trial.py | 22 +- ax/core/types.py | 4 +- ax/core/utils.py | 16 +- ax/early_stopping/strategies/base.py | 47 ++-- ax/early_stopping/strategies/logical.py | 6 +- ax/early_stopping/strategies/percentile.py | 15 +- ax/early_stopping/strategies/threshold.py | 15 +- ax/early_stopping/tests/test_strategies.py | 10 +- ax/early_stopping/utils.py | 3 +- ax/exceptions/generation_strategy.py | 7 +- ax/global_stopping/strategies/improvement.py | 7 +- ax/health_check/search_space.py | 2 +- ax/metrics/branin_map.py | 8 +- ax/metrics/factorial.py | 6 +- ax/metrics/noisy_function.py | 12 +- ax/metrics/noisy_function_map.py | 4 +- ax/metrics/tensorboard.py | 4 +- ax/modelbridge/base.py | 102 ++++----- ax/modelbridge/best_model_selector.py | 5 +- ax/modelbridge/cross_validation.py | 11 +- ax/modelbridge/discrete.py | 21 +- ax/modelbridge/dispatch_utils.py | 76 ++++--- ax/modelbridge/external_generation_node.py | 12 +- ax/modelbridge/factory.py | 21 +- ax/modelbridge/generation_node.py | 60 +++--- .../generation_node_input_constructors.py | 18 +- ax/modelbridge/generation_strategy.py | 50 ++--- ax/modelbridge/map_torch.py | 38 ++-- ax/modelbridge/model_spec.py | 35 +-- ax/modelbridge/modelbridge_utils.py | 112 +++++----- ax/modelbridge/pairwise.py | 6 +- ax/modelbridge/prediction_utils.py | 4 +- ax/modelbridge/random.py | 15 +- ax/modelbridge/registry.py | 20 +- .../tests/test_external_generation_node.py | 3 +- ...test_generation_node_input_constructors.py | 8 +- .../tests/test_modelbridge_utils.py | 3 +- .../tests/test_robust_modelbridge.py | 7 +- .../tests/test_torch_modelbridge.py | 10 +- .../tests/test_torch_moo_modelbridge.py | 3 +- ax/modelbridge/torch.py | 93 ++++---- ax/modelbridge/transforms/base.py | 18 +- ax/modelbridge/transforms/cap_parameter.py | 6 +- ax/modelbridge/transforms/cast.py | 6 +- ax/modelbridge/transforms/choice_encode.py | 8 +- .../transforms/convert_metric_names.py | 6 +- ax/modelbridge/transforms/derelativize.py | 4 +- .../transforms/int_range_to_choice.py | 6 +- ax/modelbridge/transforms/int_to_float.py | 6 +- .../transforms/inverse_gaussian_cdf_y.py | 6 +- ax/modelbridge/transforms/log.py | 6 +- ax/modelbridge/transforms/log_y.py | 18 +- ax/modelbridge/transforms/logit.py | 6 +- ax/modelbridge/transforms/map_unit_x.py | 10 +- .../transforms/merge_repeated_measurements.py | 9 +- ax/modelbridge/transforms/metrics_as_task.py | 6 +- ax/modelbridge/transforms/one_hot.py | 9 +- ax/modelbridge/transforms/percentile_y.py | 6 +- .../transforms/power_transform_y.py | 16 +- ax/modelbridge/transforms/relativize.py | 19 +- ax/modelbridge/transforms/remove_fixed.py | 12 +- .../transforms/search_space_to_choice.py | 6 +- ax/modelbridge/transforms/standardize_y.py | 18 +- .../transforms/stratified_standardize_y.py | 14 +- ax/modelbridge/transforms/task_encode.py | 6 +- .../tests/test_relativize_transform.py | 5 +- .../tests/test_transform_to_new_sq.py | 3 +- .../tests/test_winsorize_transform.py | 5 +- ax/modelbridge/transforms/time_as_feature.py | 10 +- .../transforms/transform_to_new_sq.py | 20 +- ax/modelbridge/transforms/trial_as_task.py | 16 +- ax/modelbridge/transforms/unit_x.py | 6 +- ax/modelbridge/transforms/utils.py | 11 +- ax/modelbridge/transforms/winsorize.py | 16 +- ax/modelbridge/transition_criterion.py | 173 ++++++++------- ax/models/discrete/full_factorial.py | 11 +- ax/models/discrete/thompson.py | 17 +- ax/models/discrete_base.py | 23 +- ax/models/model_utils.py | 72 +++---- ax/models/random/base.py | 29 +-- ax/models/random/sobol.py | 18 +- ax/models/random/uniform.py | 5 +- ax/models/torch/botorch.py | 25 +-- ax/models/torch/botorch_defaults.py | 97 ++++----- ax/models/torch/botorch_kg.py | 39 ++-- .../torch/botorch_modular/acquisition.py | 40 ++-- .../input_constructors/covar_modules.py | 18 +- .../input_constructors/input_transforms.py | 24 +-- .../input_constructors/outcome_transform.py | 8 +- ax/models/torch/botorch_modular/kernels.py | 36 ++-- ax/models/torch/botorch_modular/model.py | 48 ++--- .../botorch_modular/optimizer_argparse.py | 6 +- ax/models/torch/botorch_modular/sebo.py | 19 +- ax/models/torch/botorch_modular/surrogate.py | 52 ++--- ax/models/torch/botorch_modular/utils.py | 24 +-- ax/models/torch/botorch_moo.py | 11 +- ax/models/torch/botorch_moo_defaults.py | 118 +++++------ ax/models/torch/cbo_lcea.py | 30 +-- ax/models/torch/cbo_lcem.py | 12 +- ax/models/torch/cbo_sac.py | 8 +- ax/models/torch/randomforest.py | 10 +- ax/models/torch/tests/test_acquisition.py | 4 +- ax/models/torch/tests/test_sebo.py | 6 +- ax/models/torch/utils.py | 77 +++---- ax/models/torch_base.py | 34 +-- ax/models/winsorization_config.py | 5 +- ax/plot/bandit_rollout.py | 2 +- ax/plot/base.py | 12 +- ax/plot/contour.py | 34 +-- ax/plot/diagnostic.py | 26 +-- ax/plot/feature_importances.py | 16 +- ax/plot/helper.py | 29 +-- ax/plot/marginal_effects.py | 2 +- ax/plot/parallel_coordinates.py | 7 +- ax/plot/pareto_frontier.py | 55 +++-- ax/plot/pareto_utils.py | 30 +-- ax/plot/render.py | 2 +- ax/plot/scatter.py | 132 ++++++------ ax/plot/slice.py | 30 +-- ax/plot/table_view.py | 5 +- ax/plot/tests/test_tile_fitted.py | 3 +- ax/plot/trace.py | 44 ++-- ax/runners/simulated_backend.py | 8 +- ax/runners/synthetic.py | 4 +- ax/runners/torchx.py | 14 +- ax/service/ax_client.py | 164 +++++++------- ax/service/interactive_loop.py | 9 +- ax/service/managed_loop.py | 55 +++-- ax/service/scheduler.py | 72 +++---- ax/service/tests/scheduler_test_utils.py | 18 +- ax/service/tests/test_ax_client.py | 14 +- ax/service/tests/test_early_stopping.py | 3 +- ax/service/tests/test_interactive_loop.py | 5 +- ax/service/tests/test_managed_loop.py | 7 +- ax/service/utils/best_point.py | 59 +++--- ax/service/utils/best_point_mixin.py | 53 +++-- ax/service/utils/early_stopping.py | 7 +- ax/service/utils/instantiation.py | 70 +++--- ax/service/utils/report_utils.py | 88 ++++---- ax/service/utils/scheduler_options.py | 20 +- ax/service/utils/with_db_settings_base.py | 19 +- ax/storage/json_store/decoder.py | 14 +- ax/storage/json_store/decoders.py | 60 +++--- ax/storage/json_store/encoder.py | 3 +- ax/storage/json_store/load.py | 5 +- ax/storage/json_store/registry.py | 4 +- ax/storage/json_store/save.py | 3 +- ax/storage/metric_registry.py | 5 +- ax/storage/registry_bundle.py | 16 +- ax/storage/runner_registry.py | 7 +- ax/storage/sqa_store/db.py | 18 +- ax/storage/sqa_store/decoder.py | 26 +-- ax/storage/sqa_store/delete.py | 5 +- ax/storage/sqa_store/encoder.py | 24 +-- ax/storage/sqa_store/json.py | 4 +- ax/storage/sqa_store/load.py | 42 ++-- ax/storage/sqa_store/save.py | 58 +++-- ax/storage/sqa_store/sqa_classes.py | 200 ++++++++---------- ax/storage/sqa_store/sqa_config.py | 7 +- ax/storage/sqa_store/structs.py | 8 +- ax/storage/sqa_store/timestamp.py | 9 +- ax/storage/sqa_store/utils.py | 4 +- ax/storage/sqa_store/validation.py | 8 +- ax/telemetry/ax_client.py | 8 +- ax/telemetry/common.py | 4 +- ax/telemetry/experiment.py | 5 +- ax/telemetry/generation_strategy.py | 11 +- ax/telemetry/optimization.py | 51 +++-- ax/telemetry/scheduler.py | 18 +- ax/telemetry/tests/test_ax_client.py | 3 +- ax/utils/common/base.py | 5 +- ax/utils/common/decorator.py | 3 +- ax/utils/common/deprecation.py | 5 +- ax/utils/common/docutils.py | 3 +- ax/utils/common/equality.py | 7 +- ax/utils/common/executils.py | 30 +-- ax/utils/common/kwargs.py | 12 +- ax/utils/common/logger.py | 8 +- ax/utils/common/mock.py | 3 +- ax/utils/common/random.py | 3 +- ax/utils/common/result.py | 11 +- ax/utils/common/serialization.py | 11 +- ax/utils/common/testutils.py | 41 ++-- ax/utils/common/typeutils.py | 8 +- ax/utils/common/typeutils_torch.py | 5 +- ax/utils/flake8_plugins/docstring_checker.py | 3 +- ax/utils/measurement/synthetic_functions.py | 22 +- ax/utils/report/render.py | 15 +- ax/utils/sensitivity/derivative_measures.py | 29 +-- ax/utils/sensitivity/sobol_measures.py | 55 ++--- ax/utils/stats/model_fit_stats.py | 4 +- ax/utils/stats/statstools.py | 20 +- ax/utils/testing/backend_simulator.py | 33 ++- ax/utils/testing/benchmark_stubs.py | 6 +- ax/utils/testing/core_stubs.py | 50 +++-- .../testing/metrics/branin_backend_map.py | 5 +- ax/utils/testing/mock.py | 8 +- ax/utils/testing/modeling_stubs.py | 12 +- ax/utils/testing/preference_stubs.py | 7 +- ax/utils/testing/torch_stubs.py | 4 +- ax/utils/tutorials/cnn_utils.py | 5 +- scripts/make_tutorials.py | 43 ++-- scripts/parse_sphinx.py | 4 +- scripts/patch_site_config.py | 4 +- scripts/update_versions_html.py | 6 +- scripts/validate_sphinx.py | 5 +- setup.py | 2 +- sphinx/source/conf.py | 13 +- 259 files changed, 2686 insertions(+), 2828 deletions(-) diff --git a/ax/analysis/analysis.py b/ax/analysis/analysis.py index a5515d23ce4..ab9e94ab9d7 100644 --- a/ax/analysis/analysis.py +++ b/ax/analysis/analysis.py @@ -5,9 +5,10 @@ # pyre-strict +from collections.abc import Iterable from enum import Enum from logging import Logger -from typing import Iterable, Optional, Protocol +from typing import Protocol import pandas as pd from ax.core.experiment import Experiment @@ -105,8 +106,8 @@ class Analysis(Protocol): def compute( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategyInterface] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, ) -> AnalysisCard: # Note: when implementing compute always prefer experiment.lookup_data() to # experiment.fetch_data() to avoid unintential data fetching within the report @@ -115,8 +116,8 @@ def compute( def compute_result( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategyInterface] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, ) -> Result[AnalysisCard, ExceptionE]: """ Utility method to compute an AnalysisCard as a Result. This can be useful for diff --git a/ax/analysis/healthcheck/healthcheck_analysis.py b/ax/analysis/healthcheck/healthcheck_analysis.py index c7fd74d025e..58b036f1372 100644 --- a/ax/analysis/healthcheck/healthcheck_analysis.py +++ b/ax/analysis/healthcheck/healthcheck_analysis.py @@ -6,7 +6,6 @@ # pyre-strict import json from enum import IntEnum -from typing import Optional from ax.analysis.analysis import AnalysisCard from ax.core.experiment import Experiment @@ -33,6 +32,6 @@ class HealthcheckAnalysis: def compute( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategy] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> HealthcheckAnalysisCard: ... diff --git a/ax/analysis/markdown/markdown_analysis.py b/ax/analysis/markdown/markdown_analysis.py index 8dd8302a059..972b8590735 100644 --- a/ax/analysis/markdown/markdown_analysis.py +++ b/ax/analysis/markdown/markdown_analysis.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional from ax.analysis.analysis import Analysis, AnalysisCard from ax.core.experiment import Experiment @@ -34,6 +33,6 @@ class MarkdownAnalysis(Analysis): def compute( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategyInterface] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, ) -> MarkdownAnalysisCard: ... diff --git a/ax/analysis/old/analysis_report.py b/ax/analysis/old/analysis_report.py index c3f6303c206..350213a0e44 100644 --- a/ax/analysis/old/analysis_report.py +++ b/ax/analysis/old/analysis_report.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional import pandas as pd import plotly.graph_objects as go @@ -27,15 +26,15 @@ class AnalysisReport: analyses: list[BaseAnalysis] = [] experiment: Experiment - time_started: Optional[int] = None - time_completed: Optional[int] = None + time_started: int | None = None + time_completed: int | None = None def __init__( self, experiment: Experiment, analyses: list[BaseAnalysis], - time_started: Optional[int] = None, - time_completed: Optional[int] = None, + time_started: int | None = None, + time_completed: int | None = None, ) -> None: """ This class is a collection of AnalysisReport. @@ -65,7 +64,7 @@ def run_analysis_report( tuple[ BaseAnalysis, pd.DataFrame, - Optional[go.Figure], + go.Figure | None, ] ]: """ diff --git a/ax/analysis/old/base_analysis.py b/ax/analysis/old/base_analysis.py index 7dd56c9e186..d7a69187879 100644 --- a/ax/analysis/old/base_analysis.py +++ b/ax/analysis/old/base_analysis.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional import pandas as pd from ax.core.experiment import Experiment @@ -21,7 +20,7 @@ class BaseAnalysis: def __init__( self, experiment: Experiment, - df_input: Optional[pd.DataFrame] = None, + df_input: pd.DataFrame | None = None, # TODO: add support for passing in experiment name, and markdown message ) -> None: """ @@ -30,7 +29,7 @@ def __init__( we can pass the dataframe as an input. """ self._experiment = experiment - self._df: Optional[pd.DataFrame] = df_input + self._df: pd.DataFrame | None = df_input @property def experiment(self) -> Experiment: diff --git a/ax/analysis/old/base_plotly_visualization.py b/ax/analysis/old/base_plotly_visualization.py index 578a98e1041..ab65dc0d0e6 100644 --- a/ax/analysis/old/base_plotly_visualization.py +++ b/ax/analysis/old/base_plotly_visualization.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional import pandas as pd @@ -25,8 +24,8 @@ class BasePlotlyVisualization(BaseAnalysis): def __init__( self, experiment: Experiment, - df_input: Optional[pd.DataFrame] = None, - fig_input: Optional[go.Figure] = None, + df_input: pd.DataFrame | None = None, + fig_input: go.Figure | None = None, ) -> None: """ Initialize the analysis with the experiment object. diff --git a/ax/analysis/old/cross_validation_plot.py b/ax/analysis/old/cross_validation_plot.py index b0636c19c95..3f3ac5646a3 100644 --- a/ax/analysis/old/cross_validation_plot.py +++ b/ax/analysis/old/cross_validation_plot.py @@ -6,7 +6,7 @@ # pyre-strict from copy import deepcopy -from typing import Any, Optional +from typing import Any import pandas as pd @@ -43,7 +43,7 @@ def __init__( self, experiment: Experiment, model: ModelBridge, - label_dict: Optional[dict[str, str]] = None, + label_dict: dict[str, str] | None = None, caption: str = CROSS_VALIDATION_CAPTION, ) -> None: """ @@ -56,7 +56,7 @@ def __init__( self.model = model self.cv: list[CVResult] = cross_validate(model=model) - self.label_dict: Optional[dict[str, str]] = label_dict + self.label_dict: dict[str, str] | None = label_dict if self.label_dict: self.cv = self.remap_label(cv_results=self.cv, label_dict=self.label_dict) diff --git a/ax/analysis/old/helpers/cross_validation_helpers.py b/ax/analysis/old/helpers/cross_validation_helpers.py index 71e3c514fa7..a40bd358586 100644 --- a/ax/analysis/old/helpers/cross_validation_helpers.py +++ b/ax/analysis/old/helpers/cross_validation_helpers.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -186,7 +186,7 @@ def diagonal_trace(min_: float, max_: float, visible: bool = True) -> dict[str, ) -def default_value_se_raw(se_raw: Optional[list[float]], out_length: int) -> list[float]: +def default_value_se_raw(se_raw: list[float] | None, out_length: int) -> list[float]: """ Takes a list of standard errors and maps edge cases to default list of floats. diff --git a/ax/analysis/old/helpers/plot_helpers.py b/ax/analysis/old/helpers/plot_helpers.py index e464ec31e02..5ed1a7dfeee 100644 --- a/ax/analysis/old/helpers/plot_helpers.py +++ b/ax/analysis/old/helpers/plot_helpers.py @@ -43,7 +43,7 @@ def _format_dict(param_dict: TParameterization, name: str = "Parameterization") ) else: blob = "
{}:
{}".format( - name, "
".join("{}: {}".format(n, v) for n, v in param_dict.items()) + name, "
".join(f"{n}: {v}" for n, v in param_dict.items()) ) return blob diff --git a/ax/analysis/old/helpers/scatter_helpers.py b/ax/analysis/old/helpers/scatter_helpers.py index 3bf0e8b7fa0..13fd9144ac8 100644 --- a/ax/analysis/old/helpers/scatter_helpers.py +++ b/ax/analysis/old/helpers/scatter_helpers.py @@ -8,7 +8,7 @@ import numbers -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -119,7 +119,7 @@ def extract_mean_and_error_from_df( def make_label( arm_name: str, - x_axis_values: Optional[tuple[str, float, float]], + x_axis_values: tuple[str, float, float] | None, y_axis_values: tuple[str, float, float], param_blob: TParameterization, rel: bool, @@ -240,8 +240,8 @@ def error_scatter_trace_from_df( df: pd.DataFrame, show_CI: bool = True, visible: bool = True, - y_axis_label: Optional[str] = None, - x_axis_label: Optional[str] = None, + y_axis_label: str | None = None, + x_axis_label: str | None = None, ) -> dict[str, Any]: """Plot scatterplot with error bars. diff --git a/ax/analysis/old/predicted_outcomes_dot_plot.py b/ax/analysis/old/predicted_outcomes_dot_plot.py index 3c2b01e4936..6022134616d 100644 --- a/ax/analysis/old/predicted_outcomes_dot_plot.py +++ b/ax/analysis/old/predicted_outcomes_dot_plot.py @@ -101,12 +101,12 @@ def get_fig( reverse=True, ) - name_order_axes["xaxis{}".format(i + 1)] = { + name_order_axes[f"xaxis{i + 1}"] = { "categoryorder": "array", "categoryarray": names_by_arm, "type": "category", } - name_order_axes["yaxis{}".format(i + 1)] = { + name_order_axes[f"yaxis{i + 1}"] = { "ticksuffix": "%", "zerolinecolor": "red", } diff --git a/ax/analysis/plotly/parallel_coordinates.py b/ax/analysis/plotly/parallel_coordinates.py index afd4bea51fa..fda71e3ce36 100644 --- a/ax/analysis/plotly/parallel_coordinates.py +++ b/ax/analysis/plotly/parallel_coordinates.py @@ -5,7 +5,7 @@ # pyre-strict -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -32,7 +32,7 @@ class ParallelCoordinatesPlot(PlotlyAnalysis): - **PARAMETER_NAME: The value of said parameter for the arm, for each parameter """ - def __init__(self, metric_name: Optional[str] = None) -> None: + def __init__(self, metric_name: str | None = None) -> None: """ Args: metric_name: The name of the metric to plot. If not specified the objective @@ -44,8 +44,8 @@ def __init__(self, metric_name: Optional[str] = None) -> None: def compute( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategyInterface] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("ParallelCoordinatesPlot requires an Experiment") diff --git a/ax/analysis/plotly/plotly_analysis.py b/ax/analysis/plotly/plotly_analysis.py index 1248ed9dda6..cc3bf86a54e 100644 --- a/ax/analysis/plotly/plotly_analysis.py +++ b/ax/analysis/plotly/plotly_analysis.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional from ax.analysis.analysis import Analysis, AnalysisCard from ax.core.experiment import Experiment @@ -35,6 +34,6 @@ class PlotlyAnalysis(Analysis): def compute( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategyInterface] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, ) -> PlotlyAnalysisCard: ... diff --git a/ax/analysis/plotly/predicted_effects.py b/ax/analysis/plotly/predicted_effects.py index 1ae82d2cea8..7a2f8a86ebd 100644 --- a/ax/analysis/plotly/predicted_effects.py +++ b/ax/analysis/plotly/predicted_effects.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from itertools import chain -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -45,8 +45,8 @@ def __init__(self, metric_name: str) -> None: def compute( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategyInterface] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("PredictedEffectsPlot requires an Experiment.") @@ -223,8 +223,8 @@ def _get_predictions( model: ModelBridge, metric_name: str, outcome_constraints: list[OutcomeConstraint], - gr: Optional[GeneratorRun] = None, - trial_index: Optional[int] = None, + gr: GeneratorRun | None = None, + trial_index: int | None = None, ) -> list[dict[str, Any]]: if gr is None: observations = model.get_training_data() @@ -294,7 +294,7 @@ def _get_predictions( ] -def _get_max_observed_trial_index(model: ModelBridge) -> Optional[int]: +def _get_max_observed_trial_index(model: ModelBridge) -> int | None: """Returns the max observed trial index to appease multitask models for prediction by giving fixed features. This is not necessarily accurate and should eventually come from the generation strategy. diff --git a/ax/benchmark/benchmark_metric.py b/ax/benchmark/benchmark_metric.py index 139f1e2e64e..ca2ac6bd153 100644 --- a/ax/benchmark/benchmark_metric.py +++ b/ax/benchmark/benchmark_metric.py @@ -5,7 +5,7 @@ # pyre-strict -from typing import Any, Optional +from typing import Any import pandas as pd from ax.core.base_trial import BaseTrial @@ -26,7 +26,7 @@ def __init__( name: str, lower_is_better: bool, # TODO: Do we need to define this here? observe_noise_sd: bool = True, - outcome_index: Optional[int] = None, + outcome_index: int | None = None, ) -> None: """ Args: diff --git a/ax/benchmark/benchmark_problem.py b/ax/benchmark/benchmark_problem.py index a39d90c8406..b3cd547115a 100644 --- a/ax/benchmark/benchmark_problem.py +++ b/ax/benchmark/benchmark_problem.py @@ -7,7 +7,7 @@ from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Any, Optional, Union +from typing import Any import pandas as pd @@ -38,7 +38,7 @@ def _get_name( test_problem: BaseTestProblem, observe_noise_sd: bool, - dim: Optional[int] = None, + dim: int | None = None, ) -> str: """ Get a string name describing the problem, in a format such as @@ -86,7 +86,7 @@ class BenchmarkProblem(Base): name: str optimization_config: OptimizationConfig num_trials: int - observe_noise_stds: Union[bool, dict[str, bool]] = False + observe_noise_stds: bool | dict[str, bool] = False optimal_value: float search_space: SearchSpace = field(repr=False) diff --git a/ax/benchmark/benchmark_result.py b/ax/benchmark/benchmark_result.py index 52bcf06f994..e1c5b2cda97 100644 --- a/ax/benchmark/benchmark_result.py +++ b/ax/benchmark/benchmark_result.py @@ -12,7 +12,6 @@ from collections.abc import Iterable from dataclasses import dataclass -from typing import Optional import numpy as np from ax.core.experiment import Experiment @@ -87,8 +86,8 @@ class BenchmarkResult(Base): fit_time: float gen_time: float - experiment: Optional[Experiment] = None - experiment_storage_id: Optional[str] = None + experiment: Experiment | None = None + experiment_storage_id: str | None = None def __post_init__(self) -> None: if self.experiment is not None and self.experiment_storage_id is not None: diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index ea0f0e944fd..87ad638abb8 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -5,7 +5,7 @@ # pyre-strict -from typing import Any, Optional, Union +from typing import Any from ax.benchmark.benchmark_method import ( BenchmarkMethod, @@ -44,10 +44,10 @@ def get_sobol_botorch_modular_acquisition( model_cls: type[Model], acquisition_cls: type[AcquisitionFunction], distribute_replications: bool, - scheduler_options: Optional[SchedulerOptions] = None, - name: Optional[str] = None, + scheduler_options: SchedulerOptions | None = None, + name: str | None = None, num_sobol_trials: int = 5, - model_gen_kwargs: Optional[dict[str, Any]] = None, + model_gen_kwargs: dict[str, Any] | None = None, use_model_predictions_for_best_point: bool = False, ) -> BenchmarkMethod: """Get a `BenchmarkMethod` that uses Sobol followed by MBM. @@ -96,7 +96,7 @@ def get_sobol_botorch_modular_acquisition( ... ) """ model_kwargs: dict[ - str, Union[type[AcquisitionFunction], dict[str, SurrogateSpec], bool] + str, type[AcquisitionFunction] | dict[str, SurrogateSpec] | bool ] = { "botorch_acqf_class": acquisition_cls, "surrogate_specs": {"BoTorch": SurrogateSpec(botorch_model_class=model_cls)}, diff --git a/ax/benchmark/methods/sobol.py b/ax/benchmark/methods/sobol.py index 6185ab1cc7c..78e007232ff 100644 --- a/ax/benchmark/methods/sobol.py +++ b/ax/benchmark/methods/sobol.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional from ax.benchmark.benchmark_method import ( BenchmarkMethod, @@ -18,7 +17,7 @@ def get_sobol_benchmark_method( distribute_replications: bool, - scheduler_options: Optional[SchedulerOptions] = None, + scheduler_options: SchedulerOptions | None = None, ) -> BenchmarkMethod: generation_strategy = GenerationStrategy( name="Sobol", diff --git a/ax/benchmark/problems/hpo/torchvision.py b/ax/benchmark/problems/hpo/torchvision.py index d17507071ad..360b31901f4 100644 --- a/ax/benchmark/problems/hpo/torchvision.py +++ b/ax/benchmark/problems/hpo/torchvision.py @@ -5,9 +5,9 @@ # pyre-strict +from collections.abc import Mapping from dataclasses import dataclass, field, InitVar from functools import lru_cache -from typing import Mapping import torch from ax.benchmark.benchmark_problem import ( diff --git a/ax/benchmark/problems/registry.py b/ax/benchmark/problems/registry.py index 992dec980e4..6d7f5bd0fa2 100644 --- a/ax/benchmark/problems/registry.py +++ b/ax/benchmark/problems/registry.py @@ -6,8 +6,9 @@ # pyre-strict import copy +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable +from typing import Any from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch from ax.benchmark.problems.hd_embedding import embed_higher_dimension diff --git a/ax/benchmark/problems/synthetic/discretized/mixed_integer.py b/ax/benchmark/problems/synthetic/discretized/mixed_integer.py index 451bd91d880..753346488ec 100644 --- a/ax/benchmark/problems/synthetic/discretized/mixed_integer.py +++ b/ax/benchmark/problems/synthetic/discretized/mixed_integer.py @@ -18,8 +18,6 @@ 35, 2022. """ -from typing import Optional, Union - from ax.benchmark.benchmark_metric import BenchmarkMetric from ax.benchmark.benchmark_problem import BenchmarkProblem @@ -47,7 +45,7 @@ def _get_problem_from_common_inputs( benchmark_name: str, num_trials: int, optimal_value: float, - test_problem_bounds: Optional[list[tuple[float, float]]] = None, + test_problem_bounds: list[tuple[float, float]] | None = None, ) -> BenchmarkProblem: """This is a helper that deduplicates common bits of the below problems. @@ -103,7 +101,7 @@ def _get_problem_from_common_inputs( minimize=lower_is_better, ) ) - test_problem_kwargs: dict[str, Union[int, list[tuple[float, float]]]] = {"dim": dim} + test_problem_kwargs: dict[str, int | list[tuple[float, float]]] = {"dim": dim} if test_problem_bounds is not None: test_problem_kwargs["bounds"] = test_problem_bounds runner = BotorchTestProblemRunner( @@ -126,7 +124,7 @@ def _get_problem_from_common_inputs( def get_discrete_hartmann( num_trials: int = 50, observe_noise_sd: bool = False, - bounds: Optional[list[tuple[float, float]]] = None, + bounds: list[tuple[float, float]] | None = None, ) -> BenchmarkProblem: """6D Hartmann problem where first 4 dimensions are discretized.""" dim_int = 4 @@ -158,7 +156,7 @@ def get_discrete_hartmann( def get_discrete_ackley( num_trials: int = 50, observe_noise_sd: bool = False, - bounds: Optional[list[tuple[float, float]]] = None, + bounds: list[tuple[float, float]] | None = None, ) -> BenchmarkProblem: """13D Ackley problem where first 10 dimensions are discretized. @@ -191,7 +189,7 @@ def get_discrete_ackley( def get_discrete_rosenbrock( num_trials: int = 50, observe_noise_sd: bool = False, - bounds: Optional[list[tuple[float, float]]] = None, + bounds: list[tuple[float, float]] | None = None, ) -> BenchmarkProblem: """10D Rosenbrock problem where first 6 dimensions are discretized.""" dim_int = 6 diff --git a/ax/benchmark/problems/synthetic/hss/jenatton.py b/ax/benchmark/problems/synthetic/hss/jenatton.py index 30272a7babe..e3c9be2bed9 100644 --- a/ax/benchmark/problems/synthetic/hss/jenatton.py +++ b/ax/benchmark/problems/synthetic/hss/jenatton.py @@ -7,7 +7,6 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import Optional import torch from ax.benchmark.benchmark_metric import BenchmarkMetric @@ -24,15 +23,15 @@ def jenatton_test_function( - x1: Optional[int] = None, - x2: Optional[int] = None, - x3: Optional[int] = None, - x4: Optional[float] = None, - x5: Optional[float] = None, - x6: Optional[float] = None, - x7: Optional[float] = None, - r8: Optional[float] = None, - r9: Optional[float] = None, + x1: int | None = None, + x2: int | None = None, + x3: int | None = None, + x4: float | None = None, + x5: float | None = None, + x6: float | None = None, + x7: float | None = None, + r8: float | None = None, + r9: float | None = None, ) -> float: """Jenatton test function for hierarchical search spaces. @@ -54,7 +53,7 @@ def jenatton_test_function( class Jenatton(ParamBasedTestProblem): """Jenatton test function for hierarchical search spaces.""" - noise_std: Optional[float] = None + noise_std: float | None = None negate: bool = False num_objectives: int = 1 optimal_value: float = 0.1 diff --git a/ax/benchmark/runners/base.py b/ax/benchmark/runners/base.py index 39eedeac4b5..2b571ab2b07 100644 --- a/ax/benchmark/runners/base.py +++ b/ax/benchmark/runners/base.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping from math import sqrt -from typing import Any, Union +from typing import Any import torch @@ -81,7 +81,7 @@ def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> ndarray: return self.get_Y_true(params=params).numpy() @abstractmethod - def get_noise_stds(self) -> Union[None, float, dict[str, float]]: + def get_noise_stds(self) -> None | float | dict[str, float]: """ Return the standard errors for the synthetic noise to be applied to the observed values. diff --git a/ax/benchmark/runners/botorch_test.py b/ax/benchmark/runners/botorch_test.py index 58ac259b1ad..501eb048c65 100644 --- a/ax/benchmark/runners/botorch_test.py +++ b/ax/benchmark/runners/botorch_test.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any import torch from ax.benchmark.runners.base import BenchmarkRunner @@ -36,8 +36,8 @@ class ParamBasedTestProblem(ABC): # Constraints could easily be supported similar to BoTorch test problems, # but haven't been hooked up. _is_constrained: bool = False - constraint_noise_std: Optional[Union[float, list[float]]] = None - noise_std: Optional[Union[float, list[float]]] = None + constraint_noise_std: float | list[float] | None = None + noise_std: float | list[float] | None = None negate: bool = False @abstractmethod @@ -65,18 +65,18 @@ class SyntheticProblemRunner(BenchmarkRunner, ABC): problem such as the noise_std. """ - test_problem: Union[BaseTestProblem, ParamBasedTestProblem] + test_problem: BaseTestProblem | ParamBasedTestProblem _is_constrained: bool - _test_problem_class: type[Union[BaseTestProblem, ParamBasedTestProblem]] - _test_problem_kwargs: Optional[dict[str, Any]] + _test_problem_class: type[BaseTestProblem | ParamBasedTestProblem] + _test_problem_kwargs: dict[str, Any] | None def __init__( self, *, - test_problem_class: type[Union[BaseTestProblem, ParamBasedTestProblem]], + test_problem_class: type[BaseTestProblem | ParamBasedTestProblem], test_problem_kwargs: dict[str, Any], outcome_names: list[str], - modified_bounds: Optional[list[tuple[float, float]]] = None, + modified_bounds: list[tuple[float, float]] | None = None, search_space_digest: SearchSpaceDigest | None = None, ) -> None: """Initialize the test problem runner. @@ -127,7 +127,7 @@ def __eq__(self, other: Base) -> bool: == other.test_problem.__class__.__name__ ) - def get_noise_stds(self) -> Union[None, float, dict[str, float]]: + def get_noise_stds(self) -> None | float | dict[str, float]: noise_std = self.test_problem.noise_std noise_std_dict: dict[str, float] = {} num_obj = 1 if not self._is_moo else self.test_problem.num_objectives @@ -184,8 +184,8 @@ def serialize_init_args(cls, obj: Any) -> dict[str, Any]: def deserialize_init_args( cls, args: dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, - class_decoder_registry: Optional[TClassDecoderRegistry] = None, + decoder_registry: TDecoderRegistry | None = None, + class_decoder_registry: TClassDecoderRegistry | None = None, ) -> dict[str, Any]: """Given a dictionary, deserialize the properties needed to initialize the runner. Used for storage. @@ -228,7 +228,7 @@ def __init__( test_problem_class: type[BaseTestProblem], test_problem_kwargs: dict[str, Any], outcome_names: list[str], - modified_bounds: Optional[list[tuple[float, float]]] = None, + modified_bounds: list[tuple[float, float]] | None = None, search_space_digest: SearchSpaceDigest | None = None, ) -> None: super().__init__( @@ -303,7 +303,7 @@ def __init__( test_problem_class: type[ParamBasedTestProblem], test_problem_kwargs: dict[str, Any], outcome_names: list[str], - modified_bounds: Optional[list[tuple[float, float]]] = None, + modified_bounds: list[tuple[float, float]] | None = None, search_space_digest: SearchSpaceDigest | None = None, ) -> None: if modified_bounds is not None: diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index fec322303c7..a68d84e5c6a 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -6,8 +6,8 @@ # pyre-strict import warnings -from collections.abc import Mapping -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Mapping +from typing import Any import torch from ax.benchmark.runners.base import BenchmarkRunner @@ -30,12 +30,12 @@ def __init__( name: str, search_space: SearchSpace, outcome_names: list[str], - surrogate: Optional[TorchModelBridge] = None, - datasets: Optional[list[SupervisedDataset]] = None, - noise_stds: Union[float, dict[str, float]] = 0.0, - get_surrogate_and_datasets: Optional[ + surrogate: TorchModelBridge | None = None, + datasets: list[SupervisedDataset] | None = None, + noise_stds: float | dict[str, float] = 0.0, + get_surrogate_and_datasets: None | ( Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]] - ] = None, + ) = None, search_space_digest: SearchSpaceDigest | None = None, ) -> None: """Runner for surrogate benchmark problems. @@ -92,7 +92,7 @@ def datasets(self) -> list[SupervisedDataset]: self.set_surrogate_and_datasets() return none_throws(self._datasets) - def get_noise_stds(self) -> Union[None, float, dict[str, float]]: + def get_noise_stds(self) -> None | float | dict[str, float]: return self.noise_stds # pyre-fixme[14]: Inconsistent override @@ -158,8 +158,8 @@ def serialize_init_args(cls, obj: Any) -> dict[str, Any]: def deserialize_init_args( cls, args: dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, - class_decoder_registry: Optional[TClassDecoderRegistry] = None, + decoder_registry: TDecoderRegistry | None = None, + class_decoder_registry: TClassDecoderRegistry | None = None, ) -> dict[str, Any]: return {} diff --git a/ax/benchmark/tests/test_benchmark_problem.py b/ax/benchmark/tests/test_benchmark_problem.py index bae7cf43ba8..ac519d90b27 100644 --- a/ax/benchmark/tests/test_benchmark_problem.py +++ b/ax/benchmark/tests/test_benchmark_problem.py @@ -8,7 +8,6 @@ import math from itertools import product from math import pi -from typing import Optional, Union import torch @@ -203,8 +202,8 @@ def test_single_objective_from_botorch(self) -> None: def _test_constrained_from_botorch( self, observe_noise_sd: bool, - objective_noise_std: Optional[float], - constraint_noise_std: Optional[Union[float, list[float]]], + objective_noise_std: float | None, + constraint_noise_std: float | list[float] | None, test_problem_class: type[ConstrainedBaseTestProblem], ) -> None: ax_problem = create_problem_from_botorch( diff --git a/ax/core/__init__.py b/ax/core/__init__.py index 9b6fa1243a9..694742f65e9 100644 --- a/ax/core/__init__.py +++ b/ax/core/__init__.py @@ -57,6 +57,7 @@ "MultiObjectiveOptimizationConfig", "Objective", "ObjectiveThreshold", + "ObservationFeatures", "OptimizationConfig", "OrderConstraint", "OutcomeConstraint", diff --git a/ax/core/arm.py b/ax/core/arm.py index 99c75daa99e..85277c5d596 100644 --- a/ax/core/arm.py +++ b/ax/core/arm.py @@ -8,7 +8,6 @@ import hashlib import json -from typing import Optional from ax.core.types import TParameterization from ax.utils.common.base import SortableBase @@ -23,9 +22,7 @@ class Arm(SortableBase): encapsulates the parametrization needed by the unit. """ - def __init__( - self, parameters: TParameterization, name: Optional[str] = None - ) -> None: + def __init__(self, parameters: TParameterization, name: str | None = None) -> None: """Inits Arm. Args: diff --git a/ax/core/auxiliary.py b/ax/core/auxiliary.py index 748ad187711..db975fefa07 100644 --- a/ax/core/auxiliary.py +++ b/ax/core/auxiliary.py @@ -8,7 +8,7 @@ from __future__ import annotations from enum import Enum, unique -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from ax.core.data import Data from ax.utils.common.base import SortableBase @@ -25,7 +25,7 @@ class AuxiliaryExperiment(SortableBase): def __init__( self, experiment: core.experiment.Experiment, - data: Optional[Data] = None, + data: Data | None = None, ) -> None: """ Lightweight container of an experiment, and its data, diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 1377d0b9e52..3717d49a073 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -9,10 +9,11 @@ from __future__ import annotations from abc import ABC, abstractmethod, abstractproperty +from collections.abc import Callable from copy import deepcopy from datetime import datetime, timedelta from enum import Enum -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING from ax.core.arm import Arm from ax.core.data import Data @@ -171,7 +172,6 @@ def __repr__(self) -> str: ] -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def immutable_once_run(func: Callable) -> Callable: """Decorator for methods that should throw Error when trial is running or has ever run and immutable. @@ -217,9 +217,9 @@ class BaseTrial(ABC, SortableBase): def __init__( self, experiment: core.experiment.Experiment, - trial_type: Optional[str] = None, - ttl_seconds: Optional[int] = None, - index: Optional[int] = None, + trial_type: str | None = None, + ttl_seconds: int | None = None, + index: int | None = None, ) -> None: """Initialize trial. @@ -229,7 +229,7 @@ def __init__( self._experiment = experiment if ttl_seconds is not None and ttl_seconds <= 0: raise ValueError("TTL must be a positive integer (or None).") - self._ttl_seconds: Optional[int] = ttl_seconds + self._ttl_seconds: int | None = ttl_seconds self._index: int = self._experiment._attach_trial(self, index=index) if trial_type is not None: @@ -239,25 +239,25 @@ def __init__( ) else: trial_type = self._experiment.default_trial_type - self._trial_type: Optional[str] = trial_type + self._trial_type: str | None = trial_type - self.__status: Optional[TrialStatus] = None + self.__status: TrialStatus | None = None # Uses `_status` setter, which updates trial statuses to trial indices # mapping on the experiment, with which this trial is associated. self._status = TrialStatus.CANDIDATE self._time_created: datetime = datetime.now() # Initialize fields to be used later in lifecycle - self._time_completed: Optional[datetime] = None - self._time_staged: Optional[datetime] = None - self._time_run_started: Optional[datetime] = None + self._time_completed: datetime | None = None + self._time_staged: datetime | None = None + self._time_run_started: datetime | None = None - self._abandoned_reason: Optional[str] = None - self._failed_reason: Optional[str] = None + self._abandoned_reason: str | None = None + self._failed_reason: str | None = None self._run_metadata: dict[str, Any] = {} self._stop_metadata: dict[str, Any] = {} - self._runner: Optional[Runner] = None + self._runner: Runner | None = None # Counter to maintain how many arms have been named by this BatchTrial self._num_arms_created = 0 @@ -265,7 +265,7 @@ def __init__( # If generator run(s) in this trial were generated from a generation # strategy, this property will be set to the generation step that produced # the generator run(s). - self._generation_step_index: Optional[int] = None + self._generation_step_index: int | None = None # pyre-fixme[4]: Attribute must be annotated. self._properties = {} @@ -290,7 +290,7 @@ def status(self, status: TrialStatus) -> None: raise NotImplementedError("Use `trial.mark_*` methods to set trial status.") @property - def ttl_seconds(self) -> Optional[int]: + def ttl_seconds(self) -> int | None: """This trial's time-to-live once ran, in seconds. If not set, trial will never be automatically considered failed (i.e. infinite TTL). Reflects after how many seconds since the time the trial was run it @@ -299,7 +299,7 @@ def ttl_seconds(self) -> Optional[int]: return self._ttl_seconds @ttl_seconds.setter - def ttl_seconds(self, ttl_seconds: Optional[int]) -> None: + def ttl_seconds(self, ttl_seconds: int | None) -> None: """Sets this trial's time-to-live once ran, in seconds. If None, trial will never be automatically considered failed (i.e. infinite TTL). Reflects after how many seconds since the time the trial was run it @@ -320,17 +320,17 @@ def did_not_complete(self) -> bool: return self.status.is_terminal and not self.completed_successfully @property - def runner(self) -> Optional[Runner]: + def runner(self) -> Runner | None: """The runner object defining how to deploy the trial.""" return self._runner @runner.setter @immutable_once_run - def runner(self, runner: Optional[Runner]) -> None: + def runner(self, runner: Runner | None) -> None: self._runner = runner @property - def deployed_name(self) -> Optional[str]: + def deployed_name(self) -> str | None: """Name of the experiment created in external framework. This property is derived from the name field in run_metadata. @@ -354,7 +354,7 @@ def stop_metadata(self) -> dict[str, Any]: return self._stop_metadata @property - def trial_type(self) -> Optional[str]: + def trial_type(self) -> str | None: """The type of the trial. Relevant for experiments containing different kinds of trials @@ -364,7 +364,7 @@ def trial_type(self) -> Optional[str]: @trial_type.setter @immutable_once_run - def trial_type(self, trial_type: Optional[str]) -> None: + def trial_type(self, trial_type: str | None) -> None: """Identifier used to distinguish trial types in experiments with multiple trial types. """ @@ -421,7 +421,7 @@ def run(self) -> BaseTrial: self.mark_running() return self - def stop(self, new_status: TrialStatus, reason: Optional[str] = None) -> BaseTrial: + def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial: """Stops the trial according to the behavior on the runner. The runner returns a `stop_metadata` dict containining metadata @@ -458,7 +458,7 @@ def stop(self, new_status: TrialStatus, reason: Optional[str] = None) -> BaseTri self.mark_as(new_status) return self - def complete(self, reason: Optional[str] = None) -> BaseTrial: + def complete(self, reason: str | None = None) -> BaseTrial: """Stops the trial if functionality is defined on runner and marks trial completed. @@ -478,7 +478,7 @@ def complete(self, reason: Optional[str] = None) -> BaseTrial: return self def fetch_data_results( - self, metrics: Optional[list[Metric]] = None, **kwargs: Any + self, metrics: list[Metric] | None = None, **kwargs: Any ) -> dict[str, MetricFetchResult]: """Fetch data results for this trial for all metrics on experiment. @@ -496,7 +496,7 @@ def fetch_data_results( trial_index=self.index, metrics=metrics, **kwargs ) - def fetch_data(self, metrics: Optional[list[Metric]] = None, **kwargs: Any) -> Data: + def fetch_data(self, metrics: list[Metric] | None = None, **kwargs: Any) -> Data: """Fetch data for this trial for all metrics on experiment. # NOTE: This can be lossy (ex. a MapData could get implicitly cast to a Data and @@ -543,12 +543,12 @@ def _check_existing_and_name_arm(self, arm: Arm) -> None: if arm.name == proposed_name: self._num_arms_created += 1 - def _get_default_name(self, arm_index: Optional[int] = None) -> str: + def _get_default_name(self, arm_index: int | None = None) -> str: if arm_index is None: arm_index = self._num_arms_created return f"{self.index}_{arm_index}" - def _set_generation_step_index(self, generation_step_index: Optional[int]) -> None: + def _set_generation_step_index(self, generation_step_index: int | None) -> None: """Sets the `generation_step_index` property of the trial, to reflect which generation step of a given generation strategy (if any) produced the generator run(s) attached to this trial. @@ -608,17 +608,17 @@ def time_created(self) -> datetime: return self._time_created @property - def time_completed(self) -> Optional[datetime]: + def time_completed(self) -> datetime | None: """Completion time of the trial.""" return self._time_completed @property - def time_staged(self) -> Optional[datetime]: + def time_staged(self) -> datetime | None: """Staged time of the trial.""" return self._time_staged @property - def time_run_started(self) -> Optional[datetime]: + def time_run_started(self) -> datetime | None: """Time the trial was started running (i.e. collecting data).""" return self._time_run_started @@ -628,11 +628,11 @@ def is_abandoned(self) -> bool: return self._status == TrialStatus.ABANDONED @property - def abandoned_reason(self) -> Optional[str]: + def abandoned_reason(self) -> str | None: return self._abandoned_reason @property - def failed_reason(self) -> Optional[str]: + def failed_reason(self) -> str | None: return self._failed_reason def mark_staged(self, unsafe: bool = False) -> BaseTrial: @@ -696,7 +696,7 @@ def mark_completed(self, unsafe: bool = False) -> BaseTrial: return self def mark_abandoned( - self, reason: Optional[str] = None, unsafe: bool = False + self, reason: str | None = None, unsafe: bool = False ) -> BaseTrial: """Mark trial as abandoned. @@ -721,9 +721,7 @@ def mark_abandoned( self._time_completed = datetime.now() return self - def mark_failed( - self, reason: Optional[str] = None, unsafe: bool = False - ) -> BaseTrial: + def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> BaseTrial: """Mark trial as failed. Args: @@ -784,9 +782,7 @@ def mark_as( raise ValueError(f"Cannot mark trial as {status}.") return self - def mark_arm_abandoned( - self, arm_name: str, reason: Optional[str] = None - ) -> BaseTrial: + def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> BaseTrial: raise NotImplementedError( "Abandoning arms is only supported for `BatchTrial`. " "Use `trial.mark_abandoned` if applicable." @@ -805,7 +801,7 @@ def _mark_failed_if_past_TTL(self) -> None: self.mark_failed() @property - def _status(self) -> Optional[TrialStatus]: + def _status(self) -> TrialStatus | None: """The status of the trial in the experimentation lifecycle. This private property exists to allow for a corresponding setter, since its important that the trial statuses mapping on the experiment is updated always when @@ -833,8 +829,8 @@ def _unique_id(self) -> str: def _make_evaluations_and_data( self, raw_data: dict[str, TEvaluationOutcome], - metadata: Optional[dict[str, Union[str, int]]], - sample_sizes: Optional[dict[str, int]] = None, + metadata: dict[str, str | int] | None, + sample_sizes: dict[str, int] | None = None, ) -> tuple[dict[str, TEvaluationOutcome], Data]: """Formats given raw data as Ax evaluations and `Data`. diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index 6475e43283b..266897cb14f 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -16,7 +16,7 @@ from datetime import datetime from enum import Enum from logging import Logger -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING import numpy as np from ax.core.arm import Arm @@ -64,7 +64,7 @@ class AbandonedArm(SortableBase): name: str time: datetime - reason: Optional[str] = None + reason: str | None = None @equality_typechecker def __eq__(self, other: AbandonedArm) -> bool: @@ -137,13 +137,13 @@ class BatchTrial(BaseTrial): def __init__( self, experiment: core.experiment.Experiment, - generator_run: Optional[GeneratorRun] = None, - generator_runs: Optional[list[GeneratorRun]] = None, - trial_type: Optional[str] = None, - optimize_for_power: Optional[bool] = False, - ttl_seconds: Optional[int] = None, - index: Optional[int] = None, - lifecycle_stage: Optional[LifecycleStage] = None, + generator_run: GeneratorRun | None = None, + generator_runs: list[GeneratorRun] | None = None, + trial_type: str | None = None, + optimize_for_power: bool | None = False, + ttl_seconds: int | None = None, + index: int | None = None, + lifecycle_stage: LifecycleStage | None = None, ) -> None: super().__init__( experiment=experiment, @@ -154,8 +154,8 @@ def __init__( self._arms_by_name: dict[str, Arm] = {} self._generator_run_structs: list[GeneratorRunStruct] = [] self._abandoned_arms_metadata: dict[str, AbandonedArm] = {} - self._status_quo: Optional[Arm] = None - self._status_quo_weight_override: Optional[float] = None + self._status_quo: Arm | None = None + self._status_quo_weight_override: float | None = None if generator_run is not None: if generator_runs is not None: raise UnsupportedError( @@ -182,8 +182,8 @@ def __init__( # Trial status quos are stored in the DB as a generator run # with one arm; thus we need to store two `db_id` values # for this object instead of one - self._status_quo_generator_run_db_id: Optional[int] = None - self._status_quo_arm_db_id: Optional[int] = None + self._status_quo_generator_run_db_id: int | None = None + self._status_quo_arm_db_id: int | None = None self._lifecycle_stage = lifecycle_stage @property @@ -230,7 +230,7 @@ def arm_weights(self) -> MutableMapping[Arm, float]: return arm_weights @property - def lifecycle_stage(self) -> Optional[LifecycleStage]: + def lifecycle_stage(self) -> LifecycleStage | None: return self._lifecycle_stage @arm_weights.setter @@ -254,7 +254,7 @@ def add_arm(self, arm: Arm, weight: float = 1.0) -> BatchTrial: def add_arms_and_weights( self, arms: list[Arm], - weights: Optional[list[float]] = None, + weights: list[float] | None = None, multiplier: float = 1.0, ) -> BatchTrial: """Add arms and weights to the trial. @@ -327,12 +327,12 @@ def add_generator_run( return self @property - def status_quo(self) -> Optional[Arm]: + def status_quo(self) -> Arm | None: """The control arm for this batch.""" return self._status_quo @status_quo.setter - def status_quo(self, status_quo: Optional[Arm]) -> None: + def status_quo(self, status_quo: Arm | None) -> None: raise NotImplementedError( "Use `set_status_quo_with_weight` or " "`set_status_quo_and_optimize_power` " @@ -347,7 +347,7 @@ def unset_status_quo(self) -> None: @immutable_once_run def set_status_quo_with_weight( - self, status_quo: Arm, weight: Optional[float] + self, status_quo: Arm, weight: float | None ) -> BatchTrial: """Sets status quo arm with given weight. This weight *overrides* any weight the status quo has from generator runs attached to this batch. @@ -477,7 +477,7 @@ def is_factorial(self) -> bool: sufficient_factors = all(len(arm.parameters or []) >= 2 for arm in self.arms) if not sufficient_factors: return False - param_levels: defaultdict[str, dict[Union[str, float], int]] = defaultdict(dict) + param_levels: defaultdict[str, dict[str | float, int]] = defaultdict(dict) for arm in self.arms: for param_name, param_value in arm.parameters.items(): param_levels[param_name][not_none(param_value)] = 1 @@ -490,7 +490,7 @@ def run(self) -> BatchTrial: return checked_cast(BatchTrial, super().run()) def normalized_arm_weights( - self, total: float = 1, trunc_digits: Optional[int] = None + self, total: float = 1, trunc_digits: int | None = None ) -> MutableMapping[Arm, float]: """Returns arms with a new set of weights normalized to the given total. @@ -526,7 +526,7 @@ def normalized_arm_weights( return OrderedDict(zip(self.arms, weights)) def mark_arm_abandoned( - self, arm_name: str, reason: Optional[str] = None + self, arm_name: str, reason: str | None = None ) -> BatchTrial: """Mark a arm abandoned. @@ -564,7 +564,7 @@ def clone(self) -> BatchTrial: def clone_to( self, - experiment: Optional[core.experiment.Experiment] = None, + experiment: core.experiment.Experiment | None = None, include_sq: bool = True, ) -> BatchTrial: """Clone the trial and attach it to a specified experiment. @@ -603,8 +603,8 @@ def clone_to( def attach_batch_trial_data( self, raw_data: dict[str, TEvaluationOutcome], - sample_sizes: Optional[dict[str, int]] = None, - metadata: Optional[dict[str, Union[str, int]]] = None, + sample_sizes: dict[str, int] | None = None, + metadata: dict[str, str | int] | None = None, ) -> None: """Attaches data to the trial diff --git a/ax/core/data.py b/ax/core/data.py index fd6226194f4..2eb84dcc9b9 100644 --- a/ax/core/data.py +++ b/ax/core/data.py @@ -15,7 +15,7 @@ from functools import reduce from hashlib import md5 from io import StringIO -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar import numpy as np import pandas as pd @@ -72,8 +72,8 @@ class BaseData(Base, SerializationMixin): def __init__( self: TBaseData, - df: Optional[pd.DataFrame] = None, - description: Optional[str] = None, + df: pd.DataFrame | None = None, + description: str | None = None, ) -> None: """Init Data. @@ -110,7 +110,7 @@ def _safecast_df( df: pd.DataFrame, # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. - extra_column_types: Optional[dict[str, type]] = None, + extra_column_types: dict[str, type] | None = None, ) -> pd.DataFrame: """Function for safely casting df to standard data types. @@ -138,7 +138,7 @@ def _safecast_df( cls.column_data_types(extra_column_types)[col] is int and df.loc[:, col].isnull().any() ) - and not (coltype is Any) + and coltype is not Any } return checked_cast(pd.DataFrame, df.astype(dtype=dtype)) @@ -150,7 +150,7 @@ def required_columns(cls) -> set[str]: @classmethod def supported_columns( - cls, extra_column_names: Optional[Iterable[str]] = None + cls, extra_column_names: Iterable[str] | None = None ) -> set[str]: """Names of columns supported (but not necessarily required) by this class.""" extra_column_names = set(extra_column_names or []) @@ -164,8 +164,8 @@ def column_data_types( cls, # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. - extra_column_types: Optional[dict[str, type]] = None, - excluded_columns: Optional[Iterable[str]] = None, + extra_column_types: dict[str, type] | None = None, + excluded_columns: Iterable[str] | None = None, # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. ) -> dict[str, type]: @@ -194,8 +194,8 @@ def serialize_init_args(cls, obj: Any) -> dict[str, Any]: def deserialize_init_args( cls, args: dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, - class_decoder_registry: Optional[TClassDecoderRegistry] = None, + decoder_registry: TDecoderRegistry | None = None, + class_decoder_registry: TClassDecoderRegistry | None = None, ) -> dict[str, Any]: """Given a dictionary, extract the properties needed to initialize the object. Used for storage. @@ -292,9 +292,9 @@ def from_evaluations( cls: type[TBaseData], evaluations: dict[str, TTrialEvaluation], trial_index: int, - sample_sizes: Optional[dict[str, int]] = None, - start_time: Optional[Union[int, str]] = None, - end_time: Optional[Union[int, str]] = None, + sample_sizes: dict[str, int] | None = None, + start_time: int | str | None = None, + end_time: int | str | None = None, ) -> TBaseData: """ Convert dict of evaluations to Ax data object. @@ -338,9 +338,9 @@ def from_fidelity_evaluations( cls: type[TBaseData], evaluations: dict[str, TFidelityTrialEvaluation], trial_index: int, - sample_sizes: Optional[dict[str, int]] = None, - start_time: Optional[int] = None, - end_time: Optional[int] = None, + sample_sizes: dict[str, int] | None = None, + start_time: int | None = None, + end_time: int | None = None, ) -> TBaseData: """ Convert dict of fidelity evaluations to Ax data object. @@ -381,9 +381,9 @@ def _get_fidelity_records( @staticmethod def _add_cols_to_records( records: list[dict[str, Any]], - sample_sizes: Optional[dict[str, int]] = None, - start_time: Optional[Union[int, str]] = None, - end_time: Optional[Union[int, str]] = None, + sample_sizes: dict[str, int] | None = None, + start_time: int | str | None = None, + end_time: int | str | None = None, ) -> list[dict[str, Any]]: """Adds to records metadata columns that are available for all BaseData subclasses. @@ -477,8 +477,8 @@ def metric_names(self) -> set[str]: def filter( self, - trial_indices: Optional[Iterable[int]] = None, - metric_names: Optional[Iterable[str]] = None, + trial_indices: Iterable[int] | None = None, + metric_names: Iterable[str] | None = None, ) -> Data: """Construct a new object with the subset of rows corresponding to the provided trial indices AND metric names. If either trial_indices or @@ -494,8 +494,8 @@ def filter( @staticmethod def _filter_df( df: pd.DataFrame, - trial_indices: Optional[Iterable[int]] = None, - metric_names: Optional[Iterable[str]] = None, + trial_indices: Iterable[int] | None = None, + metric_names: Iterable[str] | None = None, ) -> pd.DataFrame: trial_indices_mask = ( reduce( @@ -519,7 +519,7 @@ def _filter_df( @staticmethod def from_multiple_data( - data: Iterable[Data], subset_metrics: Optional[Iterable[str]] = None + data: Iterable[Data], subset_metrics: Iterable[str] | None = None ) -> Data: """Combines multiple objects into one (with the concatenated underlying dataframe). @@ -582,9 +582,9 @@ def _ms_epoch_to_isoformat(epoch: int) -> str: def custom_data_class( # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. - column_data_types: Optional[dict[str, type]] = None, - required_columns: Optional[set[str]] = None, - time_columns: Optional[set[str]] = None, + column_data_types: dict[str, type] | None = None, + required_columns: set[str] | None = None, + time_columns: set[str] | None = None, ) -> type[Data]: """Creates a custom data class with additional columns. @@ -607,7 +607,7 @@ def required_columns(cls) -> set[str]: @classmethod def column_data_types( - cls, extra_column_types: Optional[dict[str, type]] = None + cls, extra_column_types: dict[str, type] | None = None ) -> dict[str, type]: return super().column_data_types( {**(extra_column_types or {}), **(column_data_types or {})} diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 9c4759ad8b3..49eb7321a19 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -16,7 +16,7 @@ from datetime import datetime from functools import partial, reduce -from typing import Any, Optional +from typing import Any import ax.core.observation as observation import pandas as pd @@ -71,19 +71,19 @@ class Experiment(Base): def __init__( self, search_space: SearchSpace, - name: Optional[str] = None, - optimization_config: Optional[OptimizationConfig] = None, - tracking_metrics: Optional[list[Metric]] = None, - runner: Optional[Runner] = None, - status_quo: Optional[Arm] = None, - description: Optional[str] = None, + name: str | None = None, + optimization_config: OptimizationConfig | None = None, + tracking_metrics: list[Metric] | None = None, + runner: Runner | None = None, + status_quo: Arm | None = None, + description: str | None = None, is_test: bool = False, - experiment_type: Optional[str] = None, - properties: Optional[dict[str, Any]] = None, - default_data_type: Optional[DataType] = None, - auxiliary_experiments_by_purpose: Optional[ + experiment_type: str | None = None, + properties: dict[str, Any] | None = None, + default_data_type: DataType | None = None, + auxiliary_experiments_by_purpose: None | ( dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] - ] = None, + ) = None, ) -> None: """Inits Experiment. @@ -104,7 +104,7 @@ def __init__( """ # appease pyre self._search_space: SearchSpace - self._status_quo: Optional[Arm] = None + self._status_quo: Arm | None = None self._is_test: bool self._name = name @@ -113,7 +113,7 @@ def __init__( self.is_test = is_test self._data_by_trial: dict[int, OrderedDict[int, Data]] = {} - self._experiment_type: Optional[str] = experiment_type + self._experiment_type: str | None = experiment_type # pyre-fixme[4]: Attribute must be annotated. self._optimization_config = None self._tracking_metrics: dict[str, Metric] = {} @@ -181,12 +181,12 @@ def time_created(self) -> datetime: return self._time_created @property - def experiment_type(self) -> Optional[str]: + def experiment_type(self) -> str | None: """The type of the experiment.""" return self._experiment_type @experiment_type.setter - def experiment_type(self, experiment_type: Optional[str]) -> None: + def experiment_type(self, experiment_type: str | None) -> None: """Set the type of the experiment.""" self._experiment_type = experiment_type @@ -237,12 +237,12 @@ def search_space(self, search_space: SearchSpace) -> None: self._search_space = search_space @property - def status_quo(self) -> Optional[Arm]: + def status_quo(self) -> Arm | None: """The existing arm that new arms will be compared against.""" return self._status_quo @status_quo.setter - def status_quo(self, status_quo: Optional[Arm]) -> None: + def status_quo(self, status_quo: Arm | None) -> None: if status_quo is not None: self.search_space.check_types( parameterization=status_quo.parameters, raise_error=True @@ -329,7 +329,7 @@ def num_abandoned_arms(self) -> int: return len(abandoned) @property - def optimization_config(self) -> Optional[OptimizationConfig]: + def optimization_config(self) -> OptimizationConfig | None: """The experiment's optimization config.""" return self._optimization_config @@ -464,7 +464,7 @@ def metrics(self) -> dict[str, Metric]: return {**self._tracking_metrics, **optimization_config_metrics} def _metrics_by_class( - self, metrics: Optional[list[Metric]] = None + self, metrics: list[Metric] | None = None ) -> dict[type[Metric], list[Metric]]: metrics_by_class: dict[type[Metric], list[Metric]] = defaultdict(list) for metric in metrics or list(self.metrics.values()): @@ -478,7 +478,7 @@ def _metrics_by_class( def fetch_data_results( self, - metrics: Optional[list[Metric]] = None, + metrics: list[Metric] | None = None, combine_with_last_data: bool = False, overwrite_existing_data: bool = False, **kwargs: Any, @@ -516,7 +516,7 @@ def fetch_data_results( def fetch_trials_data_results( self, trial_indices: Iterable[int], - metrics: Optional[list[Metric]] = None, + metrics: list[Metric] | None = None, combine_with_last_data: bool = False, overwrite_existing_data: bool = False, **kwargs: Any, @@ -551,7 +551,7 @@ def fetch_trials_data_results( def fetch_data( self, - metrics: Optional[list[Metric]] = None, + metrics: list[Metric] | None = None, combine_with_last_data: bool = False, overwrite_existing_data: bool = False, **kwargs: Any, @@ -597,7 +597,7 @@ def fetch_data( def fetch_trials_data( self, trial_indices: Iterable[int], - metrics: Optional[list[Metric]] = None, + metrics: list[Metric] | None = None, combine_with_last_data: bool = False, overwrite_existing_data: bool = False, **kwargs: Any, @@ -641,7 +641,7 @@ def fetch_trials_data( def _lookup_or_fetch_trials_results( self, trials: list[BaseTrial], - metrics: Optional[Iterable[Metric]] = None, + metrics: Iterable[Metric] | None = None, combine_with_last_data: bool = False, overwrite_existing_data: bool = False, **kwargs: Any, @@ -704,7 +704,7 @@ def _lookup_or_fetch_trials_results( @copy_doc(BaseTrial.fetch_data) def _fetch_trial_data( - self, trial_index: int, metrics: Optional[list[Metric]] = None, **kwargs: Any + self, trial_index: int, metrics: list[Metric] | None = None, **kwargs: Any ) -> dict[str, MetricFetchResult]: trial = self.trials[trial_index] @@ -882,7 +882,7 @@ def attach_fetch_results( results: Mapping[int, Mapping[str, MetricFetchResult]], combine_with_last_data: bool = False, overwrite_existing_data: bool = False, - ) -> Optional[int]: + ) -> int | None: """ UNSAFE: Prefer to use attach_data directly instead. @@ -977,7 +977,7 @@ def lookup_data_for_trial( def lookup_data( self, - trial_indices: Optional[Iterable[int]] = None, + trial_indices: Iterable[int] | None = None, ) -> Data: """Lookup stored data for trials on this experiment. @@ -1068,9 +1068,9 @@ def default_data_constructor(self) -> type: def new_trial( self, - generator_run: Optional[GeneratorRun] = None, - trial_type: Optional[str] = None, - ttl_seconds: Optional[int] = None, + generator_run: GeneratorRun | None = None, + trial_type: str | None = None, + ttl_seconds: int | None = None, ) -> Trial: """Create a new trial associated with this experiment. @@ -1099,12 +1099,12 @@ def new_trial( def new_batch_trial( self, - generator_run: Optional[GeneratorRun] = None, - generator_runs: Optional[list[GeneratorRun]] = None, - trial_type: Optional[str] = None, - optimize_for_power: Optional[bool] = False, - ttl_seconds: Optional[int] = None, - lifecycle_stage: Optional[LifecycleStage] = None, + generator_run: GeneratorRun | None = None, + generator_runs: list[GeneratorRun] | None = None, + trial_type: str | None = None, + optimize_for_power: bool | None = False, + ttl_seconds: int | None = None, + lifecycle_stage: LifecycleStage | None = None, ) -> BatchTrial: """Create a new batch trial associated with this experiment. @@ -1166,7 +1166,7 @@ def reset_runners(self, runner: Runner) -> None: trial.runner = runner self.runner = runner - def _attach_trial(self, trial: BaseTrial, index: Optional[int] = None) -> int: + def _attach_trial(self, trial: BaseTrial, index: int | None = None) -> int: """Attach a trial to this experiment. Should only be called within the trial constructor. @@ -1217,8 +1217,8 @@ def validate_trials(self, trials: Iterable[BaseTrial]) -> None: def warm_start_from_old_experiment( self, old_experiment: Experiment, - copy_run_metadata_keys: Optional[list[str]] = None, - trial_statuses_to_copy: Optional[list[TrialStatus]] = None, + copy_run_metadata_keys: list[str] | None = None, + trial_statuses_to_copy: list[TrialStatus] | None = None, search_space_check_membership_raise_error: bool = True, ) -> list[Trial]: """Copy all completed trials with data from an old Ax expeirment to this one. @@ -1433,7 +1433,7 @@ def __repr__(self) -> str: # overridden in the MultiTypeExperiment class. @property - def default_trial_type(self) -> Optional[str]: + def default_trial_type(self) -> str | None: """Default trial type assigned to trials in this experiment. In the base experiment class this is always None. For experiments @@ -1441,7 +1441,7 @@ def default_trial_type(self) -> Optional[str]: """ return None - def runner_for_trial(self, trial: BaseTrial) -> Optional[Runner]: + def runner_for_trial(self, trial: BaseTrial) -> Runner | None: """The default runner to use for a given trial. In the base experiment class, this is always the default experiment runner. @@ -1449,7 +1449,7 @@ def runner_for_trial(self, trial: BaseTrial) -> Optional[Runner]: """ return self.runner - def supports_trial_type(self, trial_type: Optional[str]) -> bool: + def supports_trial_type(self, trial_type: str | None) -> bool: """Whether this experiment allows trials of the given type. The base experiment class only supports None. For experiments @@ -1460,9 +1460,9 @@ def supports_trial_type(self, trial_type: Optional[str]) -> bool: def attach_trial( self, parameterizations: list[TParameterization], - arm_names: Optional[list[str]] = None, - ttl_seconds: Optional[int] = None, - run_metadata: Optional[dict[str, Any]] = None, + arm_names: list[str] | None = None, + ttl_seconds: int | None = None, + run_metadata: dict[str, Any] | None = None, optimize_for_power: bool = False, ) -> tuple[dict[str, TParameterization], int]: """Attach a new trial with the given parameterization to the experiment. @@ -1560,17 +1560,17 @@ def attach_trial( def clone_with( self, - search_space: Optional[SearchSpace] = None, - name: Optional[str] = None, - optimization_config: Optional[OptimizationConfig] = None, - tracking_metrics: Optional[list[Metric]] = None, - runner: Optional[Runner] = None, - status_quo: Optional[Arm] = None, - description: Optional[str] = None, - is_test: Optional[bool] = None, - properties: Optional[dict[str, Any]] = None, - trial_indices: Optional[list[int]] = None, - data: Optional[Data] = None, + search_space: SearchSpace | None = None, + name: str | None = None, + optimization_config: OptimizationConfig | None = None, + tracking_metrics: list[Metric] | None = None, + runner: Runner | None = None, + status_quo: Arm | None = None, + description: str | None = None, + is_test: bool | None = None, + properties: dict[str, Any] | None = None, + trial_indices: list[int] | None = None, + data: Data | None = None, ) -> Experiment: r""" Return a copy of this experiment with some attributes replaced. @@ -1753,7 +1753,7 @@ def metric_config_summary_df(self) -> pd.DataFrame: def add_arm_and_prevent_naming_collision( - new_trial: Trial, old_trial: Trial, old_experiment_name: Optional[str] = None + new_trial: Trial, old_trial: Trial, old_experiment_name: str | None = None ) -> None: # Add all of an old trial's arms to a new trial. Rename any arm with auto-generated # naming format to prevent naming collisions during warm-start. If an old diff --git a/ax/core/formatting_utils.py b/ax/core/formatting_utils.py index 08c5b2e5cf9..be088e4ffc3 100644 --- a/ax/core/formatting_utils.py +++ b/ax/core/formatting_utils.py @@ -7,7 +7,7 @@ # pyre-strict from enum import Enum -from typing import cast, Optional, Union +from typing import cast import numpy as np from ax.core.data import Data @@ -102,8 +102,8 @@ def data_and_evaluations_from_raw_data( trial_index: int, sample_sizes: dict[str, int], data_type: DataType, - start_time: Optional[Union[int, str]] = None, - end_time: Optional[Union[int, str]] = None, + start_time: int | str | None = None, + end_time: int | str | None = None, ) -> tuple[dict[str, TEvaluationOutcome], Data]: """Transforms evaluations into Ax Data. diff --git a/ax/core/generation_strategy_interface.py b/ax/core/generation_strategy_interface.py index 61915e18e28..c6877fb08a0 100644 --- a/ax/core/generation_strategy_interface.py +++ b/ax/core/generation_strategy_interface.py @@ -9,8 +9,6 @@ from abc import ABC, abstractmethod -from typing import Optional - from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun @@ -31,7 +29,7 @@ class GenerationStrategyInterface(ABC, Base): _name: str # Experiment, for which this generation strategy has generated trials, if # it exists. - _experiment: Optional[Experiment] = None + _experiment: Experiment | None = None # Constant for default number of arms to generate if `n` is not specified in # `gen` call and "total_concurrent_arms" is not set in experiment properties. @@ -44,11 +42,11 @@ def __init__(self, name: str) -> None: def gen_for_multiple_trials_with_multiple_models( self, experiment: Experiment, - data: Optional[Data] = None, + data: Data | None = None, # TODO[drfreund, danielcohennyc, mgarrard]: Update the format of the arguments # below as we find the right one. num_generator_runs: int = 1, - n: Optional[int] = None, + n: int | None = None, ) -> list[list[GeneratorRun]]: """Produce ``GeneratorRun``-s for multiple trials at once with the possibility of joining ``GeneratorRun``-s from multiple models into one ``BatchTrial``. diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index 246d058acc3..ef9d8de838e 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -15,7 +15,7 @@ from datetime import datetime from enum import Enum from logging import Logger -from typing import Any, Optional +from typing import Any import pandas as pd from ax.core.arm import Arm @@ -89,24 +89,24 @@ class GeneratorRun(SortableBase): def __init__( self, arms: list[Arm], - weights: Optional[list[float]] = None, - optimization_config: Optional[OptimizationConfig] = None, - search_space: Optional[SearchSpace] = None, - model_predictions: Optional[TModelPredict] = None, - best_arm_predictions: Optional[tuple[Arm, Optional[TModelPredictArm]]] = None, - type: Optional[str] = None, - fit_time: Optional[float] = None, - gen_time: Optional[float] = None, - model_key: Optional[str] = None, - model_kwargs: Optional[dict[str, Any]] = None, - bridge_kwargs: Optional[dict[str, Any]] = None, - gen_metadata: Optional[TGenMetadata] = None, - model_state_after_gen: Optional[dict[str, Any]] = None, - generation_step_index: Optional[int] = None, - candidate_metadata_by_arm_signature: Optional[ + weights: list[float] | None = None, + optimization_config: OptimizationConfig | None = None, + search_space: SearchSpace | None = None, + model_predictions: TModelPredict | None = None, + best_arm_predictions: tuple[Arm, TModelPredictArm | None] | None = None, + type: str | None = None, + fit_time: float | None = None, + gen_time: float | None = None, + model_key: str | None = None, + model_kwargs: dict[str, Any] | None = None, + bridge_kwargs: dict[str, Any] | None = None, + gen_metadata: TGenMetadata | None = None, + model_state_after_gen: dict[str, Any] | None = None, + generation_step_index: int | None = None, + candidate_metadata_by_arm_signature: None | ( dict[str, TCandidateMetadata] - ] = None, - generation_node_name: Optional[str] = None, + ) = None, + generation_node_name: str | None = None, ) -> None: """ Inits GeneratorRun. @@ -178,13 +178,13 @@ def __init__( for arm, weight in zip(arms, weights): self.add_arm(arm=arm, weight=weight) - self._generator_run_type: Optional[str] = type + self._generator_run_type: str | None = type self._time_created: datetime = datetime.now() self._optimization_config = optimization_config self._search_space = search_space self._model_predictions = model_predictions self._best_arm_predictions = best_arm_predictions - self._index: Optional[int] = None + self._index: int | None = None self._fit_time = fit_time self._gen_time = gen_time self._model_key = model_key @@ -239,7 +239,7 @@ def arm_weights(self) -> MutableMapping[Arm, float]: return OrderedDict(zip(self.arms, self.weights)) @property - def generator_run_type(self) -> Optional[str]: + def generator_run_type(self) -> str | None: """The type of the generator run.""" return self._generator_run_type @@ -249,7 +249,7 @@ def time_created(self) -> datetime: return self._time_created @property - def index(self) -> Optional[int]: + def index(self) -> int | None: """The index of this generator run within a trial's list of generator run structs. This field is set when the generator run is added to a trial. """ @@ -267,34 +267,34 @@ def index(self, index: int) -> None: self._index = index @property - def optimization_config(self) -> Optional[OptimizationConfig]: + def optimization_config(self) -> OptimizationConfig | None: """The optimization config used during generation of this run.""" return self._optimization_config @property - def search_space(self) -> Optional[SearchSpace]: + def search_space(self) -> SearchSpace | None: """The search used during generation of this run.""" return self._search_space @property - def model_predictions(self) -> Optional[TModelPredict]: + def model_predictions(self) -> TModelPredict | None: """Means and covariances for the arms in this run recorded at the time the run was executed. """ return self._model_predictions @property - def fit_time(self) -> Optional[float]: + def fit_time(self) -> float | None: """Time taken to fit the model in seconds.""" return self._fit_time @property - def gen_time(self) -> Optional[float]: + def gen_time(self) -> float | None: """Time taken to generate in seconds.""" return self._gen_time @property - def model_predictions_by_arm(self) -> Optional[dict[str, TModelPredictArm]]: + def model_predictions_by_arm(self) -> dict[str, TModelPredictArm] | None: """Model predictions for each arm in this run, at the time the run was executed. """ @@ -309,21 +309,21 @@ def model_predictions_by_arm(self) -> Optional[dict[str, TModelPredictArm]]: return predictions @property - def best_arm_predictions(self) -> Optional[tuple[Arm, Optional[TModelPredictArm]]]: + def best_arm_predictions(self) -> tuple[Arm, TModelPredictArm | None] | None: """Best arm in this run (according to the optimization config) and its optional respective model predictions. """ return self._best_arm_predictions @property - def gen_metadata(self) -> Optional[TGenMetadata]: + def gen_metadata(self) -> TGenMetadata | None: """Returns metadata generated by this run.""" return self._gen_metadata @property def candidate_metadata_by_arm_signature( self, - ) -> Optional[dict[str, TCandidateMetadata]]: + ) -> dict[str, TCandidateMetadata] | None: """Retrieves model-produced candidate metadata as a mapping from arm name (for the arm the candidate became when added to experiment) to the metadata dict. """ diff --git a/ax/core/map_data.py b/ax/core/map_data.py index 3e6e049ad0f..c3439c4a243 100644 --- a/ax/core/map_data.py +++ b/ax/core/map_data.py @@ -10,7 +10,7 @@ from collections.abc import Iterable, Sequence from copy import deepcopy from logging import Logger -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Generic, TypeVar import numpy as np import pandas as pd @@ -96,17 +96,17 @@ class MapData(Data): DEDUPLICATE_BY_COLUMNS = ["arm_name", "metric_name"] _map_df: pd.DataFrame - _memo_df: Optional[pd.DataFrame] + _memo_df: pd.DataFrame | None # pyre-fixme[24]: Generic type `MapKeyInfo` expects 1 type parameter. _map_key_infos: list[MapKeyInfo] def __init__( self, - df: Optional[pd.DataFrame] = None, + df: pd.DataFrame | None = None, # pyre-fixme[24]: Generic type `MapKeyInfo` expects 1 type parameter. - map_key_infos: Optional[Iterable[MapKeyInfo]] = None, - description: Optional[str] = None, + map_key_infos: Iterable[MapKeyInfo] | None = None, + description: str | None = None, ) -> None: if map_key_infos is None and df is not None: raise ValueError("map_key_infos may be `None` iff `df` is None.") @@ -174,7 +174,7 @@ def map_key_to_type(self) -> dict[str, type]: @staticmethod def from_multiple_map_data( data: Sequence[MapData], - subset_metrics: Optional[Iterable[str]] = None, + subset_metrics: Iterable[str] | None = None, ) -> MapData: if len(data) == 0: return MapData() @@ -208,7 +208,7 @@ def from_map_evaluations( evaluations: dict[str, TMapTrialEvaluation], trial_index: int, # pyre-fixme[24]: Generic type `MapKeyInfo` expects 1 type parameter. - map_key_infos: Optional[Iterable[MapKeyInfo]] = None, + map_key_infos: Iterable[MapKeyInfo] | None = None, ) -> MapData: records = [ { @@ -253,7 +253,7 @@ def map_df(self, df: pd.DataFrame): @staticmethod def from_multiple_data( data: Iterable[Data], - subset_metrics: Optional[Iterable[str]] = None, + subset_metrics: Iterable[str] | None = None, ) -> MapData: """Downcast instances of Data into instances of MapData with empty map_key_infos if necessary then combine as usual (filling in empty cells with @@ -292,8 +292,8 @@ def df(self) -> pd.DataFrame: @copy_doc(Data.filter) def filter( self, - trial_indices: Optional[Iterable[int]] = None, - metric_names: Optional[Iterable[str]] = None, + trial_indices: Iterable[int] | None = None, + metric_names: Iterable[str] | None = None, ) -> MapData: return MapData( @@ -318,8 +318,8 @@ def serialize_init_args(cls, obj: Any) -> dict[str, Any]: def deserialize_init_args( cls, args: dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, - class_decoder_registry: Optional[TClassDecoderRegistry] = None, + decoder_registry: TDecoderRegistry | None = None, + class_decoder_registry: TClassDecoderRegistry | None = None, ) -> dict[str, Any]: """Given a dictionary, extract the properties needed to initialize the metric. Used for storage. @@ -341,10 +341,10 @@ def clone(self) -> MapData: def subsample( self, - map_key: Optional[str] = None, - keep_every: Optional[int] = None, - limit_rows_per_group: Optional[int] = None, - limit_rows_per_metric: Optional[int] = None, + map_key: str | None = None, + keep_every: int | None = None, + limit_rows_per_group: int | None = None, + limit_rows_per_metric: int | None = None, include_first_last: bool = True, ) -> MapData: """Subsample the `map_key` column in an equally-spaced manner (if there is @@ -415,10 +415,10 @@ def subsample( def _subsample_one_metric( map_df: pd.DataFrame, - map_key: Optional[str] = None, - keep_every: Optional[int] = None, - limit_rows_per_group: Optional[int] = None, - limit_rows_per_metric: Optional[int] = None, + map_key: str | None = None, + keep_every: int | None = None, + limit_rows_per_group: int | None = None, + limit_rows_per_metric: int | None = None, include_first_last: bool = True, ) -> pd.DataFrame: """Helper function to subsample a dataframe that holds a single metric.""" diff --git a/ax/core/metric.py b/ax/core/metric.py index 83714af4a78..6a325bcf9e0 100644 --- a/ax/core/metric.py +++ b/ax/core/metric.py @@ -17,7 +17,7 @@ from functools import reduce from logging import Logger -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from ax.core.data import Data from ax.utils.common.base import SortableBase @@ -38,7 +38,7 @@ class MetricFetchE: # TODO[mpolson64] Replace this with ExceptionE message: str - exception: Optional[Exception] + exception: Exception | None def __post_init__(self) -> None: logger.info(msg=f"MetricFetchE INFO: Initialized {self}") @@ -52,7 +52,7 @@ def __repr__(self) -> str: f"with Traceback:\n {self.tb_str()}" ) - def tb_str(self) -> Optional[str]: + def tb_str(self) -> str | None: if self.exception is None: return None @@ -86,8 +86,8 @@ class Metric(SortableBase, SerializationMixin): def __init__( self, name: str, - lower_is_better: Optional[bool] = None, - properties: Optional[dict[str, Any]] = None, + lower_is_better: bool | None = None, + properties: dict[str, Any] | None = None, ) -> None: """Inits Metric. @@ -226,7 +226,7 @@ def bulk_fetch_experiment_data( self, experiment: core.experiment.Experiment, metrics: list[Metric], - trials: Optional[list[core.base_trial.BaseTrial]] = None, + trials: list[core.base_trial.BaseTrial] | None = None, **kwargs: Any, ) -> dict[int, dict[str, MetricFetchResult]]: """Fetch multiple metrics data for multiple trials on an experiment, using @@ -271,7 +271,7 @@ def fetch_data_prefer_lookup( self, experiment: core.experiment.Experiment, metrics: list[Metric], - trials: Optional[list[core.base_trial.BaseTrial]] = None, + trials: list[core.base_trial.BaseTrial] | None = None, **kwargs: Any, ) -> tuple[dict[int, dict[str, MetricFetchResult]], bool]: """Fetch or lookup (with fallback to fetching) data for given metrics, @@ -392,7 +392,7 @@ def fetch_experiment_data_multi( cls, experiment: core.experiment.Experiment, metrics: Iterable[Metric], - trials: Optional[Iterable[core.base_trial.BaseTrial]] = None, + trials: Iterable[core.base_trial.BaseTrial] | None = None, **kwargs: Any, ) -> dict[int, dict[str, MetricFetchResult]]: """Fetch multiple metrics data for an experiment. @@ -489,7 +489,7 @@ def _unwrap_trial_data_multi( cls, results: Mapping[str, MetricFetchResult], # TODO[mpolson64] Add critical_metric_names to other unwrap methods - critical_metric_names: Optional[list[str]] = None, + critical_metric_names: list[str] | None = None, ) -> Data: # NOTE: This can be lossy (ex. a MapData could get implicitly cast to a Data and # lose rows)if some MetricFetchResults contain Data not of type diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index 388f8156fe5..5b601617695 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -7,7 +7,8 @@ # pyre-strict import logging -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus @@ -47,13 +48,13 @@ def __init__( search_space: SearchSpace, default_trial_type: str, default_runner: Runner, - optimization_config: Optional[OptimizationConfig] = None, - status_quo: Optional[Arm] = None, - description: Optional[str] = None, + optimization_config: OptimizationConfig | None = None, + status_quo: Arm | None = None, + description: str | None = None, is_test: bool = False, - experiment_type: Optional[str] = None, - properties: Optional[dict[str, Any]] = None, - default_data_type: Optional[DataType] = None, + experiment_type: str | None = None, + properties: dict[str, Any] | None = None, + default_data_type: DataType | None = None, ) -> None: """Inits Experiment. @@ -90,7 +91,7 @@ def __init__( # call super.__init__() after defining fields above, because we need # them to be populated before optimization config is set - super(MultiTypeExperiment, self).__init__( + super().__init__( name=name, search_space=search_space, optimization_config=optimization_config, @@ -144,7 +145,7 @@ def update_runner(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment # pyre-fixme[14]: `add_tracking_metric` overrides method defined in `Experiment` # inconsistently. def add_tracking_metric( - self, metric: Metric, trial_type: str, canonical_name: Optional[str] = None + self, metric: Metric, trial_type: str, canonical_name: str | None = None ) -> "MultiTypeExperiment": """Add a new metric to the experiment. @@ -156,7 +157,7 @@ def add_tracking_metric( if not self.supports_trial_type(trial_type): raise ValueError(f"`{trial_type}` is not a supported trial type.") - super(MultiTypeExperiment, self).add_tracking_metric(metric) + super().add_tracking_metric(metric) self._metric_to_trial_type[metric.name] = trial_type if canonical_name is not None: self._metric_to_canonical_name[metric.name] = canonical_name @@ -165,7 +166,7 @@ def add_tracking_metric( # pyre-fixme[14]: `update_tracking_metric` overrides method defined in # `Experiment` inconsistently. def update_tracking_metric( - self, metric: Metric, trial_type: str, canonical_name: Optional[str] = None + self, metric: Metric, trial_type: str, canonical_name: str | None = None ) -> "MultiTypeExperiment": """Update an existing metric on the experiment. @@ -184,7 +185,7 @@ def update_tracking_metric( elif not self.supports_trial_type(trial_type): raise ValueError(f"`{trial_type}` is not a supported trial type.") - super(MultiTypeExperiment, self).update_tracking_metric(metric) + super().update_tracking_metric(metric) self._metric_to_trial_type[metric.name] = trial_type if canonical_name is not None: self._metric_to_canonical_name[metric.name] = canonical_name @@ -207,7 +208,7 @@ def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment": @copy_doc(Experiment.fetch_data) def fetch_data( self, - metrics: Optional[list[Metric]] = None, + metrics: list[Metric] | None = None, combine_with_last_data: bool = False, overwrite_existing_data: bool = False, **kwargs: Any, @@ -228,7 +229,7 @@ def fetch_data( @copy_doc(Experiment._fetch_trial_data) def _fetch_trial_data( - self, trial_index: int, metrics: Optional[list[Metric]] = None, **kwargs: Any + self, trial_index: int, metrics: list[Metric] | None = None, **kwargs: Any ) -> dict[str, MetricFetchResult]: trial = self.trials[trial_index] metrics = [ @@ -262,18 +263,18 @@ def metric_to_trial_type(self) -> dict[str, str]: # -- Overridden functions from Base Experiment Class -- @property - def default_trial_type(self) -> Optional[str]: + def default_trial_type(self) -> str | None: """Default trial type assigned to trials in this experiment.""" return self._default_trial_type - def runner_for_trial(self, trial: BaseTrial) -> Optional[Runner]: + def runner_for_trial(self, trial: BaseTrial) -> Runner | None: """The default runner to use for a given trial. Looks up the appropriate runner for this trial type in the trial_type_to_runner. """ return self.runner_for_trial_type(trial_type=none_throws(trial.trial_type)) - def runner_for_trial_type(self, trial_type: str) -> Optional[Runner]: + def runner_for_trial_type(self, trial_type: str) -> Runner | None: """The default runner to use for a given trial type. Looks up the appropriate runner for this trial type in the trial_type_to_runner. @@ -282,7 +283,7 @@ def runner_for_trial_type(self, trial_type: str) -> Optional[Runner]: raise ValueError(f"Trial type `{trial_type}` is not supported.") return self._trial_type_to_runner[trial_type] - def supports_trial_type(self, trial_type: Optional[str]) -> bool: + def supports_trial_type(self, trial_type: str | None) -> bool: """Whether this experiment allows trials of the given type. Only trial types defined in the trial_type_to_runner are allowed. @@ -296,7 +297,7 @@ def reset_runners(self, runner: Runner) -> None: def filter_trials_by_type( - trials: Sequence[BaseTrial], trial_type: Optional[str] + trials: Sequence[BaseTrial], trial_type: str | None ) -> list[BaseTrial]: """Filter trials by trial type if provided. diff --git a/ax/core/objective.py b/ax/core/objective.py index 0e63cb1f71d..3530df1518b 100644 --- a/ax/core/objective.py +++ b/ax/core/objective.py @@ -11,7 +11,7 @@ import warnings from collections.abc import Iterable from logging import Logger -from typing import Any, Optional +from typing import Any from ax.core.metric import Metric from ax.exceptions.core import UserInputError @@ -29,7 +29,7 @@ class Objective(SortableBase): minimize: If True, minimize metric. """ - def __init__(self, metric: Metric, minimize: Optional[bool] = None) -> None: + def __init__(self, metric: Metric, minimize: bool | None = None) -> None: """Create a new objective. Args: @@ -104,7 +104,7 @@ class MultiObjective(Objective): def __init__( self, - objectives: Optional[list[Objective]] = None, + objectives: list[Objective] | None = None, **extra_kwargs: Any, # Here to satisfy serialization. ) -> None: """Create a new objective. @@ -186,7 +186,7 @@ class ScalarizedObjective(Objective): def __init__( self, metrics: list[Metric], - weights: Optional[list[float]] = None, + weights: list[float] | None = None, minimize: bool = False, ) -> None: """Create a new objective. diff --git a/ax/core/observation.py b/ax/core/observation.py index 6ee60bb9040..f96099d3f31 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -13,7 +13,6 @@ from collections.abc import Iterable from copy import deepcopy from logging import Logger -from typing import Optional import ax.core.experiment as experiment import numpy as np @@ -69,10 +68,10 @@ class ObservationFeatures(Base): def __init__( self, parameters: TParameterization, - trial_index: Optional[int] = None, - start_time: Optional[pd.Timestamp] = None, - end_time: Optional[pd.Timestamp] = None, - random_split: Optional[int] = None, + trial_index: int | None = None, + start_time: pd.Timestamp | None = None, + end_time: pd.Timestamp | None = None, + random_split: int | None = None, metadata: TCandidateMetadata = None, ) -> None: self.parameters = parameters @@ -85,10 +84,10 @@ def __init__( @staticmethod def from_arm( arm: Arm, - trial_index: Optional[int] = None, - start_time: Optional[pd.Timestamp] = None, - end_time: Optional[pd.Timestamp] = None, - random_split: Optional[int] = None, + trial_index: int | None = None, + start_time: pd.Timestamp | None = None, + end_time: pd.Timestamp | None = None, + random_split: int | None = None, metadata: TCandidateMetadata = None, ) -> ObservationFeatures: """Convert a Arm to an ObservationFeatures, including additional @@ -123,7 +122,7 @@ def update_features(self, new_features: ObservationFeatures) -> ObservationFeatu return self def clone( - self, replace_parameters: Optional[TParameterization] = None + self, replace_parameters: TParameterization | None = None ) -> ObservationFeatures: """Make a copy of these ``ObservationFeatures``. @@ -150,7 +149,7 @@ def __repr__(self) -> str: strs = [] for attr in ["trial_index", "start_time", "end_time", "random_split"]: if getattr(self, attr) is not None: - strs.append(", {attr}={val}".format(attr=attr, val=getattr(self, attr))) + strs.append(f", {attr}={getattr(self, attr)}") repr_str = "ObservationFeatures(parameters={parameters}".format( parameters=self.parameters ) @@ -192,9 +191,7 @@ def __init__( ) -> None: k = len(metric_names) if means.shape != (k,): - raise ValueError( - "Shape of means should be {}, is {}.".format((k,), (means.shape)) - ) + raise ValueError(f"Shape of means should be {(k,)}, is {(means.shape)}.") if covariance.shape != (k, k): raise ValueError( "Shape of covariance should be {}, is {}.".format( @@ -248,7 +245,7 @@ def __init__( self, features: ObservationFeatures, data: ObservationData, - arm_name: Optional[str] = None, + arm_name: str | None = None, ) -> None: self.features = features self.data = data @@ -371,7 +368,7 @@ def _observations_from_dataframe( def _filter_data_on_status( df: pd.DataFrame, experiment: experiment.Experiment, - trial_status: Optional[TrialStatus], + trial_status: TrialStatus | None, # Arms on a BatchTrial can be abandoned even if the BatchTrial is not. # Data will be filtered out if is_arm_abandoned is True and the corresponding # statuses_to_include does not contain TrialStatus.ABANDONED. @@ -436,8 +433,8 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: def observations_from_data( experiment: experiment.Experiment, data: Data, - statuses_to_include: Optional[set[TrialStatus]] = None, - statuses_to_include_map_metric: Optional[set[TrialStatus]] = None, + statuses_to_include: set[TrialStatus] | None = None, + statuses_to_include_map_metric: set[TrialStatus] | None = None, ) -> list[Observation]: """Convert Data to observations. @@ -515,11 +512,11 @@ def observations_from_data( def observations_from_map_data( experiment: experiment.Experiment, map_data: MapData, - statuses_to_include: Optional[set[TrialStatus]] = None, - statuses_to_include_map_metric: Optional[set[TrialStatus]] = None, + statuses_to_include: set[TrialStatus] | None = None, + statuses_to_include_map_metric: set[TrialStatus] | None = None, map_keys_as_parameters: bool = False, - limit_rows_per_metric: Optional[int] = None, - limit_rows_per_group: Optional[int] = None, + limit_rows_per_metric: int | None = None, + limit_rows_per_group: int | None = None, ) -> list[Observation]: """Convert MapData to observations. @@ -645,7 +642,7 @@ def separate_observations( def recombine_observations( observation_features: list[ObservationFeatures], observation_data: list[ObservationData], - arm_names: Optional[list[str]] = None, + arm_names: list[str] | None = None, ) -> list[Observation]: """ Construct a list of `Observation`s from the given arguments. diff --git a/ax/core/optimization_config.py b/ax/core/optimization_config.py index 7e2681b7140..68ed22534ee 100644 --- a/ax/core/optimization_config.py +++ b/ax/core/optimization_config.py @@ -8,7 +8,6 @@ from itertools import groupby from logging import Logger -from typing import Optional, Union from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective, ScalarizedObjective @@ -47,8 +46,8 @@ class OptimizationConfig(Base): def __init__( self, objective: Objective, - outcome_constraints: Optional[list[OutcomeConstraint]] = None, - risk_measure: Optional[RiskMeasure] = None, + outcome_constraints: list[OutcomeConstraint] | None = None, + risk_measure: RiskMeasure | None = None, ) -> None: """Inits OptimizationConfig. @@ -68,7 +67,7 @@ def __init__( ) self._objective: Objective = objective self._outcome_constraints: list[OutcomeConstraint] = constraints - self.risk_measure: Optional[RiskMeasure] = risk_measure + self.risk_measure: RiskMeasure | None = risk_measure def clone(self) -> "OptimizationConfig": """Make a copy of this optimization config.""" @@ -76,11 +75,9 @@ def clone(self) -> "OptimizationConfig": def clone_with_args( self, - objective: Optional[Objective] = None, - outcome_constraints: Optional[ - list[OutcomeConstraint] - ] = _NO_OUTCOME_CONSTRAINTS, - risk_measure: Optional[RiskMeasure] = _NO_RISK_MEASURE, + objective: Objective | None = None, + outcome_constraints: None | (list[OutcomeConstraint]) = _NO_OUTCOME_CONSTRAINTS, + risk_measure: RiskMeasure | None = _NO_RISK_MEASURE, ) -> "OptimizationConfig": """Make a copy of this optimization config.""" objective = self.objective.clone() if objective is None else objective @@ -159,8 +156,8 @@ def outcome_constraints(self, outcome_constraints: list[OutcomeConstraint]) -> N @staticmethod def _validate_optimization_config( objective: Objective, - outcome_constraints: Optional[list[OutcomeConstraint]] = None, - risk_measure: Optional[RiskMeasure] = None, + outcome_constraints: list[OutcomeConstraint] | None = None, + risk_measure: RiskMeasure | None = None, ) -> None: """Ensure outcome constraints are valid and the risk measure is compatible with the objective. @@ -179,10 +176,8 @@ def _validate_optimization_config( if type(objective) is MultiObjective: # Raise error on exact equality; `ScalarizedObjective` is OK raise ValueError( - ( - "OptimizationConfig does not support MultiObjective. " - "Use MultiObjectiveOptimizationConfig instead." - ) + "OptimizationConfig does not support MultiObjective. " + "Use MultiObjectiveOptimizationConfig instead." ) outcome_constraints = outcome_constraints or [] # Only vaidate `outcome_constraints` @@ -266,10 +261,10 @@ class MultiObjectiveOptimizationConfig(OptimizationConfig): def __init__( self, - objective: Union[MultiObjective, ScalarizedObjective], - outcome_constraints: Optional[list[OutcomeConstraint]] = None, - objective_thresholds: Optional[list[ObjectiveThreshold]] = None, - risk_measure: Optional[RiskMeasure] = None, + objective: MultiObjective | ScalarizedObjective, + outcome_constraints: list[OutcomeConstraint] | None = None, + objective_thresholds: list[ObjectiveThreshold] | None = None, + risk_measure: RiskMeasure | None = None, ) -> None: """Inits OptimizationConfig. @@ -293,22 +288,20 @@ def __init__( objective_thresholds=objective_thresholds, risk_measure=risk_measure, ) - self._objective: Union[MultiObjective, ScalarizedObjective] = objective + self._objective: MultiObjective | ScalarizedObjective = objective self._outcome_constraints: list[OutcomeConstraint] = constraints self._objective_thresholds: list[ObjectiveThreshold] = objective_thresholds - self.risk_measure: Optional[RiskMeasure] = risk_measure + self.risk_measure: RiskMeasure | None = risk_measure # pyre-fixme[14]: Inconsistent override. def clone_with_args( self, - objective: Optional[Union[MultiObjective, ScalarizedObjective]] = None, - outcome_constraints: Optional[ - list[OutcomeConstraint] - ] = _NO_OUTCOME_CONSTRAINTS, - objective_thresholds: Optional[ + objective: MultiObjective | ScalarizedObjective | None = None, + outcome_constraints: None | (list[OutcomeConstraint]) = _NO_OUTCOME_CONSTRAINTS, + objective_thresholds: None | ( list[ObjectiveThreshold] - ] = _NO_OBJECTIVE_THRESHOLDS, - risk_measure: Optional[RiskMeasure] = _NO_RISK_MEASURE, + ) = _NO_OBJECTIVE_THRESHOLDS, + risk_measure: RiskMeasure | None = _NO_RISK_MEASURE, ) -> "MultiObjectiveOptimizationConfig": """Make a copy of this optimization config.""" objective = self.objective.clone() if objective is None else objective @@ -334,12 +327,12 @@ def clone_with_args( ) @property - def objective(self) -> Union[MultiObjective, ScalarizedObjective]: + def objective(self) -> MultiObjective | ScalarizedObjective: """Get objective.""" return self._objective @objective.setter - def objective(self, objective: Union[MultiObjective, ScalarizedObjective]) -> None: + def objective(self, objective: MultiObjective | ScalarizedObjective) -> None: """Set objective if not present in outcome constraints.""" self._validate_optimization_config( objective=objective, @@ -401,9 +394,9 @@ def objective_thresholds_dict(self) -> dict[str, ObjectiveThreshold]: @staticmethod def _validate_optimization_config( objective: Objective, - outcome_constraints: Optional[list[OutcomeConstraint]] = None, - objective_thresholds: Optional[list[ObjectiveThreshold]] = None, - risk_measure: Optional[RiskMeasure] = None, + outcome_constraints: list[OutcomeConstraint] | None = None, + objective_thresholds: list[ObjectiveThreshold] | None = None, + risk_measure: RiskMeasure | None = None, ) -> None: """Ensure outcome constraints are valid and the risk measure is compatible with the objective. @@ -422,12 +415,10 @@ def _validate_optimization_config( """ if not isinstance(objective, (MultiObjective, ScalarizedObjective)): raise TypeError( - ( - "`MultiObjectiveOptimizationConfig` requires an objective " - "of type `MultiObjective` or `ScalarizedObjective`. " - "Use `OptimizationConfig` instead if using a " - "single-metric objective." - ) + "`MultiObjectiveOptimizationConfig` requires an objective " + "of type `MultiObjective` or `ScalarizedObjective`. " + "Use `OptimizationConfig` instead if using a " + "single-metric objective." ) outcome_constraints = outcome_constraints or [] objective_thresholds = objective_thresholds or [] diff --git a/ax/core/outcome_constraint.py b/ax/core/outcome_constraint.py index 5e2d1a7a59d..ab9543d25f0 100644 --- a/ax/core/outcome_constraint.py +++ b/ax/core/outcome_constraint.py @@ -10,7 +10,6 @@ import logging from collections.abc import Iterable -from typing import Optional from ax.core.metric import Metric from ax.core.types import ComparisonOp @@ -123,7 +122,7 @@ def _validate_metric_constraint_op( msg = CONSTRAINT_WARNING_MESSAGE.format(**fmt_data) logger.debug(msg) return False, msg - return True, str() + return True, "" def _validate_constraint(self) -> tuple[bool, str]: """Ensure constraint is compatible with metric definition. @@ -146,7 +145,7 @@ def _validate_constraint(self) -> tuple[bool, str]: return False, msg if not self.relative: - return True, str() + return True, "" fmt_data = None if self.metric.lower_is_better is not None: @@ -160,7 +159,7 @@ def _validate_constraint(self) -> tuple[bool, str]: logger.debug(msg) return False, msg - return True, str() + return True, "" def __repr__(self) -> str: op = ">=" if self.op == ComparisonOp.GEQ else "<=" @@ -200,14 +199,12 @@ def __init__( metric: Metric, bound: float, relative: bool = True, - op: Optional[ComparisonOp] = None, + op: ComparisonOp | None = None, ) -> None: if metric.lower_is_better is None and op is None: raise ValueError( - ( - f"Metric {metric} must have attribute `lower_is_better` set or " - f"op {op} must be manually specified." - ) + f"Metric {metric} must have attribute `lower_is_better` set or " + f"op {op} must be manually specified." ) elif op is None: op = ComparisonOp.LEQ if metric.lower_is_better else ComparisonOp.GEQ @@ -259,7 +256,7 @@ def __init__( op: ComparisonOp, bound: float, relative: bool = True, - weights: Optional[list[float]] = None, + weights: list[float] | None = None, ) -> None: for metric in metrics: self._validate_metric_constraint_op(metric=metric, op=op) diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 95f49706121..91556b00bcc 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -12,7 +12,7 @@ from copy import deepcopy from enum import Enum from math import inf -from typing import cast, Optional, Union +from typing import cast, Union from warnings import warn from ax.core.types import TNumeric, TParamValue, TParamValueList @@ -65,7 +65,7 @@ def is_numeric(self) -> bool: } SUPPORTED_PARAMETER_TYPES: tuple[ - Union[type[bool], type[float], type[int], type[str]], ... + type[bool] | type[float] | type[int] | type[str], ... ] = tuple(PARAMETER_PYTHON_TYPE_MAP.values()) @@ -189,7 +189,7 @@ def available_flags(self) -> list[str]: @property def summary_dict( self, - ) -> dict[str, Union[TParamValueList, TParamValue, str, list[str]]]: + ) -> dict[str, TParamValueList | TParamValue | str | list[str]]: # Assemble dict. summary_dict = { @@ -233,7 +233,7 @@ def __init__( upper: float, log_scale: bool = False, logit_scale: bool = False, - digits: Optional[int] = None, + digits: int | None = None, is_fidelity: bool = False, target_value: TParamValue = None, ) -> None: @@ -269,7 +269,7 @@ def __init__( self._log_scale = log_scale self._logit_scale = logit_scale self._is_fidelity = is_fidelity - self._target_value: Optional[TNumeric] = self.cast(target_value) + self._target_value: TNumeric | None = self.cast(target_value) self._validate_range_param( parameter_type=parameter_type, @@ -291,7 +291,7 @@ def _validate_range_param( upper: TNumeric, log_scale: bool, logit_scale: bool, - parameter_type: Optional[ParameterType] = None, + parameter_type: ParameterType | None = None, ) -> None: if parameter_type and parameter_type not in ( ParameterType.INT, @@ -362,7 +362,7 @@ def lower(self, value: TNumeric) -> None: self._lower = not_none(self.cast(value)) @property - def digits(self) -> Optional[int]: + def digits(self) -> int | None: """Number of digits to round values to for float type. Upper and lower bound are re-cast after this property is changed. @@ -380,7 +380,7 @@ def logit_scale(self) -> bool: return self._logit_scale def update_range( - self, lower: Optional[float] = None, upper: Optional[float] = None + self, lower: float | None = None, upper: float | None = None ) -> RangeParameter: """Set the range to the given values. @@ -479,7 +479,7 @@ def clone(self) -> RangeParameter: target_value=self._target_value, ) - def cast(self, value: TParamValue) -> Optional[TNumeric]: + def cast(self, value: TParamValue) -> TNumeric | None: if value is None: return None if self.parameter_type is ParameterType.FLOAT and self._digits is not None: @@ -532,12 +532,12 @@ def __init__( name: str, parameter_type: ParameterType, values: list[TParamValue], - is_ordered: Optional[bool] = None, + is_ordered: bool | None = None, is_task: bool = False, is_fidelity: bool = False, target_value: TParamValue = None, - sort_values: Optional[bool] = None, - dependents: Optional[dict[TParamValue, list[str]]] = None, + sort_values: bool | None = None, + dependents: dict[TParamValue, list[str]] | None = None, ) -> None: if (is_fidelity or is_task) and (target_value is None): ptype = "fidelity" if is_fidelity else "task" @@ -750,7 +750,7 @@ def __init__( value: TParamValue, is_fidelity: bool = False, target_value: TParamValue = None, - dependents: Optional[dict[TParamValue, list[str]]] = None, + dependents: dict[TParamValue, list[str]] | None = None, ) -> None: """Initialize FixedParameter diff --git a/ax/core/parameter_constraint.py b/ax/core/parameter_constraint.py index a7207280717..12bd95b0876 100644 --- a/ax/core/parameter_constraint.py +++ b/ax/core/parameter_constraint.py @@ -8,8 +8,6 @@ from __future__ import annotations -from typing import Union - from ax.core.parameter import ChoiceParameter, FixedParameter, Parameter, RangeParameter from ax.core.types import ComparisonOp from ax.utils.common.base import SortableBase @@ -52,7 +50,7 @@ def bound(self, bound: float) -> None: """Set bound.""" self._bound = bound - def check(self, parameter_dict: dict[str, Union[int, float]]) -> bool: + def check(self, parameter_dict: dict[str, int | float]) -> bool: """Whether or not the set of parameter values satisfies the constraint. Does a weighted sum of the parameter values based on the constraint_dict @@ -91,10 +89,8 @@ def clone_with_transformed_parameters( def __repr__(self) -> str: return ( "ParameterConstraint(" - + " + ".join( - "{}*{}".format(v, k) for k, v in sorted(self.constraint_dict.items()) - ) - + " <= {})".format(self._bound) + + " + ".join(f"{v}*{k}" for k, v in sorted(self.constraint_dict.items())) + + f" <= {self._bound})" ) @property diff --git a/ax/core/parameter_distribution.py b/ax/core/parameter_distribution.py index ff2e885cd3b..75d355b6eb3 100644 --- a/ax/core/parameter_distribution.py +++ b/ax/core/parameter_distribution.py @@ -9,7 +9,7 @@ from copy import deepcopy from importlib import import_module -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from ax.exceptions.core import UserInputError from ax.utils.common.base import SortableBase @@ -33,7 +33,7 @@ def __init__( self, parameters: list[TParamName], distribution_class: TDistribution, - distribution_parameters: Optional[dict[str, Any]], + distribution_parameters: dict[str, Any] | None, multiplicative: bool = False, ) -> None: """Initialize a parameter distribution. @@ -56,7 +56,7 @@ def __init__( self._distribution_class = distribution_class self._distribution_parameters: dict[str, Any] = distribution_parameters or {} self.multiplicative = multiplicative - self._distribution: Optional[rv_frozen] = None # pyre-ignore [11] + self._distribution: rv_frozen | None = None # pyre-ignore [11] @property def distribution_class(self) -> TDistribution: diff --git a/ax/core/risk_measures.py b/ax/core/risk_measures.py index a7a6bf01005..5e1db4b235d 100644 --- a/ax/core/risk_measures.py +++ b/ax/core/risk_measures.py @@ -9,7 +9,6 @@ from __future__ import annotations from copy import deepcopy -from typing import Union from ax.utils.common.base import SortableBase from ax.utils.common.equality import equality_typechecker @@ -30,7 +29,7 @@ class RiskMeasure(SortableBase): def __init__( self, risk_measure: str, - options: dict[str, Union[int, float, bool, list[float]]], + options: dict[str, int | float | bool | list[float]], ) -> None: """Initialize a risk measure. diff --git a/ax/core/runner.py b/ax/core/runner.py index b75940e52ea..e21d12550f8 100644 --- a/ax/core/runner.py +++ b/ax/core/runner.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from ax.utils.common.base import Base from ax.utils.common.serialization import SerializationMixin @@ -129,7 +129,7 @@ def poll_exception(self, trial: core.base_trial.BaseTrial) -> str: ) def stop( - self, trial: core.base_trial.BaseTrial, reason: Optional[str] = None + self, trial: core.base_trial.BaseTrial, reason: str | None = None ) -> dict[str, Any]: """Stop a trial based on custom runner subclass implementation. diff --git a/ax/core/search_space.py b/ax/core/search_space.py index bf778c56ff4..ac934de976e 100644 --- a/ax/core/search_space.py +++ b/ax/core/search_space.py @@ -10,12 +10,11 @@ import math import warnings -from collections.abc import Hashable, Mapping +from collections.abc import Callable, Hashable, Mapping from dataclasses import dataclass, field from functools import reduce from logging import Logger from random import choice, uniform -from typing import Callable, Optional, Union import numpy as np import pandas as pd @@ -68,7 +67,7 @@ class SearchSpace(Base): def __init__( self, parameters: list[Parameter], - parameter_constraints: Optional[list[ParameterConstraint]] = None, + parameter_constraints: list[ParameterConstraint] | None = None, ) -> None: """Initialize SearchSpace @@ -326,7 +325,7 @@ def out_of_design_arm(self) -> Arm: return self.construct_arm() def construct_arm( - self, parameters: Optional[TParameterization] = None, name: Optional[str] = None + self, parameters: TParameterization | None = None, name: str | None = None ) -> Arm: """Construct new arm using given parameters and name. Any missing parameters fallback to the experiment defaults, represented as None. @@ -440,7 +439,7 @@ class HierarchicalSearchSpace(SearchSpace): def __init__( self, parameters: list[Parameter], - parameter_constraints: Optional[list[ParameterConstraint]] = None, + parameter_constraints: list[ParameterConstraint] | None = None, ) -> None: super().__init__( parameters=parameters, parameter_constraints=parameter_constraints @@ -626,7 +625,7 @@ def hierarchical_structure_str(self, parameter_names_only: bool = False) -> str: representation. """ - def _hrepr(param: Optional[Parameter], value: Optional[str], level: int) -> str: + def _hrepr(param: Parameter | None, value: str | None, level: int) -> str: is_level_param = param and not value if is_level_param: param = not_none(param) @@ -862,8 +861,8 @@ def __init__( parameters: list[Parameter], parameter_distributions: list[ParameterDistribution], num_samples: int, - environmental_variables: Optional[list[Parameter]] = None, - parameter_constraints: Optional[list[ParameterConstraint]] = None, + environmental_variables: list[Parameter] | None = None, + parameter_constraints: list[ParameterConstraint] | None = None, ) -> None: """Initialize the robust search space. @@ -1072,16 +1071,14 @@ class SearchSpaceDigest: """ feature_names: list[str] - bounds: list[tuple[Union[int, float], Union[int, float]]] + bounds: list[tuple[int | float, int | float]] ordinal_features: list[int] = field(default_factory=list) categorical_features: list[int] = field(default_factory=list) - discrete_choices: Mapping[int, list[Union[int, float]]] = field( - default_factory=dict - ) + discrete_choices: Mapping[int, list[int | float]] = field(default_factory=dict) task_features: list[int] = field(default_factory=list) fidelity_features: list[int] = field(default_factory=list) - target_values: dict[int, Union[int, float]] = field(default_factory=dict) - robust_digest: Optional[RobustSearchSpaceDigest] = None + target_values: dict[int, int | float] = field(default_factory=dict) + robust_digest: RobustSearchSpaceDigest | None = None @dataclass @@ -1105,8 +1102,8 @@ class RobustSearchSpaceDigest: Only relevant if paired with a `distribution_sampler`. """ - sample_param_perturbations: Optional[Callable[[], np.ndarray]] = None - sample_environmental: Optional[Callable[[], np.ndarray]] = None + sample_param_perturbations: Callable[[], np.ndarray] | None = None + sample_environmental: Callable[[], np.ndarray] | None = None environmental_variables: list[str] = field(default_factory=list) multiplicative: bool = False diff --git a/ax/core/tests/test_generation_strategy_interface.py b/ax/core/tests/test_generation_strategy_interface.py index 89ac00cb2d3..d0f478b463d 100644 --- a/ax/core/tests/test_generation_strategy_interface.py +++ b/ax/core/tests/test_generation_strategy_interface.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional from ax.core.data import Data from ax.core.experiment import Experiment @@ -20,11 +19,11 @@ class MyGSI(GenerationStrategyInterface): def gen_for_multiple_trials_with_multiple_models( self, experiment: Experiment, - data: Optional[Data] = None, + data: Data | None = None, # TODO[drfreund, danielcohennyc, mgarrard]: Update the format of the arguments # below as we find the right one. num_generator_runs: int = 1, - n: Optional[int] = None, + n: int | None = None, ) -> list[list[GeneratorRun]]: raise NotImplementedError diff --git a/ax/core/trial.py b/ax/core/trial.py index 52a17b394a0..a6980d1f603 100644 --- a/ax/core/trial.py +++ b/ax/core/trial.py @@ -12,7 +12,7 @@ from logging import Logger -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, immutable_once_run @@ -66,10 +66,10 @@ class Trial(BaseTrial): def __init__( self, experiment: core.experiment.Experiment, - generator_run: Optional[GeneratorRun] = None, - trial_type: Optional[str] = None, - ttl_seconds: Optional[int] = None, - index: Optional[int] = None, + generator_run: GeneratorRun | None = None, + trial_type: str | None = None, + ttl_seconds: int | None = None, + index: int | None = None, ) -> None: super().__init__( experiment=experiment, @@ -83,7 +83,7 @@ def __init__( self.add_generator_run(generator_run=generator_run) @property - def generator_run(self) -> Optional[GeneratorRun]: + def generator_run(self) -> GeneratorRun | None: """Generator run attached to this trial.""" return self._generator_run @@ -95,7 +95,7 @@ def generator_runs(self) -> list[GeneratorRun]: return [gr] if gr is not None else [] @property - def arm(self) -> Optional[Arm]: + def arm(self) -> Arm | None: """The arm associated with this batch.""" if self.generator_run is None: return None @@ -112,7 +112,7 @@ def arm(self) -> Optional[Arm]: @immutable_once_run def add_arm( - self, arm: Arm, candidate_metadata: Optional[dict[str, Any]] = None + self, arm: Arm, candidate_metadata: dict[str, Any] | None = None ) -> Trial: """Add arm to the trial. @@ -286,8 +286,8 @@ def validate_data_for_trial(self, data: Data) -> None: def update_trial_data( self, raw_data: TEvaluationOutcome, - metadata: Optional[dict[str, Union[str, int]]] = None, - sample_size: Optional[int] = None, + metadata: dict[str, str | int] | None = None, + sample_size: int | None = None, combine_with_last_data: bool = False, ) -> str: """Utility method that attaches data to a trial and @@ -333,7 +333,7 @@ def update_trial_data( def clone_to( self, - experiment: Optional[core.experiment.Experiment] = None, + experiment: core.experiment.Experiment | None = None, ) -> Trial: """Clone the trial and attach it to the specified experiment. If no experiment is provided, the original experiment will be used. diff --git a/ax/core/types.py b/ax/core/types.py index 220aabcccf8..46502b32bda 100644 --- a/ax/core/types.py +++ b/ax/core/types.py @@ -8,8 +8,8 @@ import enum from collections import defaultdict -from collections.abc import Hashable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Hashable +from typing import Any, Optional, Union import numpy as np diff --git a/ax/core/utils.py b/ax/core/utils.py index 53f106baf51..eb5339ed70c 100644 --- a/ax/core/utils.py +++ b/ax/core/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from copy import deepcopy -from typing import NamedTuple, Optional +from typing import NamedTuple import numpy as np from ax.core.arm import Arm @@ -168,7 +168,7 @@ def _extract_generator_runs(trial: BaseTrial) -> list[GeneratorRun]: def get_model_trace_of_times( experiment: Experiment, -) -> tuple[list[Optional[float]], list[Optional[float]]]: +) -> tuple[list[float | None], list[float | None]]: """ Get time spent fitting the model and generating candidates during each trial. Not cumulative. @@ -191,8 +191,8 @@ def get_model_times(experiment: Experiment) -> tuple[float, float]: course of the experiment. """ fit_times, gen_times = get_model_trace_of_times(experiment) - fit_time = sum((t for t in fit_times if t is not None)) - gen_time = sum((t for t in gen_times if t is not None)) + fit_time = sum(t for t in fit_times if t is not None) + gen_time = sum(t for t in gen_times if t is not None) return fit_time, gen_time @@ -202,7 +202,7 @@ def get_model_times(experiment: Experiment) -> tuple[float, float]: def extract_pending_observations( experiment: Experiment, include_out_of_design_points: bool = False, -) -> Optional[dict[str, list[ObservationFeatures]]]: +) -> dict[str, list[ObservationFeatures]] | None: """Computes a list of pending observation features (corresponding to: - arms that have been generated and run in the course of the experiment, but have not been completed with data, @@ -240,7 +240,7 @@ def get_pending_observation_features( experiment: Experiment, *, include_out_of_design_points: bool = False, -) -> Optional[dict[str, list[ObservationFeatures]]]: +) -> dict[str, list[ObservationFeatures]] | None: """Computes a list of pending observation features (corresponding to: - arms that have been generated in the course of the experiment, but have not been completed with data, @@ -274,7 +274,7 @@ def create_observation_feature( arm: Arm, trial_index: int, trial: BaseTrial, - ) -> Optional[ObservationFeatures]: + ) -> ObservationFeatures | None: if not include_out_of_design_points and not _is_in_design(arm=arm): return None return ObservationFeatures.from_arm( @@ -331,7 +331,7 @@ def create_observation_feature( def get_pending_observation_features_based_on_trial_status( experiment: Experiment, include_out_of_design_points: bool = False, -) -> Optional[dict[str, list[ObservationFeatures]]]: +) -> dict[str, list[ObservationFeatures]] | None: """A faster analogue of ``get_pending_observation_features`` that makes assumptions about trials in experiment in order to speed up extraction of pending points. diff --git a/ax/early_stopping/strategies/base.py b/ax/early_stopping/strategies/base.py index d6c1aa58981..480269952ac 100644 --- a/ax/early_stopping/strategies/base.py +++ b/ax/early_stopping/strategies/base.py @@ -11,7 +11,6 @@ from collections.abc import Iterable, Sequence from dataclasses import dataclass from logging import Logger -from typing import Optional import numpy as np import pandas as pd @@ -58,7 +57,7 @@ class EarlyStoppingTrainingData: X: np.ndarray Y: np.ndarray Yvar: np.ndarray - arm_names: list[Optional[str]] + arm_names: list[str | None] class BaseEarlyStoppingStrategy(ABC, Base): @@ -67,12 +66,12 @@ class BaseEarlyStoppingStrategy(ABC, Base): def __init__( self, - metric_names: Optional[Iterable[str]] = None, + metric_names: Iterable[str] | None = None, seconds_between_polls: int = 300, - min_progression: Optional[float] = None, - max_progression: Optional[float] = None, - min_curves: Optional[int] = None, - trial_indices_to_ignore: Optional[list[int]] = None, + min_progression: float | None = None, + max_progression: float | None = None, + min_curves: int | None = None, + trial_indices_to_ignore: list[int] | None = None, normalize_progressions: bool = False, ) -> None: """A BaseEarlyStoppingStrategy class. @@ -116,7 +115,7 @@ def should_stop_trials_early( self, trial_indices: set[int], experiment: Experiment, - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: """Decide whether to complete trials before evaluation is fully concluded. Typical examples include stopping a machine learning model's training, or @@ -134,7 +133,7 @@ def should_stop_trials_early( def _check_validity_and_get_data( self, experiment: Experiment, metric_names: list[str] - ) -> Optional[MapData]: + ) -> MapData | None: """Validity checks and returns the `MapData` used for early stopping that is associated with `metric_names`. This function also handles normalizing progressions. @@ -212,8 +211,8 @@ def _log_and_return_progression_range( logger: logging.Logger, trial_index: int, trial_last_progression: float, - min_progression: Optional[float], - max_progression: Optional[float], + min_progression: float | None, + max_progression: float | None, metric_name: str, ) -> tuple[bool, str]: """Helper function for logging/constructing a reason when min progression @@ -273,7 +272,7 @@ def is_eligible_any( trial_indices: set[int], experiment: Experiment, df: pd.DataFrame, - map_key: Optional[str] = None, + map_key: str | None = None, ) -> bool: """Perform a series of default checks for a set of trials `trial_indices` and determine if at least one of them is eligible for further stopping logic: @@ -317,7 +316,7 @@ def is_eligible( experiment: Experiment, df: pd.DataFrame, map_key: str, - ) -> tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: """Perform a series of default checks for a specific trial `trial_index` and determines whether it is eligible for further stopping logic: 1. Check for ignored indices based on `self.trial_indices_to_ignore` @@ -411,14 +410,14 @@ class ModelBasedEarlyStoppingStrategy(BaseEarlyStoppingStrategy): def __init__( self, - metric_names: Optional[Iterable[str]] = None, + metric_names: Iterable[str] | None = None, seconds_between_polls: int = 300, - min_progression: Optional[float] = None, - max_progression: Optional[float] = None, - min_curves: Optional[int] = None, - trial_indices_to_ignore: Optional[list[int]] = None, + min_progression: float | None = None, + max_progression: float | None = None, + min_curves: int | None = None, + trial_indices_to_ignore: list[int] | None = None, normalize_progressions: bool = False, - min_progression_modeling: Optional[float] = None, + min_progression_modeling: float | None = None, ) -> None: """A ModelBasedEarlyStoppingStrategy class. @@ -462,7 +461,7 @@ def __init__( def _check_validity_and_get_data( self, experiment: Experiment, metric_names: list[str] - ) -> Optional[MapData]: + ) -> MapData | None: """Validity checks and returns the `MapData` used for early stopping that is associated with `metric_names`. This function also handles normalizing progressions. @@ -486,9 +485,9 @@ def get_training_data( self, experiment: Experiment, map_data: MapData, - max_training_size: Optional[int] = None, - outcomes: Optional[Sequence[str]] = None, - parameters: Optional[list[str]] = None, + max_training_size: int | None = None, + outcomes: Sequence[str] | None = None, + parameters: list[str] | None = None, ) -> EarlyStoppingTrainingData: """Processes the raw (untransformed) training data into arrays for use in modeling. The trailing dimensions of `X` are the map keys, in @@ -544,7 +543,7 @@ def get_training_data( def get_transform_helper_model( experiment: Experiment, data: Data, - transforms: Optional[list[type[Transform]]] = None, + transforms: list[type[Transform]] | None = None, ) -> MapTorchModelBridge: """ Constructs a TorchModelBridge, to be used as a helper for transforming parameters. diff --git a/ax/early_stopping/strategies/logical.py b/ax/early_stopping/strategies/logical.py index 06f8f348c2b..3eddefede43 100644 --- a/ax/early_stopping/strategies/logical.py +++ b/ax/early_stopping/strategies/logical.py @@ -7,7 +7,7 @@ from collections.abc import Sequence from functools import reduce -from typing import Any, Optional +from typing import Any from ax.core.experiment import Experiment from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy @@ -35,7 +35,7 @@ def should_stop_trials_early( trial_indices: set[int], experiment: Experiment, **kwargs: dict[str, Any], - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: left = self.left.should_stop_trials_early( trial_indices=trial_indices, experiment=experiment, **kwargs @@ -68,7 +68,7 @@ def should_stop_trials_early( trial_indices: set[int], experiment: Experiment, **kwargs: dict[str, Any], - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: return { **self.left.should_stop_trials_early( trial_indices=trial_indices, experiment=experiment, **kwargs diff --git a/ax/early_stopping/strategies/percentile.py b/ax/early_stopping/strategies/percentile.py index 6f69f3675c3..36fad77a13c 100644 --- a/ax/early_stopping/strategies/percentile.py +++ b/ax/early_stopping/strategies/percentile.py @@ -8,7 +8,6 @@ from collections.abc import Iterable from logging import Logger -from typing import Optional import numpy as np import pandas as pd @@ -28,13 +27,13 @@ class PercentileEarlyStoppingStrategy(BaseEarlyStoppingStrategy): def __init__( self, - metric_names: Optional[Iterable[str]] = None, + metric_names: Iterable[str] | None = None, seconds_between_polls: int = 300, percentile_threshold: float = 50.0, - min_progression: Optional[float] = 10, - max_progression: Optional[float] = None, - min_curves: Optional[int] = 5, - trial_indices_to_ignore: Optional[list[int]] = None, + min_progression: float | None = 10, + max_progression: float | None = None, + min_curves: int | None = 5, + trial_indices_to_ignore: list[int] | None = None, normalize_progressions: bool = False, ) -> None: """Construct a PercentileEarlyStoppingStrategy instance. @@ -93,7 +92,7 @@ def should_stop_trials_early( self, trial_indices: set[int], experiment: Experiment, - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: """Stop a trial if its performance is in the bottom `percentile_threshold` of the trials at the same step. @@ -165,7 +164,7 @@ def _should_stop_trial_early( df_raw: pd.DataFrame, map_key: str, minimize: bool, - ) -> tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: """Stop a trial if its performance is in the bottom `percentile_threshold` of the trials at the same step. diff --git a/ax/early_stopping/strategies/threshold.py b/ax/early_stopping/strategies/threshold.py index 798cacfbc79..73aee707f69 100644 --- a/ax/early_stopping/strategies/threshold.py +++ b/ax/early_stopping/strategies/threshold.py @@ -8,7 +8,6 @@ from collections.abc import Iterable from logging import Logger -from typing import Optional import pandas as pd from ax.core.experiment import Experiment @@ -25,13 +24,13 @@ class ThresholdEarlyStoppingStrategy(BaseEarlyStoppingStrategy): def __init__( self, - metric_names: Optional[Iterable[str]] = None, + metric_names: Iterable[str] | None = None, seconds_between_polls: int = 300, metric_threshold: float = 0.2, - min_progression: Optional[float] = 10, - max_progression: Optional[float] = None, - min_curves: Optional[int] = 5, - trial_indices_to_ignore: Optional[list[int]] = None, + min_progression: float | None = 10, + max_progression: float | None = None, + min_curves: int | None = 5, + trial_indices_to_ignore: list[int] | None = None, normalize_progressions: bool = False, ) -> None: """Construct a ThresholdEarlyStoppingStrategy instance. @@ -85,7 +84,7 @@ def should_stop_trials_early( self, trial_indices: set[int], experiment: Experiment, - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: """Stop a trial if its performance doesn't reach a pre-specified threshold by `min_progression`. @@ -142,7 +141,7 @@ def _should_stop_trial_early( df: pd.DataFrame, map_key: str, minimize: bool, - ) -> tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: """Stop a trial if its performance doesn't reach a pre-specified threshold by `min_progression`. diff --git a/ax/early_stopping/tests/test_strategies.py b/ax/early_stopping/tests/test_strategies.py index 6bf8794234f..51edb756b64 100644 --- a/ax/early_stopping/tests/test_strategies.py +++ b/ax/early_stopping/tests/test_strategies.py @@ -7,7 +7,7 @@ # pyre-strict from copy import deepcopy -from typing import Any, cast, Optional +from typing import Any, cast import numpy as np from ax.core import OptimizationConfig @@ -54,7 +54,7 @@ def should_stop_trials_early( trial_indices: set[int], experiment: Experiment, **kwargs: dict[str, Any], - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: return {} test_experiment = get_test_map_data_experiment( @@ -142,7 +142,7 @@ def should_stop_trials_early( trial_indices: set[int], experiment: Experiment, **kwargs: dict[str, Any], - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: return {} experiment = get_test_map_data_experiment( @@ -217,7 +217,7 @@ def should_stop_trials_early( trial_indices: set[int], experiment: Experiment, **kwargs: dict[str, Any], - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: return {} experiment = get_test_map_data_experiment( @@ -695,7 +695,7 @@ def _evaluate_early_stopping_with_df( early_stopping_strategy: PercentileEarlyStoppingStrategy, experiment: Experiment, metric_name: str, -) -> dict[int, Optional[str]]: +) -> dict[int, str | None]: """Helper function for testing PercentileEarlyStoppingStrategy on an arbitrary (MapData) df.""" data = not_none( diff --git a/ax/early_stopping/utils.py b/ax/early_stopping/utils.py index 5d34d3f4187..dfacd55ccfe 100644 --- a/ax/early_stopping/utils.py +++ b/ax/early_stopping/utils.py @@ -8,7 +8,6 @@ from collections import defaultdict from logging import Logger -from typing import Optional import pandas as pd from ax.core.base_trial import TrialStatus @@ -119,7 +118,7 @@ def align_partial_results( def estimate_early_stopping_savings( experiment: Experiment, - map_key: Optional[str] = None, + map_key: str | None = None, ) -> float: """Estimate resource savings due to early stopping by considering COMPLETED and EARLY_STOPPED trials. First, use the mean of final diff --git a/ax/exceptions/generation_strategy.py b/ax/exceptions/generation_strategy.py index 95b31493b4a..8e088afc61c 100644 --- a/ax/exceptions/generation_strategy.py +++ b/ax/exceptions/generation_strategy.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Optional from ax.exceptions.core import AxError, OptimizationComplete @@ -28,8 +27,8 @@ def __init__( self, model_name: str, num_running: int, - step_index: Optional[int] = None, - node_name: Optional[str] = None, + step_index: int | None = None, + node_name: str | None = None, ) -> None: if node_name is not None: msg_start = ( @@ -67,7 +66,7 @@ class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted): class GenerationStrategyMisconfiguredException(AxGenerationException): """Special exception indicating that the generation strategy is misconfigured.""" - def __init__(self, error_info: Optional[str]) -> None: + def __init__(self, error_info: str | None) -> None: super().__init__( "This GenerationStrategy was unable to be initialized properly. Please " + "check the documentation, and adjust the configuration accordingly. " diff --git a/ax/global_stopping/strategies/improvement.py b/ax/global_stopping/strategies/improvement.py index 4aa70c67296..220909effc6 100644 --- a/ax/global_stopping/strategies/improvement.py +++ b/ax/global_stopping/strategies/improvement.py @@ -7,7 +7,6 @@ # pyre-strict from logging import Logger -from typing import Optional import numpy as np from ax.core.base_trial import BaseTrial, TrialStatus @@ -80,13 +79,13 @@ def __init__( self.window_size = window_size self.improvement_bar = improvement_bar self.hv_by_trial: dict[int, float] = {} - self._inferred_objective_thresholds: Optional[list[ObjectiveThreshold]] = None + self._inferred_objective_thresholds: list[ObjectiveThreshold] | None = None def _should_stop_optimization( self, experiment: Experiment, - trial_to_check: Optional[int] = None, - objective_thresholds: Optional[list[ObjectiveThreshold]] = None, + trial_to_check: int | None = None, + objective_thresholds: list[ObjectiveThreshold] | None = None, ) -> tuple[bool, str]: """ Check if the objective has improved significantly in the past diff --git a/ax/health_check/search_space.py b/ax/health_check/search_space.py index c980719bd64..a8e824e3fc9 100644 --- a/ax/health_check/search_space.py +++ b/ax/health_check/search_space.py @@ -52,7 +52,7 @@ def search_space_update_recommendation( # have parameter "a" value equal to a.lower and 20% have parameter "a" value # equal to a.upper. param_boundary_prop = defaultdict() - msg = str() + msg = "" num_suggestions = len(parametrizations) diff --git a/ax/metrics/branin_map.py b/ax/metrics/branin_map.py index 605dbf85239..afa3765e83b 100644 --- a/ax/metrics/branin_map.py +++ b/ax/metrics/branin_map.py @@ -12,7 +12,7 @@ from collections import defaultdict from collections.abc import Iterable, Mapping from random import random -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -34,8 +34,8 @@ def __init__( name: str, param_names: Iterable[str], noise_sd: float = 0.0, - lower_is_better: Optional[bool] = None, - rate: Optional[float] = None, + lower_is_better: bool | None = None, + rate: float | None = None, cache_evaluations: bool = True, ) -> None: """A Branin map metric with an optional multiplicative factor @@ -139,7 +139,7 @@ def __init__( name: str, param_names: Iterable[str], noise_sd: float = 0.0, - lower_is_better: Optional[bool] = None, + lower_is_better: bool | None = None, ) -> None: super().__init__( name=name, diff --git a/ax/metrics/factorial.py b/ax/metrics/factorial.py index 099f3014e91..bca28a77e0c 100644 --- a/ax/metrics/factorial.py +++ b/ax/metrics/factorial.py @@ -41,7 +41,7 @@ def __init__( noise_var: used in calculating the probability of each arm. """ - super(FactorialMetric, self).__init__(name) + super().__init__(name) self.coefficients = coefficients self.batch_size = batch_size @@ -122,9 +122,9 @@ def _parameterization_probability( z = 0.0 for factor, level in parameterization.items(): if factor not in coefficients.keys(): - raise ValueError("{} not in supplied coefficients".format(factor)) + raise ValueError(f"{factor} not in supplied coefficients") if level not in coefficients[factor].keys(): - raise ValueError("{} not a valid level of {}".format(level, factor)) + raise ValueError(f"{level} not a valid level of {factor}") z += coefficients[factor][level] z += np.sqrt(noise_var) * np.random.randn() return np.exp(z) / (1 + np.exp(z)) diff --git a/ax/metrics/noisy_function.py b/ax/metrics/noisy_function.py index 87384737889..376888b49f9 100644 --- a/ax/metrics/noisy_function.py +++ b/ax/metrics/noisy_function.py @@ -8,7 +8,9 @@ from __future__ import annotations -from typing import Any, Callable, Optional +from collections.abc import Callable + +from typing import Any import numpy as np import pandas as pd @@ -28,8 +30,8 @@ def __init__( self, name: str, param_names: list[str], - noise_sd: Optional[float] = 0.0, - lower_is_better: Optional[bool] = None, + noise_sd: float | None = 0.0, + lower_is_better: bool | None = None, ) -> None: """ Metric is computed by evaluating a deterministic function, implemented @@ -112,8 +114,8 @@ def __init__( self, name: str, f: Callable[[TParameterization], float], - noise_sd: Optional[float] = 0.0, - lower_is_better: Optional[bool] = None, + noise_sd: float | None = 0.0, + lower_is_better: bool | None = None, ) -> None: """ Metric is computed by evaluating a deterministic function, implemented in f. diff --git a/ax/metrics/noisy_function_map.py b/ax/metrics/noisy_function_map.py index ad0cf006238..a77612ad3c7 100644 --- a/ax/metrics/noisy_function_map.py +++ b/ax/metrics/noisy_function_map.py @@ -12,7 +12,7 @@ from logging import Logger -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -38,7 +38,7 @@ def __init__( name: str, param_names: Iterable[str], noise_sd: float = 0.0, - lower_is_better: Optional[bool] = None, + lower_is_better: bool | None = None, cache_evaluations: bool = True, ) -> None: """ diff --git a/ax/metrics/tensorboard.py b/ax/metrics/tensorboard.py index f99771b0518..a501e8d49d6 100644 --- a/ax/metrics/tensorboard.py +++ b/ax/metrics/tensorboard.py @@ -11,7 +11,7 @@ import logging from logging import Logger -from typing import Any, Optional +from typing import Any import numpy as np @@ -46,7 +46,7 @@ def __init__( self, name: str, tag: str, - lower_is_better: Optional[bool] = True, + lower_is_better: bool | None = True, smoothing: float = SMOOTHING_DEFAULT, cumulative_best: bool = False, ) -> None: diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index d00269d88c5..008f2050aaf 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, field from logging import Logger -from typing import Any, Optional +from typing import Any from ax.core.arm import Arm from ax.core.base_trial import NON_ABANDONED_STATUSES, TrialStatus @@ -49,16 +49,16 @@ @dataclass(frozen=True) class BaseGenArgs: search_space: SearchSpace - optimization_config: Optional[OptimizationConfig] + optimization_config: OptimizationConfig | None pending_observations: dict[str, list[ObservationFeatures]] - fixed_features: Optional[ObservationFeatures] + fixed_features: ObservationFeatures | None @dataclass(frozen=True) class GenResults: observation_features: list[ObservationFeatures] weights: list[float] - best_observation_features: Optional[ObservationFeatures] = None + best_observation_features: ObservationFeatures | None = None gen_metadata: dict[str, Any] = field(default_factory=dict) @@ -92,13 +92,13 @@ def __init__( search_space: SearchSpace, # pyre-fixme[2]: Parameter annotation cannot be `Any`. model: Any, - transforms: Optional[list[type[Transform]]] = None, - experiment: Optional[Experiment] = None, - data: Optional[Data] = None, - transform_configs: Optional[dict[str, TConfig]] = None, - status_quo_name: Optional[str] = None, - status_quo_features: Optional[ObservationFeatures] = None, - optimization_config: Optional[OptimizationConfig] = None, + transforms: list[type[Transform]] | None = None, + experiment: Experiment | None = None, + data: Data | None = None, + transform_configs: dict[str, TConfig] | None = None, + status_quo_name: str | None = None, + status_quo_features: ObservationFeatures | None = None, + optimization_config: OptimizationConfig | None = None, fit_out_of_design: bool = False, fit_abandoned: bool = False, fit_tracking_metrics: bool = True, @@ -151,18 +151,18 @@ def __init__( self.fit_time_since_gen: float = 0.0 self._metric_names: set[str] = set() self._training_data: list[Observation] = [] - self._optimization_config: Optional[OptimizationConfig] = optimization_config + self._optimization_config: OptimizationConfig | None = optimization_config self._training_in_design: list[bool] = [] - self._status_quo: Optional[Observation] = None - self._status_quo_name: Optional[str] = None - self._arms_by_signature: Optional[dict[str, Arm]] = None + self._status_quo: Observation | None = None + self._status_quo_name: str | None = None + self._arms_by_signature: dict[str, Arm] | None = None self.transforms: MutableMapping[str, Transform] = OrderedDict() - self._model_key: Optional[str] = None - self._model_kwargs: Optional[dict[str, Any]] = None - self._bridge_kwargs: Optional[dict[str, Any]] = None + self._model_key: str | None = None + self._model_kwargs: dict[str, Any] | None = None + self._bridge_kwargs: dict[str, Any] | None = None self._model_space: SearchSpace = search_space.clone() self._raw_transforms = transforms - self._transform_configs: Optional[dict[str, TConfig]] = transform_configs + self._transform_configs: dict[str, TConfig] | None = transform_configs self._fit_out_of_design = fit_out_of_design self._fit_abandoned = fit_abandoned self._fit_tracking_metrics = fit_tracking_metrics @@ -246,8 +246,8 @@ def _fit_if_implemented( def _process_and_transform_data( self, - experiment: Optional[Experiment] = None, - data: Optional[Data] = None, + experiment: Experiment | None = None, + data: Data | None = None, ) -> tuple[list[Observation], SearchSpace]: r"""Processes the data into observations and returns transformed observations and the search space. This packages the following methods: @@ -267,7 +267,7 @@ def _process_and_transform_data( ) def _prepare_observations( - self, experiment: Optional[Experiment], data: Optional[Data] + self, experiment: Experiment | None, data: Data | None ) -> list[Observation]: if experiment is None or data is None: return [] @@ -282,8 +282,8 @@ def _transform_data( self, observations: list[Observation], search_space: SearchSpace, - transforms: Optional[list[type[Transform]]], - transform_configs: Optional[dict[str, TConfig]], + transforms: list[type[Transform]] | None, + transform_configs: dict[str, TConfig] | None, assign_transforms: bool = True, ) -> tuple[list[Observation], SearchSpace]: """Initialize transforms and apply them to provided data.""" @@ -370,9 +370,9 @@ def _compute_in_design( def _set_status_quo( self, - experiment: Optional[Experiment], - status_quo_name: Optional[str], - status_quo_features: Optional[ObservationFeatures], + experiment: Experiment | None, + status_quo_name: str | None, + status_quo_features: ObservationFeatures | None, ) -> None: """Set model status quo by matching status_quo_name or status_quo_features. @@ -386,7 +386,7 @@ def _set_status_quo( status_quo_name: Name of status quo arm. status_quo_features: Features for status quo. """ - self._status_quo: Optional[Observation] = None + self._status_quo: Observation | None = None sq_obs = None if ( @@ -434,7 +434,7 @@ def _set_status_quo( self._status_quo = sq_obs[0] @property - def status_quo_data_by_trial(self) -> Optional[dict[int, ObservationData]]: + def status_quo_data_by_trial(self) -> dict[int, ObservationData] | None: """A map of trial index to the status quo observation data of each trial""" return _get_status_quo_by_trial( observations=self._training_data, @@ -449,7 +449,7 @@ def status_quo_data_by_trial(self) -> Optional[dict[int, ObservationData]]: ) @property - def status_quo(self) -> Optional[Observation]: + def status_quo(self) -> Observation | None: """Observation corresponding to status quo, if any.""" return self._status_quo @@ -651,9 +651,9 @@ def update(self, new_data: Data, experiment: Experiment) -> None: def _get_transformed_gen_args( self, search_space: SearchSpace, - optimization_config: Optional[OptimizationConfig] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, - fixed_features: Optional[ObservationFeatures] = None, + optimization_config: OptimizationConfig | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, + fixed_features: ObservationFeatures | None = None, ) -> BaseGenArgs: if pending_observations is None: pending_observations = {} @@ -709,11 +709,11 @@ def _get_transformed_gen_args( def _validate_gen_inputs( self, n: int, - search_space: Optional[SearchSpace] = None, - optimization_config: Optional[OptimizationConfig] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, - fixed_features: Optional[ObservationFeatures] = None, - model_gen_options: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + optimization_config: OptimizationConfig | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, + fixed_features: ObservationFeatures | None = None, + model_gen_options: TConfig | None = None, ) -> None: """Validate inputs to `ModelBridge.gen`. @@ -728,11 +728,11 @@ def _validate_gen_inputs( def gen( self, n: int, - search_space: Optional[SearchSpace] = None, - optimization_config: Optional[OptimizationConfig] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, - fixed_features: Optional[ObservationFeatures] = None, - model_gen_options: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + optimization_config: OptimizationConfig | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, + fixed_features: ObservationFeatures | None = None, + model_gen_options: TConfig | None = None, ) -> GeneratorRun: """ Generate new points from the underlying model according to @@ -864,10 +864,10 @@ def _gen( self, n: int, search_space: SearchSpace, - optimization_config: Optional[OptimizationConfig], + optimization_config: OptimizationConfig | None, pending_observations: dict[str, list[ObservationFeatures]], - fixed_features: Optional[ObservationFeatures], - model_gen_options: Optional[TConfig], + fixed_features: ObservationFeatures | None, + model_gen_options: TConfig | None, ) -> GenResults: """Apply terminal transform, gen, and reverse terminal transform on output. @@ -1090,8 +1090,8 @@ def unwrap_observation_data(observation_data: list[ObservationData]) -> TModelPr def gen_arms( observation_features: list[ObservationFeatures], - arms_by_signature: Optional[dict[str, Arm]] = None, -) -> tuple[list[Arm], Optional[dict[str, TCandidateMetadata]]]: + arms_by_signature: dict[str, Arm] | None = None, +) -> tuple[list[Arm], dict[str, TCandidateMetadata] | None]: """Converts observation features to a tuple of arms list and candidate metadata dict, where arm signatures are mapped to their respective candidate metadata. """ @@ -1140,9 +1140,9 @@ def clamp_observation_features( def _get_status_quo_by_trial( observations: list[Observation], - status_quo_name: Optional[str] = None, - status_quo_features: Optional[ObservationFeatures] = None, -) -> Optional[dict[int, ObservationData]]: + status_quo_name: str | None = None, + status_quo_features: ObservationFeatures | None = None, +) -> dict[int, ObservationData] | None: r""" Given a status quo observation, return a dictionary of trial index to the status quo observation data of each trial. diff --git a/ax/modelbridge/best_model_selector.py b/ax/modelbridge/best_model_selector.py index f153eb96a1b..d62704cb713 100644 --- a/ax/modelbridge/best_model_selector.py +++ b/ax/modelbridge/best_model_selector.py @@ -9,9 +9,10 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from enum import Enum from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Union import numpy as np from ax.exceptions.core import UserInputError @@ -85,7 +86,7 @@ def __init__( diagnostic: str, metric_aggregation: ReductionCriterion, criterion: ReductionCriterion, - model_cv_kwargs: Optional[dict[str, Any]] = None, + model_cv_kwargs: dict[str, Any] | None = None, ) -> None: self.diagnostic = diagnostic if not isinstance(metric_aggregation, ReductionCriterion) or not isinstance( diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index 01a55babd93..a880e4dac5b 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -10,10 +10,10 @@ import warnings from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from copy import deepcopy from logging import Logger -from typing import Callable, NamedTuple, Optional +from typing import NamedTuple from warnings import warn import numpy as np @@ -71,8 +71,7 @@ class AssessModelFitResult(NamedTuple): def cross_validate( model: ModelBridge, folds: int = -1, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - test_selector: Optional[Callable] = None, + test_selector: Callable | None = None, untransform: bool = True, use_posterior_predictive: bool = False, ) -> list[CVResult]: @@ -432,7 +431,7 @@ def _gen_train_test_split( def get_fit_and_std_quality_and_generalization_dict( fitted_model_bridge: ModelBridge, -) -> dict[str, Optional[float]]: +) -> dict[str, float | None]: """ Get stats and gen from a fitted ModelBridge for analytics purposes. """ @@ -471,7 +470,7 @@ def get_fit_and_std_quality_and_generalization_dict( def compute_model_fit_metrics_from_modelbridge( model_bridge: ModelBridge, - fit_metrics_dict: Optional[dict[str, ModelFitMetricProtocol]] = None, + fit_metrics_dict: dict[str, ModelFitMetricProtocol] | None = None, generalization: bool = False, untransform: bool = False, ) -> dict[str, dict[str, float]]: diff --git a/ax/modelbridge/discrete.py b/ax/modelbridge/discrete.py index 45cd7080d64..9240b2e32b0 100644 --- a/ax/modelbridge/discrete.py +++ b/ax/modelbridge/discrete.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Optional from ax.core.observation import ( Observation, @@ -49,7 +48,7 @@ class DiscreteModelBridge(ModelBridge): model: DiscreteModel outcomes: list[str] parameters: list[str] - search_space: Optional[SearchSpace] + search_space: SearchSpace | None def _fit( self, @@ -97,11 +96,11 @@ def _predict( def _validate_gen_inputs( self, n: int, - search_space: Optional[SearchSpace] = None, - optimization_config: Optional[OptimizationConfig] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, - fixed_features: Optional[ObservationFeatures] = None, - model_gen_options: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + optimization_config: OptimizationConfig | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, + fixed_features: ObservationFeatures | None = None, + model_gen_options: TConfig | None = None, ) -> None: """Validate inputs to `ModelBridge.gen`. @@ -118,9 +117,9 @@ def _gen( n: int, search_space: SearchSpace, pending_observations: dict[str, list[ObservationFeatures]], - fixed_features: Optional[ObservationFeatures], - model_gen_options: Optional[TConfig] = None, - optimization_config: Optional[OptimizationConfig] = None, + fixed_features: ObservationFeatures | None, + model_gen_options: TConfig | None = None, + optimization_config: OptimizationConfig | None = None, ) -> GenResults: """Generate new candidates according to search_space and optimization_config. @@ -153,7 +152,7 @@ def _gen( # Pending observations if len(pending_observations) == 0: - pending_array: Optional[list[list[TParamValueList]]] = None + pending_array: list[list[TParamValueList]] | None = None else: pending_array = [[] for _ in self.outcomes] for metric_name, po_list in pending_observations.items(): diff --git a/ax/modelbridge/dispatch_utils.py b/ax/modelbridge/dispatch_utils.py index 1254e557dbd..ae58bd92908 100644 --- a/ax/modelbridge/dispatch_utils.py +++ b/ax/modelbridge/dispatch_utils.py @@ -9,7 +9,7 @@ import logging import warnings from math import ceil -from typing import Any, cast, Optional, Union +from typing import Any, cast import torch from ax.core.experiment import Experiment @@ -47,10 +47,10 @@ def _make_sobol_step( num_trials: int = -1, - min_trials_observed: Optional[int] = None, + min_trials_observed: int | None = None, enforce_num_trials: bool = True, - max_parallelism: Optional[int] = None, - seed: Optional[int] = None, + max_parallelism: int | None = None, + seed: int | None = None, should_deduplicate: bool = False, ) -> GenerationStep: """Shortcut for creating a Sobol generation step.""" @@ -68,19 +68,19 @@ def _make_sobol_step( def _make_botorch_step( num_trials: int = -1, - min_trials_observed: Optional[int] = None, + min_trials_observed: int | None = None, enforce_num_trials: bool = True, - max_parallelism: Optional[int] = None, + max_parallelism: int | None = None, model: ModelRegistryBase = Models.BOTORCH_MODULAR, - model_kwargs: Optional[dict[str, Any]] = None, - winsorization_config: Optional[ - Union[WinsorizationConfig, dict[str, WinsorizationConfig]] - ] = None, + model_kwargs: dict[str, Any] | None = None, + winsorization_config: None | ( + WinsorizationConfig | dict[str, WinsorizationConfig] + ) = None, no_winsorization: bool = False, should_deduplicate: bool = False, - verbose: Optional[bool] = None, - disable_progbar: Optional[bool] = None, - jit_compile: Optional[bool] = None, + verbose: bool | None = None, + disable_progbar: bool | None = None, + jit_compile: bool | None = None, derelativize_with_raw_status_quo: bool = False, fit_out_of_design: bool = False, ) -> GenerationStep: @@ -140,10 +140,10 @@ def _make_botorch_step( def _suggest_gp_model( search_space: SearchSpace, - num_trials: Optional[int] = None, - optimization_config: Optional[OptimizationConfig] = None, + num_trials: int | None = None, + optimization_config: OptimizationConfig | None = None, use_saasbo: bool = False, -) -> Union[None, ModelRegistryBase]: +) -> None | ModelRegistryBase: """Suggest a model based on the search space. None means we use Sobol. 1. We use Sobol if the number of total iterations in the optimization is @@ -265,7 +265,7 @@ def _suggest_gp_model( def calculate_num_initialization_trials( num_tunable_parameters: int, - num_trials: Optional[int], + num_trials: int | None, use_batch_trials: bool, ) -> int: """ @@ -289,30 +289,30 @@ def choose_generation_strategy( *, use_batch_trials: bool = False, enforce_sequential_optimization: bool = True, - random_seed: Optional[int] = None, - torch_device: Optional[torch.device] = None, + random_seed: int | None = None, + torch_device: torch.device | None = None, no_winsorization: bool = False, - winsorization_config: Optional[ - Union[WinsorizationConfig, dict[str, WinsorizationConfig]] - ] = None, + winsorization_config: None | ( + WinsorizationConfig | dict[str, WinsorizationConfig] + ) = None, derelativize_with_raw_status_quo: bool = False, - no_bayesian_optimization: Optional[bool] = None, + no_bayesian_optimization: bool | None = None, force_random_search: bool = False, - num_trials: Optional[int] = None, - num_initialization_trials: Optional[int] = None, + num_trials: int | None = None, + num_initialization_trials: int | None = None, num_completed_initialization_trials: int = 0, - max_initialization_trials: Optional[int] = None, - min_sobol_trials_observed: Optional[int] = None, - max_parallelism_cap: Optional[int] = None, - max_parallelism_override: Optional[int] = None, - optimization_config: Optional[OptimizationConfig] = None, + max_initialization_trials: int | None = None, + min_sobol_trials_observed: int | None = None, + max_parallelism_cap: int | None = None, + max_parallelism_override: int | None = None, + optimization_config: OptimizationConfig | None = None, should_deduplicate: bool = False, use_saasbo: bool = False, - verbose: Optional[bool] = None, - disable_progbar: Optional[bool] = None, - jit_compile: Optional[bool] = None, - experiment: Optional[Experiment] = None, - suggested_model_override: Optional[ModelRegistryBase] = None, + verbose: bool | None = None, + disable_progbar: bool | None = None, + jit_compile: bool | None = None, + experiment: Experiment | None = None, + suggested_model_override: ModelRegistryBase | None = None, fit_out_of_design: bool = False, ) -> GenerationStrategy: """Select an appropriate generation strategy based on the properties of @@ -575,12 +575,10 @@ def choose_generation_strategy( def _get_winsorization_transform_config( - winsorization_config: Optional[ - Union[WinsorizationConfig, dict[str, WinsorizationConfig]] - ], + winsorization_config: None | (WinsorizationConfig | dict[str, WinsorizationConfig]), derelativize_with_raw_status_quo: bool, no_winsorization: bool, -) -> Optional[TConfig]: +) -> TConfig | None: if no_winsorization: if winsorization_config is not None: warnings.warn( diff --git a/ax/modelbridge/external_generation_node.py b/ax/modelbridge/external_generation_node.py index 62c719763f0..4f20ee7d2dd 100644 --- a/ax/modelbridge/external_generation_node.py +++ b/ax/modelbridge/external_generation_node.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from logging import Logger -from typing import Any, Optional +from typing import Any from ax.core.arm import Arm from ax.core.data import Data @@ -56,7 +56,7 @@ def __init__( self, node_name: str, should_deduplicate: bool = True, - transition_criteria: Optional[Sequence[TransitionCriterion]] = None, + transition_criteria: Sequence[TransitionCriterion] | None = None, ) -> None: """Initialize an external generation node. @@ -129,8 +129,8 @@ def fit( self, experiment: Experiment, data: Data, - search_space: Optional[SearchSpace] = None, - optimization_config: Optional[OptimizationConfig] = None, + search_space: SearchSpace | None = None, + optimization_config: OptimizationConfig | None = None, **kwargs: Any, ) -> None: """A method used to initialize or update the experiment state / data @@ -164,8 +164,8 @@ def fit( def _gen( self, - n: Optional[int] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, + n: int | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, **model_gen_kwargs: Any, ) -> GeneratorRun: """Generate new candidates for evaluation. diff --git a/ax/modelbridge/factory.py b/ax/modelbridge/factory.py index f8fe99e4a36..69f709cd2d7 100644 --- a/ax/modelbridge/factory.py +++ b/ax/modelbridge/factory.py @@ -7,7 +7,6 @@ # pyre-strict from logging import Logger -from typing import Optional import torch from ax.core.data import Data @@ -57,7 +56,7 @@ def get_sobol( search_space: SearchSpace, - seed: Optional[int] = None, + seed: int | None = None, deduplicate: bool = False, init_position: int = 0, scramble: bool = True, @@ -86,7 +85,7 @@ def get_sobol( def get_uniform( - search_space: SearchSpace, deduplicate: bool = False, seed: Optional[int] = None + search_space: SearchSpace, deduplicate: bool = False, seed: int | None = None ) -> RandomModelBridge: """Instantiate uniform generator. @@ -106,17 +105,17 @@ def get_uniform( def get_botorch( experiment: Experiment, data: Data, - search_space: Optional[SearchSpace] = None, + search_space: SearchSpace | None = None, dtype: torch.dtype = torch.double, device: torch.device = DEFAULT_TORCH_DEVICE, transforms: list[type[Transform]] = Cont_X_trans + Y_trans, - transform_configs: Optional[dict[str, TConfig]] = None, + transform_configs: dict[str, TConfig] | None = None, model_constructor: TModelConstructor = get_and_fit_model, model_predictor: TModelPredictor = predict_from_model, acqf_constructor: TAcqfConstructor = get_qLogNEI, acqf_optimizer: TOptimizer = scipy_optimizer, # pyre-ignore[9] refit_on_cv: bool = False, - optimization_config: Optional[OptimizationConfig] = None, + optimization_config: OptimizationConfig | None = None, ) -> TorchModelBridge: """Instantiates a BotorchModel.""" if data.df.empty: @@ -144,7 +143,7 @@ def get_botorch( def get_GPEI( experiment: Experiment, data: Data, - search_space: Optional[SearchSpace] = None, + search_space: SearchSpace | None = None, dtype: torch.dtype = torch.double, device: torch.device = DEFAULT_TORCH_DEVICE, ) -> TorchModelBridge: @@ -174,9 +173,9 @@ def get_factorial(search_space: SearchSpace) -> DiscreteModelBridge: def get_empirical_bayes_thompson( experiment: Experiment, data: Data, - search_space: Optional[SearchSpace] = None, + search_space: SearchSpace | None = None, num_samples: int = 10000, - min_weight: Optional[float] = None, + min_weight: float | None = None, uniform_weights: bool = False, ) -> DiscreteModelBridge: """Instantiates an empirical Bayes / Thompson sampling model.""" @@ -199,9 +198,9 @@ def get_empirical_bayes_thompson( def get_thompson( experiment: Experiment, data: Data, - search_space: Optional[SearchSpace] = None, + search_space: SearchSpace | None = None, num_samples: int = 10000, - min_weight: Optional[float] = None, + min_weight: float | None = None, uniform_weights: bool = False, ) -> DiscreteModelBridge: """Instantiates a Thompson sampling model.""" diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 4c1db2a6069..1f68a2b819e 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -9,9 +9,9 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from logging import Logger -from typing import Any, Callable, Optional, Union +from typing import Any # Module-level import to avoid circular dependency b/w this file and # generation_strategy.py @@ -105,35 +105,35 @@ class GenerationNode(SerializationMixin, SortableBase): _node_name: str # Optional specifications - _model_spec_to_gen_from: Optional[ModelSpec] = None + _model_spec_to_gen_from: ModelSpec | None = None # TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping? _transition_criteria: Sequence[TransitionCriterion] _input_constructors: dict[ modelbridge.generation_node_input_constructors.InputConstructorPurpose, modelbridge.generation_node_input_constructors.NodeInputConstructors, ] - _previous_node_name: Optional[str] = None + _previous_node_name: str | None = None # [TODO] Handle experiment passing more eloquently by enforcing experiment # attribute is set in generation strategies class - _generation_strategy: Optional[ + _generation_strategy: None | ( modelbridge.generation_strategy.GenerationStrategy - ] = None + ) = None def __init__( self, node_name: str, model_specs: list[ModelSpec], - best_model_selector: Optional[BestModelSelector] = None, + best_model_selector: BestModelSelector | None = None, should_deduplicate: bool = False, - transition_criteria: Optional[Sequence[TransitionCriterion]] = None, - input_constructors: Optional[ + transition_criteria: Sequence[TransitionCriterion] | None = None, + input_constructors: None | ( dict[ modelbridge.generation_node_input_constructors.InputConstructorPurpose, modelbridge.generation_node_input_constructors.NodeInputConstructors, ] - ] = None, - previous_node_name: Optional[str] = None, + ) = None, + previous_node_name: str | None = None, ) -> None: self._node_name = node_name # Check that the model specs have unique model keys. @@ -172,7 +172,7 @@ def model_spec_to_gen_from(self) -> ModelSpec: return self._model_spec_to_gen_from @property - def model_to_gen_from_name(self) -> Optional[str]: + def model_to_gen_from_name(self) -> str | None: """Returns the name of the model that will be used for gen, if there is one. Otherwise, returns None. """ @@ -232,7 +232,7 @@ def _unique_id(self) -> str: return self.node_name @property - def _fitted_model(self) -> Optional[ModelBridge]: + def _fitted_model(self) -> ModelBridge | None: """Private property to return optional fitted_model from self.model_spec_to_gen_from for convenience. If no model is fit, will return None. If using the non-private `fitted_model` property, @@ -244,8 +244,8 @@ def fit( self, experiment: Experiment, data: Data, - search_space: Optional[SearchSpace] = None, - optimization_config: Optional[OptimizationConfig] = None, + search_space: SearchSpace | None = None, + optimization_config: OptimizationConfig | None = None, **kwargs: Any, ) -> None: """Fits the specified models to the given experiment + data using @@ -330,10 +330,10 @@ def _get_model_state_from_last_generator_run( # TODO [drfreund]: Move this up to `GenerationNodeInterface` once implemented. def gen( self, - n: Optional[int] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, + n: int | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS, - arms_by_signature_for_deduplication: Optional[dict[str, Arm]] = None, + arms_by_signature_for_deduplication: dict[str, Arm] | None = None, **model_gen_kwargs: Any, ) -> GeneratorRun: """This method generates candidates using `self._gen` and handles deduplication @@ -403,8 +403,8 @@ def gen( def _gen( self, - n: Optional[int] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, + n: int | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, **model_gen_kwargs: Any, ) -> GeneratorRun: """Picks a fitted model, from which to generate candidates (via @@ -482,7 +482,7 @@ def trials_from_node(self) -> set[int]: return trials_from_node @property - def node_that_generated_last_gr(self) -> Optional[str]: + def node_that_generated_last_gr(self) -> str | None: """Returns the name of the node that generated the last generator run. Returns: @@ -709,16 +709,16 @@ class GenerationStep(GenerationNode, SortableBase): def __init__( self, - model: Union[ModelRegistryBase, Callable[..., ModelBridge]], + model: ModelRegistryBase | Callable[..., ModelBridge], num_trials: int, - model_kwargs: Optional[dict[str, Any]] = None, - model_gen_kwargs: Optional[dict[str, Any]] = None, - completion_criteria: Optional[Sequence[TransitionCriterion]] = None, + model_kwargs: dict[str, Any] | None = None, + model_gen_kwargs: dict[str, Any] | None = None, + completion_criteria: Sequence[TransitionCriterion] | None = None, min_trials_observed: int = 0, - max_parallelism: Optional[int] = None, + max_parallelism: int | None = None, enforce_num_trials: bool = True, should_deduplicate: bool = False, - model_name: Optional[str] = None, + model_name: str | None = None, use_update: bool = False, # DEPRECATED. index: int = -1, # Index of this step, set internally. ) -> None: @@ -851,10 +851,10 @@ def _unique_id(self) -> str: def gen( self, - n: Optional[int] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, + n: int | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS, - arms_by_signature_for_deduplication: Optional[dict[str, Arm]] = None, + arms_by_signature_for_deduplication: dict[str, Arm] | None = None, **model_gen_kwargs: Any, ) -> GeneratorRun: gr = super().gen( diff --git a/ax/modelbridge/generation_node_input_constructors.py b/ax/modelbridge/generation_node_input_constructors.py index d2787c3b55c..8cf2ab8a3de 100644 --- a/ax/modelbridge/generation_node_input_constructors.py +++ b/ax/modelbridge/generation_node_input_constructors.py @@ -7,7 +7,7 @@ import sys from enum import Enum, unique from math import ceil -from typing import Any, Dict, Optional +from typing import Any from ax.modelbridge.generation_node import GenerationNode @@ -27,9 +27,9 @@ class NodeInputConstructors(Enum): def __call__( self, - previous_node: Optional[GenerationNode], + previous_node: GenerationNode | None, next_node: GenerationNode, - gs_gen_call_kwargs: Dict[str, Any], + gs_gen_call_kwargs: dict[str, Any], ) -> int: """Defines a callable method for the Enum as all values are methods""" try: @@ -61,9 +61,9 @@ class InputConstructorPurpose(Enum): def consume_all_n( - previous_node: Optional[GenerationNode], + previous_node: GenerationNode | None, next_node: GenerationNode, - gs_gen_call_kwargs: Dict[str, Any], + gs_gen_call_kwargs: dict[str, Any], ) -> int: """Generate total requested number of arms from the next node. @@ -91,9 +91,9 @@ def consume_all_n( def repeat_arm_n( - previous_node: Optional[GenerationNode], + previous_node: GenerationNode | None, next_node: GenerationNode, - gs_gen_call_kwargs: Dict[str, Any], + gs_gen_call_kwargs: dict[str, Any], ) -> int: """Generate a small percentage of arms requested to be used for repeat arms in the next trial. @@ -125,9 +125,9 @@ def repeat_arm_n( def remaining_n( - previous_node: Optional[GenerationNode], + previous_node: GenerationNode | None, next_node: GenerationNode, - gs_gen_call_kwargs: Dict[str, Any], + gs_gen_call_kwargs: dict[str, Any], ) -> int: """Generate the remaining number of arms requested for this trial in gs.gen(). diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index e316e7742e2..4d67de1c132 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -8,10 +8,12 @@ from __future__ import annotations +from collections.abc import Callable + from copy import deepcopy from functools import wraps from logging import Logger -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar import pandas as pd from ax.core.data import Data @@ -59,9 +61,7 @@ def step_based_gs_only(f: Callable[..., T]) -> Callable[..., T]: """ @wraps(f) - def impl( - self: "GenerationStrategy", *args: list[Any], **kwargs: dict[str, Any] - ) -> T: + def impl(self: GenerationStrategy, *args: list[Any], **kwargs: dict[str, Any]) -> T: if self.is_node_based: raise UnsupportedError( f"{f.__name__} is not supported for GenerationNode based" @@ -104,14 +104,14 @@ class GenerationStrategy(GenerationStrategyInterface): _generator_runs: list[GeneratorRun] # Experiment, for which this generation strategy has generated trials, if # it exists. - _experiment: Optional[Experiment] = None - _model: Optional[ModelBridge] = None # Current model. + _experiment: Experiment | None = None + _model: ModelBridge | None = None # Current model. def __init__( self, - steps: Optional[list[GenerationStep]] = None, - name: Optional[str] = None, - nodes: Optional[list[GenerationNode]] = None, + steps: list[GenerationStep] | None = None, + name: str | None = None, + nodes: list[GenerationNode] | None = None, ) -> None: # Validate that one and only one of steps or nodes is provided if not ((steps is None) ^ (nodes is None)): @@ -232,7 +232,7 @@ def current_step_index(self) -> int: return node_names_for_all_steps.index(self._curr.node_name) @property - def model(self) -> Optional[ModelBridge]: + def model(self) -> ModelBridge | None: """Current model in this strategy. Returns None if no model has been set yet (i.e., if no generator runs have been produced from this GS). """ @@ -263,7 +263,7 @@ def experiment(self, experiment: Experiment) -> None: ) @property - def last_generator_run(self) -> Optional[GeneratorRun]: + def last_generator_run(self) -> GeneratorRun | None: """Latest generator run produced by this generation strategy. Returns None if no generator runs have been produced yet. """ @@ -277,7 +277,7 @@ def uses_non_registered_models(self) -> bool: return not self._uses_registered_models @property - def trials_as_df(self) -> Optional[pd.DataFrame]: + def trials_as_df(self) -> pd.DataFrame | None: """Puts information on individual trials into a data frame for easy viewing. @@ -344,9 +344,9 @@ def _steps(self) -> list[GenerationStep]: def gen( self, experiment: Experiment, - data: Optional[Data] = None, + data: Data | None = None, n: int = 1, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, **kwargs: Any, ) -> GeneratorRun: """Produce the next points in the experiment. Additional kwargs passed to @@ -388,10 +388,10 @@ def gen( def gen_with_multiple_nodes( self, experiment: Experiment, - data: Optional[Data] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, - arms_per_node: Optional[dict[str, int]] = None, - n: Optional[int] = None, + data: Data | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, + arms_per_node: dict[str, int] | None = None, + n: int | None = None, ) -> list[GeneratorRun]: """Produces a List of GeneratorRuns for a single trial, either ``Trial`` or ``BatchTrial``, and if producing a ``BatchTrial`` allows for multiple @@ -484,8 +484,8 @@ def gen_for_multiple_trials_with_multiple_models( self, experiment: Experiment, num_generator_runs: int, - data: Optional[Data] = None, - n: Optional[int] = None, + data: Data | None = None, + n: int | None = None, ) -> list[list[GeneratorRun]]: """Produce GeneratorRuns for multiple trials at once with the possibility of ensembling, or using multiple models per trial, getting multiple @@ -564,7 +564,7 @@ def clone_reset(self) -> GenerationStrategy: name=self.name, steps=checked_cast_list(GenerationStep, cloned_nodes) ) - def _get_n(self, experiment: Experiment, n: Optional[int]) -> int: + def _get_n(self, experiment: Experiment, n: int | None) -> int: """Get the number of arms to generate from the current generation node. Args: @@ -781,9 +781,9 @@ def _gen_multiple( self, experiment: Experiment, num_generator_runs: int, - data: Optional[Data] = None, + data: Data | None = None, n: int = 1, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, **model_gen_kwargs: Any, ) -> list[GeneratorRun]: """Produce multiple generator runs at once, to be made into multiple @@ -895,7 +895,7 @@ def _determine_arms_from_node( node_to_gen_from_name: str, node_names: list[str], gen_kwargs: dict[str, Any], - arms_per_node: Optional[dict[str, int]] = None, + arms_per_node: dict[str, int] | None = None, ) -> int: """Calculates the number of arms to generate from the node that will be used during generation. @@ -947,7 +947,7 @@ def _determine_arms_from_node( # ------------------------- Model selection logic helpers. ------------------------- - def _fit_current_model(self, data: Optional[Data]) -> None: + def _fit_current_model(self, data: Data | None) -> None: """Fits or update the model on the current generation node (does not move between generation nodes). diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index c46f2146e16..6f082bb2186 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -5,7 +5,7 @@ # pyre-strict -from typing import Any, Optional +from typing import Any import numpy as np @@ -59,18 +59,18 @@ def __init__( data: Data, model: TorchModel, transforms: list[type[Transform]], - transform_configs: Optional[dict[str, TConfig]] = None, - torch_dtype: Optional[torch.dtype] = None, - torch_device: Optional[torch.device] = None, - status_quo_name: Optional[str] = None, - status_quo_features: Optional[ObservationFeatures] = None, - optimization_config: Optional[OptimizationConfig] = None, + transform_configs: dict[str, TConfig] | None = None, + torch_dtype: torch.dtype | None = None, + torch_device: torch.device | None = None, + status_quo_name: str | None = None, + status_quo_features: ObservationFeatures | None = None, + optimization_config: OptimizationConfig | None = None, fit_out_of_design: bool = False, fit_on_init: bool = True, fit_abandoned: bool = False, - default_model_gen_options: Optional[TConfig] = None, - map_data_limit_rows_per_metric: Optional[int] = None, - map_data_limit_rows_per_group: Optional[int] = None, + default_model_gen_options: TConfig | None = None, + map_data_limit_rows_per_metric: int | None = None, + map_data_limit_rows_per_group: int | None = None, ) -> None: """ Applies transforms and fits model. @@ -183,7 +183,7 @@ def _fit( model: TorchModel, search_space: SearchSpace, observations: list[Observation], - parameters: Optional[list[str]] = None, + parameters: list[str] | None = None, **kwargs: Any, ) -> None: """The difference from `TorchModelBridge._fit(...)` is that we use @@ -205,9 +205,9 @@ def _gen( n: int, search_space: SearchSpace, pending_observations: dict[str, list[ObservationFeatures]], - fixed_features: Optional[ObservationFeatures], - model_gen_options: Optional[TConfig] = None, - optimization_config: Optional[OptimizationConfig] = None, + fixed_features: ObservationFeatures | None, + model_gen_options: TConfig | None = None, + optimization_config: OptimizationConfig | None = None, ) -> GenResults: """An updated version of `TorchModelBridge._gen(...) that first injects `map_dim_to_target` (e.g., `{-1: 1.0}`) into `model_gen_options` so that @@ -224,7 +224,7 @@ def _gen( ) def _array_to_observation_features( - self, X: np.ndarray, candidate_metadata: Optional[list[TCandidateMetadata]] + self, X: np.ndarray, candidate_metadata: list[TCandidateMetadata] | None ) -> list[ObservationFeatures]: """The difference b/t this method and TorchModelBridge._array_to_observation_features(...) is @@ -237,7 +237,7 @@ def _array_to_observation_features( ) def _prepare_observations( - self, experiment: Optional[Experiment], data: Optional[Data] + self, experiment: Experiment | None, data: Data | None ) -> list[Observation]: """The difference b/t this method and ModelBridge._prepare_observations(...) is that this one uses `observations_from_map_data`. @@ -278,7 +278,7 @@ def _cross_validate( search_space: SearchSpace, cv_training_data: list[Observation], cv_test_points: list[ObservationFeatures], - parameters: Optional[list[str]] = None, + parameters: list[str] | None = None, use_posterior_predictive: bool = False, **kwargs: Any, ) -> list[ObservationData]: @@ -321,7 +321,7 @@ def _filter_outcomes_out_of_map_range( observation_features: list[ObservationFeatures], observation_data: list[ObservationData], # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - map_key_ranges: dict[str, dict[str, Optional[tuple]]], + map_key_ranges: dict[str, dict[str, tuple | None]], ) -> list[ObservationData]: """Uses `map_key_ranges` to detect which `observation_features` have out-of-range map_keys and filters out the corresponding outcomes in @@ -361,7 +361,7 @@ def _get_map_key_ranges( observation_features: list[ObservationFeatures], observation_data: list[ObservationData], # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - ) -> dict[str, dict[str, Optional[tuple]]]: + ) -> dict[str, dict[str, tuple | None]]: """Get ranges of map_key values in observation features. Returns a dict of the form: {"outcome": {"map_key": (min_val, max_val)}}. """ diff --git a/ax/modelbridge/model_spec.py b/ax/modelbridge/model_spec.py index c07972b3281..06c972b7ccc 100644 --- a/ax/modelbridge/model_spec.py +++ b/ax/modelbridge/model_spec.py @@ -10,9 +10,10 @@ import json import warnings +from collections.abc import Callable from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Callable, Optional +from typing import Any from ax.core.data import Data from ax.core.experiment import Experiment @@ -63,23 +64,23 @@ class ModelSpec(SortableBase, SerializationMixin): model_cv_kwargs: dict[str, Any] = field(default_factory=dict) # An optional override for the model key. Each `ModelSpec` in a # `GenerationNode` must have a unique key to ensure identifiability. - model_key_override: Optional[str] = None + model_key_override: str | None = None # Fitted model, constructed using specified `model_kwargs` and `Data` # on `ModelSpec.fit` - _fitted_model: Optional[ModelBridge] = None + _fitted_model: ModelBridge | None = None # Stored cross validation results set in cross validate. - _cv_results: Optional[list[CVResult]] = None + _cv_results: list[CVResult] | None = None # Stored cross validation diagnostics set in cross validate. - _diagnostics: Optional[CVDiagnostics] = None + _diagnostics: CVDiagnostics | None = None # Stored to check if the CV result & diagnostic cache is safe to reuse. - _last_cv_kwargs: Optional[dict[str, Any]] = None + _last_cv_kwargs: dict[str, Any] | None = None # Stored to check if the model can be safely updated in fit. - _last_fit_arg_ids: Optional[dict[str, int]] = None + _last_fit_arg_ids: dict[str, int] | None = None def __post_init__(self) -> None: self.model_kwargs = self.model_kwargs or {} @@ -93,14 +94,14 @@ def fitted_model(self) -> ModelBridge: return not_none(self._fitted_model) @property - def fixed_features(self) -> Optional[ObservationFeatures]: + def fixed_features(self) -> ObservationFeatures | None: """ Fixed generation features to pass into the Model's `.gen` function. """ return self.model_gen_kwargs.get("fixed_features", None) @fixed_features.setter - def fixed_features(self, value: Optional[ObservationFeatures]) -> None: + def fixed_features(self, value: ObservationFeatures | None) -> None: """ Fixed generation features to pass into the Model's `.gen` function. """ @@ -155,8 +156,8 @@ def fit( def cross_validate( self, - model_cv_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[Optional[list[CVResult]], Optional[CVDiagnostics]]: + model_cv_kwargs: dict[str, Any] | None = None, + ) -> tuple[list[CVResult] | None, CVDiagnostics | None]: """ Call cross_validate, compute_diagnostics and cache the results. If the model cannot be cross validated, warn and return None. @@ -195,7 +196,7 @@ def cross_validate( return self._cv_results, self._diagnostics @property - def cv_results(self) -> Optional[list[CVResult]]: + def cv_results(self) -> list[CVResult] | None: """ Cached CV results from `self.cross_validate()` if it has been successfully called @@ -203,7 +204,7 @@ def cv_results(self) -> Optional[list[CVResult]]: return self._cv_results @property - def diagnostics(self) -> Optional[CVDiagnostics]: + def diagnostics(self) -> CVDiagnostics | None: """ Cached CV diagnostics from `self.cross_validate()` if it has been successfully called @@ -327,9 +328,9 @@ def _unique_id(self) -> str: @dataclass class FactoryFunctionModelSpec(ModelSpec): - factory_function: Optional[TModelFactory] = None + factory_function: TModelFactory | None = None # pyre-ignore[15]: `ModelSpec` has this as non-optional - model_enum: Optional[ModelRegistryBase] = None + model_enum: ModelRegistryBase | None = None def __post_init__(self) -> None: super().__post_init__() @@ -366,8 +367,8 @@ def fit( self, experiment: Experiment, data: Data, - search_space: Optional[SearchSpace] = None, - optimization_config: Optional[OptimizationConfig] = None, + search_space: SearchSpace | None = None, + optimization_config: OptimizationConfig | None = None, **model_kwargs: Any, ) -> None: """Fits the specified model on the given experiment + data using the diff --git a/ax/modelbridge/modelbridge_utils.py b/ax/modelbridge/modelbridge_utils.py index 89822c9658c..18ffdde2a44 100644 --- a/ax/modelbridge/modelbridge_utils.py +++ b/ax/modelbridge/modelbridge_utils.py @@ -9,11 +9,11 @@ from __future__ import annotations import warnings -from collections.abc import Iterable, Mapping, MutableMapping +from collections.abc import Callable, Iterable, Mapping, MutableMapping from copy import deepcopy from functools import partial from logging import Logger -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING import numpy as np import torch @@ -140,7 +140,7 @@ def extract_risk_measure(risk_measure: RiskMeasure) -> RiskMeasureMCObjective: def check_has_multi_objective_and_data( experiment: Experiment, data: Data, - optimization_config: Optional[OptimizationConfig] = None, + optimization_config: OptimizationConfig | None = None, ) -> None: """Raise an error if not using a `MultiObjective` or if the data is empty.""" optimization_config = not_none( @@ -207,13 +207,13 @@ def extract_search_space_digest( * The target_value is added to target_values. * Its index is added to fidelity_features. """ - bounds: list[tuple[Union[int, float], Union[int, float]]] = [] + bounds: list[tuple[int | float, int | float]] = [] ordinal_features: list[int] = [] categorical_features: list[int] = [] - discrete_choices: dict[int, list[Union[int, float]]] = {} + discrete_choices: dict[int, list[int | float]] = {} task_features: list[int] = [] fidelity_features: list[int] = [] - target_values: dict[int, Union[int, float]] = {} + target_values: dict[int, int | float] = {} for i, p_name in enumerate(param_names): p = search_space.parameters[p_name] @@ -260,7 +260,7 @@ def extract_search_space_digest( def extract_robust_digest( search_space: SearchSpace, param_names: list[str] -) -> Optional[RobustSearchSpaceDigest]: +) -> RobustSearchSpaceDigest | None: """Extracts the `RobustSearchSpaceDigest`. Args: @@ -365,7 +365,7 @@ def extract_objective_thresholds( objective_thresholds: TRefPoint, objective: Objective, outcomes: list[str], -) -> Optional[np.ndarray]: +) -> np.ndarray | None: """Extracts objective thresholds' values, in the order of `outcomes`. Will return None if no objective thresholds, otherwise the extracted tensor @@ -471,17 +471,17 @@ def extract_outcome_constraints( def validate_and_apply_final_transform( objective_weights: np.ndarray, - outcome_constraints: Optional[tuple[np.ndarray, np.ndarray]], - linear_constraints: Optional[tuple[np.ndarray, np.ndarray]], - pending_observations: Optional[list[np.ndarray]], - objective_thresholds: Optional[np.ndarray] = None, + outcome_constraints: tuple[np.ndarray, np.ndarray] | None, + linear_constraints: tuple[np.ndarray, np.ndarray] | None, + pending_observations: list[np.ndarray] | None, + objective_thresholds: np.ndarray | None = None, final_transform: Callable[[np.ndarray], Tensor] = torch.tensor, ) -> tuple[ Tensor, - Optional[tuple[Tensor, Tensor]], - Optional[tuple[Tensor, Tensor]], - Optional[list[Tensor]], - Optional[Tensor], + tuple[Tensor, Tensor] | None, + tuple[Tensor, Tensor] | None, + list[Tensor] | None, + Tensor | None, ]: # TODO: use some container down the road (similar to # SearchSpaceDigest) to limit the return arguments @@ -517,8 +517,8 @@ def validate_and_apply_final_transform( def get_fixed_features( - fixed_features: Optional[ObservationFeatures], param_names: list[str] -) -> Optional[dict[int, float]]: + fixed_features: ObservationFeatures | None, param_names: list[str] +) -> dict[int, float] | None: """Reformat a set of fixed_features.""" if fixed_features is None: return None @@ -547,7 +547,7 @@ def pending_observations_as_array_list( pending_observations: dict[str, list[ObservationFeatures]], outcome_names: list[str], param_names: list[str], -) -> Optional[list[np.ndarray]]: +) -> list[np.ndarray] | None: """Re-format pending observations. Args: @@ -581,7 +581,7 @@ def pending_observations_as_array_list( def parse_observation_features( X: np.ndarray, param_names: list[str], - candidate_metadata: Optional[list[TCandidateMetadata]] = None, + candidate_metadata: list[TCandidateMetadata] | None = None, ) -> list[ObservationFeatures]: """Re-format raw model-generated candidates into ObservationFeatures. @@ -671,12 +671,12 @@ def _roundtrip_transform(x: np.ndarray) -> np.ndarray: def get_pareto_frontier_and_configs( modelbridge: modelbridge_module.torch.TorchModelBridge, observation_features: list[ObservationFeatures], - observation_data: Optional[list[ObservationData]] = None, - objective_thresholds: Optional[TRefPoint] = None, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, - arm_names: Optional[list[Optional[str]]] = None, + observation_data: list[ObservationData] | None = None, + objective_thresholds: TRefPoint | None = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, + arm_names: list[str | None] | None = None, use_model_predictions: bool = True, -) -> tuple[list[Observation], Tensor, Tensor, Optional[Tensor]]: +) -> tuple[list[Observation], Tensor, Tensor, Tensor | None]: """Helper that applies transforms and calls ``frontier_evaluator``. Returns the ``frontier_evaluator`` configs in addition to the Pareto @@ -815,10 +815,10 @@ def get_pareto_frontier_and_configs( def pareto_frontier( modelbridge: modelbridge_module.torch.TorchModelBridge, observation_features: list[ObservationFeatures], - observation_data: Optional[list[ObservationData]] = None, - objective_thresholds: Optional[TRefPoint] = None, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, - arm_names: Optional[list[Optional[str]]] = None, + observation_data: list[ObservationData] | None = None, + objective_thresholds: TRefPoint | None = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, + arm_names: list[str | None] | None = None, use_model_predictions: bool = True, ) -> list[Observation]: """Compute the list of points on the Pareto frontier as `Observation`-s @@ -885,9 +885,9 @@ def pareto_frontier( def predicted_pareto_frontier( modelbridge: modelbridge_module.torch.TorchModelBridge, - objective_thresholds: Optional[TRefPoint] = None, - observation_features: Optional[list[ObservationFeatures]] = None, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, + objective_thresholds: TRefPoint | None = None, + observation_features: list[ObservationFeatures] | None = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, ) -> list[Observation]: """Generate a Pareto frontier based on the posterior means of given observation features. Given a model and optionally features to evaluate @@ -930,8 +930,8 @@ def predicted_pareto_frontier( def observed_pareto_frontier( modelbridge: modelbridge_module.torch.TorchModelBridge, - objective_thresholds: Optional[TRefPoint] = None, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, + objective_thresholds: TRefPoint | None = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, ) -> list[Observation]: """Generate a pareto frontier based on observed data. Given observed data (sourced from model training data), return points on the Pareto frontier @@ -967,10 +967,10 @@ def observed_pareto_frontier( def hypervolume( modelbridge: modelbridge_module.torch.TorchModelBridge, observation_features: list[ObservationFeatures], - objective_thresholds: Optional[TRefPoint] = None, - observation_data: Optional[list[ObservationData]] = None, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, - selected_metrics: Optional[list[str]] = None, + objective_thresholds: TRefPoint | None = None, + observation_data: list[ObservationData] | None = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, + selected_metrics: list[str] | None = None, use_model_predictions: bool = True, ) -> float: """Helper function that computes (feasible) hypervolume. @@ -1041,8 +1041,8 @@ def hypervolume( def _get_multiobjective_optimization_config( modelbridge: modelbridge_module.torch.TorchModelBridge, - optimization_config: Optional[OptimizationConfig] = None, - objective_thresholds: Optional[TRefPoint] = None, + optimization_config: OptimizationConfig | None = None, + objective_thresholds: TRefPoint | None = None, ) -> MultiObjectiveOptimizationConfig: # Optimization_config mooc = optimization_config or checked_cast_optional( @@ -1050,11 +1050,9 @@ def _get_multiobjective_optimization_config( ) if not mooc: raise ValueError( - ( - "Experiment must have an existing optimization_config " - "of type `MultiObjectiveOptimizationConfig` " - "or `optimization_config` must be passed as an argument." - ) + "Experiment must have an existing optimization_config " + "of type `MultiObjectiveOptimizationConfig` " + "or `optimization_config` must be passed as an argument." ) if not isinstance(mooc, MultiObjectiveOptimizationConfig): raise ValueError( @@ -1068,10 +1066,10 @@ def _get_multiobjective_optimization_config( def predicted_hypervolume( modelbridge: modelbridge_module.torch.TorchModelBridge, - objective_thresholds: Optional[TRefPoint] = None, - observation_features: Optional[list[ObservationFeatures]] = None, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, - selected_metrics: Optional[list[str]] = None, + objective_thresholds: TRefPoint | None = None, + observation_features: list[ObservationFeatures] | None = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, + selected_metrics: list[str] | None = None, ) -> float: """Calculate hypervolume of a pareto frontier based on the posterior means of given observation features. @@ -1115,9 +1113,9 @@ def predicted_hypervolume( def observed_hypervolume( modelbridge: modelbridge_module.torch.TorchModelBridge, - objective_thresholds: Optional[TRefPoint] = None, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, - selected_metrics: Optional[list[str]] = None, + objective_thresholds: TRefPoint | None = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, + selected_metrics: list[str] | None = None, ) -> float: """Calculate hypervolume of a pareto frontier based on observed data. @@ -1273,8 +1271,8 @@ def feasible_hypervolume( def _array_to_tensor( - array: Union[np.ndarray, list[float]], - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, + array: np.ndarray | list[float], + modelbridge: modelbridge_module.base.ModelBridge | None = None, ) -> Tensor: if modelbridge and hasattr(modelbridge, "_array_to_tensor"): # pyre-ignore[16]: modelbridge does not have attribute `_array_to_tensor` @@ -1285,14 +1283,14 @@ def _array_to_tensor( def _get_modelbridge_training_data( modelbridge: modelbridge_module.torch.TorchModelBridge, -) -> tuple[list[ObservationFeatures], list[ObservationData], list[Optional[str]]]: +) -> tuple[list[ObservationFeatures], list[ObservationData], list[str | None]]: obs = modelbridge.get_training_data() return _unpack_observations(obs=obs) def _unpack_observations( obs: list[Observation], -) -> tuple[list[ObservationFeatures], list[ObservationData], list[Optional[str]]]: +) -> tuple[list[ObservationFeatures], list[ObservationData], list[str | None]]: obs_feats, obs_data, arm_names = [], [], [] for ob in obs: obs_feats.append(ob.features) @@ -1333,7 +1331,7 @@ def process_contextual_datasets( datasets: list[SupervisedDataset], outcomes: list[str], parameter_decomposition: dict[str, list[str]], - metric_decomposition: Optional[dict[str, list[str]]] = None, + metric_decomposition: dict[str, list[str]] | None = None, ) -> list[ContextualDataset]: """Contruct a list of `ContextualDataset`. diff --git a/ax/modelbridge/pairwise.py b/ax/modelbridge/pairwise.py index b725d231fed..ad76c92f844 100644 --- a/ax/modelbridge/pairwise.py +++ b/ax/modelbridge/pairwise.py @@ -8,8 +8,6 @@ from __future__ import annotations -from typing import Optional - import numpy as np import torch from ax.core.observation import ObservationData, ObservationFeatures @@ -31,9 +29,9 @@ def _convert_observations( observation_features: list[ObservationFeatures], outcomes: list[str], parameters: list[str], - search_space_digest: Optional[SearchSpaceDigest], + search_space_digest: SearchSpaceDigest | None, ) -> tuple[ - list[SupervisedDataset], list[str], Optional[list[list[TCandidateMetadata]]] + list[SupervisedDataset], list[str], list[list[TCandidateMetadata]] | None ]: """Converts observations to a dictionary of `Dataset` containers and (optional) candidate metadata. diff --git a/ax/modelbridge/prediction_utils.py b/ax/modelbridge/prediction_utils.py index 1258b7490a0..832b65e9ec1 100644 --- a/ax/modelbridge/prediction_utils.py +++ b/ax/modelbridge/prediction_utils.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any import numpy as np from ax.core.observation import ObservationFeatures @@ -19,7 +19,7 @@ def predict_at_point( model: ModelBridge, obsf: ObservationFeatures, metric_names: set[str], - scalarized_metric_config: Optional[list[dict[str, Any]]] = None, + scalarized_metric_config: list[dict[str, Any]] | None = None, ) -> tuple[dict[str, float], dict[str, float]]: """Make a prediction at a point. diff --git a/ax/modelbridge/random.py b/ax/modelbridge/random.py index 0cb7581cf53..1f4a3b87f9b 100644 --- a/ax/modelbridge/random.py +++ b/ax/modelbridge/random.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Optional from ax.core.experiment import Experiment from ax.core.observation import Observation, ObservationData, ObservationFeatures @@ -48,7 +47,7 @@ def _fit( self, model: RandomModel, search_space: SearchSpace, - observations: Optional[list[Observation]] = None, + observations: list[Observation] | None = None, ) -> None: self.model = model # Extract and fix parameters from initial search space. @@ -59,9 +58,9 @@ def _gen( n: int, search_space: SearchSpace, pending_observations: dict[str, list[ObservationFeatures]], - fixed_features: Optional[ObservationFeatures], - optimization_config: Optional[OptimizationConfig], - model_gen_options: Optional[TConfig], + fixed_features: ObservationFeatures | None, + optimization_config: OptimizationConfig | None, + model_gen_options: TConfig | None, ) -> GenResults: """Generate new candidates according to a search_space.""" # Extract parameter values @@ -106,8 +105,8 @@ def _cross_validate( def _set_status_quo( self, - experiment: Optional[Experiment], - status_quo_name: Optional[str], - status_quo_features: Optional[ObservationFeatures], + experiment: Experiment | None, + status_quo_name: str | None, + status_quo_features: ObservationFeatures | None, ) -> None: pass diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 67f166323cd..74e2a79d046 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -22,7 +22,7 @@ from enum import Enum from inspect import isfunction, signature from logging import Logger -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple import torch from ax.core.data import Data @@ -146,9 +146,9 @@ class ModelSetup(NamedTuple): bridge_class: type[ModelBridge] model_class: type[Model] transforms: list[type[Transform]] - default_model_kwargs: Optional[dict[str, Any]] = None - standard_bridge_kwargs: Optional[dict[str, Any]] = None - not_saved_model_kwargs: Optional[list[str]] = None + default_model_kwargs: dict[str, Any] | None = None + standard_bridge_kwargs: dict[str, Any] | None = None + not_saved_model_kwargs: list[str] | None = None """A mapping of string keys that indicate a model, to the corresponding @@ -275,9 +275,9 @@ def model_bridge_class(self) -> type[ModelBridge]: def __call__( self, - search_space: Optional[SearchSpace] = None, - experiment: Optional[Experiment] = None, - data: Optional[Data] = None, + search_space: SearchSpace | None = None, + experiment: Experiment | None = None, + data: Data | None = None, silently_filter_kwargs: bool = False, **kwargs: Any, ) -> ModelBridge: @@ -392,7 +392,7 @@ def view_kwargs(self) -> tuple[dict[str, Any], dict[str, Any]]: @staticmethod def _get_model_kwargs( - info: ModelSetup, kwargs: Optional[dict[str, Any]] = None + info: ModelSetup, kwargs: dict[str, Any] | None = None ) -> dict[str, Any]: return consolidate_kwargs( [get_function_default_arguments(info.model_class), kwargs], @@ -401,7 +401,7 @@ def _get_model_kwargs( @staticmethod def _get_bridge_kwargs( - info: ModelSetup, kwargs: Optional[dict[str, Any]] = None + info: ModelSetup, kwargs: dict[str, Any] | None = None ) -> dict[str, Any]: return consolidate_kwargs( [ @@ -548,7 +548,7 @@ def get_model_from_generator_run( def _combine_model_kwargs_and_state( generator_run: GeneratorRun, model_class: type[Model], - model_kwargs: Optional[dict[str, Any]] = None, + model_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: """Produces a combined dict of model kwargs and model state after gen, extracted from generator run. If model kwargs are not specified, diff --git a/ax/modelbridge/tests/test_external_generation_node.py b/ax/modelbridge/tests/test_external_generation_node.py index 8ae829e1beb..2bd032f67fc 100644 --- a/ax/modelbridge/tests/test_external_generation_node.py +++ b/ax/modelbridge/tests/test_external_generation_node.py @@ -7,7 +7,6 @@ # pyre-strict from copy import deepcopy -from typing import Optional from unittest.mock import MagicMock from ax.core.data import Data @@ -32,7 +31,7 @@ def __init__(self) -> None: super().__init__(node_name="dummy") self.update_count = 0 self.gen_count = 0 - self.generator: Optional[RandomModelBridge] = None + self.generator: RandomModelBridge | None = None self.last_pending: list[TParameterization] = [] def update_generator_state(self, experiment: Experiment, data: Data) -> None: diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index 20d8c629ad5..4f6ddaa84dc 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -7,7 +7,7 @@ import inspect from collections import Counter -from typing import Any, Dict, get_type_hints, Optional +from typing import Any, get_type_hints from ax.core.arm import Arm from ax.core.generator_run import GeneratorRun @@ -165,8 +165,10 @@ def test_all_constructors_have_same_signature(self) -> None: ), ) self.assertEqual( - func_parameters["previous_node"], Optional[GenerationNode] + func_parameters["previous_node"], GenerationNode | None ) self.assertEqual(func_parameters["next_node"], GenerationNode) - self.assertEqual(func_parameters["gs_gen_call_kwargs"], Dict[str, Any]) + # pyre-ignore [16]: Undefined attribute [16]: `dict` has no attribute + # `__getitem__`.¸ + self.assertEqual(func_parameters["gs_gen_call_kwargs"], dict[str, Any]) self.assertEqual(method_signature, inspect.signature(constructor)) diff --git a/ax/modelbridge/tests/test_modelbridge_utils.py b/ax/modelbridge/tests/test_modelbridge_utils.py index f4dbb537c7e..9d39cf319da 100644 --- a/ax/modelbridge/tests/test_modelbridge_utils.py +++ b/ax/modelbridge/tests/test_modelbridge_utils.py @@ -7,7 +7,6 @@ # pyre-strict from dataclasses import dataclass -from typing import Union import numpy as np import torch @@ -43,7 +42,7 @@ def test__array_to_tensor(self) -> None: @dataclass class MockModelbridge(ModelBridge): - def _array_to_tensor(self, array: Union[np.ndarray, list[float]]): + def _array_to_tensor(self, array: np.ndarray | list[float]): return _array_to_tensor(array=array) mock_modelbridge = MockModelbridge() diff --git a/ax/modelbridge/tests/test_robust_modelbridge.py b/ax/modelbridge/tests/test_robust_modelbridge.py index 17c560eb6e3..8327ebf6482 100644 --- a/ax/modelbridge/tests/test_robust_modelbridge.py +++ b/ax/modelbridge/tests/test_robust_modelbridge.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Optional from ax.core import Objective, OptimizationConfig from ax.core.objective import MultiObjective @@ -35,9 +34,9 @@ class TestRobust(TestCase): @fast_botorch_optimize def test_robust( self, - risk_measure: Optional[RiskMeasure] = None, - optimization_config: Optional[OptimizationConfig] = None, - acqf_class: Optional[str] = None, + risk_measure: RiskMeasure | None = None, + optimization_config: OptimizationConfig | None = None, + acqf_class: str | None = None, ) -> None: exp = get_robust_branin_experiment( risk_measure=risk_measure, diff --git a/ax/modelbridge/tests/test_torch_modelbridge.py b/ax/modelbridge/tests/test_torch_modelbridge.py index be011511d83..e8bfc7dfdbf 100644 --- a/ax/modelbridge/tests/test_torch_modelbridge.py +++ b/ax/modelbridge/tests/test_torch_modelbridge.py @@ -7,7 +7,7 @@ # pyre-strict from contextlib import ExitStack -from typing import Any, Optional +from typing import Any from unittest import mock from unittest.mock import Mock @@ -58,8 +58,8 @@ def _get_mock_modelbridge( - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, fit_out_of_design: bool = False, ) -> TorchModelBridge: return TorchModelBridge( @@ -83,8 +83,8 @@ class TorchModelBridgeTest(TestCase): def test_TorchModelBridge( self, mock_init: Mock, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, ) -> None: ma = _get_mock_modelbridge(dtype=dtype, device=device) ma._fit_tracking_metrics = True diff --git a/ax/modelbridge/tests/test_torch_moo_modelbridge.py b/ax/modelbridge/tests/test_torch_moo_modelbridge.py index 3039704c0e7..968737c33ce 100644 --- a/ax/modelbridge/tests/test_torch_moo_modelbridge.py +++ b/ax/modelbridge/tests/test_torch_moo_modelbridge.py @@ -7,7 +7,6 @@ # pyre-strict from contextlib import ExitStack -from typing import Optional from unittest.mock import patch import numpy as np @@ -69,7 +68,7 @@ class MultiObjectiveTorchModelBridgeTest(TestCase): ) @fast_botorch_optimize def helper_test_pareto_frontier( - self, _, outcome_constraints: Optional[list[OutcomeConstraint]] + self, _, outcome_constraints: list[OutcomeConstraint] | None ) -> None: """ Make sure Pareto-related functions run. diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index dcb147af0c9..a0156039a8f 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -9,9 +9,10 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Callable from copy import deepcopy from logging import Logger -from typing import Any, Callable, Optional, Union +from typing import Any import numpy as np import torch @@ -90,11 +91,11 @@ class TorchModelBridge(ModelBridge): them to the model. """ - model: Optional[TorchModel] = None + model: TorchModel | None = None outcomes: list[str] parameters: list[str] _default_model_gen_options: TConfig - _last_observations: Optional[list[Observation]] = None + _last_observations: list[Observation] | None = None def __init__( self, @@ -103,17 +104,17 @@ def __init__( data: Data, model: TorchModel, transforms: list[type[Transform]], - transform_configs: Optional[dict[str, TConfig]] = None, - torch_dtype: Optional[torch.dtype] = None, - torch_device: Optional[torch.device] = None, - status_quo_name: Optional[str] = None, - status_quo_features: Optional[ObservationFeatures] = None, - optimization_config: Optional[OptimizationConfig] = None, + transform_configs: dict[str, TConfig] | None = None, + torch_dtype: torch.dtype | None = None, + torch_device: torch.device | None = None, + status_quo_name: str | None = None, + status_quo_features: ObservationFeatures | None = None, + optimization_config: OptimizationConfig | None = None, fit_out_of_design: bool = False, fit_abandoned: bool = False, fit_tracking_metrics: bool = True, fit_on_init: bool = True, - default_model_gen_options: Optional[TConfig] = None, + default_model_gen_options: TConfig | None = None, ) -> None: self.dtype: torch.dtype = torch.double if torch_dtype is None else torch_dtype self.device = torch_device @@ -153,9 +154,9 @@ def feature_importances(self, metric_name: str) -> dict[str, float]: def infer_objective_thresholds( self, - search_space: Optional[SearchSpace] = None, - optimization_config: Optional[OptimizationConfig] = None, - fixed_features: Optional[ObservationFeatures] = None, + search_space: SearchSpace | None = None, + optimization_config: OptimizationConfig | None = None, + fixed_features: ObservationFeatures | None = None, ) -> list[ObjectiveThreshold]: """Infer objective thresholds. @@ -227,12 +228,12 @@ def infer_objective_thresholds( def model_best_point( self, - search_space: Optional[SearchSpace] = None, - optimization_config: Optional[OptimizationConfig] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, - fixed_features: Optional[ObservationFeatures] = None, - model_gen_options: Optional[TConfig] = None, - ) -> Optional[tuple[Arm, Optional[TModelPredictArm]]]: + search_space: SearchSpace | None = None, + optimization_config: OptimizationConfig | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, + fixed_features: ObservationFeatures | None = None, + model_gen_options: TConfig | None = None, + ) -> tuple[Arm, TModelPredictArm | None] | None: # Get modifiable versions if search_space is None: search_space = self._model_space @@ -292,7 +293,7 @@ def _array_callable_to_tensor_callable( def _array_list_to_tensors(self, arrays: list[np.ndarray]) -> list[Tensor]: return [self._array_to_tensor(x) for x in arrays] - def _array_to_tensor(self, array: Union[np.ndarray, list[float]]) -> Tensor: + def _array_to_tensor(self, array: np.ndarray | list[float]) -> Tensor: return torch.as_tensor(array, dtype=self.dtype, device=self.device) def _convert_observations( @@ -301,9 +302,9 @@ def _convert_observations( observation_features: list[ObservationFeatures], outcomes: list[str], parameters: list[str], - search_space_digest: Optional[SearchSpaceDigest], + search_space_digest: SearchSpaceDigest | None, ) -> tuple[ - list[SupervisedDataset], list[str], Optional[list[list[TCandidateMetadata]]] + list[SupervisedDataset], list[str], list[list[TCandidateMetadata]] | None ]: """Converts observations to a dictionary of `Dataset` containers and (optional) candidate metadata. @@ -428,7 +429,7 @@ def _cross_validate( search_space: SearchSpace, cv_training_data: list[Observation], cv_test_points: list[ObservationFeatures], - parameters: Optional[list[str]] = None, + parameters: list[str] | None = None, use_posterior_predictive: bool = False, **kwargs: Any, ) -> list[ObservationData]: @@ -467,14 +468,14 @@ def _cross_validate( def evaluate_acquisition_function( self, - observation_features: Union[ - list[ObservationFeatures], list[list[ObservationFeatures]] - ], - search_space: Optional[SearchSpace] = None, - optimization_config: Optional[OptimizationConfig] = None, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, - fixed_features: Optional[ObservationFeatures] = None, - acq_options: Optional[dict[str, Any]] = None, + observation_features: ( + list[ObservationFeatures] | list[list[ObservationFeatures]] + ), + search_space: SearchSpace | None = None, + optimization_config: OptimizationConfig | None = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, + fixed_features: ObservationFeatures | None = None, + acq_options: dict[str, Any] | None = None, ) -> list[float]: """Evaluate the acquisition function for given set of observation features. @@ -537,9 +538,9 @@ def _evaluate_acquisition_function( observation_features: list[list[ObservationFeatures]], search_space: SearchSpace, optimization_config: OptimizationConfig, - pending_observations: Optional[dict[str, list[ObservationFeatures]]] = None, - fixed_features: Optional[ObservationFeatures] = None, - acq_options: Optional[dict[str, Any]] = None, + pending_observations: dict[str, list[ObservationFeatures]] | None = None, + fixed_features: ObservationFeatures | None = None, + acq_options: dict[str, Any] | None = None, ) -> list[float]: if self.model is None: raise RuntimeError( @@ -569,11 +570,11 @@ def _get_fit_args( self, search_space: SearchSpace, observations: list[Observation], - parameters: Optional[list[str]], + parameters: list[str] | None, update_outcomes_and_parameters: bool, ) -> tuple[ list[SupervisedDataset], - Optional[list[list[TCandidateMetadata]]], + list[list[TCandidateMetadata]] | None, SearchSpaceDigest, ]: """Helper for consolidating some common argument processing between @@ -633,7 +634,7 @@ def _fit( model: TorchModel, search_space: SearchSpace, observations: list[Observation], - parameters: Optional[list[str]] = None, + parameters: list[str] | None = None, **kwargs: Any, ) -> None: if self.model is not None and observations == self._last_observations: @@ -662,9 +663,9 @@ def _gen( n: int, search_space: SearchSpace, pending_observations: dict[str, list[ObservationFeatures]], - fixed_features: Optional[ObservationFeatures], - model_gen_options: Optional[TConfig] = None, - optimization_config: Optional[OptimizationConfig] = None, + fixed_features: ObservationFeatures | None, + model_gen_options: TConfig | None = None, + optimization_config: OptimizationConfig | None = None, ) -> GenResults: """Generate new candidates according to search_space and optimization_config. @@ -764,7 +765,7 @@ def _predict( return array_to_observation_data(f=f, cov=cov, outcomes=self.outcomes) def _array_to_observation_features( - self, X: np.ndarray, candidate_metadata: Optional[list[TCandidateMetadata]] + self, X: np.ndarray, candidate_metadata: list[TCandidateMetadata] | None ) -> list[ObservationFeatures]: return parse_observation_features( X=X, param_names=self.parameters, candidate_metadata=candidate_metadata @@ -793,9 +794,9 @@ def _get_transformed_model_gen_args( self, search_space: SearchSpace, pending_observations: dict[str, list[ObservationFeatures]], - fixed_features: Optional[ObservationFeatures], - model_gen_options: Optional[TConfig] = None, - optimization_config: Optional[OptimizationConfig] = None, + fixed_features: ObservationFeatures | None, + model_gen_options: TConfig | None = None, + optimization_config: OptimizationConfig | None = None, ) -> tuple[SearchSpaceDigest, TorchOptConfig]: # Validation if not self.parameters: @@ -900,9 +901,9 @@ def _untransform_objective_thresholds( self, objective_thresholds: Tensor, objective_weights: Tensor, - bounds: list[tuple[Union[int, float], Union[int, float]]], + bounds: list[tuple[int | float, int | float]], opt_config_metrics: dict[str, Metric], - fixed_features: Optional[dict[int, float]], + fixed_features: dict[int, float] | None, ) -> list[ObjectiveThreshold]: thresholds_np = objective_thresholds.cpu().numpy() idxs = objective_weights.nonzero().view(-1).tolist() diff --git a/ax/modelbridge/transforms/base.py b/ax/modelbridge/transforms/base.py index 9f1b7e6212a..01ebe75a21f 100644 --- a/ax/modelbridge/transforms/base.py +++ b/ax/modelbridge/transforms/base.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from ax.core.observation import ( Observation, @@ -53,14 +53,14 @@ class Transform: """ config: TConfig - modelbridge: Optional[modelbridge_module.base.ModelBridge] + modelbridge: modelbridge_module.base.ModelBridge | None def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, - config: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: modelbridge_module.base.ModelBridge | None = None, + config: TConfig | None = None, ) -> None: """Do any initial computations for preparing the transform. @@ -109,8 +109,8 @@ def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, - fixed_features: Optional[ObservationFeatures] = None, + modelbridge: modelbridge_module.base.ModelBridge | None = None, + fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: """Transform optimization config. @@ -241,7 +241,7 @@ def _untransform_observation_data( def untransform_outcome_constraints( self, outcome_constraints: list[OutcomeConstraint], - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> list[OutcomeConstraint]: """Untransform outcome constraints. diff --git a/ax/modelbridge/transforms/cap_parameter.py b/ax/modelbridge/transforms/cap_parameter.py index a28154c3054..bbfb885e946 100644 --- a/ax/modelbridge/transforms/cap_parameter.py +++ b/ax/modelbridge/transforms/cap_parameter.py @@ -29,10 +29,10 @@ class CapParameter(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: # pyre-fixme[4]: Attribute must be annotated. self.config = config or {} diff --git a/ax/modelbridge/transforms/cast.py b/ax/modelbridge/transforms/cast.py index 9b37bbc6172..45ac255c321 100644 --- a/ax/modelbridge/transforms/cast.py +++ b/ax/modelbridge/transforms/cast.py @@ -42,10 +42,10 @@ class Cast(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: self.search_space: SearchSpace = not_none(search_space).clone() config = (config or {}).copy() diff --git a/ax/modelbridge/transforms/choice_encode.py b/ax/modelbridge/transforms/choice_encode.py index 027d35dfb2b..6a885aea772 100644 --- a/ax/modelbridge/transforms/choice_encode.py +++ b/ax/modelbridge/transforms/choice_encode.py @@ -52,10 +52,10 @@ class ChoiceToNumericChoice(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert search_space is not None, "ChoiceToNumericChoice requires search space" # Identify parameters that should be transformed @@ -151,7 +151,7 @@ def __init__( search_space: SearchSpace, observations: list[Observation], modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: # Identify parameters that should be transformed self.encoded_parameters: dict[str, dict[TParamValue, int]] = {} diff --git a/ax/modelbridge/transforms/convert_metric_names.py b/ax/modelbridge/transforms/convert_metric_names.py index a8097e0bce7..66852af075a 100644 --- a/ax/modelbridge/transforms/convert_metric_names.py +++ b/ax/modelbridge/transforms/convert_metric_names.py @@ -39,10 +39,10 @@ class ConvertMetricNames(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert observations is not None, "ConvertMetricNames requires observations" if config is None: diff --git a/ax/modelbridge/transforms/derelativize.py b/ax/modelbridge/transforms/derelativize.py index c9d33a6b4aa..8d9364ae187 100644 --- a/ax/modelbridge/transforms/derelativize.py +++ b/ax/modelbridge/transforms/derelativize.py @@ -48,7 +48,7 @@ def transform_optimization_config( self, optimization_config: OptimizationConfig, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: use_raw_sq = self.config.get("use_raw_status_quo", False) has_relative_constraint = any( @@ -112,7 +112,7 @@ def transform_optimization_config( def untransform_outcome_constraints( self, outcome_constraints: list[OutcomeConstraint], - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> list[OutcomeConstraint]: # We intentionally leave outcome constraints derelativized when # untransforming. diff --git a/ax/modelbridge/transforms/int_range_to_choice.py b/ax/modelbridge/transforms/int_range_to_choice.py index 3314d672954..1ec3959f984 100644 --- a/ax/modelbridge/transforms/int_range_to_choice.py +++ b/ax/modelbridge/transforms/int_range_to_choice.py @@ -29,10 +29,10 @@ class IntRangeToChoice(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert search_space is not None, "IntRangeToChoice requires search space" config = config or {} diff --git a/ax/modelbridge/transforms/int_to_float.py b/ax/modelbridge/transforms/int_to_float.py index e625c4da7cc..7e75d07b652 100644 --- a/ax/modelbridge/transforms/int_to_float.py +++ b/ax/modelbridge/transforms/int_to_float.py @@ -48,10 +48,10 @@ class IntToFloat(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: self.search_space: SearchSpace = not_none( search_space, "IntToFloat requires search space" diff --git a/ax/modelbridge/transforms/inverse_gaussian_cdf_y.py b/ax/modelbridge/transforms/inverse_gaussian_cdf_y.py index fab074b4632..398b53cb21b 100644 --- a/ax/modelbridge/transforms/inverse_gaussian_cdf_y.py +++ b/ax/modelbridge/transforms/inverse_gaussian_cdf_y.py @@ -35,10 +35,10 @@ class InverseGaussianCdfY(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["base_modelbridge.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: # pyre-fixme[4]: Attribute must be annotated. self.dist = norm(loc=0, scale=1) diff --git a/ax/modelbridge/transforms/log.py b/ax/modelbridge/transforms/log.py index 1f2c61d2c43..0e344bf84ca 100644 --- a/ax/modelbridge/transforms/log.py +++ b/ax/modelbridge/transforms/log.py @@ -28,10 +28,10 @@ class Log(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert search_space is not None, "Log requires search space" # Identify parameters that should be transformed diff --git a/ax/modelbridge/transforms/log_y.py b/ax/modelbridge/transforms/log_y.py index 1f07e9ada44..f92173147a1 100644 --- a/ax/modelbridge/transforms/log_y.py +++ b/ax/modelbridge/transforms/log_y.py @@ -8,9 +8,11 @@ from __future__ import annotations +from collections.abc import Callable + from logging import Logger -from typing import Callable, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import numpy as np from ax.core.observation import Observation, ObservationData, ObservationFeatures @@ -45,10 +47,10 @@ class LogY(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, - modelbridge: Optional["base_modelbridge.ModelBridge"] = None, - config: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: base_modelbridge.ModelBridge | None = None, + config: TConfig | None = None, ) -> None: if config is None: raise ValueError("LogY requires a config.") @@ -79,8 +81,8 @@ def __init__( def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional[base_modelbridge.ModelBridge] = None, - fixed_features: Optional[ObservationFeatures] = None, + modelbridge: base_modelbridge.ModelBridge | None = None, + fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: for c in optimization_config.all_constraints: if c.metric.name in self.metric_names: @@ -144,7 +146,7 @@ def _untransform_observation_data( def untransform_outcome_constraints( self, outcome_constraints: list[OutcomeConstraint], - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> list[OutcomeConstraint]: for c in outcome_constraints: if c.metric.name in self.metric_names: diff --git a/ax/modelbridge/transforms/logit.py b/ax/modelbridge/transforms/logit.py index 637ddea0395..3c3e4ddb416 100644 --- a/ax/modelbridge/transforms/logit.py +++ b/ax/modelbridge/transforms/logit.py @@ -28,10 +28,10 @@ class Logit(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert search_space is not None, "Logit requires search space" # Identify parameters that should be transformed diff --git a/ax/modelbridge/transforms/map_unit_x.py b/ax/modelbridge/transforms/map_unit_x.py index 50807ed6884..881e44c23cb 100644 --- a/ax/modelbridge/transforms/map_unit_x.py +++ b/ax/modelbridge/transforms/map_unit_x.py @@ -10,7 +10,7 @@ from collections import defaultdict -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from ax.core.observation import Observation, ObservationFeatures from ax.core.search_space import SearchSpace @@ -34,10 +34,10 @@ class MapUnitX(UnitX): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, - config: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: modelbridge_module.base.ModelBridge | None = None, + config: TConfig | None = None, ) -> None: assert observations is not None, "MapUnitX requires observations" assert search_space is not None, "MapUnitX requires search space" diff --git a/ax/modelbridge/transforms/merge_repeated_measurements.py b/ax/modelbridge/transforms/merge_repeated_measurements.py index d7be3a5f9d6..f406520a800 100644 --- a/ax/modelbridge/transforms/merge_repeated_measurements.py +++ b/ax/modelbridge/transforms/merge_repeated_measurements.py @@ -10,7 +10,6 @@ from collections import defaultdict from copy import deepcopy -from typing import Optional import numpy as np from ax.core.arm import Arm @@ -35,10 +34,10 @@ class MergeRepeatedMeasurements(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, - modelbridge: Optional[ModelBridge] = None, - config: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: ModelBridge | None = None, + config: TConfig | None = None, ) -> None: if observations is None: raise RuntimeError("MergeRepeatedMeasurements requires observations") diff --git a/ax/modelbridge/transforms/metrics_as_task.py b/ax/modelbridge/transforms/metrics_as_task.py index 4664d3b2486..f3473366853 100644 --- a/ax/modelbridge/transforms/metrics_as_task.py +++ b/ax/modelbridge/transforms/metrics_as_task.py @@ -41,10 +41,10 @@ class MetricsAsTask(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: # Use config to specify metric task map if config is None or "metric_task_map" not in config: diff --git a/ax/modelbridge/transforms/one_hot.py b/ax/modelbridge/transforms/one_hot.py index 35c95da69ef..aab07d0bf7a 100644 --- a/ax/modelbridge/transforms/one_hot.py +++ b/ax/modelbridge/transforms/one_hot.py @@ -87,10 +87,10 @@ class OneHot(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert search_space is not None, "OneHot requires search space" # Identify parameters that should be transformed @@ -111,8 +111,7 @@ def __init__( self.encoded_parameters[p.name] = [p.name + OH_PARAM_INFIX] else: self.encoded_parameters[p.name] = [ - "{}{}_{}".format(p.name, OH_PARAM_INFIX, i) - for i in range(encoded_len) + f"{p.name}{OH_PARAM_INFIX}_{i}" for i in range(encoded_len) ] def transform_observation_features( diff --git a/ax/modelbridge/transforms/percentile_y.py b/ax/modelbridge/transforms/percentile_y.py index 237033e83e0..a73d2562756 100644 --- a/ax/modelbridge/transforms/percentile_y.py +++ b/ax/modelbridge/transforms/percentile_y.py @@ -35,10 +35,10 @@ class PercentileY(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert observations is not None, "PercentileY requires observations" if len(observations) == 0: diff --git a/ax/modelbridge/transforms/power_transform_y.py b/ax/modelbridge/transforms/power_transform_y.py index 5b1ece7a2cb..71591a5dc53 100644 --- a/ax/modelbridge/transforms/power_transform_y.py +++ b/ax/modelbridge/transforms/power_transform_y.py @@ -10,7 +10,7 @@ from collections import defaultdict from logging import Logger -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import numpy as np from ax.core.observation import Observation, ObservationData, ObservationFeatures @@ -52,10 +52,10 @@ class PowerTransformY(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, - config: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: modelbridge_module.base.ModelBridge | None = None, + config: TConfig | None = None, ) -> None: assert observations is not None, "PowerTransformY requires observations" if config is None: @@ -120,8 +120,8 @@ def _untransform_observation_data( def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, - fixed_features: Optional[ObservationFeatures] = None, + modelbridge: modelbridge_module.base.ModelBridge | None = None, + fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: for c in optimization_config.all_constraints: if isinstance(c, ScalarizedOutcomeConstraint): @@ -146,7 +146,7 @@ def transform_optimization_config( def untransform_outcome_constraints( self, outcome_constraints: list[OutcomeConstraint], - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> list[OutcomeConstraint]: for c in outcome_constraints: if isinstance(c, ScalarizedOutcomeConstraint): diff --git a/ax/modelbridge/transforms/relativize.py b/ax/modelbridge/transforms/relativize.py index 8b9a2b66b9b..c1a298e0681 100644 --- a/ax/modelbridge/transforms/relativize.py +++ b/ax/modelbridge/transforms/relativize.py @@ -9,9 +9,10 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from math import sqrt -from typing import Callable, Optional, Tuple, TYPE_CHECKING, Union +from typing import TYPE_CHECKING import numpy as np from ax.core.observation import Observation, ObservationData, ObservationFeatures @@ -49,10 +50,10 @@ class BaseRelativize(Transform, ABC): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, - config: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: modelbridge_module.base.ModelBridge | None = None, + config: TConfig | None = None, ) -> None: cls_name = self.__class__.__name__ assert observations is not None, f"{cls_name} requires observations" @@ -84,8 +85,8 @@ def control_as_constant(self) -> bool: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, - fixed_features: Optional[ObservationFeatures] = None, + modelbridge: modelbridge_module.base.ModelBridge | None = None, + fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: r""" Change the relative flag of the given relative optimization configuration @@ -144,7 +145,7 @@ def transform_optimization_config( def untransform_outcome_constraints( self, outcome_constraints: list[OutcomeConstraint], - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> list[OutcomeConstraint]: for c in outcome_constraints: c.relative = True @@ -253,7 +254,7 @@ def _get_rel_mean_sem( sem_c: float, metric: str, rel_op: Callable[..., tuple[np.ndarray, np.ndarray]], - ) -> Tuple[Union[float, np.ndarray], Union[float, np.ndarray]]: + ) -> tuple[float | np.ndarray, float | np.ndarray]: """Compute (un)relativized mean and sem for a single metric.""" # if the is the status quo if means_t == mean_c and sems_t == sem_c: diff --git a/ax/modelbridge/transforms/remove_fixed.py b/ax/modelbridge/transforms/remove_fixed.py index 730a2abd961..acb0c7aebbc 100644 --- a/ax/modelbridge/transforms/remove_fixed.py +++ b/ax/modelbridge/transforms/remove_fixed.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING from ax.core.observation import Observation, ObservationFeatures from ax.core.parameter import ChoiceParameter, FixedParameter, RangeParameter @@ -31,10 +31,10 @@ class RemoveFixed(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert search_space is not None, "RemoveFixed requires search space" # Identify parameters that should be transformed @@ -53,14 +53,14 @@ def transform_observation_features( return observation_features def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: - tunable_parameters: list[Union[ChoiceParameter, RangeParameter]] = [] + tunable_parameters: list[ChoiceParameter | RangeParameter] = [] for p in search_space.parameters.values(): if p.name not in self.fixed_parameters: # If it's not in fixed_parameters, it must be a tunable param. # pyre: p_ is declared to have type `Union[ChoiceParameter, # pyre: RangeParameter]` but is used as type `ax.core. # pyre-fixme[9]: parameter.Parameter`. - p_: Union[ChoiceParameter, RangeParameter] = p + p_: ChoiceParameter | RangeParameter = p tunable_parameters.append(p_) return construct_new_search_space( search_space=search_space, diff --git a/ax/modelbridge/transforms/search_space_to_choice.py b/ax/modelbridge/transforms/search_space_to_choice.py index 673d8ac62ad..462f15abf56 100644 --- a/ax/modelbridge/transforms/search_space_to_choice.py +++ b/ax/modelbridge/transforms/search_space_to_choice.py @@ -36,10 +36,10 @@ class SearchSpaceToChoice(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert search_space is not None, "SearchSpaceToChoice requires search space" assert observations is not None, "SeachSpaceToChoice requires observations" diff --git a/ax/modelbridge/transforms/standardize_y.py b/ax/modelbridge/transforms/standardize_y.py index 695f3f7fcd7..ce7fc936601 100644 --- a/ax/modelbridge/transforms/standardize_y.py +++ b/ax/modelbridge/transforms/standardize_y.py @@ -8,7 +8,7 @@ from collections import defaultdict from logging import Logger -from typing import Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING import numpy as np from ax.core.observation import Observation, ObservationData, ObservationFeatures @@ -39,10 +39,10 @@ class StandardizeY(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["base_modelbridge.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: if observations is None or len(observations) == 0: raise DataRequiredError("`StandardizeY` transform requires non-empty data.") @@ -69,7 +69,7 @@ def transform_optimization_config( self, optimization_config: OptimizationConfig, modelbridge: Optional["base_modelbridge.ModelBridge"] = None, - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: for c in optimization_config.all_constraints: if c.relative: @@ -131,7 +131,7 @@ def _untransform_observation_data( def untransform_outcome_constraints( self, outcome_constraints: list[OutcomeConstraint], - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> list[OutcomeConstraint]: for c in outcome_constraints: if c.relative: @@ -147,10 +147,8 @@ def untransform_outcome_constraints( def compute_standardization_parameters( - Ys: defaultdict[Union[str, tuple[str, TParamValue]], list[float]] -) -> tuple[ - dict[Union[str, tuple[str, str]], float], dict[Union[str, tuple[str, str]], float] -]: + Ys: defaultdict[str | tuple[str, TParamValue], list[float]] +) -> tuple[dict[str | tuple[str, str], float], dict[str | tuple[str, str], float]]: """Compute mean and std. dev of Ys.""" Ymean = {k: np.mean(y) for k, y in Ys.items()} # We use the Bessel correction term (divide by N-1) here in order to diff --git a/ax/modelbridge/transforms/stratified_standardize_y.py b/ax/modelbridge/transforms/stratified_standardize_y.py index 7bfea695b56..82635601611 100644 --- a/ax/modelbridge/transforms/stratified_standardize_y.py +++ b/ax/modelbridge/transforms/stratified_standardize_y.py @@ -8,7 +8,7 @@ from collections import defaultdict from logging import Logger -from typing import Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING import numpy as np from ax.core.observation import Observation, ObservationFeatures, separate_observations @@ -49,10 +49,10 @@ class StratifiedStandardizeY(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: """Initialize StratifiedStandardizeY. @@ -83,7 +83,7 @@ def __init__( if "strata_mapping" in config: # pyre-ignore [8] self.strata_mapping: dict[ - Union[bool, float, int, str], Union[bool, float, int, str] + bool | float | int | str, bool | float | int | str ] = config["strata_mapping"] if set(strat_p.values) != set(self.strata_mapping.keys()): raise ValueError( @@ -149,7 +149,7 @@ def transform_optimization_config( self, optimization_config: OptimizationConfig, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: if len(optimization_config.all_constraints) == 0: return optimization_config @@ -187,7 +187,7 @@ def untransform_observations( def untransform_outcome_constraints( self, outcome_constraints: list[OutcomeConstraint], - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> list[OutcomeConstraint]: if fixed_features is None or self.p_name not in fixed_features.parameters: raise ValueError( diff --git a/ax/modelbridge/transforms/task_encode.py b/ax/modelbridge/transforms/task_encode.py index d32b0b870b4..e1e2111c445 100644 --- a/ax/modelbridge/transforms/task_encode.py +++ b/ax/modelbridge/transforms/task_encode.py @@ -39,10 +39,10 @@ class TaskChoiceToIntTaskChoice(OrderedChoiceToIntegerRange): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert ( search_space is not None diff --git a/ax/modelbridge/transforms/tests/test_relativize_transform.py b/ax/modelbridge/transforms/tests/test_relativize_transform.py index 778c35544ae..982532e2919 100644 --- a/ax/modelbridge/transforms/tests/test_relativize_transform.py +++ b/ax/modelbridge/transforms/tests/test_relativize_transform.py @@ -6,7 +6,6 @@ # pyre-strict from copy import deepcopy -from typing import List, Tuple, Type from unittest.mock import Mock import numpy as np @@ -45,11 +44,11 @@ class RelativizeDataTest(TestCase): - relativize_classes: List[Type[Transform]] = [ + relativize_classes: list[type[Transform]] = [ Relativize, RelativizeWithConstantControl, ] - cases: List[Tuple[Type[Transform], List[Tuple[np.ndarray, np.ndarray]]]] = [ + cases: list[tuple[type[Transform], list[tuple[np.ndarray, np.ndarray]]]] = [ ( Relativize, [ diff --git a/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py b/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py index f74ff7e7043..2f96114a008 100644 --- a/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py +++ b/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import List, Tuple, Type import numpy as np from ax.core.batch_trial import BatchTrial @@ -29,7 +28,7 @@ class TransformToNewSQTest(RelativizeDataTest): # [Type[TransformToNewSQ]]` is not a subtype of the # overridden attribute `List[Type[Transform]]` relativize_classes = [TransformToNewSQ] - cases: List[Tuple[Type[Transform], List[Tuple[np.ndarray, np.ndarray]]]] = [ + cases: list[tuple[type[Transform], list[tuple[np.ndarray, np.ndarray]]]] = [ ( TransformToNewSQ, [ diff --git a/ax/modelbridge/transforms/tests/test_winsorize_transform.py b/ax/modelbridge/transforms/tests/test_winsorize_transform.py index e327de99ba6..62a18484720 100644 --- a/ax/modelbridge/transforms/tests/test_winsorize_transform.py +++ b/ax/modelbridge/transforms/tests/test_winsorize_transform.py @@ -8,7 +8,7 @@ import warnings from copy import deepcopy -from typing import Any, Optional +from typing import Any, SupportsIndex from unittest import mock import numpy as np @@ -44,7 +44,6 @@ get_observations_with_invalid_value, get_optimization_config, ) -from typing_extensions import SupportsIndex INF = float("inf") OBSERVATION_DATA = [ @@ -684,7 +683,7 @@ def get_transform(observation_data, config=None, optimization_config=None) -> Wi def get_default_transform_cutoffs( optimization_config: OptimizationConfig, - winsorization_config: Optional[dict[str, WinsorizationConfig]] = None, + winsorization_config: dict[str, WinsorizationConfig] | None = None, obs_data_len: SupportsIndex = 6, ) -> dict[str, tuple[float, float]]: obsd = ObservationData( diff --git a/ax/modelbridge/transforms/time_as_feature.py b/ax/modelbridge/transforms/time_as_feature.py index 7104bcb2c26..f1fe6501b09 100644 --- a/ax/modelbridge/transforms/time_as_feature.py +++ b/ax/modelbridge/transforms/time_as_feature.py @@ -44,10 +44,10 @@ class TimeAsFeature(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert observations is not None, "TimeAsFeature requires observations" if isinstance(search_space, RobustSearchSpace): @@ -77,9 +77,7 @@ def __init__( # no need to case-distinguish during normalization self.duration_range = 1.0 - def _get_duration( - self, start_time: float, end_time: Optional[pd.Timestamp] - ) -> float: + def _get_duration(self, start_time: float, end_time: pd.Timestamp | None) -> float: return ( self.current_time if end_time is None else end_time.timestamp() ) - start_time diff --git a/ax/modelbridge/transforms/transform_to_new_sq.py b/ax/modelbridge/transforms/transform_to_new_sq.py index 7d609eb0ebc..623d710974e 100644 --- a/ax/modelbridge/transforms/transform_to_new_sq.py +++ b/ax/modelbridge/transforms/transform_to_new_sq.py @@ -8,8 +8,10 @@ from __future__ import annotations +from collections.abc import Callable + from math import sqrt -from typing import Callable, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING import numpy as np from ax.core.observation import Observation, ObservationData, ObservationFeatures @@ -41,10 +43,10 @@ class TransformToNewSQ(BaseRelativize): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, - config: Optional[TConfig] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: modelbridge_module.base.ModelBridge | None = None, + config: TConfig | None = None, ) -> None: super().__init__( search_space=search_space, @@ -69,15 +71,15 @@ def control_as_constant(self) -> bool: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, - fixed_features: Optional[ObservationFeatures] = None, + modelbridge: modelbridge_module.base.ModelBridge | None = None, + fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: return optimization_config def untransform_outcome_constraints( self, outcome_constraints: list[OutcomeConstraint], - fixed_features: Optional[ObservationFeatures] = None, + fixed_features: ObservationFeatures | None = None, ) -> list[OutcomeConstraint]: return outcome_constraints @@ -170,7 +172,7 @@ def _get_rel_mean_sem( sem_c: float, metric: str, rel_op: Callable[..., tuple[np.ndarray, np.ndarray]], - ) -> Tuple[float, float]: + ) -> tuple[float, float]: """Compute (un)transformed mean and sem for a single metric.""" target_status_quo_data = self.status_quo_data_by_trial[self.default_trial_idx] j = get_metric_index(data=target_status_quo_data, metric_name=metric) diff --git a/ax/modelbridge/transforms/trial_as_task.py b/ax/modelbridge/transforms/trial_as_task.py index 1dcdfdeba0d..f7c6e357c07 100644 --- a/ax/modelbridge/transforms/trial_as_task.py +++ b/ax/modelbridge/transforms/trial_as_task.py @@ -7,7 +7,7 @@ # pyre-strict from logging import Logger -from typing import Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING from ax.core.observation import Observation, ObservationFeatures from ax.core.parameter import ChoiceParameter, ParameterType @@ -54,10 +54,10 @@ class TrialAsTask(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert observations is not None, "TrialAsTask requires observations" # Identify values of trial. @@ -74,11 +74,11 @@ def __init__( # Get trial level map if config is not None and "trial_level_map" in config: # pyre-ignore [9] - trial_level_map: dict[str, dict[Union[int, str], Union[int, str]]] = config[ + trial_level_map: dict[str, dict[int | str, int | str]] = config[ "trial_level_map" ] # Validate - self.trial_level_map: dict[str, dict[int, Union[int, str]]] = {} + self.trial_level_map: dict[str, dict[int, int | str]] = {} for _p_name, level_dict in trial_level_map.items(): # cast trial index as an integer int_keyed_level_dict = { @@ -99,13 +99,13 @@ def __init__( self.trial_level_map = {TRIAL_PARAM: {int(b): str(b) for b in trials}} if len(self.trial_level_map) == 1: level_dict = next(iter(self.trial_level_map.values())) - self.inverse_map: Optional[dict[Union[int, str], int]] = { + self.inverse_map: dict[int | str, int] | None = { v: k for k, v in level_dict.items() } else: self.inverse_map = None # Compute target values - self.target_values: dict[str, Union[int, str]] = {} + self.target_values: dict[str, int | str] = {} for p_name, trial_map in self.trial_level_map.items(): if config is not None and "target_trial" in config: target_trial = int(config["target_trial"]) # pyre-ignore [6] diff --git a/ax/modelbridge/transforms/unit_x.py b/ax/modelbridge/transforms/unit_x.py index b7f5eabf00f..43b24025b12 100644 --- a/ax/modelbridge/transforms/unit_x.py +++ b/ax/modelbridge/transforms/unit_x.py @@ -38,10 +38,10 @@ class UnitX(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: assert search_space is not None, "UnitX requires search space" # Identify parameters that should be transformed diff --git a/ax/modelbridge/transforms/utils.py b/ax/modelbridge/transforms/utils.py index 1f705b9cf32..8b961cb26f5 100644 --- a/ax/modelbridge/transforms/utils.py +++ b/ax/modelbridge/transforms/utils.py @@ -9,9 +9,10 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Callable from math import isnan from numbers import Number -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING import numpy as np from ax.core.observation import Observation, ObservationData, ObservationFeatures @@ -67,7 +68,7 @@ def __getitem__(self, key: Number) -> Any: def get_data( observation_data: list[ObservationData], - metric_names: Union[list[str], None] = None, + metric_names: list[str] | None = None, raise_on_non_finite_data: bool = True, ) -> dict[str, list[float]]: """Extract all metrics if `metric_names` is None. @@ -122,7 +123,7 @@ def match_ci_width_truncated( def construct_new_search_space( search_space: SearchSpace, parameters: list[Parameter], - parameter_constraints: Optional[list[ParameterConstraint]] = None, + parameter_constraints: list[ParameterConstraint] | None = None, ) -> SearchSpace: """Construct a search space with the transformed arguments. @@ -154,8 +155,8 @@ def construct_new_search_space( def derelativize_optimization_config_with_raw_status_quo( optimization_config: OptimizationConfig, - modelbridge: "modelbridge_module.base.ModelBridge", - observations: Optional[list[Observation]], + modelbridge: modelbridge_module.base.ModelBridge, + observations: list[Observation] | None, ) -> OptimizationConfig: """Derelativize optimization_config using raw status-quo values""" tf = Derelativize( diff --git a/ax/modelbridge/transforms/winsorize.py b/ax/modelbridge/transforms/winsorize.py index 0424755ae76..674245a26db 100644 --- a/ax/modelbridge/transforms/winsorize.py +++ b/ax/modelbridge/transforms/winsorize.py @@ -8,7 +8,7 @@ import warnings from logging import Logger -from typing import Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING import numpy as np from ax.core.objective import MultiObjective, ScalarizedObjective @@ -92,10 +92,10 @@ class Winsorize(Transform): def __init__( self, - search_space: Optional[SearchSpace] = None, - observations: Optional[list[Observation]] = None, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, - config: Optional[TConfig] = None, + config: TConfig | None = None, ) -> None: if observations is None or len(observations) == 0: raise DataRequiredError("`Winsorize` transform requires non-empty data.") @@ -171,10 +171,10 @@ def _transform_observation_data( def _get_cutoffs( metric_name: str, metric_values: list[float], - winsorization_config: Union[WinsorizationConfig, dict[str, WinsorizationConfig]], + winsorization_config: WinsorizationConfig | dict[str, WinsorizationConfig], modelbridge: Optional["modelbridge_module.base.ModelBridge"], - observations: Optional[list[Observation]], - optimization_config: Optional[OptimizationConfig], + observations: list[Observation] | None, + optimization_config: OptimizationConfig | None, use_raw_sq: bool, ) -> tuple[float, float]: # (1) Use the same config for all metrics if one WinsorizationConfig was specified @@ -373,7 +373,7 @@ def _get_auto_winsorization_cutoffs_single_objective( def _get_auto_winsorization_cutoffs_outcome_constraint( metric_values: list[float], - outcome_constraints: Union[list[ObjectiveThreshold], list[OutcomeConstraint]], + outcome_constraints: list[ObjectiveThreshold] | list[OutcomeConstraint], ) -> tuple[float, float]: """Automatic winsorization to an outcome constraint. diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index d8f5248dc97..2e5b82d55f4 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -7,7 +7,6 @@ from abc import abstractmethod from logging import Logger -from typing import Optional from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment @@ -48,14 +47,14 @@ class TransitionCriterion(SortableBase, SerializationMixin): create a given ``BatchTrial``. """ - _transition_to: Optional[str] = None + _transition_to: str | None = None def __init__( self, - transition_to: Optional[str] = None, - block_transition_if_unmet: Optional[bool] = True, - block_gen_if_met: Optional[bool] = False, - continue_trial_generation: Optional[bool] = False, + transition_to: str | None = None, + block_transition_if_unmet: bool | None = True, + block_gen_if_met: bool | None = False, + continue_trial_generation: bool | None = False, ) -> None: self._transition_to = transition_to self.block_transition_if_unmet = block_transition_if_unmet @@ -63,7 +62,7 @@ def __init__( self.continue_trial_generation = continue_trial_generation @property - def transition_to(self) -> Optional[str]: + def transition_to(self) -> str | None: """The name of the next GenerationNode after this TransitionCriterion is completed, if it exists. """ @@ -73,9 +72,9 @@ def transition_to(self) -> Optional[str]: def is_met( self, experiment: Experiment, - trials_from_node: Optional[set[int]] = None, - node_that_generated_last_gr: Optional[str] = None, - curr_node_name: Optional[str] = None, + trials_from_node: set[int] | None = None, + node_that_generated_last_gr: str | None = None, + curr_node_name: str | None = None, ) -> bool: """If the criterion of this TransitionCriterion is met, returns True.""" pass @@ -83,10 +82,10 @@ def is_met( @abstractmethod def block_continued_generation_error( self, - node_name: Optional[str], - model_name: Optional[str], - experiment: Optional[Experiment], - trials_from_node: Optional[set[int]] = None, + node_name: str | None, + model_name: str | None, + experiment: Experiment | None, + trials_from_node: set[int] | None = None, ) -> None: """Error to be raised if the `block_gen_if_met` flag is set to True.""" pass @@ -127,8 +126,8 @@ class AutoTransitionAfterGen(TransitionCriterion): def __init__( self, transition_to: str, - block_transition_if_unmet: Optional[bool] = True, - continue_trial_generation: Optional[bool] = True, + block_transition_if_unmet: bool | None = True, + continue_trial_generation: bool | None = True, ) -> None: super().__init__( transition_to=transition_to, @@ -139,9 +138,9 @@ def __init__( def is_met( self, experiment: Experiment, - trials_from_node: Optional[set[int]] = None, - node_that_generated_last_gr: Optional[str] = None, - curr_node_name: Optional[str] = None, + trials_from_node: set[int] | None = None, + node_that_generated_last_gr: str | None = None, + curr_node_name: str | None = None, ) -> bool: """Return True as soon as any GeneratorRun is generated by this GenerationNode. @@ -150,10 +149,10 @@ def is_met( def block_continued_generation_error( self, - node_name: Optional[str], - model_name: Optional[str], - experiment: Optional[Experiment], - trials_from_node: Optional[set[int]] = None, + node_name: str | None, + model_name: str | None, + experiment: Experiment | None, + trials_from_node: set[int] | None = None, ) -> None: """Error to be raised if the `block_gen_if_met` flag is set to True.""" pass @@ -194,13 +193,13 @@ class TrialBasedCriterion(TransitionCriterion): def __init__( self, threshold: int, - block_transition_if_unmet: Optional[bool] = True, - block_gen_if_met: Optional[bool] = False, - only_in_statuses: Optional[list[TrialStatus]] = None, - not_in_statuses: Optional[list[TrialStatus]] = None, - transition_to: Optional[str] = None, - use_all_trials_in_exp: Optional[bool] = False, - continue_trial_generation: Optional[bool] = False, + block_transition_if_unmet: bool | None = True, + block_gen_if_met: bool | None = False, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + transition_to: str | None = None, + use_all_trials_in_exp: bool | None = False, + continue_trial_generation: bool | None = False, ) -> None: self.threshold = threshold self.only_in_statuses = only_in_statuses @@ -252,7 +251,7 @@ def all_trials_to_check(self, experiment: Experiment) -> set[int]: return trials_to_check def num_contributing_to_threshold( - self, experiment: Experiment, trials_from_node: Optional[set[int]] + self, experiment: Experiment, trials_from_node: set[int] | None ) -> int: """Returns the number of trials contributing to the threshold. @@ -275,7 +274,7 @@ def num_contributing_to_threshold( return len(trials_from_node.intersection(all_trials_to_check)) def num_till_threshold( - self, experiment: Experiment, trials_from_node: Optional[set[int]] + self, experiment: Experiment, trials_from_node: set[int] | None ) -> int: """Returns the number of trials needed to meet the threshold. @@ -290,10 +289,10 @@ def num_till_threshold( def is_met( self, experiment: Experiment, - trials_from_node: Optional[set[int]] = None, - block_continued_generation: Optional[bool] = False, - node_that_generated_last_gr: Optional[str] = None, - curr_node_name: Optional[str] = None, + trials_from_node: set[int] | None = None, + block_continued_generation: bool | None = False, + node_that_generated_last_gr: str | None = None, + curr_node_name: str | None = None, ) -> bool: """Returns if this criterion has been met given its constraints. Args: @@ -356,13 +355,13 @@ class MaxGenerationParallelism(TrialBasedCriterion): def __init__( self, threshold: int, - only_in_statuses: Optional[list[TrialStatus]] = None, - not_in_statuses: Optional[list[TrialStatus]] = None, - transition_to: Optional[str] = None, - block_transition_if_unmet: Optional[bool] = False, - block_gen_if_met: Optional[bool] = True, - use_all_trials_in_exp: Optional[bool] = False, - continue_trial_generation: Optional[bool] = True, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + transition_to: str | None = None, + block_transition_if_unmet: bool | None = False, + block_gen_if_met: bool | None = True, + use_all_trials_in_exp: bool | None = False, + continue_trial_generation: bool | None = True, ) -> None: super().__init__( threshold=threshold, @@ -377,10 +376,10 @@ def __init__( def block_continued_generation_error( self, - node_name: Optional[str], - model_name: Optional[str], - experiment: Optional[Experiment], - trials_from_node: Optional[set[int]] = None, + node_name: str | None, + model_name: str | None, + experiment: Experiment | None, + trials_from_node: set[int] | None = None, ) -> None: """If the block_continued_generation flag is set, raises the MaxParallelismReachedException error. @@ -438,13 +437,13 @@ class MaxTrials(TrialBasedCriterion): def __init__( self, threshold: int, - only_in_statuses: Optional[list[TrialStatus]] = None, - not_in_statuses: Optional[list[TrialStatus]] = None, - transition_to: Optional[str] = None, - block_transition_if_unmet: Optional[bool] = True, - block_gen_if_met: Optional[bool] = False, - use_all_trials_in_exp: Optional[bool] = False, - continue_trial_generation: Optional[bool] = False, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + transition_to: str | None = None, + block_transition_if_unmet: bool | None = True, + block_gen_if_met: bool | None = False, + use_all_trials_in_exp: bool | None = False, + continue_trial_generation: bool | None = False, ) -> None: super().__init__( threshold=threshold, @@ -459,10 +458,10 @@ def __init__( def block_continued_generation_error( self, - node_name: Optional[str], - model_name: Optional[str], - experiment: Optional[Experiment], - trials_from_node: Optional[set[int]] = None, + node_name: str | None, + model_name: str | None, + experiment: Experiment | None, + trials_from_node: set[int] | None = None, ) -> None: """If the block_continued_generation flag is set, raises an error because the remaining TransitionCriterion cannot be completed in the current state. @@ -514,13 +513,13 @@ class MinTrials(TrialBasedCriterion): def __init__( self, threshold: int, - only_in_statuses: Optional[list[TrialStatus]] = None, - not_in_statuses: Optional[list[TrialStatus]] = None, - transition_to: Optional[str] = None, - block_transition_if_unmet: Optional[bool] = True, - block_gen_if_met: Optional[bool] = False, - use_all_trials_in_exp: Optional[bool] = False, - continue_trial_generation: Optional[bool] = False, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + transition_to: str | None = None, + block_transition_if_unmet: bool | None = True, + block_gen_if_met: bool | None = False, + use_all_trials_in_exp: bool | None = False, + continue_trial_generation: bool | None = False, ) -> None: super().__init__( threshold=threshold, @@ -535,10 +534,10 @@ def __init__( def block_continued_generation_error( self, - node_name: Optional[str], - model_name: Optional[str], - experiment: Optional[Experiment], - trials_from_node: Optional[set[int]] = None, + node_name: str | None, + model_name: str | None, + experiment: Experiment | None, + trials_from_node: set[int] | None = None, ) -> None: """If the enforce flag is set, raises an error because the remaining TransitionCriterion cannot be completed in the current state. @@ -578,9 +577,9 @@ def __init__( self, metric_name: str, threshold: int, - transition_to: Optional[str] = None, - block_gen_if_met: Optional[bool] = False, - block_transition_if_unmet: Optional[bool] = True, + transition_to: str | None = None, + block_gen_if_met: bool | None = False, + block_transition_if_unmet: bool | None = True, ) -> None: self.metric_name = metric_name self.threshold = threshold @@ -593,9 +592,9 @@ def __init__( def is_met( self, experiment: Experiment, - trials_from_node: Optional[set[int]] = None, - node_that_generated_last_gr: Optional[str] = None, - curr_node_name: Optional[str] = None, + trials_from_node: set[int] | None = None, + node_that_generated_last_gr: str | None = None, + curr_node_name: str | None = None, ) -> bool: # TODO: @mgarrard replace fetch_data with lookup_data data = experiment.fetch_data(metrics=[experiment.metrics[self.metric_name]]) @@ -607,10 +606,10 @@ def is_met( def block_continued_generation_error( self, - node_name: Optional[str], - model_name: Optional[str], - experiment: Optional[Experiment], - trials_from_node: Optional[set[int]] = None, + node_name: str | None, + model_name: str | None, + experiment: Experiment | None, + trials_from_node: set[int] | None = None, ) -> None: pass @@ -625,7 +624,7 @@ def __init__( self, status: TrialStatus, threshold: int, - transition_to: Optional[str] = None, + transition_to: str | None = None, ) -> None: self.status = status self.threshold = threshold @@ -634,17 +633,17 @@ def __init__( def is_met( self, experiment: Experiment, - trials_from_node: Optional[set[int]] = None, - node_that_generated_last_gr: Optional[str] = None, - curr_node_name: Optional[str] = None, + trials_from_node: set[int] | None = None, + node_that_generated_last_gr: str | None = None, + curr_node_name: str | None = None, ) -> bool: return len(experiment.trial_indices_by_status[self.status]) >= self.threshold def block_continued_generation_error( self, - node_name: Optional[str], - model_name: Optional[str], - experiment: Optional[Experiment], - trials_from_node: Optional[set[int]] = None, + node_name: str | None, + model_name: str | None, + experiment: Experiment | None, + trials_from_node: set[int] | None = None, ) -> None: pass diff --git a/ax/models/discrete/full_factorial.py b/ax/models/discrete/full_factorial.py index a1eb8286f1a..3246c5274e7 100644 --- a/ax/models/discrete/full_factorial.py +++ b/ax/models/discrete/full_factorial.py @@ -10,7 +10,6 @@ import logging from functools import reduce from operator import mul -from typing import Optional import numpy as np from ax.core.types import TGenMetadata, TParamValue, TParamValueList @@ -53,11 +52,11 @@ def gen( self, n: int, parameter_values: list[TParamValueList], - objective_weights: Optional[np.ndarray], - outcome_constraints: Optional[tuple[np.ndarray, np.ndarray]] = None, - fixed_features: Optional[dict[int, TParamValue]] = None, - pending_observations: Optional[list[list[TParamValueList]]] = None, - model_gen_options: Optional[TConfig] = None, + objective_weights: np.ndarray | None, + outcome_constraints: tuple[np.ndarray, np.ndarray] | None = None, + fixed_features: dict[int, TParamValue] | None = None, + pending_observations: list[list[TParamValueList]] | None = None, + model_gen_options: TConfig | None = None, ) -> tuple[list[TParamValueList], list[float], TGenMetadata]: if n != -1: logger.warning( diff --git a/ax/models/discrete/thompson.py b/ax/models/discrete/thompson.py index 08f3e81c8d1..103180a95ff 100644 --- a/ax/models/discrete/thompson.py +++ b/ax/models/discrete/thompson.py @@ -8,7 +8,6 @@ import hashlib import json -from typing import Optional import numpy as np from ax.core.types import TGenMetadata, TParamValue, TParamValueList @@ -30,7 +29,7 @@ class ThompsonSampler(DiscreteModel): def __init__( self, num_samples: int = 10000, - min_weight: Optional[float] = None, + min_weight: float | None = None, uniform_weights: bool = False, ) -> None: """ @@ -77,11 +76,11 @@ def gen( self, n: int, parameter_values: list[TParamValueList], - objective_weights: Optional[np.ndarray], - outcome_constraints: Optional[tuple[np.ndarray, np.ndarray]] = None, - fixed_features: Optional[dict[int, TParamValue]] = None, - pending_observations: Optional[list[list[TParamValueList]]] = None, - model_gen_options: Optional[TConfig] = None, + objective_weights: np.ndarray | None, + outcome_constraints: tuple[np.ndarray, np.ndarray] | None = None, + fixed_features: dict[int, TParamValue] | None = None, + pending_observations: list[list[TParamValueList]] | None = None, + model_gen_options: TConfig | None = None, ) -> tuple[list[TParamValueList], list[float], TGenMetadata]: if objective_weights is None: raise ValueError("ThompsonSampler requires objective weights.") @@ -143,7 +142,7 @@ def predict(self, X: list[TParamValueList]) -> tuple[np.ndarray, np.ndarray]: def _generate_weights( self, objective_weights: np.ndarray, - outcome_constraints: Optional[tuple[np.ndarray, np.ndarray]] = None, + outcome_constraints: tuple[np.ndarray, np.ndarray] | None = None, ) -> list[float]: samples, fraction_all_infeasible = self._produce_samples( num_samples=self.num_samples, @@ -192,7 +191,7 @@ def _produce_samples( self, num_samples: int, objective_weights: np.ndarray, - outcome_constraints: Optional[tuple[np.ndarray, np.ndarray]], + outcome_constraints: tuple[np.ndarray, np.ndarray] | None, ) -> tuple[np.ndarray, float]: k = len(self.X) samples_per_metric = self._generate_samples_per_metric(num_samples=num_samples) diff --git a/ax/models/discrete_base.py b/ax/models/discrete_base.py index 95f7f1dfa5d..a5bdadd1f9c 100644 --- a/ax/models/discrete_base.py +++ b/ax/models/discrete_base.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Optional import numpy as np from ax.core.types import TGenMetadata, TParamValue, TParamValueList @@ -62,11 +61,11 @@ def gen( self, n: int, parameter_values: list[TParamValueList], - objective_weights: Optional[np.ndarray], - outcome_constraints: Optional[tuple[np.ndarray, np.ndarray]] = None, - fixed_features: Optional[dict[int, TParamValue]] = None, - pending_observations: Optional[list[list[TParamValueList]]] = None, - model_gen_options: Optional[TConfig] = None, + objective_weights: np.ndarray | None, + outcome_constraints: tuple[np.ndarray, np.ndarray] | None = None, + fixed_features: dict[int, TParamValue] | None = None, + pending_observations: list[list[TParamValueList]] | None = None, + model_gen_options: TConfig | None = None, ) -> tuple[list[TParamValueList], list[float], TGenMetadata]: """ Generate new candidates. @@ -134,12 +133,12 @@ def best_point( self, n: int, parameter_values: list[TParamValueList], - objective_weights: Optional[np.ndarray], - outcome_constraints: Optional[tuple[np.ndarray, np.ndarray]] = None, - fixed_features: Optional[dict[int, TParamValue]] = None, - pending_observations: Optional[list[list[TParamValueList]]] = None, - model_gen_options: Optional[TConfig] = None, - ) -> Optional[TParamValueList]: + objective_weights: np.ndarray | None, + outcome_constraints: tuple[np.ndarray, np.ndarray] | None = None, + fixed_features: dict[int, TParamValue] | None = None, + pending_observations: list[list[TParamValueList]] | None = None, + model_gen_options: TConfig | None = None, + ) -> TParamValueList | None: """Obtains the point that has the best value according to the model prediction and its model predictions. diff --git a/ax/models/model_utils.py b/ax/models/model_utils.py index 9cecadef3bb..d0896a75618 100644 --- a/ax/models/model_utils.py +++ b/ax/models/model_utils.py @@ -10,8 +10,8 @@ import itertools import warnings -from collections.abc import Mapping -from typing import Callable, Optional, Protocol, Union +from collections.abc import Callable, Mapping +from typing import Protocol, Union import numpy as np import torch @@ -49,17 +49,17 @@ def predict(self, X: Tensor) -> tuple[Tensor, Tensor]: def rejection_sample( gen_unconstrained: Callable[ - [int, int, np.ndarray, Optional[dict[int, float]]], np.ndarray + [int, int, np.ndarray, dict[int, float] | None], np.ndarray ], n: int, d: int, tunable_feature_indices: np.ndarray, - linear_constraints: Optional[tuple[np.ndarray, np.ndarray]] = None, + linear_constraints: tuple[np.ndarray, np.ndarray] | None = None, deduplicate: bool = False, - max_draws: Optional[int] = None, - fixed_features: Optional[dict[int, float]] = None, - rounding_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, - existing_points: Optional[np.ndarray] = None, + max_draws: int | None = None, + fixed_features: dict[int, float] | None = None, + rounding_func: Callable[[np.ndarray], np.ndarray] | None = None, + existing_points: np.ndarray | None = None, ) -> tuple[np.ndarray, int]: """Rejection sample in parameter space. Parameter space is typically [0, 1] for all tunable parameters. @@ -175,7 +175,7 @@ def check_duplicate(point: np.ndarray, points: np.ndarray) -> bool: def add_fixed_features( tunable_points: np.ndarray, d: int, - fixed_features: Optional[dict[int, float]], + fixed_features: dict[int, float] | None, tunable_feature_indices: np.ndarray, ) -> np.ndarray: """Add fixed features to points in tunable space. @@ -227,7 +227,7 @@ def check_param_constraints( def tunable_feature_indices( - bounds: list[tuple[float, float]], fixed_features: Optional[dict[int, float]] = None + bounds: list[tuple[float, float]], fixed_features: dict[int, float] | None = None ) -> np.ndarray: """Get the feature indices of tunable features. @@ -270,13 +270,13 @@ def validate_bounds( def best_observed_point( model: TorchModelLike, bounds: list[tuple[float, float]], - objective_weights: Optional[Tensoray], - outcome_constraints: Optional[tuple[Tensoray, Tensoray]] = None, - linear_constraints: Optional[tuple[Tensoray, Tensoray]] = None, - fixed_features: Optional[dict[int, float]] = None, - risk_measure: Optional[RiskMeasureMCObjective] = None, - options: Optional[TConfig] = None, -) -> Optional[Tensoray]: + objective_weights: Tensoray | None, + outcome_constraints: tuple[Tensoray, Tensoray] | None = None, + linear_constraints: tuple[Tensoray, Tensoray] | None = None, + fixed_features: dict[int, float] | None = None, + risk_measure: RiskMeasureMCObjective | None = None, + options: TConfig | None = None, +) -> Tensoray | None: """Select the best point that has been observed. Implements two approaches to selecting the best point. @@ -343,16 +343,16 @@ def best_observed_point( def best_in_sample_point( - Xs: Union[list[torch.Tensor], list[np.ndarray]], + Xs: list[torch.Tensor] | list[np.ndarray], model: TorchModelLike, bounds: list[tuple[float, float]], - objective_weights: Optional[Tensoray], - outcome_constraints: Optional[tuple[Tensoray, Tensoray]] = None, - linear_constraints: Optional[tuple[Tensoray, Tensoray]] = None, - fixed_features: Optional[dict[int, float]] = None, - risk_measure: Optional[RiskMeasureMCObjective] = None, - options: Optional[TConfig] = None, -) -> Optional[tuple[Tensoray, float]]: + objective_weights: Tensoray | None, + outcome_constraints: tuple[Tensoray, Tensoray] | None = None, + linear_constraints: tuple[Tensoray, Tensoray] | None = None, + fixed_features: dict[int, float] | None = None, + risk_measure: RiskMeasureMCObjective | None = None, + options: TConfig | None = None, +) -> tuple[Tensoray, float] | None: """Select the best point that has been observed. Implements two approaches to selecting the best point. @@ -414,7 +414,7 @@ def best_in_sample_point( if options is None: options = {} method: str = options.get("best_point_method", "max_utility") - B: Optional[float] = options.get("utility_baseline", None) + B: float | None = options.get("utility_baseline", None) threshold: float = options.get("probability_threshold", 0.95) nsamp: int = options.get("feasibility_mc_samples", 10000) # Get points observed for all objective and constraint outcomes @@ -468,9 +468,7 @@ def best_in_sample_point( return X_obs[i, :], utility[i] -def as_array( - x: Union[Tensoray, tuple[Tensoray, ...]] -) -> Union[np.ndarray, tuple[np.ndarray, ...]]: +def as_array(x: Tensoray | tuple[Tensoray, ...]) -> np.ndarray | tuple[np.ndarray, ...]: """Convert every item in a tuple of tensors/arrays into an array. Args: @@ -490,9 +488,9 @@ def as_array( def get_observed( - Xs: Union[list[torch.Tensor], list[np.ndarray]], + Xs: list[torch.Tensor] | list[np.ndarray], objective_weights: Tensoray, - outcome_constraints: Optional[tuple[Tensoray, Tensoray]] = None, + outcome_constraints: tuple[Tensoray, Tensoray] | None = None, ) -> Tensoray: """Filter points to those that are observed for objective outcomes and outcomes that show up in outcome_constraints (if there are any). @@ -537,8 +535,8 @@ def get_observed( def filter_constraints_and_fixed_features( X: Tensoray, bounds: list[tuple[float, float]], - linear_constraints: Optional[tuple[Tensoray, Tensoray]] = None, - fixed_features: Optional[dict[int, float]] = None, + linear_constraints: tuple[Tensoray, Tensoray] | None = None, + fixed_features: dict[int, float] | None = None, ) -> Tensoray: """Filter points to those that satisfy bounds, linear_constraints, and fixed_features. @@ -578,8 +576,8 @@ def filter_constraints_and_fixed_features( def mk_discrete_choices( ssd: SearchSpaceDigest, - fixed_features: Optional[dict[int, float]] = None, -) -> Mapping[int, list[Union[int, float]]]: + fixed_features: dict[int, float] | None = None, +) -> Mapping[int, list[int | float]]: discrete_choices = ssd.discrete_choices # Add in fixed features. if fixed_features is not None: @@ -592,8 +590,8 @@ def mk_discrete_choices( def enumerate_discrete_combinations( - discrete_choices: Mapping[int, list[Union[int, float]]], -) -> list[dict[int, Union[float, int]]]: + discrete_choices: Mapping[int, list[int | float]], +) -> list[dict[int, float | int]]: n_combos = np.prod([len(v) for v in discrete_choices.values()]) if n_combos > 50: warnings.warn( diff --git a/ax/models/random/base.py b/ax/models/random/base.py index 252cbfc2449..ff70ed1b98d 100644 --- a/ax/models/random/base.py +++ b/ax/models/random/base.py @@ -6,8 +6,9 @@ # pyre-strict +from collections.abc import Callable from logging import Logger -from typing import Any, Callable, Optional +from typing import Any import numpy as np import torch @@ -62,9 +63,9 @@ class RandomModel(Model): def __init__( self, deduplicate: bool = True, - seed: Optional[int] = None, + seed: int | None = None, init_position: int = 0, - generated_points: Optional[np.ndarray] = None, + generated_points: np.ndarray | None = None, fallback_to_sample_polytope: bool = False, ) -> None: super().__init__() @@ -84,10 +85,10 @@ def gen( self, n: int, bounds: list[tuple[float, float]], - linear_constraints: Optional[tuple[np.ndarray, np.ndarray]] = None, - fixed_features: Optional[dict[int, float]] = None, - model_gen_options: Optional[TConfig] = None, - rounding_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + linear_constraints: tuple[np.ndarray, np.ndarray] | None = None, + fixed_features: dict[int, float] | None = None, + model_gen_options: TConfig | None = None, + rounding_func: Callable[[np.ndarray], np.ndarray] | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Generate new candidates. @@ -200,7 +201,7 @@ def _gen_unconstrained( n: int, d: int, tunable_feature_indices: np.ndarray, - fixed_features: Optional[dict[int, float]] = None, + fixed_features: dict[int, float] | None = None, ) -> np.ndarray: """Generate n points, from an unconstrained parameter space, using _gen_samples. @@ -237,8 +238,8 @@ def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray: raise NotImplementedError("Base RandomModel can't generate samples.") def _convert_inequality_constraints( - self, linear_constraints: Optional[tuple[np.ndarray, np.ndarray]] - ) -> Optional[tuple[Tensor, Tensor]]: + self, linear_constraints: tuple[np.ndarray, np.ndarray] | None + ) -> tuple[Tensor, Tensor] | None: """Helper method to convert inequality constraints used by the rejection sampler to the format required for the polytope sampler. @@ -258,8 +259,8 @@ def _convert_inequality_constraints( return A, b def _convert_equality_constraints( - self, d: int, fixed_features: Optional[dict[int, float]] - ) -> Optional[tuple[Tensor, Tensor]]: + self, d: int, fixed_features: dict[int, float] | None + ) -> tuple[Tensor, Tensor] | None: """Helper method to convert the fixed feature dictionary used by the rejection sampler to the corresponding matrix representation required for the polytope sampler. @@ -286,7 +287,7 @@ def _convert_equality_constraints( constraint_matrix[index, fixed_indices[index]] = 1.0 return constraint_matrix, fixed_vals - def _convert_bounds(self, bounds: list[tuple[float, float]]) -> Optional[Tensor]: + def _convert_bounds(self, bounds: list[tuple[float, float]]) -> Tensor | None: """Helper method to convert bounds list used by the rejectionsampler to the tensor format required for the polytope sampler. @@ -302,7 +303,7 @@ def _convert_bounds(self, bounds: list[tuple[float, float]]) -> Optional[Tensor] else: return torch.tensor(bounds, dtype=torch.double).transpose(-1, -2) - def _get_last_point(self) -> Optional[Tensor]: + def _get_last_point(self) -> Tensor | None: # Return the last sampled point when points have been sampled if self.generated_points is None: return None diff --git a/ax/models/random/sobol.py b/ax/models/random/sobol.py index 407747a75bd..8bf9eb142ef 100644 --- a/ax/models/random/sobol.py +++ b/ax/models/random/sobol.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Callable, Optional +from collections.abc import Callable import numpy as np import torch @@ -32,10 +32,10 @@ class SobolGenerator(RandomModel): def __init__( self, deduplicate: bool = True, - seed: Optional[int] = None, + seed: int | None = None, init_position: int = 0, scramble: bool = True, - generated_points: Optional[np.ndarray] = None, + generated_points: np.ndarray | None = None, fallback_to_sample_polytope: bool = False, ) -> None: super().__init__( @@ -47,7 +47,7 @@ def __init__( ) self.scramble = scramble # Initialize engine on gen. - self._engine: Optional[SobolEngine] = None + self._engine: SobolEngine | None = None def init_engine(self, n_tunable_features: int) -> SobolEngine: """Initialize singleton SobolEngine, only on gen. @@ -67,7 +67,7 @@ def init_engine(self, n_tunable_features: int) -> SobolEngine: return self._engine @property - def engine(self) -> Optional[SobolEngine]: + def engine(self) -> SobolEngine | None: """Return a singleton SobolEngine.""" return self._engine @@ -75,10 +75,10 @@ def gen( self, n: int, bounds: list[tuple[float, float]], - linear_constraints: Optional[tuple[np.ndarray, np.ndarray]] = None, - fixed_features: Optional[dict[int, float]] = None, - model_gen_options: Optional[TConfig] = None, - rounding_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + linear_constraints: tuple[np.ndarray, np.ndarray] | None = None, + fixed_features: dict[int, float] | None = None, + model_gen_options: TConfig | None = None, + rounding_func: Callable[[np.ndarray], np.ndarray] | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Generate new candidates. diff --git a/ax/models/random/uniform.py b/ax/models/random/uniform.py index d99650b974a..41b9d2e9aa2 100644 --- a/ax/models/random/uniform.py +++ b/ax/models/random/uniform.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Optional import numpy as np from ax.models.random.base import RandomModel @@ -24,9 +23,9 @@ class UniformGenerator(RandomModel): def __init__( self, deduplicate: bool = True, - seed: Optional[int] = None, + seed: int | None = None, init_position: int = 0, - generated_points: Optional[np.ndarray] = None, + generated_points: np.ndarray | None = None, fallback_to_sample_polytope: bool = False, ) -> None: super().__init__( diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 0be1a671598..087f286df9f 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -9,9 +9,10 @@ from __future__ import annotations import warnings +from collections.abc import Callable from copy import deepcopy from logging import Logger -from typing import Any, Callable, Optional, Union +from typing import Any, Optional import numpy as np import torch @@ -232,13 +233,13 @@ class BotorchModel(TorchModel): optimization problems. % TODO: refer to an example. """ - dtype: Optional[torch.dtype] - device: Optional[torch.device] + dtype: torch.dtype | None + device: torch.device | None Xs: list[Tensor] Ys: list[Tensor] Yvars: list[Tensor] - _model: Optional[Model] - _search_space_digest: Optional[SearchSpaceDigest] = None + _model: Model | None + _search_space_digest: SearchSpaceDigest | None = None def __init__( self, @@ -252,7 +253,7 @@ def __init__( warm_start_refitting: bool = True, use_input_warping: bool = False, use_loocv_pseudo_likelihood: bool = False, - prior: Optional[dict[str, Any]] = None, + prior: dict[str, Any] | None = None, **kwargs: Any, ) -> None: warnings.warn( @@ -277,7 +278,7 @@ def __init__( self.use_input_warping = use_input_warping self.use_loocv_pseudo_likelihood = use_loocv_pseudo_likelihood self.prior = prior - self._model: Optional[Model] = None + self._model: Model | None = None self.Xs = [] self.Ys = [] self.Yvars = [] @@ -292,7 +293,7 @@ def fit( self, datasets: list[SupervisedDataset], search_space_digest: SearchSpaceDigest, - candidate_metadata: Optional[list[list[TCandidateMetadata]]] = None, + candidate_metadata: list[list[TCandidateMetadata]] | None = None, ) -> None: if len(datasets) == 0: raise DataRequiredError("BotorchModel.fit requires non-empty data sets.") @@ -440,7 +441,7 @@ def best_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - ) -> Optional[Tensor]: + ) -> Tensor | None: if torch_opt_config.is_moo: raise NotImplementedError( "Best observed point is incompatible with MOO problems." @@ -523,8 +524,8 @@ def model(self, model: Model) -> None: def get_rounding_func( - rounding_func: Optional[Callable[[Tensor], Tensor]] -) -> Optional[Callable[[Tensor], Tensor]]: + rounding_func: Callable[[Tensor], Tensor] | None +) -> Callable[[Tensor], Tensor] | None: if rounding_func is None: botorch_rounding_func = rounding_func else: @@ -540,7 +541,7 @@ def botorch_rounding_func(X: Tensor) -> Tensor: def get_feature_importances_from_botorch_model( - model: Union[Model, ModuleList, None], + model: Model | ModuleList | None, ) -> np.ndarray: """Get feature importances from a list of BoTorch models. diff --git a/ax/models/torch/botorch_defaults.py b/ax/models/torch/botorch_defaults.py index 05e16b203a2..b6faa4b817a 100644 --- a/ax/models/torch/botorch_defaults.py +++ b/ax/models/torch/botorch_defaults.py @@ -7,9 +7,10 @@ # pyre-strict import functools +from collections.abc import Callable from copy import deepcopy from random import randint -from typing import Any, Callable, Optional, Protocol, Union +from typing import Any, Protocol import torch from ax.models.model_utils import best_observed_point, get_observed @@ -56,16 +57,16 @@ def _construct_model( - task_feature: Optional[int], + task_feature: int | None, Xs: list[Tensor], Ys: list[Tensor], Yvars: list[Tensor], fidelity_features: list[int], metric_names: list[str], use_input_warping: bool = False, - prior: Optional[dict[str, Any]] = None, + prior: dict[str, Any] | None = None, *, - multitask_gp_ranks: Optional[dict[str, Union[Prior, float]]] = None, + multitask_gp_ranks: dict[str, Prior | float] | None = None, **kwargs: Any, ) -> GPyTorchModel: """ @@ -146,13 +147,13 @@ def get_and_fit_model( task_features: list[int], fidelity_features: list[int], metric_names: list[str], - state_dict: Optional[dict[str, Tensor]] = None, + state_dict: dict[str, Tensor] | None = None, refit_model: bool = True, use_input_warping: bool = False, use_loocv_pseudo_likelihood: bool = False, - prior: Optional[dict[str, Any]] = None, + prior: dict[str, Any] | None = None, *, - multitask_gp_ranks: Optional[dict[str, Union[Prior, float]]] = None, + multitask_gp_ranks: dict[str, Prior | float] | None = None, **kwargs: Any, ) -> GPyTorchModel: r"""Instantiates and fits a botorch GPyTorchModel using the given data. @@ -242,9 +243,9 @@ def __call__( self, # making this a static method makes Pyre unhappy, better to keep `self` model: Model, objective_weights: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + X_pending: Tensor | None = None, **kwargs: Any, ) -> AcquisitionFunction: ... # pragma: no cover @@ -294,9 +295,9 @@ def decorator(empty_acqf_getter: Callable[[], None]) -> TAcqfConstructor: def wrapper( model: Model, objective_weights: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + X_pending: Tensor | None = None, **kwargs: Any, ) -> AcquisitionFunction: kwargs.pop("objective_thresholds", None) @@ -342,21 +343,21 @@ def _get_acquisition_func( model: Model, acquisition_function_name: str, objective_weights: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + X_pending: Tensor | None = None, mc_objective: type[GenericMCObjective] = GenericMCObjective, - constrained_mc_objective: Optional[ + constrained_mc_objective: None | ( type[ConstrainedMCObjective] - ] = ConstrainedMCObjective, + ) = ConstrainedMCObjective, # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict` to avoid runtime subscripting errors. - mc_objective_kwargs: Optional[dict] = None, + mc_objective_kwargs: dict | None = None, *, chebyshev_scalarization: bool = False, prune_baseline: bool = True, mc_samples: int = 512, - marginalize_dim: Optional[int] = None, + marginalize_dim: int | None = None, ) -> AcquisitionFunction: r"""Instantiates a acquisition function. @@ -414,7 +415,7 @@ def _get_acquisition_func( obj_tf = get_objective_weights_transform(objective_weights) # pyre-fixme[53]: Captured variable `obj_tf` is not annotated. - def objective(samples: Tensor, X: Optional[Tensor] = None) -> Tensor: + def objective(samples: Tensor, X: Tensor | None = None) -> Tensor: return obj_tf(samples) mc_objective_kwargs = {} if mc_objective_kwargs is None else mc_objective_kwargs @@ -459,15 +460,15 @@ def scipy_optimizer( acq_function: AcquisitionFunction, bounds: Tensor, n: int, - inequality_constraints: Optional[list[tuple[Tensor, Tensor, float]]] = None, - equality_constraints: Optional[list[tuple[Tensor, Tensor, float]]] = None, - fixed_features: Optional[dict[int, float]] = None, - rounding_func: Optional[Callable[[Tensor], Tensor]] = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + fixed_features: dict[int, float] | None = None, + rounding_func: Callable[[Tensor], Tensor] | None = None, *, num_restarts: int = 20, - raw_samples: Optional[int] = None, + raw_samples: int | None = None, joint_optimization: bool = False, - options: Optional[dict[str, Union[bool, float, int, str]]] = None, + options: dict[str, bool | float | int | str] | None = None, ) -> tuple[Tensor, Tensor]: r"""Optimizer using scipy's minimize module on a numpy-adpator. @@ -499,7 +500,7 @@ def scipy_optimizer( """ sequential = not joint_optimization - optimize_acqf_options: dict[str, Union[bool, float, int, str]] = { + optimize_acqf_options: dict[str, bool | float | int | str] = { "batch_limit": 5, "init_batch_limit": 32, } @@ -525,12 +526,12 @@ def recommend_best_observed_point( model: TorchModel, bounds: list[tuple[float, float]], objective_weights: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - linear_constraints: Optional[tuple[Tensor, Tensor]] = None, - fixed_features: Optional[dict[int, float]] = None, - model_gen_options: Optional[TConfig] = None, - target_fidelities: Optional[dict[int, float]] = None, -) -> Optional[Tensor]: + outcome_constraints: tuple[Tensor, Tensor] | None = None, + linear_constraints: tuple[Tensor, Tensor] | None = None, + fixed_features: dict[int, float] | None = None, + model_gen_options: TConfig | None = None, + target_fidelities: dict[int, float] | None = None, +) -> Tensor | None: """ A wrapper around `ax.models.model_utils.best_observed_point` for TorchModel that recommends a best point from previously observed points using either a @@ -581,12 +582,12 @@ def recommend_best_out_of_sample_point( model: TorchModel, bounds: list[tuple[float, float]], objective_weights: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - linear_constraints: Optional[tuple[Tensor, Tensor]] = None, - fixed_features: Optional[dict[int, float]] = None, - model_gen_options: Optional[TConfig] = None, - target_fidelities: Optional[dict[int, float]] = None, -) -> Optional[Tensor]: + outcome_constraints: tuple[Tensor, Tensor] | None = None, + linear_constraints: tuple[Tensor, Tensor] | None = None, + fixed_features: dict[int, float] | None = None, + model_gen_options: TConfig | None = None, + target_fidelities: dict[int, float] | None = None, +) -> Tensor | None: """ Identify the current best point by optimizing the posterior mean of the model. This is "out-of-sample" because it considers un-observed designs as well. @@ -652,7 +653,7 @@ def recommend_best_out_of_sample_point( if non_fixed_idcs is not None: bounds_ = bounds_[..., non_fixed_idcs] - opt_options: dict[str, Union[bool, float, int, str]] = { + opt_options: dict[str, bool | float | int | str] = { "batch_limit": 8, "maxiter": 200, "method": "L-BFGS-B", @@ -682,11 +683,11 @@ def _get_model( X: Tensor, Y: Tensor, Yvar: Tensor, - task_feature: Optional[int] = None, - fidelity_features: Optional[list[int]] = None, + task_feature: int | None = None, + fidelity_features: list[int] | None = None, use_input_warping: bool = False, - covar_module: Optional[Kernel] = None, - prior: Optional[dict[str, Any]] = None, + covar_module: Kernel | None = None, + prior: dict[str, Any] | None = None, **kwargs: Any, ) -> GPyTorchModel: """Instantiate a model of type depending on the input data. @@ -816,7 +817,7 @@ def _get_customized_covar_module( covar_module_prior_dict: dict[str, Prior], ard_num_dims: int, aug_batch_shape: torch.Size, - task_feature: Optional[int] = None, + task_feature: int | None = None, ) -> Kernel: """Construct a GP kernel based on customized prior dict. @@ -866,8 +867,8 @@ def _get_aug_batch_shape(X: Tensor, Y: Tensor) -> torch.Size: def get_warping_transform( d: int, - batch_shape: Optional[torch.Size] = None, - task_feature: Optional[int] = None, + batch_shape: torch.Size | None = None, + task_feature: int | None = None, ) -> Warp: """Construct input warping transform. diff --git a/ax/models/torch/botorch_kg.py b/ax/models/torch/botorch_kg.py index b3698b5fc6e..99905d788af 100644 --- a/ax/models/torch/botorch_kg.py +++ b/ax/models/torch/botorch_kg.py @@ -7,7 +7,8 @@ # pyre-strict import dataclasses -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any import torch from ax.core.search_space import SearchSpaceDigest @@ -209,13 +210,13 @@ def _get_best_point_acqf( X_observed: Tensor, objective_weights: Tensor, mc_samples: int = 512, - fixed_features: Optional[dict[int, float]] = None, - target_fidelities: Optional[dict[int, float]] = None, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - seed_inner: Optional[int] = None, + fixed_features: dict[int, float] | None = None, + target_fidelities: dict[int, float] | None = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + seed_inner: int | None = None, qmc: bool = True, **kwargs: Any, - ) -> tuple[AcquisitionFunction, Optional[list[int]]]: + ) -> tuple[AcquisitionFunction, list[int] | None]: return get_out_of_sample_best_point_acqf( model=not_none(self.model), Xs=self.Xs, @@ -235,7 +236,7 @@ def _get_current_value( search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, X_observed: Tensor, - seed_inner: Optional[int], + seed_inner: int | None, qmc: bool, ) -> Tensor: r"""Computes the value of the current best point. This is the current_value @@ -281,18 +282,18 @@ def _get_current_value( def _instantiate_KG( model: Model, - objective: Optional[MCAcquisitionObjective] = None, - posterior_transform: Optional[PosteriorTransform] = None, + objective: MCAcquisitionObjective | None = None, + posterior_transform: PosteriorTransform | None = None, qmc: bool = True, n_fantasies: int = 64, mc_samples: int = 256, num_trace_observations: int = 0, - seed_inner: Optional[int] = None, - seed_outer: Optional[int] = None, - X_pending: Optional[Tensor] = None, - current_value: Optional[Tensor] = None, - target_fidelities: Optional[dict[int, float]] = None, - fidelity_weights: Optional[dict[int, float]] = None, + seed_inner: int | None = None, + seed_outer: int | None = None, + X_pending: Tensor | None = None, + current_value: Tensor | None = None, + target_fidelities: dict[int, float] | None = None, + fidelity_weights: dict[int, float] | None = None, cost_intercept: float = 1.0, ) -> qKnowledgeGradient: r"""Instantiate either a `qKnowledgeGradient` or `qMultiFidelityKnowledgeGradient` @@ -367,9 +368,9 @@ def _optimize_and_get_candidates( # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict` to avoid runtime subscripting errors. optimizer_options: dict, - rounding_func: Optional[Callable[[Tensor], Tensor]], - inequality_constraints: Optional[list[tuple[Tensor, Tensor, float]]], - fixed_features: Optional[dict[int, float]], + rounding_func: Callable[[Tensor], Tensor] | None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None, + fixed_features: dict[int, float] | None, ) -> Tensor: r"""Generates initial conditions for optimization, optimize the acquisition function, and return the candidates. @@ -389,7 +390,7 @@ def _optimize_and_get_candidates( botorch_rounding_func = get_rounding_func(rounding_func) - opt_options: dict[str, Union[bool, float, int, str]] = { + opt_options: dict[str, bool | float | int | str] = { "batch_limit": 8, "maxiter": 200, "method": "L-BFGS-B", diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index c96801528e6..271515e99fe 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -12,11 +12,11 @@ import functools import operator import warnings -from collections.abc import Mapping +from collections.abc import Callable, Mapping from functools import partial, reduce from itertools import product from logging import Logger -from typing import Any, Callable, Optional +from typing import Any import torch from ax.core.search_space import SearchSpaceDigest @@ -97,7 +97,7 @@ def __init__( search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, botorch_acqf_class: type[AcquisitionFunction], - options: Optional[dict[str, Any]] = None, + options: dict[str, Any] | None = None, ) -> None: self.surrogates = surrogates self.options = options or {} @@ -175,7 +175,7 @@ def __init__( ) # Store objective thresholds for all outcomes (including non-objectives). - self._objective_thresholds: Optional[Tensor] = ( + self._objective_thresholds: Tensor | None = ( torch_opt_config.objective_thresholds ) self._full_objective_weights: Tensor = torch_opt_config.objective_weights @@ -326,8 +326,8 @@ def __init__( **{k: v for k, v in input_constructor_kwargs.items() if v is not None}, ) self.acqf = botorch_acqf_class(**acqf_inputs) # pyre-ignore [45] - self.X_pending: Optional[Tensor] = unique_Xs_pending - self.X_observed: Optional[Tensor] = unique_Xs_observed + self.X_pending: Tensor | None = unique_Xs_pending + self.X_observed: Tensor | None = unique_Xs_observed @property def botorch_acqf_class(self) -> type[AcquisitionFunction]: @@ -335,7 +335,7 @@ def botorch_acqf_class(self) -> type[AcquisitionFunction]: return self.acqf.__class__ @property - def dtype(self) -> Optional[torch.dtype]: + def dtype(self) -> torch.dtype | None: """Torch data type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. """ @@ -352,7 +352,7 @@ def dtype(self) -> Optional[torch.dtype]: return dtypes_list[0] @property - def device(self) -> Optional[torch.device]: + def device(self) -> torch.device | None: """Torch device type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. """ @@ -370,7 +370,7 @@ def device(self) -> Optional[torch.device]: return devices_list[0] @property - def objective_thresholds(self) -> Optional[Tensor]: + def objective_thresholds(self) -> Tensor | None: """The objective thresholds for all outcomes. For non-objective outcomes, the objective thresholds are nans. @@ -378,7 +378,7 @@ def objective_thresholds(self) -> Optional[Tensor]: return self._objective_thresholds @property - def objective_weights(self) -> Optional[Tensor]: + def objective_weights(self) -> Tensor | None: """The objective weights for all outcomes.""" return self._full_objective_weights @@ -386,10 +386,10 @@ def optimize( self, n: int, search_space_digest: SearchSpaceDigest, - inequality_constraints: Optional[list[tuple[Tensor, Tensor, float]]] = None, - fixed_features: Optional[dict[int, float]] = None, - rounding_func: Optional[Callable[[Tensor], Tensor]] = None, - optimizer_options: Optional[dict[str, Any]] = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + fixed_features: dict[int, float] | None = None, + rounding_func: Callable[[Tensor], Tensor] | None = None, + optimizer_options: dict[str, Any] | None = None, ) -> tuple[Tensor, Tensor, Tensor]: """Generate a set of candidates via multi-start optimization. Obtains candidates and their associated acquisition function values. @@ -606,7 +606,7 @@ def compute_model_dependencies( surrogates: Mapping[str, Surrogate], search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - options: Optional[dict[str, Any]] = None, + options: dict[str, Any] | None = None, ) -> dict[str, Any]: """Computes inputs to acquisition function class based on the given surrogate model. @@ -639,11 +639,11 @@ def get_botorch_objective_and_transform( botorch_acqf_class: type[AcquisitionFunction], model: Model, objective_weights: Tensor, - objective_thresholds: Optional[Tensor] = None, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - risk_measure: Optional[RiskMeasureMCObjective] = None, - ) -> tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]: + objective_thresholds: Tensor | None = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + risk_measure: RiskMeasureMCObjective | None = None, + ) -> tuple[MCAcquisitionObjective | None, PosteriorTransform | None]: return get_botorch_objective_and_transform( botorch_acqf_class=botorch_acqf_class, model=model, diff --git a/ax/models/torch/botorch_modular/input_constructors/covar_modules.py b/ax/models/torch/botorch_modular/input_constructors/covar_modules.py index 556ecb02034..e93e1c78c61 100644 --- a/ax/models/torch/botorch_modular/input_constructors/covar_modules.py +++ b/ax/models/torch/botorch_modular/input_constructors/covar_modules.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Optional, Union +from typing import Any import torch from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel @@ -31,9 +31,9 @@ @covar_module_argparse.register(Kernel) def _covar_module_argparse_base( covar_module_class: type[Kernel], - botorch_model_class: Optional[type[Model]] = None, - dataset: Optional[SupervisedDataset] = None, - covar_module_options: Optional[dict[str, Any]] = None, + botorch_model_class: type[Model] | None = None, + dataset: SupervisedDataset | None = None, + covar_module_options: dict[str, Any] | None = None, **kwargs: Any, ) -> dict[str, Any]: """ @@ -74,11 +74,11 @@ def _covar_module_argparse_scale_matern( covar_module_class: type[ScaleMaternKernel], botorch_model_class: type[Model], dataset: SupervisedDataset, - ard_num_dims: Union[int, _DefaultType] = DEFAULT, - batch_shape: Union[torch.Size, _DefaultType] = DEFAULT, - lengthscale_prior: Optional[Prior] = None, - outputscale_prior: Optional[Prior] = None, - covar_module_options: Optional[dict[str, Any]] = None, + ard_num_dims: int | _DefaultType = DEFAULT, + batch_shape: torch.Size | _DefaultType = DEFAULT, + lengthscale_prior: Prior | None = None, + outputscale_prior: Prior | None = None, + covar_module_options: dict[str, Any] | None = None, **kwargs: Any, ) -> dict[str, Any]: """Extract the base covar module kwargs form the given arguments. diff --git a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py index 1c5e7988463..d6afa6ad0b8 100644 --- a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any import torch from ax.core.search_space import SearchSpaceDigest @@ -33,9 +33,9 @@ @input_transform_argparse.register(InputTransform) def _input_transform_argparse_base( input_transform_class: type[InputTransform], - dataset: Optional[SupervisedDataset] = None, - search_space_digest: Optional[SearchSpaceDigest] = None, - input_transform_options: Optional[dict[str, Any]] = None, + dataset: SupervisedDataset | None = None, + search_space_digest: SearchSpaceDigest | None = None, + input_transform_options: dict[str, Any] | None = None, ) -> dict[str, Any]: """ Extract the input transform kwargs from the given arguments. @@ -63,7 +63,7 @@ def _input_transform_argparse_warp( input_transform_class: type[Warp], dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, - input_transform_options: Optional[dict[str, Any]] = None, + input_transform_options: dict[str, Any] | None = None, ) -> dict[str, Any]: """Extract the base input transform kwargs form the given arguments. @@ -93,9 +93,9 @@ def _input_transform_argparse_normalize( input_transform_class: type[Normalize], dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, - input_transform_options: Optional[dict[str, Any]] = None, - torch_device: Optional[torch.device] = None, - torch_dtype: Optional[torch.dtype] = None, + input_transform_options: dict[str, Any] | None = None, + torch_device: torch.device | None = None, + torch_dtype: torch.dtype | None = None, ) -> dict[str, Any]: """ Extract the base input transform kwargs form the given arguments. @@ -153,10 +153,10 @@ def _input_transform_argparse_normalize( def _input_transform_argparse_input_perturbation( input_transform_class: type[InputPerturbation], search_space_digest: SearchSpaceDigest, - dataset: Optional[SupervisedDataset] = None, - input_transform_options: Optional[dict[str, Any]] = None, - torch_device: Optional[torch.device] = None, - torch_dtype: Optional[torch.dtype] = None, + dataset: SupervisedDataset | None = None, + input_transform_options: dict[str, Any] | None = None, + torch_device: torch.device | None = None, + torch_dtype: torch.dtype | None = None, ) -> dict[str, Any]: """Extract the base input transform kwargs form the given arguments. diff --git a/ax/models/torch/botorch_modular/input_constructors/outcome_transform.py b/ax/models/torch/botorch_modular/input_constructors/outcome_transform.py index c3352cceb9d..248df3a369e 100644 --- a/ax/models/torch/botorch_modular/input_constructors/outcome_transform.py +++ b/ax/models/torch/botorch_modular/input_constructors/outcome_transform.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from ax.utils.common.typeutils import _argparse_type_encoder from botorch.models.transforms.outcome import OutcomeTransform, Standardize @@ -23,8 +23,8 @@ @outcome_transform_argparse.register(OutcomeTransform) def _outcome_transform_argparse_base( outcome_transform_class: type[OutcomeTransform], - dataset: Optional[SupervisedDataset] = None, - outcome_transform_options: Optional[dict[str, Any]] = None, + dataset: SupervisedDataset | None = None, + outcome_transform_options: dict[str, Any] | None = None, ) -> dict[str, Any]: """ Extract the outcome transform kwargs from the given arguments. @@ -50,7 +50,7 @@ def _outcome_transform_argparse_base( def _outcome_transform_argparse_standardize( outcome_transform_class: type[Standardize], dataset: SupervisedDataset, - outcome_transform_options: Optional[dict[str, Any]] = None, + outcome_transform_options: dict[str, Any] | None = None, ) -> dict[str, Any]: """Extract the outcome transform kwargs form the given arguments. diff --git a/ax/models/torch/botorch_modular/kernels.py b/ax/models/torch/botorch_modular/kernels.py index a8bcd7165d3..db50da2817a 100644 --- a/ax/models/torch/botorch_modular/kernels.py +++ b/ax/models/torch/botorch_modular/kernels.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any import torch from ax.exceptions.core import AxError @@ -22,12 +22,12 @@ class ScaleMaternKernel(ScaleKernel): def __init__( self, - ard_num_dims: Optional[int] = None, - batch_shape: Optional[torch.Size] = None, - lengthscale_prior: Optional[Prior] = None, - outputscale_prior: Optional[Prior] = None, - lengthscale_constraint: Optional[Interval] = None, - outputscale_constraint: Optional[Interval] = None, + ard_num_dims: int | None = None, + batch_shape: torch.Size | None = None, + lengthscale_prior: Prior | None = None, + outputscale_prior: Prior | None = None, + lengthscale_constraint: Interval | None = None, + outputscale_constraint: Interval | None = None, **kwargs: Any, ) -> None: r""" @@ -70,17 +70,17 @@ def __init__( self, dim: int, temporal_features: list[int], - matern_ard_num_dims: Optional[int] = None, - batch_shape: Optional[torch.Size] = None, - lengthscale_prior: Optional[Prior] = None, - temporal_lengthscale_prior: Optional[Prior] = None, - period_length_prior: Optional[Prior] = None, - fixed_period_length: Optional[float] = None, - outputscale_prior: Optional[Prior] = None, - lengthscale_constraint: Optional[Interval] = None, - outputscale_constraint: Optional[Interval] = None, - temporal_lengthscale_constraint: Optional[Interval] = None, - period_length_constraint: Optional[Interval] = None, + matern_ard_num_dims: int | None = None, + batch_shape: torch.Size | None = None, + lengthscale_prior: Prior | None = None, + temporal_lengthscale_prior: Prior | None = None, + period_length_prior: Prior | None = None, + fixed_period_length: float | None = None, + outputscale_prior: Prior | None = None, + lengthscale_constraint: Interval | None = None, + outputscale_constraint: Interval | None = None, + temporal_lengthscale_constraint: Interval | None = None, + period_length_constraint: Interval | None = None, **kwargs: Any, ) -> None: r""" diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 84be96666d5..068cce14945 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -8,12 +8,12 @@ import dataclasses from collections import OrderedDict -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass, field from functools import wraps from itertools import chain -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar import numpy as np import torch @@ -82,23 +82,23 @@ class SurrogateSpec: If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate. """ - botorch_model_class: Optional[type[Model]] = None + botorch_model_class: type[Model] | None = None botorch_model_kwargs: dict[str, Any] = field(default_factory=dict) mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood mll_kwargs: dict[str, Any] = field(default_factory=dict) - covar_module_class: Optional[type[Kernel]] = None - covar_module_kwargs: Optional[dict[str, Any]] = None + covar_module_class: type[Kernel] | None = None + covar_module_kwargs: dict[str, Any] | None = None - likelihood_class: Optional[type[Likelihood]] = None - likelihood_kwargs: Optional[dict[str, Any]] = None + likelihood_class: type[Likelihood] | None = None + likelihood_kwargs: dict[str, Any] | None = None - input_transform_classes: Optional[list[type[InputTransform]]] = None - input_transform_options: Optional[dict[str, dict[str, Any]]] = None + input_transform_classes: list[type[InputTransform]] | None = None + input_transform_options: dict[str, dict[str, Any]] | None = None - outcome_transform_classes: Optional[list[type[OutcomeTransform]]] = None - outcome_transform_options: Optional[dict[str, dict[str, Any]]] = None + outcome_transform_classes: list[type[OutcomeTransform]] | None = None + outcome_transform_options: dict[str, dict[str, Any]] | None = None allow_batched_models: bool = True @@ -144,19 +144,19 @@ class BoTorchModel(TorchModel, Base): surrogate_specs: dict[str, SurrogateSpec] _surrogates: dict[str, Surrogate] - _output_order: Optional[list[int]] = None + _output_order: list[int] | None = None - _botorch_acqf_class: Optional[type[AcquisitionFunction]] - _search_space_digest: Optional[SearchSpaceDigest] = None + _botorch_acqf_class: type[AcquisitionFunction] | None + _search_space_digest: SearchSpaceDigest | None = None _supports_robust_optimization: bool = True def __init__( self, - surrogate_specs: Optional[Mapping[str, SurrogateSpec]] = None, - surrogate: Optional[Surrogate] = None, - acquisition_class: Optional[type[Acquisition]] = None, - acquisition_options: Optional[dict[str, Any]] = None, - botorch_acqf_class: Optional[type[AcquisitionFunction]] = None, + surrogate_specs: Mapping[str, SurrogateSpec] | None = None, + surrogate: Surrogate | None = None, + acquisition_class: type[Acquisition] | None = None, + acquisition_options: dict[str, Any] | None = None, + botorch_acqf_class: type[AcquisitionFunction] | None = None, # TODO: [T168715924] Revisit these "refit" arguments. refit_on_cv: bool = False, warm_start_refit: bool = True, @@ -242,9 +242,9 @@ def fit( self, datasets: Sequence[SupervisedDataset], search_space_digest: SearchSpaceDigest, - candidate_metadata: Optional[list[list[TCandidateMetadata]]] = None, + candidate_metadata: list[list[TCandidateMetadata]] | None = None, # state dict by surrogate label - state_dicts: Optional[Mapping[str, OrderedDict[str, Tensor]]] = None, + state_dicts: Mapping[str, OrderedDict[str, Tensor]] | None = None, refit: bool = True, **additional_model_inputs: Any, ) -> None: @@ -469,7 +469,7 @@ def best_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - ) -> Optional[Tensor]: + ) -> Tensor | None: try: return self.surrogate.best_in_sample_point( search_space_digest=search_space_digest, @@ -484,7 +484,7 @@ def evaluate_acquisition_function( X: Tensor, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - acq_options: Optional[dict[str, Any]] = None, + acq_options: dict[str, Any] | None = None, ) -> Tensor: acqf = self._instantiate_acquisition( search_space_digest=search_space_digest, @@ -606,7 +606,7 @@ def _instantiate_acquisition( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - acq_options: Optional[dict[str, Any]] = None, + acq_options: dict[str, Any] | None = None, ) -> Acquisition: """Set a BoTorch acquisition function class for this model if needed and instantiate it. diff --git a/ax/models/torch/botorch_modular/optimizer_argparse.py b/ax/models/torch/botorch_modular/optimizer_argparse.py index 726a8c01a40..eeccf4eef2d 100644 --- a/ax/models/torch/botorch_modular/optimizer_argparse.py +++ b/ax/models/torch/botorch_modular/optimizer_argparse.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar, Union import torch from ax.utils.common.constants import Keys @@ -43,7 +43,7 @@ def _argparse_base( raw_samples: int = RAW_SAMPLES, init_batch_limit: int = INIT_BATCH_LIMIT, batch_limit: int = BATCH_LIMIT, - optimizer_options: Optional[dict[str, Any]] = None, + optimizer_options: dict[str, Any] | None = None, **ignore: Any, ) -> dict[str, Any]: """Extract the kwargs to be passed to a BoTorch optimizer. @@ -132,7 +132,7 @@ def _argparse_kg( num_restarts: int = NUM_RESTARTS, raw_samples: int = RAW_SAMPLES, frac_random: float = 0.1, - optimizer_options: Optional[dict[str, Any]] = None, + optimizer_options: dict[str, Any] | None = None, **kwargs: Any, ) -> dict[str, Any]: """ diff --git a/ax/models/torch/botorch_modular/sebo.py b/ax/models/torch/botorch_modular/sebo.py index 35cb9f1bd4c..cfda0557ba6 100644 --- a/ax/models/torch/botorch_modular/sebo.py +++ b/ax/models/torch/botorch_modular/sebo.py @@ -7,9 +7,10 @@ # pyre-strict import functools +from collections.abc import Callable from copy import deepcopy from functools import partial -from typing import Any, Callable, Optional +from typing import Any import torch from ax.core.search_space import SearchSpaceDigest @@ -56,7 +57,7 @@ def __init__( search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, botorch_acqf_class: type[AcquisitionFunction], - options: Optional[dict[str, Any]] = None, + options: dict[str, Any] | None = None, ) -> None: if len(surrogates) > 1: raise ValueError("SEBO does not support support multiple surrogates.") @@ -210,10 +211,10 @@ def optimize( self, n: int, search_space_digest: SearchSpaceDigest, - inequality_constraints: Optional[list[tuple[Tensor, Tensor, float]]] = None, - fixed_features: Optional[dict[int, float]] = None, - rounding_func: Optional[Callable[[Tensor], Tensor]] = None, - optimizer_options: Optional[dict[str, Any]] = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + fixed_features: dict[int, float] | None = None, + rounding_func: Callable[[Tensor], Tensor] | None = None, + optimizer_options: dict[str, Any] | None = None, ) -> tuple[Tensor, Tensor, Tensor]: """Generate a set of candidates via multi-start optimization. Obtains candidates and their associated acquisition function values. @@ -278,9 +279,9 @@ def _optimize_with_homotopy( self, n: int, search_space_digest: SearchSpaceDigest, - fixed_features: Optional[dict[int, float]] = None, - rounding_func: Optional[Callable[[Tensor], Tensor]] = None, - optimizer_options: Optional[dict[str, Any]] = None, + fixed_features: dict[int, float] | None = None, + rounding_func: Callable[[Tensor], Tensor] | None = None, + optimizer_options: dict[str, Any] | None = None, ) -> tuple[Tensor, Tensor, Tensor]: """Optimize SEBO ACQF with L0 norm using homotopy.""" # extend to fixed a no homotopy_schedule schedule diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 2ea9accc55d..7e940821eb2 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -13,7 +13,7 @@ from collections.abc import Sequence from copy import deepcopy from logging import Logger -from typing import Any, Optional, Union +from typing import Any import torch from ax.core.search_space import SearchSpaceDigest @@ -80,7 +80,7 @@ def _extract_model_kwargs( search_space_digest: SearchSpaceDigest, -) -> dict[str, Union[list[int], int]]: +) -> dict[str, list[int] | int]: """ Extracts keyword arguments that are passed to the `construct_inputs` method of a BoTorch `Model` class. @@ -105,7 +105,7 @@ def _extract_model_kwargs( if len(task_features) > 1: raise NotImplementedError("Multiple task features are not supported.") - kwargs: dict[str, Union[list[int], int]] = {} + kwargs: dict[str, list[int] | int] = {} if len(search_space_digest.categorical_features) > 0: kwargs["categorical_features"] = search_space_digest.categorical_features if len(fidelity_features) > 0: @@ -120,7 +120,7 @@ def _make_botorch_input_transform( dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, input_options: dict[str, dict[str, Any]], -) -> Optional[InputTransform]: +) -> InputTransform | None: """ Makes a BoTorch input transform from the provided input classes and options. """ @@ -165,7 +165,7 @@ def _make_botorch_outcome_transform( input_classes: list[type[OutcomeTransform]], input_options: dict[str, dict[str, Any]], dataset: SupervisedDataset, -) -> Optional[OutcomeTransform]: +) -> OutcomeTransform | None: """ Makes a BoTorch outcome transform from the provided classes and options. """ @@ -327,18 +327,18 @@ class string names and the values are dictionaries of input transform def __init__( self, - botorch_model_class: Optional[type[Model]] = None, - model_options: Optional[dict[str, Any]] = None, + botorch_model_class: type[Model] | None = None, + model_options: dict[str, Any] | None = None, mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood, - mll_options: Optional[dict[str, Any]] = None, - outcome_transform_classes: Optional[Sequence[type[OutcomeTransform]]] = None, - outcome_transform_options: Optional[dict[str, dict[str, Any]]] = None, - input_transform_classes: Optional[Sequence[type[InputTransform]]] = None, - input_transform_options: Optional[dict[str, dict[str, Any]]] = None, - covar_module_class: Optional[type[Kernel]] = None, - covar_module_options: Optional[dict[str, Any]] = None, - likelihood_class: Optional[type[Likelihood]] = None, - likelihood_options: Optional[dict[str, Any]] = None, + mll_options: dict[str, Any] | None = None, + outcome_transform_classes: Sequence[type[OutcomeTransform]] | None = None, + outcome_transform_options: dict[str, dict[str, Any]] | None = None, + input_transform_classes: Sequence[type[InputTransform]] | None = None, + input_transform_options: dict[str, dict[str, Any]] | None = None, + covar_module_class: type[Kernel] | None = None, + covar_module_options: dict[str, Any] | None = None, + likelihood_class: type[Likelihood] | None = None, + likelihood_options: dict[str, Any] | None = None, allow_batched_models: bool = True, ) -> None: self.botorch_model_class = botorch_model_class @@ -367,12 +367,12 @@ def __init__( self._submodels: dict[tuple[str], Model] = {} # Store a reference to search space digest used while fitting the cached models. # We will re-fit the models if the search space digest changes. - self._last_search_space_digest: Optional[SearchSpaceDigest] = None + self._last_search_space_digest: SearchSpaceDigest | None = None # These are later updated during model fitting. - self._training_data: Optional[list[SupervisedDataset]] = None - self._outcomes: Optional[list[str]] = None - self._model: Optional[Model] = None + self._training_data: list[SupervisedDataset] | None = None + self._outcomes: list[str] | None = None + self._model: Model | None = None def __repr__(self) -> str: return ( @@ -432,7 +432,7 @@ def _construct_model( dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, botorch_model_class: type[Model], - state_dict: Optional[OrderedDict[str, Tensor]], + state_dict: OrderedDict[str, Tensor] | None, refit: bool, ) -> Model: """Constructs the underlying BoTorch ``Model`` using the training data. @@ -501,8 +501,8 @@ def fit( self, datasets: Sequence[SupervisedDataset], search_space_digest: SearchSpaceDigest, - candidate_metadata: Optional[list[list[TCandidateMetadata]]] = None, - state_dict: Optional[OrderedDict[str, Tensor]] = None, + candidate_metadata: list[list[TCandidateMetadata]] | None = None, + state_dict: OrderedDict[str, Tensor] | None = None, refit: bool = True, ) -> None: """Fits the underlying BoTorch ``Model`` to ``m`` outcomes. @@ -622,7 +622,7 @@ def best_in_sample_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - options: Optional[TConfig] = None, + options: TConfig | None = None, ) -> tuple[Tensor, float]: """Finds the best observed point and the corresponding observed outcome values. @@ -654,7 +654,7 @@ def best_out_of_sample_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - options: Optional[TConfig] = None, + options: TConfig | None = None, ) -> tuple[Tensor, Tensor]: """Finds the best predicted point and the corresponding value of the appropriate best point acquisition function. @@ -743,7 +743,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]: def _extract_construct_input_transform_args( self, search_space_digest: SearchSpaceDigest - ) -> tuple[Optional[Sequence[type[InputTransform]]], dict[str, dict[str, Any]]]: + ) -> tuple[Sequence[type[InputTransform]] | None, dict[str, dict[str, Any]]]: """ Extracts input transform classes and input transform options that will be used in `_set_formatted_inputs` and ultimately passed to diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 239d267b974..cf8b80f1442 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -8,9 +8,9 @@ import warnings from collections import OrderedDict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from logging import Logger -from typing import Any, Callable, Optional +from typing import Any import torch from ax.core.search_space import SearchSpaceDigest @@ -130,12 +130,12 @@ def choose_model_class( def choose_botorch_acqf_class( - pending_observations: Optional[list[Tensor]] = None, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - linear_constraints: Optional[tuple[Tensor, Tensor]] = None, - fixed_features: Optional[dict[int, float]] = None, - objective_thresholds: Optional[Tensor] = None, - objective_weights: Optional[Tensor] = None, + pending_observations: list[Tensor] | None = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + linear_constraints: tuple[Tensor, Tensor] | None = None, + fixed_features: dict[int, float] | None = None, + objective_thresholds: Tensor | None = None, + objective_weights: Tensor | None = None, ) -> type[AcquisitionFunction]: """Chooses a BoTorch `AcquisitionFunction` class.""" if objective_thresholds is not None or ( @@ -154,7 +154,7 @@ def choose_botorch_acqf_class( def construct_acquisition_and_optimizer_options( - acqf_options: TConfig, model_gen_options: Optional[TConfig] = None + acqf_options: TConfig, model_gen_options: TConfig | None = None ) -> tuple[TConfig, TConfig]: """Extract acquisition and optimizer options from `model_gen_options`.""" acq_options = acqf_options.copy() @@ -282,7 +282,7 @@ def _get_shared_rows(Xs: list[Tensor]) -> tuple[Tensor, list[Tensor]]: def fit_botorch_model( model: Model, mll_class: type[MarginalLogLikelihood], - mll_options: Optional[dict[str, Any]] = None, + mll_options: dict[str, Any] | None = None, ) -> None: """Fit a BoTorch model.""" mll_options = mll_options or {} @@ -317,9 +317,9 @@ def _tensor_difference(A: Tensor, B: Tensor) -> Tensor: def get_post_processing_func( - rounding_func: Optional[Callable[[Tensor], Tensor]], + rounding_func: Callable[[Tensor], Tensor] | None, optimizer_options: dict[str, Any], -) -> Optional[Callable[[Tensor], Tensor]]: +) -> Callable[[Tensor], Tensor] | None: """Get the post processing function by combining the rounding function with the post processing function provided as part of the optimizer options. If both are given, the post processing function is applied before diff --git a/ax/models/torch/botorch_moo.py b/ax/models/torch/botorch_moo.py index acc30922ec8..bf1c97b9044 100644 --- a/ax/models/torch/botorch_moo.py +++ b/ax/models/torch/botorch_moo.py @@ -6,8 +6,9 @@ # pyre-strict +from collections.abc import Callable from logging import Logger -from typing import Any, Callable, Optional +from typing import Any, Optional import torch from ax.core.search_space import SearchSpaceDigest @@ -180,8 +181,8 @@ class MultiObjectiveBotorchModel(BotorchModel): a tuple of tensors describing the (linear) outcome constraints. """ - dtype: Optional[torch.dtype] - device: Optional[torch.device] + dtype: torch.dtype | None + device: torch.device | None Xs: list[Tensor] Ys: list[Tensor] Yvars: list[Tensor] @@ -209,7 +210,7 @@ def __init__( warm_start_refitting: bool = False, use_input_warping: bool = False, use_loocv_pseudo_likelihood: bool = False, - prior: Optional[dict[str, Any]] = None, + prior: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.model_constructor = model_constructor @@ -225,7 +226,7 @@ def __init__( self.use_input_warping = use_input_warping self.use_loocv_pseudo_likelihood = use_loocv_pseudo_likelihood self.prior = prior - self.model: Optional[Model] = None + self.model: Model | None = None self.Xs = [] self.Ys = [] self.Yvars = [] diff --git a/ax/models/torch/botorch_moo_defaults.py b/ax/models/torch/botorch_moo_defaults.py index a7108e813f6..762b4bb95f7 100644 --- a/ax/models/torch/botorch_moo_defaults.py +++ b/ax/models/torch/botorch_moo_defaults.py @@ -27,7 +27,9 @@ from __future__ import annotations -from typing import Callable, cast, Optional, Union +from collections.abc import Callable + +from typing import cast, Optional, Union import torch from ax.exceptions.core import AxError @@ -113,16 +115,16 @@ def get_NEHVI( model: Model, objective_weights: Tensor, objective_thresholds: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + X_pending: Tensor | None = None, *, prune_baseline: bool = True, mc_samples: int = DEFAULT_EHVI_MC_SAMPLES, - alpha: Optional[float] = None, - marginalize_dim: Optional[int] = None, + alpha: float | None = None, + marginalize_dim: int | None = None, cache_root: bool = True, - seed: Optional[int] = None, + seed: int | None = None, ) -> qNoisyExpectedHypervolumeImprovement: r"""Instantiates a qNoisyExpectedHyperVolumeImprovement acquisition function. @@ -180,16 +182,16 @@ def get_qLogNEHVI( model: Model, objective_weights: Tensor, objective_thresholds: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + X_pending: Tensor | None = None, *, prune_baseline: bool = True, mc_samples: int = DEFAULT_EHVI_MC_SAMPLES, - alpha: Optional[float] = None, - marginalize_dim: Optional[int] = None, + alpha: float | None = None, + marginalize_dim: int | None = None, cache_root: bool = True, - seed: Optional[int] = None, + seed: int | None = None, ) -> qLogNoisyExpectedHypervolumeImprovement: r"""Instantiates a qLogNoisyExpectedHyperVolumeImprovement acquisition function. @@ -248,19 +250,17 @@ def _get_NEHVI( model: Model, objective_weights: Tensor, objective_thresholds: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + X_pending: Tensor | None = None, *, prune_baseline: bool = True, mc_samples: int = DEFAULT_EHVI_MC_SAMPLES, - alpha: Optional[float] = None, - marginalize_dim: Optional[int] = None, + alpha: float | None = None, + marginalize_dim: int | None = None, cache_root: bool = True, - seed: Optional[int] = None, -) -> Union[ - qNoisyExpectedHypervolumeImprovement, qLogNoisyExpectedHypervolumeImprovement -]: + seed: int | None = None, +) -> qNoisyExpectedHypervolumeImprovement | qLogNoisyExpectedHypervolumeImprovement: if X_observed is None: raise ValueError(NO_FEASIBLE_POINTS_MESSAGE) # construct Objective module @@ -311,13 +311,13 @@ def get_EHVI( model: Model, objective_weights: Tensor, objective_thresholds: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + X_pending: Tensor | None = None, *, mc_samples: int = DEFAULT_EHVI_MC_SAMPLES, - alpha: Optional[float] = None, - seed: Optional[int] = None, + alpha: float | None = None, + seed: int | None = None, ) -> qExpectedHypervolumeImprovement: r"""Instantiates a qExpectedHyperVolumeImprovement acquisition function. @@ -370,13 +370,13 @@ def get_qLogEHVI( model: Model, objective_weights: Tensor, objective_thresholds: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + X_pending: Tensor | None = None, *, mc_samples: int = DEFAULT_EHVI_MC_SAMPLES, - alpha: Optional[float] = None, - seed: Optional[int] = None, + alpha: float | None = None, + seed: int | None = None, ) -> qLogExpectedHypervolumeImprovement: r"""Instantiates a qLogExpectedHyperVolumeImprovement acquisition function. @@ -430,14 +430,14 @@ def _get_EHVI( model: Model, objective_weights: Tensor, objective_thresholds: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + X_pending: Tensor | None = None, *, mc_samples: int = DEFAULT_EHVI_MC_SAMPLES, - alpha: Optional[float] = None, - seed: Optional[int] = None, -) -> Union[qExpectedHypervolumeImprovement, qLogExpectedHypervolumeImprovement]: + alpha: float | None = None, + seed: int | None = None, +) -> qExpectedHypervolumeImprovement | qLogExpectedHypervolumeImprovement: if X_observed is None: raise ValueError(NO_FEASIBLE_POINTS_MESSAGE) # construct Objective module @@ -487,12 +487,12 @@ def _get_EHVI( def scipy_optimizer_list( acq_function_list: list[AcquisitionFunction], bounds: Tensor, - inequality_constraints: Optional[list[tuple[Tensor, Tensor, float]]] = None, - fixed_features: Optional[dict[int, float]] = None, - rounding_func: Optional[Callable[[Tensor], Tensor]] = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + fixed_features: dict[int, float] | None = None, + rounding_func: Callable[[Tensor], Tensor] | None = None, num_restarts: int = 20, - raw_samples: Optional[int] = None, - options: Optional[dict[str, Union[bool, float, int, str]]] = None, + raw_samples: int | None = None, + options: dict[str, bool | float | int | str] | None = None, ) -> tuple[Tensor, Tensor]: r"""Sequential optimizer using scipy's minimize module on a numpy-adaptor. @@ -521,7 +521,7 @@ def scipy_optimizer_list( conditional on having observed candidates `0,1,...,i-1`. """ # Use SLSQP by default for small problems since it yields faster wall times. - optimize_options: dict[str, Union[bool, float, int, str]] = { + optimize_options: dict[str, bool | float | int | str] = { "batch_limit": 5, "init_batch_limit": 32, "method": "SLSQP", @@ -542,13 +542,13 @@ def scipy_optimizer_list( def pareto_frontier_evaluator( - model: Optional[TorchModel], + model: TorchModel | None, objective_weights: Tensor, - objective_thresholds: Optional[Tensor] = None, - X: Optional[Tensor] = None, - Y: Optional[Tensor] = None, - Yvar: Optional[Tensor] = None, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, + objective_thresholds: Tensor | None = None, + X: Tensor | None = None, + Y: Tensor | None = None, + Yvar: Tensor | None = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, ) -> tuple[Tensor, Tensor, Tensor]: """Return outcomes predicted to lie on a pareto frontier. @@ -654,14 +654,14 @@ def pareto_frontier_evaluator( def infer_objective_thresholds( model: Model, objective_weights: Tensor, # objective_directions - bounds: Optional[list[tuple[float, float]]] = None, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - linear_constraints: Optional[tuple[Tensor, Tensor]] = None, - fixed_features: Optional[dict[int, float]] = None, - subset_idcs: Optional[Tensor] = None, - Xs: Optional[list[Tensor]] = None, - X_observed: Optional[Tensor] = None, - objective_thresholds: Optional[Tensor] = None, + bounds: list[tuple[float, float]] | None = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + linear_constraints: tuple[Tensor, Tensor] | None = None, + fixed_features: dict[int, float] | None = None, + subset_idcs: Tensor | None = None, + Xs: list[Tensor] | None = None, + X_observed: Tensor | None = None, + objective_thresholds: Tensor | None = None, ) -> Tensor: """Infer objective thresholds. @@ -783,7 +783,7 @@ def infer_objective_thresholds( def _check_posterior_type( posterior: Posterior, -) -> Union[GPyTorchPosterior, PosteriorList]: +) -> GPyTorchPosterior | PosteriorList: """Check whether the posterior type is `GPyTorchPosterior` or `PosteriorList`.""" if isinstance(posterior, GPyTorchPosterior) or isinstance(posterior, PosteriorList): return posterior diff --git a/ax/models/torch/cbo_lcea.py b/ax/models/torch/cbo_lcea.py index d6a65039218..c3d2f667847 100644 --- a/ax/models/torch/cbo_lcea.py +++ b/ax/models/torch/cbo_lcea.py @@ -7,7 +7,7 @@ # pyre-strict from logging import Logger -from typing import Any, cast, Optional, Union +from typing import Any, cast, Union from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata @@ -38,14 +38,14 @@ def get_map_model( train_embedding: bool = True, # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict` to avoid runtime subscripting errors. - cat_feature_dict: Optional[dict] = None, + cat_feature_dict: dict | None = None, # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict` to avoid runtime subscripting errors. - embs_feature_dict: Optional[dict] = None, - embs_dim_list: Optional[list[int]] = None, + embs_feature_dict: dict | None = None, + embs_dim_list: list[int] | None = None, # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict` to avoid runtime subscripting errors. - context_weight_dict: Optional[dict] = None, + context_weight_dict: dict | None = None, ) -> tuple[LCEAGP, ExactMarginalLogLikelihood]: """Obtain MAP fitting of Latent Context Embedding Additive (LCE-A) GP.""" # assert train_X is non-batched @@ -85,15 +85,15 @@ def __init__( decomposition: dict[str, list[str]], # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict` to avoid runtime subscripting errors. - cat_feature_dict: Optional[dict] = None, + cat_feature_dict: dict | None = None, # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict` to avoid runtime subscripting errors. - embs_feature_dict: Optional[dict] = None, + embs_feature_dict: dict | None = None, # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict` to avoid runtime subscripting errors. - context_weight_dict: Optional[dict] = None, - embs_dim_list: Optional[list[int]] = None, - gp_model_args: Optional[dict[str, Any]] = None, + context_weight_dict: dict | None = None, + embs_dim_list: list[int] | None = None, + gp_model_args: dict[str, Any] | None = None, ) -> None: # add validation for input decomposition for param_list in list(decomposition.values()): @@ -119,7 +119,7 @@ def fit( self, datasets: list[SupervisedDataset], search_space_digest: SearchSpaceDigest, - candidate_metadata: Optional[list[list[TCandidateMetadata]]] = None, + candidate_metadata: list[list[TCandidateMetadata]] | None = None, ) -> None: if len(search_space_digest.feature_names) == 0: raise ValueError("feature names are required for LCEABO") @@ -134,7 +134,7 @@ def best_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - ) -> Optional[Tensor]: + ) -> Tensor | None: raise NotImplementedError def get_and_fit_model( @@ -145,8 +145,8 @@ def get_and_fit_model( task_features: list[int], fidelity_features: list[int], metric_names: list[str], - state_dict: Optional[dict[str, Tensor]] = None, - fidelity_model_id: Optional[int] = None, + state_dict: dict[str, Tensor] | None = None, + fidelity_model_id: int | None = None, **kwargs: Any, ) -> GPyTorchModel: """Get a fitted LCEAGP model for each outcome. @@ -185,5 +185,5 @@ def get_and_fit_model( return model @property - def model(self) -> Union[LCEAGP, ModelListGP]: + def model(self) -> LCEAGP | ModelListGP: return cast(Union[LCEAGP, ModelListGP], super().model) diff --git a/ax/models/torch/cbo_lcem.py b/ax/models/torch/cbo_lcem.py index 119dffb4902..eb320c9ddc6 100644 --- a/ax/models/torch/cbo_lcem.py +++ b/ax/models/torch/cbo_lcem.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Any, Optional +from typing import Any import torch from ax.models.torch.botorch import BotorchModel @@ -25,9 +25,9 @@ class LCEMBO(BotorchModel): def __init__( self, - context_cat_feature: Optional[Tensor] = None, - context_emb_feature: Optional[Tensor] = None, - embs_dim_list: Optional[list[int]] = None, + context_cat_feature: Tensor | None = None, + context_emb_feature: Tensor | None = None, + embs_dim_list: list[int] | None = None, ) -> None: self.context_cat_feature = context_cat_feature self.context_emb_feature = context_emb_feature @@ -42,8 +42,8 @@ def get_and_fit_model( task_features: list[int], fidelity_features: list[int], metric_names: list[str], - state_dict: Optional[dict[str, Tensor]] = None, - fidelity_model_id: Optional[int] = None, + state_dict: dict[str, Tensor] | None = None, + fidelity_model_id: int | None = None, **kwargs: Any, ) -> ModelListGP: """Get a fitted multi-task contextual GP model for each outcome. diff --git a/ax/models/torch/cbo_sac.py b/ax/models/torch/cbo_sac.py index b583e136b30..e0276369a0d 100644 --- a/ax/models/torch/cbo_sac.py +++ b/ax/models/torch/cbo_sac.py @@ -7,7 +7,7 @@ # pyre-strict from logging import Logger -from typing import Any, Optional +from typing import Any from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata @@ -53,7 +53,7 @@ def fit( self, datasets: list[SupervisedDataset], search_space_digest: SearchSpaceDigest, - candidate_metadata: Optional[list[list[TCandidateMetadata]]] = None, + candidate_metadata: list[list[TCandidateMetadata]] | None = None, ) -> None: if len(search_space_digest.feature_names) == 0: raise ValueError("feature names are required for SACBO") @@ -71,8 +71,8 @@ def get_and_fit_model( task_features: list[int], fidelity_features: list[int], metric_names: list[str], - state_dict: Optional[dict[str, Tensor]] = None, - fidelity_model_id: Optional[int] = None, + state_dict: dict[str, Tensor] | None = None, + fidelity_model_id: int | None = None, **kwargs: Any, ) -> GPyTorchModel: """Get a fitted StructuralAdditiveContextualGP model for each outcome. diff --git a/ax/models/torch/randomforest.py b/ax/models/torch/randomforest.py index 453a5912975..82fb6688da4 100644 --- a/ax/models/torch/randomforest.py +++ b/ax/models/torch/randomforest.py @@ -8,8 +8,6 @@ from __future__ import annotations -from typing import Optional - import numpy as np import torch from ax.core.search_space import SearchSpaceDigest @@ -38,9 +36,7 @@ class RandomForest(TorchModel): num_trees: Number of trees. """ - def __init__( - self, max_features: Optional[str] = "sqrt", num_trees: int = 500 - ) -> None: + def __init__(self, max_features: str | None = "sqrt", num_trees: int = 500) -> None: self.max_features = max_features self.num_trees = num_trees self.models: list[RandomForestRegressor] = [] @@ -50,7 +46,7 @@ def fit( self, datasets: list[SupervisedDataset], search_space_digest: SearchSpaceDigest, - candidate_metadata: Optional[list[list[TCandidateMetadata]]] = None, + candidate_metadata: list[list[TCandidateMetadata]] | None = None, ) -> None: Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets) for X, Y, Yvar in zip(Xs, Ys, Yvars): @@ -95,7 +91,7 @@ def _get_rf( Y: np.ndarray, Yvar: np.ndarray, num_trees: int, - max_features: Optional[str], + max_features: str | None, ) -> RandomForestRegressor: """Fit a Random Forest model. diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index aa385916b43..1fa02cf2a69 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -12,7 +12,7 @@ import itertools from contextlib import ExitStack from itertools import chain -from typing import Any, Optional +from typing import Any from unittest import mock from unittest.mock import Mock @@ -172,7 +172,7 @@ def setUp(self) -> None: ) def get_acquisition_function( - self, fixed_features: Optional[dict[int, float]] = None, one_shot: bool = False + self, fixed_features: dict[int, float] | None = None, one_shot: bool = False ) -> Acquisition: return Acquisition( botorch_acqf_class=( diff --git a/ax/models/torch/tests/test_sebo.py b/ax/models/torch/tests/test_sebo.py index f4e3f04ac4d..f5bb18c32f7 100644 --- a/ax/models/torch/tests/test_sebo.py +++ b/ax/models/torch/tests/test_sebo.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Optional, Union +from typing import Any from unittest import mock from unittest.mock import Mock @@ -106,8 +106,8 @@ def setUp(self) -> None: def get_acquisition_function( self, - fixed_features: Optional[dict[int, float]] = None, - options: Optional[dict[str, Union[str, float]]] = None, + fixed_features: dict[int, float] | None = None, + options: dict[str, str | float] | None = None, ) -> SEBOAcquisition: return SEBOAcquisition( botorch_acqf_class=qNoisyExpectedHypervolumeImprovement, diff --git a/ax/models/torch/utils.py b/ax/models/torch/utils.py index 4c8cc894df0..0d29ca27f3d 100644 --- a/ax/models/torch/utils.py +++ b/ax/models/torch/utils.py @@ -6,9 +6,10 @@ # pyre-strict +from collections.abc import Callable from dataclasses import dataclass from logging import Logger -from typing import Any, Callable, cast, Optional +from typing import Any, cast import numpy as np import torch @@ -76,8 +77,8 @@ class SubsetModelData: model: Model objective_weights: Tensor - outcome_constraints: Optional[tuple[Tensor, Tensor]] - objective_thresholds: Optional[Tensor] + outcome_constraints: tuple[Tensor, Tensor] | None + objective_thresholds: Tensor | None indices: Tensor @@ -94,11 +95,11 @@ def _filter_X_observed( Xs: list[Tensor], objective_weights: Tensor, bounds: list[tuple[float, float]], - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - linear_constraints: Optional[tuple[Tensor, Tensor]] = None, - fixed_features: Optional[dict[int, float]] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + linear_constraints: tuple[Tensor, Tensor] | None = None, + fixed_features: dict[int, float] | None = None, fit_out_of_design: bool = False, -) -> Optional[Tensor]: +) -> Tensor | None: r"""Filter input points to those appearing in objective or constraints. Args: @@ -143,12 +144,12 @@ def _get_X_pending_and_observed( Xs: list[Tensor], objective_weights: Tensor, bounds: list[tuple[float, float]], - pending_observations: Optional[list[Tensor]] = None, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - linear_constraints: Optional[tuple[Tensor, Tensor]] = None, - fixed_features: Optional[dict[int, float]] = None, + pending_observations: list[Tensor] | None = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + linear_constraints: tuple[Tensor, Tensor] | None = None, + fixed_features: dict[int, float] | None = None, fit_out_of_design: bool = False, -) -> tuple[Optional[Tensor], Optional[Tensor]]: +) -> tuple[Tensor | None, Tensor | None]: r"""Get pending and observed points. If all points would otherwise be filtered, remove `linear_constraints` @@ -216,10 +217,10 @@ def _generate_sobol_points( n_sobol: int, bounds: list[tuple[float, float]], device: torch.device, - linear_constraints: Optional[tuple[Tensor, Tensor]] = None, - fixed_features: Optional[dict[int, float]] = None, - rounding_func: Optional[Callable[[Tensor], Tensor]] = None, - model_gen_options: Optional[TConfig] = None, + linear_constraints: tuple[Tensor, Tensor] | None = None, + fixed_features: dict[int, float] | None = None, + rounding_func: Callable[[Tensor], Tensor] | None = None, + model_gen_options: TConfig | None = None, ) -> Tensor: linear_constraints_array = None @@ -271,8 +272,8 @@ def normalize_indices(indices: list[int], d: int) -> list[int]: def subset_model( model: Model, objective_weights: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - objective_thresholds: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + objective_thresholds: Tensor | None = None, ) -> SubsetModelData: """Subset a botorch model to the outputs used in the optimization. @@ -344,8 +345,8 @@ def subset_model( def _to_inequality_constraints( - linear_constraints: Optional[tuple[Tensor, Tensor]] = None -) -> Optional[list[tuple[Tensor, Tensor, float]]]: + linear_constraints: tuple[Tensor, Tensor] | None = None +) -> list[tuple[Tensor, Tensor, float]] | None: if linear_constraints is not None: A, b = linear_constraints inequality_constraints = [] @@ -388,8 +389,8 @@ def _get_risk_measure( model: Model, objective_weights: Tensor, risk_measure: RiskMeasureMCObjective, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, ) -> RiskMeasureMCObjective: r"""Processes the risk measure for `get_botorch_objective_and_transform`. See the docstring of `get_botorch_objective_and_transform` for the arguments. @@ -431,10 +432,10 @@ def get_botorch_objective_and_transform( botorch_acqf_class: type[AcquisitionFunction], model: Model, objective_weights: Tensor, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - risk_measure: Optional[RiskMeasureMCObjective] = None, -) -> tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]: + outcome_constraints: tuple[Tensor, Tensor] | None = None, + X_observed: Tensor | None = None, + risk_measure: RiskMeasureMCObjective | None = None, +) -> tuple[MCAcquisitionObjective | None, PosteriorTransform | None]: """Constructs a BoTorch `AcquisitionObjective` object. Args: @@ -481,11 +482,11 @@ def get_botorch_objective_and_transform( "X_observed is required to construct a constrained BoTorch objective." ) # If there are outcome constraints, we use MC Acquisition functions. - obj_tf: Callable[[Tensor, Optional[Tensor]], Tensor] = ( + obj_tf: Callable[[Tensor, Tensor | None], Tensor] = ( get_objective_weights_transform(objective_weights) ) - def objective(samples: Tensor, X: Optional[Tensor] = None) -> Tensor: + def objective(samples: Tensor, X: Tensor | None = None) -> Tensor: return obj_tf(samples, X) # SampleReducingMCAcquisitionFunctions take care of the constraint handling @@ -511,15 +512,15 @@ def get_out_of_sample_best_point_acqf( X_observed: Tensor, objective_weights: Tensor, mc_samples: int = 512, - fixed_features: Optional[dict[int, float]] = None, - fidelity_features: Optional[list[int]] = None, - target_fidelities: Optional[dict[int, float]] = None, - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, - seed_inner: Optional[int] = None, + fixed_features: dict[int, float] | None = None, + fidelity_features: list[int] | None = None, + target_fidelities: dict[int, float] | None = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, + seed_inner: int | None = None, qmc: bool = True, - risk_measure: Optional[RiskMeasureMCObjective] = None, + risk_measure: RiskMeasureMCObjective | None = None, **kwargs: Any, -) -> tuple[AcquisitionFunction, Optional[list[int]]]: +) -> tuple[AcquisitionFunction, list[int] | None]: """Picks an appropriate acquisition function to find the best out-of-sample (predicted by the given surrogate model) point and instantiates it. @@ -601,11 +602,11 @@ def get_out_of_sample_best_point_acqf( def pick_best_out_of_sample_point_acqf_class( - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None, + outcome_constraints: tuple[Tensor, Tensor] | None = None, mc_samples: int = 512, qmc: bool = True, - seed_inner: Optional[int] = None, - risk_measure: Optional[RiskMeasureMCObjective] = None, + seed_inner: int | None = None, + risk_measure: RiskMeasureMCObjective | None = None, ) -> tuple[type[AcquisitionFunction], dict[str, Any]]: if outcome_constraints is None and risk_measure is None: acqf_class = PosteriorMean diff --git a/ax/models/torch_base.py b/ax/models/torch_base.py index c8427db767f..2f75427327b 100644 --- a/ax/models/torch_base.py +++ b/ax/models/torch_base.py @@ -8,8 +8,10 @@ from __future__ import annotations +from collections.abc import Callable + from dataclasses import dataclass, field -from typing import Any, Callable, Optional +from typing import Any import torch from ax.core.metric import Metric @@ -78,16 +80,16 @@ class TorchOptConfig: """ objective_weights: Tensor - outcome_constraints: Optional[tuple[Tensor, Tensor]] = None - objective_thresholds: Optional[Tensor] = None - linear_constraints: Optional[tuple[Tensor, Tensor]] = None - fixed_features: Optional[dict[int, float]] = None - pending_observations: Optional[list[Tensor]] = None + outcome_constraints: tuple[Tensor, Tensor] | None = None + objective_thresholds: Tensor | None = None + linear_constraints: tuple[Tensor, Tensor] | None = None + fixed_features: dict[int, float] | None = None + pending_observations: list[Tensor] | None = None model_gen_options: TConfig = field(default_factory=dict) - rounding_func: Optional[Callable[[Tensor], Tensor]] = None - opt_config_metrics: Optional[dict[str, Metric]] = None + rounding_func: Callable[[Tensor], Tensor] | None = None + opt_config_metrics: dict[str, Metric] | None = None is_moo: bool = False - risk_measure: Optional[RiskMeasureMCObjective] = None + risk_measure: RiskMeasureMCObjective | None = None fit_out_of_design: bool = False @@ -105,7 +107,7 @@ class TorchGenResults: points: Tensor # (n x d)-dim weights: Tensor # n-dim gen_metadata: dict[str, Any] = field(default_factory=dict) - candidate_metadata: Optional[list[TCandidateMetadata]] = None + candidate_metadata: list[TCandidateMetadata] | None = None class TorchModel(BaseModel): @@ -115,15 +117,15 @@ class TorchModel(BaseModel): of Ax. """ - dtype: Optional[torch.dtype] = None - device: Optional[torch.device] = None + dtype: torch.dtype | None = None + device: torch.device | None = None _supports_robust_optimization: bool = False def fit( self, datasets: list[SupervisedDataset], search_space_digest: SearchSpaceDigest, - candidate_metadata: Optional[list[list[TCandidateMetadata]]] = None, + candidate_metadata: list[list[TCandidateMetadata]] | None = None, ) -> None: """Fit model to m outcomes. @@ -177,7 +179,7 @@ def best_point( self, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - ) -> Optional[Tensor]: + ) -> Tensor | None: """ Identify the current best point, satisfying the constraints in the same format as to gen. @@ -231,7 +233,7 @@ def update( datasets: list[SupervisedDataset], metric_names: list[str], search_space_digest: SearchSpaceDigest, - candidate_metadata: Optional[list[list[TCandidateMetadata]]] = None, + candidate_metadata: list[list[TCandidateMetadata]] | None = None, ) -> None: """Update the model. @@ -257,7 +259,7 @@ def evaluate_acquisition_function( X: Tensor, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, - acq_options: Optional[dict[str, Any]] = None, + acq_options: dict[str, Any] | None = None, ) -> Tensor: """Evaluate the acquisition function on the candidate set `X`. diff --git a/ax/models/winsorization_config.py b/ax/models/winsorization_config.py index 501d31babf0..3f4e47c45e3 100644 --- a/ax/models/winsorization_config.py +++ b/ax/models/winsorization_config.py @@ -7,7 +7,6 @@ # pyre-strict from dataclasses import dataclass -from typing import Optional @dataclass @@ -31,5 +30,5 @@ class WinsorizationConfig: lower_quantile_margin: float = 0.0 upper_quantile_margin: float = 0.0 - lower_boundary: Optional[float] = None - upper_boundary: Optional[float] = None + lower_boundary: float | None = None + upper_boundary: float | None = None diff --git a/ax/plot/bandit_rollout.py b/ax/plot/bandit_rollout.py index ec6b68e480f..b6576cd8d00 100644 --- a/ax/plot/bandit_rollout.py +++ b/ax/plot/bandit_rollout.py @@ -44,7 +44,7 @@ def plot_bandit_rollout(experiment: Experiment) -> AxPlotConfig: arms[arm.name]["x"].append(category) arms[arm.name]["y"].append(weight) - arms[arm.name]["text"].append("{:.2f}%".format(weight)) + arms[arm.name]["text"].append(f"{weight:.2f}%") for key in arms.keys(): data.append(arms[key]) diff --git a/ax/plot/base.py b/ax/plot/base.py index ec113fbdd5a..baa057e19c7 100644 --- a/ax/plot/base.py +++ b/ax/plot/base.py @@ -8,7 +8,7 @@ import enum import json -from typing import Any, NamedTuple, Optional, Union +from typing import Any, NamedTuple from ax.core.types import TParameterization from ax.utils.common.serialization import named_tuple_to_dict @@ -50,7 +50,7 @@ def __new__(cls, data: dict[str, Any], plot_type: enum.Enum) -> "AxPlotConfig": json.dumps(named_tuple_to_dict(data), cls=utils.PlotlyJSONEncoder) ) # pyre-fixme[7]: Expected `AxPlotConfig` but got `NamedTuple`. - return super(AxPlotConfig, cls).__new__(cls, dict_data, plot_type) + return super().__new__(cls, dict_data, plot_type) # Structs for plot data @@ -63,7 +63,7 @@ class PlotInSampleArm(NamedTuple): y_hat: dict[str, float] se: dict[str, float] se_hat: dict[str, float] - context_stratum: Optional[dict[str, Union[str, float]]] + context_stratum: dict[str, str | float] | None class PlotOutOfSampleArm(NamedTuple): @@ -73,7 +73,7 @@ class PlotOutOfSampleArm(NamedTuple): parameters: TParameterization y_hat: dict[str, float] se_hat: dict[str, float] - context_stratum: Optional[dict[str, Union[str, float]]] + context_stratum: dict[str, str | float] | None class PlotData(NamedTuple): @@ -81,8 +81,8 @@ class PlotData(NamedTuple): metrics: list[str] in_sample: dict[str, PlotInSampleArm] - out_of_sample: Optional[dict[str, dict[str, PlotOutOfSampleArm]]] - status_quo_name: Optional[str] + out_of_sample: dict[str, dict[str, PlotOutOfSampleArm]] | None + status_quo_name: str | None class PlotMetric(NamedTuple): diff --git a/ax/plot/contour.py b/ax/plot/contour.py index 530d52ee90a..5d0ec9888ac 100644 --- a/ax/plot/contour.py +++ b/ax/plot/contour.py @@ -8,7 +8,7 @@ import re from copy import deepcopy -from typing import Any, Optional +from typing import Any import numpy as np import plotly.graph_objs as go @@ -58,8 +58,8 @@ def _get_contour_predictions( metric: str, generator_runs_dict: TNullableGeneratorRunsDict, density: int, - slice_values: Optional[dict[str, Any]] = None, - fixed_features: Optional[ObservationFeatures] = None, + slice_values: dict[str, Any] | None = None, + fixed_features: ObservationFeatures | None = None, ) -> ContourPredictions: """ slice_values is a dictionary {param_name: value} for the parameters that @@ -114,10 +114,10 @@ def plot_contour_plotly( generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, - slice_values: Optional[dict[str, Any]] = None, + slice_values: dict[str, Any] | None = None, lower_is_better: bool = False, - fixed_features: Optional[ObservationFeatures] = None, - trial_index: Optional[int] = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, ) -> go.Figure: """Plot predictions for a 2-d slice of the parameter space. @@ -287,10 +287,10 @@ def plot_contour( generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, - slice_values: Optional[dict[str, Any]] = None, + slice_values: dict[str, Any] | None = None, lower_is_better: bool = False, - fixed_features: Optional[ObservationFeatures] = None, - trial_index: Optional[int] = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, ) -> AxPlotConfig: """Plot predictions for a 2-d slice of the parameter space. @@ -341,11 +341,11 @@ def interact_contour_plotly( generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, - slice_values: Optional[dict[str, Any]] = None, + slice_values: dict[str, Any] | None = None, lower_is_better: bool = False, - fixed_features: Optional[ObservationFeatures] = None, - trial_index: Optional[int] = None, - parameters_to_use: Optional[list[str]] = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, + parameters_to_use: list[str] | None = None, ) -> go.Figure: """Create interactive plot with predictions for a 2-d slice of the parameter space. @@ -896,11 +896,11 @@ def interact_contour( generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, - slice_values: Optional[dict[str, Any]] = None, + slice_values: dict[str, Any] | None = None, lower_is_better: bool = False, - fixed_features: Optional[ObservationFeatures] = None, - trial_index: Optional[int] = None, - parameters_to_use: Optional[list[str]] = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, + parameters_to_use: list[str] | None = None, ) -> AxPlotConfig: """Create interactive plot with predictions for a 2-d slice of the parameter space. diff --git a/ax/plot/diagnostic.py b/ax/plot/diagnostic.py index 8570b762c2e..09ea83b4716 100644 --- a/ax/plot/diagnostic.py +++ b/ax/plot/diagnostic.py @@ -7,7 +7,7 @@ # pyre-strict from copy import deepcopy -from typing import Any, Optional +from typing import Any import numpy as np import plotly.graph_objs as go @@ -279,7 +279,7 @@ def _get_batch_comparison_plot_data( batch_x: int, batch_y: int, rel: bool = False, - status_quo_name: Optional[str] = None, + status_quo_name: str | None = None, ) -> PlotData: """Compute PlotData for comparing repeated arms across trials. @@ -354,7 +354,7 @@ def _get_batch_comparison_plot_data( def _get_cv_plot_data( - cv_results: list[CVResult], label_dict: Optional[dict[str, str]] + cv_results: list[CVResult], label_dict: dict[str, str] | None ) -> PlotData: if len(cv_results) == 0: return PlotData( @@ -488,7 +488,7 @@ def interact_cross_validation_plotly( cv_results: list[CVResult], show_context: bool = True, caption: str = "", - label_dict: Optional[dict[str, str]] = None, + label_dict: dict[str, str] | None = None, autoset_axis_limits: bool = True, ) -> go.Figure: """Interactive cross-validation (CV) plotting; select metric via dropdown. @@ -525,7 +525,7 @@ def interact_cross_validation( cv_results: list[CVResult], show_context: bool = True, caption: str = "", - label_dict: Optional[dict[str, str]] = None, + label_dict: dict[str, str] | None = None, autoset_axis_limits: bool = True, ) -> AxPlotConfig: """Interactive cross-validation (CV) plotting; select metric via dropdown. @@ -560,7 +560,7 @@ def tile_cross_validation( cv_results: list[CVResult], show_arm_details_on_hover: bool = True, show_context: bool = True, - label_dict: Optional[dict[str, str]] = None, + label_dict: dict[str, str] | None = None, ) -> AxPlotConfig: """Tile version of CV plots; sorted by 'best fitting' outcomes. @@ -626,8 +626,8 @@ def tile_cross_validation( # if odd number of plots, need to manually remove the last blank subplot # generated by `subplots.make_subplots` if len(metrics) % 2 == 1: - fig["layout"].pop("xaxis{}".format(nrows * ncols)) - fig["layout"].pop("yaxis{}".format(nrows * ncols)) + fig["layout"].pop(f"xaxis{nrows * ncols}") + fig["layout"].pop(f"yaxis{nrows * ncols}") # allocate 400 px per plot (equal aspect ratio) fig["layout"].update( @@ -642,10 +642,10 @@ def tile_cross_validation( # update subplot title size and the axis labels for i, ant in enumerate(fig["layout"]["annotations"]): ant["font"].update(size=12) - fig["layout"]["xaxis{}".format(i + 1)].update( + fig["layout"][f"xaxis{i + 1}"].update( title="Actual Outcome", mirror=True, linecolor="black", linewidth=0.5 ) - fig["layout"]["yaxis{}".format(i + 1)].update( + fig["layout"][f"yaxis{i + 1}"].update( title="Predicted Outcome", mirror=True, linecolor="black", linewidth=0.5 ) @@ -659,9 +659,9 @@ def interact_batch_comparison( batch_x: int, batch_y: int, rel: bool = False, - status_quo_name: Optional[str] = None, - x_label: Optional[str] = None, - y_label: Optional[str] = None, + status_quo_name: str | None = None, + x_label: str | None = None, + y_label: str | None = None, ) -> AxPlotConfig: """Compare repeated arms from two trials; select metric via dropdown. diff --git a/ax/plot/feature_importances.py b/ax/plot/feature_importances.py index 77eac4f1cb5..2c4253d80ca 100644 --- a/ax/plot/feature_importances.py +++ b/ax/plot/feature_importances.py @@ -7,7 +7,7 @@ # pyre-strict from logging import Logger -from typing import Any, Optional, Union +from typing import Any import numpy as np import pandas as pd @@ -101,12 +101,12 @@ def plot_feature_importance_by_metric(model: ModelBridge) -> AxPlotConfig: def plot_feature_importance_by_feature_plotly( - model: Optional[ModelBridge] = None, - sensitivity_values: Optional[dict[str, dict[str, Union[float, np.ndarray]]]] = None, + model: ModelBridge | None = None, + sensitivity_values: dict[str, dict[str, float | np.ndarray]] | None = None, relative: bool = False, caption: str = "", importance_measure: str = "", - label_dict: Optional[dict[str, str]] = None, + label_dict: dict[str, str] | None = None, ) -> go.Figure: """One plot per metric, showing importances by feature. @@ -277,12 +277,12 @@ def plot_feature_importance_by_feature_plotly( def plot_feature_importance_by_feature( - model: Optional[ModelBridge] = None, - sensitivity_values: Optional[dict[str, dict[str, Union[float, np.ndarray]]]] = None, + model: ModelBridge | None = None, + sensitivity_values: dict[str, dict[str, float | np.ndarray]] | None = None, relative: bool = False, caption: str = "", importance_measure: str = "", - label_dict: Optional[dict[str, str]] = None, + label_dict: dict[str, str] | None = None, ) -> AxPlotConfig: """Wrapper method to convert `plot_feature_importance_by_feature_plotly` to AxPlotConfig""" @@ -311,7 +311,7 @@ def plot_relative_feature_importance_plotly(model: ModelBridge) -> go.Figure: importances.append(vals) except Exception: logger.warning( - "Model for {} does not support feature importances.".format(metric_name) + f"Model for {metric_name} does not support feature importances." ) df = pd.DataFrame(importances) df.set_index("index", inplace=True) diff --git a/ax/plot/helper.py b/ax/plot/helper.py index 236085738fb..756bed1d14a 100644 --- a/ax/plot/helper.py +++ b/ax/plot/helper.py @@ -8,9 +8,10 @@ import math from collections import Counter +from collections.abc import Callable from logging import Logger -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import numpy as np from ax.core.generator_run import GeneratorRun @@ -70,7 +71,7 @@ def _format_dict(param_dict: TParameterization, name: str = "Parameterization") ) else: blob = "
{}:
{}".format( - name, "
".join("{}: {}".format(n, v) for n, v in param_dict.items()) + name, "
".join(f"{n}: {v}" for n, v in param_dict.items()) ) return blob @@ -108,7 +109,7 @@ def _format_CI(estimate: float, sd: float, relative: bool, zval: float = Z) -> s ) -def arm_name_to_tuple(arm_name: str) -> Union[tuple[int, int], tuple[int]]: +def arm_name_to_tuple(arm_name: str) -> tuple[int, int] | tuple[int]: tup = arm_name.split("_") if len(tup) == 2: try: @@ -149,9 +150,9 @@ def _filter_dict( def _get_in_sample_arms( model: ModelBridge, metric_names: set[str], - fixed_features: Optional[ObservationFeatures] = None, - data_selector: Optional[Callable[[Observation], bool]] = None, - scalarized_metric_config: Optional[list[dict[str, dict[str, float]]]] = None, + fixed_features: ObservationFeatures | None = None, + data_selector: Callable[[Observation], bool] | None = None, + scalarized_metric_config: list[dict[str, dict[str, float]]] | None = None, ) -> tuple[dict[str, PlotInSampleArm], RawData, dict[str, TParameterization]]: """Get in-sample arms from a model with observed and predicted values for specified metrics. @@ -283,8 +284,8 @@ def _get_out_of_sample_arms( model: ModelBridge, generator_runs_dict: dict[str, GeneratorRun], metric_names: set[str], - fixed_features: Optional[ObservationFeatures] = None, - scalarized_metric_config: Optional[list[dict[str, dict[str, float]]]] = None, + fixed_features: ObservationFeatures | None = None, + scalarized_metric_config: list[dict[str, dict[str, float]]] | None = None, ) -> dict[str, dict[str, PlotOutOfSampleArm]]: """Get out-of-sample predictions from a model given a dict of generator runs. @@ -336,10 +337,10 @@ def _get_out_of_sample_arms( def get_plot_data( model: ModelBridge, generator_runs_dict: dict[str, GeneratorRun], - metric_names: Optional[set[str]] = None, - fixed_features: Optional[ObservationFeatures] = None, - data_selector: Optional[Callable[[Observation], bool]] = None, - scalarized_metric_config: Optional[list[dict[str, dict[str, float]]]] = None, + metric_names: set[str] | None = None, + fixed_features: ObservationFeatures | None = None, + data_selector: Callable[[Observation], bool] | None = None, + scalarized_metric_config: list[dict[str, dict[str, float]]] | None = None, ) -> tuple[PlotData, RawData, dict[str, TParameterization]]: """Format data object with metrics for in-sample and out-of-sample arms. @@ -482,8 +483,8 @@ def get_grid_for_parameter(parameter: RangeParameter, density: int) -> np.ndarra def get_fixed_values( model: ModelBridge, - slice_values: Optional[dict[str, Any]] = None, - trial_index: Optional[int] = None, + slice_values: dict[str, Any] | None = None, + trial_index: int | None = None, ) -> TParameterization: """Get fixed values for parameters in a slice plot. diff --git a/ax/plot/marginal_effects.py b/ax/plot/marginal_effects.py index dfcb10162fb..84d439bc789 100644 --- a/ax/plot/marginal_effects.py +++ b/ax/plot/marginal_effects.py @@ -67,7 +67,7 @@ def plot_marginal_effects(model: ModelBridge, metric: str) -> AxPlotConfig: fig.layout.title = "Marginal Effects by Factor" fig.layout.yaxis = { "title": "% higher than experiment average", - "hoverformat": ".{}f".format(DECIMALS), + "hoverformat": f".{DECIMALS}f", } # pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`. return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC) diff --git a/ax/plot/parallel_coordinates.py b/ax/plot/parallel_coordinates.py index df7208c1bea..fcbb93e9917 100644 --- a/ax/plot/parallel_coordinates.py +++ b/ax/plot/parallel_coordinates.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Optional import pandas as pd from ax.core.experiment import Experiment @@ -17,7 +16,7 @@ def prepare_experiment_for_plotting( experiment: Experiment, - ignored_names: Optional[list[str]] = None, + ignored_names: list[str] | None = None, ) -> pd.DataFrame: """Strip variables not desired in the final plot and truncate names for readability @@ -48,7 +47,7 @@ def prepare_experiment_for_plotting( def plot_parallel_coordinates_plotly( - experiment: Experiment, ignored_names: Optional[list[str]] = None + experiment: Experiment, ignored_names: list[str] | None = None ) -> go.Figure: """Plot trials as a parallel coordinates graph @@ -69,7 +68,7 @@ def plot_parallel_coordinates_plotly( def plot_parallel_coordinates( - experiment: Experiment, ignored_names: Optional[list[str]] = None + experiment: Experiment, ignored_names: list[str] | None = None ) -> AxPlotConfig: """Plot trials as a parallel coordinates graph diff --git a/ax/plot/pareto_frontier.py b/ax/plot/pareto_frontier.py index 716a883297e..dedfdfc9892 100644 --- a/ax/plot/pareto_frontier.py +++ b/ax/plot/pareto_frontier.py @@ -8,7 +8,6 @@ import warnings from collections.abc import Iterable -from typing import Optional, Union import numpy as np import pandas as pd @@ -36,7 +35,7 @@ def _make_label( - mean: float, sem: float, name: str, is_relative: bool, Z: Optional[float] + mean: float, sem: float, name: str, is_relative: bool, Z: float | None ) -> str: estimate = str(round(mean, DECIMALS)) perc = "%" if is_relative else "" @@ -82,12 +81,12 @@ def scatter_plot_with_hypervolume_trace_plotly(experiment: Experiment) -> go.Fig def scatter_plot_with_pareto_frontier_plotly( Y: np.ndarray, - Y_pareto: Optional[np.ndarray], - metric_x: Optional[str], - metric_y: Optional[str], - reference_point: Optional[tuple[float, float]], - minimize: Optional[Union[bool, tuple[bool, bool]]] = True, - hovertext: Optional[Iterable[str]] = None, + Y_pareto: np.ndarray | None, + metric_x: str | None, + metric_y: str | None, + reference_point: tuple[float, float] | None, + minimize: bool | tuple[bool, bool] | None = True, + hovertext: Iterable[str] | None = None, ) -> go.Figure: """Plots a scatter of all points in ``Y`` for ``metric_x`` and ``metric_y`` with a reference point and Pareto frontier from ``Y_pareto``. @@ -530,7 +529,7 @@ def interact_pareto_frontier( frontier_list: list[ParetoFrontierResults], CI_level: float = DEFAULT_CI_LEVEL, show_parameterization_on_hover: bool = True, - label_dict: Optional[dict[str, str]] = None, + label_dict: dict[str, str] | None = None, ) -> AxPlotConfig: """Plot a pareto frontier from a list of objects @@ -778,10 +777,10 @@ def interact_multiple_pareto_frontier( def _pareto_frontier_plot_input_processing( experiment: Experiment, - metric_names: Optional[tuple[str, str]] = None, - reference_point: Optional[tuple[float, float]] = None, - minimize: Optional[Union[bool, tuple[bool, bool]]] = None, -) -> tuple[tuple[str, str], Optional[tuple[float, float]], Optional[tuple[bool, bool]]]: + metric_names: tuple[str, str] | None = None, + reference_point: tuple[float, float] | None = None, + minimize: bool | tuple[bool, bool] | None = None, +) -> tuple[tuple[str, str], tuple[float, float] | None, tuple[bool, bool] | None]: """Processes inputs for Pareto frontier + scatterplot. Args: @@ -830,10 +829,10 @@ def _pareto_frontier_plot_input_processing( def _validate_experiment_and_get_optimization_config( experiment: Experiment, - metric_names: Optional[tuple[str, str]] = None, - reference_point: Optional[tuple[float, float]] = None, - minimize: Optional[Union[bool, tuple[bool, bool]]] = None, -) -> Optional[OptimizationConfig]: + metric_names: tuple[str, str] | None = None, + reference_point: tuple[float, float] | None = None, + minimize: bool | tuple[bool, bool] | None = None, +) -> OptimizationConfig | None: # If `optimization_config` is unspecified, check what inputs are missing and # error/warn accordingly if experiment.optimization_config is None: @@ -856,8 +855,8 @@ def _validate_experiment_and_get_optimization_config( def _validate_and_maybe_get_default_metric_names( - metric_names: Optional[tuple[str, str]], - optimization_config: Optional[OptimizationConfig], + metric_names: tuple[str, str] | None, + optimization_config: OptimizationConfig | None, ) -> tuple[str, str]: # Default metric_names is all metrics, producing an error if more than 2 if metric_names is None: @@ -883,9 +882,9 @@ def _validate_and_maybe_get_default_metric_names( def _validate_experiment_and_maybe_get_objective_thresholds( - optimization_config: Optional[OptimizationConfig], + optimization_config: OptimizationConfig | None, metric_names: tuple[str, str], - reference_point: Optional[tuple[float, float]], + reference_point: tuple[float, float] | None, ) -> list[ObjectiveThreshold]: objective_thresholds = [] # Validate `objective_thresholds` if `reference_point` is unspecified. @@ -920,10 +919,10 @@ def _validate_experiment_and_maybe_get_objective_thresholds( def _validate_and_maybe_get_default_reference_point( - reference_point: Optional[tuple[float, float]], + reference_point: tuple[float, float] | None, objective_thresholds: list[ObjectiveThreshold], metric_names: tuple[str, str], -) -> Optional[tuple[float, float]]: +) -> tuple[float, float] | None: if reference_point is None: reference_point = { objective_threshold.metric.name: objective_threshold.bound @@ -953,11 +952,11 @@ def _validate_and_maybe_get_default_reference_point( def _validate_and_maybe_get_default_minimize( - minimize: Optional[Union[bool, tuple[bool, bool]]], + minimize: bool | tuple[bool, bool] | None, objective_thresholds: list[ObjectiveThreshold], metric_names: tuple[str, str], - optimization_config: Optional[OptimizationConfig] = None, -) -> Optional[tuple[bool, bool]]: + optimization_config: OptimizationConfig | None = None, +) -> tuple[bool, bool] | None: if minimize is None: # Determine `minimize` defaults minimize = tuple( @@ -995,8 +994,8 @@ def _validate_and_maybe_get_default_minimize( def _maybe_get_default_minimize_single_metric( metric_name: str, objective_thresholds: list[ObjectiveThreshold], - optimization_config: Optional[OptimizationConfig] = None, -) -> Optional[bool]: + optimization_config: OptimizationConfig | None = None, +) -> bool | None: minimize = None # First try to get metric_name from optimization_config if ( diff --git a/ax/plot/pareto_utils.py b/ax/plot/pareto_utils.py index e8e37b63e6f..19aed4e6687 100644 --- a/ax/plot/pareto_utils.py +++ b/ax/plot/pareto_utils.py @@ -9,7 +9,7 @@ from copy import deepcopy from itertools import combinations from logging import Logger -from typing import NamedTuple, Optional, Union +from typing import NamedTuple import numpy as np import torch @@ -54,8 +54,8 @@ def _extract_observed_pareto_2d( Y: np.ndarray, - reference_point: Optional[tuple[float, float]], - minimize: Union[bool, tuple[bool, bool]] = True, + reference_point: tuple[float, float] | None, + minimize: bool | tuple[bool, bool] = True, ) -> np.ndarray: if Y.shape[1] != 2: raise NotImplementedError("Currently only the 2-dim case is handled.") @@ -113,8 +113,8 @@ class ParetoFrontierResults(NamedTuple): primary_metric: str secondary_metric: str absolute_metrics: list[str] - objective_thresholds: Optional[dict[str, float]] - arm_names: Optional[list[Optional[str]]] + objective_thresholds: dict[str, float] | None + arm_names: list[str | None] | None def _extract_sq_data( @@ -162,9 +162,9 @@ def _relativize_values( def get_observed_pareto_frontiers( experiment: Experiment, - data: Optional[Data] = None, - rel: Optional[bool] = None, - arm_names: Optional[list[str]] = None, + data: Data | None = None, + rel: bool | None = None, + arm_names: list[str] | None = None, ) -> list[ParetoFrontierResults]: """ Find all Pareto points from an experiment. @@ -342,11 +342,11 @@ def compute_posterior_pareto_frontier( experiment: Experiment, primary_objective: Metric, secondary_objective: Metric, - data: Optional[Data] = None, - outcome_constraints: Optional[list[OutcomeConstraint]] = None, - absolute_metrics: Optional[list[str]] = None, + data: Data | None = None, + outcome_constraints: list[OutcomeConstraint] | None = None, + absolute_metrics: list[str] | None = None, num_points: int = 10, - trial_index: Optional[int] = None, + trial_index: int | None = None, ) -> ParetoFrontierResults: """Compute the Pareto frontier between two objectives. For experiments with batch trials, a trial index or data object must be provided. @@ -489,8 +489,8 @@ def _extract_pareto_frontier_results( primary_metric: str, secondary_metric: str, absolute_metrics: list[str], - outcome_constraints: Optional[list[OutcomeConstraint]], - status_quo_prediction: Optional[tuple[Mu, Cov]], + outcome_constraints: list[OutcomeConstraint] | None, + status_quo_prediction: tuple[Mu, Cov] | None, ) -> ParetoFrontierResults: """Extract prediction results into ParetoFrontierResults struture.""" metrics = list(means.keys()) @@ -548,7 +548,7 @@ def _build_scalarized_optimization_config( weights: np.ndarray, primary_objective: Metric, secondary_objective: Metric, - outcome_constraints: Optional[list[OutcomeConstraint]] = None, + outcome_constraints: list[OutcomeConstraint] | None = None, ) -> MultiObjectiveOptimizationConfig: obj = ScalarizedObjective( metrics=[primary_objective, secondary_objective], diff --git a/ax/plot/render.py b/ax/plot/render.py index c249c1f6ffe..b1d7a30f2b7 100644 --- a/ax/plot/render.py +++ b/ax/plot/render.py @@ -97,7 +97,7 @@ def _get_plot_js( def _wrap_js(script: str) -> str: """Wrap JS in tag for injection into HTML.""" - return "".format(script=script) + return f"" def _plot_js_to_html(js_script: str, plotdivid: str) -> str: diff --git a/ax/plot/scatter.py b/ax/plot/scatter.py index cdedfc6a274..9b5d65c7e6a 100644 --- a/ax/plot/scatter.py +++ b/ax/plot/scatter.py @@ -9,10 +9,10 @@ import numbers import warnings from collections import OrderedDict -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from logging import Logger -from typing import Any, Callable, Optional, Union +from typing import Any import numpy as np import plotly.graph_objs as go @@ -55,11 +55,11 @@ def _error_scatter_data( - arms: Iterable[Union[PlotInSampleArm, PlotOutOfSampleArm]], + arms: Iterable[PlotInSampleArm | PlotOutOfSampleArm], y_axis_var: PlotMetric, - x_axis_var: Optional[PlotMetric] = None, - status_quo_arm: Optional[PlotInSampleArm] = None, -) -> tuple[list[float], Optional[list[float]], list[float], list[float]]: + x_axis_var: PlotMetric | None = None, + status_quo_arm: PlotInSampleArm | None = None, +) -> tuple[list[float], list[float] | None, list[float], list[float]]: y_metric_key = "y_hat" if y_axis_var.pred else "y" y_sd_key = "se_hat" if y_axis_var.pred else "se" @@ -106,24 +106,24 @@ def _error_scatter_data( def _error_scatter_trace( - arms: Sequence[Union[PlotInSampleArm, PlotOutOfSampleArm]], + arms: Sequence[PlotInSampleArm | PlotOutOfSampleArm], y_axis_var: PlotMetric, - x_axis_var: Optional[PlotMetric] = None, - y_axis_label: Optional[str] = None, - x_axis_label: Optional[str] = None, - status_quo_arm: Optional[PlotInSampleArm] = None, + x_axis_var: PlotMetric | None = None, + y_axis_label: str | None = None, + x_axis_label: str | None = None, + status_quo_arm: PlotInSampleArm | None = None, show_CI: bool = True, name: str = "In-sample", color: tuple[int] = COLORS.STEELBLUE.value, visible: bool = True, - legendgroup: Optional[str] = None, + legendgroup: str | None = None, showlegend: bool = True, hoverinfo: str = "text", show_arm_details_on_hover: bool = True, show_context: bool = False, arm_noun: str = "arm", - color_parameter: Optional[str] = None, - color_metric: Optional[str] = None, + color_parameter: str | None = None, + color_metric: str | None = None, ) -> dict[str, Any]: """Plot scatterplot with error bars. @@ -305,10 +305,10 @@ def _multiple_metric_traces( generator_runs_dict: TNullableGeneratorRunsDict, rel_x: bool, rel_y: bool, - fixed_features: Optional[ObservationFeatures] = None, - data_selector: Optional[Callable[[Observation], bool]] = None, - color_parameter: Optional[str] = None, - color_metric: Optional[str] = None, + fixed_features: ObservationFeatures | None = None, + data_selector: Callable[[Observation], bool] | None = None, + color_parameter: str | None = None, + color_metric: str | None = None, ) -> Traces: """Plot traces for multiple metrics given a model and metrics. @@ -390,10 +390,10 @@ def plot_multiple_metrics( generator_runs_dict: TNullableGeneratorRunsDict = None, rel_x: bool = True, rel_y: bool = True, - fixed_features: Optional[ObservationFeatures] = None, - data_selector: Optional[Callable[[Observation], bool]] = None, - color_parameter: Optional[str] = None, - color_metric: Optional[str] = None, + fixed_features: ObservationFeatures | None = None, + data_selector: Callable[[Observation], bool] | None = None, + color_parameter: str | None = None, + color_metric: str | None = None, **kwargs: Any, ) -> AxPlotConfig: """Plot raw values or predictions of two metrics for arms. @@ -543,15 +543,15 @@ def plot_multiple_metrics( def plot_objective_vs_constraints( model: ModelBridge, objective: str, - subset_metrics: Optional[list[str]] = None, + subset_metrics: list[str] | None = None, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, - infer_relative_constraints: Optional[bool] = False, - fixed_features: Optional[ObservationFeatures] = None, - data_selector: Optional[Callable[[Observation], bool]] = None, - color_parameter: Optional[str] = None, - color_metric: Optional[str] = None, - label_dict: Optional[dict[str, str]] = None, + infer_relative_constraints: bool | None = False, + fixed_features: ObservationFeatures | None = None, + data_selector: Callable[[Observation], bool] | None = None, + color_parameter: str | None = None, + color_metric: str | None = None, + label_dict: dict[str, str] | None = None, ) -> AxPlotConfig: """Plot the tradeoff between an objective and all other metrics in a model. @@ -778,7 +778,7 @@ def plot_objective_vs_constraints( return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC) -def _replace_str(input_str: str, str_dict: Optional[dict[str, str]] = None) -> str: +def _replace_str(input_str: str, str_dict: dict[str, str] | None = None) -> str: """Utility function to replace a string based on a mapping dictionary. Args: @@ -811,7 +811,7 @@ def lattice_multiple_metrics( generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = False, - data_selector: Optional[Callable[[Observation], bool]] = None, + data_selector: Callable[[Observation], bool] | None = None, ) -> AxPlotConfig: """Plot raw values or predictions of combinations of two metrics for arms. @@ -1050,10 +1050,10 @@ def lattice_multiple_metrics( for i, o in enumerate(metrics): pos_x = len(metrics) * len(metrics) - len(metrics) + i + 1 pos_y = 1 + (len(metrics) * i) - fig["layout"]["xaxis{}".format(pos_x)].update( + fig["layout"][f"xaxis{pos_x}"].update( title=_wrap_metric(o), titlefont={"size": 10} ) - fig["layout"]["yaxis{}".format(pos_y)].update( + fig["layout"][f"yaxis{pos_y}"].update( title=_wrap_metric(o), titlefont={"size": 10} ) @@ -1085,9 +1085,9 @@ def _single_metric_traces( showlegend: bool = True, show_CI: bool = True, arm_noun: str = "arm", - fixed_features: Optional[ObservationFeatures] = None, - data_selector: Optional[Callable[[Observation], bool]] = None, - scalarized_metric_config: Optional[list[dict[str, Any]]] = None, + fixed_features: ObservationFeatures | None = None, + data_selector: Callable[[Observation], bool] | None = None, + scalarized_metric_config: list[dict[str, Any]] | None = None, ) -> Traces: """Plot scatterplots with errors for a single metric (y-axis). @@ -1168,11 +1168,11 @@ def plot_fitted( metric: str, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, - custom_arm_order: Optional[list[str]] = None, + custom_arm_order: list[str] | None = None, custom_arm_order_name: str = "Custom", show_CI: bool = True, - data_selector: Optional[Callable[[Observation], bool]] = None, - scalarized_metric_config: Optional[list[dict[str, Any]]] = None, + data_selector: Callable[[Observation], bool] | None = None, + scalarized_metric_config: list[dict[str, Any]] | None = None, ) -> AxPlotConfig: """Plot fitted metrics. @@ -1306,11 +1306,11 @@ def tile_fitted( show_arm_details_on_hover: bool = False, show_CI: bool = True, arm_noun: str = "arm", - metrics: Optional[list[str]] = None, - fixed_features: Optional[ObservationFeatures] = None, - data_selector: Optional[Callable[[Observation], bool]] = None, - scalarized_metric_config: Optional[list[dict[str, Any]]] = None, - label_dict: Optional[dict[str, str]] = None, + metrics: list[str] | None = None, + fixed_features: ObservationFeatures | None = None, + data_selector: Callable[[Observation], bool] | None = None, + scalarized_metric_config: list[dict[str, Any]] | None = None, + label_dict: dict[str, str] | None = None, ) -> AxPlotConfig: """Tile version of fitted outcome plots. @@ -1394,16 +1394,16 @@ def tile_fitted( # xaxis1, xaxis2, xaxis3, etc. Note the discrepancy for the initial # axis. label = "" if i == 0 else i + 1 - name_order_args["xaxis{}.categoryorder".format(label)] = "array" - name_order_args["xaxis{}.categoryarray".format(label)] = names_by_arm - effect_order_args["xaxis{}.categoryorder".format(label)] = "array" - effect_order_args["xaxis{}.categoryarray".format(label)] = names_by_effect - name_order_axes["xaxis{}".format(i + 1)] = { + name_order_args[f"xaxis{label}.categoryorder"] = "array" + name_order_args[f"xaxis{label}.categoryarray"] = names_by_arm + effect_order_args[f"xaxis{label}.categoryorder"] = "array" + effect_order_args[f"xaxis{label}.categoryarray"] = names_by_effect + name_order_axes[f"xaxis{i + 1}"] = { "categoryorder": "array", "categoryarray": names_by_arm, "type": "category", } - name_order_axes["yaxis{}".format(i + 1)] = { + name_order_axes[f"yaxis{i + 1}"] = { "ticksuffix": "%" if rel else "", "zerolinecolor": "red", } @@ -1418,8 +1418,8 @@ def tile_fitted( # if odd number of plots, need to manually remove the last blank subplot # generated by `subplots.make_subplots` if len(metrics) % 2 == 1: - fig["layout"].pop("xaxis{}".format(nrows * ncols)) - fig["layout"].pop("yaxis{}".format(nrows * ncols)) + fig["layout"].pop(f"xaxis{nrows * ncols}") + fig["layout"].pop(f"yaxis{nrows * ncols}") # allocate 400 px per plot fig["layout"].update( @@ -1484,11 +1484,11 @@ def interact_fitted_plotly( show_arm_details_on_hover: bool = True, show_CI: bool = True, arm_noun: str = "arm", - metrics: Optional[list[str]] = None, - fixed_features: Optional[ObservationFeatures] = None, - data_selector: Optional[Callable[[Observation], bool]] = None, - label_dict: Optional[dict[str, str]] = None, - scalarized_metric_config: Optional[list[dict[str, Any]]] = None, + metrics: list[str] | None = None, + fixed_features: ObservationFeatures | None = None, + data_selector: Callable[[Observation], bool] | None = None, + label_dict: dict[str, str] | None = None, + scalarized_metric_config: list[dict[str, Any]] | None = None, ) -> go.Figure: """Interactive fitted outcome plots for each arm used in fitting the model. @@ -1632,11 +1632,11 @@ def interact_fitted( show_arm_details_on_hover: bool = True, show_CI: bool = True, arm_noun: str = "arm", - metrics: Optional[list[str]] = None, - fixed_features: Optional[ObservationFeatures] = None, - data_selector: Optional[Callable[[Observation], bool]] = None, - label_dict: Optional[dict[str, str]] = None, - scalarized_metric_config: Optional[list[dict[str, Any]]] = None, + metrics: list[str] | None = None, + fixed_features: ObservationFeatures | None = None, + data_selector: Callable[[Observation], bool] | None = None, + label_dict: dict[str, str] | None = None, + scalarized_metric_config: list[dict[str, Any]] | None = None, ) -> AxPlotConfig: """Interactive fitted outcome plots for each arm used in fitting the model. @@ -1685,12 +1685,12 @@ def interact_fitted( def tile_observations( experiment: Experiment, - data: Optional[Data] = None, + data: Data | None = None, rel: bool = True, - metrics: Optional[list[str]] = None, - arm_names: Optional[list[str]] = None, + metrics: list[str] | None = None, + arm_names: list[str] | None = None, arm_noun: str = "arm", - label_dict: Optional[dict[str, str]] = None, + label_dict: dict[str, str] | None = None, ) -> AxPlotConfig: """ Tiled plot with all observed outcomes. diff --git a/ax/plot/slice.py b/ax/plot/slice.py index df5853c7861..785171f187c 100644 --- a/ax/plot/slice.py +++ b/ax/plot/slice.py @@ -50,9 +50,9 @@ def _get_slice_predictions( generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, - slice_values: Optional[dict[str, Any]] = None, - fixed_features: Optional[ObservationFeatures] = None, - trial_index: Optional[int] = None, + slice_values: dict[str, Any] | None = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, ) -> SlicePredictions: """Computes slice prediction configuration values for a single metric name. @@ -132,9 +132,9 @@ def plot_slice_plotly( generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, - slice_values: Optional[dict[str, Any]] = None, - fixed_features: Optional[ObservationFeatures] = None, - trial_index: Optional[int] = None, + slice_values: dict[str, Any] | None = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, ) -> go.Figure: """Plot predictions for a 1-d slice of the parameter space. @@ -246,9 +246,9 @@ def plot_slice( generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, - slice_values: Optional[dict[str, Any]] = None, - fixed_features: Optional[ObservationFeatures] = None, - trial_index: Optional[int] = None, + slice_values: dict[str, Any] | None = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, ) -> AxPlotConfig: """Plot predictions for a 1-d slice of the parameter space. @@ -295,9 +295,9 @@ def interact_slice_plotly( generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, - slice_values: Optional[dict[str, Any]] = None, - fixed_features: Optional[ObservationFeatures] = None, - trial_index: Optional[int] = None, + slice_values: dict[str, Any] | None = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, ) -> go.Figure: """Create interactive plot with predictions for a 1-d slice of the parameter space. @@ -545,9 +545,9 @@ def interact_slice( generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, - slice_values: Optional[dict[str, Any]] = None, - fixed_features: Optional[ObservationFeatures] = None, - trial_index: Optional[int] = None, + slice_values: dict[str, Any] | None = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, ) -> AxPlotConfig: """Create interactive plot with predictions for a 1-d slice of the parameter space. diff --git a/ax/plot/table_view.py b/ax/plot/table_view.py index fef9cdbbd9c..b8074a180f1 100644 --- a/ax/plot/table_view.py +++ b/ax/plot/table_view.py @@ -122,10 +122,7 @@ def table_view_plot( ] ) records.append( - [ - "{:.3f} ± {:.3f}".format(y, Z * y_se) - for (_, y, y_se) in results_by_arm - ] + [f"{y:.3f} ± {Z * y_se:.3f}" for (_, y, y_se) in results_by_arm] ) records_with_mean.append({arm_name: y for (arm_name, y, _) in results_by_arm}) records_with_ci.append( diff --git a/ax/plot/tests/test_tile_fitted.py b/ax/plot/tests/test_tile_fitted.py index 92bd313bb1c..b68b8a13443 100644 --- a/ax/plot/tests/test_tile_fitted.py +++ b/ax/plot/tests/test_tile_fitted.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Optional from unittest import mock from ax.core.arm import Arm @@ -38,7 +37,7 @@ def get_modelbridge( mock_gen_arms, # pyre-fixme[2]: Parameter must be annotated. mock_observations_from_data, - status_quo_name: Optional[str] = None, + status_quo_name: str | None = None, ) -> ModelBridge: exp = get_experiment() modelbridge = ModelBridge( diff --git a/ax/plot/trace.py b/ax/plot/trace.py index 36dbb9ef477..83691695cc7 100644 --- a/ax/plot/trace.py +++ b/ax/plot/trace.py @@ -7,7 +7,7 @@ # pyre-strict from datetime import datetime, timedelta -from typing import Any, Optional, Union +from typing import Any import numpy as np import pandas as pd @@ -102,7 +102,7 @@ def map_data_multiple_metrics_dropdown_plotly( legend_labels_by_metric: dict[str, list[str]], stopping_markers_by_metric: dict[str, list[bool]], xlabels_by_metric: dict[str, str], - lower_is_better_by_metric: dict[str, Optional[bool]], + lower_is_better_by_metric: dict[str, bool | None], opacity: float = 0.75, color_map: str = "viridis", autoset_axis_limits: bool = True, @@ -212,7 +212,7 @@ def mean_trace_scatter( y: np.ndarray, trace_color: tuple[int] = COLORS.STEELBLUE.value, legend_label: str = "mean", - hover_labels: Optional[list[str]] = None, + hover_labels: list[str] | None = None, ) -> go.Scatter: """Creates a graph object for trace of the mean of the given series across runs. @@ -288,7 +288,7 @@ def mean_markers_scatter( y: np.ndarray, marker_color: tuple[int] = COLORS.LIGHT_PURPLE.value, legend_label: str = "", - hover_labels: Optional[list[str]] = None, + hover_labels: list[str] | None = None, ) -> go.Scatter: """Creates a graph object for trace of the mean of the given series across runs, with errorbars. @@ -347,15 +347,15 @@ def optimum_objective_scatter( def optimization_trace_single_method_plotly( y: np.ndarray, - optimum: Optional[float] = None, - model_transitions: Optional[list[int]] = None, + optimum: float | None = None, + model_transitions: list[int] | None = None, title: str = "", ylabel: str = "", - hover_labels: Optional[list[str]] = None, + hover_labels: list[str] | None = None, trace_color: tuple[int] = COLORS.STEELBLUE.value, optimum_color: tuple[int] = COLORS.ORANGE.value, generator_change_color: tuple[int] = COLORS.TEAL.value, - optimization_direction: Optional[str] = "passthrough", + optimization_direction: str | None = "passthrough", plot_trial_points: bool = False, trial_points_color: tuple[int] = COLORS.LIGHT_PURPLE.value, autoset_axis_limits: bool = True, @@ -451,7 +451,7 @@ def optimization_trace_single_method_plotly( def _autoset_axis_limits( y: np.ndarray, optimization_direction: str, - force_include_value: Optional[float] = None, + force_include_value: float | None = None, ) -> list[float]: """Provides automatic axis limits based on the data and optimization direction. All best points are included in this range, and by default the worst points are @@ -482,15 +482,15 @@ def _autoset_axis_limits( def optimization_trace_single_method( y: np.ndarray, - optimum: Optional[float] = None, - model_transitions: Optional[list[int]] = None, + optimum: float | None = None, + model_transitions: list[int] | None = None, title: str = "", ylabel: str = "", - hover_labels: Optional[list[str]] = None, + hover_labels: list[str] | None = None, trace_color: tuple[int] = COLORS.STEELBLUE.value, optimum_color: tuple[int] = COLORS.ORANGE.value, generator_change_color: tuple[int] = COLORS.TEAL.value, - optimization_direction: Optional[str] = "passthrough", + optimization_direction: str | None = "passthrough", plot_trial_points: bool = False, trial_points_color: tuple[int] = COLORS.LIGHT_PURPLE.value, autoset_axis_limits: bool = True, @@ -549,10 +549,10 @@ def optimization_trace_single_method( def optimization_trace_all_methods( y_dict: dict[str, np.ndarray], - optimum: Optional[float] = None, + optimum: float | None = None, title: str = "", ylabel: str = "", - hover_labels: Optional[list[str]] = None, + hover_labels: list[str] | None = None, trace_colors: list[tuple[int]] = DISCRETE_COLOR_SCALE, optimum_color: tuple[int] = COLORS.ORANGE.value, ) -> AxPlotConfig: @@ -627,12 +627,12 @@ def optimization_times( """ # Compute means and SEs methods = list(fit_times.keys()) - fit_res: dict[str, Union[str, list[float]]] = {"name": "Fitting"} + fit_res: dict[str, str | list[float]] = {"name": "Fitting"} fit_res["mean"] = [np.mean(fit_times[m]) for m in methods] fit_res["2sems"] = [ 2 * np.std(fit_times[m]) / np.sqrt(len(fit_times[m])) for m in methods ] - gen_res: dict[str, Union[str, list[float]]] = {"name": "Generation"} + gen_res: dict[str, str | list[float]] = {"name": "Generation"} gen_res["mean"] = [np.mean(gen_times[m]) for m in methods] gen_res["2sems"] = [ 2 * np.std(gen_times[m]) / np.sqrt(len(gen_times[m])) for m in methods @@ -643,7 +643,7 @@ def optimization_times( totals = np.array(fit_times[m]) + np.array(gen_times[m]) total_mean.append(np.mean(totals)) total_2sems.append(2 * np.std(totals) / np.sqrt(len(totals))) - total_res: dict[str, Union[str, list[float]]] = { + total_res: dict[str, str | list[float]] = { "name": "Total", "mean": total_mean, "2sems": total_2sems, @@ -688,7 +688,7 @@ def get_running_trials_per_minute( experiment: Experiment, show_until_latest_end_plus_timedelta: timedelta = FIVE_MINUTES, ) -> AxPlotConfig: - trial_runtimes: list[tuple[int, datetime, Optional[datetime]]] = [ + trial_runtimes: list[tuple[int, datetime, datetime | None]] = [ ( trial.index, not_none(trial._time_run_started), @@ -739,8 +739,8 @@ def plot_objective_value_vs_trial_index( exp_df: pd.DataFrame, metric_colname: str, minimize: bool, - title: Optional[str] = None, - hover_data_colnames: Optional[list[str]] = None, + title: str | None = None, + hover_data_colnames: list[str] | None = None, autoset_axis_limits: bool = True, ) -> go.Figure: """Returns a plotly figure showing the optimization trace for a single metric. @@ -818,7 +818,7 @@ def compute_running_feasible_optimum_df( exp_df: pd.DataFrame, metric_colname: str, minimize: bool, - is_feasible_colname: Optional[str], + is_feasible_colname: str | None, ) -> pd.DataFrame: """Computes the running feasible optimum for a given metric.""" # If feasibility column is not provided, assume all arms are feasible. diff --git a/ax/runners/simulated_backend.py b/ax/runners/simulated_backend.py index a0d9cdbf4bc..d15cbe2de72 100644 --- a/ax/runners/simulated_backend.py +++ b/ax/runners/simulated_backend.py @@ -8,8 +8,8 @@ from collections import defaultdict -from collections.abc import Iterable -from typing import Any, Callable, Optional +from collections.abc import Callable, Iterable +from typing import Any import numpy as np from ax.core.base_trial import BaseTrial, TrialStatus @@ -23,7 +23,7 @@ class SimulatedBackendRunner(Runner): def __init__( self, simulator: BackendSimulator, - sample_runtime_func: Optional[Callable[[BaseTrial], float]] = None, + sample_runtime_func: Callable[[BaseTrial], float] | None = None, ) -> None: """Runner for a BackendSimulator. @@ -67,7 +67,7 @@ def run(self, trial: BaseTrial) -> dict[str, Any]: self.simulator.run_trial(trial_index=trial.index, runtime=runtime) return {"runtime": runtime} - def stop(self, trial: BaseTrial, reason: Optional[str] = None) -> dict[str, Any]: + def stop(self, trial: BaseTrial, reason: str | None = None) -> dict[str, Any]: """Stop a trial on the BackendSimulator. Args: diff --git a/ax/runners/synthetic.py b/ax/runners/synthetic.py index 91fb93db0e8..fc08c9a35b5 100644 --- a/ax/runners/synthetic.py +++ b/ax/runners/synthetic.py @@ -7,7 +7,7 @@ # pyre-strict from collections.abc import Iterable -from typing import Any, Optional +from typing import Any from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.runner import Runner @@ -19,7 +19,7 @@ class SyntheticRunner(Runner): Currently acts as a shell runner, only creating a name. """ - def __init__(self, dummy_metadata: Optional[str] = None) -> None: + def __init__(self, dummy_metadata: str | None = None) -> None: self.dummy_metadata = dummy_metadata def run(self, trial: BaseTrial) -> dict[str, Any]: diff --git a/ax/runners/torchx.py b/ax/runners/torchx.py index 0722fb17f9a..fbf7bc211e2 100644 --- a/ax/runners/torchx.py +++ b/ax/runners/torchx.py @@ -7,10 +7,10 @@ # pyre-strict import inspect -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from logging import Logger -from typing import Any, Callable, Optional +from typing import Any from ax.core import Trial from ax.core.base_trial import BaseTrial, TrialStatus @@ -117,13 +117,13 @@ def __init__( self, tracker_base: str, component: Callable[..., AppDef], - component_const_params: Optional[dict[str, Any]] = None, + component_const_params: dict[str, Any] | None = None, scheduler: str = "local", - cfg: Optional[Mapping[str, CfgVal]] = None, + cfg: Mapping[str, CfgVal] | None = None, ) -> None: self._component: Callable[..., AppDef] = component self._scheduler: str = scheduler - self._cfg: Optional[Mapping[str, CfgVal]] = cfg + self._cfg: Mapping[str, CfgVal] | None = cfg # need to use the same runner in case it has state # e.g. torchx's local_scheduler has state hence need to poll status # on the same scheduler instance @@ -180,9 +180,7 @@ def poll_trial_status( return trial_statuses - def stop( - self, trial: BaseTrial, reason: Optional[str] = None - ) -> dict[str, Any]: + def stop(self, trial: BaseTrial, reason: str | None = None) -> dict[str, Any]: """Kill the given trial.""" app_handle: str = trial.run_metadata[TORCHX_APP_HANDLE] self._torchx_runner.stop(app_handle) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 0e508b3b546..4e1f4982a86 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -9,11 +9,11 @@ import json import logging import warnings -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial from logging import Logger -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar import ax.service.utils.early_stopping as early_stopping_utils import numpy as np @@ -174,19 +174,19 @@ class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase): whether the full optimization should be stopped or not. """ - _experiment: Optional[Experiment] = None + _experiment: Experiment | None = None def __init__( self, - generation_strategy: Optional[GenerationStrategy] = None, - db_settings: Optional[DBSettings] = None, + generation_strategy: GenerationStrategy | None = None, + db_settings: DBSettings | None = None, enforce_sequential_optimization: bool = True, - random_seed: Optional[int] = None, - torch_device: Optional[torch.device] = None, + random_seed: int | None = None, + torch_device: torch.device | None = None, verbose_logging: bool = True, suppress_storage_errors: bool = False, - early_stopping_strategy: Optional[BaseEarlyStoppingStrategy] = None, - global_stopping_strategy: Optional[BaseGlobalStoppingStrategy] = None, + early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, + global_stopping_strategy: BaseGlobalStoppingStrategy | None = None, ) -> None: super().__init__( db_settings=db_settings, @@ -233,23 +233,23 @@ def __init__( def create_experiment( self, parameters: list[ - dict[str, Union[TParamValue, Sequence[TParamValue], dict[str, list[str]]]] + dict[str, TParamValue | Sequence[TParamValue] | dict[str, list[str]]] ], - name: Optional[str] = None, - description: Optional[str] = None, - owners: Optional[list[str]] = None, - objectives: Optional[dict[str, ObjectiveProperties]] = None, - parameter_constraints: Optional[list[str]] = None, - outcome_constraints: Optional[list[str]] = None, - status_quo: Optional[TParameterization] = None, + name: str | None = None, + description: str | None = None, + owners: list[str] | None = None, + objectives: dict[str, ObjectiveProperties] | None = None, + parameter_constraints: list[str] | None = None, + outcome_constraints: list[str] | None = None, + status_quo: TParameterization | None = None, overwrite_existing_experiment: bool = False, - experiment_type: Optional[str] = None, - tracking_metric_names: Optional[list[str]] = None, - choose_generation_strategy_kwargs: Optional[dict[str, Any]] = None, + experiment_type: str | None = None, + tracking_metric_names: list[str] | None = None, + choose_generation_strategy_kwargs: dict[str, Any] | None = None, support_intermediate_data: bool = False, immutable_search_space_and_opt_config: bool = True, is_test: bool = False, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> None: """Create a new experiment and save it if DBSettings available. @@ -356,13 +356,13 @@ def create_experiment( self._save_generation_strategy_to_db_if_possible() @property - def status_quo(self) -> Optional[TParameterization]: + def status_quo(self) -> TParameterization | None: """The parameterization of the status quo arm of the experiment.""" if self.experiment.status_quo: return self.experiment.status_quo.parameters return None - def set_status_quo(self, params: Optional[TParameterization]) -> None: + def set_status_quo(self, params: TParameterization | None) -> None: """Set, or unset status quo on the experiment. There may be risk in using this after a trial with the status quo arm has run. @@ -375,9 +375,9 @@ def set_status_quo(self, params: Optional[TParameterization]) -> None: def set_optimization_config( self, - objectives: Optional[dict[str, ObjectiveProperties]] = None, - outcome_constraints: Optional[list[str]] = None, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + objectives: dict[str, ObjectiveProperties] | None = None, + outcome_constraints: list[str] | None = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> None: """Overwrite experiment's optimization config @@ -414,7 +414,7 @@ def set_optimization_config( def add_tracking_metrics( self, metric_names: list[str], - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> None: """Add a list of new metrics to the experiment. @@ -449,9 +449,9 @@ def remove_tracking_metric(self, metric_name: str) -> None: def set_search_space( self, parameters: list[ - dict[str, Union[TParamValue, Sequence[TParamValue], dict[str, list[str]]]] + dict[str, TParamValue | Sequence[TParamValue] | dict[str, list[str]]] ], - parameter_constraints: Optional[list[str]] = None, + parameter_constraints: list[str] | None = None, ) -> None: """Sets the search space on the experiment and saves. This is expected to fail on base AxClient as experiment will have @@ -500,9 +500,9 @@ def set_search_space( ) def get_next_trial( self, - ttl_seconds: Optional[int] = None, + ttl_seconds: int | None = None, force: bool = False, - fixed_features: Optional[FixedFeatures] = None, + fixed_features: FixedFeatures | None = None, ) -> tuple[TParameterization, int]: """ Generate trial with the next set of parameters to try in the iteration process. @@ -595,8 +595,8 @@ def get_current_trial_generation_limit(self) -> tuple[int, bool]: def get_next_trials( self, max_trials: int, - ttl_seconds: Optional[int] = None, - fixed_features: Optional[FixedFeatures] = None, + ttl_seconds: int | None = None, + fixed_features: FixedFeatures | None = None, ) -> tuple[dict[int, TParameterization], bool]: """Generate as many trials as currently possible. @@ -650,7 +650,7 @@ def get_next_trials( _, optimization_complete = self.get_current_trial_generation_limit() return trials_dict, optimization_complete - def abandon_trial(self, trial_index: int, reason: Optional[str] = None) -> None: + def abandon_trial(self, trial_index: int, reason: str | None = None) -> None: """Abandons a trial and adds optional metadata to it. Args: @@ -663,8 +663,8 @@ def update_running_trial_with_intermediate_data( self, trial_index: int, raw_data: TEvaluationOutcome, - metadata: Optional[dict[str, Union[str, int]]] = None, - sample_size: Optional[int] = None, + metadata: dict[str, str | int] | None = None, + sample_size: int | None = None, ) -> None: """ Updates the trial with given metric values without completing it. Also @@ -733,8 +733,8 @@ def complete_trial( self, trial_index: int, raw_data: TEvaluationOutcome, - metadata: Optional[dict[str, Union[str, int]]] = None, - sample_size: Optional[int] = None, + metadata: dict[str, str | int] | None = None, + sample_size: int | None = None, ) -> None: """ Completes the trial with given metric values and adds optional metadata @@ -784,8 +784,8 @@ def update_trial_data( self, trial_index: int, raw_data: TEvaluationOutcome, - metadata: Optional[dict[str, Union[str, int]]] = None, - sample_size: Optional[int] = None, + metadata: dict[str, str | int] | None = None, + sample_size: int | None = None, ) -> None: """ Attaches additional data or updates the existing data for a trial in a @@ -826,7 +826,7 @@ def update_trial_data( logger.info(f"Added data: {data_update_repr} to trial {trial.index}.") def log_trial_failure( - self, trial_index: int, metadata: Optional[dict[str, str]] = None + self, trial_index: int, metadata: dict[str, str] | None = None ) -> None: """Mark that the given trial has failed while running. @@ -846,9 +846,9 @@ def log_trial_failure( def attach_trial( self, parameters: TParameterization, - ttl_seconds: Optional[int] = None, - run_metadata: Optional[dict[str, Any]] = None, - arm_name: Optional[str] = None, + ttl_seconds: int | None = None, + run_metadata: dict[str, Any] | None = None, + arm_name: str | None = None, ) -> tuple[TParameterization, int]: """Attach a new trial with the given parameterization to the experiment. @@ -922,7 +922,7 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: return parallelism_settings def get_optimization_trace( - self, objective_optimum: Optional[float] = None + self, objective_optimum: float | None = None ) -> AxPlotConfig: """Retrieves the plot configuration for optimization trace, which shows the evolution of the objective mean over iterations. @@ -978,9 +978,9 @@ def _constrained_trial_objective_mean(trial: BaseTrial) -> float: def get_contour_plot( self, - param_x: Optional[str] = None, - param_y: Optional[str] = None, - metric_name: Optional[str] = None, + param_x: str | None = None, + param_y: str | None = None, + metric_name: str | None = None, ) -> AxPlotConfig: """Retrieves a plot configuration for a contour plot of the response surface. For response surfaces with more than two parameters, @@ -1099,7 +1099,7 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig: def load_experiment_from_database( self, experiment_name: str, - choose_generation_strategy_kwargs: Optional[dict[str, Any]] = None, + choose_generation_strategy_kwargs: dict[str, Any] | None = None, ) -> None: """Load an existing experiment from database using the `DBSettings` passed to this `AxClient` on instantiation. @@ -1132,7 +1132,7 @@ def load_experiment_from_database( def get_model_predictions_for_parameterizations( self, parameterizations: list[TParameterization], - metric_names: Optional[list[str]] = None, + metric_names: list[str] | None = None, ) -> list[dict[str, tuple[float, float]]]: """Retrieve model-estimated means and covariances for all metrics for the provided parameterizations. @@ -1162,9 +1162,9 @@ def get_model_predictions_for_parameterizations( def get_model_predictions( self, - metric_names: Optional[list[str]] = None, - include_out_of_sample: Optional[bool] = True, - parameterizations: Optional[dict[int, TParameterization]] = None, + metric_names: list[str] | None = None, + include_out_of_sample: bool | None = True, + parameterizations: dict[int, TParameterization] | None = None, ) -> dict[int, dict[str, tuple[float, float]]]: """Retrieve model-estimated means and covariances for all metrics. @@ -1280,7 +1280,7 @@ def verify_trial_parameterization( def should_stop_trials_early( self, trial_indices: set[int] - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: """Evaluate whether to early-stop running trials. Args: @@ -1309,7 +1309,7 @@ def stop_trial_early(self, trial_index: int) -> None: experiment=self.experiment, trial=trial ) - def estimate_early_stopping_savings(self, map_key: Optional[str] = None) -> float: + def estimate_early_stopping_savings(self, map_key: str | None = None) -> float: """Estimate early stopping savings using progressions of the MapMetric present on the EarlyStoppingConfig as a proxy for resource usage. @@ -1367,7 +1367,7 @@ def load_from_json_file( """Restore an `AxClient` and its state from a JSON-serialized snapshot, residing in a .json file by the given path. """ - with open(filepath, "r") as file: + with open(filepath) as file: serialized = json.loads(file.read()) return cls.from_json_snapshot(serialized=serialized, **kwargs) @@ -1376,13 +1376,13 @@ def to_json_snapshot( # pyre-fixme[2]: Parameter annotation cannot contain `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. - encoder_registry: Optional[dict[type, Callable[[Any], dict[str, Any]]]] = None, + encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] | None = None, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. - class_encoder_registry: Optional[ + class_encoder_registry: None | ( dict[type, Callable[[Any], dict[str, Any]]] - ] = None, + ) = None, ) -> dict[str, Any]: """Serialize this `AxClient` to JSON to be able to interrupt and restart optimization and save it to file by the provided path. @@ -1415,11 +1415,11 @@ def to_json_snapshot( def from_json_snapshot( cls: type[AxClientSubclass], serialized: dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, + decoder_registry: TDecoderRegistry | None = None, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Optional[ + class_decoder_registry: None | ( dict[str, Callable[[dict[str, Any]], Any]] - ] = None, + ) = None, # pyre-fixme[2]: Parameter must be annotated. **kwargs, ) -> AxClientSubclass: @@ -1519,7 +1519,7 @@ def metric_names(self) -> set[str]: return set(self.experiment.metrics) @property - def early_stopping_strategy(self) -> Optional[BaseEarlyStoppingStrategy]: + def early_stopping_strategy(self) -> BaseEarlyStoppingStrategy | None: """The early stopping strategy used on the experiment.""" return self._early_stopping_strategy @@ -1529,7 +1529,7 @@ def early_stopping_strategy(self, ess: BaseEarlyStoppingStrategy) -> None: self._early_stopping_strategy = ess @property - def global_stopping_strategy(self) -> Optional[BaseGlobalStoppingStrategy]: + def global_stopping_strategy(self) -> BaseGlobalStoppingStrategy | None: """The global stopping strategy used on the experiment.""" return self._global_stopping_strategy @@ -1541,10 +1541,10 @@ def global_stopping_strategy(self, gss: BaseGlobalStoppingStrategy) -> None: @copy_doc(BestPointMixin.get_best_trial) def get_best_trial( self, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, - ) -> Optional[tuple[int, TParameterization, Optional[TModelPredictArm]]]: + ) -> tuple[int, TParameterization, TModelPredictArm | None] | None: return self._get_best_trial( experiment=self.experiment, generation_strategy=self.generation_strategy, @@ -1555,8 +1555,8 @@ def get_best_trial( @copy_doc(BestPointMixin.get_pareto_optimal_parameters) def get_pareto_optimal_parameters( self, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> dict[int, tuple[TParameterization, TModelPredictArm]]: return self._get_pareto_optimal_parameters( @@ -1569,8 +1569,8 @@ def get_pareto_optimal_parameters( @copy_doc(BestPointMixin.get_hypervolume) def get_hypervolume( self, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> float: return BestPointMixin._get_hypervolume( @@ -1584,7 +1584,7 @@ def get_hypervolume( @copy_doc(BestPointMixin.get_trace) def get_trace( self, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, ) -> list[float]: return BestPointMixin._get_trace( experiment=self.experiment, @@ -1594,8 +1594,8 @@ def get_trace( @copy_doc(BestPointMixin.get_trace_by_progression) def get_trace_by_progression( self, - optimization_config: Optional[OptimizationConfig] = None, - bins: Optional[list[float]] = None, + optimization_config: OptimizationConfig | None = None, + bins: list[float] | None = None, final_progression_only: bool = False, ) -> tuple[list[float], list[float]]: return BestPointMixin._get_trace_by_progression( @@ -1609,8 +1609,8 @@ def _update_trial_with_raw_data( self, trial_index: int, raw_data: TEvaluationOutcome, - metadata: Optional[dict[str, Union[str, int]]] = None, - sample_size: Optional[int] = None, + metadata: dict[str, str | int] | None = None, + sample_size: int | None = None, complete_trial: bool = False, combine_with_last_data: bool = False, ) -> str: @@ -1711,7 +1711,7 @@ def _set_runner(self, experiment: Experiment) -> None: experiment.runner = None def _set_generation_strategy( - self, choose_generation_strategy_kwargs: Optional[dict[str, Any]] = None + self, choose_generation_strategy_kwargs: dict[str, Any] | None = None ) -> None: """Selects the generation strategy and applies specified dispatch kwargs, if any. @@ -1746,14 +1746,14 @@ def _set_generation_strategy( def _save_generation_strategy_to_db_if_possible( self, - generation_strategy: Optional[GenerationStrategyInterface] = None, + generation_strategy: GenerationStrategyInterface | None = None, ) -> bool: return super()._save_generation_strategy_to_db_if_possible( generation_strategy=generation_strategy or self.generation_strategy, ) def _gen_new_generator_run( - self, n: int = 1, fixed_features: Optional[FixedFeatures] = None + self, n: int = 1, fixed_features: FixedFeatures | None = None ) -> GeneratorRun: """Generate new generator run for this experiment. @@ -1824,7 +1824,7 @@ def _validate_all_required_metrics_present( def _get_pending_observation_features( cls, experiment: Experiment, - ) -> Optional[dict[str, list[ObservationFeatures]]]: + ) -> dict[str, list[ObservationFeatures]] | None: """Extract pending points for the given experiment. NOTE: With one-arm `Trial`-s, we use a more performant @@ -1870,14 +1870,14 @@ def load_experiment(experiment_name: str) -> None: ) @staticmethod - def load(filepath: Optional[str] = None) -> None: + def load(filepath: str | None = None) -> None: raise NotImplementedError( "Use `load_experiment_from_database` to load from SQL database or " "`load_from_json_file` to load optimization state from .json file." ) @staticmethod - def save(filepath: Optional[str] = None) -> None: + def save(filepath: str | None = None) -> None: raise NotImplementedError( "Use `save_to_json_file` to save optimization state to .json file." ) diff --git a/ax/service/interactive_loop.py b/ax/service/interactive_loop.py index 12c9e493ec2..c986d10b2c8 100644 --- a/ax/service/interactive_loop.py +++ b/ax/service/interactive_loop.py @@ -6,10 +6,11 @@ # pyre-strict import time +from collections.abc import Callable from logging import Logger from queue import Queue from threading import Event, Lock, Thread -from typing import Any, Callable, Optional, Tuple +from typing import Any, Tuple from ax.core.types import TEvaluationOutcome, TParameterization @@ -30,9 +31,9 @@ def interactive_optimize( data_attacher_function: Callable[..., None], # pyre-ignore[2]: Missing parameter annotation elicitation_function: Callable[..., Any], - candidate_generator_kwargs: Optional[dict[str, Any]] = None, - data_attacher_kwargs: Optional[dict[str, Any]] = None, - elicitation_function_kwargs: Optional[dict[str, Any]] = None, + candidate_generator_kwargs: dict[str, Any] | None = None, + data_attacher_kwargs: dict[str, Any] | None = None, + elicitation_function_kwargs: dict[str, Any] | None = None, ) -> bool: """ Function to facilitate running Ax experiments with candidate pregeneration (the diff --git a/ax/service/managed_loop.py b/ax/service/managed_loop.py index fa6e2e15727..9f85bf18451 100644 --- a/ax/service/managed_loop.py +++ b/ax/service/managed_loop.py @@ -11,7 +11,6 @@ import inspect import logging from collections.abc import Iterable -from typing import Optional from ax.core.arm import Arm from ax.core.base_trial import BaseTrial @@ -59,10 +58,10 @@ def __init__( evaluation_function: TEvaluationFunction, total_trials: int = 20, arms_per_trial: int = 1, - random_seed: Optional[int] = None, + random_seed: int | None = None, wait_time: int = 0, run_async: bool = False, # TODO[Lena], - generation_strategy: Optional[GenerationStrategy] = None, + generation_strategy: GenerationStrategy | None = None, ) -> None: assert not run_async, "OptimizationLoop does not yet support async." self.wait_time = wait_time @@ -91,17 +90,17 @@ def __init__( def with_evaluation_function( parameters: list[TParameterRepresentation], evaluation_function: TEvaluationFunction, - experiment_name: Optional[str] = None, - objective_name: Optional[str] = None, + experiment_name: str | None = None, + objective_name: str | None = None, minimize: bool = False, - parameter_constraints: Optional[list[str]] = None, - outcome_constraints: Optional[list[str]] = None, + parameter_constraints: list[str] | None = None, + outcome_constraints: list[str] | None = None, total_trials: int = 20, arms_per_trial: int = 1, wait_time: int = 0, - random_seed: Optional[int] = None, - generation_strategy: Optional[GenerationStrategy] = None, - ) -> "OptimizationLoop": + random_seed: int | None = None, + generation_strategy: GenerationStrategy | None = None, + ) -> OptimizationLoop: """Constructs a synchronous `OptimizationLoop` using an evaluation function.""" if objective_name is None: @@ -129,23 +128,23 @@ def with_runners_and_metrics( parameters: list[TParameterRepresentation], path_to_runner: str, paths_to_metrics: list[str], - experiment_name: Optional[str] = None, - objective_name: Optional[str] = None, + experiment_name: str | None = None, + objective_name: str | None = None, minimize: bool = False, - parameter_constraints: Optional[list[str]] = None, - outcome_constraints: Optional[list[str]] = None, + parameter_constraints: list[str] | None = None, + outcome_constraints: list[str] | None = None, total_trials: int = 20, arms_per_trial: int = 1, wait_time: int = 0, - random_seed: Optional[int] = None, - ) -> "OptimizationLoop": + random_seed: int | None = None, + ) -> OptimizationLoop: """Constructs an asynchronous `OptimizationLoop` using Ax runners and metrics.""" # NOTE: Could use `Scheduler` to implement this if needed. raise NotImplementedError def _call_evaluation_function( - self, parameterization: TParameterization, weight: Optional[float] = None + self, parameterization: TParameterization, weight: float | None = None ) -> TEvaluationOutcome: signature = inspect.signature(self.evaluation_function) num_evaluation_function_params = len(signature.parameters.items()) @@ -186,7 +185,7 @@ def _get_new_trial(self) -> BaseTrial: def _get_weights_by_arm( self, trial: BaseTrial - ) -> Iterable[tuple[Arm, Optional[float]]]: + ) -> Iterable[tuple[Arm, float | None]]: if isinstance(trial, Trial): if trial.arm is not None: return [(not_none(trial.arm), None)] @@ -247,7 +246,7 @@ def full_run(self) -> OptimizationLoop: return self return self - def get_best_point(self) -> tuple[TParameterization, Optional[TModelPredictArm]]: + def get_best_point(self) -> tuple[TParameterization, TModelPredictArm | None]: """Obtains the best point encountered in the course of this optimization.""" # Find latest trial which has a generator_run attached and get its predictions @@ -270,7 +269,7 @@ def get_best_point(self) -> tuple[TParameterization, Optional[TModelPredictArm]] ), ) - def get_current_model(self) -> Optional[ModelBridge]: + def get_current_model(self) -> ModelBridge | None: """Obtain the most recently used model in optimization.""" return self.generation_strategy.model @@ -278,18 +277,16 @@ def get_current_model(self) -> Optional[ModelBridge]: def optimize( parameters: list[TParameterRepresentation], evaluation_function: TEvaluationFunction, - experiment_name: Optional[str] = None, - objective_name: Optional[str] = None, + experiment_name: str | None = None, + objective_name: str | None = None, minimize: bool = False, - parameter_constraints: Optional[list[str]] = None, - outcome_constraints: Optional[list[str]] = None, + parameter_constraints: list[str] | None = None, + outcome_constraints: list[str] | None = None, total_trials: int = 20, arms_per_trial: int = 1, - random_seed: Optional[int] = None, - generation_strategy: Optional[GenerationStrategy] = None, -) -> tuple[ - TParameterization, Optional[TModelPredictArm], Experiment, Optional[ModelBridge] -]: + random_seed: int | None = None, + generation_strategy: GenerationStrategy | None = None, +) -> tuple[TParameterization, TModelPredictArm | None, Experiment, ModelBridge | None]: """Construct and run a full optimization loop.""" loop = OptimizationLoop.with_evaluation_function( parameters=parameters, diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 2d4bdf5051a..c61548dbcc1 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -9,12 +9,12 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Generator, Iterable +from collections.abc import Callable, Generator, Iterable from copy import deepcopy from datetime import datetime from logging import LoggerAdapter from time import sleep -from typing import Any, Callable, cast, NamedTuple, Optional +from typing import Any, cast, NamedTuple import ax.service.utils.early_stopping as early_stopping_utils from ax.analysis.analysis import Analysis, AnalysisCard @@ -158,11 +158,11 @@ class Scheduler(WithDBSettingsBase, BestPointMixin): _num_trials_bad_due_to_err: int = 0 # Timestamp of last optimization start time (milliseconds since Unix epoch); # recorded in each `run_n_trials`. - _latest_optimization_start_timestamp: Optional[int] = None + _latest_optimization_start_timestamp: int | None = None # Timeout setting for current optimization. - _timeout_hours: Optional[float] = None + _timeout_hours: float | None = None # Timestamp of when the last deployed trial started running. - _latest_trial_start_timestamp: Optional[float] = None + _latest_trial_start_timestamp: float | None = None # Will be set to `True` if generation strategy signals that the optimization # is complete, in which case the optimization should gracefully exit early. _optimization_complete: bool = False @@ -180,7 +180,7 @@ def __init__( experiment: Experiment, generation_strategy: GenerationStrategyInterface, options: SchedulerOptions, - db_settings: Optional[DBSettings] = None, + db_settings: DBSettings | None = None, _skip_experiment_save: bool = False, ) -> None: self.experiment = experiment @@ -242,8 +242,8 @@ def from_stored_experiment( cls, experiment_name: str, options: SchedulerOptions, - db_settings: Optional[DBSettings] = None, - generation_strategy: Optional[GenerationStrategy] = None, + db_settings: DBSettings | None = None, + generation_strategy: GenerationStrategy | None = None, reduced_state: bool = True, **kwargs: Any, ) -> Scheduler: @@ -321,7 +321,7 @@ def options(self, options: SchedulerOptions) -> None: self._validate_runner_and_implemented_metrics(experiment=self.experiment) @property - def trial_type(self) -> Optional[str]: + def trial_type(self) -> str | None: """Trial type for the experiment this scheduler is running. This returns None if the experiment is not a MultitypeExperiment @@ -518,10 +518,10 @@ def completion_criterion(self) -> tuple[bool, str]: @copy_doc(BestPointMixin.get_best_trial) def get_best_trial( self, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, - ) -> Optional[tuple[int, TParameterization, Optional[TModelPredictArm]]]: + ) -> tuple[int, TParameterization, TModelPredictArm | None] | None: return self._get_best_trial( experiment=self.experiment, generation_strategy=self.standard_generation_strategy, @@ -533,8 +533,8 @@ def get_best_trial( @copy_doc(BestPointMixin.get_pareto_optimal_parameters) def get_pareto_optimal_parameters( self, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> dict[int, tuple[TParameterization, TModelPredictArm]]: return self._get_pareto_optimal_parameters( @@ -548,8 +548,8 @@ def get_pareto_optimal_parameters( @copy_doc(BestPointMixin.get_hypervolume) def get_hypervolume( self, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> float: return BestPointMixin._get_hypervolume( @@ -563,7 +563,7 @@ def get_hypervolume( @copy_doc(BestPointMixin.get_trace) def get_trace( self, - optimization_config: Optional[OptimizationConfig] = None, + optimization_config: OptimizationConfig | None = None, ) -> list[float]: return BestPointMixin._get_trace( experiment=self.experiment, @@ -573,8 +573,8 @@ def get_trace( @copy_doc(BestPointMixin.get_trace_by_progression) def get_trace_by_progression( self, - optimization_config: Optional[OptimizationConfig] = None, - bins: Optional[list[float]] = None, + optimization_config: OptimizationConfig | None = None, + bins: list[float] | None = None, final_progression_only: bool = False, ) -> tuple[list[float], list[float]]: return BestPointMixin._get_trace_by_progression( @@ -609,7 +609,7 @@ def summarize_final_result(self) -> OptimizationResult: def get_improvement_over_baseline( self, - baseline_arm_name: Optional[str] = None, + baseline_arm_name: str | None = None, ) -> float: """Returns the scalarized improvement over baseline, if applicable. @@ -727,7 +727,7 @@ def poll_trial_status( @retry_on_exception(retries=3, no_retry_on_exception_types=NO_RETRY_EXCEPTIONS) def stop_trial_runs( - self, trials: list[BaseTrial], reasons: Optional[list[Optional[str]]] = None + self, trials: list[BaseTrial], reasons: list[str | None] | None = None ) -> None: """Stops the jobs that execute given trials. @@ -754,7 +754,7 @@ def stop_trial_runs( def wait_for_completed_trials_and_report_results( self, - idle_callback: Optional[Callable[[Scheduler], None]] = None, + idle_callback: Callable[[Scheduler], None] | None = None, force_refit: bool = False, ) -> dict[str, Any]: """Continuously poll for successful trials, with limited exponential @@ -955,7 +955,7 @@ def run_trials_and_yield_results( max_trials: int, ignore_global_stopping_strategy: bool = False, timeout_hours: int | float | None = None, - idle_callback: Optional[Callable[[Scheduler], None]] = None, + idle_callback: Callable[[Scheduler], None] | None = None, ) -> Generator[dict[str, Any], None, None]: """Make continuous calls to `run` and `process_results` to run up to ``max_trials`` trials, until completion criterion is reached. This is the 'main' @@ -1072,9 +1072,9 @@ def run_trials_and_yield_results( def _check_exit_status_and_report_results( self, n_existing: int, - idle_callback: Optional[Callable[[Scheduler], None]], + idle_callback: Callable[[Scheduler], None] | None, force_refit: bool, - ) -> Optional[dict[str, Any]]: + ) -> dict[str, Any] | None: if not self.should_wait_for_running_trials: return None return self.wait_for_completed_trials_and_report_results( @@ -1085,9 +1085,9 @@ def run_n_trials( self, max_trials: int, ignore_global_stopping_strategy: bool = False, - timeout_hours: Optional[int] = None, + timeout_hours: int | None = None, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - idle_callback: Optional[Callable[[Scheduler], Any]] = None, + idle_callback: Callable[[Scheduler], Any] | None = None, ) -> OptimizationResult: """Run up to ``max_trials`` trials; will run all ``max_trials`` unless completion criterion is reached. For base ``Scheduler``, completion criterion @@ -1135,9 +1135,9 @@ def run_n_trials( def run_all_trials( self, - timeout_hours: Optional[int] = None, + timeout_hours: int | None = None, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - idle_callback: Optional[Callable[[Scheduler], Any]] = None, + idle_callback: Callable[[Scheduler], Any] | None = None, ) -> OptimizationResult: """Run all trials until ``completion_criterion`` is reached (by default, completion criterion is reaching the ``num_trials`` setting, passed to @@ -1529,7 +1529,7 @@ def _process_completed_trials(self, newly_completed: set[int]) -> None: trial_indices=newly_completed, ) - def estimate_early_stopping_savings(self, map_key: Optional[str] = None) -> float: + def estimate_early_stopping_savings(self, map_key: str | None = None) -> float: """Estimate early stopping savings using progressions of the MapMetric present on the EarlyStoppingConfig as a proxy for resource usage. @@ -1599,7 +1599,7 @@ def _complete_optimization( self, num_preexisting_trials: int, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - idle_callback: Optional[Callable[[Scheduler], Any]] = None, + idle_callback: Callable[[Scheduler], Any] | None = None, ) -> dict[str, Any]: """Conclude optimization with waiting for anymore running trials and return final results via `wait_for_completed_trials_and_report_results`. @@ -1709,7 +1709,7 @@ def _prepare_trials( return existing_candidate_trials, new_trials def _get_next_trials( - self, num_trials: int = 1, n: Optional[int] = None + self, num_trials: int = 1, n: int | None = None ) -> list[BaseTrial]: """Produce up to `num_trials` new generator runs from the underlying generation strategy and create new trials with them. Logs errors @@ -1834,7 +1834,7 @@ def generate_candidates( return new_trials def compute_analyses( - self, analyses: Optional[Iterable[Analysis]] = None + self, analyses: Iterable[Analysis] | None = None ) -> list[AnalysisCard]: analyses = analyses if analyses is not None else self._choose_analyses() @@ -1866,7 +1866,7 @@ def _choose_analyses(self) -> list[Analysis]: def _gen_new_trials_from_generation_strategy( self, num_trials: int, - n: Optional[int] = None, + n: int | None = None, ) -> list[list[GeneratorRun]]: """Generates a list ``GeneratorRun``s of length of ``num_trials`` using the ``_gen_multiple`` method of the scheduler's ``generation_strategy``, taking @@ -2169,8 +2169,8 @@ def _report_metric_fetch_e( def _mark_err_trial_status( self, trial: BaseTrial, - metric_name: Optional[str] = None, - metric_fetch_e: Optional[MetricFetchE] = None, + metric_name: str | None = None, + metric_fetch_e: MetricFetchE | None = None, ) -> TrialStatus: trial.mark_failed(unsafe=True) diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index cbe1de1519f..52d3202840f 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -8,13 +8,13 @@ import os import re -from collections.abc import Iterable +from collections.abc import Callable, Iterable from datetime import datetime, timedelta from logging import WARNING from math import ceil from random import randint from tempfile import NamedTemporaryFile -from typing import Any, Callable, cast, Optional +from typing import Any, cast from unittest.mock import call, Mock, patch, PropertyMock import pandas as pd @@ -348,7 +348,7 @@ def setUp(self) -> None: def _get_generation_strategy_strategy_for_test( self, experiment: Experiment, - generation_strategy: Optional[GenerationStrategy] = None, + generation_strategy: GenerationStrategy | None = None, ) -> GenerationStrategyInterface: return not_none(generation_strategy) @@ -404,7 +404,7 @@ def db_settings(self) -> DBSettings: return DBSettings(encoder=encoder, decoder=decoder) @property - def db_settings_if_always_needed(self) -> Optional[DBSettings]: + def db_settings_if_always_needed(self) -> DBSettings | None: if self.ALWAYS_USE_DB: return self.db_settings return None @@ -639,7 +639,7 @@ def write_n_trials(scheduler: Scheduler) -> None: def base_run_n_trials( self, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - idle_callback: Optional[Callable[[Scheduler], Any]], + idle_callback: Callable[[Scheduler], Any] | None, ) -> None: gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, @@ -1416,7 +1416,7 @@ def should_stop_trials_early( trial_indices: set[int], experiment: Experiment, **kwargs: dict[str, Any], - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: # Make sure that we can lookup data for the trial, # even though we won't use it in this dummy strategy data = experiment.lookup_data(trial_indices=trial_indices) @@ -1616,10 +1616,8 @@ def test_optimization_complete(self) -> None: self.assertEqual(len(scheduler.experiment.trials), 0) @patch( - ( - f"{WithDBSettingsBase.__module__}.WithDBSettingsBase." - "_save_generation_strategy_to_db_if_possible" - ) + f"{WithDBSettingsBase.__module__}.WithDBSettingsBase." + "_save_generation_strategy_to_db_if_possible" ) @patch( f"{WithDBSettingsBase.__module__}._save_experiment", side_effect=StaleDataError diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index b0b054852d8..9439f23a05e 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -11,7 +11,7 @@ import time from itertools import product from math import ceil -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import Mock, patch @@ -130,7 +130,7 @@ def get_branin_currin_optimization_with_N_sobol_trials( minimize: bool = False, include_objective_thresholds: bool = True, random_seed: int = RANDOM_SEED, - outcome_constraints: Optional[list[str]] = None, + outcome_constraints: list[str] | None = None, ) -> tuple[AxClient, BraninCurrin]: branin_currin = get_branin_currin(minimize=minimize) ax_client = AxClient() @@ -182,8 +182,8 @@ def get_branin_currin_optimization_with_N_sobol_trials( def get_branin_optimization( - generation_strategy: Optional[GenerationStrategy] = None, - torch_device: Optional[torch.device] = None, + generation_strategy: GenerationStrategy | None = None, + torch_device: torch.device | None = None, ) -> AxClient: ax_client = AxClient( generation_strategy=generation_strategy, torch_device=torch_device @@ -2284,7 +2284,7 @@ def helper_test_get_pareto_optimal_points( self, mock_observed_pareto: Mock, mock_predicted_pareto: Mock, - outcome_constraints: Optional[list[str]] = None, + outcome_constraints: list[str] | None = None, ) -> None: ax_client, branin_currin = get_branin_currin_optimization_with_N_sobol_trials( num_trials=20, outcome_constraints=outcome_constraints @@ -2355,7 +2355,7 @@ def test_get_pareto_optimal_points(self) -> None: ) def helper_test_get_pareto_optimal_points_from_sobol_step( - self, minimize: bool, outcome_constraints: Optional[list[str]] = None + self, minimize: bool, outcome_constraints: list[str] | None = None ) -> None: ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials( num_trials=20, minimize=minimize, outcome_constraints=outcome_constraints @@ -2698,7 +2698,7 @@ def test_with_hss(self) -> None: ) def test_should_stop_trials_early(self) -> None: - expected: dict[int, Optional[str]] = { + expected: dict[int, str | None] = { 1: "Stopped due to testing.", 3: "Stopped due to testing.", } diff --git a/ax/service/tests/test_early_stopping.py b/ax/service/tests/test_early_stopping.py index d9bc32c1859..666dbe2aef0 100644 --- a/ax/service/tests/test_early_stopping.py +++ b/ax/service/tests/test_early_stopping.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional from ax.service.utils import early_stopping as early_stopping_utils from ax.utils.common.testutils import TestCase @@ -24,7 +23,7 @@ def setUp(self) -> None: self.branin_experiment = get_branin_experiment() def test_should_stop_trials_early(self) -> None: - expected: dict[int, Optional[str]] = { + expected: dict[int, str | None] = { 1: "Stopped due to testing.", 3: "Stopped due to testing.", } diff --git a/ax/service/tests/test_interactive_loop.py b/ax/service/tests/test_interactive_loop.py index 120714585e5..843411f05dc 100644 --- a/ax/service/tests/test_interactive_loop.py +++ b/ax/service/tests/test_interactive_loop.py @@ -10,7 +10,6 @@ import functools import time from logging import WARN -from typing import Optional import numpy as np from ax.core.types import TEvaluationOutcome @@ -29,7 +28,7 @@ class TestInteractiveLoop(TestCase): def test_interactive_loop(self) -> None: def _elicit( parameterization_with_trial_index: tuple[TParameterization, int] - ) -> Optional[tuple[int, TEvaluationOutcome]]: + ) -> tuple[int, TEvaluationOutcome] | None: parameterization, trial_index = parameterization_with_trial_index x = np.array([parameterization.get(f"x{i+1}") for i in range(6)]) @@ -43,7 +42,7 @@ def _elicit( def _aborted_elicit( parameterization_with_trial_index: tuple[TParameterization, int] - ) -> Optional[tuple[int, TEvaluationOutcome]]: + ) -> tuple[int, TEvaluationOutcome] | None: return None ax_client = AxClient() diff --git a/ax/service/tests/test_managed_loop.py b/ax/service/tests/test_managed_loop.py index 728fce17b8e..a6bfa320737 100644 --- a/ax/service/tests/test_managed_loop.py +++ b/ax/service/tests/test_managed_loop.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Union from unittest.mock import Mock, patch import numpy as np @@ -22,7 +21,7 @@ def _branin_evaluation_function( parameterization, weight=None # pyre-fixme[2]: Parameter must be annotated. -) -> dict[str, tuple[Union[float, ndarray], float]]: +) -> dict[str, tuple[float | ndarray, float]]: if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]): raise ValueError("Parametrization does not contain x1 or x2") x1, x2 = parameterization["x1"], parameterization["x2"] @@ -34,7 +33,7 @@ def _branin_evaluation_function( def _branin_evaluation_function_v2( parameterization, weight=None # pyre-fixme[2]: Parameter must be annotated. -) -> tuple[Union[float, ndarray], float]: +) -> tuple[float | ndarray, float]: if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]): raise ValueError("Parametrization does not contain x1 or x2") x1, x2 = parameterization["x1"], parameterization["x2"] @@ -43,7 +42,7 @@ def _branin_evaluation_function_v2( def _branin_evaluation_function_with_unknown_sem( parameterization, weight=None # pyre-fixme[2]: Parameter must be annotated. -) -> tuple[Union[float, ndarray], None]: +) -> tuple[float | ndarray, None]: if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]): raise ValueError("Parametrization does not contain x1 or x2") x1, x2 = parameterization["x1"], parameterization["x2"] diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 096facad63a..7f19765cdd1 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -11,7 +11,6 @@ from functools import reduce from logging import Logger -from typing import Optional import pandas as pd import torch @@ -61,8 +60,8 @@ def get_best_raw_objective_point_with_trial_index( experiment: Experiment, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, ) -> tuple[int, TParameterization, dict[str, tuple[float, float]]]: """Given an experiment, identifies the arm that had the best raw objective, based on the data fetched from the experiment. @@ -139,8 +138,8 @@ def get_best_raw_objective_point_with_trial_index( def get_best_raw_objective_point( experiment: Experiment, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, ) -> tuple[TParameterization, dict[str, tuple[float, float]]]: _, parameterization, vals = get_best_raw_objective_point_with_trial_index( @@ -153,7 +152,7 @@ def get_best_raw_objective_point( def _gr_to_prediction_with_trial_index( idx: int, gr: GeneratorRun -) -> Optional[tuple[int, TParameterization, Optional[TModelPredictArm]]]: +) -> tuple[int, TParameterization, TModelPredictArm | None] | None: if gr.best_arm_predictions is None: return None @@ -177,9 +176,9 @@ def _raw_values_to_model_predict_arm( def get_best_parameters_from_model_predictions_with_trial_index( experiment: Experiment, models_enum: type[ModelRegistryBase], - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, -) -> Optional[tuple[int, TParameterization, Optional[TModelPredictArm]]]: + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, +) -> tuple[int, TParameterization, TModelPredictArm | None] | None: """Given an experiment, returns the best predicted parameterization and corresponding prediction based on the most recent Trial with predictions. If no trials have predictions returns None. @@ -283,8 +282,8 @@ def get_best_parameters_from_model_predictions_with_trial_index( def get_best_parameters_from_model_predictions( experiment: Experiment, models_enum: type[ModelRegistryBase], - trial_indices: Optional[Iterable[int]] = None, -) -> Optional[tuple[TParameterization, Optional[TModelPredictArm]]]: + trial_indices: Iterable[int] | None = None, +) -> tuple[TParameterization, TModelPredictArm | None] | None: """Given an experiment, returns the best predicted parameterization and corresponding prediction based on the most recent Trial with predictions. If no trials have predictions returns None. @@ -320,9 +319,9 @@ def get_best_parameters_from_model_predictions( def get_best_by_raw_objective_with_trial_index( experiment: Experiment, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, -) -> Optional[tuple[int, TParameterization, Optional[TModelPredictArm]]]: + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, +) -> tuple[int, TParameterization, TModelPredictArm | None] | None: """Given an experiment, identifies the arm that had the best raw objective, based on the data fetched from the experiment. @@ -363,9 +362,9 @@ def get_best_by_raw_objective_with_trial_index( def get_best_by_raw_objective( experiment: Experiment, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, -) -> Optional[tuple[TParameterization, Optional[TModelPredictArm]]]: + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, +) -> tuple[TParameterization, TModelPredictArm | None] | None: """Given an experiment, identifies the arm that had the best raw objective, based on the data fetched from the experiment. @@ -398,9 +397,9 @@ def get_best_by_raw_objective( def get_best_parameters_with_trial_index( experiment: Experiment, models_enum: type[ModelRegistryBase], - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, -) -> Optional[tuple[int, TParameterization, Optional[TModelPredictArm]]]: + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, +) -> tuple[int, TParameterization, TModelPredictArm | None] | None: """Given an experiment, identifies the best arm. First attempts according to do so with models used in optimization and @@ -454,9 +453,9 @@ def get_best_parameters_with_trial_index( def get_best_parameters( experiment: Experiment, models_enum: type[ModelRegistryBase], - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, -) -> Optional[tuple[TParameterization, Optional[TModelPredictArm]]]: + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, +) -> tuple[TParameterization, TModelPredictArm | None] | None: """Given an experiment, identifies the best arm. First attempts according to do so with models used in optimization and @@ -495,8 +494,8 @@ def get_best_parameters( def get_pareto_optimal_parameters( experiment: Experiment, generation_strategy: GenerationStrategy, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> dict[int, tuple[TParameterization, TModelPredictArm]]: """Identifies the best parameterizations tried in the experiment so far, @@ -741,9 +740,9 @@ def _is_all_noiseless(df: pd.DataFrame, metric_name: str) -> bool: def _derel_opt_config_wrapper( optimization_config: OptimizationConfig, - modelbridge: Optional[ModelBridge] = None, - experiment: Optional[Experiment] = None, - observations: Optional[list[Observation]] = None, + modelbridge: ModelBridge | None = None, + experiment: Experiment | None = None, + observations: list[Observation] | None = None, ) -> OptimizationConfig: """Derelativize optimization_config using raw status-quo values""" @@ -782,7 +781,7 @@ def _derel_opt_config_wrapper( def extract_Y_from_data( experiment: Experiment, metric_names: list[str], - data: Optional[Data] = None, + data: Data | None = None, ) -> tuple[Tensor, Tensor]: r"""Converts the experiment observation data into a tensor. @@ -855,7 +854,7 @@ def extract_Y_from_data( def _objective_threshold_from_nadir( experiment: Experiment, objective: Objective, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, ) -> ObjectiveThreshold: """ Find the worst value observed for each objective and create an ObjectiveThreshold diff --git a/ax/service/utils/best_point_mixin.py b/ax/service/utils/best_point_mixin.py index b85f8d8bfbb..a9640801455 100644 --- a/ax/service/utils/best_point_mixin.py +++ b/ax/service/utils/best_point_mixin.py @@ -10,7 +10,6 @@ from collections.abc import Iterable from functools import partial from logging import Logger -from typing import Optional import numpy as np import torch @@ -58,10 +57,10 @@ class BestPointMixin(metaclass=ABCMeta): @abstractmethod def get_best_trial( self, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, - ) -> Optional[tuple[int, TParameterization, Optional[TModelPredictArm]]]: + ) -> tuple[int, TParameterization, TModelPredictArm | None] | None: """Identifies the best parameterization tried in the experiment so far. First attempts to do so with the model used in optimization and @@ -89,10 +88,10 @@ def get_best_trial( def get_best_parameters( self, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, - ) -> Optional[tuple[TParameterization, Optional[TModelPredictArm]]]: + ) -> tuple[TParameterization, TModelPredictArm | None] | None: """Identifies the best parameterization tried in the experiment so far. First attempts to do so with the model used in optimization and @@ -131,8 +130,8 @@ def get_best_parameters( @abstractmethod def get_pareto_optimal_parameters( self, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> dict[int, tuple[TParameterization, TModelPredictArm]]: """Identifies the best parameterizations tried in the experiment so far, @@ -171,8 +170,8 @@ def get_pareto_optimal_parameters( @abstractmethod def get_hypervolume( self, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> float: """Calculate hypervolume of a pareto frontier based on either the posterior @@ -193,7 +192,7 @@ def get_hypervolume( @abstractmethod def get_trace( - optimization_config: Optional[OptimizationConfig] = None, + optimization_config: OptimizationConfig | None = None, ) -> list[float]: """Get the optimization trace of the given experiment. @@ -214,8 +213,8 @@ def get_trace( @abstractmethod def get_trace_by_progression( - optimization_config: Optional[OptimizationConfig] = None, - bins: Optional[list[float]] = None, + optimization_config: OptimizationConfig | None = None, + bins: list[float] | None = None, final_progression_only: bool = False, ) -> tuple[list[float], list[float]]: """Get the optimization trace with respect to trial progressions instead of @@ -255,10 +254,10 @@ def get_trace_by_progression( def _get_best_trial( experiment: Experiment, generation_strategy: GenerationStrategy, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, - ) -> Optional[tuple[int, TParameterization, Optional[TModelPredictArm]]]: + ) -> tuple[int, TParameterization, TModelPredictArm | None] | None: optimization_config = optimization_config or not_none( experiment.optimization_config ) @@ -301,9 +300,9 @@ def _get_best_trial( @staticmethod def _get_best_observed_value( experiment: Experiment, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, - ) -> Optional[float]: + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, + ) -> float | None: """Identifies the best objective value observed in the experiment among the trials indicated by `trial_indices`. @@ -349,8 +348,8 @@ def _get_best_observed_value( def _get_pareto_optimal_parameters( experiment: Experiment, generation_strategy: GenerationStrategy, - optimization_config: Optional[OptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: OptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> dict[int, tuple[TParameterization, TModelPredictArm]]: optimization_config = optimization_config or not_none( @@ -372,8 +371,8 @@ def _get_pareto_optimal_parameters( def _get_hypervolume( experiment: Experiment, generation_strategy: GenerationStrategy, - optimization_config: Optional[MultiObjectiveOptimizationConfig] = None, - trial_indices: Optional[Iterable[int]] = None, + optimization_config: MultiObjectiveOptimizationConfig | None = None, + trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> float: data = experiment.lookup_data() @@ -430,7 +429,7 @@ def _get_hypervolume( @staticmethod def _get_trace( experiment: Experiment, - optimization_config: Optional[OptimizationConfig] = None, + optimization_config: OptimizationConfig | None = None, ) -> list[float]: """Compute the optimization trace at each iteration. @@ -570,8 +569,8 @@ def _get_trace( @staticmethod def _get_trace_by_progression( experiment: Experiment, - optimization_config: Optional[OptimizationConfig] = None, - bins: Optional[list[float]] = None, + optimization_config: OptimizationConfig | None = None, + bins: list[float] | None = None, final_progression_only: bool = False, ) -> tuple[list[float], list[float]]: optimization_config = optimization_config or not_none( diff --git a/ax/service/utils/early_stopping.py b/ax/service/utils/early_stopping.py index 5fdea7dfb6b..936508f933f 100644 --- a/ax/service/utils/early_stopping.py +++ b/ax/service/utils/early_stopping.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional from ax.core.experiment import Experiment from ax.early_stopping.strategies import BaseEarlyStoppingStrategy @@ -13,10 +12,10 @@ def should_stop_trials_early( - early_stopping_strategy: Optional[BaseEarlyStoppingStrategy], + early_stopping_strategy: BaseEarlyStoppingStrategy | None, trial_indices: set[int], experiment: Experiment, -) -> dict[int, Optional[str]]: +) -> dict[int, str | None]: """Evaluate whether to early-stop running trials. Args: @@ -39,7 +38,7 @@ def should_stop_trials_early( def get_early_stopping_metrics( - experiment: Experiment, early_stopping_strategy: Optional[BaseEarlyStoppingStrategy] + experiment: Experiment, early_stopping_strategy: BaseEarlyStoppingStrategy | None ) -> list[str]: """A helper function that returns a list of metric names on which a given `early_stopping_strategy` is operating.""" diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index afcbcbd292a..aebf8aeaa17 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from logging import Logger -from typing import Any, Optional, Union +from typing import Any, Union from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose @@ -114,7 +114,7 @@ class ObjectiveProperties: """ minimize: bool - threshold: Optional[float] = None + threshold: float | None = None @dataclass(frozen=True) @@ -122,7 +122,7 @@ class FixedFeatures: """Class for representing fixed features via the Service API.""" parameters: TParameterization - trial_index: Optional[int] = None + trial_index: int | None = None class InstantiationBase: @@ -136,7 +136,7 @@ class InstantiationBase: def _get_deserialized_metric_kwargs( metric_class: type[Metric], name: str, - metric_definitions: Optional[dict[str, dict[str, Any]]], + metric_definitions: dict[str, dict[str, Any]] | None, ) -> dict[str, Any]: """Get metric kwargs from metric_definitions if available and deserialize if so. Deserialization is necessary because they were serialized on creation""" @@ -152,10 +152,10 @@ def _get_deserialized_metric_kwargs( def _make_metric( cls, name: str, - lower_is_better: Optional[bool] = None, + lower_is_better: bool | None = None, metric_class: type[Metric] = Metric, for_opt_config: bool = False, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> Metric: if " " in name: raise ValueError( @@ -184,7 +184,7 @@ def _get_parameter_type(python_type: TParameterType) -> ParameterType: def _to_parameter_type( cls, vals: list[TParamValue], - typ: Optional[str], + typ: str | None, param_name: str, field_name: str, ) -> ParameterType: @@ -209,7 +209,7 @@ def _make_range_param( cls, name: str, representation: TParameterRepresentation, - parameter_type: Optional[str], + parameter_type: str | None, ) -> RangeParameter: assert "bounds" in representation, "Bounds are required for range parameters." bounds = representation["bounds"] @@ -238,7 +238,7 @@ def _make_choice_param( cls, name: str, representation: TParameterRepresentation, - parameter_type: Optional[str], + parameter_type: str | None, ) -> ChoiceParameter: values = representation["values"] assert isinstance(values, list) and len(values) > 1, ( @@ -268,7 +268,7 @@ def _make_fixed_param( cls, name: str, representation: TParameterRepresentation, - parameter_type: Optional[str], + parameter_type: str | None, ) -> FixedParameter: assert "value" in representation, "Value is required for fixed parameters." value = representation["value"] @@ -472,7 +472,7 @@ def constraint_from_str( def outcome_constraint_from_str( cls, representation: str, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> OutcomeConstraint: """Parse string representation of an outcome constraint.""" tokens = representation.split() @@ -509,7 +509,7 @@ def outcome_constraint_from_str( def objective_threshold_constraint_from_str( cls, representation: str, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> ObjectiveThreshold: oc = cls.outcome_constraint_from_str( representation, metric_definitions=metric_definitions @@ -525,7 +525,7 @@ def objective_threshold_constraint_from_str( def make_objectives( cls, objectives: dict[str, str], - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> list[Objective]: try: output_objectives = [] @@ -557,7 +557,7 @@ def make_outcome_constraints( cls, outcome_constraints: list[str], status_quo_defined: bool, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> list[OutcomeConstraint]: typed_outcome_constraints = [ @@ -579,7 +579,7 @@ def make_objective_thresholds( cls, objective_thresholds: list[str], status_quo_defined: bool, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> list[ObjectiveThreshold]: typed_objective_thresholds = ( @@ -647,7 +647,7 @@ def make_optimization_config( objective_thresholds: list[str], outcome_constraints: list[str], status_quo_defined: bool, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, ) -> OptimizationConfig: return cls.optimization_config_from_objectives( @@ -667,11 +667,11 @@ def make_optimization_config( @classmethod def make_optimization_config_from_properties( cls, - objectives: Optional[dict[str, ObjectiveProperties]] = None, - outcome_constraints: Optional[list[str]] = None, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, + objectives: dict[str, ObjectiveProperties] | None = None, + outcome_constraints: list[str] | None = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, status_quo_defined: bool = False, - ) -> Optional[OptimizationConfig]: + ) -> OptimizationConfig | None: """Makes optimization config based on ObjectiveProperties objects Args: @@ -707,7 +707,7 @@ def make_optimization_config_from_properties( def make_search_space( cls, parameters: list[TParameterRepresentation], - parameter_constraints: Optional[list[str]], + parameter_constraints: list[str] | None, ) -> SearchSpace: parameter_constraints = ( parameter_constraints if parameter_constraints is not None else [] @@ -765,7 +765,7 @@ def make_search_space( ) @classmethod - def _get_default_objectives(cls) -> Optional[dict[str, str]]: + def _get_default_objectives(cls) -> dict[str, str] | None: """Get the default objective and its optimization direction. The return type is optional since some subclasses may not wish to @@ -777,22 +777,22 @@ def _get_default_objectives(cls) -> Optional[dict[str, str]]: def make_experiment( cls, parameters: list[TParameterRepresentation], - name: Optional[str] = None, - description: Optional[str] = None, - owners: Optional[list[str]] = None, - parameter_constraints: Optional[list[str]] = None, - outcome_constraints: Optional[list[str]] = None, - status_quo: Optional[TParameterization] = None, - experiment_type: Optional[str] = None, - tracking_metric_names: Optional[list[str]] = None, - metric_definitions: Optional[dict[str, dict[str, Any]]] = None, - objectives: Optional[dict[str, str]] = None, - objective_thresholds: Optional[list[str]] = None, + name: str | None = None, + description: str | None = None, + owners: list[str] | None = None, + parameter_constraints: list[str] | None = None, + outcome_constraints: list[str] | None = None, + status_quo: TParameterization | None = None, + experiment_type: str | None = None, + tracking_metric_names: list[str] | None = None, + metric_definitions: dict[str, dict[str, Any]] | None = None, + objectives: dict[str, str] | None = None, + objective_thresholds: list[str] | None = None, support_intermediate_data: bool = False, immutable_search_space_and_opt_config: bool = True, - auxiliary_experiments_by_purpose: Optional[ + auxiliary_experiments_by_purpose: None | ( dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] - ] = None, + ) = None, is_test: bool = False, ) -> Experiment: """Instantiation wrapper that allows for Ax `Experiment` creation diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index 4ea27dabe2f..3294ab60623 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -9,10 +9,10 @@ import itertools import logging from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from datetime import timedelta from logging import Logger -from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union +from typing import Any, cast, TYPE_CHECKING import gpytorch import numpy as np @@ -99,7 +99,7 @@ def _get_cross_validation_plots(model: ModelBridge) -> list[go.Figure]: def _get_objective_trace_plot( experiment: Experiment, data: Data, - true_objective_metric_name: Optional[str] = None, + true_objective_metric_name: str | None = None, ) -> Iterable[go.Figure]: if experiment.is_moo_problem: return [ @@ -146,9 +146,9 @@ def _get_objective_trace_plot( def _get_objective_v_param_plots( experiment: Experiment, model: ModelBridge, - importance: Optional[ - Union[dict[str, dict[str, np.ndarray]], dict[str, dict[str, float]]] - ] = None, + importance: None | ( + dict[str, dict[str, np.ndarray]] | dict[str, dict[str, float]] + ) = None, # Chosen to take ~1min on local benchmarks. max_num_slice_plots: int = 200, # Chosen to take ~2min on local benchmarks. @@ -311,11 +311,11 @@ def _get_shortest_unique_suffix_dict( def get_standard_plots( experiment: Experiment, - model: Optional[ModelBridge], - data: Optional[Data] = None, - true_objective_metric_name: Optional[str] = None, - early_stopping_strategy: Optional[BaseEarlyStoppingStrategy] = None, - limit_points_per_plot: Optional[int] = None, + model: ModelBridge | None, + data: Data | None = None, + true_objective_metric_name: str | None = None, + early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, + limit_points_per_plot: int | None = None, global_sensitivity_analysis: bool = True, ) -> list[go.Figure]: """Extract standard plots for single-objective optimization. @@ -522,7 +522,7 @@ def get_standard_plots( def _transform_progression_to_walltime( progressions: np.ndarray, exp_df: pd.DataFrame, trial_idx: int -) -> Optional[np.ndarray]: +) -> np.ndarray | None: try: trial_df = exp_df[exp_df["trial_index"] == trial_idx] time_run_started = trial_df["time_run_started"].iloc[0] @@ -542,10 +542,10 @@ def _get_curve_plot_dropdown( experiment: Experiment, map_metrics: Iterable[MapMetric], data: MapData, - early_stopping_strategy: Optional[BaseEarlyStoppingStrategy], + early_stopping_strategy: BaseEarlyStoppingStrategy | None, by_walltime: bool = False, - limit_points_per_plot: Optional[int] = None, -) -> Optional[go.Figure]: + limit_points_per_plot: int | None = None, +) -> go.Figure | None: """Plot curve metrics by either progression or walltime. Args: @@ -763,13 +763,13 @@ def _merge_results_if_no_duplicates( def exp_to_df( exp: Experiment, - metrics: Optional[list[Metric]] = None, - run_metadata_fields: Optional[list[str]] = None, - trial_properties_fields: Optional[list[str]] = None, - trial_attribute_fields: Optional[list[str]] = None, - additional_fields_callables: Optional[ - dict[str, Callable[[Experiment], dict[int, Union[str, float]]]] - ] = None, + metrics: list[Metric] | None = None, + run_metadata_fields: list[str] | None = None, + trial_properties_fields: list[str] | None = None, + trial_attribute_fields: list[str] | None = None, + additional_fields_callables: None | ( + dict[str, Callable[[Experiment], dict[int, str | float]]] + ) = None, always_include_field_columns: bool = False, **kwargs: Any, ) -> pd.DataFrame: @@ -1001,7 +1001,7 @@ def exp_to_df( def compute_maximum_map_values( - experiment: Experiment, map_key: Optional[str] = None + experiment: Experiment, map_key: str | None = None ) -> dict[int, float]: """A function that returns a map from trial_index to the maximum map value reached. If map_key is not specified, it uses the first map_key.""" @@ -1030,9 +1030,9 @@ def compute_maximum_map_values( def _pairwise_pareto_plotly_scatter( experiment: Experiment, - metric_names: Optional[tuple[str, str]] = None, - reference_point: Optional[tuple[float, float]] = None, - minimize: Optional[Union[bool, tuple[bool, bool]]] = None, + metric_names: tuple[str, str] | None = None, + reference_point: tuple[float, float] | None = None, + minimize: bool | tuple[bool, bool] | None = None, ) -> Iterable[go.Figure]: metric_name_pairs = _get_metric_name_pairs(experiment=experiment) return [ @@ -1073,9 +1073,9 @@ def _get_metric_name_pairs( def _pareto_frontier_scatter_2d_plotly( experiment: Experiment, - metric_names: Optional[tuple[str, str]] = None, - reference_point: Optional[tuple[float, float]] = None, - minimize: Optional[Union[bool, tuple[bool, bool]]] = None, + metric_names: tuple[str, str] | None = None, + reference_point: tuple[float, float] | None = None, + minimize: bool | tuple[bool, bool] | None = None, ) -> go.Figure: # Determine defaults for unspecified inputs using `optimization_config` @@ -1094,8 +1094,8 @@ def _pareto_frontier_scatter_2d_plotly( def pareto_frontier_scatter_2d_plotly( experiment: Experiment, metric_names: tuple[str, str], - reference_point: Optional[tuple[float, float]] = None, - minimize: Optional[Union[bool, tuple[bool, bool]]] = None, + reference_point: tuple[float, float] | None = None, + minimize: bool | tuple[bool, bool] | None = None, ) -> go.Figure: df = exp_to_df(experiment) @@ -1224,7 +1224,7 @@ def _construct_comparison_message( comparison_arm_name: str, comparison_value: float, digits: int = 2, -) -> Optional[str]: +) -> str | None: # TODO: allow for user configured digits value if baseline_value == 0: logger.info( @@ -1287,7 +1287,7 @@ def _build_result_tuple( def select_baseline_arm( - experiment: Experiment, arms_df: pd.DataFrame, baseline_arm_name: Optional[str] + experiment: Experiment, arms_df: pd.DataFrame, baseline_arm_name: str | None ) -> tuple[str, bool]: """ Choose a baseline arm that is found in arms_df @@ -1332,10 +1332,10 @@ def select_baseline_arm( def maybe_extract_baseline_comparison_values( experiment: Experiment, - optimization_config: Optional[OptimizationConfig], - comparison_arm_names: Optional[list[str]], - baseline_arm_name: Optional[str], -) -> Optional[list[tuple[str, bool, str, float, str, float]]]: + optimization_config: OptimizationConfig | None, + comparison_arm_names: list[str] | None, + baseline_arm_name: str | None, +) -> list[tuple[str, bool, str, float, str, float]] | None: """ Extracts the baseline values from the experiment, for use in comparing the baseline arm to the optimal results. @@ -1431,7 +1431,7 @@ def maybe_extract_baseline_comparison_values( def compare_to_baseline_impl( comparison_list: list[tuple[str, bool, str, float, str, float]] -) -> Optional[str]: +) -> str | None: """Implementation of compare_to_baseline, taking in a list of arm comparisons. Can be used directly with the output of @@ -1457,10 +1457,10 @@ def compare_to_baseline_impl( def compare_to_baseline( experiment: Experiment, - optimization_config: Optional[OptimizationConfig], - comparison_arm_names: Optional[list[str]], - baseline_arm_name: Optional[str] = None, -) -> Optional[str]: + optimization_config: OptimizationConfig | None, + comparison_arm_names: list[str] | None, + baseline_arm_name: str | None = None, +) -> str | None: """Calculate metric improvement of the experiment against baseline. Returns the message(s) added to markdown_messages.""" @@ -1480,9 +1480,9 @@ def warn_if_unpredictable_metrics( experiment: Experiment, generation_strategy: GenerationStrategy, model_fit_threshold: float, - metric_names: Optional[list[str]] = None, + metric_names: list[str] | None = None, model_fit_metric_name: str = "coefficient_of_determination", -) -> Optional[str]: +) -> str | None: """Warn if any optimization config metrics are considered unpredictable, i.e., their coefficient of determination is less than model_fit_threshold. Args: diff --git a/ax/service/utils/scheduler_options.py b/ax/service/utils/scheduler_options.py index 94eb989e706..b2c034bbb03 100644 --- a/ax/service/utils/scheduler_options.py +++ b/ax/service/utils/scheduler_options.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from enum import Enum from logging import INFO -from typing import Any, Optional +from typing import Any from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy @@ -121,25 +121,25 @@ class SchedulerOptions: max_pending_trials: int = 10 trial_type: TrialType = TrialType.TRIAL - batch_size: Optional[int] = None - total_trials: Optional[int] = None + batch_size: int | None = None + total_trials: int | None = None tolerated_trial_failure_rate: float = 0.5 min_failed_trials_for_failure_rate_check: int = 5 - log_filepath: Optional[str] = None + log_filepath: str | None = None logging_level: int = INFO - ttl_seconds_for_trials: Optional[int] = None - init_seconds_between_polls: Optional[int] = 1 + ttl_seconds_for_trials: int | None = None + init_seconds_between_polls: int | None = 1 min_seconds_before_poll: float = 1.0 seconds_between_polls_backoff_factor: float = 1.5 - timeout_hours: Optional[float] = None + timeout_hours: float | None = None run_trials_in_batches: bool = False debug_log_run_metadata: bool = False - early_stopping_strategy: Optional[BaseEarlyStoppingStrategy] = None - global_stopping_strategy: Optional[BaseGlobalStoppingStrategy] = None + early_stopping_strategy: BaseEarlyStoppingStrategy | None = None + global_stopping_strategy: BaseGlobalStoppingStrategy | None = None suppress_storage_errors_after_retries: bool = False wait_for_running_trials: bool = True fetch_kwargs: dict[str, Any] = field(default_factory=dict) validate_metrics: bool = True status_quo_weight: float = 0.0 enforce_immutable_search_space_and_opt_config: bool = True - mt_experiment_trial_type: Optional[str] = None + mt_experiment_trial_type: str | None = None diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index aa418469232..2163cd4b1a3 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -8,9 +8,10 @@ import re import time +from collections.abc import Iterable from logging import INFO, Logger -from typing import Any, Iterable, Optional +from typing import Any from ax.analysis.analysis import AnalysisCard @@ -92,7 +93,7 @@ class WithDBSettingsBase: if `db_settings` property is set to a non-None value on the instance. """ - _db_settings: Optional[DBSettings] = None + _db_settings: DBSettings | None = None # Mapping of object types to mapping of fields to override values # loaded objects will all be instantiated with fields set to @@ -102,7 +103,7 @@ class WithDBSettingsBase: def __init__( self, - db_settings: Optional[DBSettings] = None, + db_settings: DBSettings | None = None, logging_level: int = INFO, suppress_all_errors: bool = False, ) -> None: @@ -122,7 +123,7 @@ def __init__( logger.setLevel(logging_level) @staticmethod - def _get_default_db_settings() -> Optional[DBSettings]: + def _get_default_db_settings() -> DBSettings | None: """Overridable method to get default db_settings if none are passed in __init__ """ @@ -142,7 +143,7 @@ def db_settings(self) -> DBSettings: def _get_experiment_and_generation_strategy_db_id( self, experiment_name: str - ) -> tuple[Optional[int], Optional[int]]: + ) -> tuple[int | None, int | None]: """Retrieve DB ids of experiment by the given name and the associated generation strategy. Each ID is None if corresponding object is not found. @@ -221,7 +222,7 @@ def _load_experiment_and_generation_strategy( experiment_name: str, reduced_state: bool = False, skip_runners_and_metrics: bool = False, - ) -> tuple[Optional[Experiment], Optional[GenerationStrategy]]: + ) -> tuple[Experiment | None, GenerationStrategy | None]: """Loads experiment and its corresponding generation strategy from database if DB settings are set on this `WithDBSettingsBase` instance. @@ -380,7 +381,7 @@ def _save_or_update_trials_in_db_if_possible( def _save_generation_strategy_to_db_if_possible( self, - generation_strategy: Optional[GenerationStrategyInterface] = None, + generation_strategy: GenerationStrategyInterface | None = None, ) -> bool: """Saves given generation strategy if DB settings are set on this `WithDBSettingsBase` instance and the generation strategy is an @@ -617,9 +618,9 @@ def _save_analysis_cards_to_db_if_possible( def try_load_generation_strategy( experiment_name: str, decoder: Decoder, - experiment: Optional[Experiment] = None, + experiment: Experiment | None = None, reduced_state: bool = False, -) -> Optional[GenerationStrategy]: +) -> GenerationStrategy | None: """Load generation strategy by experiment name, if it exists.""" try: start_time = time.time() diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 59df188c285..2103856ca3d 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -12,7 +12,7 @@ from inspect import isclass from io import StringIO from logging import Logger -from typing import Any, Optional, Union +from typing import Any import numpy as np import pandas as pd @@ -357,7 +357,7 @@ def trial_transition_criteria_from_json( transition_criteria_json: dict[str, Any], decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, -) -> Optional[TransitionCriterion]: +) -> TransitionCriterion | None: """Load Ax transition criteria that depend on Trials from JSON. Since ``TrialBasedCriterion`` contain lists of ``TrialStatus``, @@ -799,7 +799,7 @@ def model_spec_from_json( def generation_strategy_from_json( generation_strategy_json: dict[str, Any], - experiment: Optional[Experiment] = None, + experiment: Experiment | None = None, decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> GenerationStrategy: @@ -924,10 +924,10 @@ def surrogate_from_list_surrogate_json( def get_input_transform_json_components( - input_transforms_json: Optional[Union[list[dict[str, Any]], dict[str, Any]]], + input_transforms_json: list[dict[str, Any]] | dict[str, Any] | None, decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, -) -> tuple[Optional[list[dict[str, Any]]], Optional[dict[str, Any]]]: +) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | None]: if input_transforms_json is None: return None, None if isinstance(input_transforms_json, dict): @@ -952,10 +952,10 @@ def get_input_transform_json_components( def get_outcome_transform_json_components( - outcome_transforms_json: Optional[list[dict[str, Any]]], + outcome_transforms_json: list[dict[str, Any]] | None, decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, -) -> tuple[Optional[list[dict[str, Any]]], Optional[dict[str, Any]]]: +) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | None]: if outcome_transforms_json is None: return None, None diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index f4d50e2aad9..811df051022 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -13,7 +13,7 @@ from collections.abc import Iterable from datetime import datetime from pathlib import Path -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING import torch from ax.core.arm import Arm @@ -55,29 +55,29 @@ def batch_trial_from_json( experiment: core.experiment.Experiment, index: int, - trial_type: Optional[str], + trial_type: str | None, status: TrialStatus, time_created: datetime, - time_completed: Optional[datetime], - time_staged: Optional[datetime], - time_run_started: Optional[datetime], - abandoned_reason: Optional[str], - run_metadata: Optional[dict[str, Any]], + time_completed: datetime | None, + time_staged: datetime | None, + time_run_started: datetime | None, + abandoned_reason: str | None, + run_metadata: dict[str, Any] | None, generator_run_structs: list[GeneratorRunStruct], - runner: Optional[Runner], + runner: Runner | None, abandoned_arms_metadata: dict[str, AbandonedArm], num_arms_created: int, - status_quo: Optional[Arm], + status_quo: Arm | None, status_quo_weight_override: float, - optimize_for_power: Optional[bool], + optimize_for_power: bool | None, # Allowing default values for backwards compatibility with # objects stored before these fields were added. - failed_reason: Optional[str] = None, - ttl_seconds: Optional[int] = None, - generation_step_index: Optional[int] = None, - properties: Optional[dict[str, Any]] = None, - stop_metadata: Optional[dict[str, Any]] = None, - lifecycle_stage: Optional[LifecycleStage] = None, + failed_reason: str | None = None, + ttl_seconds: int | None = None, + generation_step_index: int | None = None, + properties: dict[str, Any] | None = None, + stop_metadata: dict[str, Any] | None = None, + lifecycle_stage: LifecycleStage | None = None, **kwargs: Any, ) -> BatchTrial: """Load Ax BatchTrial from JSON. @@ -117,24 +117,24 @@ def batch_trial_from_json( def trial_from_json( experiment: core.experiment.Experiment, index: int, - trial_type: Optional[str], + trial_type: str | None, status: TrialStatus, time_created: datetime, - time_completed: Optional[datetime], - time_staged: Optional[datetime], - time_run_started: Optional[datetime], - abandoned_reason: Optional[str], - run_metadata: Optional[dict[str, Any]], + time_completed: datetime | None, + time_staged: datetime | None, + time_run_started: datetime | None, + abandoned_reason: str | None, + run_metadata: dict[str, Any] | None, generator_run: GeneratorRun, - runner: Optional[Runner], + runner: Runner | None, num_arms_created: int, # Allowing default values for backwards compatibility with # objects stored before these fields were added. - failed_reason: Optional[str] = None, - ttl_seconds: Optional[int] = None, - generation_step_index: Optional[int] = None, - properties: Optional[dict[str, Any]] = None, - stop_metadata: Optional[dict[str, Any]] = None, + failed_reason: str | None = None, + ttl_seconds: int | None = None, + generation_step_index: int | None = None, + properties: dict[str, Any] | None = None, + stop_metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> Trial: """Load Ax trial from JSON. @@ -241,7 +241,7 @@ def tensor_from_json(json: dict[str, Any]) -> torch.Tensor: ) -def tensor_or_size_from_json(json: dict[str, Any]) -> Union[torch.Tensor, torch.Size]: +def tensor_or_size_from_json(json: dict[str, Any]) -> torch.Tensor | torch.Size: if json["__type"] == "Tensor": return tensor_from_json(json) elif json["__type"] == "torch_Size": @@ -322,7 +322,7 @@ def botorch_component_from_json(botorch_class: Any, json: dict[str, Any]) -> typ ) -def pathlib_from_json(pathsegments: Union[str, Iterable[str]]) -> Path: +def pathlib_from_json(pathsegments: str | Iterable[str]) -> Path: if isinstance(pathsegments, str): return Path(pathsegments) diff --git a/ax/storage/json_store/encoder.py b/ax/storage/json_store/encoder.py index 025d46a735c..7e47be23bfa 100644 --- a/ax/storage/json_store/encoder.py +++ b/ax/storage/json_store/encoder.py @@ -10,8 +10,9 @@ import datetime import enum from collections import OrderedDict +from collections.abc import Callable from inspect import isclass -from typing import Any, Callable +from typing import Any import numpy as np import pandas as pd diff --git a/ax/storage/json_store/load.py b/ax/storage/json_store/load.py index cca8ae0b696..c8ff1da6658 100644 --- a/ax/storage/json_store/load.py +++ b/ax/storage/json_store/load.py @@ -7,7 +7,8 @@ # pyre-strict import json -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from ax.core.experiment import Experiment from ax.storage.json_store.decoder import object_from_json @@ -31,7 +32,7 @@ def load_experiment( 1) Read file. 2) Convert dictionary to Ax experiment instance. """ - with open(filepath, "r") as file: + with open(filepath) as file: json_experiment = json.loads(file.read()) return object_from_json( json_experiment, decoder_registry, class_decoder_registry diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index a99866124bd..389b4f0bc45 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -7,7 +7,8 @@ # pyre-strict import pathlib -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch from ax.benchmark.benchmark_method import BenchmarkMethod @@ -251,7 +252,6 @@ SobolQMCNormalSampler: botorch_component_to_dict, SumConstraint: sum_parameter_constraint_to_dict, Surrogate: surrogate_to_dict, - BenchmarkMetric: metric_to_dict, SurrogateRunner: runner_to_dict, SyntheticRunner: runner_to_dict, ThresholdEarlyStoppingStrategy: threshold_early_stopping_strategy_to_dict, diff --git a/ax/storage/json_store/save.py b/ax/storage/json_store/save.py index 6501c8d16c1..5bcb23974f8 100644 --- a/ax/storage/json_store/save.py +++ b/ax/storage/json_store/save.py @@ -7,7 +7,8 @@ # pyre-strict import json -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from ax.core.experiment import Experiment from ax.storage.json_store.encoder import object_to_json diff --git a/ax/storage/metric_registry.py b/ax/storage/metric_registry.py index 6dc8463f140..10df0d3d988 100644 --- a/ax/storage/metric_registry.py +++ b/ax/storage/metric_registry.py @@ -6,8 +6,9 @@ # pyre-strict +from collections.abc import Callable from logging import Logger -from typing import Any, Callable, Optional +from typing import Any from ax.core.map_metric import MapMetric from ax.core.metric import Metric @@ -51,7 +52,7 @@ # pyre-fixme[3]: Return annotation cannot contain `Any`. def register_metrics( - metric_clss: dict[type[Metric], Optional[int]], + metric_clss: dict[type[Metric], int | None], # pyre-fixme[2]: Parameter annotation cannot contain `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. diff --git a/ax/storage/registry_bundle.py b/ax/storage/registry_bundle.py index 571c6d13dab..cfe8e3eeeb5 100644 --- a/ax/storage/registry_bundle.py +++ b/ax/storage/registry_bundle.py @@ -8,7 +8,9 @@ from __future__ import annotations from abc import ABC, abstractproperty -from typing import Any, Callable, ChainMap, Optional +from collections import ChainMap +from collections.abc import Callable +from typing import Any from ax.core.metric import Metric from ax.core.runner import Runner @@ -52,8 +54,8 @@ class to JSON. def __init__( self, - metric_clss: dict[type[Metric], Optional[int]], - runner_clss: dict[type[Runner], Optional[int]], + metric_clss: dict[type[Metric], int | None], + runner_clss: dict[type[Runner], int | None], # pyre-fixme[2]: Parameter annotation cannot contain `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. @@ -139,21 +141,17 @@ def from_registry_bundles( metric_clss={}, runner_clss={}, json_encoder_registry=dict( - # pyre-ignore[29] `typing._Alias` is not a function. ChainMap(*[bundle.encoder_registry for bundle in registry_bundles]) ), json_class_encoder_registry=dict( - # pyre-ignore[29] `typing._Alias` is not a function. ChainMap( *[bundle.class_encoder_registry for bundle in registry_bundles] ) ), json_decoder_registry=dict( - # pyre-ignore[29] `typing._Alias` is not a function. ChainMap(*[bundle.decoder_registry for bundle in registry_bundles]) ), json_class_decoder_registry=dict( - # pyre-ignore[29] `typing._Alias` is not a function. ChainMap( *[bundle.class_decoder_registry for bundle in registry_bundles] ) @@ -166,8 +164,8 @@ class RegistryBundle(RegistryBundleBase): def __init__( self, - metric_clss: dict[type[Metric], Optional[int]], - runner_clss: dict[type[Runner], Optional[int]], + metric_clss: dict[type[Metric], int | None], + runner_clss: dict[type[Runner], int | None], # pyre-fixme[2]: Parameter annotation cannot contain `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. diff --git a/ax/storage/runner_registry.py b/ax/storage/runner_registry.py index a0aeee83eb2..5f5e0ff256d 100644 --- a/ax/storage/runner_registry.py +++ b/ax/storage/runner_registry.py @@ -6,8 +6,9 @@ # pyre-strict +from collections.abc import Callable from logging import Logger -from typing import Any, Callable, Optional +from typing import Any from ax.core.runner import Runner from ax.runners.synthetic import SyntheticRunner @@ -45,7 +46,7 @@ def register_runner( type, Callable[[Any], dict[str, Any]] ] = CORE_ENCODER_REGISTRY, decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, - val: Optional[int] = None, + val: int | None = None, # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to # avoid runtime subscripting errors. ) -> tuple[ @@ -67,7 +68,7 @@ def register_runner( # pyre-fixme[3]: Return annotation cannot contain `Any`. def register_runners( - runner_clss: dict[type[Runner], Optional[int]], + runner_clss: dict[type[Runner], int | None], runner_registry: dict[type[Runner], int] = CORE_RUNNER_REGISTRY, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use diff --git a/ax/storage/sqa_store/db.py b/ax/storage/sqa_store/db.py index 39f28253e4e..fce52c7be75 100644 --- a/ax/storage/sqa_store/db.py +++ b/ax/storage/sqa_store/db.py @@ -8,10 +8,10 @@ from __future__ import annotations -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar from sqlalchemy import create_engine from sqlalchemy.engine.base import Engine @@ -34,7 +34,7 @@ LONGTEXT_BYTES: int = 2**32 - 1 # global database variables -SESSION_FACTORY: Optional[Session] = None +SESSION_FACTORY: Session | None = None # set this to false to prevent SQLAlchemy for automatically expiring objects # on commit, which essentially makes them unusable outside of a session @@ -54,7 +54,6 @@ class SQABase: def create_mysql_engine_from_creator( - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. creator: Callable, echo: bool = False, pool_recycle: int = 10, @@ -99,7 +98,7 @@ def create_mysql_engine_from_url( return create_engine(url, pool_recycle=pool_recycle, echo=echo, **kwargs) -def create_test_engine(path: Optional[str] = None, echo: bool = True) -> Engine: +def create_test_engine(path: str | None = None, echo: bool = True) -> Engine: """Creates a SQLAlchemy engine object for use in unit tests. Args: @@ -118,14 +117,13 @@ def create_test_engine(path: Optional[str] = None, echo: bool = True) -> Engine: # (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlite) db_path = "sqlite://" else: - db_path = "sqlite:///{path}".format(path=path) + db_path = f"sqlite:///{path}" return create_engine(db_path, echo=echo) def init_engine_and_session_factory( - url: Optional[str] = None, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - creator: Optional[Callable] = None, + url: str | None = None, + creator: Callable | None = None, echo: bool = False, force_init: bool = False, **kwargs: Any, @@ -169,7 +167,7 @@ def init_engine_and_session_factory( def init_test_engine_and_session_factory( - tier_or_path: Optional[str] = None, + tier_or_path: str | None = None, echo: bool = False, force_init: bool = False, **kwargs: Any, diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index eac2e0caf4d..c2842e0ce97 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -11,7 +11,7 @@ from enum import Enum from io import StringIO from logging import Logger -from typing import Any, cast, Optional, Union +from typing import Any, cast, Union import pandas as pd from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel @@ -90,8 +90,8 @@ def __init__(self, config: SQAConfig) -> None: self.config = config def get_enum_name( - self, value: Optional[int], enum: Optional[Union[Enum, type[Enum]]] - ) -> Optional[str]: + self, value: int | None, enum: Enum | type[Enum] | None + ) -> str | None: """Given an enum value (int) and an enum (of ints), return the corresponding enum name. If the value is not present in the enum, throw an error. @@ -106,7 +106,7 @@ def get_enum_name( def _auxiliary_experiments_by_purpose_from_experiment_sqa( self, experiment_sqa: SQAExperiment - ) -> Optional[dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]]: + ) -> dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] | None: auxiliary_experiments_by_purpose = None if experiment_sqa.auxiliary_experiments_by_purpose: from ax.storage.sqa_store.load import load_experiment @@ -138,7 +138,7 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa( def _init_experiment_from_sqa( self, experiment_sqa: SQAExperiment, - ax_object_field_overrides: Optional[dict[str, Any]] = None, + ax_object_field_overrides: dict[str, Any] | None = None, load_auxiliary_experiments: bool = True, ) -> Experiment: """First step of conversion within experiment_from_sqa.""" @@ -259,7 +259,7 @@ def experiment_from_sqa( self, experiment_sqa: SQAExperiment, reduced_state: bool = False, - ax_object_field_overrides: Optional[dict[str, Any]] = None, + ax_object_field_overrides: dict[str, Any] | None = None, load_auxiliary_experiments: bool = True, ) -> Experiment: """Convert SQLAlchemy Experiment to Ax Experiment. @@ -455,7 +455,7 @@ def parameter_constraint_from_sqa( def parameter_distributions_from_sqa( self, parameter_constraint_sqa_list: list[SQAParameterConstraint], - ) -> tuple[list[ParameterDistribution], Optional[int]]: + ) -> tuple[list[ParameterDistribution], int | None]: """Convert SQLAlchemy ParameterConstraints to Ax ParameterDistributions.""" parameter_distributions: list[ParameterDistribution] = [] num_samples = None @@ -506,7 +506,7 @@ def search_space_from_sqa( self, parameters_sqa: list[SQAParameter], parameter_constraints_sqa: list[SQAParameterConstraint], - ) -> Optional[SearchSpace]: + ) -> SearchSpace | None: """Convert a list of SQLAlchemy Parameters and ParameterConstraints to an Ax SearchSpace. """ @@ -555,7 +555,7 @@ def search_space_from_sqa( def metric_from_sqa( self, metric_sqa: SQAMetric - ) -> Union[Metric, Objective, OutcomeConstraint, RiskMeasure]: + ) -> Metric | Objective | OutcomeConstraint | RiskMeasure: """Convert SQLAlchemy Metric to Ax Metric, Objective, or OutcomeConstraint.""" metric = self._metric_from_sqa_util(metric_sqa) @@ -596,7 +596,7 @@ def metric_from_sqa( def opt_config_and_tracking_metrics_from_sqa( self, metrics_sqa: list[SQAMetric] - ) -> tuple[Optional[OptimizationConfig], list[Metric]]: + ) -> tuple[OptimizationConfig | None, list[Metric]]: """Convert a list of SQLAlchemy Metrics to a a tuple of Ax OptimizationConfig and tracking metrics. """ @@ -793,7 +793,7 @@ def generator_run_from_sqa( def generation_strategy_from_sqa( self, gs_sqa: SQAGenerationStrategy, - experiment: Optional[Experiment] = None, + experiment: Experiment | None = None, reduced_state: bool = False, ) -> GenerationStrategy: """Convert SQALchemy generation strategy to Ax `GenerationStrategy`.""" @@ -871,7 +871,7 @@ def generation_strategy_from_sqa( return gs def runner_from_sqa( - self, runner_sqa: SQARunner, runner_kwargs: Optional[dict[str, Any]] = None + self, runner_sqa: SQARunner, runner_kwargs: dict[str, Any] | None = None ) -> Runner: """Convert SQLAlchemy Runner to Ax Runner.""" if runner_sqa.runner_type not in self.config.reverse_runner_registry: @@ -897,7 +897,7 @@ def trial_from_sqa( trial_sqa: SQATrial, experiment: Experiment, reduced_state: bool = False, - ax_object_field_overrides: Optional[dict[str, Any]] = None, + ax_object_field_overrides: dict[str, Any] | None = None, ) -> BaseTrial: """Convert SQLAlchemy Trial to Ax Trial. diff --git a/ax/storage/sqa_store/delete.py b/ax/storage/sqa_store/delete.py index 237a31d4aa7..9ac6ea233b6 100644 --- a/ax/storage/sqa_store/delete.py +++ b/ax/storage/sqa_store/delete.py @@ -6,7 +6,6 @@ # pyre-strict from logging import Logger -from typing import Optional from ax.core.experiment import Experiment from ax.modelbridge.generation_strategy import GenerationStrategy @@ -35,9 +34,7 @@ def delete_experiment(exp_name: str) -> None: ) -def delete_generation_strategy( - exp_name: str, config: Optional[SQAConfig] = None -) -> None: +def delete_generation_strategy(exp_name: str, config: SQAConfig | None = None) -> None: """Delete the generation strategy associated with an experiment Args: diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index fe4a920ecc6..6b6d341b410 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -10,7 +10,7 @@ from enum import Enum from logging import Logger -from typing import Any, cast, Optional, Union +from typing import Any, cast from ax.analysis.analysis import AnalysisCard @@ -88,7 +88,7 @@ def __init__(self, config: SQAConfig) -> None: def validate_experiment_metadata( cls, experiment: Experiment, - existing_sqa_experiment_id: Optional[int], + existing_sqa_experiment_id: int | None, ) -> None: """Validates required experiment metadata.""" if experiment.db_id is not None: @@ -115,8 +115,8 @@ def validate_experiment_metadata( ) def get_enum_value( - self, value: Optional[str], enum: Optional[Union[Enum, type[Enum]]] - ) -> Optional[int]: + self, value: str | None, enum: Enum | type[Enum] | None + ) -> int | None: """Given an enum name (string) and an enum (of ints), return the corresponding enum value. If the name is not present in the enum, throw an error. @@ -311,7 +311,7 @@ def parameter_constraint_to_sqa( ) def search_space_to_sqa( - self, search_space: Optional[SearchSpace] + self, search_space: SearchSpace | None ) -> tuple[list[SQAParameter], list[SQAParameterConstraint]]: """Convert Ax SearchSpace to a list of SQLAlchemy Parameters and ParameterConstraints. @@ -690,7 +690,7 @@ def risk_measure_to_sqa(self, risk_measure: RiskMeasure) -> SQAMetric: ) def optimization_config_to_sqa( - self, optimization_config: Optional[OptimizationConfig] + self, optimization_config: OptimizationConfig | None ) -> list[SQAMetric]: """Convert Ax OptimizationConfig to a list of SQLAlchemy Metrics.""" if optimization_config is None: @@ -717,7 +717,7 @@ def optimization_config_to_sqa( metrics_sqa.append(risk_measure_sqa) return metrics_sqa - def arm_to_sqa(self, arm: Arm, weight: Optional[float] = 1.0) -> SQAArm: + def arm_to_sqa(self, arm: Arm, weight: float | None = 1.0) -> SQAArm: """Convert Ax Arm to SQLAlchemy.""" # pyre-fixme: Expected `Base` for 1st... got `typing.Type[Arm]`. arm_class: SQAArm = self.config.class_to_sqa_class[Arm] @@ -743,7 +743,7 @@ def abandoned_arm_to_sqa(self, abandoned_arm: AbandonedArm) -> SQAAbandonedArm: def generator_run_to_sqa( self, generator_run: GeneratorRun, - weight: Optional[float] = None, + weight: float | None = None, reduced_state: bool = False, ) -> SQAGeneratorRun: """Convert Ax GeneratorRun to SQLAlchemy. @@ -852,7 +852,7 @@ def generator_run_to_sqa( def generation_strategy_to_sqa( self, generation_strategy: GenerationStrategy, - experiment_id: Optional[int], + experiment_id: int | None, generator_run_reduced_state: bool = False, ) -> SQAGenerationStrategy: """Convert an Ax `GenerationStrategy` to SQLAlchemy, preserving its state, @@ -907,9 +907,7 @@ def generation_strategy_to_sqa( ) return gs_sqa - def runner_to_sqa( - self, runner: Runner, trial_type: Optional[str] = None - ) -> SQARunner: + def runner_to_sqa(self, runner: Runner, trial_type: str | None = None) -> SQARunner: """Convert Ax Runner to SQLAlchemy.""" runner_class = type(runner) runner_type = self.config.runner_registry.get(runner_class) @@ -1043,7 +1041,7 @@ def experiment_data_to_sqa(self, experiment: Experiment) -> list[SQAData]: ] def data_to_sqa( - self, data: Data, trial_index: Optional[int], timestamp: int + self, data: Data, trial_index: int | None, timestamp: int ) -> SQAData: """Convert Ax data to SQLAlchemy.""" # pyre-fixme: Expected `Base` for 1st...ot `typing.Type[Data]`. diff --git a/ax/storage/sqa_store/json.py b/ax/storage/sqa_store/json.py index d31fb7e6959..9d09eaca96b 100644 --- a/ax/storage/sqa_store/json.py +++ b/ax/storage/sqa_store/json.py @@ -8,7 +8,7 @@ import json from json import JSONDecodeError -from typing import Any, Optional +from typing import Any from ax.storage.sqa_store.db import JSON_FIELD_LENGTH, LONGTEXT_BYTES, MEDIUMTEXT_BYTES from sqlalchemy.ext.mutable import MutableDict, MutableList @@ -40,7 +40,7 @@ def __init__( super().__init__(*args, **kwargs) # pyre-fixme[2]: Parameter annotation cannot be `Any`. - def process_bind_param(self, value: Any, dialect: Any) -> Optional[str]: + def process_bind_param(self, value: Any, dialect: Any) -> str | None: if value is not None: return json.dumps(value) else: diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index c5d8ae6f8e8..089f2cdbae2 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -7,7 +7,7 @@ # pyre-strict from math import ceil -from typing import Any, cast, Optional +from typing import Any, cast from ax.analysis.analysis import AnalysisCard @@ -45,9 +45,9 @@ def load_experiment( experiment_name: str, - config: Optional[SQAConfig] = None, + config: SQAConfig | None = None, reduced_state: bool = False, - load_trials_in_batches_of_size: Optional[int] = None, + load_trials_in_batches_of_size: int | None = None, skip_runners_and_metrics: bool = False, load_auxiliary_experiments: bool = True, ) -> Experiment: @@ -84,8 +84,8 @@ def _load_experiment( experiment_name: str, decoder: Decoder, reduced_state: bool = False, - load_trials_in_batches_of_size: Optional[int] = None, - ax_object_field_overrides: Optional[dict[str, Any]] = None, + load_trials_in_batches_of_size: int | None = None, + ax_object_field_overrides: dict[str, Any] | None = None, skip_runners_and_metrics: bool = False, load_auxiliary_experiments: bool = True, ) -> Experiment: @@ -180,8 +180,8 @@ def _get_experiment_sqa( exp_sqa_class: type[SQAExperiment], trial_sqa_class: type[SQATrial], # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - trials_query_options: Optional[list[Any]] = None, - load_trials_in_batches_of_size: Optional[int] = None, + trials_query_options: list[Any] | None = None, + load_trials_in_batches_of_size: int | None = None, skip_runners_and_metrics: bool = False, ) -> SQAExperiment: """Obtains SQLAlchemy experiment object from DB.""" @@ -216,9 +216,9 @@ def _get_experiment_sqa( def _get_trials_sqa( experiment_id: int, trial_sqa_class: type[SQATrial], - load_trials_in_batches_of_size: Optional[int] = None, + load_trials_in_batches_of_size: int | None = None, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - trials_query_options: Optional[list[Any]] = None, + trials_query_options: list[Any] | None = None, skip_runners_and_metrics: bool = False, ) -> list[SQATrial]: """Obtains SQLAlchemy trial objects for given experiment ID from DB, @@ -276,7 +276,7 @@ def _get_experiment_sqa_reduced_state( experiment_name: str, exp_sqa_class: type[SQAExperiment], trial_sqa_class: type[SQATrial], - load_trials_in_batches_of_size: Optional[int] = None, + load_trials_in_batches_of_size: int | None = None, skip_runners_and_metrics: bool = False, ) -> SQAExperiment: """Obtains most of the SQLAlchemy experiment object from DB, with some attributes @@ -301,7 +301,7 @@ def _get_experiment_sqa_immutable_opt_config_and_search_space( experiment_name: str, exp_sqa_class: type[SQAExperiment], trial_sqa_class: type[SQATrial], - load_trials_in_batches_of_size: Optional[int] = None, + load_trials_in_batches_of_size: int | None = None, skip_runners_and_metrics: bool = False, ) -> SQAExperiment: """For experiments where the search space and opt config are @@ -341,7 +341,7 @@ def _get_experiment_immutable_opt_config_and_search_space( ) -def _get_experiment_id(experiment_name: str, config: SQAConfig) -> Optional[int]: +def _get_experiment_id(experiment_name: str, config: SQAConfig) -> int | None: """Get DB ID of the experiment by the given name if its in DB, return None otherwise. """ @@ -363,8 +363,8 @@ def _get_experiment_id(experiment_name: str, config: SQAConfig) -> Optional[int] def load_generation_strategy_by_experiment_name( experiment_name: str, - config: Optional[SQAConfig] = None, - experiment: Optional[Experiment] = None, + config: SQAConfig | None = None, + experiment: Experiment | None = None, reduced_state: bool = False, skip_runners_and_metrics: bool = False, ) -> GenerationStrategy: @@ -384,8 +384,8 @@ def load_generation_strategy_by_experiment_name( def load_generation_strategy_by_id( gs_id: int, - config: Optional[SQAConfig] = None, - experiment: Optional[Experiment] = None, + config: SQAConfig | None = None, + experiment: Experiment | None = None, reduced_state: bool = False, ) -> GenerationStrategy: """Finds a generation strategy stored by a given ID and restores it.""" @@ -399,7 +399,7 @@ def load_generation_strategy_by_id( def _load_generation_strategy_by_experiment_name( experiment_name: str, decoder: Decoder, - experiment: Optional[Experiment] = None, + experiment: Experiment | None = None, reduced_state: bool = False, skip_runners_and_metrics: bool = False, ) -> GenerationStrategy: @@ -430,7 +430,7 @@ def _load_generation_strategy_by_experiment_name( def _load_generation_strategy_by_id( gs_id: int, decoder: Decoder, - experiment: Optional[Experiment] = None, + experiment: Experiment | None = None, reduced_state: bool = False, ) -> GenerationStrategy: """Finds a generation strategy stored by a given ID and restores it.""" @@ -470,7 +470,7 @@ def _load_generation_strategy_by_id( ) -def get_generation_strategy_id(experiment_name: str, decoder: Decoder) -> Optional[int]: +def get_generation_strategy_id(experiment_name: str, decoder: Decoder) -> int | None: """Get DB ID of the generation strategy, associated with the experiment with the given name if its in DB, return None otherwise. """ @@ -494,7 +494,7 @@ def get_generation_strategy_sqa( gs_id: int, decoder: Decoder, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - query_options: Optional[list[Any]] = None, + query_options: list[Any] | None = None, ) -> SQAGenerationStrategy: """Obtains the SQLAlchemy generation strategy object from DB.""" gs_sqa_class = cast( @@ -596,7 +596,7 @@ def _get_generation_strategy_sqa_immutable_opt_config_and_search_space( def load_analysis_cards_by_experiment_name( experiment_name: str, - config: Optional[SQAConfig] = None, + config: SQAConfig | None = None, ) -> list[AnalysisCard]: """Loads analysis cards for an experiment.""" config = SQAConfig() if config is None else config diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index 385d07fc830..d0f9c381670 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -7,11 +7,11 @@ # pyre-strict import os -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from logging import Logger -from typing import Any, Callable, cast, Optional, Union +from typing import Any, cast from ax.analysis.analysis import AnalysisCard @@ -45,7 +45,7 @@ logger: Logger = get_logger(__name__) -def save_experiment(experiment: Experiment, config: Optional[SQAConfig] = None) -> None: +def save_experiment(experiment: Experiment, config: SQAConfig | None = None) -> None: """Save experiment (using default SQAConfig).""" if not isinstance(experiment, Experiment): raise ValueError("Can only save instances of Experiment") @@ -62,8 +62,8 @@ def _save_experiment( encoder: Encoder, decoder: Decoder, return_sqa: bool = False, - validation_kwargs: Optional[dict[str, Any]] = None, -) -> Optional[SQABase]: + validation_kwargs: dict[str, Any] | None = None, +) -> SQABase | None: """Save experiment, using given Encoder instance. 1) Convert Ax object to SQLAlchemy object. @@ -101,7 +101,7 @@ def _save_experiment( def save_generation_strategy( - generation_strategy: GenerationStrategy, config: Optional[SQAConfig] = None + generation_strategy: GenerationStrategy, config: SQAConfig | None = None ) -> int: """Save generation strategy (using default SQAConfig if no config is specified). If the generation strategy has an experiment set, the experiment @@ -150,7 +150,7 @@ def _save_generation_strategy( def save_or_update_trial( - experiment: Experiment, trial: BaseTrial, config: Optional[SQAConfig] = None + experiment: Experiment, trial: BaseTrial, config: SQAConfig | None = None ) -> None: """Add new trial to the experiment, or update if already exists (using default SQAConfig).""" @@ -182,8 +182,8 @@ def _save_or_update_trial( def save_or_update_trials( experiment: Experiment, trials: list[BaseTrial], - config: Optional[SQAConfig] = None, - batch_size: Optional[int] = None, + config: SQAConfig | None = None, + batch_size: int | None = None, reduce_state_generator_runs: bool = False, ) -> None: """Add new trials to the experiment, or update if already exists @@ -211,7 +211,7 @@ def _save_or_update_trials( trials: list[BaseTrial], encoder: Encoder, decoder: Decoder, - batch_size: Optional[int] = None, + batch_size: int | None = None, reduce_state_generator_runs: bool = False, ) -> None: """Add new trials to the experiment, or update if they already exist. @@ -225,7 +225,7 @@ def _save_or_update_trials( experiment_id: int = experiment._db_id - def add_experiment_id(sqa: Union[SQATrial, SQAData]) -> None: + def add_experiment_id(sqa: SQATrial | SQAData) -> None: sqa.experiment_id = experiment_id if reduce_state_generator_runs: @@ -304,8 +304,8 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial: def update_generation_strategy( generation_strategy: GenerationStrategy, generator_runs: list[GeneratorRun], - config: Optional[SQAConfig] = None, - batch_size: Optional[int] = None, + config: SQAConfig | None = None, + batch_size: int | None = None, reduce_state_generator_runs: bool = False, ) -> None: """Update generation strategy's current step and attach generator runs @@ -328,7 +328,7 @@ def _update_generation_strategy( generator_runs: list[GeneratorRun], encoder: Encoder, decoder: Decoder, - batch_size: Optional[int] = None, + batch_size: int | None = None, reduce_state_generator_runs: bool = False, ) -> None: """Update generation strategy's current step and attach generator runs.""" @@ -367,7 +367,7 @@ def add_generation_strategy_id(sqa: SQAGeneratorRun): sqa.generation_strategy_id = gs_id # pyre-fixme[3]: Return type must be annotated. - def generator_run_to_sqa_encoder(gr: GeneratorRun, weight: Optional[float] = None): + def generator_run_to_sqa_encoder(gr: GeneratorRun, weight: float | None = None): return encoder.generator_run_to_sqa( gr, weight=weight, @@ -423,7 +423,7 @@ def update_outcome_constraint_on_experiment( ) -> None: oc_sqa_class = encoder.config.class_to_sqa_class[Metric] - exp_id: Optional[int] = experiment.db_id + exp_id: int | None = experiment.db_id if exp_id is None: raise UserInputError("Experiment must be saved before being updated.") oc_id = outcome_constraint.db_id @@ -451,7 +451,7 @@ def add_experiment_id(sqa: SQAMetric) -> None: def update_properties_on_experiment( experiment_with_updated_properties: Experiment, - config: Optional[SQAConfig] = None, + config: SQAConfig | None = None, ) -> None: config = SQAConfig() if config is None else config exp_sqa_class = config.class_to_sqa_class[Experiment] @@ -470,7 +470,7 @@ def update_properties_on_experiment( def update_properties_on_trial( trial_with_updated_properties: BaseTrial, - config: Optional[SQAConfig] = None, + config: SQAConfig | None = None, ) -> None: config = SQAConfig() if config is None else config trial_sqa_class = config.class_to_sqa_class[Trial] @@ -490,7 +490,7 @@ def update_properties_on_trial( def save_analysis_cards( analysis_cards: list[AnalysisCard], experiment: Experiment, - config: Optional[SQAConfig] = None, + config: SQAConfig | None = None, ) -> None: # Start up SQA encoder. config = SQAConfig() if config is None else config @@ -535,14 +535,11 @@ def _save_analysis_cards( def _merge_into_session( obj: Base, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. encode_func: Callable, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. decode_func: Callable, - encode_args: Optional[dict[str, Any]] = None, - decode_args: Optional[dict[str, Any]] = None, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - modify_sqa: Optional[Callable] = None, + encode_args: dict[str, Any] | None = None, + decode_args: dict[str, Any] | None = None, + modify_sqa: Callable | None = None, ) -> SQABase: """Given a user-facing object (that may or may not correspond to an existing DB object), perform the following steps to either create or @@ -577,15 +574,12 @@ def _merge_into_session( def _bulk_merge_into_session( objs: Sequence[Base], - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. encode_func: Callable, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. decode_func: Callable, - encode_args_list: Optional[Union[list[None], list[dict[str, Any]]]] = None, - decode_args_list: Optional[Union[list[None], list[dict[str, Any]]]] = None, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - modify_sqa: Optional[Callable] = None, - batch_size: Optional[int] = None, + encode_args_list: list[None] | list[dict[str, Any]] | None = None, + decode_args_list: list[None] | list[dict[str, Any]] | None = None, + modify_sqa: Callable | None = None, + batch_size: int | None = None, ) -> list[SQABase]: """Bulk version of _merge_into_session. diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 7ae9e93e2c8..24082ee378d 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -10,7 +10,7 @@ from datetime import datetime from decimal import Decimal -from typing import Any, Optional +from typing import Any from ax.analysis.analysis import AnalysisCardLevel @@ -62,36 +62,32 @@ class SQAParameter(Base): __tablename__: str = "parameter_v2" domain_type: Column[DomainType] = Column(IntEnum(DomainType), nullable=False) - experiment_id: Column[Optional[int]] = Column( - Integer, ForeignKey("experiment_v2.id") - ) + experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) id: Column[int] = Column(Integer, primary_key=True) - generator_run_id: Column[Optional[int]] = Column( + generator_run_id: Column[int | None] = Column( Integer, ForeignKey("generator_run_v2.id") ) name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) parameter_type: Column[ParameterType] = Column( IntEnum(ParameterType), nullable=False ) - is_fidelity: Column[Optional[bool]] = Column(Boolean) - target_value: Column[Optional[TParamValue]] = Column(JSONEncodedObject) + is_fidelity: Column[bool | None] = Column(Boolean) + target_value: Column[TParamValue | None] = Column(JSONEncodedObject) # Attributes for Range Parameters - digits: Column[Optional[int]] = Column(Integer) - log_scale: Column[Optional[bool]] = Column(Boolean) - lower: Column[Optional[Decimal]] = Column(Float) - upper: Column[Optional[Decimal]] = Column(Float) + digits: Column[int | None] = Column(Integer) + log_scale: Column[bool | None] = Column(Boolean) + lower: Column[Decimal | None] = Column(Float) + upper: Column[Decimal | None] = Column(Float) # Attributes for Choice Parameters - choice_values: Column[Optional[list[TParamValue]]] = Column(JSONEncodedList) - is_ordered: Column[Optional[bool]] = Column(Boolean) - is_task: Column[Optional[bool]] = Column(Boolean) - dependents: Column[Optional[dict[TParamValue, list[str]]]] = Column( - JSONEncodedObject - ) + choice_values: Column[list[TParamValue] | None] = Column(JSONEncodedList) + is_ordered: Column[bool | None] = Column(Boolean) + is_task: Column[bool | None] = Column(Boolean) + dependents: Column[dict[TParamValue, list[str]] | None] = Column(JSONEncodedObject) # Attributes for Fixed Parameters - fixed_value: Column[Optional[TParamValue]] = Column(JSONEncodedObject) + fixed_value: Column[TParamValue | None] = Column(JSONEncodedObject) class SQAParameterConstraint(Base): @@ -99,11 +95,9 @@ class SQAParameterConstraint(Base): bound: Column[Decimal] = Column(Float, nullable=False) constraint_dict: Column[dict[str, float]] = Column(JSONEncodedDict, nullable=False) - experiment_id: Column[Optional[int]] = Column( - Integer, ForeignKey("experiment_v2.id") - ) + experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) id: Column[int] = Column(Integer, primary_key=True) - generator_run_id: Column[Optional[int]] = Column( + generator_run_id: Column[int | None] = Column( Integer, ForeignKey("generator_run_v2.id") ) type: Column[IntEnum] = Column(IntEnum(ParameterConstraintType), nullable=False) @@ -112,33 +106,29 @@ class SQAParameterConstraint(Base): class SQAMetric(Base): __tablename__: str = "metric_v2" - experiment_id: Column[Optional[int]] = Column( - Integer, ForeignKey("experiment_v2.id") - ) - generator_run_id: Column[Optional[int]] = Column( + experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + generator_run_id: Column[int | None] = Column( Integer, ForeignKey("generator_run_v2.id") ) id: Column[int] = Column(Integer, primary_key=True) - lower_is_better: Column[Optional[bool]] = Column(Boolean) + lower_is_better: Column[bool | None] = Column(Boolean) intent: Column[MetricIntent] = Column(StringEnum(MetricIntent), nullable=False) metric_type: Column[int] = Column(Integer, nullable=False) name: Column[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) - properties: Column[Optional[dict[str, Any]]] = Column( - JSONEncodedTextDict, default={} - ) + properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) # Attributes for Objectives - minimize: Column[Optional[bool]] = Column(Boolean) + minimize: Column[bool | None] = Column(Boolean) # Attributes for Outcome Constraints - op: Column[Optional[ComparisonOp]] = Column(IntEnum(ComparisonOp)) - bound: Column[Optional[Decimal]] = Column(Float) - relative: Column[Optional[bool]] = Column(Boolean) + op: Column[ComparisonOp | None] = Column(IntEnum(ComparisonOp)) + bound: Column[Decimal | None] = Column(Float) + relative: Column[bool | None] = Column(Boolean) # Multi-type Experiment attributes - trial_type: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - canonical_name: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - scalarized_objective_id: Column[Optional[int]] = Column( + trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + canonical_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + scalarized_objective_id: Column[int | None] = Column( Integer, ForeignKey("metric_v2.id") ) @@ -154,8 +144,8 @@ class SQAMetric(Base): ) # Attribute only defined for the children of Scalarized Objective - scalarized_objective_weight: Column[Optional[Decimal]] = Column(Float) - scalarized_outcome_constraint_id: Column[Optional[int]] = Column( + scalarized_objective_weight: Column[Decimal | None] = Column(Float) + scalarized_outcome_constraint_id: Column[int | None] = Column( Integer, ForeignKey("metric_v2.id") ) scalarized_outcome_constraint_children_metrics: list[SQAMetric] = relationship( @@ -164,7 +154,7 @@ class SQAMetric(Base): lazy=True, foreign_keys=[scalarized_outcome_constraint_id], ) - scalarized_outcome_constraint_weight: Column[Optional[Decimal]] = Column(Float) + scalarized_outcome_constraint_weight: Column[Decimal | None] = Column(Float) class SQAArm(Base): @@ -174,7 +164,7 @@ class SQAArm(Base): Integer, ForeignKey("generator_run_v2.id"), nullable=False ) id: Column[int] = Column(Integer, primary_key=True) - name: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) parameters: Column[TParameterization] = Column(JSONEncodedTextDict, nullable=False) weight: Column[Decimal] = Column(Float, nullable=False, default=1.0) @@ -182,7 +172,7 @@ class SQAArm(Base): class SQAAbandonedArm(Base): __tablename__: str = "abandoned_arm_v2" - abandoned_reason: Column[Optional[str]] = Column(String(LONG_STRING_FIELD_LENGTH)) + abandoned_reason: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) id: Column[int] = Column(Integer, primary_key=True) name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) time_abandoned: Column[datetime] = Column( @@ -194,39 +184,33 @@ class SQAAbandonedArm(Base): class SQAGeneratorRun(Base): __tablename__: str = "generator_run_v2" - best_arm_name: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - best_arm_parameters: Column[Optional[TParameterization]] = Column( - JSONEncodedTextDict - ) - best_arm_predictions: Column[Optional[TModelPredictArm]] = Column(JSONEncodedList) - generator_run_type: Column[Optional[int]] = Column(Integer) + best_arm_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + best_arm_parameters: Column[TParameterization | None] = Column(JSONEncodedTextDict) + best_arm_predictions: Column[TModelPredictArm | None] = Column(JSONEncodedList) + generator_run_type: Column[int | None] = Column(Integer) id: Column[int] = Column(Integer, primary_key=True) - index: Column[Optional[int]] = Column(Integer) - model_predictions: Column[Optional[TModelPredict]] = Column(JSONEncodedList) + index: Column[int | None] = Column(Integer) + model_predictions: Column[TModelPredict | None] = Column(JSONEncodedList) time_created: Column[datetime] = Column( IntTimestamp, nullable=False, default=datetime.now ) - trial_id: Column[Optional[int]] = Column(Integer, ForeignKey("trial_v2.id")) - weight: Column[Optional[Decimal]] = Column(Float) - fit_time: Column[Optional[Decimal]] = Column(Float) - gen_time: Column[Optional[Decimal]] = Column(Float) - model_key: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - model_kwargs: Column[Optional[dict[str, Any]]] = Column(JSONEncodedTextDict) - bridge_kwargs: Column[Optional[dict[str, Any]]] = Column(JSONEncodedTextDict) - gen_metadata: Column[Optional[dict[str, Any]]] = Column(JSONEncodedTextDict) - model_state_after_gen: Column[Optional[dict[str, Any]]] = Column( - JSONEncodedTextDict - ) - generation_strategy_id: Column[Optional[int]] = Column( + trial_id: Column[int | None] = Column(Integer, ForeignKey("trial_v2.id")) + weight: Column[Decimal | None] = Column(Float) + fit_time: Column[Decimal | None] = Column(Float) + gen_time: Column[Decimal | None] = Column(Float) + model_key: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + model_kwargs: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) + bridge_kwargs: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) + gen_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) + model_state_after_gen: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) + generation_strategy_id: Column[int | None] = Column( Integer, ForeignKey("generation_strategy.id") ) - generation_step_index: Column[Optional[int]] = Column(Integer) - candidate_metadata_by_arm_signature: Column[Optional[dict[str, Any]]] = Column( + generation_step_index: Column[int | None] = Column(Integer) + candidate_metadata_by_arm_signature: Column[dict[str, Any] | None] = Column( JSONEncodedTextDict ) - generation_node_name: Column[Optional[str]] = Column( - String(NAME_OR_TYPE_FIELD_LENGTH) - ) + generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) # relationships # Use selectin loading for collections to prevent idle timeout errors @@ -252,17 +236,15 @@ class SQARunner(Base): __tablename__: str = "runner" id: Column[int] = Column(Integer, primary_key=True) - experiment_id: Column[Optional[int]] = Column( - Integer, ForeignKey("experiment_v2.id") - ) - properties: Column[Optional[dict[str, Any]]] = Column( + experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) + properties: Column[dict[str, Any] | None] = Column( JSONEncodedLongTextDict, default={} ) runner_type: Column[int] = Column(Integer, nullable=False) - trial_id: Column[Optional[int]] = Column(Integer, ForeignKey("trial_v2.id")) + trial_id: Column[int | None] = Column(Integer, ForeignKey("trial_v2.id")) # Multi-type Experiment attributes - trial_type: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) class SQAData(Base): @@ -270,16 +252,14 @@ class SQAData(Base): id: Column[int] = Column(Integer, primary_key=True) data_json: Column[str] = Column(Text(LONGTEXT_BYTES), nullable=False) - description: Column[Optional[str]] = Column(String(LONG_STRING_FIELD_LENGTH)) - experiment_id: Column[Optional[int]] = Column( - Integer, ForeignKey("experiment_v2.id") - ) + description: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) + experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) time_created: Column[int] = Column(BigInteger, nullable=False) - trial_index: Column[Optional[int]] = Column(Integer) - generation_strategy_id: Column[Optional[int]] = Column( + trial_index: Column[int | None] = Column(Integer) + generation_strategy_id: Column[int | None] = Column( Integer, ForeignKey("generation_strategy.id") ) - structure_metadata_json: Column[Optional[str]] = Column( + structure_metadata_json: Column[str | None] = Column( Text(LONGTEXT_BYTES), nullable=True ) @@ -290,12 +270,10 @@ class SQAGenerationStrategy(Base): id: Column[int] = Column(Integer, primary_key=True) name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) steps: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=False) - curr_index: Column[Optional[int]] = Column(Integer, nullable=True) - experiment_id: Column[Optional[int]] = Column( - Integer, ForeignKey("experiment_v2.id") - ) + curr_index: Column[int | None] = Column(Integer, nullable=True) + experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) nodes: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=True) - curr_node_name: Column[Optional[str]] = Column( + curr_node_name: Column[str | None] = Column( String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True ) @@ -310,36 +288,34 @@ class SQAGenerationStrategy(Base): class SQATrial(Base): __tablename__: str = "trial_v2" - abandoned_reason: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - failed_reason: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - deployed_name: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + abandoned_reason: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + failed_reason: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + deployed_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) experiment_id: Column[int] = Column( Integer, ForeignKey("experiment_v2.id"), nullable=False ) id: Column[int] = Column(Integer, primary_key=True) index: Column[int] = Column(Integer, index=True, nullable=False) is_batch: Column[bool] = Column("is_batched", Boolean, nullable=False, default=True) - lifecycle_stage: Column[Optional[LifecycleStage]] = Column( + lifecycle_stage: Column[LifecycleStage | None] = Column( IntEnum(LifecycleStage), nullable=True ) num_arms_created: Column[int] = Column(Integer, nullable=False, default=0) - optimize_for_power: Column[Optional[bool]] = Column(Boolean) - ttl_seconds: Column[Optional[int]] = Column(Integer) - run_metadata: Column[Optional[dict[str, Any]]] = Column(JSONEncodedLongTextDict) - stop_metadata: Column[Optional[dict[str, Any]]] = Column(JSONEncodedTextDict) + optimize_for_power: Column[bool | None] = Column(Boolean) + ttl_seconds: Column[int | None] = Column(Integer) + run_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedLongTextDict) + stop_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict) status: Column[TrialStatus] = Column( IntEnum(TrialStatus), nullable=False, default=TrialStatus.CANDIDATE ) - status_quo_name: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - time_completed: Column[Optional[datetime]] = Column(IntTimestamp) + status_quo_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + time_completed: Column[datetime | None] = Column(IntTimestamp) time_created: Column[datetime] = Column(IntTimestamp, nullable=False) - time_staged: Column[Optional[datetime]] = Column(IntTimestamp) - time_run_started: Column[Optional[datetime]] = Column(IntTimestamp) - trial_type: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - generation_step_index: Column[Optional[int]] = Column(Integer) - properties: Column[Optional[dict[str, Any]]] = Column( - JSONEncodedTextDict, default={} - ) + time_staged: Column[datetime | None] = Column(IntTimestamp) + time_run_started: Column[datetime | None] = Column(IntTimestamp) + trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + generation_step_index: Column[int | None] = Column(Integer) + properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) # relationships # Trials and experiments are mutable, so the children relationships need @@ -382,27 +358,23 @@ class SQAAnalysisCard(Base): class SQAExperiment(Base): __tablename__: str = "experiment_v2" - description: Column[Optional[str]] = Column(String(LONG_STRING_FIELD_LENGTH)) - experiment_type: Column[Optional[int]] = Column(Integer) + description: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH)) + experiment_type: Column[int | None] = Column(Integer) id: Column[int] = Column(Integer, primary_key=True) is_test: Column[bool] = Column(Boolean, nullable=False, default=False) name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - properties: Column[Optional[dict[str, Any]]] = Column( - JSONEncodedTextDict, default={} - ) - status_quo_name: Column[Optional[str]] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) - status_quo_parameters: Column[Optional[TParameterization]] = Column( + properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={}) + status_quo_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + status_quo_parameters: Column[TParameterization | None] = Column( JSONEncodedTextDict ) time_created: Column[datetime] = Column(IntTimestamp, nullable=False) - default_trial_type: Column[Optional[str]] = Column( - String(NAME_OR_TYPE_FIELD_LENGTH) - ) + default_trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) default_data_type: Column[DataType] = Column(IntEnum(DataType), nullable=True) # pyre-fixme[8]: Incompatible attribute type [8]: Attribute # `auxiliary_experiments_by_purpose` declared in class `SQAExperiment` has # type `Optional[Dict[str, List[str]]]` but is used as type `Column[typing.Any]` - auxiliary_experiments_by_purpose: Optional[dict[str, list[str]]] = Column( + auxiliary_experiments_by_purpose: dict[str, list[str]] | None = Column( JSONEncodedTextDict, nullable=True, default={} ) @@ -430,7 +402,7 @@ class SQAExperiment(Base): trials: list[SQATrial] = relationship( "SQATrial", cascade="all, delete-orphan", lazy="selectin" ) - generation_strategy: Optional[SQAGenerationStrategy] = relationship( + generation_strategy: SQAGenerationStrategy | None = relationship( "SQAGenerationStrategy", backref=backref("experiment", lazy=True), uselist=False, diff --git a/ax/storage/sqa_store/sqa_config.py b/ax/storage/sqa_store/sqa_config.py index f66f4db6dbd..3e615151fbf 100644 --- a/ax/storage/sqa_store/sqa_config.py +++ b/ax/storage/sqa_store/sqa_config.py @@ -6,9 +6,10 @@ # pyre-strict +from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import Any from ax.analysis.analysis import AnalysisCard @@ -85,8 +86,8 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: class_to_sqa_class: dict[type[Base], type[SQABase]] = field( default_factory=_default_class_to_sqa_class ) - experiment_type_enum: Optional[Union[Enum, type[Enum]]] = None - generator_run_type_enum: Optional[Union[Enum, type[Enum]]] = GeneratorRunType + experiment_type_enum: Enum | type[Enum] | None = None + generator_run_type_enum: Enum | type[Enum] | None = GeneratorRunType auxiliary_experiment_purpose_enum: type[Enum] = AuxiliaryExperimentPurpose # pyre-fixme[4]: Attribute annotation cannot contain `Any`. diff --git a/ax/storage/sqa_store/structs.py b/ax/storage/sqa_store/structs.py index b27b1c47913..c6c591a79f5 100644 --- a/ax/storage/sqa_store/structs.py +++ b/ax/storage/sqa_store/structs.py @@ -6,7 +6,8 @@ # pyre-strict -from typing import Callable, NamedTuple, Optional +from collections.abc import Callable +from typing import NamedTuple from ax.storage.sqa_store.decoder import Decoder from ax.storage.sqa_store.encoder import Encoder @@ -19,8 +20,7 @@ class DBSettings(NamedTuple): Either creator or url must be specified as a way to connect to the SQL db. """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - creator: Optional[Callable] = None + creator: Callable | None = None decoder: Decoder = Decoder(config=SQAConfig()) encoder: Encoder = Encoder(config=SQAConfig()) - url: Optional[str] = None + url: str | None = None diff --git a/ax/storage/sqa_store/timestamp.py b/ax/storage/sqa_store/timestamp.py index ff3d9ea58a4..d3036002dad 100644 --- a/ax/storage/sqa_store/timestamp.py +++ b/ax/storage/sqa_store/timestamp.py @@ -7,7 +7,6 @@ # pyre-strict import datetime -from typing import Optional from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.types import Integer, TypeDecorator @@ -20,14 +19,14 @@ class IntTimestamp(TypeDecorator): # pyre-fixme[15]: `process_bind_param` overrides method defined in # `TypeDecorator` inconsistently. def process_bind_param( - self, value: Optional[datetime.datetime], dialect: Dialect - ) -> Optional[int]: + self, value: datetime.datetime | None, dialect: Dialect + ) -> int | None: if value is None: return None else: return int(value.timestamp()) def process_result_value( - self, value: Optional[int], dialect: Dialect - ) -> Optional[datetime.datetime]: + self, value: int | None, dialect: Dialect + ) -> datetime.datetime | None: return None if value is None else datetime.datetime.fromtimestamp(value) diff --git a/ax/storage/sqa_store/utils.py b/ax/storage/sqa_store/utils.py index 22a240316e8..7264dde8256 100644 --- a/ax/storage/sqa_store/utils.py +++ b/ax/storage/sqa_store/utils.py @@ -7,7 +7,7 @@ # pyre-strict import warnings -from typing import Any, Optional +from typing import Any from ax.core.experiment import Experiment from ax.core.search_space import SearchSpace @@ -48,7 +48,7 @@ def is_foreign_key_field(field: str) -> bool: # pyre-fixme[2]: Parameter annotation cannot be `Any`. -def copy_db_ids(source: Any, target: Any, path: Optional[list[str]] = None) -> None: +def copy_db_ids(source: Any, target: Any, path: list[str] | None = None) -> None: """Takes as input two objects, `source` and `target`, that should be identical, except that `source` has _db_ids set and `target` doesn't. Recursively copies the _db_ids from `source` to `target`. diff --git a/ax/storage/sqa_store/validation.py b/ax/storage/sqa_store/validation.py index d6dd673dda4..314ee973bd0 100644 --- a/ax/storage/sqa_store/validation.py +++ b/ax/storage/sqa_store/validation.py @@ -6,8 +6,9 @@ # pyre-strict +from collections.abc import Callable from logging import Logger -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar from ax.storage.sqa_store.db import SQABase from ax.storage.sqa_store.reduced_state import GR_LARGE_MODEL_ATTRS @@ -37,15 +38,12 @@ def listens_for_multiple( identifier: str, *args: Any, **kwargs: Any, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. ) -> Callable: """Analogue of SQLAlchemy `listen_for`, but applies the same listening handler function to multiple instrumented attributes. """ - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def wrapper(fn: Callable): + def wrapper(fn: Callable) -> Callable: for target in targets: event.listen(target, identifier, fn, *args, **kwargs) return fn diff --git a/ax/telemetry/ax_client.py b/ax/telemetry/ax_client.py index 82d4d24abe9..13df668ca25 100644 --- a/ax/telemetry/ax_client.py +++ b/ax/telemetry/ax_client.py @@ -8,7 +8,7 @@ from __future__ import annotations from dataclasses import asdict, dataclass -from typing import Any, Optional +from typing import Any from ax.service.ax_client import AxClient from ax.telemetry.common import _get_max_transformed_dimensionality @@ -29,12 +29,12 @@ class AxClientCreatedRecord: generation_strategy_created_record: GenerationStrategyCreatedRecord arms_per_trial: int - early_stopping_strategy_cls: Optional[str] - global_stopping_strategy_cls: Optional[str] + early_stopping_strategy_cls: str | None + global_stopping_strategy_cls: str | None # Dimensionality of transformed SearchSpace can often be much higher due to one-hot # encoding of unordered ChoiceParameters - transformed_dimensionality: Optional[int] + transformed_dimensionality: int | None @classmethod def from_ax_client(cls, ax_client: AxClient) -> AxClientCreatedRecord: diff --git a/ax/telemetry/common.py b/ax/telemetry/common.py index ce6e26b5db9..267e5b88b70 100644 --- a/ax/telemetry/common.py +++ b/ax/telemetry/common.py @@ -7,7 +7,7 @@ import warnings from datetime import datetime -from typing import Any, Optional +from typing import Any from ax.core.experiment import Experiment from ax.core.parameter import FixedParameter @@ -32,7 +32,7 @@ def _get_max_transformed_dimensionality( search_space: SearchSpace, generation_strategy: GenerationStrategy -) -> Optional[int]: +) -> int | None: """ Get dimensionality of transformed SearchSpace for all steps in the GenerationStrategy and return the maximum. diff --git a/ax/telemetry/experiment.py b/ax/telemetry/experiment.py index 08d5cfd2f5a..55d05e7de09 100644 --- a/ax/telemetry/experiment.py +++ b/ax/telemetry/experiment.py @@ -9,7 +9,6 @@ from dataclasses import dataclass from math import inf -from typing import Optional from ax.core.base_trial import TrialStatus @@ -38,8 +37,8 @@ class ExperimentCreatedRecord: bools, and None. """ - experiment_name: Optional[str] - experiment_type: Optional[str] + experiment_name: str | None + experiment_type: str | None # SearchSpace info num_continuous_range_parameters: int diff --git a/ax/telemetry/generation_strategy.py b/ax/telemetry/generation_strategy.py index 96af0d9e027..2bfb760aa4c 100644 --- a/ax/telemetry/generation_strategy.py +++ b/ax/telemetry/generation_strategy.py @@ -10,7 +10,6 @@ import warnings from dataclasses import dataclass from math import inf -from typing import Optional from ax.exceptions.core import AxWarning from ax.modelbridge.generation_strategy import GenerationStrategy @@ -29,14 +28,14 @@ class GenerationStrategyCreatedRecord: generation_strategy_name: str # -1 indicates unlimited trials requested, 0 indicates no trials requested - num_requested_initialization_trials: Optional[ + num_requested_initialization_trials: None | ( int # Typically the number of Sobol trials - ] - num_requested_bayesopt_trials: Optional[int] - num_requested_other_trials: Optional[int] + ) + num_requested_bayesopt_trials: int | None + num_requested_other_trials: int | None # Minimum `max_parallelism` across GenerationSteps, i.e. the bottleneck - max_parallelism: Optional[int] + max_parallelism: int | None @classmethod def from_generation_strategy( diff --git a/ax/telemetry/optimization.py b/ax/telemetry/optimization.py index 9aee480a10f..aa485e6486c 100644 --- a/ax/telemetry/optimization.py +++ b/ax/telemetry/optimization.py @@ -8,7 +8,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Union from ax.core.experiment import Experiment from ax.modelbridge.generation_strategy import GenerationStrategy @@ -37,8 +36,8 @@ class OptimizationCreatedRecord: owner: str # ExperimentCreatedRecord fields - experiment_name: Optional[str] - experiment_type: Optional[str] + experiment_name: str | None + experiment_type: str | None num_continuous_range_parameters: int num_int_range_parameters_small: int num_int_range_parameters_medium: int @@ -59,17 +58,17 @@ class OptimizationCreatedRecord: runner_cls: str # GenerationStrategyCreatedRecord fields - generation_strategy_name: Optional[str] - num_requested_initialization_trials: Optional[int] - num_requested_bayesopt_trials: Optional[int] - num_requested_other_trials: Optional[int] - max_parallelism: Optional[int] + generation_strategy_name: str | None + num_requested_initialization_trials: int | None + num_requested_bayesopt_trials: int | None + num_requested_other_trials: int | None + max_parallelism: int | None # {AxClient, Scheduler}CreatedRecord fields - early_stopping_strategy_cls: Optional[str] - global_stopping_strategy_cls: Optional[str] - transformed_dimensionality: Optional[int] - scheduler_total_trials: Optional[int] + early_stopping_strategy_cls: str | None + global_stopping_strategy_cls: str | None + transformed_dimensionality: int | None + scheduler_total_trials: int | None scheduler_max_pending_trials: int arms_per_trial: int @@ -78,11 +77,11 @@ class OptimizationCreatedRecord: launch_surface: str deployed_job_id: int - trial_evaluation_identifier: Optional[str] + trial_evaluation_identifier: str | None # Miscellaneous product info is_manual_generation_strategy: bool - warm_started_from: Optional[str] + warm_started_from: str | None num_custom_trials: int support_tier: str @@ -95,9 +94,9 @@ def from_scheduler( product_surface: str, launch_surface: str, deployed_job_id: int, - trial_evaluation_identifier: Optional[str], + trial_evaluation_identifier: str | None, is_manual_generation_strategy: bool, - warm_started_from: Optional[str], + warm_started_from: str | None, num_custom_trials: int, support_tier: str, ) -> OptimizationCreatedRecord: @@ -197,9 +196,9 @@ def from_ax_client( product_surface: str, launch_surface: str, deployed_job_id: int, - trial_evaluation_identifier: Optional[str], + trial_evaluation_identifier: str | None, is_manual_generation_strategy: bool, - warm_started_from: Optional[str], + warm_started_from: str | None, num_custom_trials: int, ) -> OptimizationCreatedRecord: ax_client_created_record = AxClientCreatedRecord.from_ax_client( @@ -294,7 +293,7 @@ def from_ax_client( def from_experiment( cls, experiment: Experiment, - generation_strategy: Optional[GenerationStrategy], + generation_strategy: GenerationStrategy | None, unique_identifier: str, owner: str, product_surface: str, @@ -302,9 +301,9 @@ def from_experiment( deployed_job_id: int, is_manual_generation_strategy: bool, num_custom_trials: int, - warm_started_from: Optional[str] = None, - arms_per_trial: Optional[int] = None, - trial_evaluation_identifier: Optional[str] = None, + warm_started_from: str | None = None, + arms_per_trial: int | None = None, + trial_evaluation_identifier: str | None = None, ) -> OptimizationCreatedRecord: experiment_created_record = ExperimentCreatedRecord.from_experiment( experiment=experiment, @@ -453,7 +452,7 @@ class OptimizationCompletedRecord: num_trials_bad_due_to_err: int # TODO[mpolson64] Deprecate this field as it is redundant with unique_identifier - deployed_job_id: Optional[int] + deployed_job_id: int | None # Miscellaneous deployment specific info estimated_early_stopping_savings: float @@ -464,7 +463,7 @@ def from_scheduler( cls, scheduler: Scheduler, unique_identifier: str, - deployed_job_id: Optional[int], + deployed_job_id: int | None, estimated_early_stopping_savings: float, estimated_global_stopping_savings: float, ) -> OptimizationCompletedRecord: @@ -511,7 +510,7 @@ def from_ax_client( cls, ax_client: AxClient, unique_identifier: str, - deployed_job_id: Optional[int], + deployed_job_id: int | None, estimated_early_stopping_savings: float, estimated_global_stopping_savings: float, ) -> OptimizationCompletedRecord: @@ -550,7 +549,7 @@ def from_ax_client( def _extract_model_fit_dict( - completed_record: Union[SchedulerCompletedRecord, AxClientCompletedRecord], + completed_record: SchedulerCompletedRecord | AxClientCompletedRecord, ) -> dict[str, float]: model_fit_names = [ "model_fit_quality", diff --git a/ax/telemetry/scheduler.py b/ax/telemetry/scheduler.py index e3d0604b2bb..9c8059c8b37 100644 --- a/ax/telemetry/scheduler.py +++ b/ax/telemetry/scheduler.py @@ -8,7 +8,7 @@ from __future__ import annotations from dataclasses import asdict, dataclass -from typing import Any, Optional +from typing import Any from warnings import warn from ax.modelbridge.cross_validation import ( @@ -35,15 +35,15 @@ class SchedulerCreatedRecord: generation_strategy_created_record: GenerationStrategyCreatedRecord # SchedulerOptions info - scheduler_total_trials: Optional[int] + scheduler_total_trials: int | None scheduler_max_pending_trials: int arms_per_trial: int - early_stopping_strategy_cls: Optional[str] - global_stopping_strategy_cls: Optional[str] + early_stopping_strategy_cls: str | None + global_stopping_strategy_cls: str | None # Dimensionality of transformed SearchSpace can often be much higher due to one-hot # encoding of unordered ChoiceParameters - transformed_dimensionality: Optional[int] + transformed_dimensionality: int | None @classmethod def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord: @@ -104,10 +104,10 @@ class SchedulerCompletedRecord: experiment_completed_record: ExperimentCompletedRecord best_point_quality: float - model_fit_quality: Optional[float] - model_std_quality: Optional[float] - model_fit_generalization: Optional[float] - model_std_generalization: Optional[float] + model_fit_quality: float | None + model_std_quality: float | None + model_fit_generalization: float | None + model_std_generalization: float | None improvement_over_baseline: float diff --git a/ax/telemetry/tests/test_ax_client.py b/ax/telemetry/tests/test_ax_client.py index d820b6f6ac2..ad811d37061 100644 --- a/ax/telemetry/tests/test_ax_client.py +++ b/ax/telemetry/tests/test_ax_client.py @@ -8,7 +8,6 @@ import logging from collections.abc import Sequence -from typing import Union import numpy as np @@ -54,7 +53,7 @@ def test_ax_client_created_record_from_ax_client(self) -> None: # Test with HSS & MOO. ax_client = AxClient() parameters: list[ - dict[str, Union[TParamValue, Sequence[TParamValue], dict[str, list[str]]]] + dict[str, TParamValue | Sequence[TParamValue] | dict[str, list[str]]] ] = [ { "name": "SearchSpace.optimizer", diff --git a/ax/utils/common/base.py b/ax/utils/common/base.py index 5ece9e5b290..5115c0f9a71 100644 --- a/ax/utils/common/base.py +++ b/ax/utils/common/base.py @@ -9,7 +9,6 @@ from __future__ import annotations import abc -from typing import Optional from ax.utils.common.equality import equality_typechecker, object_attribute_dicts_equal @@ -19,10 +18,10 @@ class Base: property for SQA storage. """ - _db_id: Optional[int] = None + _db_id: int | None = None @property - def db_id(self) -> Optional[int]: + def db_id(self) -> int | None: return self._db_id @db_id.setter diff --git a/ax/utils/common/decorator.py b/ax/utils/common/decorator.py index 28b29ea72bd..f4686b15214 100644 --- a/ax/utils/common/decorator.py +++ b/ax/utils/common/decorator.py @@ -6,7 +6,8 @@ # pyre-strict from abc import ABC, abstractmethod -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar T = TypeVar("T") diff --git a/ax/utils/common/deprecation.py b/ax/utils/common/deprecation.py index da7402eab18..d2dc45d8d2f 100644 --- a/ax/utils/common/deprecation.py +++ b/ax/utils/common/deprecation.py @@ -7,13 +7,12 @@ # pyre-strict import warnings -from typing import Optional, Type def _validate_force_random_search( - no_bayesian_optimization: Optional[bool] = None, + no_bayesian_optimization: bool | None = None, force_random_search: bool = False, - exception_cls: Type[Exception] = ValueError, + exception_cls: type[Exception] = ValueError, ) -> None: """Helper function to validate interaction between `force_random_search` and `no_bayesian_optimization` (supported until deprecation in [T199632397]) diff --git a/ax/utils/common/docutils.py b/ax/utils/common/docutils.py index 06d8da4a256..cac882d14c2 100644 --- a/ax/utils/common/docutils.py +++ b/ax/utils/common/docutils.py @@ -10,7 +10,8 @@ """ -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar _T = TypeVar("_T") diff --git a/ax/utils/common/equality.py b/ax/utils/common/equality.py index 96ad4b28e2f..13fe375924b 100644 --- a/ax/utils/common/equality.py +++ b/ax/utils/common/equality.py @@ -8,15 +8,16 @@ from __future__ import annotations +from collections.abc import Callable + from datetime import datetime -from typing import Any, Callable, Optional +from typing import Any import numpy as np import pandas as pd from ax.utils.common.typeutils_nonnative import numpy_type_to_python_type -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def equality_typechecker(eq_func: Callable) -> Callable: """A decorator to wrap all __eq__ methods to ensure that the inputs are of the right type. @@ -100,7 +101,7 @@ def is_ax_equal(one_val: Any, other_val: Any) -> bool: return False -def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool: +def datetime_equals(dt1: datetime | None, dt2: datetime | None) -> bool: """Compare equality of two datetimes, ignoring microseconds.""" if not dt1 and not dt2: return True diff --git a/ax/utils/common/executils.py b/ax/utils/common/executils.py index eb4014986e5..f07b07808d0 100644 --- a/ax/utils/common/executils.py +++ b/ax/utils/common/executils.py @@ -10,11 +10,11 @@ import functools import threading import time -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from functools import partial from logging import Logger -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar MAX_WAIT_SECONDS: int = 600 @@ -23,17 +23,17 @@ # pyre-fixme[3]: Return annotation cannot be `Any`. def retry_on_exception( - exception_types: Optional[tuple[type[Exception], ...]] = None, - no_retry_on_exception_types: Optional[tuple[type[Exception], ...]] = None, - check_message_contains: Optional[list[str]] = None, + exception_types: tuple[type[Exception], ...] | None = None, + no_retry_on_exception_types: tuple[type[Exception], ...] | None = None, + check_message_contains: list[str] | None = None, retries: int = 3, suppress_all_errors: bool = False, - logger: Optional[Logger] = None, + logger: Logger | None = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. - default_return_on_suppression: Optional[Any] = None, - wrap_error_message_in: Optional[str] = None, - initial_wait_seconds: Optional[int] = None, -) -> Optional[Any]: + default_return_on_suppression: Any | None = None, + wrap_error_message_in: str | None = None, + initial_wait_seconds: int | None = None, +) -> Any | None: """ A decorator for instance methods or standalone functions that makes them retry on failure and allows to specify on which types of exceptions the @@ -180,10 +180,10 @@ def handle_exceptions_in_retries( no_retry_exceptions: tuple[type[Exception], ...], retry_exceptions: tuple[type[Exception], ...], suppress_errors: bool, - check_message_contains: Optional[str], + check_message_contains: str | None, last_retry: bool, - logger: Optional[Logger], - wrap_error_message_in: Optional[str], + logger: Logger | None, + wrap_error_message_in: str | None, ) -> Generator[None, None, None]: try: yield # Perform action within the context manager. @@ -218,8 +218,8 @@ def handle_exceptions_in_retries( def _validate_and_fill_defaults( - retry_on_exception_types: Optional[tuple[type[Exception], ...]], - no_retry_on_exception_types: Optional[tuple[type[Exception], ...]], + retry_on_exception_types: tuple[type[Exception], ...] | None, + no_retry_on_exception_types: tuple[type[Exception], ...] | None, suppress_errors: bool, **kwargs: Any, ) -> tuple[tuple[type[Exception], ...], tuple[type[Exception], ...], bool]: diff --git a/ax/utils/common/kwargs.py b/ax/utils/common/kwargs.py index a9f2e66fe7d..7c4a23c1c02 100644 --- a/ax/utils/common/kwargs.py +++ b/ax/utils/common/kwargs.py @@ -6,11 +6,11 @@ # pyre-strict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from inspect import Parameter, signature from logging import Logger -from typing import Any, Callable, Optional +from typing import Any from ax.utils.common.logger import get_logger @@ -20,7 +20,7 @@ def consolidate_kwargs( - kwargs_iterable: Iterable[Optional[dict[str, Any]]], keywords: Iterable[str] + kwargs_iterable: Iterable[dict[str, Any] | None], keywords: Iterable[str] ) -> dict[str, Any]: """Combine an iterable of kwargs into a single dict of kwargs, where kwargs by duplicate keys that appear later in the iterable get priority over the @@ -42,16 +42,14 @@ def consolidate_kwargs( def get_function_argument_names( - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. function: Callable, - omit: Optional[list[str]] = None, + omit: list[str] | None = None, ) -> list[str]: """Extract parameter names from function signature.""" omit = omit or [] return [p for p in signature(function).parameters.keys() if p not in omit] -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def get_function_default_arguments(function: Callable) -> dict[str, Any]: """Extract default arguments from function signature.""" params = signature(function).parameters @@ -60,7 +58,6 @@ def get_function_default_arguments(function: Callable) -> dict[str, Any]: } -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def warn_on_kwargs(callable_with_kwargs: Callable, **kwargs: Any) -> None: """Log a warning when a decoder function receives unexpected kwargs. @@ -80,7 +77,6 @@ def warn_on_kwargs(callable_with_kwargs: Callable, **kwargs: Any) -> None: # pyre-fixme[3]: Return annotation cannot be `Any`. -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def filter_kwargs(function: Callable, **kwargs: Any) -> Any: """Filter out kwargs that are not applicable for a given function. Return a copy of given kwargs dict with only the required kwargs.""" diff --git a/ax/utils/common/logger.py b/ax/utils/common/logger.py index 8f05f295e81..0c6fe27cf43 100644 --- a/ax/utils/common/logger.py +++ b/ax/utils/common/logger.py @@ -9,9 +9,9 @@ import logging import os import re -from collections.abc import Iterable +from collections.abc import Callable, Iterable from functools import wraps -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar from ax.utils.common.decorator import ClassDecorator @@ -57,9 +57,7 @@ def get_logger( The logging.Logger object. """ # because handlers are attached to the "ax" module - if not force_name and not re.search( - r"^{ax_root}(\.|$)".format(ax_root=AX_ROOT_LOGGER_NAME), name - ): + if not force_name and not re.search(rf"^{AX_ROOT_LOGGER_NAME}(\.|$)", name): name = f"{AX_ROOT_LOGGER_NAME}.{name}" logger = logging.getLogger(name) logger.setLevel(level) diff --git a/ax/utils/common/mock.py b/ax/utils/common/mock.py index d6c35e6d138..47cd29266d6 100644 --- a/ax/utils/common/mock.py +++ b/ax/utils/common/mock.py @@ -6,8 +6,9 @@ # pyre-strict +from collections.abc import Callable from contextlib import contextmanager -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar from unittest.mock import MagicMock, patch diff --git a/ax/utils/common/random.py b/ax/utils/common/random.py index 0b68efe1bcd..a9fb6e45329 100644 --- a/ax/utils/common/random.py +++ b/ax/utils/common/random.py @@ -9,7 +9,6 @@ import random from collections.abc import Generator from contextlib import contextmanager -from typing import Optional import numpy as np import torch @@ -28,7 +27,7 @@ def set_rng_seed(seed: int) -> None: @contextmanager -def with_rng_seed(seed: Optional[int]) -> Generator[None, None, None]: +def with_rng_seed(seed: int | None) -> Generator[None, None, None]: """Context manager that sets the random number generator seeds to a given value and restores the previous state on exit. diff --git a/ax/utils/common/result.py b/ax/utils/common/result.py index a38bb89ab8a..03949f74b34 100644 --- a/ax/utils/common/result.py +++ b/ax/utils/common/result.py @@ -10,9 +10,10 @@ import traceback from abc import ABC, abstractmethod, abstractproperty +from collections.abc import Callable from functools import reduce -from typing import Any, Callable, cast, Generic, NoReturn, Optional, TypeVar, Union +from typing import Any, cast, Generic, NoReturn, TypeVar T = TypeVar("T", covariant=True) @@ -37,15 +38,15 @@ def is_err(self) -> bool: pass @abstractproperty - def ok(self) -> Optional[T]: + def ok(self) -> T | None: pass @abstractproperty - def err(self) -> Optional[E]: + def err(self) -> E | None: pass @abstractproperty - def value(self) -> Union[T, E]: + def value(self) -> T | E: pass @abstractmethod @@ -281,7 +282,7 @@ def __repr__(self) -> str: f"with Traceback:\n {self.tb_str()}" ) - def tb_str(self) -> Optional[str]: + def tb_str(self) -> str | None: if self.exception is None: return None diff --git a/ax/utils/common/serialization.py b/ax/utils/common/serialization.py index 512224b75b4..d42eab59fc5 100644 --- a/ax/utils/common/serialization.py +++ b/ax/utils/common/serialization.py @@ -10,8 +10,9 @@ import inspect import pydoc +from collections.abc import Callable from types import FunctionType -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar, Union T = TypeVar("T") @@ -53,7 +54,6 @@ def _is_named_tuple(x: Any) -> bool: return all(isinstance(n, str) for n in f) -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def callable_to_reference(callable: Callable) -> str: """Obtains path to the callable of form ..""" if not isinstance(callable, (FunctionType, type)): @@ -69,7 +69,6 @@ def callable_to_reference(callable: Callable) -> str: ) -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def callable_from_reference(path: str) -> Callable: """Retrieves a callable by its path.""" return pydoc.locate(path) # pyre-ignore[7] @@ -78,7 +77,7 @@ def callable_from_reference(path: str) -> Callable: def serialize_init_args( # pyre-fixme[2]: Parameter annotation cannot be `Any`. obj: Any, - exclude_fields: Optional[list[str]] = None, + exclude_fields: list[str] | None = None, ) -> dict[str, Any]: """Given an object, return a dictionary of the arguments that are needed by its constructor. @@ -158,8 +157,8 @@ def serialize_init_args(cls, obj: SerializationMixin) -> dict[str, Any]: def deserialize_init_args( cls, args: dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, - class_decoder_registry: Optional[TClassDecoderRegistry] = None, + decoder_registry: TDecoderRegistry | None = None, + class_decoder_registry: TClassDecoderRegistry | None = None, ) -> dict[str, Any]: """Given a dictionary, deserialize the properties needed to initialize the object. Used for storage. diff --git a/ax/utils/common/testutils.py b/ax/utils/common/testutils.py index 33d4b177c19..32fbd899c0e 100644 --- a/ax/utils/common/testutils.py +++ b/ax/utils/common/testutils.py @@ -21,12 +21,12 @@ import types import unittest import warnings -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import AbstractContextManager from logging import Logger from pstats import Stats from types import FrameType -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar, Union from unittest.mock import MagicMock import numpy as np @@ -71,16 +71,16 @@ class _AssertRaisesContextOn(unittest.case._AssertRaisesContext): filename: the file in which the error occured """ - _expected_line: Optional[str] - lineno: Optional[int] - filename: Optional[str] + _expected_line: str | None + lineno: int | None + filename: str | None def __init__( self, expected: type[Exception], test_case: unittest.TestCase, - expected_line: Optional[str] = None, - expected_regex: Optional[str] = None, + expected_line: str | None = None, + expected_regex: str | None = None, ) -> None: self._expected_line = ( expected_line.strip() if expected_line is not None else None @@ -95,9 +95,9 @@ def __init__( # inconsistently. def __exit__( self, - exc_type: Optional[type[Exception]], - exc_value: Optional[Exception], - tb: Optional[types.TracebackType], + exc_type: type[Exception] | None, + exc_value: Exception | None, + tb: types.TracebackType | None, ) -> bool: """This is called when the context closes. If an exception was raised `exc_type`, `exc_value` and `tb` will be set. @@ -122,7 +122,6 @@ def __exit__( # Instead of showing a warning (like in the standard library) we throw an error when # deprecated functions are called. -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def _deprecate(original_func: Callable) -> Callable: def _deprecated_func(*args: list[Any], **kwargs: dict[str, Any]) -> None: raise RuntimeError( @@ -248,7 +247,7 @@ def _unequal_str(first: Any, second: Any) -> str: # pyre-ignore[2] def setup_import_mocks( - mocked_import_paths: list[str], mock_config_dict: Optional[dict[str, Any]] = None + mocked_import_paths: list[str], mock_config_dict: dict[str, Any] | None = None ) -> None: """This function mocks expensive modules used in tests. It must be called before those modules are imported or it will not work. Stubbing out these modules @@ -295,10 +294,10 @@ class TestCase(fake_filesystem_unittest.TestCase): MAX_TEST_SECONDS = 60 NUMBER_OF_PROFILER_LINES_TO_OUTPUT = 20 PROFILE_TESTS = False - _long_test_active_reason: Optional[str] = None + _long_test_active_reason: str | None = None def __init__(self, methodName: str = "runTest") -> None: - def signal_handler(signum: int, frame: Optional[FrameType]) -> None: + def signal_handler(signum: int, frame: FrameType | None) -> None: message = f"Test took longer than {self.MAX_TEST_SECONDS} seconds." if self.PROFILE_TESTS: self._print_profiler_output() @@ -376,8 +375,8 @@ def setUp(self) -> None: ) def run( - self, result: Optional[unittest.result.TestResult] = ... - ) -> Optional[unittest.result.TestResult]: + self, result: unittest.result.TestResult | None = ... + ) -> unittest.result.TestResult | None: # Arrange for a SIGALRM signal to be delivered to the calling process # in specified number of seconds. signal.alarm(self.MAX_TEST_SECONDS) @@ -391,7 +390,7 @@ def assertEqual( self, first: Any, # pyre-ignore[2] second: Any, # pyre-ignore[2] - msg: Optional[str] = None, + msg: str | None = None, ) -> None: if isinstance(first, Base) and isinstance(second, Base): self.assertAxBaseEqual(first=first, second=second, msg=msg) @@ -402,7 +401,7 @@ def assertAxBaseEqual( self, first: Base, second: Base, - msg: Optional[str] = None, + msg: str | None = None, skip_db_id_check: bool = False, ) -> None: """Check that two Ax objects that subclass ``Base`` are equal or raise @@ -442,8 +441,8 @@ def assertAxBaseEqual( def assertRaisesOn( self, exc: type[Exception], - line: Optional[str] = None, - regex: Optional[str] = None, + line: str | None = None, + regex: str | None = None, ) -> AbstractContextManager[None]: """Assert that an exception is raised on a specific line.""" context = _AssertRaisesContextOn(exc, self, line, regex) @@ -520,7 +519,7 @@ def _print_profiler_output(self) -> None: @classmethod @contextlib.contextmanager - def ax_long_test(cls, reason: Optional[str]) -> Generator[None, None, None]: + def ax_long_test(cls, reason: str | None) -> Generator[None, None, None]: cls._long_test_active_reason = reason yield cls._long_test_active_reason = None diff --git a/ax/utils/common/typeutils.py b/ax/utils/common/typeutils.py index 2415f96dc4a..bd3377fb015 100644 --- a/ax/utils/common/typeutils.py +++ b/ax/utils/common/typeutils.py @@ -5,7 +5,7 @@ # pyre-strict -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar T = TypeVar("T") @@ -15,7 +15,7 @@ Y = TypeVar("Y") -def not_none(val: Optional[T], message: Optional[str] = None) -> T: +def not_none(val: T | None, message: str | None = None) -> T: """ Unbox an optional type. @@ -32,7 +32,7 @@ def not_none(val: Optional[T], message: Optional[str] = None) -> T: return val -def checked_cast(typ: type[T], val: V, exception: Optional[Exception] = None) -> T: +def checked_cast(typ: type[T], val: V, exception: Exception | None = None) -> T: """ Cast a value to a type (with a runtime safety check). @@ -61,7 +61,7 @@ def checked_cast(typ: type[T], val: V, exception: Optional[Exception] = None) -> return val -def checked_cast_optional(typ: type[T], val: Optional[V]) -> Optional[T]: +def checked_cast_optional(typ: type[T], val: V | None) -> T | None: """Calls checked_cast only if value is not None.""" if val is None: return val diff --git a/ax/utils/common/typeutils_torch.py b/ax/utils/common/typeutils_torch.py index 9c933652a79..1aa4a84dde5 100644 --- a/ax/utils/common/typeutils_torch.py +++ b/ax/utils/common/typeutils_torch.py @@ -7,13 +7,12 @@ # pyre-strict import json -from typing import Union import torch from ax.utils.common.typeutils import checked_cast -def torch_type_to_str(value: Union[torch.dtype, torch.device, torch.Size]) -> str: +def torch_type_to_str(value: torch.dtype | torch.device | torch.Size) -> str: """Converts torch types, commonly used in Ax, to string representations.""" if isinstance(value, torch.dtype): return str(value) @@ -26,7 +25,7 @@ def torch_type_to_str(value: Union[torch.dtype, torch.device, torch.Size]) -> st def torch_type_from_str( identifier: str, type_name: str -) -> Union[torch.dtype, torch.device, torch.Size]: +) -> torch.dtype | torch.device | torch.Size: if type_name == "device": return torch.device(identifier) if type_name == "dtype": diff --git a/ax/utils/flake8_plugins/docstring_checker.py b/ax/utils/flake8_plugins/docstring_checker.py index b5b546431ff..041044b3033 100644 --- a/ax/utils/flake8_plugins/docstring_checker.py +++ b/ax/utils/flake8_plugins/docstring_checker.py @@ -6,8 +6,9 @@ import ast import itertools +from collections.abc import Callable from pathlib import Path -from typing import Callable, NamedTuple +from typing import NamedTuple class Error(NamedTuple): diff --git a/ax/utils/measurement/synthetic_functions.py b/ax/utils/measurement/synthetic_functions.py index 417be1fbbfa..a52bdce7597 100644 --- a/ax/utils/measurement/synthetic_functions.py +++ b/ax/utils/measurement/synthetic_functions.py @@ -7,7 +7,7 @@ # pyre-strict from abc import ABC, abstractmethod -from typing import Optional, TypeVar, Union +from typing import TypeVar import numpy as np import torch @@ -23,12 +23,12 @@ class SyntheticFunction(ABC): _required_dimensionality: int _domain: list[tuple[float, float]] - _minimums: Optional[list[tuple[float, ...]]] = None - _maximums: Optional[list[tuple[float, ...]]] = None - _fmin: Optional[float] = None - _fmax: Optional[float] = None + _minimums: list[tuple[float, ...]] | None = None + _maximums: list[tuple[float, ...]] | None = None + _fmin: float | None = None + _fmax: float | None = None - def informative_failure_on_none(self, attr: Optional[T]) -> T: + def informative_failure_on_none(self, attr: T | None) -> T: if attr is None: raise NotImplementedError(f"{self.name} does not specify property.") return not_none(attr) @@ -39,9 +39,9 @@ def name(self) -> str: def __call__( self, - *args: Union[int, float, np.ndarray], - **kwargs: Union[int, float, np.ndarray], - ) -> Union[float, np.ndarray]: + *args: int | float | np.ndarray, + **kwargs: int | float | np.ndarray, + ) -> float | np.ndarray: """Simplified way to call the synthetic function and pass the argument numbers directly, e.g. `branin(2.0, 3.0)`. """ @@ -69,7 +69,7 @@ def __call__( x = float(x) return checked_cast(float, self.f(np.array(args))) - def f(self, X: np.ndarray) -> Union[float, np.ndarray]: + def f(self, X: np.ndarray) -> float | np.ndarray: """Synthetic function implementation. Args: @@ -165,7 +165,7 @@ def __init__( self._botorch_function = botorch_synthetic_function self._required_dimensionality: int = self._botorch_function.dim self._domain: list[tuple[float, float]] = self._botorch_function._bounds - self._fmin: Optional[float] = self._botorch_function._optimal_value + self._fmin: float | None = self._botorch_function._optimal_value @override @property diff --git a/ax/utils/report/render.py b/ax/utils/report/render.py index 40c3f376c1b..293b4476894 100644 --- a/ax/utils/report/render.py +++ b/ax/utils/report/render.py @@ -8,7 +8,6 @@ import os import pkgutil -from typing import Optional from ax.plot.render import _js_requires, _load_css_resource as _load_plot_css_resource from jinja2 import Environment, FunctionLoader @@ -30,22 +29,22 @@ def _load_css_resource() -> str: def p_html(text: str) -> str: """Embed text in paragraph tag.""" - return "

{}

".format(text) + return f"

{text}

" def h2_html(text: str) -> str: """Embed text in subheading tag.""" - return "

{}

".format(text) + return f"

{text}

" def h3_html(text: str) -> str: """Embed text in subsubheading tag.""" - return "

{}

".format(text) + return f"

{text}

" def list_item_html(text: str) -> str: """Embed text in list element tag.""" - return "
  • {}
  • ".format(text) + return f"
  • {text}
  • " def unordered_list_html(list_items: list[str]) -> str: @@ -55,10 +54,10 @@ def unordered_list_html(list_items: list[str]) -> str: def link_html(text: str, href: str) -> str: """Embed text and reference address into link tag.""" - return '{}'.format(href, text) + return f'{text}' -def table_cell_html(text: str, width: Optional[str] = None) -> str: +def table_cell_html(text: str, width: str | None = None) -> str: """Embed text or an HTML element into table cell tag.""" if width: return f"{text}" @@ -68,7 +67,7 @@ def table_cell_html(text: str, width: Optional[str] = None) -> str: def table_heading_cell_html(text: str) -> str: """Embed text or an HTML element into table heading cell tag.""" - return "{}".format(text) + return f"{text}" def table_row_html(table_cells: list[str]) -> str: diff --git a/ax/utils/sensitivity/derivative_measures.py b/ax/utils/sensitivity/derivative_measures.py index 583bbc91093..4209c5946ac 100644 --- a/ax/utils/sensitivity/derivative_measures.py +++ b/ax/utils/sensitivity/derivative_measures.py @@ -5,9 +5,10 @@ # pyre-strict +from collections.abc import Callable from copy import deepcopy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any import torch from ax.utils.common.typeutils import checked_cast, not_none @@ -23,7 +24,7 @@ def sample_discrete_parameters( input_mc_samples: torch.Tensor, - discrete_features: Union[None, list[int]], + discrete_features: None | list[int], bounds: torch.Tensor, num_mc_samples: int, ) -> torch.Tensor: @@ -49,24 +50,24 @@ def sample_discrete_parameters( return input_mc_samples -class GpDGSMGpMean(object): +class GpDGSMGpMean: - mean_gradients: Optional[torch.Tensor] = None - bootstrap_indices: Optional[torch.Tensor] = None - mean_gradients_btsp: Optional[list[torch.Tensor]] = None + mean_gradients: torch.Tensor | None = None + bootstrap_indices: torch.Tensor | None = None + mean_gradients_btsp: list[torch.Tensor] | None = None def __init__( self, model: Model, bounds: torch.Tensor, derivative_gp: bool = False, - kernel_type: Optional[str] = None, + kernel_type: str | None = None, Y_scale: float = 1.0, num_mc_samples: int = 10**4, input_qmc: bool = False, dtype: torch.dtype = torch.double, num_bootstrap_samples: int = 1, - discrete_features: Optional[list[int]] = None, + discrete_features: list[int] | None = None, ) -> None: r"""Computes three types of derivative based measures: the gradient, the gradient square and the gradient absolute measures. @@ -136,7 +137,7 @@ def __init__( self._compute_gradient_quantities(posterior, Y_scale) def _compute_gradient_quantities( - self, posterior: Union[GPyTorchPosterior, MultivariateNormal], Y_scale: float + self, posterior: GPyTorchPosterior | MultivariateNormal, Y_scale: float ) -> None: if self.derivative_gp: self.mean_gradients = checked_cast(torch.Tensor, posterior.mean) * Y_scale @@ -237,8 +238,8 @@ def gradients_square_measure(self) -> torch.Tensor: class GpDGSMGpSampling(GpDGSMGpMean): - samples_gradients: Optional[torch.Tensor] = None - samples_gradients_btsp: Optional[list[torch.Tensor]] = None + samples_gradients: torch.Tensor | None = None + samples_gradients_btsp: list[torch.Tensor] | None = None def __init__( self, @@ -246,7 +247,7 @@ def __init__( bounds: torch.Tensor, num_gp_samples: int, derivative_gp: bool = False, - kernel_type: Optional[str] = None, + kernel_type: str | None = None, Y_scale: float = 1.0, num_mc_samples: int = 10**4, input_qmc: bool = False, @@ -300,7 +301,7 @@ def __init__( ) def _compute_gradient_quantities( - self, posterior: Union[Posterior, MultivariateNormal], Y_scale: float + self, posterior: Posterior | MultivariateNormal, Y_scale: float ) -> None: if self.gp_sample_qmc: sampler = SobolQMCNormalSampler( @@ -415,7 +416,7 @@ def aggregation( def compute_derivatives_from_model_list( model_list: list[Model], bounds: torch.Tensor, - discrete_features: Optional[list[int]] = None, + discrete_features: list[int] | None = None, **kwargs: Any, ) -> torch.Tensor: """ diff --git a/ax/utils/sensitivity/sobol_measures.py b/ax/utils/sensitivity/sobol_measures.py index 63d7dcb422f..33a0c1b853a 100644 --- a/ax/utils/sensitivity/sobol_measures.py +++ b/ax/utils/sensitivity/sobol_measures.py @@ -5,9 +5,10 @@ # pyre-strict +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable, Optional, Union +from typing import Any import numpy as np @@ -29,17 +30,17 @@ from torch import Tensor -class SobolSensitivity(object): +class SobolSensitivity: def __init__( self, bounds: torch.Tensor, - input_function: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + input_function: Callable[[torch.Tensor], torch.Tensor] | None = None, num_mc_samples: int = 10**4, input_qmc: bool = False, second_order: bool = False, num_bootstrap_samples: int = 1, bootstrap_array: bool = False, - discrete_features: Optional[list[int]] = None, + discrete_features: list[int] | None = None, ) -> None: r"""Computes three types of Sobol indices: first order indices, total indices and second order indices (if specified ). @@ -102,32 +103,32 @@ def __init__( self.bootstrap_indices = torch.randint( 0, num_mc_samples, (self.num_bootstrap_samples, subset_size) ) - self.f_A: Optional[torch.Tensor] = None - self.f_B: Optional[torch.Tensor] = None + self.f_A: torch.Tensor | None = None + self.f_B: torch.Tensor | None = None # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List` to avoid runtime subscripting errors. - self.f_ABis: Optional[list] = None - self.f_total_var: Optional[torch.Tensor] = None + self.f_ABis: list | None = None + self.f_total_var: torch.Tensor | None = None # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List` to avoid runtime subscripting errors. - self.f_A_btsp: Optional[list] = None + self.f_A_btsp: list | None = None # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List` to avoid runtime subscripting errors. - self.f_B_btsp: Optional[list] = None + self.f_B_btsp: list | None = None # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List` to avoid runtime subscripting errors. - self.f_ABis_btsp: Optional[list] = None + self.f_ABis_btsp: list | None = None # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List` to avoid runtime subscripting errors. - self.f_total_var_btsp: Optional[list] = None + self.f_total_var_btsp: list | None = None # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List` to avoid runtime subscripting errors. - self.f_BAis: Optional[list] = None + self.f_BAis: list | None = None # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use # `typing.List` to avoid runtime subscripting errors. - self.f_BAis_btsp: Optional[list] = None - self.first_order_idxs: Optional[torch.Tensor] = None - self.first_order_idxs_btsp: Optional[torch.Tensor] = None + self.f_BAis_btsp: list | None = None + self.first_order_idxs: torch.Tensor | None = None + self.first_order_idxs_btsp: torch.Tensor | None = None def generate_all_input_matrix(self) -> torch.Tensor: A_B_ABi_list = [self.A, self.B] @@ -143,7 +144,7 @@ def generate_all_input_matrix(self) -> torch.Tensor: A_B_ABi = torch.cat(A_B_ABi_list, dim=0) return A_B_ABi - def evalute_function(self, f_A_B_ABi: Optional[torch.Tensor] = None) -> None: + def evalute_function(self, f_A_B_ABi: torch.Tensor | None = None) -> None: r"""evaluates the objective function and devides the evaluation into torch.Tensors needed for the indices computation. Args: @@ -318,8 +319,8 @@ def total_order_indices(self) -> Tensor: def second_order_indices( self, - first_order_idxs: Optional[torch.Tensor] = None, - first_order_idxs_btsp: Optional[torch.Tensor] = None, + first_order_idxs: torch.Tensor | None = None, + first_order_idxs_btsp: torch.Tensor | None = None, ) -> Tensor: r"""Computes the Second order Sobol indices: Args: @@ -404,7 +405,7 @@ def ProbitLinkMean(mean: torch.Tensor, var: torch.Tensor) -> torch.Tensor: return torch.distributions.Normal(0, 1).cdf(a) -class SobolSensitivityGPMean(object): +class SobolSensitivityGPMean: def __init__( self, model: Model, # TODO: narrow type down. E.g. ModelListGP does not work. @@ -417,7 +418,7 @@ def __init__( [torch.Tensor, torch.Tensor], torch.Tensor ] = GaussianLinkMean, mini_batch_size: int = 128, - discrete_features: Optional[list[int]] = None, + discrete_features: list[int] | None = None, ) -> None: r"""Computes three types of Sobol indices: first order indices, total indices and second order indices (if specified ). @@ -506,7 +507,7 @@ def second_order_indices(self) -> Tensor: return self.sensitivity.second_order_indices() -class SobolSensitivityGPSampling(object): +class SobolSensitivityGPSampling: def __init__( self, model: Model, @@ -517,7 +518,7 @@ def __init__( input_qmc: bool = False, gp_sample_qmc: bool = False, num_bootstrap_samples: int = 1, - discrete_features: Optional[list[int]] = None, + discrete_features: list[int] | None = None, ) -> None: r"""Computes three types of Sobol indices: first order indices, total indices and second order indices (if specified ). @@ -753,7 +754,7 @@ def compute_sobol_indices_from_model_list( model_list: list[Model], bounds: Tensor, order: str = "first", - discrete_features: Optional[list[int]] = None, + discrete_features: list[int] | None = None, **sobol_kwargs: Any, ) -> Tensor: """ @@ -788,7 +789,7 @@ def compute_sobol_indices_from_model_list( def ax_parameter_sens( model_bridge: TorchModelBridge, - metrics: Optional[list[str]] = None, + metrics: list[str] | None = None, order: str = "first", signed: bool = True, **sobol_kwargs: Any, @@ -856,7 +857,7 @@ def ax_parameter_sens( def _get_torch_model( model_bridge: TorchModelBridge, -) -> Union[BotorchModel, ModularBoTorchModel]: +) -> BotorchModel | ModularBoTorchModel: """Returns the TorchModel of the model_bridge, if it is a type that stores SearchSpaceDigest during model fitting. At this point, this is BotorchModel, and ModularBoTorchModel. @@ -875,7 +876,7 @@ def _get_torch_model( def _get_model_per_metric( - model: Union[BotorchModel, ModularBoTorchModel], metrics: list[str] + model: BotorchModel | ModularBoTorchModel, metrics: list[str] ) -> list[Model]: """For a given TorchModel model, returns a list of botorch.models.model.Model objects corresponding to - and in the same order as - the given metrics. diff --git a/ax/utils/stats/model_fit_stats.py b/ax/utils/stats/model_fit_stats.py index cbcfec1c3a1..71efa484dac 100644 --- a/ax/utils/stats/model_fit_stats.py +++ b/ax/utils/stats/model_fit_stats.py @@ -7,7 +7,7 @@ from collections.abc import Mapping from logging import Logger -from typing import Optional, Protocol +from typing import Protocol import numpy as np @@ -71,7 +71,7 @@ def compute_model_fit_metrics( def coefficient_of_determination( y_obs: np.ndarray, y_pred: np.ndarray, - se_pred: Optional[np.ndarray] = None, + se_pred: np.ndarray | None = None, eps: float = 1e-12, ) -> float: """Computes coefficient of determination, the proportion of variance in `y_obs` diff --git a/ax/utils/stats/statstools.py b/ax/utils/stats/statstools.py index 3b7f68a1148..0da3d6ab159 100644 --- a/ax/utils/stats/statstools.py +++ b/ax/utils/stats/statstools.py @@ -142,12 +142,12 @@ def positive_part_james_stein( def relativize( - means_t: Union[np.ndarray, list[float], float], - sems_t: Union[np.ndarray, list[float], float], + means_t: np.ndarray | list[float] | float, + sems_t: np.ndarray | list[float] | float, mean_c: float, sem_c: float, bias_correction: bool = True, - cov_means: Union[np.ndarray, list[float], float] = 0.0, + cov_means: np.ndarray | list[float] | float = 0.0, as_percent: bool = False, control_as_constant: bool = False, ) -> tuple[np.ndarray, np.ndarray]: @@ -213,7 +213,7 @@ def relativize( epsilon = 1e-10 if np.any(np.abs(mean_c) < epsilon): raise ValueError( - "mean_control ({0} +/- {1}) is smaller than 1 in 10 billion, " + "mean_control ({} +/- {}) is smaller than 1 in 10 billion, " "which is too small to reliably analyze ratios using the delta " "method. This usually occurs because winsorization has truncated " "all values down to zero. Try using a delta type that applies " @@ -243,12 +243,12 @@ def relativize( def unrelativize( - means_t: Union[np.ndarray, list[float], float], - sems_t: Union[np.ndarray, list[float], float], + means_t: np.ndarray | list[float] | float, + sems_t: np.ndarray | list[float] | float, mean_c: float, sem_c: float, bias_correction: bool = True, - cov_means: Union[np.ndarray, list[float], float] = 0.0, + cov_means: np.ndarray | list[float] | float = 0.0, as_percent: bool = False, control_as_constant: bool = False, ) -> tuple[np.ndarray, np.ndarray]: @@ -309,11 +309,11 @@ def unrelativize( def agresti_coull_sem( - n_numer: Union[pd.Series, np.ndarray, int], - n_denom: Union[pd.Series, np.ndarray, int], + n_numer: pd.Series | np.ndarray | int, + n_denom: pd.Series | np.ndarray | int, prior_successes: int = 2, prior_failures: int = 2, -) -> Union[np.ndarray, float]: +) -> np.ndarray | float: """Compute the Agresti-Coull style standard error for a binomial proportion. Reference: diff --git a/ax/utils/testing/backend_simulator.py b/ax/utils/testing/backend_simulator.py index b79479a1a10..6f5b5e1e41c 100644 --- a/ax/utils/testing/backend_simulator.py +++ b/ax/utils/testing/backend_simulator.py @@ -12,7 +12,6 @@ from dataclasses import dataclass from logging import Logger -from typing import Optional from ax.core.base_trial import TrialStatus from ax.utils.common.logger import get_logger @@ -38,11 +37,11 @@ class SimTrial: # The simulation runtime in seconds sim_runtime: float # the start time in seconds - sim_start_time: Optional[float] = None + sim_start_time: float | None = None # the queued time in seconds - sim_queued_time: Optional[float] = None + sim_queued_time: float | None = None # the completed time (used for early stopping) - sim_completed_time: Optional[float] = None + sim_completed_time: float | None = None @dataclass @@ -87,7 +86,7 @@ class BackendSimulatorOptions: max_concurrency: int = 1 time_scaling: float = 1.0 failure_rate: float = 0.0 - internal_clock: Optional[float] = None + internal_clock: float | None = None use_update_as_start_time: bool = False @@ -106,10 +105,10 @@ class BackendSimulatorState: options: BackendSimulatorOptions verbose_logging: bool - queued: list[dict[str, Optional[float]]] - running: list[dict[str, Optional[float]]] - failed: list[dict[str, Optional[float]]] - completed: list[dict[str, Optional[float]]] + queued: list[dict[str, float | None]] + running: list[dict[str, float | None]] + failed: list[dict[str, float | None]] + completed: list[dict[str, float | None]] class BackendSimulator: @@ -117,11 +116,11 @@ class BackendSimulator: def __init__( self, - options: Optional[BackendSimulatorOptions] = None, - queued: Optional[list[SimTrial]] = None, - running: Optional[list[SimTrial]] = None, - failed: Optional[list[SimTrial]] = None, - completed: Optional[list[SimTrial]] = None, + options: BackendSimulatorOptions | None = None, + queued: list[SimTrial] | None = None, + running: list[SimTrial] | None = None, + failed: list[SimTrial] | None = None, + completed: list[SimTrial] | None = None, verbose_logging: bool = True, ) -> None: """A simulator for a concurrent dispatch with a queue. @@ -372,7 +371,7 @@ def status(self) -> SimStatus: completed=[t.trial_index for t in self._completed], ) - def lookup_trial_index_status(self, trial_index: int) -> Optional[TrialStatus]: + def lookup_trial_index_status(self, trial_index: int) -> TrialStatus | None: """Lookup the trial status of a ``trial_index``. Args: @@ -392,7 +391,7 @@ def lookup_trial_index_status(self, trial_index: int) -> Optional[TrialStatus]: return TrialStatus.FAILED return None - def get_sim_trial_by_index(self, trial_index: int) -> Optional[SimTrial]: + def get_sim_trial_by_index(self, trial_index: int) -> SimTrial | None: """Get a ``SimTrial`` by ``trial_index``. Args: @@ -473,7 +472,7 @@ def _create_index_to_trial_map(self) -> None: self._index_to_trial_map = {t.trial_index: t for t in self.all_trials} -def format(trial_list: list[dict[str, Optional[float]]]) -> str: +def format(trial_list: list[dict[str, float | None]]) -> str: """Helper function for formatting a list.""" trial_list_str = [str(i) for i in trial_list] return "\n".join(trial_list_str) diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 7f3bfd7b524..6aab3ee3f0a 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Any, Optional, Union +from typing import Any import numpy as np import torch @@ -45,7 +45,7 @@ def get_single_objective_benchmark_problem( observe_noise_sd: bool = False, num_trials: int = 4, - test_problem_kwargs: Optional[dict[str, Any]] = None, + test_problem_kwargs: dict[str, Any] | None = None, report_inference_value_as_trace: bool = False, ) -> BenchmarkProblem: return create_problem_from_botorch( @@ -230,7 +230,7 @@ class TestParamBasedTestProblem(ParamBasedTestProblem): def __init__( self, num_objectives: int, - noise_std: Optional[Union[float, list[float]]] = None, + noise_std: float | list[float] | None = None, dim: int = 6, ) -> None: self.num_objectives = num_objectives diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 050f2399d70..6469c2f49a1 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -14,7 +14,7 @@ from datetime import datetime, timedelta from logging import Logger from pathlib import Path -from typing import Any, cast, Optional +from typing import Any, cast import numpy as np import pandas as pd @@ -242,7 +242,7 @@ def get_branin_experiment( with_fidelity_parameter: bool = False, with_choice_parameter: bool = False, with_str_choice_param: bool = False, - search_space: Optional[SearchSpace] = None, + search_space: SearchSpace | None = None, minimize: bool = False, named: bool = True, num_batch_trial: int = 1, @@ -316,8 +316,8 @@ def get_branin_experiment_with_status_quo_trials( def get_robust_branin_experiment( - risk_measure: Optional[RiskMeasure] = None, - optimization_config: Optional[OptimizationConfig] = None, + risk_measure: RiskMeasure | None = None, + optimization_config: OptimizationConfig | None = None, num_sobol_trials: int = 2, ) -> Experiment: x1_dist = ParameterDistribution( @@ -375,7 +375,7 @@ def get_map_metric(name: str, rate: float | None = None) -> BraninTimestampMapMe def get_branin_experiment_with_timestamp_map_metric( with_status_quo: bool = False, - rate: Optional[float] = None, + rate: float | None = None, map_tracking_metric: bool = False, ) -> Experiment: tracking_metric = ( @@ -405,7 +405,7 @@ def get_branin_experiment_with_timestamp_map_metric( def run_branin_experiment_with_generation_strategy( generation_strategy: GenerationStrategy, num_trials: int = 6, - kwargs_for_get_branin_experiment: Optional[dict[str, Any]] = None, + kwargs_for_get_branin_experiment: dict[str, Any] | None = None, ) -> Experiment: """Gets a Branin experiment using any given kwargs and runs num_trials trials using the given generation strategy.""" @@ -738,7 +738,7 @@ def get_experiment_with_observations( scalarized: bool = False, constrained: bool = False, with_tracking_metrics: bool = False, - search_space: Optional[SearchSpace] = None, + search_space: SearchSpace | None = None, ) -> Experiment: if observations: multi_objective = (len(observations[0]) - constrained) > 1 @@ -1213,7 +1213,7 @@ def get_robust_search_space_environmental( def get_batch_trial( abandon_arm: bool = True, - experiment: Optional[Experiment] = None, + experiment: Experiment | None = None, constrain_search_space: bool = True, ) -> BatchTrial: experiment = experiment or get_experiment( @@ -1321,12 +1321,12 @@ class TestTrial(BaseTrial): def __repr__(self) -> str: return "test" - def _get_candidate_metadata(self, arm_name: str) -> Optional[dict[str, Any]]: + def _get_candidate_metadata(self, arm_name: str) -> dict[str, Any] | None: return None def _get_candidate_metadata_from_all_generator_runs( self, - ) -> dict[str, Optional[dict[str, Any]]]: + ) -> dict[str, dict[str, Any] | None]: return {"test": None} def abandoned_arms(self) -> str: @@ -1527,7 +1527,7 @@ def get_objective_threshold( def get_outcome_constraint( - metric: Optional[Metric] = None, relative: bool = True, bound: float = -0.25 + metric: Metric | None = None, relative: bool = True, bound: float = -0.25 ) -> OutcomeConstraint: if metric is None: metric = Metric(name="m2") @@ -1962,8 +1962,8 @@ def get_map_key_info() -> MapKeyInfo: def get_branin_data( - trial_indices: Optional[Iterable[int]] = None, - trials: Optional[Iterable[Trial]] = None, + trial_indices: Iterable[int] | None = None, + trials: Iterable[Trial] | None = None, ) -> Data: if trial_indices and trials: raise ValueError("Expected `trial_indices` or `trials`, not both.") @@ -2021,7 +2021,7 @@ def get_branin_data_batch(batch: BatchTrial) -> Data: def get_branin_data_multi_objective( - trial_indices: Optional[Iterable[int]] = None, num_objectives: int = 2 + trial_indices: Iterable[int] | None = None, num_objectives: int = 2 ) -> Data: _validate_num_objectives(num_objectives=num_objectives) suffixes = ["a", "b"] @@ -2088,10 +2088,8 @@ def get_or_early_stopping_strategy() -> OrEarlyStoppingStrategy: class DummyEarlyStoppingStrategy(BaseEarlyStoppingStrategy): - def __init__( - self, early_stop_trials: Optional[dict[int, Optional[str]]] = None - ) -> None: - self.early_stop_trials: dict[int, Optional[str]] = early_stop_trials or {} + def __init__(self, early_stop_trials: dict[int, str | None] | None = None) -> None: + self.early_stop_trials: dict[int, str | None] = early_stop_trials or {} self.seconds_between_polls = 1 def should_stop_trials_early( @@ -2099,7 +2097,7 @@ def should_stop_trials_early( trial_indices: set[int], experiment: Experiment, **kwargs: dict[str, Any], - ) -> dict[int, Optional[str]]: + ) -> dict[int, str | None]: return self.early_stop_trials @@ -2321,10 +2319,10 @@ def get_dataset( d: int = 2, m: int = 2, has_observation_noise: bool = False, - feature_names: Optional[list[str]] = None, - outcome_names: Optional[list[str]] = None, - tkwargs: Optional[dict[str, Any]] = None, - seed: Optional[int] = None, + feature_names: list[str] | None = None, + outcome_names: list[str] | None = None, + tkwargs: dict[str, Any] | None = None, + seed: int | None = None, ) -> SupervisedDataset: """Constructs a SupervisedDataset based on the given arguments. @@ -2427,7 +2425,7 @@ def __init__( self, name: str, test_attribute: str, - lower_is_better: Optional[bool] = None, + lower_is_better: bool | None = None, ) -> None: self.test_attribute = test_attribute super().__init__(name=name, lower_is_better=lower_is_better) @@ -2446,8 +2444,8 @@ def gen_for_multiple_trials_with_multiple_models( self, experiment: Experiment, num_generator_runs: int, - data: Optional[Data] = None, - n: Optional[int] = None, + data: Data | None = None, + n: int | None = None, ) -> list[list[GeneratorRun]]: return [] diff --git a/ax/utils/testing/metrics/branin_backend_map.py b/ax/utils/testing/metrics/branin_backend_map.py index 00661fb390f..f56c60ec9ce 100644 --- a/ax/utils/testing/metrics/branin_backend_map.py +++ b/ax/utils/testing/metrics/branin_backend_map.py @@ -6,7 +6,6 @@ # pyre-strict -from typing import Optional import numpy as np from ax.metrics.branin_map import BraninTimestampMapMetric @@ -26,7 +25,7 @@ def __init__( name: str, param_names: list[str], noise_sd: float = 0.0, - lower_is_better: Optional[bool] = True, + lower_is_better: bool | None = True, cache_evaluations: bool = True, rate: float = 0.5, delta_t: float = 1.0, @@ -56,7 +55,7 @@ def __init__( self._timestamp = -1 def convert_to_timestamps( - self, start_time: Optional[float], end_time: float + self, start_time: float | None, end_time: float ) -> list[float]: """Given a starting and current time, get the list of intermediate timestamps at which we have observations.""" diff --git a/ax/utils/testing/mock.py b/ax/utils/testing/mock.py index 29c61338b7c..d4183162838 100644 --- a/ax/utils/testing/mock.py +++ b/ax/utils/testing/mock.py @@ -5,10 +5,10 @@ # pyre-strict -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager, ExitStack from functools import wraps -from typing import Any, Callable, Optional +from typing import Any from unittest import mock from botorch.fit import fit_fully_bayesian_model_nuts @@ -48,7 +48,7 @@ def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor: return gen_batch_initial_conditions(*args, **kwargs) - def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Optional[Tensor]: + def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None: kwargs["num_restarts"] = 2 kwargs["raw_samples"] = 4 @@ -111,7 +111,6 @@ def minimal_fit_fully_bayesian(*args: Any, **kwargs: Any) -> None: ) -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def fast_botorch_optimize(f: Callable) -> Callable: """Wraps f in the fast_botorch_optimize_context_manager for use as a decorator.""" @@ -141,7 +140,6 @@ def skip_fit_gpytorch_mll_context_manager() -> Generator[None, None, None]: ) -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def skip_fit_gpytorch_mll(f: Callable) -> Callable: """Wraps f in the skip_fit_gpytorch_mll_context_manager for use as a decorator.""" diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index 4a9b522315f..c05671eded5 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -7,7 +7,7 @@ # pyre-strict from logging import Logger -from typing import Any, Optional +from typing import Any import numpy as np from ax.core.base_trial import TrialStatus @@ -518,7 +518,7 @@ def get_surrogate_as_dict() -> dict[str, Any]: def get_surrogate_spec_as_dict( - model_class: Optional[str] = None, with_legacy_input_transform: bool = False + model_class: str | None = None, with_legacy_input_transform: bool = False ) -> dict[str, Any]: """ For use ensuring backwards compatibility when loading SurrogateSpec @@ -593,8 +593,8 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional[ModelBridge], - fixed_features: Optional[ObservationFeatures], + modelbridge: ModelBridge | None, + fixed_features: ObservationFeatures | None, ) -> OptimizationConfig: return ( # pyre-ignore[7]: pyre is right, this is a hack for testing. # pyre-fixme[58]: `+` is not supported for operand types @@ -652,8 +652,8 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional[ModelBridge], - fixed_features: Optional[ObservationFeatures], + modelbridge: ModelBridge | None, + fixed_features: ObservationFeatures | None, ) -> OptimizationConfig: return ( # pyre-fixme[58]: `**` is not supported for operand types diff --git a/ax/utils/testing/preference_stubs.py b/ax/utils/testing/preference_stubs.py index c1ed162dc94..be024b7a0a0 100644 --- a/ax/utils/testing/preference_stubs.py +++ b/ax/utils/testing/preference_stubs.py @@ -5,7 +5,8 @@ # pyre-strict -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import numpy as np import torch @@ -75,8 +76,8 @@ def experimental_metric_eval( def get_pbo_experiment( num_parameters: int = 2, num_experimental_metrics: int = 3, - parameter_names: Optional[list[str]] = None, - tracking_metric_names: Optional[list[str]] = None, + parameter_names: list[str] | None = None, + tracking_metric_names: list[str] | None = None, num_experimental_trials: int = 3, num_preference_trials: int = 3, num_preference_trials_w_repeated_arm: int = 5, diff --git a/ax/utils/testing/torch_stubs.py b/ax/utils/testing/torch_stubs.py index dd3e4474f8b..fa258b3ebdd 100644 --- a/ax/utils/testing/torch_stubs.py +++ b/ax/utils/testing/torch_stubs.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any import torch @@ -17,7 +17,7 @@ def get_torch_test_data( dtype: torch.dtype = torch.float, cuda: bool = False, constant_noise: bool = True, - task_features: Optional[list[int]] = None, + task_features: list[int] | None = None, offset: float = 0.0, ) -> tuple[ list[torch.Tensor], diff --git a/ax/utils/tutorials/cnn_utils.py b/ax/utils/tutorials/cnn_utils.py index bed932bce7a..d64cabb4177 100644 --- a/ax/utils/tutorials/cnn_utils.py +++ b/ax/utils/tutorials/cnn_utils.py @@ -7,7 +7,6 @@ # pyre-strict from itertools import accumulate -from typing import Optional import torch import torch.nn as nn @@ -46,7 +45,7 @@ def load_mnist( batch_size: int = 128, num_workers: int = 0, deterministic_partitions: bool = False, - downsample_pct_test: Optional[float] = None, + downsample_pct_test: float | None = None, ) -> tuple[DataLoader, DataLoader, DataLoader]: """ Load MNIST dataset (download if necessary) and split data into training, @@ -102,7 +101,7 @@ def get_partition_data_loaders( batch_size: int = 128, num_workers: int = 0, deterministic_partitions: bool = False, - downsample_pct_test: Optional[float] = None, + downsample_pct_test: float | None = None, ) -> tuple[DataLoader, DataLoader, DataLoader]: """ Helper function for partitioning training data into training and validation sets, diff --git a/scripts/make_tutorials.py b/scripts/make_tutorials.py index 40e352a2d57..005c2bbb342 100644 --- a/scripts/make_tutorials.py +++ b/scripts/make_tutorials.py @@ -10,7 +10,6 @@ import tarfile import time from pathlib import Path -from typing import Optional import nbformat import papermill @@ -66,7 +65,7 @@ class TutorialPage extends React.Component {{ """ -def _get_paths(repo_dir: str, t_dir: Optional[str], tid: str) -> dict[str, str]: +def _get_paths(repo_dir: str, t_dir: str | None, tid: str) -> dict[str, str]: if t_dir is not None: tutorial_dir = os.path.join(repo_dir, "tutorials", t_dir) html_dir = os.path.join(repo_dir, "website", "_tutorials", t_dir) @@ -76,26 +75,20 @@ def _get_paths(repo_dir: str, t_dir: Optional[str], tid: str) -> dict[str, str]: for d in [tutorial_dir, html_dir, js_dir, py_dir]: os.makedirs(d, exist_ok=True) - tutorial_path = os.path.join(tutorial_dir, "{}.ipynb".format(tid)) - html_path = os.path.join(html_dir, "{}.html".format(tid)) - js_path = os.path.join(js_dir, "{}.js".format(tid)) - ipynb_path = os.path.join(py_dir, "{}.ipynb".format(tid)) - py_path = os.path.join(py_dir, "{}.py".format(tid)) + tutorial_path = os.path.join(tutorial_dir, f"{tid}.ipynb") + html_path = os.path.join(html_dir, f"{tid}.html") + js_path = os.path.join(js_dir, f"{tid}.js") + ipynb_path = os.path.join(py_dir, f"{tid}.ipynb") + py_path = os.path.join(py_dir, f"{tid}.py") else: tutorial_dir = os.path.join(repo_dir, "tutorials") - tutorial_path = os.path.join(repo_dir, "tutorials", "{}.ipynb".format(tid)) - html_path = os.path.join( - repo_dir, "website", "_tutorials", "{}.html".format(tid) - ) - js_path = os.path.join( - repo_dir, "website", "pages", "tutorials", "{}.js".format(tid) - ) + tutorial_path = os.path.join(repo_dir, "tutorials", f"{tid}.ipynb") + html_path = os.path.join(repo_dir, "website", "_tutorials", f"{tid}.html") + js_path = os.path.join(repo_dir, "website", "pages", "tutorials", f"{tid}.js") ipynb_path = os.path.join( - repo_dir, "website", "static", "files", "{}.ipynb".format(tid) - ) - py_path = os.path.join( - repo_dir, "website", "static", "files", "{}.py".format(tid) + repo_dir, "website", "static", "files", f"{tid}.ipynb" ) + py_path = os.path.join(repo_dir, "website", "static", "files", f"{tid}.py") paths = { "tutorial_dir": tutorial_dir, @@ -106,12 +99,12 @@ def _get_paths(repo_dir: str, t_dir: Optional[str], tid: str) -> dict[str, str]: "py_path": py_path, } if t_dir is not None: - paths["tar_path"] = os.path.join(py_dir, "{}.tar.gz".format(tid)) + paths["tar_path"] = os.path.join(py_dir, f"{tid}.tar.gz") return paths def run_script( - tutorial: Path, timeout_minutes: int, env: Optional[dict[str, str]] = None + tutorial: Path, timeout_minutes: int, env: dict[str, str] | None = None ) -> None: if env is not None: os.environ.update(env) @@ -126,7 +119,7 @@ def run_script( def gen_tutorials( repo_dir: str, exec_tutorials: bool, - name: Optional[str] = None, + name: str | None = None, smoke_test: bool = False, ) -> None: """Generate HTML tutorials for Docusaurus Ax site from Jupyter notebooks. @@ -136,7 +129,7 @@ def gen_tutorials( """ has_errors = False - with open(os.path.join(repo_dir, "website", "tutorials.json"), "r") as infile: + with open(os.path.join(repo_dir, "website", "tutorials.json")) as infile: tutorial_config = json.loads(infile.read()) # flatten config dict tutorial_configs = [ @@ -156,7 +149,7 @@ def gen_tutorials( tid = config["id"] t_dir = config.get("dir") exec_on_build = config.get("exec_on_build", True) - print("Generating {} tutorial".format(tid)) + print(f"Generating {tid} tutorial") paths = _get_paths(repo_dir=repo_dir, t_dir=t_dir, tid=tid) total_time = None @@ -166,7 +159,7 @@ def gen_tutorials( continue elif exec_tutorials and exec_on_build: tutorial_path = Path(paths["tutorial_path"]) - print("Executing tutorial {}".format(tid)) + print(f"Executing tutorial {tid}") start_time = time.monotonic() # Try / catch failures for now. We will re-raise at the end. @@ -193,7 +186,7 @@ def gen_tutorials( print(f"Encountered error running tutorial {tid}: \n {e}") # load notebook - with open(paths["tutorial_path"], "r") as infile: + with open(paths["tutorial_path"]) as infile: nb_str = infile.read() nb = nbformat.reads(nb_str, nbformat.NO_CONVERT) # convert notebook to HTML diff --git a/scripts/parse_sphinx.py b/scripts/parse_sphinx.py index 952c7c67b3f..4a113f09ce9 100644 --- a/scripts/parse_sphinx.py +++ b/scripts/parse_sphinx.py @@ -47,7 +47,7 @@ def parse_sphinx(input_dir: str, output_dir: str) -> None: for cur, _, files in os.walk(input_dir): for fname in files: if fname.endswith(".html"): - with open(os.path.join(cur, fname), "r") as f: + with open(os.path.join(cur, fname)) as f: soup = BeautifulSoup(f.read(), "html.parser") doc = soup.find("div", {"class": "document"}) wrapped_doc = doc.wrap(soup.new_tag("div", **{"class": "sphinx"})) @@ -62,7 +62,7 @@ def parse_sphinx(input_dir: str, output_dir: str) -> None: fout.write(out) # update reference in JS file - with open(os.path.join(input_dir, "_static/searchtools.js"), "r") as js_file: + with open(os.path.join(input_dir, "_static/searchtools.js")) as js_file: js = js_file.read() js = js.replace( "DOCUMENTATION_OPTIONS.URL_ROOT + '_sources/'", "'_sphinx-sources/'" diff --git a/scripts/patch_site_config.py b/scripts/patch_site_config.py index 9579ff58307..8bd504a2a46 100644 --- a/scripts/patch_site_config.py +++ b/scripts/patch_site_config.py @@ -11,10 +11,10 @@ def patch_config( config_file: str, base_url: str = None, disable_algolia: bool = True ) -> None: - config = open(config_file, "r").read() + config = open(config_file).read() if base_url is not None: - config = re.sub("baseUrl = '/';", "baseUrl = '{}';".format(base_url), config) + config = re.sub("baseUrl = '/';", f"baseUrl = '{base_url}';", config) if disable_algolia is True: config = re.sub( "const includeAlgolia = true;", "const includeAlgolia = false;", config diff --git a/scripts/update_versions_html.py b/scripts/update_versions_html.py index f432e34b047..0457511d6e9 100644 --- a/scripts/update_versions_html.py +++ b/scripts/update_versions_html.py @@ -50,12 +50,10 @@ def prepend_url(a_tag, base_url, version): h3.string = v # output files - with open( - base_path + "/new-site/versions/{}/versions.html".format(v), "w" - ) as outfile: + with open(base_path + f"/new-site/versions/{v}/versions.html", "w") as outfile: outfile.write(str(soup)) with open( - base_path + "/new-site/versions/{}/en/versions.html".format(v), "w" + base_path + f"/new-site/versions/{v}/en/versions.html", "w" ) as outfile: outfile.write(str(soup)) diff --git a/scripts/validate_sphinx.py b/scripts/validate_sphinx.py index 0d2e0226c43..fd9146b661f 100755 --- a/scripts/validate_sphinx.py +++ b/scripts/validate_sphinx.py @@ -8,7 +8,6 @@ import os import pkgutil import re -from typing import Set # Paths are relative to top-level Ax directory (which is passed into fxn below) @@ -29,10 +28,10 @@ } -def parse_rst(rst_filename: str) -> Set[str]: +def parse_rst(rst_filename: str) -> set[str]: """Extract automodule directives from rst.""" ret = set() - with open(rst_filename, "r") as f: + with open(rst_filename) as f: lines = f.readlines() for line in lines: line = line.strip() diff --git a/setup.py b/setup.py index 278b2ed9251..da926b7c3e0 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def local_version(version): def setup_package() -> None: """Used for installing the Ax package.""" - with open("README.md", "r") as fh: + with open("README.md") as fh: long_description = fh.read() setup( diff --git a/sphinx/source/conf.py b/sphinx/source/conf.py index 31cf8946507..a5ab08ef9ba 100644 --- a/sphinx/source/conf.py +++ b/sphinx/source/conf.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the @@ -60,9 +59,9 @@ index_doc = "index" # General information about the project. -project = u"Ax" -copyright = u"2019, Facebook Inc." -author = u"Facebook Inc." +project = "Ax" +copyright = "2019, Facebook Inc." +author = "Facebook Inc." # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -211,14 +210,14 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (index_doc, "Ax.tex", u"Ax Documentation", u"Facebook Inc.", "manual") + (index_doc, "Ax.tex", "Ax Documentation", "Facebook Inc.", "manual") ] # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(index_doc, "ax", u"Ax Documentation", [author], 1)] +man_pages = [(index_doc, "ax", "Ax Documentation", [author], 1)] # If true, show URL addresses after external links. # man_show_urls = False @@ -233,7 +232,7 @@ ( index_doc, "Ax", - u"Ax Documentation", + "Ax Documentation", author, "Ax", "Platform for automated optimization and experimentation.",