diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 6a1cef306be..c11884fa370 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -2697,6 +2697,117 @@ def test_aten_native_layer_norm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs) + def test_aten_native_batch_norm_legit(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,2,2)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + False, + 0.5, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs) + + def test_aten_native_batch_norm_legit_none(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,4)).to(torch.float32), + None, + None, + torch.ones(channel), + torch.zeros(channel), + False, + 0.5, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs) + + def test_aten_native_batch_norm_legit_training_none(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + None, + None, + torch.zeros(channel), + torch.ones(channel), + True, + 0.2, + 2e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs) + + def test_aten_native_batch_norm_legit_no_training(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + 0.2, + 2e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs) + + def test_aten_native_batch_norm_training(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + True, + 0.1, + 1e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) + + def test_aten_native_batch_norm_training_none(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + None, + None, + torch.zeros(channel), + torch.ones(channel), + True, + 0.1, + 1e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) + + def test_aten_native_batch_norm_eval(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + False, + 0.2, + 2e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) + def test_aten_ne_Scalar_0(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 20686f2fe6c..1e10706f100 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -13,7 +13,6 @@ "__getitem__", "__rmatmul__", "__rpow__", - "_native_batch_norm_legit", "_segment_reduce", "_upsample_bilinear2d_aa", "argsort", @@ -198,7 +197,6 @@ "nansum", "narrow_copy", "narrow", - "native_batch_norm", "native_layer_norm", "new_empty", "new_empty_strided", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a628a648441..c5ca628908f 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -30,6 +30,7 @@ torch.ops.aten.eq_: torch.ops.aten.eq, torch.ops.aten.ne_: torch.ops.aten.ne, torch.ops.aten.uniform_: torch.ops.aten.uniform, + torch.ops.aten.relu_: torch.ops.aten.relu, } @@ -545,35 +546,67 @@ def create_default_conv_dimension_numbers(num_spatial_dims): def _aten__native_batch_norm_legit( input, weight, bias, running_mean, running_var, training, momentum, eps ): - return _aten__native_batch_norm_legit_no_training( - input, weight, bias, running_mean, running_var, momentum, eps - ) + """JAX implementation of batch normalization with optional parameters. + Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713. + + Args: + input (DeviceArray): Input data (N, C, H, W). + running_mean ([DeviceArray]): Running mean of input (C,). + running_var ([DeviceArray]): Running variance of input (C,). + weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None. + bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None. + training (bool): If True, use batch statistics for normalization. + If False, use running statistics. + momentum (float): Momentum factor for updating running statistics. + eps (float): Small constant for numerical stability. + + Returns: + DeviceArray: Normalized output + DeviceArray: Batch mean (C,) or empty if training is False + DeviceArray: Reversed batch variance (C,) or empty if training is False + """ + reduction_dims = [0] + list(range(2, input.ndim)) + reshape_dims = [1, -1] + [1]*(input.ndim-2) + + if training: + # Calculate batch mean and variance + mean = jnp.mean(input, axis=reduction_dims, keepdims=True) + saved_mean = jnp.squeeze(mean, reduction_dims) + var = jnp.var(input, axis=reduction_dims) + rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps) + # Update running statistics using momentum + running_mean = (1 - momentum) * running_mean + momentum * saved_mean + running_var = (1 - momentum) * running_var + momentum * var + saved_rstd = jnp.squeeze(rstd, reduction_dims) + else: + rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps) + saved_mean = jnp.array([]) # No need to calculate batch statistics in inference mode + saved_rstd = jnp.array([]) + + # Normalize + if training: + # use batch statistics if training + x_hat = (input - mean) * rstd + else: + # Use running statistics in inference mode + x_hat = (input - running_mean.reshape(reshape_dims)) * rstd + + # Scale and shift + if weight is not None: + x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting + if bias is not None: + x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting + + return x_hat, saved_mean, saved_rstd + @op(torch.ops.aten._native_batch_norm_legit_no_training) def _aten__native_batch_norm_legit_no_training( input, weight, bias, running_mean, running_var, momentum, eps ): - if weight is None: - weight = jnp.ones_like(running_mean) - if bias is None: - bias = jnp.zeros_like(running_mean) - - def broadcast(t): - return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,)) - - if running_mean is not None: - a = input - broadcast(running_mean) - else: - a = input - if running_var is not None: - b = broadcast(jnp.sqrt(running_var + eps)) - else: - b = broadcast(jnp.sqrt(eps)) - return ( - a / b * broadcast(weight) + broadcast(bias), - jnp.array([]), - jnp.array([]), + return _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, False, momentum, eps ) @@ -1950,3 +1983,15 @@ def _aten_outer(a, b): def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(input, other, rtol, atol, equal_nan) +@op(torch.ops.aten.native_batch_norm) +def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5): + + if running_mean is None: + running_mean = jnp.zeros(input.shape[1]) # Initialize running mean if None + if running_var is None: + running_var = jnp.ones(input.shape[1]) # Initialize running variance if None + + if training: + return torch.ops.aten._native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps) + else: + return torch.ops.aten._native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps)