Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unsafe_rbg + vmap --> 10x slow down #16792

Closed
dlwh opened this issue Jul 19, 2023 · 1 comment · Fixed by #20094
Closed

unsafe_rbg + vmap --> 10x slow down #16792

dlwh opened this issue Jul 19, 2023 · 1 comment · Fixed by #20094
Assignees
Labels
bug Something isn't working

Comments

@dlwh
Copy link
Contributor

dlwh commented Jul 19, 2023

Description

unsafe_rbg is advertised as the solution for performance issues for rng, but it has surprising pessimization in the presence of vmap, more than 10x in this script, but more like 4x in real use.

The key is vmap(loss), where loss calls the RNG.

This is with LIBTPU_INIT_ARGS=--xla_tpu_spmd_rng_bit_generator_unsafe=true

Times:

  • vmap, no dropout: 0.057
  • vmap, dropout, threefry: 0.092
  • vmap, dropout, unsafe: 0.736 (!!!!!)
  • no vmap, no dropout: 0.057
  • no vmap, dropout, threefy: 0.088
  • no vmap, dropout, unsafe: 0.075
# import functools
import time

import jax
import jax.numpy as jnp
import numpy as onp
from jax.sharding import Mesh, NamedSharding, PartitionSpec


batch_size = 256
seq_len = 2048
embed_size = 256
vocab_size = 2000
num_layers = 20

pdrop = 0.1
USE_VMAP = True
USE_UNSAFE_RBG = True

mesh = Mesh(onp.array(jax.devices()), ("dp",))


if USE_UNSAFE_RBG:
    jax.config.update("jax_default_prng_impl", "unsafe_rbg")
else:
    jax.config.update("jax_threefry_partitionable", True)

with mesh:
    key = jax.random.PRNGKey(0)

    def model(tokens, key):
        embed = jnp.take(jnp.ones((vocab_size, embed_size)), tokens, axis=0)
        # dumb fake gpt2 attn
        for i in range(0, num_layers):
            attn = jnp.einsum("...ld,...kd->...lk", embed, embed)

            if pdrop > 0.0:
                key, subkey = jax.random.split(key)
                dout = jax.random.bernoulli(subkey, pdrop, shape=attn.shape)
                attn = jnp.where(dout, jnp.zeros_like(attn), attn)

            attn = jax.nn.softmax(attn, axis=-1)
            embed = jnp.einsum("...ld,...lk->...kd", attn, embed)

        out = jnp.einsum("...ld,...kd->...lk", embed, jnp.ones((vocab_size, embed_size)))

        return out

    def compute_loss(example, key):
        pred_y = model(example, key=key)
        return jnp.mean(pred_y)

    def compute_loss_vmap(examples, key):
        key = jax.random.split(key, batch_size)
        per_ex_loss = jax.vmap(compute_loss)(examples, key)
        return jnp.mean(per_ex_loss)

    if USE_VMAP:
        compute_loss_pjit = jax.jit(compute_loss_vmap)
    else:
        compute_loss_pjit = jax.jit(compute_loss)

    # i still honestly find the way to turn a "replicated" array like batch into a sharded array to be a bit confusing
    batch = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
    batch = jax.make_array_from_callback(
        (batch_size, seq_len), NamedSharding(mesh, PartitionSpec("dp", None)), lambda idx: batch[idx]
    )

    total_loss = 0.0
    total_time = 0.0

    for n in range(20):
        this_key, key = jax.random.split(key)
        time_in = time.time()
        loss = compute_loss_pjit(batch, this_key)

        total_loss += loss.item()
        time_out = time.time()

        if n > 0:
            total_time += time_out - time_in

    print(f"eval loss: {total_loss / n:.3f}")
    print(f"eval time: {total_time / (n-1):.3f}")

What jax/jaxlib version are you using?

0.4.13

Which accelerator(s) are you using?

TPU

Additional system info

v3-32

NVIDIA GPU info

No response

@dlwh dlwh added the bug Something isn't working label Jul 19, 2023
@froystig
Copy link
Member

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants