Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[PERF] update index select for beam search (#576)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZYHowell authored and merrymercy committed Jul 2, 2022
1 parent 65d3cdc commit 724a148
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 138 deletions.
39 changes: 4 additions & 35 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import jax
from jax import core, xla, device_put
from jax._src.api import ShapeDtypeStruct
from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe
from jax._src.lib import xla_bridge as xb, xla_extension as xe
from jax._src.tree_util import tree_leaves
from jax.abstract_arrays import array_types
from jax.core import ShapedArray
Expand All @@ -48,14 +48,11 @@
import alpa.collective as col
from alpa.global_env import global_config
from alpa.monkey_patch import set_override_backend
from alpa.shard_parallel.auto_sharding import (AutoShardingOption,
LogicalDeviceMesh,
run_spmd_partitioner_pass)
from alpa.parallel_plan import PlacementSpec, StagePlan
from alpa.shard_parallel.auto_sharding import (LogicalDeviceMesh)
from alpa.parallel_plan import PlacementSpec
from alpa.timer import timers
from alpa.util import (benchmark_func, list_gpu_info, OrderedSet,
update_jax_platform, is_ray_node_resource,
get_index_select_computation)
update_jax_platform, is_ray_node_resource)

if global_config.nccl_mode == "cupy":
import alpa.collective.worker_nccl_util_cupy as worker_nccl_util
Expand Down Expand Up @@ -1534,34 +1531,6 @@ def __float__(self):

# TODO(lmzheng): copy more functions from DeviceArray
# (jax/_src/device_array.py)
def index_select(self, dim, index):
"""Compile and run index select operation."""
# pylint: disable=import-outside-toplevel
from alpa.mesh_executable import NormalMeshDriverExecutable
if type(index) not in [ShapedArray, ShapeDtypeStruct]:
index = xla.canonicalize_dtype(index)
index_shape = xc.shape_from_pyval(index)
key = hash(("index_select", self.aval, dim, index_shape))
if key in self.device_mesh.operation_executables:
executable = self.device_mesh.operation_executables[key]
else:
index_aval = ShapedArray(index.shape, index.dtype)
c = get_index_select_computation(self.sharding_spec, dim, self.aval,
index_shape).as_hlo_module()
hlo_module = run_spmd_partitioner_pass(c,
self.device_mesh.num_devices)

as_option = AutoShardingOption()
strategy_config = StagePlan(global_config.compile_random_seed,
self.device_mesh.shape, 1 << 60,
as_option.all_reduce_threshold, None,
-1)
executable = NormalMeshDriverExecutable(self.device_mesh,
hlo_module, strategy_config,
[self.aval, index_aval],
[self.aval], [False, False])
self.device_mesh.operation_executables[key] = executable
return executable.launch_on_driver(self, index)

def __str__(self):
return (f"DistributedArray(sharding_spec={self.sharding_spec}, "
Expand Down
44 changes: 39 additions & 5 deletions alpa/mesh_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from typing import Sequence, Optional
import os

from jax import xla
import jax.numpy as jnp
from jax._src.lib import xla_bridge as xb, xla_extension as xe
from jax._src.api import ShapeDtypeStruct
from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe
from jax.core import ShapedArray
from jax.interpreters import pxla
from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
Expand All @@ -26,14 +28,16 @@
next_array_uuids)
from alpa.global_env import global_config
from alpa.parallel_plan import PlacementSpec, StagePlan
from alpa.shard_parallel.auto_sharding import (get_input_output_sharding_specs,
from alpa.shard_parallel.auto_sharding import (AutoShardingOption,
get_input_output_sharding_specs,
make_replicated_spec, HloStatus,
run_backend_compilation)
run_backend_compilation,
run_spmd_partitioner_pass)
from alpa.timer import timers
from alpa.util import (compile_allocate_zero_buffers,
compile_memset_zero_buffers, get_compile_options,
get_shard_shape, get_microbatch_sharding_spec,
profile_xla_executable)
get_index_select_computation, get_shard_shape,
get_microbatch_sharding_spec, profile_xla_executable)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -1117,3 +1121,33 @@ def execute_on_worker(self, input_uuids: Sequence[int],

def __del__(self):
self.concat.delete()


def get_index_select_mesh_executable(avals, sharding_specs, index, dim,
device_mesh, donate_avals):
if type(index) not in [ShapedArray, ShapeDtypeStruct]:
index = xla.canonicalize_dtype(index)
index_shape = xc.shape_from_pyval(index)
key = hash(("index_select", tuple(avals), tuple(sharding_specs),
tuple(donate_avals), dim, index_shape))
if key in device_mesh.operation_executables:
return device_mesh.operation_executables[key]
index_aval = ShapedArray(index.shape, index.dtype)
assert len(avals) == len(sharding_specs) == len(donate_avals)
c = get_index_select_computation(sharding_specs, dim, avals,
index_shape).as_hlo_module()
hlo_module = run_spmd_partitioner_pass(c, device_mesh.num_devices)

as_option = AutoShardingOption()
strategy_config = StagePlan(global_config.compile_random_seed,
device_mesh.shape, 1 << 60,
as_option.all_reduce_threshold, None, -1)
out_tree = tree_flatten(avals)[1]
executable = NormalMeshDriverExecutable(device_mesh,
hlo_module,
strategy_config,
[*avals, index_aval],
avals, [*donate_avals, False],
out_tree=out_tree)
device_mesh.operation_executables[key] = executable
return executable
26 changes: 16 additions & 10 deletions alpa/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,20 +560,26 @@ def compile_concatenate(backend, mesh_shape, sharding_spec, batch_size,
return hlo_proto


def get_index_select_computation(sharding_spec, dim, aval, index_shape):
sharding = pxla.sharding_spec_sharding_proto(sharding_spec)
def get_index_select_computation(sharding_specs, dim, avals, index_shape):
c = xc.XlaBuilder("index_select")
c.set_sharding(sharding)
operand = xc.ops.Parameter(
c, 0, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype)))
c.clear_sharding()
index = xc.ops.Parameter(c, 1, index_shape)
index_selected = xc.ops.IndexSelect(operand, index, dim)
shardings = []
selected = []
index = xc.ops.Parameter(c, len(avals), index_shape)
for i, aval in enumerate(avals):
sharding_spec = sharding_specs[i]
sharding = pxla.sharding_spec_sharding_proto(sharding_spec)
c.set_sharding(sharding)
operand = xc.ops.Parameter(
c, i, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype)))
c.clear_sharding()
index_selected = xc.ops.IndexSelect(operand, index, dim)
shardings.append(sharding)
selected.append(index_selected)
sharding2 = xc.OpSharding()
sharding2.type = sharding.type.TUPLE
sharding2.tuple_shardings = [sharding]
sharding2.tuple_shardings = shardings
c.set_sharding(sharding2)
c = c.build(xc.ops.Tuple(c, [index_selected]))
c = c.build(xc.ops.Tuple(c, selected))
return c


Expand Down
45 changes: 31 additions & 14 deletions examples/opt_serving/benchmark/benchmark_text_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,21 @@
import torch
from transformers import AutoTokenizer

from examples.opt_serving.model.opt_utils import compute_gpt_tflops_inference_with_padding, test_prompts
from examples.opt_serving.model.opt_utils import compute_gpt_tflops_inference_with_padding
from examples.opt_serving.model.wrapper import get_model

test_prompts = [
"Computer science is the study of computation and",
"Ion Stoica is a Romanian-American computer scientist specializing in",
"The University of California, Berkeley is a public",
"Today is a good day and I want to", "What is the valuation of Databricks?",
"Paris is the capital city of", "Which country has the most population?",
"What do you think about the future of Cryptocurrency?",
"What do you think about the meaning of life?",
"Donald Trump is the president of",
"GPT-3 is a large language model that is capable of"
]

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="alpa/opt-125m")
Expand All @@ -41,6 +53,7 @@
parser.add_argument("--decoder-length", type=int, default=1)
parser.add_argument("--nb", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--num-beams", type=int, default=1)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--dtype", type=str, default="fp16")
args = parser.parse_args()
Expand All @@ -60,6 +73,7 @@
num_micro_batches = args.nb
decoder_length_per_step = args.decoder_length
batch_size = args.batch_size
num_beams = args.num_beams
autoregressive = not args.forward
dtype = jnp.float16 if args.dtype == "fp16" else jnp.float32

Expand Down Expand Up @@ -159,7 +173,8 @@
args.path,
autoregressive,
dtype=dtype,
dummy=args.dummy)
dummy=args.dummy,
num_beams=num_beams)
load_time = time.time() - tic

# warm up
Expand All @@ -169,7 +184,8 @@
max_length=256,
do_sample=False,
return_dict_in_generate=True,
output_hidden_states=False)
output_hidden_states=False,
num_beams=num_beams)

