From b64268da34a8c2756ae894c894aa94d2885d4d68 Mon Sep 17 00:00:00 2001 From: zpcore Date: Fri, 17 May 2024 22:10:35 +0000 Subject: [PATCH 1/9] add missing aten op --- .../torch_xla2/torch_xla2/ops/jaten.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a628a648441..a4e421e1faf 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -30,6 +30,7 @@ torch.ops.aten.eq_: torch.ops.aten.eq, torch.ops.aten.ne_: torch.ops.aten.ne, torch.ops.aten.uniform_: torch.ops.aten.uniform, + torch.ops.aten.relu_: torch.ops.aten.relu, } @@ -1950,3 +1951,47 @@ def _aten_outer(a, b): def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(input, other, rtol, atol, equal_nan) +@op(torch.ops.aten.native_batch_norm) +def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5): + """JAX implementation of batch normalization. + + Args: + input: Input data (N, C, ...) + running_mean: Running mean of input (C,) + running_var: Running variance of input (C,) + weight: Optional scaling factor (gamma) (C,) + bias: Optional shift factor (beta) (C,) + training: Whether to perform training-time or inference-time batch norm + momentum: Momentum factor for updating running mean and variance + eps: Small constant added to the variance to avoid division by zero + + Returns: + Output data, updated running mean, updated running var + """ + + if training: + # Training-time batch norm: compute statistics across the batch + mean = jnp.mean(input, axis=(0, 2, 3)) + var = jnp.var(input, axis=(0, 2, 3)) + + # Update running statistics + running_mean = momentum * mean + (1 - momentum) * running_mean + running_var = momentum * var + (1 - momentum) * running_var + + else: + # Inference-time batch norm: use pre-computed running statistics + mean = running_mean + var = running_var + + # Normalize + xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch + ivar = lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root + + # Scale and shift + out = xmu * ivar + if weight is not None: + out *= weight.reshape(1, -1, 1, 1) + if bias is not None: + out += bias.reshape(1, -1, 1, 1) + + return out, running_mean, running_var \ No newline at end of file From 4e19dfba8cbd449f498656df7a8b5745e39ece3c Mon Sep 17 00:00:00 2001 From: zpcore Date: Fri, 17 May 2024 22:12:31 +0000 Subject: [PATCH 2/9] nit update --- .../torch_xla2/torch_xla2/ops/jaten.py | 387 +++++++++--------- 1 file changed, 191 insertions(+), 196 deletions(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a4e421e1faf..ea6d4351c5d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -18,19 +18,19 @@ # and need to be implemented in jax mutation_ops_to_functional = { - torch.ops.aten.add_: torch.ops.aten.add, - torch.ops.aten.sub_: torch.ops.aten.sub, - torch.ops.aten.mul_: torch.ops.aten.mul, - torch.ops.aten.div_: torch.ops.aten.div, - torch.ops.aten.pow_: torch.ops.aten.pow, - torch.ops.aten.lt_: torch.ops.aten.lt, - torch.ops.aten.le_: torch.ops.aten.le, - torch.ops.aten.gt_: torch.ops.aten.gt, - torch.ops.aten.ge_: torch.ops.aten.ge, - torch.ops.aten.eq_: torch.ops.aten.eq, - torch.ops.aten.ne_: torch.ops.aten.ne, - torch.ops.aten.uniform_: torch.ops.aten.uniform, - torch.ops.aten.relu_: torch.ops.aten.relu, + torch.ops.aten.add_: torch.ops.aten.add, + torch.ops.aten.sub_: torch.ops.aten.sub, + torch.ops.aten.mul_: torch.ops.aten.mul, + torch.ops.aten.div_: torch.ops.aten.div, + torch.ops.aten.pow_: torch.ops.aten.pow, + torch.ops.aten.lt_: torch.ops.aten.lt, + torch.ops.aten.le_: torch.ops.aten.le, + torch.ops.aten.gt_: torch.ops.aten.gt, + torch.ops.aten.ge_: torch.ops.aten.ge, + torch.ops.aten.eq_: torch.ops.aten.eq, + torch.ops.aten.ne_: torch.ops.aten.ne, + torch.ops.aten.uniform_: torch.ops.aten.uniform, + torch.ops.aten.relu_: torch.ops.aten.relu, } @@ -40,11 +40,11 @@ def make_mutation(op): for op in mutation_ops_to_functional.keys(): ops_registry.register_torch_dispatch_op( - op, make_mutation(op), is_jax_function=False - ) + op, make_mutation(op), is_jax_function=False) def op(*aten, **kwargs): + def inner(func): for a in aten: ops_registry.register_torch_dispatch_op(a, func, **kwargs) @@ -54,10 +54,10 @@ def inner(func): @op( - torch.ops.aten.view_copy, - torch.ops.aten.view, - torch.ops.aten._unsafe_view, - torch.ops.aten.reshape, + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, ) def _aten_unsafe_view(x, shape): return jnp.reshape(x, shape) @@ -277,6 +277,7 @@ def _aten_rsqrt(x): @op(torch.ops.aten.expand) @op(torch.ops.aten.expand_copy) def _aten_expand(x, dims): + def fix_dims(d, xs): if d == -1: return xs @@ -350,8 +351,8 @@ def make_range(rank, dim, start, end): return tuple(res) return [ - x[make_range(rank, dim, start, end)] - for start, end in zip([0] + list(splits[:-1]), splits) + x[make_range(rank, dim, start, end)] + for start, end in zip([0] + list(splits[:-1]), splits) ] @@ -384,9 +385,11 @@ def _aten_cumsum(x, y, dtype=None): @op(torch.ops.aten.native_layer_norm) -def _aten_native_layer_norm( - input, normalized_shape, weight=None, bias=None, eps=1e-5 -): +def _aten_native_layer_norm(input, + normalized_shape, + weight=None, + bias=None, + eps=1e-5): """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. Args: @@ -437,9 +440,8 @@ def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): alpha = jnp.array(alpha).astype(batch1.dtype) beta = jnp.array(beta).astype(batch1.dtype) mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) - return jax.lax.cond( - beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm - ) + return jax.lax.cond(beta == 0, lambda: alpha * mm, + lambda: beta * input + alpha * mm) @op(torch.ops.aten.gelu) @@ -487,15 +489,15 @@ def fix_dim(p): @op(torch.ops.aten.convolution) def _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, ): if transposed: raise NotImplementedError("Transposed convolution is not implemented.") @@ -516,19 +518,18 @@ def create_default_conv_dimension_numbers(num_spatial_dims): rhs_spec.append(i + 2) out_spec.append(i + 2) return jax.lax.ConvDimensionNumbers( - *map(tuple, (lhs_spec, rhs_spec, out_spec)) - ) + *map(tuple, (lhs_spec, rhs_spec, out_spec))) res = jax.lax.conv_general_dilated( - input, - weight, - stride, - make_padding(padding), - lhs_dilation=(1,) * len(stride), - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, + input, + weight, + stride, + make_padding(padding), + lhs_dilation=(1,) * len(stride), + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, ) if bias is not None: @@ -543,18 +544,17 @@ def create_default_conv_dimension_numbers(num_spatial_dims): # _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) @op(torch.ops.aten._native_batch_norm_legit) -def _aten__native_batch_norm_legit( - input, weight, bias, running_mean, running_var, training, momentum, eps -): - return _aten__native_batch_norm_legit_no_training( - input, weight, bias, running_mean, running_var, momentum, eps - ) +def _aten__native_batch_norm_legit(input, weight, bias, running_mean, + running_var, training, momentum, eps): + return _aten__native_batch_norm_legit_no_training(input, weight, bias, + running_mean, running_var, + momentum, eps) @op(torch.ops.aten._native_batch_norm_legit_no_training) -def _aten__native_batch_norm_legit_no_training( - input, weight, bias, running_mean, running_var, momentum, eps -): +def _aten__native_batch_norm_legit_no_training(input, weight, bias, + running_mean, running_var, + momentum, eps): if weight is None: weight = jnp.ones_like(running_mean) if bias is None: @@ -572,9 +572,9 @@ def broadcast(t): else: b = broadcast(jnp.sqrt(eps)) return ( - a / b * broadcast(weight) + broadcast(bias), - jnp.array([]), - jnp.array([]), + a / b * broadcast(weight) + broadcast(bias), + jnp.array([]), + jnp.array([]), ) @@ -590,9 +590,12 @@ def _aten_cat(tensors, dims=0): @op(torch.ops.aten.max_pool2d_with_indices) @op(torch.ops.aten.max_pool3d_with_indices) -def _aten_max_pool2d_with_indices( - inputs, kernel_size, strides, padding=0, dilation=1, ceil_mode=False -): +def _aten_max_pool2d_with_indices(inputs, + kernel_size, + strides, + padding=0, + dilation=1, + ceil_mode=False): num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 kernel_size = tuple(kernel_size) strides = tuple(strides) @@ -605,8 +608,7 @@ def _aten_max_pool2d_with_indices( num_batch_dims = inputs.ndim - (len(window_shape) + 1) strides = strides or (1,) * len(window_shape) assert len(window_shape) == len( - strides - ), f"len({window_shape}) must equal len({strides})" + strides), f"len({window_shape}) must equal len({strides})" strides = (1,) * (1 + num_batch_dims) + strides dims = (1,) * (1 + num_batch_dims) + window_shape @@ -623,12 +625,10 @@ def _aten_max_pool2d_with_indices( if not isinstance(padding, str): padding = tuple(map(tuple, padding)) assert len(padding) == len(window_shape), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}" - ) - assert all( - [len(x) == 2 for x in padding] - ), f"each entry in padding {padding} must be length 2" + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}") + assert all([len(x) == 2 for x in padding + ]), f"each entry in padding {padding} must be length 2" padding = ((0, 0), (0, 0)) + padding indices = jnp.arange(np.prod(inputs.shape)).reshape(inputs.shape) @@ -644,17 +644,15 @@ def reduce_fn(a, b): init_val = -(1 << 31) init_val = jnp.array(init_val).astype(inputs.dtype) - indices, y = jax.lax.reduce_window( - (indices, inputs), (0, init_val), reduce_fn, dims, strides, padding - ) + indices, y = jax.lax.reduce_window((indices, inputs), (0, init_val), + reduce_fn, dims, strides, padding) if is_single_input: indices = jnp.squeeze(indices, axis=0) y = jnp.squeeze(y, axis=0) return y, indices - batch_result = pool( - inputs, -jnp.inf, jax.lax.max, kernel_size, strides, padding - ) + batch_result = pool(inputs, -jnp.inf, jax.lax.max, kernel_size, strides, + padding) indices = pool(inputs, 0, jnp.argmax, kernel_size, strides, padding) return batch_result, indices @@ -696,8 +694,7 @@ def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): @op(torch.ops.prims.broadcast_in_dim) def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): return jax.lax.broadcast_in_dim( - t, shape, broadcast_dimensions=broadcast_dimensions - ) + t, shape, broadcast_dimensions=broadcast_dimensions) # aten.native_group_norm -- should use decomp table @@ -737,17 +734,15 @@ def group_norm_body(x): # Function to apply within each group normalized = (x - mean) * rstd return normalized, mean, rstd - normalized, group_mean, group_rstd = jax.lax.map( - group_norm_body, reshaped_input - ) + normalized, group_mean, group_rstd = jax.lax.map(group_norm_body, + reshaped_input) # Reshape back to original input shape output = jnp.reshape(normalized, input_shape) # **Affine transformation** - affine_shape = [ - -1 if i == 1 else 1 for i in range(input.ndim) - ] # Shape for broadcasting + affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim) + ] # Shape for broadcasting if weight is not None and bias is not None: output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) elif weight is not None: @@ -781,13 +776,12 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): if ord not in {2, float("inf"), float("-inf"), "fro"}: raise ValueError( - f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" - " 'fro'." - ) + f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" + " 'fro'.") # Special cases (for efficiency and clarity) if ord == 2: # Euclidean norm - result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + result = jnp.sqrt(jnp.sum(jnp.abs(self)**2, axis=dim, keepdims=keepdim)) elif ord == float("inf"): result = jnp.max(jnp.abs(self), axis=dim, keepdims=keepdim) @@ -796,12 +790,11 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): result = jnp.min(jnp.abs(self), axis=dim, keepdims=keepdim) elif ord == "fro": # Frobenius norm - result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + result = jnp.sqrt(jnp.sum(jnp.abs(self)**2, axis=dim, keepdims=keepdim)) else: # General case (e.g., ord = 1, ord = 3) - result = jnp.sum(jnp.abs(self) ** ord, axis=dim, keepdims=keepdim) ** ( - 1.0 / ord - ) + result = jnp.sum( + jnp.abs(self)**ord, axis=dim, keepdims=keepdim)**(1.0 / ord) # (Optional) dtype conversion if dtype is not None: @@ -833,9 +826,12 @@ def _aten_sinh(self): # aten.native_layer_norm_backward @op(torch.ops.aten.native_layer_norm_backward) -def _aten_native_layer_norm_backward( - grad_out, input, normalized_shape, weight, bias, eps=1e-5 -): +def _aten_native_layer_norm_backward(grad_out, + input, + normalized_shape, + weight, + bias, + eps=1e-5): """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. Args: @@ -849,9 +845,8 @@ def _aten_native_layer_norm_backward( Returns: A tuple of (grad_input, grad_weight, grad_bias). """ - return jax.lax.native_layer_norm_backward( - grad_out, input, normalized_shape, weight, bias, eps - ) + return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape, + weight, bias, eps) # aten.reflection_pad3d_backward @@ -936,10 +931,8 @@ def _scatter_index(dim, index): target_shape = [1] * len(index_shape) target_shape[i] = index_shape[i] input_indexes.append( - jnp.broadcast_to( - jnp.arange(index_shape[i]).reshape(target_shape), index_shape - ) - ) + jnp.broadcast_to( + jnp.arange(index_shape[i]).reshape(target_shape), index_shape)) return tuple(input_indexes), tuple(source_indexes) @@ -1033,19 +1026,17 @@ def _aten_pixel_shuffle(x, upscale_factor): if channels % (upscale_factor**2) != 0: raise ValueError( - "Number of channels must be divisible by the square of the upscale factor." + "Number of channels must be divisible by the square of the upscale factor." ) new_channels = channels // (upscale_factor**2) new_height = height * upscale_factor new_width = width * upscale_factor - x = x.reshape( - batch_size, new_channels, upscale_factor, upscale_factor, height, width - ) - x = jnp.transpose( - x, (0, 1, 2, 4, 3, 5) - ) # Move channels to spatial dimensions + x = x.reshape(batch_size, new_channels, upscale_factor, upscale_factor, + height, width) + x = jnp.transpose(x, + (0, 1, 2, 4, 3, 5)) # Move channels to spatial dimensions x = x.reshape(batch_size, new_channels, new_height, new_width) return x @@ -1082,8 +1073,7 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): num_batch_dims = inputs.ndim - (len(window_shape) + 1) strides = strides or (1,) * len(window_shape) assert len(window_shape) == len( - strides - ), f"len({window_shape}) must equal len({strides})" + strides), f"len({window_shape}) must equal len({strides})" strides = (1,) * (1 + num_batch_dims) + strides dims = (1,) * (1 + num_batch_dims) + window_shape @@ -1100,12 +1090,10 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): if not isinstance(padding, str): padding = tuple(map(tuple, padding)) assert len(padding) == len(window_shape), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}" - ) - assert all( - [len(x) == 2 for x in padding] - ), f"each entry in padding {padding} must be length 2" + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}") + assert all([len(x) == 2 for x in padding + ]), f"each entry in padding {padding} must be length 2" padding = ((0, 0), (0, 0)) + padding y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) if is_single_input: @@ -1124,13 +1112,13 @@ def _aten_adaptive_avg_pool3d(x, output_shape): def _aten_adaptive_avg_pool(x, output_shape, pool_dim): + def adaptive_kernel_size(input_shape, output_shape): sizes = [1, 1] spatial_dim_off = len(input_shape) - pool_dim for spatial_dim in range(pool_dim): - sizes.append( - input_shape[spatial_dim_off + spatial_dim] // output_shape[spatial_dim] - ) + sizes.append(input_shape[spatial_dim_off + spatial_dim] // + output_shape[spatial_dim]) return tuple(sizes) kernel_sizes = adaptive_kernel_size(x.shape, output_shape) @@ -1143,8 +1131,8 @@ def adaptive_kernel_size(input_shape, output_shape): if len(div_shape) - 2 == len(kernel_sizes): div_shape = (1,) + div_shape[1:] y = y / pool( - jnp.ones(div_shape), 0.0, jax.lax.add, kernel_sizes, kernel_sizes, "VALID" - ) + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_sizes, kernel_sizes, + "VALID") return y @@ -1152,13 +1140,13 @@ def adaptive_kernel_size(input_shape, output_shape): @op(torch.ops.aten.avg_pool2d) @op(torch.ops.aten.avg_pool3d) def _aten_avg_pool( - inputs, - kernel_size, - strides=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None, + inputs, + kernel_size, + strides=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, ): num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 kernel_size = tuple(kernel_size) @@ -1178,8 +1166,7 @@ def _aten_avg_pool( if len(div_shape) - 2 == len(kernel_size): div_shape = (1,) + div_shape[1:] y = y / pool( - jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding - ) + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding) return y @@ -1232,9 +1219,9 @@ def _aten_round(input, decimals=0): # aten.max @op(torch.ops.aten.max) def _aten_max(self, dim=None, keepdim=False): - return jnp.max(self, axis=dim, keepdims=keepdim), jnp.argmax( - self, axis=dim, keepdims=keepdim - ) + return jnp.max( + self, axis=dim, keepdims=keepdim), jnp.argmax( + self, axis=dim, keepdims=keepdim) # aten.maximum @@ -1281,15 +1268,15 @@ def _aten_any(self, dim=None, keepdim=False): @op(torch.ops.aten.arange.start) @op(torch.ops.aten.arange.default) def _aten_arange( - start, - end=None, - step=1, - *, - dtype=None, - layout=None, - requires_grad=False, - device=None, - pin_memory=False, + start, + end=None, + step=1, + *, + dtype=None, + layout=None, + requires_grad=False, + device=None, + pin_memory=False, ): if end is None: end = start @@ -1297,10 +1284,10 @@ def _aten_arange( if dtype: dtype = tensor.t2j_dtype(dtype) return jnp.arange( - start, - end, - step, - dtype=dtype, + start, + end, + step, + dtype=dtype, ) @@ -1389,9 +1376,8 @@ def _aten_cdist_forward(x1, x2, p, compute_mode=""): @op(torch.ops.aten._pdist_forward) def _aten__pdist_forward(x, p): pairwise_dists = _aten_cdist_forward(x, x, p) - condensed_dists = pairwise_dists[ - jnp.triu_indices(pairwise_dists.shape[0], k=1) - ] + condensed_dists = pairwise_dists[jnp.triu_indices( + pairwise_dists.shape[0], k=1)] return condensed_dists @@ -1602,7 +1588,6 @@ def _aten_prod(self, dim=None, keepdim=False): # aten.randperm - # aten.reflection_pad3d @@ -1644,8 +1629,8 @@ def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): @op(torch.ops.aten.sort) def _aten_sort(a, dim=-1, descending=False, stable=False): return ( - jnp.sort(a, axis=dim, stable=stable, descending=descending), - jnp.argsort(a, axis=dim, stable=stable, descending=descending), + jnp.sort(a, axis=dim, stable=stable, descending=descending), + jnp.argsort(a, axis=dim, stable=stable, descending=descending), ) @@ -1684,8 +1669,8 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): if dim != -1 and dim != len(input.shape) - 1: transpose_shape = list(range(len(input.shape))) transpose_shape[dim], transpose_shape[-1] = ( - transpose_shape[-1], - transpose_shape[dim], + transpose_shape[-1], + transpose_shape[dim], ) input = jnp.transpose(input, transpose_shape) @@ -1694,8 +1679,7 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): if sorted: values = jnp.sort(values, descending=True) indices = jnp.take_along_axis( - indices, jnp.argsort(values, axis=-1, descending=True), axis=-1 - ) + indices, jnp.argsort(values, axis=-1, descending=True), axis=-1) if not largest: values = -values # Negate values back if we found smallest @@ -1717,9 +1701,8 @@ def _aten_trunc(a): @op(torch.ops.aten.unbind_copy) def _aten_unbind(a, dim=0): return tuple( - _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) - for i in range(a.shape[dim]) - ) + _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) + for i in range(a.shape[dim])) # NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d @@ -1738,9 +1721,11 @@ def _aten_where(condition, x, y): # aten.to.dtype # Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None @op(torch.ops.aten.to.dtype) -def _aten_to_dtype( - a, dtype, non_blocking=False, copy=False, memory_format=None -): +def _aten_to_dtype(a, + dtype, + non_blocking=False, + copy=False, + memory_format=None): if dtype: jaxdtype = tensor.t2j_dtype(dtype) return a.astype(jaxdtype) @@ -1753,15 +1738,17 @@ def _aten_to_dtype( @op(torch.ops.aten.var_mean.correction) def _aten_var_mean_correction(self, dim=None, correction=None, keepdim=False): return ( - jnp.var(self, axis=dim, ddof=correction, keepdims=keepdim), - jnp.mean(self, dim, keepdims=keepdim), + jnp.var(self, axis=dim, ddof=correction, keepdims=keepdim), + jnp.mean(self, dim, keepdims=keepdim), ) @op(torch.ops.aten.scalar_tensor) -def _aten_scalar_tensor( - s, dtype=None, layout=None, device=None, pin_memory=None -): +def _aten_scalar_tensor(s, + dtype=None, + layout=None, + device=None, + pin_memory=None): if dtype is not None: dtype = tensor.t2j_dtype(dtype) return jnp.array(s, dtype=dtype) @@ -1774,9 +1761,9 @@ def _aten_to_device(x, device, dtype): @op(torch.ops.aten.max_pool2d_with_indices_backward) -def max_pool2d_with_indices_backward_custom( - grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices -): +def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, + stride, padding, dilation, + ceil_mode, indices): """ Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. @@ -1832,15 +1819,15 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0): @op(torch.ops.aten.randn, needs_env=True) @op_base.convert_dtype() def _randn( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, ): shape = size if len(shape) == 1 and isinstance(shape[0], (list, tuple)): @@ -1855,15 +1842,15 @@ def _randn( @op(torch.ops.aten.rand, needs_env=True) @op_base.convert_dtype() def _rand( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, ): shape = size if len(shape) == 1 and isinstance(shape[0], (list, tuple)): @@ -1887,9 +1874,9 @@ def _aten_to_device(x, device, dtype): @op(torch.ops.aten.max_pool2d_with_indices_backward) -def max_pool2d_with_indices_backward_custom( - grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices -): +def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, + stride, padding, dilation, + ceil_mode, indices): """ Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. @@ -1951,8 +1938,16 @@ def _aten_outer(a, b): def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(input, other, rtol, atol, equal_nan) + @op(torch.ops.aten.native_batch_norm) -def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5): +def _aten_native_batch_norm(input, + weight, + bias, + running_mean, + running_var, + training=False, + momentum=0.1, + eps=1e-5): """JAX implementation of batch normalization. Args: @@ -1990,8 +1985,8 @@ def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, trai # Scale and shift out = xmu * ivar if weight is not None: - out *= weight.reshape(1, -1, 1, 1) + out *= weight.reshape(1, -1, 1, 1) if bias is not None: - out += bias.reshape(1, -1, 1, 1) + out += bias.reshape(1, -1, 1, 1) - return out, running_mean, running_var \ No newline at end of file + return out, running_mean, running_var From 7e43402d92a4cf17684f8fa9c96707c946ebabb1 Mon Sep 17 00:00:00 2001 From: zpcore Date: Fri, 17 May 2024 22:13:27 +0000 Subject: [PATCH 3/9] revert back format... --- .../torch_xla2/torch_xla2/ops/jaten.py | 385 +++++++++--------- 1 file changed, 195 insertions(+), 190 deletions(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index ea6d4351c5d..5f07f09086d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -18,19 +18,19 @@ # and need to be implemented in jax mutation_ops_to_functional = { - torch.ops.aten.add_: torch.ops.aten.add, - torch.ops.aten.sub_: torch.ops.aten.sub, - torch.ops.aten.mul_: torch.ops.aten.mul, - torch.ops.aten.div_: torch.ops.aten.div, - torch.ops.aten.pow_: torch.ops.aten.pow, - torch.ops.aten.lt_: torch.ops.aten.lt, - torch.ops.aten.le_: torch.ops.aten.le, - torch.ops.aten.gt_: torch.ops.aten.gt, - torch.ops.aten.ge_: torch.ops.aten.ge, - torch.ops.aten.eq_: torch.ops.aten.eq, - torch.ops.aten.ne_: torch.ops.aten.ne, - torch.ops.aten.uniform_: torch.ops.aten.uniform, - torch.ops.aten.relu_: torch.ops.aten.relu, + torch.ops.aten.add_: torch.ops.aten.add, + torch.ops.aten.sub_: torch.ops.aten.sub, + torch.ops.aten.mul_: torch.ops.aten.mul, + torch.ops.aten.div_: torch.ops.aten.div, + torch.ops.aten.pow_: torch.ops.aten.pow, + torch.ops.aten.lt_: torch.ops.aten.lt, + torch.ops.aten.le_: torch.ops.aten.le, + torch.ops.aten.gt_: torch.ops.aten.gt, + torch.ops.aten.ge_: torch.ops.aten.ge, + torch.ops.aten.eq_: torch.ops.aten.eq, + torch.ops.aten.ne_: torch.ops.aten.ne, + torch.ops.aten.uniform_: torch.ops.aten.uniform, + torch.ops.aten.relu_: torch.ops.aten.relu, } @@ -40,11 +40,11 @@ def make_mutation(op): for op in mutation_ops_to_functional.keys(): ops_registry.register_torch_dispatch_op( - op, make_mutation(op), is_jax_function=False) + op, make_mutation(op), is_jax_function=False + ) def op(*aten, **kwargs): - def inner(func): for a in aten: ops_registry.register_torch_dispatch_op(a, func, **kwargs) @@ -54,10 +54,10 @@ def inner(func): @op( - torch.ops.aten.view_copy, - torch.ops.aten.view, - torch.ops.aten._unsafe_view, - torch.ops.aten.reshape, + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, ) def _aten_unsafe_view(x, shape): return jnp.reshape(x, shape) @@ -277,7 +277,6 @@ def _aten_rsqrt(x): @op(torch.ops.aten.expand) @op(torch.ops.aten.expand_copy) def _aten_expand(x, dims): - def fix_dims(d, xs): if d == -1: return xs @@ -351,8 +350,8 @@ def make_range(rank, dim, start, end): return tuple(res) return [ - x[make_range(rank, dim, start, end)] - for start, end in zip([0] + list(splits[:-1]), splits) + x[make_range(rank, dim, start, end)] + for start, end in zip([0] + list(splits[:-1]), splits) ] @@ -385,11 +384,9 @@ def _aten_cumsum(x, y, dtype=None): @op(torch.ops.aten.native_layer_norm) -def _aten_native_layer_norm(input, - normalized_shape, - weight=None, - bias=None, - eps=1e-5): +def _aten_native_layer_norm( + input, normalized_shape, weight=None, bias=None, eps=1e-5 +): """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. Args: @@ -440,8 +437,9 @@ def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): alpha = jnp.array(alpha).astype(batch1.dtype) beta = jnp.array(beta).astype(batch1.dtype) mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) - return jax.lax.cond(beta == 0, lambda: alpha * mm, - lambda: beta * input + alpha * mm) + return jax.lax.cond( + beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm + ) @op(torch.ops.aten.gelu) @@ -489,15 +487,15 @@ def fix_dim(p): @op(torch.ops.aten.convolution) def _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, ): if transposed: raise NotImplementedError("Transposed convolution is not implemented.") @@ -518,18 +516,19 @@ def create_default_conv_dimension_numbers(num_spatial_dims): rhs_spec.append(i + 2) out_spec.append(i + 2) return jax.lax.ConvDimensionNumbers( - *map(tuple, (lhs_spec, rhs_spec, out_spec))) + *map(tuple, (lhs_spec, rhs_spec, out_spec)) + ) res = jax.lax.conv_general_dilated( - input, - weight, - stride, - make_padding(padding), - lhs_dilation=(1,) * len(stride), - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, + input, + weight, + stride, + make_padding(padding), + lhs_dilation=(1,) * len(stride), + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, ) if bias is not None: @@ -544,17 +543,18 @@ def create_default_conv_dimension_numbers(num_spatial_dims): # _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) @op(torch.ops.aten._native_batch_norm_legit) -def _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, training, momentum, eps): - return _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps) +def _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps +): + return _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) @op(torch.ops.aten._native_batch_norm_legit_no_training) -def _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps): +def _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps +): if weight is None: weight = jnp.ones_like(running_mean) if bias is None: @@ -572,9 +572,9 @@ def broadcast(t): else: b = broadcast(jnp.sqrt(eps)) return ( - a / b * broadcast(weight) + broadcast(bias), - jnp.array([]), - jnp.array([]), + a / b * broadcast(weight) + broadcast(bias), + jnp.array([]), + jnp.array([]), ) @@ -590,12 +590,9 @@ def _aten_cat(tensors, dims=0): @op(torch.ops.aten.max_pool2d_with_indices) @op(torch.ops.aten.max_pool3d_with_indices) -def _aten_max_pool2d_with_indices(inputs, - kernel_size, - strides, - padding=0, - dilation=1, - ceil_mode=False): +def _aten_max_pool2d_with_indices( + inputs, kernel_size, strides, padding=0, dilation=1, ceil_mode=False +): num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 kernel_size = tuple(kernel_size) strides = tuple(strides) @@ -608,7 +605,8 @@ def _aten_max_pool2d_with_indices(inputs, num_batch_dims = inputs.ndim - (len(window_shape) + 1) strides = strides or (1,) * len(window_shape) assert len(window_shape) == len( - strides), f"len({window_shape}) must equal len({strides})" + strides + ), f"len({window_shape}) must equal len({strides})" strides = (1,) * (1 + num_batch_dims) + strides dims = (1,) * (1 + num_batch_dims) + window_shape @@ -625,10 +623,12 @@ def _aten_max_pool2d_with_indices(inputs, if not isinstance(padding, str): padding = tuple(map(tuple, padding)) assert len(padding) == len(window_shape), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}") - assert all([len(x) == 2 for x in padding - ]), f"each entry in padding {padding} must be length 2" + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}" + ) + assert all( + [len(x) == 2 for x in padding] + ), f"each entry in padding {padding} must be length 2" padding = ((0, 0), (0, 0)) + padding indices = jnp.arange(np.prod(inputs.shape)).reshape(inputs.shape) @@ -644,15 +644,17 @@ def reduce_fn(a, b): init_val = -(1 << 31) init_val = jnp.array(init_val).astype(inputs.dtype) - indices, y = jax.lax.reduce_window((indices, inputs), (0, init_val), - reduce_fn, dims, strides, padding) + indices, y = jax.lax.reduce_window( + (indices, inputs), (0, init_val), reduce_fn, dims, strides, padding + ) if is_single_input: indices = jnp.squeeze(indices, axis=0) y = jnp.squeeze(y, axis=0) return y, indices - batch_result = pool(inputs, -jnp.inf, jax.lax.max, kernel_size, strides, - padding) + batch_result = pool( + inputs, -jnp.inf, jax.lax.max, kernel_size, strides, padding + ) indices = pool(inputs, 0, jnp.argmax, kernel_size, strides, padding) return batch_result, indices @@ -694,7 +696,8 @@ def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): @op(torch.ops.prims.broadcast_in_dim) def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): return jax.lax.broadcast_in_dim( - t, shape, broadcast_dimensions=broadcast_dimensions) + t, shape, broadcast_dimensions=broadcast_dimensions + ) # aten.native_group_norm -- should use decomp table @@ -734,15 +737,17 @@ def group_norm_body(x): # Function to apply within each group normalized = (x - mean) * rstd return normalized, mean, rstd - normalized, group_mean, group_rstd = jax.lax.map(group_norm_body, - reshaped_input) + normalized, group_mean, group_rstd = jax.lax.map( + group_norm_body, reshaped_input + ) # Reshape back to original input shape output = jnp.reshape(normalized, input_shape) # **Affine transformation** - affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim) - ] # Shape for broadcasting + affine_shape = [ + -1 if i == 1 else 1 for i in range(input.ndim) + ] # Shape for broadcasting if weight is not None and bias is not None: output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) elif weight is not None: @@ -776,12 +781,13 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): if ord not in {2, float("inf"), float("-inf"), "fro"}: raise ValueError( - f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" - " 'fro'.") + f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" + " 'fro'." + ) # Special cases (for efficiency and clarity) if ord == 2: # Euclidean norm - result = jnp.sqrt(jnp.sum(jnp.abs(self)**2, axis=dim, keepdims=keepdim)) + result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) elif ord == float("inf"): result = jnp.max(jnp.abs(self), axis=dim, keepdims=keepdim) @@ -790,11 +796,12 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): result = jnp.min(jnp.abs(self), axis=dim, keepdims=keepdim) elif ord == "fro": # Frobenius norm - result = jnp.sqrt(jnp.sum(jnp.abs(self)**2, axis=dim, keepdims=keepdim)) + result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) else: # General case (e.g., ord = 1, ord = 3) - result = jnp.sum( - jnp.abs(self)**ord, axis=dim, keepdims=keepdim)**(1.0 / ord) + result = jnp.sum(jnp.abs(self) ** ord, axis=dim, keepdims=keepdim) ** ( + 1.0 / ord + ) # (Optional) dtype conversion if dtype is not None: @@ -826,12 +833,9 @@ def _aten_sinh(self): # aten.native_layer_norm_backward @op(torch.ops.aten.native_layer_norm_backward) -def _aten_native_layer_norm_backward(grad_out, - input, - normalized_shape, - weight, - bias, - eps=1e-5): +def _aten_native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps=1e-5 +): """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. Args: @@ -845,8 +849,9 @@ def _aten_native_layer_norm_backward(grad_out, Returns: A tuple of (grad_input, grad_weight, grad_bias). """ - return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape, - weight, bias, eps) + return jax.lax.native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps + ) # aten.reflection_pad3d_backward @@ -931,8 +936,10 @@ def _scatter_index(dim, index): target_shape = [1] * len(index_shape) target_shape[i] = index_shape[i] input_indexes.append( - jnp.broadcast_to( - jnp.arange(index_shape[i]).reshape(target_shape), index_shape)) + jnp.broadcast_to( + jnp.arange(index_shape[i]).reshape(target_shape), index_shape + ) + ) return tuple(input_indexes), tuple(source_indexes) @@ -1026,17 +1033,19 @@ def _aten_pixel_shuffle(x, upscale_factor): if channels % (upscale_factor**2) != 0: raise ValueError( - "Number of channels must be divisible by the square of the upscale factor." + "Number of channels must be divisible by the square of the upscale factor." ) new_channels = channels // (upscale_factor**2) new_height = height * upscale_factor new_width = width * upscale_factor - x = x.reshape(batch_size, new_channels, upscale_factor, upscale_factor, - height, width) - x = jnp.transpose(x, - (0, 1, 2, 4, 3, 5)) # Move channels to spatial dimensions + x = x.reshape( + batch_size, new_channels, upscale_factor, upscale_factor, height, width + ) + x = jnp.transpose( + x, (0, 1, 2, 4, 3, 5) + ) # Move channels to spatial dimensions x = x.reshape(batch_size, new_channels, new_height, new_width) return x @@ -1073,7 +1082,8 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): num_batch_dims = inputs.ndim - (len(window_shape) + 1) strides = strides or (1,) * len(window_shape) assert len(window_shape) == len( - strides), f"len({window_shape}) must equal len({strides})" + strides + ), f"len({window_shape}) must equal len({strides})" strides = (1,) * (1 + num_batch_dims) + strides dims = (1,) * (1 + num_batch_dims) + window_shape @@ -1090,10 +1100,12 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): if not isinstance(padding, str): padding = tuple(map(tuple, padding)) assert len(padding) == len(window_shape), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}") - assert all([len(x) == 2 for x in padding - ]), f"each entry in padding {padding} must be length 2" + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}" + ) + assert all( + [len(x) == 2 for x in padding] + ), f"each entry in padding {padding} must be length 2" padding = ((0, 0), (0, 0)) + padding y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) if is_single_input: @@ -1112,13 +1124,13 @@ def _aten_adaptive_avg_pool3d(x, output_shape): def _aten_adaptive_avg_pool(x, output_shape, pool_dim): - def adaptive_kernel_size(input_shape, output_shape): sizes = [1, 1] spatial_dim_off = len(input_shape) - pool_dim for spatial_dim in range(pool_dim): - sizes.append(input_shape[spatial_dim_off + spatial_dim] // - output_shape[spatial_dim]) + sizes.append( + input_shape[spatial_dim_off + spatial_dim] // output_shape[spatial_dim] + ) return tuple(sizes) kernel_sizes = adaptive_kernel_size(x.shape, output_shape) @@ -1131,8 +1143,8 @@ def adaptive_kernel_size(input_shape, output_shape): if len(div_shape) - 2 == len(kernel_sizes): div_shape = (1,) + div_shape[1:] y = y / pool( - jnp.ones(div_shape), 0.0, jax.lax.add, kernel_sizes, kernel_sizes, - "VALID") + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_sizes, kernel_sizes, "VALID" + ) return y @@ -1140,13 +1152,13 @@ def adaptive_kernel_size(input_shape, output_shape): @op(torch.ops.aten.avg_pool2d) @op(torch.ops.aten.avg_pool3d) def _aten_avg_pool( - inputs, - kernel_size, - strides=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None, + inputs, + kernel_size, + strides=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, ): num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 kernel_size = tuple(kernel_size) @@ -1166,7 +1178,8 @@ def _aten_avg_pool( if len(div_shape) - 2 == len(kernel_size): div_shape = (1,) + div_shape[1:] y = y / pool( - jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding) + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding + ) return y @@ -1219,9 +1232,9 @@ def _aten_round(input, decimals=0): # aten.max @op(torch.ops.aten.max) def _aten_max(self, dim=None, keepdim=False): - return jnp.max( - self, axis=dim, keepdims=keepdim), jnp.argmax( - self, axis=dim, keepdims=keepdim) + return jnp.max(self, axis=dim, keepdims=keepdim), jnp.argmax( + self, axis=dim, keepdims=keepdim + ) # aten.maximum @@ -1268,15 +1281,15 @@ def _aten_any(self, dim=None, keepdim=False): @op(torch.ops.aten.arange.start) @op(torch.ops.aten.arange.default) def _aten_arange( - start, - end=None, - step=1, - *, - dtype=None, - layout=None, - requires_grad=False, - device=None, - pin_memory=False, + start, + end=None, + step=1, + *, + dtype=None, + layout=None, + requires_grad=False, + device=None, + pin_memory=False, ): if end is None: end = start @@ -1284,10 +1297,10 @@ def _aten_arange( if dtype: dtype = tensor.t2j_dtype(dtype) return jnp.arange( - start, - end, - step, - dtype=dtype, + start, + end, + step, + dtype=dtype, ) @@ -1376,8 +1389,9 @@ def _aten_cdist_forward(x1, x2, p, compute_mode=""): @op(torch.ops.aten._pdist_forward) def _aten__pdist_forward(x, p): pairwise_dists = _aten_cdist_forward(x, x, p) - condensed_dists = pairwise_dists[jnp.triu_indices( - pairwise_dists.shape[0], k=1)] + condensed_dists = pairwise_dists[ + jnp.triu_indices(pairwise_dists.shape[0], k=1) + ] return condensed_dists @@ -1588,6 +1602,7 @@ def _aten_prod(self, dim=None, keepdim=False): # aten.randperm + # aten.reflection_pad3d @@ -1629,8 +1644,8 @@ def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): @op(torch.ops.aten.sort) def _aten_sort(a, dim=-1, descending=False, stable=False): return ( - jnp.sort(a, axis=dim, stable=stable, descending=descending), - jnp.argsort(a, axis=dim, stable=stable, descending=descending), + jnp.sort(a, axis=dim, stable=stable, descending=descending), + jnp.argsort(a, axis=dim, stable=stable, descending=descending), ) @@ -1669,8 +1684,8 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): if dim != -1 and dim != len(input.shape) - 1: transpose_shape = list(range(len(input.shape))) transpose_shape[dim], transpose_shape[-1] = ( - transpose_shape[-1], - transpose_shape[dim], + transpose_shape[-1], + transpose_shape[dim], ) input = jnp.transpose(input, transpose_shape) @@ -1679,7 +1694,8 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): if sorted: values = jnp.sort(values, descending=True) indices = jnp.take_along_axis( - indices, jnp.argsort(values, axis=-1, descending=True), axis=-1) + indices, jnp.argsort(values, axis=-1, descending=True), axis=-1 + ) if not largest: values = -values # Negate values back if we found smallest @@ -1701,8 +1717,9 @@ def _aten_trunc(a): @op(torch.ops.aten.unbind_copy) def _aten_unbind(a, dim=0): return tuple( - _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) - for i in range(a.shape[dim])) + _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) + for i in range(a.shape[dim]) + ) # NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d @@ -1721,11 +1738,9 @@ def _aten_where(condition, x, y): # aten.to.dtype # Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None @op(torch.ops.aten.to.dtype) -def _aten_to_dtype(a, - dtype, - non_blocking=False, - copy=False, - memory_format=None): +def _aten_to_dtype( + a, dtype, non_blocking=False, copy=False, memory_format=None +): if dtype: jaxdtype = tensor.t2j_dtype(dtype) return a.astype(jaxdtype) @@ -1738,17 +1753,15 @@ def _aten_to_dtype(a, @op(torch.ops.aten.var_mean.correction) def _aten_var_mean_correction(self, dim=None, correction=None, keepdim=False): return ( - jnp.var(self, axis=dim, ddof=correction, keepdims=keepdim), - jnp.mean(self, dim, keepdims=keepdim), + jnp.var(self, axis=dim, ddof=correction, keepdims=keepdim), + jnp.mean(self, dim, keepdims=keepdim), ) @op(torch.ops.aten.scalar_tensor) -def _aten_scalar_tensor(s, - dtype=None, - layout=None, - device=None, - pin_memory=None): +def _aten_scalar_tensor( + s, dtype=None, layout=None, device=None, pin_memory=None +): if dtype is not None: dtype = tensor.t2j_dtype(dtype) return jnp.array(s, dtype=dtype) @@ -1761,9 +1774,9 @@ def _aten_to_device(x, device, dtype): @op(torch.ops.aten.max_pool2d_with_indices_backward) -def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, - stride, padding, dilation, - ceil_mode, indices): +def max_pool2d_with_indices_backward_custom( + grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices +): """ Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. @@ -1819,15 +1832,15 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0): @op(torch.ops.aten.randn, needs_env=True) @op_base.convert_dtype() def _randn( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, ): shape = size if len(shape) == 1 and isinstance(shape[0], (list, tuple)): @@ -1842,15 +1855,15 @@ def _randn( @op(torch.ops.aten.rand, needs_env=True) @op_base.convert_dtype() def _rand( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, ): shape = size if len(shape) == 1 and isinstance(shape[0], (list, tuple)): @@ -1874,9 +1887,9 @@ def _aten_to_device(x, device, dtype): @op(torch.ops.aten.max_pool2d_with_indices_backward) -def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, - stride, padding, dilation, - ceil_mode, indices): +def max_pool2d_with_indices_backward_custom( + grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices +): """ Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. @@ -1938,16 +1951,8 @@ def _aten_outer(a, b): def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(input, other, rtol, atol, equal_nan) - @op(torch.ops.aten.native_batch_norm) -def _aten_native_batch_norm(input, - weight, - bias, - running_mean, - running_var, - training=False, - momentum=0.1, - eps=1e-5): +def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5): """JAX implementation of batch normalization. Args: @@ -1985,8 +1990,8 @@ def _aten_native_batch_norm(input, # Scale and shift out = xmu * ivar if weight is not None: - out *= weight.reshape(1, -1, 1, 1) + out *= weight.reshape(1, -1, 1, 1) if bias is not None: - out += bias.reshape(1, -1, 1, 1) + out += bias.reshape(1, -1, 1, 1) return out, running_mean, running_var From 6e0a0d39b2e626614a13fd2aaa542ee08aabdff8 Mon Sep 17 00:00:00 2001 From: zpcore Date: Fri, 17 May 2024 22:17:51 +0000 Subject: [PATCH 4/9] update comment --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5f07f09086d..e0acf954d88 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1957,10 +1957,10 @@ def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, trai Args: input: Input data (N, C, ...) + weight: Scaling factor (gamma) (C,), can be None + bias: Shift factor (beta) (C,), can be None running_mean: Running mean of input (C,) running_var: Running variance of input (C,) - weight: Optional scaling factor (gamma) (C,) - bias: Optional shift factor (beta) (C,) training: Whether to perform training-time or inference-time batch norm momentum: Momentum factor for updating running mean and variance eps: Small constant added to the variance to avoid division by zero From 13838ec7c033c32316fbd89d1f27da84aca5f6c8 Mon Sep 17 00:00:00 2001 From: zpcore Date: Fri, 17 May 2024 22:25:10 +0000 Subject: [PATCH 5/9] fix lax dependency --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index e0acf954d88..6dd94c8e42d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1985,7 +1985,7 @@ def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, trai # Normalize xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch - ivar = lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root + ivar = jax.lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root # Scale and shift out = xmu * ivar From e60a646ca3ee65733a07704317856db620feb426 Mon Sep 17 00:00:00 2001 From: zpcore Date: Mon, 20 May 2024 17:38:26 +0000 Subject: [PATCH 6/9] add native_batch_norm test --- experimental/torch_xla2/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 20686f2fe6c..1c6a9d77785 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -198,7 +198,6 @@ "nansum", "narrow_copy", "narrow", - "native_batch_norm", "native_layer_norm", "new_empty", "new_empty_strided", From 7d0495eca08ed8debafd34c060b7503184b505bc Mon Sep 17 00:00:00 2001 From: zpcore Date: Wed, 22 May 2024 06:09:41 +0000 Subject: [PATCH 7/9] fix core aten ops --- .../torch_xla2/test/test_core_aten_ops.py | 114 ++++++++++++++++ experimental/torch_xla2/test/test_ops.py | 1 - .../torch_xla2/torch_xla2/ops/jaten.py | 126 +++++++++--------- 3 files changed, 177 insertions(+), 64 deletions(-) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 6a1cef306be..175ff01eb03 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -41,6 +41,8 @@ def run_export_and_compare(testcase, with testcase.env: res2 = func(*args2, **kwargs2) res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) + print(res) + print(res2) # import pdb; pdb.set_trace() with testcase.subTest("torch_xla2_diff:" + str(atol)): if ignore_indices and isinstance(res, tuple) and len(res) == 2: @@ -2697,6 +2699,118 @@ def test_aten_native_layer_norm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs) + def test_aten_native_batch_norm_legit(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,2,2)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + False, + 0.5, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs) + + def test_aten_native_batch_norm_legit_none(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,4)).to(torch.float32), + None, + None, + torch.ones(channel), + torch.zeros(channel), + False, + 0.5, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs) + + def test_aten_native_batch_norm_legit_training_none(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + None, + None, + torch.zeros(channel), + torch.ones(channel), + True, + 0.2, + 2e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs) + + def test_aten_native_batch_norm_legit_no_training(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + True, + 0.2, + 2e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs) + + def test_aten_native_batch_norm_training(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + True, + 0.1, + 1e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) + + def test_aten_native_batch_norm_training_none(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + None, + None, + torch.zeros(channel), + torch.ones(channel), + True, + 0.1, + 1e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) + + def test_aten_native_batch_norm_eval(self): + batch = 3 + channel = 2 + args = ( + torch.randn((batch,channel,4,3)).to(torch.float32), + torch.ones(channel), + torch.zeros(channel), + torch.zeros(channel), + torch.ones(channel), + False, + 0.2, + 2e-5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) + def test_aten_ne_Scalar_0(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 1c6a9d77785..1e10706f100 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -13,7 +13,6 @@ "__getitem__", "__rmatmul__", "__rpow__", - "_native_batch_norm_legit", "_segment_reduce", "_upsample_bilinear2d_aa", "argsort", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 6dd94c8e42d..469a2fe9a13 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -546,35 +546,67 @@ def create_default_conv_dimension_numbers(num_spatial_dims): def _aten__native_batch_norm_legit( input, weight, bias, running_mean, running_var, training, momentum, eps ): - return _aten__native_batch_norm_legit_no_training( - input, weight, bias, running_mean, running_var, momentum, eps - ) + """JAX implementation of batch normalization with optional parameters. + Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713. + + Args: + input (DeviceArray): Input data (N, C, H, W). + running_mean ([DeviceArray]): Running mean of input (C,). + running_var ([DeviceArray]): Running variance of input (C,). + weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None. + bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None. + training (bool): If True, use batch statistics for normalization. + If False, use running statistics. + momentum (float): Momentum factor for updating running statistics. + eps (float): Small constant for numerical stability. + + Returns: + DeviceArray: Normalized output + DeviceArray: Batch mean (C,) or empty if training is False + DeviceArray: Reversed batch variance (C,) or empty if training is False + """ + reduction_dims = [0] + list(range(2, input.ndim)) + reshape_dims = [1, -1] + [1]*(input.ndim-2) + + if training: + # Calculate batch mean and variance + mean = jnp.mean(input, axis=reduction_dims, keepdims=True) + saved_mean = jnp.squeeze(mean, reduction_dims) + var = jnp.var(input, axis=reduction_dims) + rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps) + # Update running statistics using momentum + running_mean = (1 - momentum) * running_mean + momentum * saved_mean + running_var = (1 - momentum) * running_var + momentum * var + saved_rstd = jnp.squeeze(rstd, reduction_dims) + else: + rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps) + saved_mean = jnp.array([]) # No need to calculate batch statistics in inference mode + saved_rstd = jnp.array([]) + + # Normalize + if training: + # use batch statistics if training + x_hat = (input - mean) * rstd + else: + # Use running statistics in inference mode + x_hat = (input - running_mean.reshape(reshape_dims)) * rstd + + # Scale and shift + if weight is not None: + x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting + if bias is not None: + x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting + + return x_hat, saved_mean, saved_rstd + @op(torch.ops.aten._native_batch_norm_legit_no_training) def _aten__native_batch_norm_legit_no_training( input, weight, bias, running_mean, running_var, momentum, eps ): - if weight is None: - weight = jnp.ones_like(running_mean) - if bias is None: - bias = jnp.zeros_like(running_mean) - - def broadcast(t): - return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,)) - - if running_mean is not None: - a = input - broadcast(running_mean) - else: - a = input - if running_var is not None: - b = broadcast(jnp.sqrt(running_var + eps)) - else: - b = broadcast(jnp.sqrt(eps)) - return ( - a / b * broadcast(weight) + broadcast(bias), - jnp.array([]), - jnp.array([]), + return _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, False, momentum, eps ) @@ -1953,45 +1985,13 @@ def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): @op(torch.ops.aten.native_batch_norm) def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5): - """JAX implementation of batch normalization. - - Args: - input: Input data (N, C, ...) - weight: Scaling factor (gamma) (C,), can be None - bias: Shift factor (beta) (C,), can be None - running_mean: Running mean of input (C,) - running_var: Running variance of input (C,) - training: Whether to perform training-time or inference-time batch norm - momentum: Momentum factor for updating running mean and variance - eps: Small constant added to the variance to avoid division by zero - - Returns: - Output data, updated running mean, updated running var - """ - + + if running_mean is None: + running_mean = jnp.zeros(input.shape[1]) # Initialize running mean if None + if running_var is None: + running_var = jnp.ones(input.shape[1]) # Initialize running variance if None + if training: - # Training-time batch norm: compute statistics across the batch - mean = jnp.mean(input, axis=(0, 2, 3)) - var = jnp.var(input, axis=(0, 2, 3)) - - # Update running statistics - running_mean = momentum * mean + (1 - momentum) * running_mean - running_var = momentum * var + (1 - momentum) * running_var - + return torch.ops.aten._native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps) else: - # Inference-time batch norm: use pre-computed running statistics - mean = running_mean - var = running_var - - # Normalize - xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch - ivar = jax.lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root - - # Scale and shift - out = xmu * ivar - if weight is not None: - out *= weight.reshape(1, -1, 1, 1) - if bias is not None: - out += bias.reshape(1, -1, 1, 1) - - return out, running_mean, running_var + return torch.ops.aten._native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps) \ No newline at end of file From a10991ae1869ff42ca823170e56dccca5c90fe38 Mon Sep 17 00:00:00 2001 From: zpcore Date: Wed, 22 May 2024 06:12:26 +0000 Subject: [PATCH 8/9] nit update --- experimental/torch_xla2/test/test_core_aten_ops.py | 2 -- experimental/torch_xla2/torch_xla2/ops/jaten.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 175ff01eb03..388986364d9 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -41,8 +41,6 @@ def run_export_and_compare(testcase, with testcase.env: res2 = func(*args2, **kwargs2) res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) - print(res) - print(res2) # import pdb; pdb.set_trace() with testcase.subTest("torch_xla2_diff:" + str(atol)): if ignore_indices and isinstance(res, tuple) and len(res) == 2: diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 469a2fe9a13..c5ca628908f 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1994,4 +1994,4 @@ def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, trai if training: return torch.ops.aten._native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps) else: - return torch.ops.aten._native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps) \ No newline at end of file + return torch.ops.aten._native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps) From 57b18dd9b9e22b1c7d689c35b976a1a31276ac90 Mon Sep 17 00:00:00 2001 From: zpcore Date: Wed, 22 May 2024 06:17:47 +0000 Subject: [PATCH 9/9] nit update --- experimental/torch_xla2/test/test_core_aten_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 388986364d9..c11884fa370 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -2754,7 +2754,6 @@ def test_aten_native_batch_norm_legit_no_training(self): torch.zeros(channel), torch.zeros(channel), torch.ones(channel), - True, 0.2, 2e-5, )