diff --git a/InnerEye/Azure/azure_runner.py b/InnerEye/Azure/azure_runner.py index c64bcb182..e90c42d7b 100644 --- a/InnerEye/Azure/azure_runner.py +++ b/InnerEye/Azure/azure_runner.py @@ -25,6 +25,7 @@ # Environment variables used for multi-node training ENV_AZ_BATCHAI_MPI_MASTER_NODE = "AZ_BATCHAI_MPI_MASTER_NODE" +ENV_AZ_BATCH_MASTER_NODE = "AZ_BATCH_MASTER_NODE" ENV_MASTER_ADDR = "MASTER_ADDR" ENV_MASTER_IP = "MASTER_IP" ENV_MASTER_PORT = "MASTER_PORT" @@ -32,6 +33,7 @@ ENV_NODE_RANK = "NODE_RANK" ENV_GLOBAL_RANK = "GLOBAL_RANK" ENV_LOCAL_RANK = "LOCAL_RANK" +MASTER_PORT_DEFAULT = 6105 def get_git_tags(azure_config: AzureConfig) -> Dict[str, str]: @@ -294,21 +296,43 @@ def set_environment_variables_for_multi_node() -> None: """ Sets the environment variables that PyTorch Lightning needs for multi-node training. """ - - if ENV_AZ_BATCHAI_MPI_MASTER_NODE in os.environ: + if ENV_AZ_BATCH_MASTER_NODE in os.environ: + master_node = os.environ[ENV_AZ_BATCH_MASTER_NODE] + logging.debug( + f"Found AZ_BATCH_MASTER_NODE: {master_node} in environment variables") + # For AML BATCHAI + split_master_node_addr = master_node.split(":") + if len(split_master_node_addr) == 2: + master_addr, port = split_master_node_addr + os.environ[ENV_MASTER_PORT] = port + elif len(split_master_node_addr) == 1: + master_addr = split_master_node_addr[0] + else: + raise ValueError(f"Format not recognized: {master_node}") + os.environ[ENV_MASTER_ADDR] = master_addr + elif ENV_AZ_BATCHAI_MPI_MASTER_NODE in os.environ and os.environ.get(ENV_AZ_BATCHAI_MPI_MASTER_NODE) != "localhost": + mpi_master_node = os.environ[ENV_AZ_BATCHAI_MPI_MASTER_NODE] + logging.debug( + f"Found AZ_BATCHAI_MPI_MASTER_NODE: {mpi_master_node} in environment variables") # For AML BATCHAI - os.environ[ENV_MASTER_ADDR] = os.environ[ENV_AZ_BATCHAI_MPI_MASTER_NODE] + os.environ[ENV_MASTER_ADDR] = mpi_master_node elif ENV_MASTER_IP in os.environ: + master_ip = os.environ[ENV_MASTER_IP] + logging.debug( + f"Found MASTER_IP: {master_ip} in environment variables") # AKS - os.environ[ENV_MASTER_ADDR] = os.environ[ENV_MASTER_IP] + os.environ[ENV_MASTER_ADDR] = master_ip else: logging.info("No settings for the MPI central node found. Assuming that this is a single node training job.") return if ENV_MASTER_PORT not in os.environ: - os.environ[ENV_MASTER_PORT] = "6105" + os.environ[ENV_MASTER_PORT] = str(MASTER_PORT_DEFAULT) if ENV_OMPI_COMM_WORLD_RANK in os.environ: - os.environ[ENV_NODE_RANK] = os.environ[ENV_OMPI_COMM_WORLD_RANK] # node rank is the world_rank from mpi run + world_rank = os.environ[ENV_OMPI_COMM_WORLD_RANK] + logging.debug(f"Found OMPI_COMM_WORLD_RANK: {world_rank} in environment variables") + os.environ[ENV_NODE_RANK] = world_rank # node rank is the world_rank from mpi run + env_vars = ", ".join(f"{var} = {os.environ[var]}" for var in [ENV_MASTER_ADDR, ENV_MASTER_PORT, ENV_NODE_RANK]) - print(f"Distributed training: {env_vars}") + logging.info(f"Distributed training: {env_vars}") diff --git a/InnerEye/ML/runner.py b/InnerEye/ML/runner.py index 5f50fa7eb..35b954e89 100755 --- a/InnerEye/ML/runner.py +++ b/InnerEye/ML/runner.py @@ -405,7 +405,8 @@ def run_in_situ(self, azure_run_info: AzureRunInfo) -> None: else: # Set environment variables for multi-node training if needed. This function will terminate early # if it detects that it is not in a multi-node environment. - set_environment_variables_for_multi_node() + if self.azure_config.num_nodes > 1: + set_environment_variables_for_multi_node() self.ml_runner = self.create_ml_runner() self.ml_runner.setup(azure_run_info) self.ml_runner.run()