From 7d0495eca08ed8debafd34c060b7503184b505bc Mon Sep 17 00:00:00 2001 From: zpcore Date: Wed, 22 May 2024 06:09:41 +0000 Subject: [PATCH] fix core aten ops --- .../torch_xla2/test/test_core_aten_ops.py | 114 ++++++++++++++++ experimental/torch_xla2/test/test_ops.py | 1 - .../torch_xla2/torch_xla2/ops/jaten.py | 126 +++++++++--------- 3 files changed, 177 insertions(+), 64 deletions(-) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 6a1cef306be..175ff01eb03 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -41,6 +41,8 @@ def run_export_and_compare(testcase, with testcase.env: res2 = func(*args2, **kwargs2) res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) + print(res) + print(res2) # import pdb; pdb.set_trace() with testcase.subTest("torch_xla2_diff:" + str(atol)): if ignore_indices and isinstance(res, tuple) and len(res) == 2: @@ -2697,6 +2699,118 @@ def test_aten_native_layer_norm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs) + def test_aten_native_batch_norm_legit(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,2,2)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + False, + 0.5, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs) + + def test_aten_native_batch_norm_legit_none(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,4)).to(torch.float32), + None, + None, + torch.ones(channel), + torch.zeros(channel), + False, + 0.5, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs) + + def test_aten_native_batch_norm_legit_training_none(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + None, + None, + torch.zeros(channel), + torch.ones(channel), + True, + 0.2, + 2e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs) + + def test_aten_native_batch_norm_legit_no_training(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + True, + 0.2, + 2e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs) + + def test_aten_native_batch_norm_training(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + True, + 0.1, + 1e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) + + def test_aten_native_batch_norm_training_none(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + None, + None, + torch.zeros(channel), + torch.ones(channel), + True, + 0.1, + 1e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) + + def test_aten_native_batch_norm_eval(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + False, + 0.2, + 2e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) + def test_aten_ne_Scalar_0(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 1c6a9d77785..1e10706f100 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -13,7 +13,6 @@ "__getitem__", "__rmatmul__", "__rpow__", - "_native_batch_norm_legit", "_segment_reduce", "_upsample_bilinear2d_aa", "argsort", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 6dd94c8e42d..469a2fe9a13 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -546,35 +546,67 @@ def create_default_conv_dimension_numbers(num_spatial_dims): def _aten__native_batch_norm_legit( input, weight, bias, running_mean, running_var, training, momentum, eps ): - return _aten__native_batch_norm_legit_no_training( - input, weight, bias, running_mean, running_var, momentum, eps - ) + """JAX implementation of batch normalization with optional parameters. + Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713. + + Args: + input (DeviceArray): Input data (N, C, H, W). + running_mean ([DeviceArray]): Running mean of input (C,). + running_var ([DeviceArray]): Running variance of input (C,). + weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None. + bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None. + training (bool): If True, use batch statistics for normalization. + If False, use running statistics. + momentum (float): Momentum factor for updating running statistics. + eps (float): Small constant for numerical stability. + + Returns: + DeviceArray: Normalized output + DeviceArray: Batch mean (C,) or empty if training is False + DeviceArray: Reversed batch variance (C,) or empty if training is False + """ + reduction_dims = [0] + list(range(2, input.ndim)) + reshape_dims = [1, -1] + [1]*(input.ndim-2) + + if training: + # Calculate batch mean and variance + mean = jnp.mean(input, axis=reduction_dims, keepdims=True) + saved_mean = jnp.squeeze(mean, reduction_dims) + var = jnp.var(input, axis=reduction_dims) + rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps) + # Update running statistics using momentum + running_mean = (1 - momentum) * running_mean + momentum * saved_mean + running_var = (1 - momentum) * running_var + momentum * var + saved_rstd = jnp.squeeze(rstd, reduction_dims) + else: + rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps) + saved_mean = jnp.array([]) # No need to calculate batch statistics in inference mode + saved_rstd = jnp.array([]) + + # Normalize + if training: + # use batch statistics if training + x_hat = (input - mean) * rstd + else: + # Use running statistics in inference mode + x_hat = (input - running_mean.reshape(reshape_dims)) * rstd + + # Scale and shift + if weight is not None: + x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting + if bias is not None: + x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting + + return x_hat, saved_mean, saved_rstd + @op(torch.ops.aten._native_batch_norm_legit_no_training) def _aten__native_batch_norm_legit_no_training( input, weight, bias, running_mean, running_var, momentum, eps ): - if weight is None: - weight = jnp.ones_like(running_mean) - if bias is None: - bias = jnp.zeros_like(running_mean) - - def broadcast(t): - return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,)) - - if running_mean is not None: - a = input - broadcast(running_mean) - else: - a = input - if running_var is not None: - b = broadcast(jnp.sqrt(running_var + eps)) - else: - b = broadcast(jnp.sqrt(eps)) - return ( - a / b * broadcast(weight) + broadcast(bias), - jnp.array([]), - jnp.array([]), + return _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, False, momentum, eps ) @@ -1953,45 +1985,13 @@ def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): @op(torch.ops.aten.native_batch_norm) def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5): - """JAX implementation of batch normalization. - - Args: - input: Input data (N, C, ...) - weight: Scaling factor (gamma) (C,), can be None - bias: Shift factor (beta) (C,), can be None - running_mean: Running mean of input (C,) - running_var: Running variance of input (C,) - training: Whether to perform training-time or inference-time batch norm - momentum: Momentum factor for updating running mean and variance - eps: Small constant added to the variance to avoid division by zero - - Returns: - Output data, updated running mean, updated running var - """ - + + if running_mean is None: + running_mean = jnp.zeros(input.shape[1]) # Initialize running mean if None + if running_var is None: + running_var = jnp.ones(input.shape[1]) # Initialize running variance if None + if training: - # Training-time batch norm: compute statistics across the batch - mean = jnp.mean(input, axis=(0, 2, 3)) - var = jnp.var(input, axis=(0, 2, 3)) - - # Update running statistics - running_mean = momentum * mean + (1 - momentum) * running_mean - running_var = momentum * var + (1 - momentum) * running_var - + return torch.ops.aten._native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps) else: - # Inference-time batch norm: use pre-computed running statistics - mean = running_mean - var = running_var - - # Normalize - xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch - ivar = jax.lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root - - # Scale and shift - out = xmu * ivar - if weight is not None: - out *= weight.reshape(1, -1, 1, 1) - if bias is not None: - out += bias.reshape(1, -1, 1, 1) - - return out, running_mean, running_var + return torch.ops.aten._native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps) \ No newline at end of file