JIT compatible equivalent to segment_sum that concatenates instead of summing? #25183
-
Hi, I am new to (thinking in) JAX and trying to implement a function def segment_concatenate(data, segment_ids, num_segments, max_duplicate_indices, padding_value): where the output is an array Is it possible to implement this in a JIT compatible way? The output shape is static, but I have not yet come up with a solution that avoids intermediate variable-shaped arrays. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
This function does what I want to and is JIT compatible, but I am not sure if the combination of from functools import partial
from typing import Any
import jax
import jax.numpy as jnp
from jaxtyping import Array, Integer, Shaped
@partial(jax.jit, static_argnums=(2, 3))
def segment_concatenate(
data: Shaped[Array, "E"],
segment_ids: Integer[Array, "E"],
num_segments: int,
max_duplicate_indexes: int,
padding_value: Any,
) -> Shaped[Array, "{num_segments} {max_duplicate_indexes}"]:
output = jnp.full(
(num_segments, max_duplicate_indexes), padding_value, dtype=data.dtype
)
segment_id_counts = jnp.zeros(num_segments, dtype=jnp.int32)
for i, segment_id in enumerate(segment_ids):
output = output.at[segment_id, segment_id_counts[segment_id]].set(data[i])
segment_id_counts = segment_id_counts.at[segment_id].add(1)
return output |
Beta Was this translation helpful? Give feedback.
This looks like a decent enough approach, and should be efficient so long as
segment_ids
is quite small. Ifsegment_ids
is longer, the flattening of the pythonfor
loop could lead to prohibitively long compilation times – I'd suggest rewriting that part in a vectorized manner if possible. It seems like it would effectively be a generalization of some of the approaches used in the implementation ofjnp.unique
.