Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply PEP 604 union type syntax codemod #2808

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

# pyre-strict

from collections.abc import Iterable
from enum import Enum
from logging import Logger
from typing import Iterable, Optional, Protocol
from typing import Protocol

import pandas as pd
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -105,8 +106,8 @@ class Analysis(Protocol):

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
) -> AnalysisCard:
# Note: when implementing compute always prefer experiment.lookup_data() to
# experiment.fetch_data() to avoid unintential data fetching within the report
Expand All @@ -115,8 +116,8 @@ def compute(

def compute_result(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
) -> Result[AnalysisCard, ExceptionE]:
"""
Utility method to compute an AnalysisCard as a Result. This can be useful for
Expand Down
5 changes: 2 additions & 3 deletions ax/analysis/healthcheck/healthcheck_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# pyre-strict
import json
from enum import IntEnum
from typing import Optional

from ax.analysis.analysis import AnalysisCard
from ax.core.experiment import Experiment
Expand All @@ -33,6 +32,6 @@ class HealthcheckAnalysis:

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> HealthcheckAnalysisCard: ...
5 changes: 2 additions & 3 deletions ax/analysis/markdown/markdown_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

from typing import Optional

from ax.analysis.analysis import Analysis, AnalysisCard
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -34,6 +33,6 @@ class MarkdownAnalysis(Analysis):

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
) -> MarkdownAnalysisCard: ...
11 changes: 5 additions & 6 deletions ax/analysis/old/analysis_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

from typing import Optional

import pandas as pd
import plotly.graph_objects as go
Expand All @@ -27,15 +26,15 @@ class AnalysisReport:
analyses: list[BaseAnalysis] = []
experiment: Experiment

time_started: Optional[int] = None
time_completed: Optional[int] = None
time_started: int | None = None
time_completed: int | None = None

def __init__(
self,
experiment: Experiment,
analyses: list[BaseAnalysis],
time_started: Optional[int] = None,
time_completed: Optional[int] = None,
time_started: int | None = None,
time_completed: int | None = None,
) -> None:
"""
This class is a collection of AnalysisReport.
Expand Down Expand Up @@ -65,7 +64,7 @@ def run_analysis_report(
tuple[
BaseAnalysis,
pd.DataFrame,
Optional[go.Figure],
go.Figure | None,
]
]:
"""
Expand Down
5 changes: 2 additions & 3 deletions ax/analysis/old/base_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

from typing import Optional

import pandas as pd
from ax.core.experiment import Experiment
Expand All @@ -21,7 +20,7 @@ class BaseAnalysis:
def __init__(
self,
experiment: Experiment,
df_input: Optional[pd.DataFrame] = None,
df_input: pd.DataFrame | None = None,
# TODO: add support for passing in experiment name, and markdown message
) -> None:
"""
Expand All @@ -30,7 +29,7 @@ def __init__(
we can pass the dataframe as an input.
"""
self._experiment = experiment
self._df: Optional[pd.DataFrame] = df_input
self._df: pd.DataFrame | None = df_input

@property
def experiment(self) -> Experiment:
Expand Down
5 changes: 2 additions & 3 deletions ax/analysis/old/base_plotly_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

from typing import Optional

import pandas as pd

Expand All @@ -25,8 +24,8 @@ class BasePlotlyVisualization(BaseAnalysis):
def __init__(
self,
experiment: Experiment,
df_input: Optional[pd.DataFrame] = None,
fig_input: Optional[go.Figure] = None,
df_input: pd.DataFrame | None = None,
fig_input: go.Figure | None = None,
) -> None:
"""
Initialize the analysis with the experiment object.
Expand Down
6 changes: 3 additions & 3 deletions ax/analysis/old/cross_validation_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-strict

from copy import deepcopy
from typing import Any, Optional
from typing import Any

import pandas as pd

Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
self,
experiment: Experiment,
model: ModelBridge,
label_dict: Optional[dict[str, str]] = None,
label_dict: dict[str, str] | None = None,
caption: str = CROSS_VALIDATION_CAPTION,
) -> None:
"""
Expand All @@ -56,7 +56,7 @@ def __init__(
self.model = model
self.cv: list[CVResult] = cross_validate(model=model)

self.label_dict: Optional[dict[str, str]] = label_dict
self.label_dict: dict[str, str] | None = label_dict
if self.label_dict:
self.cv = self.remap_label(cv_results=self.cv, label_dict=self.label_dict)

Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/old/helpers/cross_validation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict

from typing import Any, Optional
from typing import Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -186,7 +186,7 @@ def diagonal_trace(min_: float, max_: float, visible: bool = True) -> dict[str,
)


def default_value_se_raw(se_raw: Optional[list[float]], out_length: int) -> list[float]:
def default_value_se_raw(se_raw: list[float] | None, out_length: int) -> list[float]:
"""
Takes a list of standard errors and maps edge cases to default list
of floats.
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/old/helpers/plot_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _format_dict(param_dict: TParameterization, name: str = "Parameterization")
)
else:
blob = "<br><em>{}:</em><br>{}".format(
name, "<br>".join("{}: {}".format(n, v) for n, v in param_dict.items())
name, "<br>".join(f"{n}: {v}" for n, v in param_dict.items())
)
return blob

