diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index 6a66cef97..afe71b095 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -32,7 +32,7 @@ import jax from jax import core, xla, device_put from jax._src.api import ShapeDtypeStruct -from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe +from jax._src.lib import xla_bridge as xb, xla_extension as xe from jax._src.tree_util import tree_leaves from jax.abstract_arrays import array_types from jax.core import ShapedArray @@ -48,14 +48,11 @@ import alpa.collective as col from alpa.global_env import global_config from alpa.monkey_patch import set_override_backend -from alpa.shard_parallel.auto_sharding import (AutoShardingOption, - LogicalDeviceMesh, - run_spmd_partitioner_pass) -from alpa.parallel_plan import PlacementSpec, StagePlan +from alpa.shard_parallel.auto_sharding import (LogicalDeviceMesh) +from alpa.parallel_plan import PlacementSpec from alpa.timer import timers from alpa.util import (benchmark_func, list_gpu_info, OrderedSet, - update_jax_platform, is_ray_node_resource, - get_index_select_computation) + update_jax_platform, is_ray_node_resource) if global_config.nccl_mode == "cupy": import alpa.collective.worker_nccl_util_cupy as worker_nccl_util @@ -1534,34 +1531,6 @@ def __float__(self): # TODO(lmzheng): copy more functions from DeviceArray # (jax/_src/device_array.py) - def index_select(self, dim, index): - """Compile and run index select operation.""" - # pylint: disable=import-outside-toplevel - from alpa.mesh_executable import NormalMeshDriverExecutable - if type(index) not in [ShapedArray, ShapeDtypeStruct]: - index = xla.canonicalize_dtype(index) - index_shape = xc.shape_from_pyval(index) - key = hash(("index_select", self.aval, dim, index_shape)) - if key in self.device_mesh.operation_executables: - executable = self.device_mesh.operation_executables[key] - else: - index_aval = ShapedArray(index.shape, index.dtype) - c = get_index_select_computation(self.sharding_spec, dim, self.aval, - index_shape).as_hlo_module() - hlo_module = run_spmd_partitioner_pass(c, - self.device_mesh.num_devices) - - as_option = AutoShardingOption() - strategy_config = StagePlan(global_config.compile_random_seed, - self.device_mesh.shape, 1 << 60, - as_option.all_reduce_threshold, None, - -1) - executable = NormalMeshDriverExecutable(self.device_mesh, - hlo_module, strategy_config, - [self.aval, index_aval], - [self.aval], [False, False]) - self.device_mesh.operation_executables[key] = executable - return executable.launch_on_driver(self, index) def __str__(self): return (f"DistributedArray(sharding_spec={self.sharding_spec}, " diff --git a/alpa/mesh_executable.py b/alpa/mesh_executable.py index 2d7a840e5..668d9bfa8 100644 --- a/alpa/mesh_executable.py +++ b/alpa/mesh_executable.py @@ -13,8 +13,10 @@ from typing import Sequence, Optional import os +from jax import xla import jax.numpy as jnp -from jax._src.lib import xla_bridge as xb, xla_extension as xe +from jax._src.api import ShapeDtypeStruct +from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe from jax.core import ShapedArray from jax.interpreters import pxla from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef @@ -26,14 +28,16 @@ next_array_uuids) from alpa.global_env import global_config from alpa.parallel_plan import PlacementSpec, StagePlan -from alpa.shard_parallel.auto_sharding import (get_input_output_sharding_specs, +from alpa.shard_parallel.auto_sharding import (AutoShardingOption, + get_input_output_sharding_specs, make_replicated_spec, HloStatus, - run_backend_compilation) + run_backend_compilation, + run_spmd_partitioner_pass) from alpa.timer import timers from alpa.util import (compile_allocate_zero_buffers, compile_memset_zero_buffers, get_compile_options, - get_shard_shape, get_microbatch_sharding_spec, - profile_xla_executable) + get_index_select_computation, get_shard_shape, + get_microbatch_sharding_spec, profile_xla_executable) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -1117,3 +1121,33 @@ def execute_on_worker(self, input_uuids: Sequence[int], def __del__(self): self.concat.delete() + + +def get_index_select_mesh_executable(avals, sharding_specs, index, dim, + device_mesh, donate_avals): + if type(index) not in [ShapedArray, ShapeDtypeStruct]: + index = xla.canonicalize_dtype(index) + index_shape = xc.shape_from_pyval(index) + key = hash(("index_select", tuple(avals), tuple(sharding_specs), + tuple(donate_avals), dim, index_shape)) + if key in device_mesh.operation_executables: + return device_mesh.operation_executables[key] + index_aval = ShapedArray(index.shape, index.dtype) + assert len(avals) == len(sharding_specs) == len(donate_avals) + c = get_index_select_computation(sharding_specs, dim, avals, + index_shape).as_hlo_module() + hlo_module = run_spmd_partitioner_pass(c, device_mesh.num_devices) + + as_option = AutoShardingOption() + strategy_config = StagePlan(global_config.compile_random_seed, + device_mesh.shape, 1 << 60, + as_option.all_reduce_threshold, None, -1) + out_tree = tree_flatten(avals)[1] + executable = NormalMeshDriverExecutable(device_mesh, + hlo_module, + strategy_config, + [*avals, index_aval], + avals, [*donate_avals, False], + out_tree=out_tree) + device_mesh.operation_executables[key] = executable + return executable diff --git a/alpa/util.py b/alpa/util.py index 7069706b2..49ecaeb13 100644 --- a/alpa/util.py +++ b/alpa/util.py @@ -560,20 +560,26 @@ def compile_concatenate(backend, mesh_shape, sharding_spec, batch_size, return hlo_proto -def get_index_select_computation(sharding_spec, dim, aval, index_shape): - sharding = pxla.sharding_spec_sharding_proto(sharding_spec) +def get_index_select_computation(sharding_specs, dim, avals, index_shape): c = xc.XlaBuilder("index_select") - c.set_sharding(sharding) - operand = xc.ops.Parameter( - c, 0, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype))) - c.clear_sharding() - index = xc.ops.Parameter(c, 1, index_shape) - index_selected = xc.ops.IndexSelect(operand, index, dim) + shardings = [] + selected = [] + index = xc.ops.Parameter(c, len(avals), index_shape) + for i, aval in enumerate(avals): + sharding_spec = sharding_specs[i] + sharding = pxla.sharding_spec_sharding_proto(sharding_spec) + c.set_sharding(sharding) + operand = xc.ops.Parameter( + c, i, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype))) + c.clear_sharding() + index_selected = xc.ops.IndexSelect(operand, index, dim) + shardings.append(sharding) + selected.append(index_selected) sharding2 = xc.OpSharding() sharding2.type = sharding.type.TUPLE - sharding2.tuple_shardings = [sharding] + sharding2.tuple_shardings = shardings c.set_sharding(sharding2) - c = c.build(xc.ops.Tuple(c, [index_selected])) + c = c.build(xc.ops.Tuple(c, selected)) return c diff --git a/examples/opt_serving/benchmark/benchmark_text_gen.py b/examples/opt_serving/benchmark/benchmark_text_gen.py index 0b8f46fcc..b5f8d4b92 100644 --- a/examples/opt_serving/benchmark/benchmark_text_gen.py +++ b/examples/opt_serving/benchmark/benchmark_text_gen.py @@ -28,9 +28,21 @@ import torch from transformers import AutoTokenizer -from examples.opt_serving.model.opt_utils import compute_gpt_tflops_inference_with_padding, test_prompts +from examples.opt_serving.model.opt_utils import compute_gpt_tflops_inference_with_padding from examples.opt_serving.model.wrapper import get_model +test_prompts = [ + "Computer science is the study of computation and", + "Ion Stoica is a Romanian-American computer scientist specializing in", + "The University of California, Berkeley is a public", + "Today is a good day and I want to", "What is the valuation of Databricks?", + "Paris is the capital city of", "Which country has the most population?", + "What do you think about the future of Cryptocurrency?", + "What do you think about the meaning of life?", + "Donald Trump is the president of", + "GPT-3 is a large language model that is capable of" +] + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="alpa/opt-125m") @@ -41,6 +53,7 @@ parser.add_argument("--decoder-length", type=int, default=1) parser.add_argument("--nb", type=int, default=1) parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--num-beams", type=int, default=1) parser.add_argument("--debug", action="store_true") parser.add_argument("--dtype", type=str, default="fp16") args = parser.parse_args() @@ -60,6 +73,7 @@ num_micro_batches = args.nb decoder_length_per_step = args.decoder_length batch_size = args.batch_size + num_beams = args.num_beams autoregressive = not args.forward dtype = jnp.float16 if args.dtype == "fp16" else jnp.float32 @@ -159,7 +173,8 @@ args.path, autoregressive, dtype=dtype, - dummy=args.dummy) + dummy=args.dummy, + num_beams=num_beams) load_time = time.time() - tic # warm up @@ -169,7 +184,8 @@ max_length=256, do_sample=False, return_dict_in_generate=True, - output_hidden_states=False) + output_hidden_states=False, + num_beams=num_beams) H = model.transformer_config.H L = model.transformer_config.L @@ -194,7 +210,8 @@ max_length=256, do_sample=False, return_dict_in_generate=True, - output_hidden_states=False) + output_hidden_states=False, + num_beams=num_beams) latency = time.time() - tic generated_ids = output.sequences generated_string = tokenizer.batch_decode(generated_ids, @@ -208,11 +225,11 @@ else: compute_latency = latency tflops = compute_gpt_tflops_inference_with_padding( - batch_size, gen_len, seq_len, L, H, vocab_size, num_gpus, - latency) + num_beams * batch_size, gen_len, seq_len, L, H, vocab_size, + num_gpus, latency) compute_tflops = compute_gpt_tflops_inference_with_padding( - batch_size, gen_len, seq_len, L, H, vocab_size, num_gpus, - compute_latency) + num_beams * batch_size, gen_len, seq_len, L, H, vocab_size, + num_gpus, compute_latency) speed = np.prod(generated_ids.shape) / latency if args.debug: print( @@ -224,20 +241,20 @@ tflopss.append(tflops) compute_tflopss.append(compute_tflops) - avg_speed = sum(decode_speeds) / n_iters - avg_tflops = sum(tflopss) / n_iters - avg_compute_tflops = sum(compute_tflopss) / n_iters + avg_speed = np.mean(decode_speeds) + avg_tflops = np.mean(tflopss) + avg_compute_tflops = np.mean(compute_tflopss) latency_32_tokens = 32.0 / (avg_speed / batch_size) heads = [ "Model", "Device", "Dummy", "Load (s)", "Autoregressive", "Batchsize", - "#Microbatches", "#Stages", "Decoder step length", "TFlops", + "#Microbatches", "#Beams", "#Stages", "Decoder step length", "TFlops", "Compute TFlops", "Speed (token/s)", "latency (32 token)" ] values = [ args.model, args.device, args.dummy, f"{load_time:.2f}", - f"{autoregressive}", f"{batch_size}", f"{num_micro_batches}", "2", - f"{decoder_length_per_step}", f"{avg_tflops:.4f}", + f"{autoregressive}", f"{batch_size}", f"{num_micro_batches}", + f"{num_beams}", "2", f"{decoder_length_per_step}", f"{avg_tflops:.4f}", f"{avg_compute_tflops:.4f}", f"{avg_speed:.2f}", f"{latency_32_tokens:.2f}" ] diff --git a/examples/opt_serving/model/opt_model.py b/examples/opt_serving/model/opt_model.py index bb9f0a9e3..ef69fa723 100644 --- a/examples/opt_serving/model/opt_model.py +++ b/examples/opt_serving/model/opt_model.py @@ -684,9 +684,17 @@ def get_pipeshard_executable(config, model, params = init_model_aval(config) # Parallelize - method = alpa.PipeshardParallel(num_micro_batches=num_micro_batches, - pipeline_schedule="inference", - layer_option="manual") + method = alpa.PipeshardParallel( + num_micro_batches=num_micro_batches, + pipeline_schedule="inference", + layer_option="manual", + default_auto_sharding_option=alpa.AutoShardingOption( + # Force operator model parallel + force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0, + # Disabling all-to-all and all-gather generates better intra-op strategies. + allow_all_to_all=False, + allow_all_gather=False, + )) #method = alpa.ShardParallel() if autoregressive: @@ -705,9 +713,12 @@ def inference_step_with_cache(params, batch): alpa.global_config.always_donate_micro_batch_vars = False executable = inference_step_with_cache.get_executable( params, { - "input_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), - "position_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), - "cache": init_cache_aval(config, batch_size), + "input_ids": + jax.core.ShapedArray((batch_size, 1), jnp.int32), + "position_ids": + jax.core.ShapedArray((batch_size, 1), jnp.int32), + "cache": + init_cache_aval(config, batch_size), }) else: @@ -724,9 +735,6 @@ def inference_step(params, batch): assert batch_size % num_micro_batches == 0, "cannot divide batch_size by num_micro_batches" micro_batch_size = batch_size // num_micro_batches - # Disable all-to-all and all-gather generates better intra-op strategies. - method.as_option.allow_all_to_all = False - method.as_option.allow_all_gather = False executable = inference_step.get_executable( params, { "input_ids": diff --git a/examples/opt_serving/model/opt_utils.py b/examples/opt_serving/model/opt_utils.py index 6b84ad891..e857865c6 100644 --- a/examples/opt_serving/model/opt_utils.py +++ b/examples/opt_serving/model/opt_utils.py @@ -1,3 +1,8 @@ +from functools import partial + +from jax import xla, jit +from jax.core import Primitive +from jax._src.lib import xla_client as xc from transformers.generation_utils import dataclass @@ -27,14 +32,26 @@ def compute_gpt_tflops_inference_with_padding(batch_size, gen_len, seq_len, return tflops -test_prompts = [ - "Computer science is the study of computation and", - "Ion Stoica is a Romanian-American computer scientist specializing in", - "The University of California, Berkeley is a public", - "Today is a good day and I want to", "What is the valuation of Databricks?", - "Paris is the capital city of", "Which country has the most population?", - "What do you think about the future of Cryptocurrency?", - "What do you think about the meaning of life?", - "Donald Trump is the president of", - "GPT-3 is a large language model that is capable of" -] +def is_power_of_two(n): + return (n != 0) and (n & (n-1) == 0) + + +index_select_p = Primitive("index-select") + + +@partial(jit, static_argnums=(2,)) +def jax_index_select(input, index, dim=0): + return index_select_p.bind(input, index, dim=dim) + + +def _index_select_eval(input, index, dim): + return input + + +def _index_select_translation(c, input, index, dim): + return xc.ops.IndexSelect(input, index, dim) + + +index_select_p.def_abstract_eval(_index_select_eval) +index_select_p.def_impl(partial(xla.apply_primitive, index_select_p)) +xla.translations[index_select_p] = _index_select_translation diff --git a/examples/opt_serving/model/wrapper.py b/examples/opt_serving/model/wrapper.py index 2cc7880e4..577a9e953 100644 --- a/examples/opt_serving/model/wrapper.py +++ b/examples/opt_serving/model/wrapper.py @@ -1,13 +1,13 @@ -from functools import partial +from collections import defaultdict import os from typing import Sequence, Any import alpa +from alpa.device_mesh import DistributedArray +from alpa.mesh_executable import get_index_select_mesh_executable import jax from jax import xla from jax import ShapeDtypeStruct, ShapedArray -from jax._src.lib import xla_client as xc -from jax.core import Primitive from jax.interpreters import pxla from jax.interpreters.pxla import NoSharding, Replicated, ShardingSpec import jax.numpy as jnp @@ -19,22 +19,9 @@ from examples.opt_serving.model.opt_model import ( get_opt_config, get_pipeshard_executable, load_params_dis_array, init_cache_dis_array, load_params_np, init_cache_np, get_jax_executable) -from examples.opt_serving.model.opt_utils import TransformerModelConfig - - -index_select_p = Primitive("index-select") -def jax_index_select(input, index, dim=0): - return index_select_p.bind(input, index, dim=dim) - -def _index_select_eval(input, index, dim): - return input - -def _index_select_translation(c, input, index, dim): - return xc.ops.IndexSelect(input, index, dim) - -index_select_p.def_abstract_eval(_index_select_eval) -index_select_p.def_impl(partial(xla.apply_primitive, index_select_p)) -xla.translations[index_select_p] = _index_select_translation +from examples.opt_serving.model.opt_utils import (TransformerModelConfig, + jax_index_select, + is_power_of_two) @dataclass @@ -94,8 +81,11 @@ def __init__(self, inference_func, config, executable, transformer_config): self.main_input_name = "input_ids" self.executable = executable self.transformer_config = transformer_config + self.index_select_executables = {} + self.cache_location = None def forward(self, attention_mask): + # This function is never used raise NotImplementedError() def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): @@ -114,6 +104,7 @@ def __call__(self, output_attentions=None, output_hidden_states=None, return_dict=None): + # Decompose the call to token by token for i in range(input_ids.shape[1]): ret = self.inference_func(input_ids[:, i:i + 1], past_key_values, @@ -123,42 +114,73 @@ def __call__(self, return ret def _reorder_cache(self, past, beam_idx): - # Current beam_idx is a torch tensor from beam scorer. To speedup, - # we need to have alpa's own beam scorer - cache = {} - cpu_idx = beam_idx.to("cpu").numpy() - if type(cpu_idx) not in [ShapedArray, ShapeDtypeStruct]: - cpu_idx = xla.canonicalize_dtype(cpu_idx) - - def to_mesh(mesh): - if mesh in cache: - return cache[mesh] - avals = [ShapedArray(cpu_idx.shape, cpu_idx.dtype)] - replicated_spec = ShardingSpec([NoSharding()] * len(cpu_idx.shape), - [Replicated(mesh.num_devices)]) - specs = [replicated_spec] - indices = [pxla.spec_to_indices(cpu_idx.shape, replicated_spec)] - ary = mesh.shard_args_to_arrays(avals, indices, specs, [cpu_idx])[0] - cache[mesh] = ary - return ary - - def single_element_reorder_cache(ary): - if hasattr(ary, "index_select"): - # Torch or Alpa path - device_idx = None - if hasattr(ary, "device"): # Torch to_device - device_idx = beam_idx.to(ary.device) - else: - device_idx = to_mesh(ary.device_mesh) - return ary.index_select(0, device_idx) + # Reorder cache for beam search + + # PyTorch + if hasattr(past[0][0], "index_select"): + return tuple( + tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past) + for layer_past in past) + + # Jax (single-device) + if not isinstance(past[0][0], DistributedArray): + beam_idx = jnp.array(beam_idx.to("cpu").numpy()) + return tuple( + tuple( + jax_index_select(past_state, beam_idx, 0) + for past_state in layer_past) + for layer_past in past) + + # Alpa + mesh_groups = defaultdict(list) + if self.cache_location is None: + self.cache_location = [] + for layer_past in past: + tmp_loc = [] + for past_state in layer_past: + assert isinstance(past_state, DistributedArray) + mesh = past_state.device_mesh + mesh_groups[mesh].append(past_state) + tmp_loc.append((mesh, len(mesh_groups[mesh]) - 1)) + self.cache_location.append(tmp_loc) + else: + for layer_past in past: + for past_state in layer_past: + assert isinstance(past_state, DistributedArray) + mesh = past_state.device_mesh + mesh_groups[mesh].append(past_state) + + beam_idx = beam_idx.to("cpu").numpy() + + def grouped_reorder_cache(arys, device_mesh): + if len(arys) == 0: + return [] + if device_mesh in self.index_select_executables: + executable = self.index_select_executables[device_mesh] else: - # Jax path - return jax_index_select(ary, cpu_idx, 0) + dim = 0 + avals = [ary.aval for ary in arys] + specs = [ary.sharding_spec for ary in arys] + executable = get_index_select_mesh_executable( + avals, specs, beam_idx, dim, device_mesh, + [False] * len(avals)) + self.index_select_executables[device_mesh] = executable + ret = executable(*arys, beam_idx) + for v in ret: + v.skip_shard_args_check = True + return ret + + results = { + mesh: grouped_reorder_cache(mesh_groups[mesh], mesh) + for mesh in mesh_groups + } + return tuple( - tuple( - single_element_reorder_cache(past_state) - for past_state in layer_past) - for layer_past in past) + tuple(results[mesh][loc] + for mesh, loc in layer_loc) + for layer_loc in self.cache_location) def get_hf_gpt_model(model_name, device, num_beams): @@ -262,7 +284,8 @@ def get_model(model_name: str, # weight path path = os.path.join(path, f"{name}_np") - assert os.path.exists(path), f"No such file or directory: '{path}'" + if not dummy: + assert os.path.exists(path), f"No such file or directory: '{path}'" if "jax/opt" in model_name: config = get_opt_config(name, @@ -287,6 +310,7 @@ def get_model(model_name: str, params, init_cache = jax.tree_map(jnp.array, (params, init_cache)) else: assert "alpa/opt" in model_name + assert is_power_of_two(num_beams), "num_beams must be a power of two" alpa.init() print( @@ -294,6 +318,8 @@ def get_model(model_name: str, ) num_pp_stages = max(2, alpa.get_global_cluster().num_hosts) + num_pp_stages = min(num_pp_stages, + alpa.get_global_cluster().num_devices) config = get_opt_config(name, num_pp_stages=num_pp_stages, dtype=dtype) transformer_config = TransformerModelConfig( H=config.decoder_embed_dim,