H = model.transformer_config.H
L = model.transformer_config.L
Expand All @@ -194,7 +210,8 @@
max_length=256,
do_sample=False,
return_dict_in_generate=True,
output_hidden_states=False)
output_hidden_states=False,
num_beams=num_beams)
latency = time.time() - tic
generated_ids = output.sequences
generated_string = tokenizer.batch_decode(generated_ids,
Expand All @@ -208,11 +225,11 @@
else:
compute_latency = latency
tflops = compute_gpt_tflops_inference_with_padding(
batch_size, gen_len, seq_len, L, H, vocab_size, num_gpus,
latency)
num_beams * batch_size, gen_len, seq_len, L, H, vocab_size,
num_gpus, latency)
compute_tflops = compute_gpt_tflops_inference_with_padding(
batch_size, gen_len, seq_len, L, H, vocab_size, num_gpus,
compute_latency)
num_beams * batch_size, gen_len, seq_len, L, H, vocab_size,
num_gpus, compute_latency)
speed = np.prod(generated_ids.shape) / latency
if args.debug:
print(
Expand All @@ -224,20 +241,20 @@
tflopss.append(tflops)
compute_tflopss.append(compute_tflops)

avg_speed = sum(decode_speeds) / n_iters
avg_tflops = sum(tflopss) / n_iters
avg_compute_tflops = sum(compute_tflopss) / n_iters
avg_speed = np.mean(decode_speeds)
avg_tflops = np.mean(tflopss)
avg_compute_tflops = np.mean(compute_tflopss)
latency_32_tokens = 32.0 / (avg_speed / batch_size)

heads = [
"Model", "Device", "Dummy", "Load (s)", "Autoregressive", "Batchsize",
"#Microbatches", "#Stages", "Decoder step length", "TFlops",
"#Microbatches", "#Beams", "#Stages", "Decoder step length", "TFlops",
"Compute TFlops", "Speed (token/s)", "latency (32 token)"
]
values = [
args.model, args.device, args.dummy, f"{load_time:.2f}",
f"{autoregressive}", f"{batch_size}", f"{num_micro_batches}", "2",
f"{decoder_length_per_step}", f"{avg_tflops:.4f}",
f"{autoregressive}", f"{batch_size}", f"{num_micro_batches}",
f"{num_beams}", "2", f"{decoder_length_per_step}", f"{avg_tflops:.4f}",
f"{avg_compute_tflops:.4f}", f"{avg_speed:.2f}",
f"{latency_32_tokens:.2f}"
]
Expand Down
26 changes: 17 additions & 9 deletions examples/opt_serving/model/opt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,17 @@ def get_pipeshard_executable(config,
model, params = init_model_aval(config)

# Parallelize
method = alpa.PipeshardParallel(num_micro_batches=num_micro_batches,
pipeline_schedule="inference",
layer_option="manual")
method = alpa.PipeshardParallel(
num_micro_batches=num_micro_batches,
pipeline_schedule="inference",
layer_option="manual",
default_auto_sharding_option=alpa.AutoShardingOption(
# Force operator model parallel
force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0,
# Disabling all-to-all and all-gather generates better intra-op strategies.
allow_all_to_all=False,
allow_all_gather=False,
))
#method = alpa.ShardParallel()

if autoregressive:
Expand All @@ -705,9 +713,12 @@ def inference_step_with_cache(params, batch):
alpa.global_config.always_donate_micro_batch_vars = False
executable = inference_step_with_cache.get_executable(
params, {
"input_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32),
"position_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32),
"cache": init_cache_aval(config, batch_size),
"input_ids":
jax.core.ShapedArray((batch_size, 1), jnp.int32),
"position_ids":
jax.core.ShapedArray((batch_size, 1), jnp.int32),
"cache":
init_cache_aval(config, batch_size),
})
else:

Expand All @@ -724,9 +735,6 @@ def inference_step(params, batch):
assert batch_size % num_micro_batches == 0, "cannot divide batch_size by num_micro_batches"
micro_batch_size = batch_size // num_micro_batches

# Disable all-to-all and all-gather generates better intra-op strategies.
method.as_option.allow_all_to_all = False
method.as_option.allow_all_gather = False
executable = inference_step.get_executable(
params, {
"input_ids":
Expand Down
39 changes: 28 additions & 11 deletions examples/opt_serving/model/opt_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from functools import partial

from jax import xla, jit
from jax.core import Primitive
from jax._src.lib import xla_client as xc
from transformers.generation_utils import dataclass


Expand Down Expand Up @@ -27,14 +32,26 @@ def compute_gpt_tflops_inference_with_padding(batch_size, gen_len, seq_len,
return tflops


test_prompts = [
"Computer science is the study of computation and",
"Ion Stoica is a Romanian-American computer scientist specializing in",
"The University of California, Berkeley is a public",
"Today is a good day and I want to", "What is the valuation of Databricks?",
"Paris is the capital city of", "Which country has the most population?",
"What do you think about the future of Cryptocurrency?",
"What do you think about the meaning of life?",
"Donald Trump is the president of",
"GPT-3 is a large language model that is capable of"
]
def is_power_of_two(n):
return (n != 0) and (n & (n-1) == 0)


index_select_p = Primitive("index-select")


@partial(jit, static_argnums=(2,))
def jax_index_select(input, index, dim=0):
return index_select_p.bind(input, index, dim=dim)


def _index_select_eval(input, index, dim):
return input


def _index_select_translation(c, input, index, dim):
return xc.ops.IndexSelect(input, index, dim)


index_select_p.def_abstract_eval(_index_select_eval)
index_select_p.def_impl(partial(xla.apply_primitive, index_select_p))
xla.translations[index_select_p] = _index_select_translation
Loading

0 comments on commit 724a148

Please sign in to comment.