From 7b2a02352a146277fe417318facec1df5d6c6fe0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 1 Jul 2022 06:01:06 -0700 Subject: [PATCH] [Fix] Improve compilation speed by using Set instead of List for query (#567) --- alpa/pipeline_parallel/runtime_emitter.py | 47 +++++++++++------------ 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/alpa/pipeline_parallel/runtime_emitter.py b/alpa/pipeline_parallel/runtime_emitter.py index 1ecff93b3..be21055b9 100644 --- a/alpa/pipeline_parallel/runtime_emitter.py +++ b/alpa/pipeline_parallel/runtime_emitter.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import enum import logging -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Union, Set from jax._src.tree_util import PyTreeDef, tree_unflatten from jax.core import Var @@ -160,11 +160,11 @@ def flatten_uuid_set(container): class PipelineInstEmitterHelper: """Environment for PipelineInstEmitter.""" - def __init__(self, global_invars, grad_dummy_invars, is_batch, - schedule: PipelineSchedule): - self.global_invars = global_invars - self.global_batch_invars = OrderedSet( - v for v, b in zip(global_invars, is_batch) if b) + def __init__(self, global_invar_set: Set[Var], + global_batch_invar_set: Set[Var], + grad_dummy_invars: Dict[Var, Var], schedule: PipelineSchedule): + self.global_invar_set = global_invar_set + self.global_batch_invar_set = global_batch_invar_set self.grad_dummy_invars = grad_dummy_invars self.schedule = schedule # Dict[var_key -> Dict[mesh_idx -> array_uuid]] @@ -172,7 +172,8 @@ def __init__(self, global_invars, grad_dummy_invars, is_batch, self.env = {} def _get_var_key(self, var, batch_idx): - if var in self.global_invars and var not in self.global_batch_invars: + if (var in self.global_invar_set and + var not in self.global_batch_invar_set): key = (var, 0) elif (var in self.grad_dummy_invars and batch_idx != self.schedule.first_backward_batch_index): @@ -283,8 +284,12 @@ def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation], ##### Internal states ##### self.uuid_counter = 0 # counter for local buffer uuid - self.env = PipelineInstEmitterHelper(global_invars, grad_dummy_invars, - is_batch, schedule) + global_invar_set = OrderedSet(global_invars) + global_batch_invar_set = OrderedSet( + v for v, b in zip(global_invars, is_batch) if b) + self.env = PipelineInstEmitterHelper(global_invar_set, + global_batch_invar_set, + grad_dummy_invars, schedule) self._communicator = None self._resharding_tasks = [ [{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh) @@ -390,12 +395,8 @@ def compile(self): executable_config_lists) # Split input into micro batches - global_batch_invar_set = OrderedSet([ - var for var, batch in zip(self.global_invars, self.is_batch) - if batch - ]) - (input_config, input_shard_specs - ) = self._compile_split_input_to_microbatches(global_batch_invar_set) + (input_config, + input_shard_specs) = self._compile_split_input_to_microbatches() # Simulate the pipeline schedule and generate instructions donation_mapping = [DisjointDict() for _ in range(num_mesh)] @@ -618,7 +619,7 @@ def _compile_grad_buffer_allocations(self, executable_config_lists): return grad_uuids, instruction_lists - def _compile_collect_mesh_input(self, mesh_idx, batch_vars): + def _compile_collect_mesh_input(self, mesh_idx): mesh_arg_set = OrderedSet() var_to_spec = {} mesh_batch_vars = OrderedSet() @@ -630,9 +631,9 @@ def _compile_collect_mesh_input(self, mesh_idx, batch_vars): for stage_idx in self.schedule.mesh_stage_mapping[mesh_idx]: stage = self.stages[stage_idx] for spec, invar in zip(stage.input_sharding_specs, stage.invars): - if invar in self.global_invars: + if invar in self.env.global_invar_set: var_to_spec[invar] = spec - if invar in batch_vars: + if invar in self.env.global_batch_invar_set: # Split batch arg for batch_idx in range(num_batch): mesh_arg_set.add((invar, batch_idx)) @@ -666,7 +667,7 @@ def _compile_collect_mesh_input(self, mesh_idx, batch_vars): return (mesh_arg_list, mesh_arg_indices, input_shard_indices, input_shard_specs, mesh_invar_is_batch) - def _compile_split_input_to_microbatches(self, global_batch_invar_set): + def _compile_split_input_to_microbatches(self): """ Split batch arguments into micro batches. @@ -675,10 +676,9 @@ def _compile_split_input_to_microbatches(self, global_batch_invar_set): after (b, d are batch args and #mb=2): a, b0, b1, c, d0, d1 """ donated_invar_set = OrderedSet() - global_invar_set = OrderedSet(self.global_invars) for stage in self.stages: for invar, donate in zip(stage.invars, stage.donated_invars): - if donate and invar in global_invar_set: + if donate and invar in self.env.global_invar_set: donated_invar_set.add(invar) num_mesh = len(self.mesh_group) mesh_arg_lists = [None for _ in range(num_mesh)] @@ -692,13 +692,12 @@ def _compile_split_input_to_microbatches(self, global_batch_invar_set): batch_invars = [] for mesh_idx in range(num_mesh): (mesh_arg_list, arg_indices, shard_indices, shard_specs, - is_batch) = self._compile_collect_mesh_input( - mesh_idx, global_batch_invar_set) + is_batch) = self._compile_collect_mesh_input(mesh_idx) mesh_arg_lists[mesh_idx] = mesh_arg_list delete_after_run = [ var in donated_invar_set or - (var in global_batch_invar_set and + (var in self.env.global_batch_invar_set and global_config.always_donate_micro_batch_vars) for var, _ in mesh_arg_list ]