Skip to content

Commit

Permalink
Fix breakages from typing changes when SQA is not installed (#2819)
Browse files Browse the repository at this point in the history
Summary:
#2808 moved to PEP604 optional type annotiation. Apparently this broke some code that worked fined wiht the previously used `Optional[DBSettings]` even if `DBSettings` was not defined. I'm not entirely sure what's going on; this is more than about just delayed type annotations. Anyway, this is a hotfix that returns that code to the previous state.

Pull Request resolved: #2819

Reviewed By: saitcakmak

Differential Revision: D63834254

Pulled By: Balandat

fbshipit-source-id: 9ca80b263816d4bae9451375d5cd212752ed23c8
  • Loading branch information
Balandat authored and facebook-github-bot committed Oct 3, 2024
1 parent 33333d2 commit 1f542ff
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from functools import partial

from logging import Logger
from typing import Any, TypeVar
from typing import Any, Optional, TypeVar

import ax.service.utils.early_stopping as early_stopping_utils
import numpy as np
Expand Down Expand Up @@ -179,7 +179,7 @@ class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase):
def __init__(
self,
generation_strategy: GenerationStrategy | None = None,
db_settings: DBSettings | None = None,
db_settings: Optional[DBSettings] = None,
enforce_sequential_optimization: bool = True,
random_seed: int | None = None,
torch_device: torch.device | None = None,
Expand Down
6 changes: 3 additions & 3 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from datetime import datetime
from logging import LoggerAdapter
from time import sleep
from typing import Any, cast, NamedTuple
from typing import Any, cast, NamedTuple, Optional

import ax.service.utils.early_stopping as early_stopping_utils
from ax.analysis.analysis import Analysis, AnalysisCard
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(
experiment: Experiment,
generation_strategy: GenerationStrategyInterface,
options: SchedulerOptions,
db_settings: DBSettings | None = None,
db_settings: Optional[DBSettings] = None,
_skip_experiment_save: bool = False,
) -> None:
self.experiment = experiment
Expand Down Expand Up @@ -242,7 +242,7 @@ def from_stored_experiment(
cls,
experiment_name: str,
options: SchedulerOptions,
db_settings: DBSettings | None = None,
db_settings: Optional[DBSettings] = None,
generation_strategy: GenerationStrategy | None = None,
reduced_state: bool = True,
**kwargs: Any,
Expand Down
4 changes: 2 additions & 2 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from math import ceil
from random import randint
from tempfile import NamedTemporaryFile
from typing import Any, cast
from typing import Any, cast, Optional
from unittest.mock import call, Mock, patch, PropertyMock

import pandas as pd
Expand Down Expand Up @@ -404,7 +404,7 @@ def db_settings(self) -> DBSettings:
return DBSettings(encoder=encoder, decoder=decoder)

@property
def db_settings_if_always_needed(self) -> DBSettings | None:
def db_settings_if_always_needed(self) -> Optional[DBSettings]:
if self.ALWAYS_USE_DB:
return self.db_settings
return None
Expand Down
8 changes: 4 additions & 4 deletions ax/service/utils/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Iterable

from logging import INFO, Logger
from typing import Any
from typing import Any, Optional

from ax.analysis.analysis import AnalysisCard

Expand Down Expand Up @@ -93,7 +93,7 @@ class WithDBSettingsBase:
if `db_settings` property is set to a non-None value on the instance.
"""

_db_settings: DBSettings | None = None
_db_settings: Optional[DBSettings] = None

# Mapping of object types to mapping of fields to override values
# loaded objects will all be instantiated with fields set to
Expand All @@ -103,7 +103,7 @@ class WithDBSettingsBase:

def __init__(
self,
db_settings: DBSettings | None = None,
db_settings: Optional[DBSettings] = None,
logging_level: int = INFO,
suppress_all_errors: bool = False,
) -> None:
Expand All @@ -123,7 +123,7 @@ def __init__(
logger.setLevel(logging_level)

@staticmethod
def _get_default_db_settings() -> DBSettings | None:
def _get_default_db_settings() -> Optional[DBSettings]:
"""Overridable method to get default db_settings
if none are passed in __init__
"""
Expand Down

0 comments on commit 1f542ff

Please sign in to comment.