Expand Down
8 changes: 4 additions & 4 deletions ax/analysis/old/helpers/scatter_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numbers

from typing import Any, Optional
from typing import Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -119,7 +119,7 @@ def extract_mean_and_error_from_df(

def make_label(
arm_name: str,
x_axis_values: Optional[tuple[str, float, float]],
x_axis_values: tuple[str, float, float] | None,
y_axis_values: tuple[str, float, float],
param_blob: TParameterization,
rel: bool,
Expand Down Expand Up @@ -240,8 +240,8 @@ def error_scatter_trace_from_df(
df: pd.DataFrame,
show_CI: bool = True,
visible: bool = True,
y_axis_label: Optional[str] = None,
x_axis_label: Optional[str] = None,
y_axis_label: str | None = None,
x_axis_label: str | None = None,
) -> dict[str, Any]:
"""Plot scatterplot with error bars.

Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/old/predicted_outcomes_dot_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def get_fig(
reverse=True,
)

name_order_axes["xaxis{}".format(i + 1)] = {
name_order_axes[f"xaxis{i + 1}"] = {
"categoryorder": "array",
"categoryarray": names_by_arm,
"type": "category",
}
name_order_axes["yaxis{}".format(i + 1)] = {
name_order_axes[f"yaxis{i + 1}"] = {
"ticksuffix": "%",
"zerolinecolor": "red",
}
Expand Down
8 changes: 4 additions & 4 deletions ax/analysis/plotly/parallel_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-strict

from typing import Any, Optional
from typing import Any

import numpy as np
import pandas as pd
Expand All @@ -32,7 +32,7 @@ class ParallelCoordinatesPlot(PlotlyAnalysis):
- **PARAMETER_NAME: The value of said parameter for the arm, for each parameter
"""

def __init__(self, metric_name: Optional[str] = None) -> None:
def __init__(self, metric_name: str | None = None) -> None:
"""
Args:
metric_name: The name of the metric to plot. If not specified the objective
Expand All @@ -44,8 +44,8 @@ def __init__(self, metric_name: Optional[str] = None) -> None:

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("ParallelCoordinatesPlot requires an Experiment")
Expand Down
5 changes: 2 additions & 3 deletions ax/analysis/plotly/plotly_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

from typing import Optional

from ax.analysis.analysis import Analysis, AnalysisCard
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -35,6 +34,6 @@ class PlotlyAnalysis(Analysis):

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
) -> PlotlyAnalysisCard: ...
12 changes: 6 additions & 6 deletions ax/analysis/plotly/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

from itertools import chain
from typing import Any, Optional
from typing import Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -45,8 +45,8 @@ def __init__(self, metric_name: str) -> None:

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("PredictedEffectsPlot requires an Experiment.")
Expand Down Expand Up @@ -223,8 +223,8 @@ def _get_predictions(
model: ModelBridge,
metric_name: str,
outcome_constraints: list[OutcomeConstraint],
gr: Optional[GeneratorRun] = None,
trial_index: Optional[int] = None,
gr: GeneratorRun | None = None,
trial_index: int | None = None,
) -> list[dict[str, Any]]:
if gr is None:
observations = model.get_training_data()
Expand Down Expand Up @@ -294,7 +294,7 @@ def _get_predictions(
]


def _get_max_observed_trial_index(model: ModelBridge) -> Optional[int]:
def _get_max_observed_trial_index(model: ModelBridge) -> int | None:
"""Returns the max observed trial index to appease multitask models for prediction
by giving fixed features. This is not necessarily accurate and should eventually
come from the generation strategy.
Expand Down
4 changes: 2 additions & 2 deletions ax/benchmark/benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-strict

from typing import Any, Optional
from typing import Any

import pandas as pd
from ax.core.base_trial import BaseTrial
Expand All @@ -26,7 +26,7 @@ def __init__(
name: str,
lower_is_better: bool, # TODO: Do we need to define this here?
observe_noise_sd: bool = True,
outcome_index: Optional[int] = None,
outcome_index: int | None = None,
) -> None:
"""
Args:
Expand Down
6 changes: 3 additions & 3 deletions ax/benchmark/benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any

import pandas as pd

Expand Down Expand Up @@ -38,7 +38,7 @@
def _get_name(
test_problem: BaseTestProblem,
observe_noise_sd: bool,
dim: Optional[int] = None,
dim: int | None = None,
) -> str:
"""
Get a string name describing the problem, in a format such as
Expand Down Expand Up @@ -86,7 +86,7 @@ class BenchmarkProblem(Base):
name: str
optimization_config: OptimizationConfig
num_trials: int
observe_noise_stds: Union[bool, dict[str, bool]] = False
observe_noise_stds: bool | dict[str, bool] = False
optimal_value: float

search_space: SearchSpace = field(repr=False)
Expand Down
5 changes: 2 additions & 3 deletions ax/benchmark/benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional

import numpy as np
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -87,8 +86,8 @@ class BenchmarkResult(Base):
fit_time: float
gen_time: float

experiment: Optional[Experiment] = None
experiment_storage_id: Optional[str] = None
experiment: Experiment | None = None
experiment_storage_id: str | None = None

def __post_init__(self) -> None:
if self.experiment is not None and self.experiment_storage_id is not None:
Expand Down
Loading