Skip to content

Commit

Permalink
fix lax dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed May 20, 2024
1 parent 395d291 commit cef02c1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,7 +1985,7 @@ def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, trai

# Normalize
xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch
ivar = lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root
ivar = jax.lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root

# Scale and shift
out = xmu * ivar
Expand Down

0 comments on commit cef02c1

Please sign in to comment.