Skip to content

Commit

Permalink
feat(jax-backend) - Implementing native jax while loop in nms body (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mobley-trent authored Feb 26, 2024
1 parent 7ba2e79 commit e1c93cb
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion ivy/functional/backends/jax/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ def nms(
keep = jnp.zeros((size,), dtype=jnp.int64)
keep_idx = 0

while jnp.unique(order).size > 1:
def body_fn(loop_vars):
keep, keep_idx, boxes, areas, order = loop_vars
max_iou_idx = order[0]
keep = keep.at[keep_idx].set(max_iou_idx)
keep_idx += 1
Expand All @@ -522,6 +523,15 @@ def nms(
boxes = boxes.at[forward].set(boxes[forward[::-1]])
areas = areas.at[forward].set(areas[forward[::-1]])

return keep, keep_idx, boxes, areas, order

def cond_fn(loop_vars):
_, _, _, _, order = loop_vars
return jnp.min(order) != jnp.max(order)

init_vars = (keep, keep_idx, boxes, areas, order)
keep, keep_idx, boxes, _, _ = jlax.while_loop(cond_fn, body_fn, init_vars)

ret = jnp.array(keep[:keep_idx], dtype=jnp.int64)

if len(ret) > 1 and scores is not None:
Expand Down

0 comments on commit e1c93cb

Please sign in to comment.