-
I have a recurring problem with hydra and Google'a AI platform. When submitting a training job via gcloud the cli forces me to always add |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
There is no official way to "patch the function" to add an extra argument. |
Beta Was this translation helpful? Give feedback.
-
The following did the job although modifying sys.argv which is not pretty import argparse
import sys
from typing import Any, Callable, Dict, List, Tuple, Union
import hydra
import wrapt
from omegaconf import DictConfig
def _namespace_to_hydra_overrides(args: argparse.Namespace) -> List[str]:
overrides = []
for key, val in vars(args).items():
if val is not None:
overrides.append(f"{key}={val}")
return overrides
@wrapt.decorator
def translate_gcloud_ai_platform_argparse_to_hydra_overrides(
fn: Callable,
instance: Union[None, Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
):
parser = argparse.ArgumentParser()
parser.add_argument("--job_dir", "--job-dir", required=False)
namespace, unparsed_args = parser.parse_known_args()
# By default argparse uses sys.argv[1:] to search for arguments.
sys.argv[1:] = unparsed_args + _namespace_to_hydra_overrides(namespace)
return fn(*args, **kwargs)
@translate_gcloud_ai_platform_argparse_to_hydra_overrides
@hydra.main(version_base="1.2", config_path="configs", config_name="train")
def main(cfg: DictConfig) -> None:
print(cfg)
if __name__ == "__main__":
main() |
Beta Was this translation helpful? Give feedback.
The following did the job although modifying sys.argv which is not pretty