diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 827542835f..eb17798902 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -71,6 +71,7 @@ from .nnx.nn.attention import make_attention_mask as make_attention_mask from .nnx.nn.attention import make_causal_mask as make_causal_mask from .nnx.nn.linear import Conv as Conv +from .nnx.nn.linear import ConvTranspose as ConvTranspose from .nnx.nn.linear import Embed as Embed from .nnx.nn.linear import Linear as Linear from .nnx.nn.linear import LinearGeneral as LinearGeneral diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/experimental/nnx/nnx/helpers.py index 8aeca41764..b670ea73f9 100644 --- a/flax/experimental/nnx/nnx/helpers.py +++ b/flax/experimental/nnx/nnx/helpers.py @@ -48,14 +48,12 @@ class Dict(Module, tp.Mapping[str, A]): @tp.overload - def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): - ... + def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ... @tp.overload def __init__( self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A - ): - ... + ): ... def __init__(self, *args, **kwargs): for name, value in dict(*args, **kwargs).items(): @@ -126,15 +124,18 @@ def _graph_node_pop_key(self, key: Key): return super()._graph_node_pop_key(key) -class Sequential(List): +class Sequential(Module): + def __init__(self, *fns: tp.Callable[..., tp.Any]): + self.layers = list(fns) + def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: output: tp.Any = None - for i, f in enumerate(self): + for i, f in enumerate(self.layers): if not callable(f): raise TypeError(f'Sequence[{i}] is not callable: {f}') if i > 0: - if isinstance(output, tp.Tuple): + if isinstance(output, tuple): args = output kwargs = {} elif isinstance(output, dict): @@ -154,8 +155,7 @@ def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: class ModuleDefApply(tp.Protocol, tp.Generic[M]): def __call__( self, state: State, *states: State - ) -> ApplyCaller[tuple[State, GraphDef[M]]]: - ... + ) -> ApplyCaller[tuple[State, GraphDef[M]]]: ... class TrainState(tp.Generic[M], struct.PyTreeNode): @@ -186,8 +186,7 @@ def create( if tp.TYPE_CHECKING: - def __getattr__(self, key: str) -> tp.Any: - ... + def __getattr__(self, key: str) -> tp.Any: ... def apply( self, state: tp.Union[State, str], *states: tp.Union[State, str] diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/experimental/nnx/nnx/nn/linear.py index 990f6c6f3a..6905202895 100644 --- a/flax/experimental/nnx/nnx/nn/linear.py +++ b/flax/experimental/nnx/nnx/nn/linear.py @@ -41,7 +41,6 @@ from flax.experimental.nnx.nnx.module import Module, first_from from flax.experimental.nnx.nnx.nn import dtypes, initializers from flax.typing import ( - Array, Dtype, Shape, Initializer, @@ -52,11 +51,13 @@ LaxPadding, ) +Array = jax.Array Axis = int Size = int default_kernel_init = initializers.lecun_normal() +default_bias_init = initializers.zeros_init() def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: @@ -158,7 +159,7 @@ def __init__( dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, kernel_init: Initializer = default_kernel_init, - bias_init: Initializer = initializers.zeros_init(), + bias_init: Initializer = default_bias_init, precision: PrecisionLike = None, # Deprecated. Will be removed. dot_general: DotGeneralT | None = None, @@ -316,7 +317,7 @@ def __init__( param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, kernel_init: Initializer = default_kernel_init, - bias_init: Initializer = initializers.zeros_init(), + bias_init: Initializer = default_bias_init, dot_general: DotGeneralT = lax.dot_general, rngs: rnglib.Rngs, ): @@ -410,7 +411,7 @@ def __init__( param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, kernel_init: Initializer = default_kernel_init, - bias_init: Initializer = initializers.zeros_init(), + bias_init: Initializer = default_bias_init, rngs: rnglib.Rngs, ): einsum_str = einsum_str.replace(' ', '') @@ -576,7 +577,7 @@ def __init__( param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, kernel_init: Initializer = default_kernel_init, - bias_init: Initializer = initializers.zeros_init(), + bias_init: Initializer = default_bias_init, conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated, rngs: rnglib.Rngs, ): @@ -641,7 +642,7 @@ def __call__(self, inputs: Array) -> Array: def maybe_broadcast( x: tp.Optional[tp.Union[int, tp.Sequence[int]]], - ) -> tp.Tuple[int, ...]: + ) -> tuple[int, ...]: if x is None: # backward compatibility with using None as sentinel for # broadcast 1 @@ -670,7 +671,7 @@ def maybe_broadcast( kernel_size_dilated = [ (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) ] - zero_pad: tp.List[tp.Tuple[int, int]] = [(0, 0)] + zero_pad: tp.List[tuple[int, int]] = [(0, 0)] pads = ( zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] @@ -725,6 +726,210 @@ def maybe_broadcast( y = jnp.reshape(y, output_shape) return y +class ConvTranspose(Module): + # features: int + # kernel_size: Union[int, Sequence[int]] + # strides: Optional[Sequence[int]] = None + # padding: PaddingLike = 'SAME' + # kernel_dilation: Optional[Sequence[int]] = None + # use_bias: bool = True + # mask: Optional[Array] = None + # dtype: Optional[Dtype] = None + # param_dtype: Dtype = jnp.float32 + # precision: PrecisionLike = None + # kernel_init: Initializer = default_kernel_init + # bias_init: Initializer = initializers.zeros_init() + # transpose_kernel: bool = False + + def __init__( + self, + in_features: int, + out_features: int, + kernel_size: int | tp.Sequence[int], + strides: int | tp.Sequence[int] | None = None, + *, + padding: PaddingLike = 'SAME', + kernel_dilation: int | tp.Sequence[int] | None = None, + use_bias: bool = True, + mask: Array | None = None, + dtype: Dtype | None = None, + param_dtype: Dtype = jnp.float32, + precision: PrecisionLike | None = None, + kernel_init: Initializer = default_kernel_init, + bias_init: Initializer = default_bias_init, + transpose_kernel: bool = False, + rngs: rnglib.Rngs, + ): + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) + else: + kernel_size = tuple(kernel_size) + + self.kernel_size = kernel_size + self.in_features = in_features + self.out_features = out_features + self.strides = strides + self.padding = padding + self.kernel_dilation = kernel_dilation + self.use_bias = use_bias + self.mask = mask + self.dtype = dtype + self.param_dtype = param_dtype + self.precision = precision + self.kernel_init = kernel_init + self.bias_init = bias_init + self.transpose_kernel = transpose_kernel + + if self.transpose_kernel: + kernel_shape = kernel_size + (self.out_features, in_features) + else: + kernel_shape = kernel_size + (in_features, self.out_features) + + self.kernel_shape = kernel_shape + self.kernel = nnx.Param( + self.kernel_init(rngs.params(), kernel_shape, self.param_dtype) + ) + + if self.use_bias: + self.bias = nnx.Param( + self.bias_init(rngs.params(), (self.out_features,), self.param_dtype) + ) + else: + self.bias = None + + def __call__(self, inputs: Array) -> Array: + """Applies a transposed convolution to the inputs. + + Behaviour mirrors of ``jax.lax.conv_transpose``. + + Args: + inputs: input data with dimensions (*batch_dims, spatial_dims..., + features). This is the channels-last convention, i.e. NHWC for a 2d + convolution and NDHWC for a 3D convolution. Note: this is different from + the input convention used by ``lax.conv_general_dilated``, which puts the + spatial dimensions last. + Note: If the input has more than 1 batch dimension, all batch dimensions + are flattened into a single dimension for the convolution and restored + before returning. In some cases directly vmap'ing the layer may yield + better performance than this default flattening approach. If the input + lacks a batch dimension it will be added for the convolution and removed + n return, an allowance made to enable writing single-example code. + + Returns: + The convolved data. + """ + kernel_size = self.kernel_size + + def maybe_broadcast( + x: tp.Optional[tp.Union[int, tp.Sequence[int]]], + ) -> tuple[int, ...]: + if x is None: + # backward compatibility with using None as sentinel for + # broadcast 1 + x = 1 + if isinstance(x, int): + return (x,) * len(kernel_size) + return tuple(x) + + # Combine all input batch dimensions into a single leading batch axis. + num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) + if num_batch_dimensions != 1: + input_batch_shape = inputs.shape[:num_batch_dimensions] + total_batch_size = int(np.prod(input_batch_shape)) + flat_input_shape = (total_batch_size,) + inputs.shape[ + num_batch_dimensions: + ] + inputs = jnp.reshape(inputs, flat_input_shape) + + strides = maybe_broadcast(self.strides) + kernel_dilation = maybe_broadcast(self.kernel_dilation) + + kernel_shape = self.kernel_shape + + if self.mask is not None and self.mask.shape != kernel_shape: + raise ValueError( + 'Mask needs to have the same shape as weights. ' + f'Shapes are: {self.mask.shape}, {kernel_shape}' + ) + + kernel = self.kernel.value + + if self.mask is not None: + kernel *= self.mask + + padding_lax = canonicalize_padding(self.padding, len(kernel_size)) + if padding_lax == 'CIRCULAR': + padding_lax = 'VALID' + + bias = self.bias.value if self.bias is not None else None + + inputs, kernel, bias = dtypes.promote_dtype( + inputs, kernel, bias, dtype=self.dtype + ) + + y = lax.conv_transpose( + inputs, + kernel, + strides, + padding_lax, + rhs_dilation=kernel_dilation, + transpose_kernel=self.transpose_kernel, + precision=self.precision, + ) + + if self.padding == 'CIRCULAR': + # For circular padding, we need to identify the size of the final output + # ("period") along each spatial dimension, pad each dimension to an + # integer number of periods, and wrap the array periodically around each + # dimension. Padding should be done in such a way that the start of the + # original input data inside the padded array is located at integer + # number of periods - otherwise the result would be circularly shifted. + + # Compute period along each spatial dimension - it's input size scaled + # by the stride. + scaled_x_dims = [ + x_dim * stride + for x_dim, stride in zip(jnp.shape(inputs)[1:-1], strides) + ] + # Compute difference between the current size of y and the final output + # size, and complement this difference to 2 * period - that gives how + # much we need to pad. + size_diffs = [ + -(y_dim - x_dim) % (2 * x_dim) + for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims) + ] + if self.transpose_kernel: + # If the kernel is transposed, the "+1" is put on the right to + # mirror the regular convolution. If the same kernel parameters are used + # as for Conv, this layer then computes the proper transpose convolution. + total_pad = [ + (size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs + ] + else: + # Divide the padding equally between left and right. The choice to put + # "+1" on the left (and not on the right) represents a convention for + # aligning even-sized kernels. + total_pad = [ + ((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs + ] + y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)]) + # Wrap the result periodically around each spatial dimension, + # one by one. + for i in range(1, y.ndim - 1): + y = y.reshape( + y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :] + ) + y = y.sum(axis=i) + + if self.use_bias: + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) # type: ignore + + if num_batch_dimensions != 1: + output_shape = input_batch_shape + y.shape[1:] + y = jnp.reshape(y, output_shape) + + return y + default_embed_init = initializers.variance_scaling( 1.0, 'fan_in', 'normal', out_axis=0