Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

BUG: Dont update multi-node env vars for single node training #796

Merged
merged 4 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions InnerEye/Azure/azure_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@

# Environment variables used for multi-node training
ENV_AZ_BATCHAI_MPI_MASTER_NODE = "AZ_BATCHAI_MPI_MASTER_NODE"
ENV_AZ_BATCH_MASTER_NODE = "AZ_BATCH_MASTER_NODE"
ENV_MASTER_ADDR = "MASTER_ADDR"
ENV_MASTER_IP = "MASTER_IP"
ENV_MASTER_PORT = "MASTER_PORT"
ENV_OMPI_COMM_WORLD_RANK = "OMPI_COMM_WORLD_RANK"
ENV_NODE_RANK = "NODE_RANK"
ENV_GLOBAL_RANK = "GLOBAL_RANK"
ENV_LOCAL_RANK = "LOCAL_RANK"
MASTER_PORT_DEFAULT = 6105


def get_git_tags(azure_config: AzureConfig) -> Dict[str, str]:
Expand Down Expand Up @@ -294,21 +296,43 @@ def set_environment_variables_for_multi_node() -> None:
"""
Sets the environment variables that PyTorch Lightning needs for multi-node training.
"""

if ENV_AZ_BATCHAI_MPI_MASTER_NODE in os.environ:
if ENV_AZ_BATCH_MASTER_NODE in os.environ:
master_node = os.environ[ENV_AZ_BATCH_MASTER_NODE]
logging.debug(
f"Found AZ_BATCH_MASTER_NODE: {master_node} in environment variables")
# For AML BATCHAI
split_master_node_addr = master_node.split(":")
if len(split_master_node_addr) == 2:
master_addr, port = split_master_node_addr
os.environ[ENV_MASTER_PORT] = port
elif len(split_master_node_addr) == 1:
master_addr = split_master_node_addr[0]
else:
raise ValueError(f"Format not recognized: {master_node}")
os.environ[ENV_MASTER_ADDR] = master_addr
elif ENV_AZ_BATCHAI_MPI_MASTER_NODE in os.environ and os.environ.get(ENV_AZ_BATCHAI_MPI_MASTER_NODE) != "localhost":
mpi_master_node = os.environ[ENV_AZ_BATCHAI_MPI_MASTER_NODE]
logging.debug(
f"Found AZ_BATCHAI_MPI_MASTER_NODE: {mpi_master_node} in environment variables")
# For AML BATCHAI
os.environ[ENV_MASTER_ADDR] = os.environ[ENV_AZ_BATCHAI_MPI_MASTER_NODE]
os.environ[ENV_MASTER_ADDR] = mpi_master_node
elif ENV_MASTER_IP in os.environ:
master_ip = os.environ[ENV_MASTER_IP]
logging.debug(
f"Found MASTER_IP: {master_ip} in environment variables")
# AKS
os.environ[ENV_MASTER_ADDR] = os.environ[ENV_MASTER_IP]
os.environ[ENV_MASTER_ADDR] = master_ip
else:
logging.info("No settings for the MPI central node found. Assuming that this is a single node training job.")
return

if ENV_MASTER_PORT not in os.environ:
os.environ[ENV_MASTER_PORT] = "6105"
os.environ[ENV_MASTER_PORT] = str(MASTER_PORT_DEFAULT)

if ENV_OMPI_COMM_WORLD_RANK in os.environ:
os.environ[ENV_NODE_RANK] = os.environ[ENV_OMPI_COMM_WORLD_RANK] # node rank is the world_rank from mpi run
world_rank = os.environ[ENV_OMPI_COMM_WORLD_RANK]
logging.debug(f"Found OMPI_COMM_WORLD_RANK: {world_rank} in environment variables")
os.environ[ENV_NODE_RANK] = world_rank # node rank is the world_rank from mpi run

env_vars = ", ".join(f"{var} = {os.environ[var]}" for var in [ENV_MASTER_ADDR, ENV_MASTER_PORT, ENV_NODE_RANK])
print(f"Distributed training: {env_vars}")
logging.info(f"Distributed training: {env_vars}")
3 changes: 2 additions & 1 deletion InnerEye/ML/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ def run_in_situ(self, azure_run_info: AzureRunInfo) -> None:
else:
# Set environment variables for multi-node training if needed. This function will terminate early
# if it detects that it is not in a multi-node environment.
set_environment_variables_for_multi_node()
if self.azure_config.num_nodes > 1:
set_environment_variables_for_multi_node()
self.ml_runner = self.create_ml_runner()
self.ml_runner.setup(azure_run_info)
self.ml_runner.run()
Expand Down